implmented plaintext version HNSW algorithm

This commit is contained in:
2025-07-18 22:16:57 +08:00
parent f78764da13
commit 979f6d17d7
2 changed files with 303 additions and 6 deletions

View File

@@ -1,6 +1,9 @@
#![allow(dead_code)]
use anyhow::Result;
use clap::Parser;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
@@ -8,9 +11,9 @@ use std::io::{BufRead, BufReader, Write};
#[command(name = "hfe_knn")]
#[command(about = "FHE-based KNN classifier")]
struct Args {
#[arg(long)]
#[arg(long, default_value = "./dataset/train.jsonl")]
dataset: String,
#[arg(long)]
#[arg(long, default_value = "./dataset/answer1.jsonl")]
predictions: String,
}
@@ -45,27 +48,268 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
}
#[derive(Clone)]
struct HNSWNode {
vector: Vec<f64>,
level: usize,
neighbors: Vec<Vec<usize>>,
}
#[derive(Clone)]
struct HNSWGraph {
nodes: Vec<HNSWNode>,
entry_point: Option<usize>,
max_level: usize,
max_connections: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct OrderedFloat(f64);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
}
}
impl HNSWGraph {
fn new() -> Self {
Self {
nodes: Vec::new(),
entry_point: None,
max_level: 3,
max_connections: 16,
}
}
fn insert_node(&mut self, vector: Vec<f64>) -> usize {
let level = self.select_level();
let node_id = self.nodes.len();
let node = HNSWNode {
vector,
level,
neighbors: vec![Vec::new(); level + 1],
};
// 先添加节点到向量中
self.nodes.push(node);
if node_id == 0 {
self.entry_point = Some(node_id);
return node_id;
}
let mut current_candidates = vec![self.entry_point.unwrap()];
// 从最高层搜索到目标层+1
for lc in (level + 1..=self.max_level).rev() {
current_candidates =
self.search_layer(&self.nodes[node_id].vector, current_candidates, 1, lc);
}
// 从目标层到第0层建立连接
for lc in (0..=level).rev() {
current_candidates = self.search_layer(
&self.nodes[node_id].vector,
current_candidates,
self.max_connections,
lc,
);
for &candidate_id in &current_candidates {
self.connect_nodes(node_id, candidate_id, lc);
}
}
// 更新入口点
if level > self.nodes[self.entry_point.unwrap()].level {
self.entry_point = Some(node_id);
}
node_id
}
fn search_layer(
&self,
query: &[f64],
entry_points: Vec<usize>,
ef: usize,
layer: usize,
) -> Vec<usize> {
let mut visited = std::collections::HashSet::new();
let mut candidates = std::collections::BinaryHeap::new();
let mut w = std::collections::BinaryHeap::new();
// 初始化候选点
for &ep in &entry_points {
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
let dist = euclidean_distance(query, &self.nodes[ep].vector);
candidates.push(std::cmp::Reverse((OrderedFloat(dist), ep)));
w.push((OrderedFloat(dist), ep));
visited.insert(ep);
}
}
while let Some(std::cmp::Reverse((current_dist, current))) = candidates.pop() {
// 如果当前距离已经比最远的结果距离大,停止搜索
if let Some(&(farthest_dist, _)) = w.iter().max() {
if current_dist > farthest_dist && w.len() >= ef {
break;
}
}
// 探索当前节点的邻居
if current < self.nodes.len() && layer < self.nodes[current].neighbors.len() {
for &neighbor in &self.nodes[current].neighbors[layer] {
if !visited.contains(&neighbor) && neighbor < self.nodes.len() {
visited.insert(neighbor);
let dist = euclidean_distance(query, &self.nodes[neighbor].vector);
let ordered_dist = OrderedFloat(dist);
if w.len() < ef {
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
w.push((ordered_dist, neighbor));
} else if let Some(&(farthest_dist, _)) = w.iter().max() {
if ordered_dist < farthest_dist {
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
w.push((ordered_dist, neighbor));
// 移除最远的点
if let Some(max_item) = w.iter().max().copied() {
w.retain(|&x| x != max_item);
}
}
}
}
}
}
}
// 返回按距离排序的结果
let mut results: Vec<_> = w.into_iter().collect();
results.sort_by(|a, b| a.0.cmp(&b.0));
results.into_iter().map(|(_, idx)| idx).collect()
}
fn select_level(&self) -> usize {
let mut rng = rand::rng();
let mut level = 0;
while level < self.max_level && rng.random::<f32>() < 0.5 {
level += 1;
}
level
}
fn connect_nodes(&mut self, node1: usize, node2: usize, layer: usize) {
self.nodes[node1].neighbors[layer].push(node2);
self.nodes[node2].neighbors[layer].push(node1);
if self.nodes[node1].neighbors[layer].len() > self.max_connections {
self.prune_node_connections(node1, layer);
}
if self.nodes[node2].neighbors[layer].len() > self.max_connections {
self.prune_node_connections(node2, layer);
}
}
fn prune_node_connections(&mut self, node_id: usize, _layer: usize) {
let node_vector = self.nodes[node_id].vector.clone();
let nodes = self.nodes.clone(); // Create immutable reference before mutable borrow
// 计算所有邻居的距离并排序
let neighbors = &mut self.nodes[node_id].neighbors[_layer];
let mut neighbor_distances: Vec<(f64, usize)> = neighbors
.iter()
.map(|&neighbor_id| {
let dist = euclidean_distance(&node_vector, &nodes[neighbor_id].vector);
(dist, neighbor_id)
})
.collect();
// 按距离排序
neighbor_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// 更新邻居列表,只保留距离最近的
*neighbors = neighbor_distances
.into_iter()
.take(self.max_connections)
.map(|(_, neighbor_id)| neighbor_id)
.collect();
}
}
// fn main() -> Result<()> {
// let args = Args::parse();
//
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
//
// let mut results = Vec::new();
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// }
//
// Ok(())
// }
fn main() -> Result<()> {
let args = Args::parse();
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
let mut results = Vec::new();
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
let mut hnsw = HNSWGraph::new();
for data_point in &dataset.data {
hnsw.insert_node(data_point.clone());
}
let nearest = if let Some(entry_point) = hnsw.entry_point {
let mut search_results = vec![entry_point];
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
search_results = hnsw.search_layer(&dataset.query, search_results, 1, layer);
}
let final_results = hnsw.search_layer(&dataset.query, search_results, 16, 0);
final_results
.into_iter()
.take(10)
.map(|idx| idx + 1)
.collect()
} else {
Vec::new()
};
results.push(Prediction { answer: nearest });
}
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
println!("{}", serde_json::to_string(&result)?);
}
Ok(())
}

53
testhnsw.py Normal file
View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python3
import subprocess
import json
from collections import Counter
def load_answers(filepath):
with open(filepath, "r") as f:
data = json.load(f)
return data["answer"]
def run_plain_binary():
result = subprocess.run(
["cargo", "r", "-r", "--bin", "plain"], capture_output=True, text=True, cwd="."
)
if result.returncode == 0:
# The program outputs the same results as answer1.jsonl
return load_answers("dataset/answer1.jsonl")
return None
def compare_answers(predictions, ground_truth):
if not predictions or len(predictions) != len(ground_truth):
return 0
return sum(1 for p, gt in zip(predictions, ground_truth) if p == gt)
def main():
ground_truth = load_answers("dataset/answer.jsonl")
num_runs = 100
accuracies = []
for i in range(num_runs):
predictions = run_plain_binary()
if predictions is not None:
accuracy = compare_answers(predictions, ground_truth)
accuracies.append(accuracy)
print(f"\nResults ({len(accuracies)} runs):")
print(
f"Min: {min(accuracies)}, Max: {max(accuracies)}, Mean: {sum(accuracies)/len(accuracies):.2f}"
)
counter = Counter(accuracies)
print("Distribution:")
for correct_count in sorted(counter.keys()):
print(f" {correct_count} correct: {counter[correct_count]} times")
if __name__ == "__main__":
main()