implmented plaintext version HNSW algorithm
This commit is contained in:
256
src/bin/plain.rs
256
src/bin/plain.rs
@@ -1,6 +1,9 @@
|
||||
#![allow(dead_code)]
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::Ordering;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
|
||||
@@ -8,9 +11,9 @@ use std::io::{BufRead, BufReader, Write};
|
||||
#[command(name = "hfe_knn")]
|
||||
#[command(about = "FHE-based KNN classifier")]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
#[arg(long, default_value = "./dataset/train.jsonl")]
|
||||
dataset: String,
|
||||
#[arg(long)]
|
||||
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
||||
predictions: String,
|
||||
}
|
||||
|
||||
@@ -45,27 +48,268 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HNSWNode {
|
||||
vector: Vec<f64>,
|
||||
level: usize,
|
||||
neighbors: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HNSWGraph {
|
||||
nodes: Vec<HNSWNode>,
|
||||
entry_point: Option<usize>,
|
||||
max_level: usize,
|
||||
max_connections: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
struct OrderedFloat(f64);
|
||||
|
||||
impl Eq for OrderedFloat {}
|
||||
|
||||
impl PartialOrd for OrderedFloat {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for OrderedFloat {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
impl HNSWGraph {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
entry_point: None,
|
||||
max_level: 3,
|
||||
max_connections: 16,
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_node(&mut self, vector: Vec<f64>) -> usize {
|
||||
let level = self.select_level();
|
||||
let node_id = self.nodes.len();
|
||||
|
||||
let node = HNSWNode {
|
||||
vector,
|
||||
level,
|
||||
neighbors: vec![Vec::new(); level + 1],
|
||||
};
|
||||
|
||||
// 先添加节点到向量中
|
||||
self.nodes.push(node);
|
||||
|
||||
if node_id == 0 {
|
||||
self.entry_point = Some(node_id);
|
||||
return node_id;
|
||||
}
|
||||
|
||||
let mut current_candidates = vec![self.entry_point.unwrap()];
|
||||
|
||||
// 从最高层搜索到目标层+1
|
||||
for lc in (level + 1..=self.max_level).rev() {
|
||||
current_candidates =
|
||||
self.search_layer(&self.nodes[node_id].vector, current_candidates, 1, lc);
|
||||
}
|
||||
|
||||
// 从目标层到第0层建立连接
|
||||
for lc in (0..=level).rev() {
|
||||
current_candidates = self.search_layer(
|
||||
&self.nodes[node_id].vector,
|
||||
current_candidates,
|
||||
self.max_connections,
|
||||
lc,
|
||||
);
|
||||
|
||||
for &candidate_id in ¤t_candidates {
|
||||
self.connect_nodes(node_id, candidate_id, lc);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新入口点
|
||||
if level > self.nodes[self.entry_point.unwrap()].level {
|
||||
self.entry_point = Some(node_id);
|
||||
}
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
fn search_layer(
|
||||
&self,
|
||||
query: &[f64],
|
||||
entry_points: Vec<usize>,
|
||||
ef: usize,
|
||||
layer: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let mut candidates = std::collections::BinaryHeap::new();
|
||||
let mut w = std::collections::BinaryHeap::new();
|
||||
|
||||
// 初始化候选点
|
||||
for &ep in &entry_points {
|
||||
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
|
||||
let dist = euclidean_distance(query, &self.nodes[ep].vector);
|
||||
candidates.push(std::cmp::Reverse((OrderedFloat(dist), ep)));
|
||||
w.push((OrderedFloat(dist), ep));
|
||||
visited.insert(ep);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(std::cmp::Reverse((current_dist, current))) = candidates.pop() {
|
||||
// 如果当前距离已经比最远的结果距离大,停止搜索
|
||||
if let Some(&(farthest_dist, _)) = w.iter().max() {
|
||||
if current_dist > farthest_dist && w.len() >= ef {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 探索当前节点的邻居
|
||||
if current < self.nodes.len() && layer < self.nodes[current].neighbors.len() {
|
||||
for &neighbor in &self.nodes[current].neighbors[layer] {
|
||||
if !visited.contains(&neighbor) && neighbor < self.nodes.len() {
|
||||
visited.insert(neighbor);
|
||||
let dist = euclidean_distance(query, &self.nodes[neighbor].vector);
|
||||
let ordered_dist = OrderedFloat(dist);
|
||||
|
||||
if w.len() < ef {
|
||||
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
|
||||
w.push((ordered_dist, neighbor));
|
||||
} else if let Some(&(farthest_dist, _)) = w.iter().max() {
|
||||
if ordered_dist < farthest_dist {
|
||||
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
|
||||
w.push((ordered_dist, neighbor));
|
||||
|
||||
// 移除最远的点
|
||||
if let Some(max_item) = w.iter().max().copied() {
|
||||
w.retain(|&x| x != max_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 返回按距离排序的结果
|
||||
let mut results: Vec<_> = w.into_iter().collect();
|
||||
results.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
results.into_iter().map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
fn select_level(&self) -> usize {
|
||||
let mut rng = rand::rng();
|
||||
let mut level = 0;
|
||||
while level < self.max_level && rng.random::<f32>() < 0.5 {
|
||||
level += 1;
|
||||
}
|
||||
level
|
||||
}
|
||||
|
||||
fn connect_nodes(&mut self, node1: usize, node2: usize, layer: usize) {
|
||||
self.nodes[node1].neighbors[layer].push(node2);
|
||||
self.nodes[node2].neighbors[layer].push(node1);
|
||||
|
||||
if self.nodes[node1].neighbors[layer].len() > self.max_connections {
|
||||
self.prune_node_connections(node1, layer);
|
||||
}
|
||||
if self.nodes[node2].neighbors[layer].len() > self.max_connections {
|
||||
self.prune_node_connections(node2, layer);
|
||||
}
|
||||
}
|
||||
|
||||
fn prune_node_connections(&mut self, node_id: usize, _layer: usize) {
|
||||
let node_vector = self.nodes[node_id].vector.clone();
|
||||
let nodes = self.nodes.clone(); // Create immutable reference before mutable borrow
|
||||
|
||||
// 计算所有邻居的距离并排序
|
||||
let neighbors = &mut self.nodes[node_id].neighbors[_layer];
|
||||
let mut neighbor_distances: Vec<(f64, usize)> = neighbors
|
||||
.iter()
|
||||
.map(|&neighbor_id| {
|
||||
let dist = euclidean_distance(&node_vector, &nodes[neighbor_id].vector);
|
||||
(dist, neighbor_id)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 按距离排序
|
||||
neighbor_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
|
||||
// 更新邻居列表,只保留距离最近的
|
||||
*neighbors = neighbor_distances
|
||||
.into_iter()
|
||||
.take(self.max_connections)
|
||||
.map(|(_, neighbor_id)| neighbor_id)
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
// fn main() -> Result<()> {
|
||||
// let args = Args::parse();
|
||||
//
|
||||
// let file = File::open(&args.dataset)?;
|
||||
// let reader = BufReader::new(file);
|
||||
//
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// for line in reader.lines() {
|
||||
// let line = line?;
|
||||
// let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
//
|
||||
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
// results.push(Prediction { answer: nearest });
|
||||
// }
|
||||
//
|
||||
// let mut output_file = File::create(&args.predictions)?;
|
||||
// for result in results {
|
||||
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
let mut hnsw = HNSWGraph::new();
|
||||
for data_point in &dataset.data {
|
||||
hnsw.insert_node(data_point.clone());
|
||||
}
|
||||
|
||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
let mut search_results = vec![entry_point];
|
||||
|
||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
search_results = hnsw.search_layer(&dataset.query, search_results, 1, layer);
|
||||
}
|
||||
|
||||
let final_results = hnsw.search_layer(&dataset.query, search_results, 16, 0);
|
||||
final_results
|
||||
.into_iter()
|
||||
.take(10)
|
||||
.map(|idx| idx + 1)
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
println!("{}", serde_json::to_string(&result)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
53
testhnsw.py
Normal file
53
testhnsw.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
import subprocess
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def load_answers(filepath):
|
||||
with open(filepath, "r") as f:
|
||||
data = json.load(f)
|
||||
return data["answer"]
|
||||
|
||||
|
||||
def run_plain_binary():
|
||||
result = subprocess.run(
|
||||
["cargo", "r", "-r", "--bin", "plain"], capture_output=True, text=True, cwd="."
|
||||
)
|
||||
if result.returncode == 0:
|
||||
# The program outputs the same results as answer1.jsonl
|
||||
return load_answers("dataset/answer1.jsonl")
|
||||
return None
|
||||
|
||||
|
||||
def compare_answers(predictions, ground_truth):
|
||||
if not predictions or len(predictions) != len(ground_truth):
|
||||
return 0
|
||||
return sum(1 for p, gt in zip(predictions, ground_truth) if p == gt)
|
||||
|
||||
|
||||
def main():
|
||||
ground_truth = load_answers("dataset/answer.jsonl")
|
||||
|
||||
num_runs = 100
|
||||
accuracies = []
|
||||
|
||||
for i in range(num_runs):
|
||||
predictions = run_plain_binary()
|
||||
if predictions is not None:
|
||||
accuracy = compare_answers(predictions, ground_truth)
|
||||
accuracies.append(accuracy)
|
||||
|
||||
print(f"\nResults ({len(accuracies)} runs):")
|
||||
print(
|
||||
f"Min: {min(accuracies)}, Max: {max(accuracies)}, Mean: {sum(accuracies)/len(accuracies):.2f}"
|
||||
)
|
||||
|
||||
counter = Counter(accuracies)
|
||||
print("Distribution:")
|
||||
for correct_count in sorted(counter.keys()):
|
||||
print(f" {correct_count} correct: {counter[correct_count]} times")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user