feat: remove hnsw
This commit is contained in:
parent
7fae6b23b7
commit
25675228f4
64
CLAUDE.md
Normal file
64
CLAUDE.md
Normal file
@ -0,0 +1,64 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
This project implements a K-Nearest Neighbors (KNN) algorithm using Fully Homomorphic Encryption (FHE) in Rust. The implementation uses TFHE-rs for cryptographic operations and operates on a 10-dimensional synthetic dataset with 100 training points.
|
||||
|
||||
## Key Architecture
|
||||
|
||||
- **Homomorphic Encryption**: Uses TFHE-rs library for fully homomorphic encryption operations
|
||||
- **Data Processing**: Synthetic dataset features are scaled by 10x to preserve decimal precision as integers
|
||||
- **KNN Implementation**: Complete implementation with multiple algorithms:
|
||||
- `euclidean_distance()`: Optimized distance calculation using precomputed squares
|
||||
- `perform_knn_selection()`: Supports selection sort, bitonic sort, and heap-based selection
|
||||
- `encrypted_bitonic_sort()`: Parallel bitonic sort with power-of-2 padding
|
||||
|
||||
## Build and Development Commands
|
||||
|
||||
```bash
|
||||
# Build the project
|
||||
cargo build
|
||||
|
||||
# Run with different algorithms
|
||||
cargo run --bin enc # Default selection sort
|
||||
cargo run --bin enc -- --algorithm=bitonic # Bitonic sort (fastest for large datasets)
|
||||
cargo run --bin enc -- --algorithm=heap # Heap-based selection
|
||||
cargo run --bin enc -- --debug # Debug mode with plaintext verification
|
||||
cargo run --bin plain # Plaintext version for comparison
|
||||
|
||||
# Development commands - ALWAYS use cargo check for verification
|
||||
cargo check # Use this for code verification, NOT cargo run
|
||||
cargo test
|
||||
cargo fmt
|
||||
cargo clippy
|
||||
```
|
||||
|
||||
## Data Structure
|
||||
|
||||
The project processes synthetic 10-dimensional dataset with these key data structures:
|
||||
|
||||
- `EncryptedQuery`: Query point with precomputed values for optimization
|
||||
- `EncryptedPoint`: Training data points with precomputed squared sums
|
||||
- `EncryptedNeighbor`: Distance and index pairs for KNN results
|
||||
- Custom deserializer converts float values to scaled integers (×10) for FHE compatibility
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Training Data**: `dataset/train.jsonl` containing one query point and 100 10-dimensional training points
|
||||
- **Results**: `dataset/answer.jsonl` and `dataset/answer1.jsonl` contain KNN classification results in JSON format
|
||||
|
||||
## Important Technical Notes
|
||||
|
||||
- **FheInt14 Range**: Valid range is -8192 to 8191 (2^13). Using values outside this range (like i16::MAX = 32767) will cause overflow
|
||||
- **Bitonic Sort**: Requires `up=true` for ascending order to get smallest distances first. Using `false` gives largest distances (wrong for KNN)
|
||||
- **Performance**: Bitonic sort is fastest for larger datasets due to parallel processing, but requires power-of-2 padding
|
||||
|
||||
## Git Workflow Instructions
|
||||
|
||||
**IMPORTANT**: When user asks to "write commit" or "帮我写commit":
|
||||
|
||||
- Do NOT add any files to staging area
|
||||
- User has already staged the files they want to commit
|
||||
- Only create the commit with appropriate message for the staged changes
|
@ -321,60 +321,3 @@ fn encrypted_conditional_swap(
|
||||
a.index = new_a_index;
|
||||
b.index = new_b_index;
|
||||
}
|
||||
|
||||
///// 执行HNSW近似最近邻搜索
|
||||
/////
|
||||
///// # Arguments
|
||||
///// * `graph` - 预构建的FHE HNSW图结构
|
||||
///// * `query` - 加密的查询点
|
||||
///// * `zero` - 加密的零值
|
||||
/////
|
||||
///// # Returns
|
||||
///// * 10个最近邻的索引列表
|
||||
//pub fn perform_hnsw_search(
|
||||
// graph: &FheHnswGraph,
|
||||
// query: &EncryptedQuery,
|
||||
// zero: &FheInt14,
|
||||
//) -> Vec<usize> {
|
||||
// println!("🚀 Starting HNSW approximate search...");
|
||||
//
|
||||
// if graph.nodes.is_empty() {
|
||||
// println!("❌ Empty HNSW graph");
|
||||
// return Vec::new();
|
||||
// }
|
||||
//
|
||||
// let Some(entry_point) = graph.entry_point else {
|
||||
// println!("❌ No entry point in HNSW graph");
|
||||
// return Vec::new();
|
||||
// };
|
||||
//
|
||||
// println!(
|
||||
// "🔍 HNSW search from entry point {} at level {}",
|
||||
// entry_point, graph.max_level
|
||||
// );
|
||||
//
|
||||
// let mut current_candidates = vec![entry_point];
|
||||
//
|
||||
// // 从最高层逐层搜索到第1层
|
||||
// for layer in (1..=graph.max_level).rev() {
|
||||
// println!("🔍 Searching layer {layer} with ef=1...");
|
||||
// let layer_start = Instant::now();
|
||||
// current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
|
||||
// println!(
|
||||
// "✅ Layer {} search completed in {}, {} candidates",
|
||||
// layer,
|
||||
// format_duration(layer_start.elapsed()),
|
||||
// current_candidates.len()
|
||||
// );
|
||||
// }
|
||||
//
|
||||
// // 在第0层进行最终搜索
|
||||
// let final_search_start = Instant::now();
|
||||
// let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
|
||||
// println!(
|
||||
// "✅ Final layer search completed in {}, {} candidates",
|
||||
// format_duration(final_search_start.elapsed()),
|
||||
// final_candidates.len()
|
||||
// );
|
||||
// final_candidates
|
||||
//}
|
||||
|
@ -26,7 +26,7 @@ struct Args {
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "bitonic",
|
||||
help = "Algorithm: selection, bitonic, heap, hnsw"
|
||||
help = "Algorithm: selection, bitonic, heap"
|
||||
)]
|
||||
algorithm: String,
|
||||
#[arg(long, help = "Enable debug mode (plaintext calculation first)")]
|
||||
@ -143,24 +143,6 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
|
||||
let query_encrypted = encrypt_query(&dataset.query, client_key);
|
||||
|
||||
let k = 10; // Number of nearest neighbors
|
||||
// if args.algorithm == "hnsw" {
|
||||
// // HNSW 算法路径
|
||||
// println!("🚀 Using HNSW algorithm...");
|
||||
//
|
||||
// // 1. 构建明文HNSW图
|
||||
// let (plaintext_nodes, entry_point, max_level) =
|
||||
// build_plaintext_hnsw_graph(&dataset.data);
|
||||
//
|
||||
// // 2. 转换为加密HNSW图
|
||||
// println!("🔐 Converting to encrypted HNSW graph...");
|
||||
// let fhe_graph =
|
||||
// build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
|
||||
//
|
||||
// // 3. 执行HNSW搜索
|
||||
// let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
// let answer = perform_hnsw_search(&fhe_graph, &query_encrypted, &encrypted_zero);
|
||||
// results.push(Prediction { answer });
|
||||
// } else {
|
||||
// 传统算法路径
|
||||
// Encrypt all training points
|
||||
println!("🔐 Encrypting training points...");
|
||||
|
284
src/bin/plain.rs
284
src/bin/plain.rs
@ -1,9 +1,6 @@
|
||||
#![allow(dead_code)]
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::Ordering;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
|
||||
@ -15,28 +12,6 @@ struct Args {
|
||||
dataset: String,
|
||||
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
||||
predictions: String,
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "8",
|
||||
help = "Max connections per node (M parameter)"
|
||||
)]
|
||||
max_connections: usize,
|
||||
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
|
||||
max_level: usize,
|
||||
#[arg(long, default_value = "0.6", help = "Level selection probability")]
|
||||
level_prob: f32,
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "1",
|
||||
help = "ef parameter for upper layer search"
|
||||
)]
|
||||
ef_upper: usize,
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "10",
|
||||
help = "ef parameter for bottom layer search"
|
||||
)]
|
||||
ef_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -82,278 +57,25 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HNSWNode {
|
||||
vector: Vec<f64>,
|
||||
level: usize,
|
||||
neighbors: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HNSWGraph {
|
||||
nodes: Vec<HNSWNode>,
|
||||
entry_point: Option<usize>,
|
||||
max_level: usize,
|
||||
max_connections: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
struct OrderedFloat(f64);
|
||||
|
||||
impl Eq for OrderedFloat {}
|
||||
|
||||
impl PartialOrd for OrderedFloat {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for OrderedFloat {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
impl HNSWGraph {
|
||||
fn new(max_level: usize, max_connections: usize) -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
entry_point: None,
|
||||
max_level,
|
||||
max_connections,
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_node(&mut self, vector: Vec<f64>, level_prob: f32) -> usize {
|
||||
let level = self.select_level(level_prob);
|
||||
let node_id = self.nodes.len();
|
||||
|
||||
let node = HNSWNode {
|
||||
vector,
|
||||
level,
|
||||
neighbors: vec![Vec::new(); level + 1],
|
||||
};
|
||||
|
||||
// 先添加节点到向量中
|
||||
self.nodes.push(node);
|
||||
|
||||
if node_id == 0 {
|
||||
self.entry_point = Some(node_id);
|
||||
return node_id;
|
||||
}
|
||||
|
||||
let mut current_candidates = vec![self.entry_point.unwrap()];
|
||||
|
||||
// 从最高层搜索到目标层+1
|
||||
for lc in (level + 1..=self.max_level).rev() {
|
||||
current_candidates =
|
||||
self.search_layer(&self.nodes[node_id].vector, current_candidates, 1, lc);
|
||||
}
|
||||
|
||||
// 从目标层到第0层建立连接
|
||||
for lc in (0..=level).rev() {
|
||||
current_candidates = self.search_layer(
|
||||
&self.nodes[node_id].vector,
|
||||
current_candidates,
|
||||
self.max_connections,
|
||||
lc,
|
||||
);
|
||||
|
||||
for &candidate_id in ¤t_candidates {
|
||||
self.connect_nodes(node_id, candidate_id, lc);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新入口点
|
||||
if level > self.nodes[self.entry_point.unwrap()].level {
|
||||
self.entry_point = Some(node_id);
|
||||
}
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
fn search_layer(
|
||||
&self,
|
||||
query: &[f64],
|
||||
entry_points: Vec<usize>,
|
||||
ef: usize,
|
||||
layer: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let mut candidates = std::collections::BinaryHeap::new();
|
||||
let mut w = std::collections::BinaryHeap::new();
|
||||
|
||||
// 初始化候选点
|
||||
for &ep in &entry_points {
|
||||
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
|
||||
let dist = euclidean_distance(query, &self.nodes[ep].vector);
|
||||
candidates.push(std::cmp::Reverse((OrderedFloat(dist), ep)));
|
||||
w.push((OrderedFloat(dist), ep));
|
||||
visited.insert(ep);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(std::cmp::Reverse((current_dist, current))) = candidates.pop() {
|
||||
// 如果当前距离已经比最远的结果距离大,停止搜索
|
||||
if let Some(&(farthest_dist, _)) = w.iter().max() {
|
||||
if current_dist > farthest_dist && w.len() >= ef {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 探索当前节点的邻居
|
||||
if current < self.nodes.len() && layer < self.nodes[current].neighbors.len() {
|
||||
for &neighbor in &self.nodes[current].neighbors[layer] {
|
||||
if !visited.contains(&neighbor) && neighbor < self.nodes.len() {
|
||||
visited.insert(neighbor);
|
||||
let dist = euclidean_distance(query, &self.nodes[neighbor].vector);
|
||||
let ordered_dist = OrderedFloat(dist);
|
||||
|
||||
if w.len() < ef {
|
||||
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
|
||||
w.push((ordered_dist, neighbor));
|
||||
} else if let Some(&(farthest_dist, _)) = w.iter().max() {
|
||||
if ordered_dist < farthest_dist {
|
||||
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
|
||||
w.push((ordered_dist, neighbor));
|
||||
|
||||
// 移除最远的点
|
||||
if let Some(max_item) = w.iter().max().copied() {
|
||||
w.retain(|&x| x != max_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 返回按距离排序的结果
|
||||
let mut results: Vec<_> = w.into_iter().collect();
|
||||
results.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
results.into_iter().map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
fn select_level(&self, level_prob: f32) -> usize {
|
||||
let mut rng = rand::rng();
|
||||
let mut level = 0;
|
||||
while level < self.max_level && rng.random::<f32>() < level_prob {
|
||||
level += 1;
|
||||
}
|
||||
level
|
||||
}
|
||||
|
||||
fn connect_nodes(&mut self, node1: usize, node2: usize, layer: usize) {
|
||||
self.nodes[node1].neighbors[layer].push(node2);
|
||||
self.nodes[node2].neighbors[layer].push(node1);
|
||||
|
||||
if self.nodes[node1].neighbors[layer].len() > self.max_connections {
|
||||
self.prune_node_connections(node1, layer);
|
||||
}
|
||||
if self.nodes[node2].neighbors[layer].len() > self.max_connections {
|
||||
self.prune_node_connections(node2, layer);
|
||||
}
|
||||
}
|
||||
|
||||
fn prune_node_connections(&mut self, node_id: usize, _layer: usize) {
|
||||
let node_vector = self.nodes[node_id].vector.clone();
|
||||
let nodes = self.nodes.clone(); // Create immutable reference before mutable borrow
|
||||
|
||||
// 计算所有邻居的距离并排序
|
||||
let neighbors = &mut self.nodes[node_id].neighbors[_layer];
|
||||
let mut neighbor_distances: Vec<(f64, usize)> = neighbors
|
||||
.iter()
|
||||
.map(|&neighbor_id| {
|
||||
let dist = euclidean_distance(&node_vector, &nodes[neighbor_id].vector);
|
||||
(dist, neighbor_id)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 按距离排序
|
||||
neighbor_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
|
||||
// 更新邻居列表,只保留距离最近的
|
||||
*neighbors = neighbor_distances
|
||||
.into_iter()
|
||||
.take(self.max_connections)
|
||||
.map(|(_, neighbor_id)| neighbor_id)
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
// fn main() -> Result<()> {
|
||||
// let args = Args::parse();
|
||||
//
|
||||
// let file = File::open(&args.dataset)?;
|
||||
// let reader = BufReader::new(file);
|
||||
//
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// for line in reader.lines() {
|
||||
// let line = line?;
|
||||
// let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
//
|
||||
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
// results.push(Prediction { answer: nearest });
|
||||
// }
|
||||
//
|
||||
// let mut output_file = File::create(&args.predictions)?;
|
||||
// for result in results {
|
||||
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut results = Vec::new();
|
||||
|
||||
println!("🔧 HNSW Parameters:");
|
||||
println!(" max_connections: {}", args.max_connections);
|
||||
println!(" max_level: {}", args.max_level);
|
||||
println!(" level_prob: {}", args.level_prob);
|
||||
println!(" ef_upper: {}", args.ef_upper);
|
||||
println!(" ef_bottom: {}", args.ef_bottom);
|
||||
let mut results = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||
for data_point in &dataset.data {
|
||||
hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||
}
|
||||
|
||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
let mut search_results = vec![entry_point];
|
||||
|
||||
// 上层搜索使用ef_upper参数
|
||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
search_results =
|
||||
hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||
}
|
||||
|
||||
// 底层搜索使用ef_bottom参数
|
||||
let final_results =
|
||||
hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||
final_results
|
||||
.into_iter()
|
||||
.take(10)
|
||||
.map(|idx| idx + 1)
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
println!("{}", serde_json::to_string(&result)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
248
src/data.rs
248
src/data.rs
@ -128,251 +128,3 @@ pub fn decrypt_indices(encrypted_indices: &[FheUint8], client_key: &ClientKey) -
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
///// 密文版 HNSW 节点
|
||||
//#[derive(Clone)]
|
||||
//pub struct FheHnswNode {
|
||||
// pub encrypted_point: EncryptedPoint,
|
||||
// pub level: usize,
|
||||
// pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
|
||||
//}
|
||||
//
|
||||
///// 密文版 HNSW 图
|
||||
//#[derive(Clone)]
|
||||
//pub struct FheHnswGraph {
|
||||
// pub nodes: Vec<FheHnswNode>,
|
||||
// pub entry_point: Option<usize>,
|
||||
// pub max_level: usize,
|
||||
// pub distances: [Option<FheInt14>; 100],
|
||||
//}
|
||||
//
|
||||
//impl FheHnswGraph {
|
||||
// pub fn new() -> Self {
|
||||
// Self {
|
||||
// nodes: Vec::new(),
|
||||
// entry_point: None,
|
||||
// max_level: 0,
|
||||
// distances: [const { None };100]
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /// 搜索层函数
|
||||
// ///
|
||||
// /// 核心算法:实现标准HNSW贪心搜索
|
||||
// ///
|
||||
// /// 算法流程:
|
||||
// /// 1. 初始化候选队列(candidates)和结果集(w)
|
||||
// /// 2. 从entry_points开始,计算到query的密文距离
|
||||
// /// 3. 贪心搜索循环:
|
||||
// /// - 从candidates中选择距离最小的点作为当前探索点
|
||||
// /// - 探索该点的所有邻居
|
||||
// /// - 将未访问的邻居加入candidates和w
|
||||
// /// - 维护w的大小不超过ef(移除最远的点)
|
||||
// /// - 实现剪枝:如果当前点比w中最远点还远且w已满,则停止
|
||||
// /// 4. 返回w中的节点索引
|
||||
// ///
|
||||
// /// 关键挑战:
|
||||
// /// - 需要模拟优先队列但只能用密文排序
|
||||
// /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
|
||||
// /// - 候选队列管理要高效,避免过多密文运算
|
||||
// pub fn search_layer(
|
||||
// &self,
|
||||
// query: &EncryptedQuery,
|
||||
// entry_points: Vec<usize>,
|
||||
// ef: usize,
|
||||
// layer: usize,
|
||||
// zero: &FheInt14,
|
||||
// ) -> Vec<usize> {
|
||||
// let visited: HashSet<usize> = HashSet::new();
|
||||
// let mut candidate =
|
||||
//
|
||||
// // TODO: 设计合适的数据结构
|
||||
// // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
|
||||
// // w: 结果集,维护当前找到的最好的ef个候选点
|
||||
//
|
||||
// // TODO: Step 1 - 初始化候选点
|
||||
// // 对每个entry_point:
|
||||
// // - 检查节点是否存在且level >= layer
|
||||
// // - 计算到query的密文距离:euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
|
||||
// // - 将节点加入visited集合
|
||||
// // - 将(节点索引, 距离)加入candidates和w
|
||||
//
|
||||
// println!(
|
||||
// "🔍 Starting search with {} initial candidates",
|
||||
// // TODO: 显示实际初始化的候选点数量
|
||||
// 0
|
||||
// );
|
||||
//
|
||||
// // TODO: Step 2 - 主搜索循环
|
||||
// // while candidates不为空:
|
||||
// // - 从candidates中找到距离最小的点(需要密文排序或其他方法)
|
||||
// // - 将该点从candidates中移除,设为当前探索点current
|
||||
// //
|
||||
// // - 剪枝检查(可选,复杂):
|
||||
// // 如果w.len() >= ef且current的距离 > w中最远点的距离,则break
|
||||
// //
|
||||
// // - 探索current的邻居:
|
||||
// // for neighbor in self.nodes[current].neighbors[layer]:
|
||||
// // - 如果neighbor未访问过:
|
||||
// // - 标记为已访问
|
||||
// // - 计算到query的密文距离
|
||||
// // - 将neighbor加入candidates
|
||||
// // - 管理结果集w:
|
||||
// // if w.len() < ef: 直接加入w
|
||||
// // else: 加入w后排序,移除最远的点,保持w大小为ef
|
||||
//
|
||||
// println!(
|
||||
// "🔍 Found {} total candidates (ef={})",
|
||||
// // TODO: 显示最终找到的候选点数量
|
||||
// 0,
|
||||
// ef
|
||||
// );
|
||||
//
|
||||
// // TODO: Step 3 - 返回结果
|
||||
// // 从w中提取节点索引并返回
|
||||
// // w.into_iter().map(|(node_idx, _)| node_idx).collect()
|
||||
//
|
||||
// // 临时返回空结果,避免编译错误
|
||||
// Vec::new()
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//impl Default for FheHnswGraph {
|
||||
// fn default() -> Self {
|
||||
// Self::new()
|
||||
// }
|
||||
//}
|
||||
//
|
||||
///// 从明文数据构建密文 HNSW 图的辅助结构
|
||||
//#[derive(Clone)]
|
||||
//pub struct PlaintextHnswNode {
|
||||
// pub vector: Vec<f64>,
|
||||
// pub level: usize,
|
||||
// pub neighbors: Vec<Vec<usize>>,
|
||||
//}
|
||||
//
|
||||
///// 从明文数据构建密文 HNSW 图
|
||||
//pub fn build_fhe_hnsw_from_plaintext(
|
||||
// plaintext_nodes: &[PlaintextHnswNode],
|
||||
// plaintext_entry_point: Option<usize>,
|
||||
// plaintext_max_level: usize,
|
||||
// client_key: &ClientKey,
|
||||
//) -> FheHnswGraph {
|
||||
// use crate::logging::print_progress_bar;
|
||||
//
|
||||
// let mut fhe_nodes = Vec::new();
|
||||
// let total_nodes = plaintext_nodes.len();
|
||||
//
|
||||
// println!("🔐 Encrypting {total_nodes} HNSW nodes...");
|
||||
//
|
||||
// for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
|
||||
// print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
|
||||
// let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
|
||||
// let fhe_node = FheHnswNode {
|
||||
// encrypted_point,
|
||||
// level: plain_node.level,
|
||||
// neighbors: plain_node.neighbors.clone(),
|
||||
// };
|
||||
// fhe_nodes.push(fhe_node);
|
||||
// }
|
||||
//
|
||||
//
|
||||
// FheHnswGraph {
|
||||
// nodes: fhe_nodes,
|
||||
// entry_point: plaintext_entry_point,
|
||||
// max_level: plaintext_max_level,
|
||||
// }
|
||||
//}
|
||||
//
|
||||
///// 构建明文HNSW图
|
||||
//pub fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
|
||||
// println!("🔨 Building HNSW graph from {} points...", data.len());
|
||||
//
|
||||
// let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
|
||||
// let max_level = 3; // 最优参数:保持3层结构
|
||||
// let mut nodes = Vec::new();
|
||||
// let mut entry_point = None;
|
||||
// let mut current_max_level = 0;
|
||||
//
|
||||
// // 创建所有节点
|
||||
// for (idx, vector) in data.iter().enumerate() {
|
||||
// let level = select_level_for_node(max_level);
|
||||
// current_max_level = current_max_level.max(level);
|
||||
//
|
||||
// let node = PlaintextHnswNode {
|
||||
// vector: vector.clone(),
|
||||
// level,
|
||||
// neighbors: vec![Vec::new(); level + 1],
|
||||
// };
|
||||
// nodes.push(node);
|
||||
//
|
||||
// if idx == 0 {
|
||||
// entry_point = Some(idx);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 为每个节点建立连接(简化版实现)
|
||||
// for node_id in 0..nodes.len() {
|
||||
// print_progress_bar(node_id + 1, nodes.len(), "Building connections");
|
||||
//
|
||||
// let node_vector = nodes[node_id].vector.clone();
|
||||
// let node_level = nodes[node_id].level;
|
||||
//
|
||||
// // 为每层找到最近的邻居
|
||||
// for layer in 0..=node_level {
|
||||
// let mut distances: Vec<(f64, usize)> = nodes
|
||||
// .iter()
|
||||
// .enumerate()
|
||||
// .filter(|(idx, n)| *idx != node_id && n.level >= layer)
|
||||
// .map(|(idx, n)| {
|
||||
// let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
|
||||
// (dist, idx)
|
||||
// })
|
||||
// .collect();
|
||||
//
|
||||
// distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
//
|
||||
// // 添加最近的邻居(限制连接数)
|
||||
// for &(_, neighbor_id) in distances.iter().take(max_connections) {
|
||||
// nodes[node_id].neighbors[layer].push(neighbor_id);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 更新入口点
|
||||
// if nodes[node_id].level > nodes[entry_point.unwrap()].level {
|
||||
// entry_point = Some(node_id);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// println!(); // 新行
|
||||
// println!(
|
||||
// "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
|
||||
// nodes.len(),
|
||||
// entry_point,
|
||||
// current_max_level
|
||||
// );
|
||||
//
|
||||
// (nodes, entry_point, current_max_level)
|
||||
//}
|
||||
//
|
||||
///// 选择节点的层级
|
||||
//pub fn select_level_for_node(max_level: usize) -> usize {
|
||||
// let mut rng = rand::rng();
|
||||
// let mut level = 0;
|
||||
// while level < max_level && rng.random::<f32>() < 0.6 {
|
||||
// // 最优参数:提高层级概率到0.6
|
||||
// level += 1;
|
||||
// }
|
||||
// level
|
||||
//}
|
||||
//
|
||||
//
|
||||
//
|
||||
///// 明文欧几里得距离计算
|
||||
//fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
|
||||
// a.iter()
|
||||
// .zip(b.iter())
|
||||
// .map(|(x, y)| (x - y).powi(2))
|
||||
// .sum::<f64>()
|
||||
// .sqrt()
|
||||
//}
|
||||
|
3
src/main.rs
Normal file
3
src/main.rs
Normal file
@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
todo!();
|
||||
}
|
304
testhnsw.py
304
testhnsw.py
@ -1,304 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
HNSW参数优化测试脚本 - 控制变量法
|
||||
针对100个10维向量,单次查询的场景进行参数调优
|
||||
每个配置运行100次,记录≥90%准确率的成功率
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# 正确答案
|
||||
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
|
||||
|
||||
|
||||
def calculate_accuracy(result, correct):
|
||||
"""计算准确率:匹配元素数量 / 总元素数量"""
|
||||
matches = len(set(result) & set(correct))
|
||||
return matches / len(correct) * 100
|
||||
|
||||
|
||||
def run_hnsw_test(max_connections, max_level, level_prob, ef_upper, ef_bottom):
|
||||
"""运行HNSW测试并返回结果"""
|
||||
cmd = [
|
||||
"cargo",
|
||||
"run",
|
||||
"--bin",
|
||||
"plain",
|
||||
"--",
|
||||
"--max-connections",
|
||||
str(max_connections),
|
||||
"--max-level",
|
||||
str(max_level),
|
||||
"--level-prob",
|
||||
str(level_prob),
|
||||
"--ef-upper",
|
||||
str(ef_upper),
|
||||
"--ef-bottom",
|
||||
str(ef_bottom),
|
||||
"--predictions",
|
||||
"./test_output.jsonl",
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
with open("./test_output.jsonl", "r") as f:
|
||||
output = json.load(f)
|
||||
return output["answer"]
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def test_config_stability(config, runs=100):
|
||||
"""测试配置的稳定性:运行100次,统计≥90%准确率的次数"""
|
||||
success_count = 0
|
||||
accuracies = []
|
||||
|
||||
print(f" 测试中... ", end="", flush=True)
|
||||
|
||||
for i in range(runs):
|
||||
if i % 20 == 0 and i > 0:
|
||||
print(f"{i}/100 ", end="", flush=True)
|
||||
|
||||
result = run_hnsw_test(
|
||||
config["max_connections"],
|
||||
config["max_level"],
|
||||
config["level_prob"],
|
||||
config["ef_upper"],
|
||||
config["ef_bottom"],
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
accuracy = calculate_accuracy(result, CORRECT_ANSWER)
|
||||
accuracies.append(accuracy)
|
||||
if accuracy >= 90:
|
||||
success_count += 1
|
||||
|
||||
success_rate = success_count / len(accuracies) * 100 if accuracies else 0
|
||||
avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
|
||||
|
||||
print(f"完成!")
|
||||
return success_rate, avg_accuracy, len(accuracies)
|
||||
|
||||
|
||||
def main():
|
||||
"""主测试函数 - 测试最优参数组合"""
|
||||
print("🚀 HNSW最优参数组合测试")
|
||||
print(f"🎯 正确答案: {CORRECT_ANSWER}")
|
||||
print("📊 测试最优参数组合运行100次,记录≥90%准确率的成功率")
|
||||
print()
|
||||
|
||||
# 最优参数组合
|
||||
optimal_config = {
|
||||
"max_connections": 8,
|
||||
"max_level": 3,
|
||||
"level_prob": 0.6,
|
||||
"ef_upper": 1,
|
||||
"ef_bottom": 10,
|
||||
}
|
||||
|
||||
print("=" * 80)
|
||||
print("🏆 测试最优参数组合")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"参数配置: M={optimal_config['max_connections']}, L={optimal_config['max_level']}, "
|
||||
+ f"P={optimal_config['level_prob']}, ef_upper={optimal_config['ef_upper']}, "
|
||||
+ f"ef_bottom={optimal_config['ef_bottom']}"
|
||||
)
|
||||
print()
|
||||
|
||||
print("最优组合 ", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(optimal_config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("📈 测试结果分析")
|
||||
print("=" * 80)
|
||||
|
||||
if success_rate >= 50:
|
||||
print("✅ 最优参数组合测试通过!")
|
||||
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
|
||||
print(f" - 平均准确率: {avg_acc:.1f}%")
|
||||
else:
|
||||
print("❌ 最优参数组合未达到50%成功率")
|
||||
print(f" - 成功率: {success_rate:.1f}% (<50%)")
|
||||
print(f" - 平均准确率: {avg_acc:.1f}%")
|
||||
|
||||
return
|
||||
|
||||
# 以下是控制变量测试代码 - 已注释掉
|
||||
"""
|
||||
# 控制变量测试 - 每次只改变一个参数
|
||||
test_results = []
|
||||
|
||||
print("=" * 80)
|
||||
print("1测试 max_connections (M) 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for m in [4, 8, 12, 16, 20]:
|
||||
config = base_config.copy()
|
||||
config["max_connections"] = m
|
||||
config_name = f"M={m}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "max_connections",
|
||||
"value": m,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("2测试 max_level (L) 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for l in [3, 4, 5, 6, 7]:
|
||||
config = base_config.copy()
|
||||
config["max_level"] = l
|
||||
config_name = f"L={l}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "max_level",
|
||||
"value": l,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("3测试 level_prob (P) 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for p in [0.2, 0.3, 0.4, 0.5, 0.6]:
|
||||
config = base_config.copy()
|
||||
config["level_prob"] = p
|
||||
config_name = f"P={p}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "level_prob",
|
||||
"value": p,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("4️⃣ 测试 ef_bottom 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for ef in [10, 16, 25, 40, 60]:
|
||||
config = base_config.copy()
|
||||
config["ef_bottom"] = ef
|
||||
config_name = f"ef_b={ef}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "ef_bottom",
|
||||
"value": ef,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
# 总结报告
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("📈 测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
passed_configs = [r for r in test_results if r["pass"]]
|
||||
print(f"总测试配置数: {len(test_results)}")
|
||||
print(
|
||||
f"成功率≥50%的配置: {len(passed_configs)} ({len(passed_configs)/len(test_results)*100:.1f}%)"
|
||||
)
|
||||
print()
|
||||
|
||||
if passed_configs:
|
||||
print("🏆 成功率≥50%的参数配置:")
|
||||
for result in passed_configs:
|
||||
print(
|
||||
f" {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%, 平均准确率 {result['avg_accuracy']:.1f}%"
|
||||
)
|
||||
|
||||
# 找出每个参数的最佳值
|
||||
print()
|
||||
print("🎯 各参数最佳值推荐:")
|
||||
for param_name in ["max_connections", "max_level", "level_prob", "ef_bottom"]:
|
||||
param_results = [r for r in test_results if r["param"] == param_name]
|
||||
best_result = max(param_results, key=lambda x: x["success_rate"])
|
||||
print(
|
||||
f" {param_name}: {best_result['value']} (成功率: {best_result['success_rate']:.1f}%)"
|
||||
)
|
||||
else:
|
||||
print("❌ 没有配置的成功率达到50%")
|
||||
print("📊 按成功率排序的前5个配置:")
|
||||
top_configs = sorted(
|
||||
test_results, key=lambda x: x["success_rate"], reverse=True
|
||||
)[:5]
|
||||
for i, result in enumerate(top_configs, 1):
|
||||
print(
|
||||
f" {i}. {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%"
|
||||
)
|
||||
|
||||
# 清理临时文件
|
||||
Path("./test_output.jsonl").unlink(missing_ok=True)
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user