diff --git a/src/bin/plain.rs b/src/bin/plain.rs index 9bb62bd..d2f76c2 100644 --- a/src/bin/plain.rs +++ b/src/bin/plain.rs @@ -1,6 +1,9 @@ +#![allow(dead_code)] use anyhow::Result; use clap::Parser; +use rand::Rng; use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; use std::fs::File; use std::io::{BufRead, BufReader, Write}; @@ -8,9 +11,9 @@ use std::io::{BufRead, BufReader, Write}; #[command(name = "hfe_knn")] #[command(about = "FHE-based KNN classifier")] struct Args { - #[arg(long)] + #[arg(long, default_value = "./dataset/train.jsonl")] dataset: String, - #[arg(long)] + #[arg(long, default_value = "./dataset/answer1.jsonl")] predictions: String, } @@ -45,27 +48,268 @@ fn knn_classify(query: &[f64], data: &[Vec], k: usize) -> Vec { distances.into_iter().take(k).map(|(_, idx)| idx).collect() } +#[derive(Clone)] +struct HNSWNode { + vector: Vec, + level: usize, + neighbors: Vec>, +} + +#[derive(Clone)] +struct HNSWGraph { + nodes: Vec, + entry_point: Option, + max_level: usize, + max_connections: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +struct OrderedFloat(f64); + +impl Eq for OrderedFloat {} + +impl PartialOrd for OrderedFloat { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OrderedFloat { + fn cmp(&self, other: &Self) -> Ordering { + self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal) + } +} + +impl HNSWGraph { + fn new() -> Self { + Self { + nodes: Vec::new(), + entry_point: None, + max_level: 3, + max_connections: 16, + } + } + + fn insert_node(&mut self, vector: Vec) -> usize { + let level = self.select_level(); + let node_id = self.nodes.len(); + + let node = HNSWNode { + vector, + level, + neighbors: vec![Vec::new(); level + 1], + }; + + // 先添加节点到向量中 + self.nodes.push(node); + + if node_id == 0 { + self.entry_point = Some(node_id); + return node_id; + } + + let mut current_candidates = vec![self.entry_point.unwrap()]; + + // 从最高层搜索到目标层+1 + for lc in (level + 1..=self.max_level).rev() { + current_candidates = + self.search_layer(&self.nodes[node_id].vector, current_candidates, 1, lc); + } + + // 从目标层到第0层建立连接 + for lc in (0..=level).rev() { + current_candidates = self.search_layer( + &self.nodes[node_id].vector, + current_candidates, + self.max_connections, + lc, + ); + + for &candidate_id in ¤t_candidates { + self.connect_nodes(node_id, candidate_id, lc); + } + } + + // 更新入口点 + if level > self.nodes[self.entry_point.unwrap()].level { + self.entry_point = Some(node_id); + } + + node_id + } + + fn search_layer( + &self, + query: &[f64], + entry_points: Vec, + ef: usize, + layer: usize, + ) -> Vec { + let mut visited = std::collections::HashSet::new(); + let mut candidates = std::collections::BinaryHeap::new(); + let mut w = std::collections::BinaryHeap::new(); + + // 初始化候选点 + for &ep in &entry_points { + if ep < self.nodes.len() && self.nodes[ep].level >= layer { + let dist = euclidean_distance(query, &self.nodes[ep].vector); + candidates.push(std::cmp::Reverse((OrderedFloat(dist), ep))); + w.push((OrderedFloat(dist), ep)); + visited.insert(ep); + } + } + + while let Some(std::cmp::Reverse((current_dist, current))) = candidates.pop() { + // 如果当前距离已经比最远的结果距离大,停止搜索 + if let Some(&(farthest_dist, _)) = w.iter().max() { + if current_dist > farthest_dist && w.len() >= ef { + break; + } + } + + // 探索当前节点的邻居 + if current < self.nodes.len() && layer < self.nodes[current].neighbors.len() { + for &neighbor in &self.nodes[current].neighbors[layer] { + if !visited.contains(&neighbor) && neighbor < self.nodes.len() { + visited.insert(neighbor); + let dist = euclidean_distance(query, &self.nodes[neighbor].vector); + let ordered_dist = OrderedFloat(dist); + + if w.len() < ef { + candidates.push(std::cmp::Reverse((ordered_dist, neighbor))); + w.push((ordered_dist, neighbor)); + } else if let Some(&(farthest_dist, _)) = w.iter().max() { + if ordered_dist < farthest_dist { + candidates.push(std::cmp::Reverse((ordered_dist, neighbor))); + w.push((ordered_dist, neighbor)); + + // 移除最远的点 + if let Some(max_item) = w.iter().max().copied() { + w.retain(|&x| x != max_item); + } + } + } + } + } + } + } + + // 返回按距离排序的结果 + let mut results: Vec<_> = w.into_iter().collect(); + results.sort_by(|a, b| a.0.cmp(&b.0)); + results.into_iter().map(|(_, idx)| idx).collect() + } + + fn select_level(&self) -> usize { + let mut rng = rand::rng(); + let mut level = 0; + while level < self.max_level && rng.random::() < 0.5 { + level += 1; + } + level + } + + fn connect_nodes(&mut self, node1: usize, node2: usize, layer: usize) { + self.nodes[node1].neighbors[layer].push(node2); + self.nodes[node2].neighbors[layer].push(node1); + + if self.nodes[node1].neighbors[layer].len() > self.max_connections { + self.prune_node_connections(node1, layer); + } + if self.nodes[node2].neighbors[layer].len() > self.max_connections { + self.prune_node_connections(node2, layer); + } + } + + fn prune_node_connections(&mut self, node_id: usize, _layer: usize) { + let node_vector = self.nodes[node_id].vector.clone(); + let nodes = self.nodes.clone(); // Create immutable reference before mutable borrow + + // 计算所有邻居的距离并排序 + let neighbors = &mut self.nodes[node_id].neighbors[_layer]; + let mut neighbor_distances: Vec<(f64, usize)> = neighbors + .iter() + .map(|&neighbor_id| { + let dist = euclidean_distance(&node_vector, &nodes[neighbor_id].vector); + (dist, neighbor_id) + }) + .collect(); + + // 按距离排序 + neighbor_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + + // 更新邻居列表,只保留距离最近的 + *neighbors = neighbor_distances + .into_iter() + .take(self.max_connections) + .map(|(_, neighbor_id)| neighbor_id) + .collect(); + } +} + +// fn main() -> Result<()> { +// let args = Args::parse(); +// +// let file = File::open(&args.dataset)?; +// let reader = BufReader::new(file); +// +// let mut results = Vec::new(); +// +// for line in reader.lines() { +// let line = line?; +// let dataset: Dataset = serde_json::from_str(&line)?; +// +// let nearest = knn_classify(&dataset.query, &dataset.data, 10); +// results.push(Prediction { answer: nearest }); +// } +// +// let mut output_file = File::create(&args.predictions)?; +// for result in results { +// writeln!(output_file, "{}", serde_json::to_string(&result)?)?; +// } +// +// Ok(()) +// } fn main() -> Result<()> { let args = Args::parse(); - let file = File::open(&args.dataset)?; let reader = BufReader::new(file); - let mut results = Vec::new(); for line in reader.lines() { let line = line?; let dataset: Dataset = serde_json::from_str(&line)?; - let nearest = knn_classify(&dataset.query, &dataset.data, 10); + let mut hnsw = HNSWGraph::new(); + for data_point in &dataset.data { + hnsw.insert_node(data_point.clone()); + } + + let nearest = if let Some(entry_point) = hnsw.entry_point { + let mut search_results = vec![entry_point]; + + for layer in (1..=hnsw.nodes[entry_point].level).rev() { + search_results = hnsw.search_layer(&dataset.query, search_results, 1, layer); + } + + let final_results = hnsw.search_layer(&dataset.query, search_results, 16, 0); + final_results + .into_iter() + .take(10) + .map(|idx| idx + 1) + .collect() + } else { + Vec::new() + }; + results.push(Prediction { answer: nearest }); } let mut output_file = File::create(&args.predictions)?; for result in results { writeln!(output_file, "{}", serde_json::to_string(&result)?)?; + println!("{}", serde_json::to_string(&result)?); } Ok(()) } - diff --git a/testhnsw.py b/testhnsw.py new file mode 100644 index 0000000..83a9aa5 --- /dev/null +++ b/testhnsw.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +import subprocess +import json +from collections import Counter + + +def load_answers(filepath): + with open(filepath, "r") as f: + data = json.load(f) + return data["answer"] + + +def run_plain_binary(): + result = subprocess.run( + ["cargo", "r", "-r", "--bin", "plain"], capture_output=True, text=True, cwd="." + ) + if result.returncode == 0: + # The program outputs the same results as answer1.jsonl + return load_answers("dataset/answer1.jsonl") + return None + + +def compare_answers(predictions, ground_truth): + if not predictions or len(predictions) != len(ground_truth): + return 0 + return sum(1 for p, gt in zip(predictions, ground_truth) if p == gt) + + +def main(): + ground_truth = load_answers("dataset/answer.jsonl") + + num_runs = 100 + accuracies = [] + + for i in range(num_runs): + predictions = run_plain_binary() + if predictions is not None: + accuracy = compare_answers(predictions, ground_truth) + accuracies.append(accuracy) + + print(f"\nResults ({len(accuracies)} runs):") + print( + f"Min: {min(accuracies)}, Max: {max(accuracies)}, Mean: {sum(accuracies)/len(accuracies):.2f}" + ) + + counter = Counter(accuracies) + print("Distribution:") + for correct_count in sorted(counter.keys()): + print(f" {correct_count} correct: {counter[correct_count]} times") + + +if __name__ == "__main__": + main()