feat: remove hnsw

This commit is contained in:
sangge 2025-08-27 13:13:20 +08:00
parent 7fae6b23b7
commit 25675228f4
7 changed files with 71 additions and 909 deletions

64
CLAUDE.md Normal file
View File

@ -0,0 +1,64 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
This project implements a K-Nearest Neighbors (KNN) algorithm using Fully Homomorphic Encryption (FHE) in Rust. The implementation uses TFHE-rs for cryptographic operations and operates on a 10-dimensional synthetic dataset with 100 training points.
## Key Architecture
- **Homomorphic Encryption**: Uses TFHE-rs library for fully homomorphic encryption operations
- **Data Processing**: Synthetic dataset features are scaled by 10x to preserve decimal precision as integers
- **KNN Implementation**: Complete implementation with multiple algorithms:
- `euclidean_distance()`: Optimized distance calculation using precomputed squares
- `perform_knn_selection()`: Supports selection sort, bitonic sort, and heap-based selection
- `encrypted_bitonic_sort()`: Parallel bitonic sort with power-of-2 padding
## Build and Development Commands
```bash
# Build the project
cargo build
# Run with different algorithms
cargo run --bin enc # Default selection sort
cargo run --bin enc -- --algorithm=bitonic # Bitonic sort (fastest for large datasets)
cargo run --bin enc -- --algorithm=heap # Heap-based selection
cargo run --bin enc -- --debug # Debug mode with plaintext verification
cargo run --bin plain # Plaintext version for comparison
# Development commands - ALWAYS use cargo check for verification
cargo check # Use this for code verification, NOT cargo run
cargo test
cargo fmt
cargo clippy
```
## Data Structure
The project processes synthetic 10-dimensional dataset with these key data structures:
- `EncryptedQuery`: Query point with precomputed values for optimization
- `EncryptedPoint`: Training data points with precomputed squared sums
- `EncryptedNeighbor`: Distance and index pairs for KNN results
- Custom deserializer converts float values to scaled integers (×10) for FHE compatibility
## Dataset
- **Training Data**: `dataset/train.jsonl` containing one query point and 100 10-dimensional training points
- **Results**: `dataset/answer.jsonl` and `dataset/answer1.jsonl` contain KNN classification results in JSON format
## Important Technical Notes
- **FheInt14 Range**: Valid range is -8192 to 8191 (2^13). Using values outside this range (like i16::MAX = 32767) will cause overflow
- **Bitonic Sort**: Requires `up=true` for ascending order to get smallest distances first. Using `false` gives largest distances (wrong for KNN)
- **Performance**: Bitonic sort is fastest for larger datasets due to parallel processing, but requires power-of-2 padding
## Git Workflow Instructions
**IMPORTANT**: When user asks to "write commit" or "帮我写commit":
- Do NOT add any files to staging area
- User has already staged the files they want to commit
- Only create the commit with appropriate message for the staged changes

View File

@ -321,60 +321,3 @@ fn encrypted_conditional_swap(
a.index = new_a_index;
b.index = new_b_index;
}
///// 执行HNSW近似最近邻搜索
/////
///// # Arguments
///// * `graph` - 预构建的FHE HNSW图结构
///// * `query` - 加密的查询点
///// * `zero` - 加密的零值
/////
///// # Returns
///// * 10个最近邻的索引列表
//pub fn perform_hnsw_search(
// graph: &FheHnswGraph,
// query: &EncryptedQuery,
// zero: &FheInt14,
//) -> Vec<usize> {
// println!("🚀 Starting HNSW approximate search...");
//
// if graph.nodes.is_empty() {
// println!("❌ Empty HNSW graph");
// return Vec::new();
// }
//
// let Some(entry_point) = graph.entry_point else {
// println!("❌ No entry point in HNSW graph");
// return Vec::new();
// };
//
// println!(
// "🔍 HNSW search from entry point {} at level {}",
// entry_point, graph.max_level
// );
//
// let mut current_candidates = vec![entry_point];
//
// // 从最高层逐层搜索到第1层
// for layer in (1..=graph.max_level).rev() {
// println!("🔍 Searching layer {layer} with ef=1...");
// let layer_start = Instant::now();
// current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
// println!(
// "✅ Layer {} search completed in {}, {} candidates",
// layer,
// format_duration(layer_start.elapsed()),
// current_candidates.len()
// );
// }
//
// // 在第0层进行最终搜索
// let final_search_start = Instant::now();
// let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
// println!(
// "✅ Final layer search completed in {}, {} candidates",
// format_duration(final_search_start.elapsed()),
// final_candidates.len()
// );
// final_candidates
//}

View File

@ -26,7 +26,7 @@ struct Args {
#[arg(
long,
default_value = "bitonic",
help = "Algorithm: selection, bitonic, heap, hnsw"
help = "Algorithm: selection, bitonic, heap"
)]
algorithm: String,
#[arg(long, help = "Enable debug mode (plaintext calculation first)")]
@ -143,24 +143,6 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
let query_encrypted = encrypt_query(&dataset.query, client_key);
let k = 10; // Number of nearest neighbors
// if args.algorithm == "hnsw" {
// // HNSW 算法路径
// println!("🚀 Using HNSW algorithm...");
//
// // 1. 构建明文HNSW图
// let (plaintext_nodes, entry_point, max_level) =
// build_plaintext_hnsw_graph(&dataset.data);
//
// // 2. 转换为加密HNSW图
// println!("🔐 Converting to encrypted HNSW graph...");
// let fhe_graph =
// build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
//
// // 3. 执行HNSW搜索
// let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
// let answer = perform_hnsw_search(&fhe_graph, &query_encrypted, &encrypted_zero);
// results.push(Prediction { answer });
// } else {
// 传统算法路径
// Encrypt all training points
println!("🔐 Encrypting training points...");

View File

@ -1,9 +1,6 @@
#![allow(dead_code)]
use anyhow::Result;
use clap::Parser;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
@ -15,28 +12,6 @@ struct Args {
dataset: String,
#[arg(long, default_value = "./dataset/answer1.jsonl")]
predictions: String,
#[arg(
long,
default_value = "8",
help = "Max connections per node (M parameter)"
)]
max_connections: usize,
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
max_level: usize,
#[arg(long, default_value = "0.6", help = "Level selection probability")]
level_prob: f32,
#[arg(
long,
default_value = "1",
help = "ef parameter for upper layer search"
)]
ef_upper: usize,
#[arg(
long,
default_value = "10",
help = "ef parameter for bottom layer search"
)]
ef_bottom: usize,
}
#[derive(Deserialize)]
@ -82,278 +57,25 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
}
#[derive(Clone)]
struct HNSWNode {
vector: Vec<f64>,
level: usize,
neighbors: Vec<Vec<usize>>,
}
#[derive(Clone)]
struct HNSWGraph {
nodes: Vec<HNSWNode>,
entry_point: Option<usize>,
max_level: usize,
max_connections: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct OrderedFloat(f64);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
}
}
impl HNSWGraph {
fn new(max_level: usize, max_connections: usize) -> Self {
Self {
nodes: Vec::new(),
entry_point: None,
max_level,
max_connections,
}
}
fn insert_node(&mut self, vector: Vec<f64>, level_prob: f32) -> usize {
let level = self.select_level(level_prob);
let node_id = self.nodes.len();
let node = HNSWNode {
vector,
level,
neighbors: vec![Vec::new(); level + 1],
};
// 先添加节点到向量中
self.nodes.push(node);
if node_id == 0 {
self.entry_point = Some(node_id);
return node_id;
}
let mut current_candidates = vec![self.entry_point.unwrap()];
// 从最高层搜索到目标层+1
for lc in (level + 1..=self.max_level).rev() {
current_candidates =
self.search_layer(&self.nodes[node_id].vector, current_candidates, 1, lc);
}
// 从目标层到第0层建立连接
for lc in (0..=level).rev() {
current_candidates = self.search_layer(
&self.nodes[node_id].vector,
current_candidates,
self.max_connections,
lc,
);
for &candidate_id in &current_candidates {
self.connect_nodes(node_id, candidate_id, lc);
}
}
// 更新入口点
if level > self.nodes[self.entry_point.unwrap()].level {
self.entry_point = Some(node_id);
}
node_id
}
fn search_layer(
&self,
query: &[f64],
entry_points: Vec<usize>,
ef: usize,
layer: usize,
) -> Vec<usize> {
let mut visited = std::collections::HashSet::new();
let mut candidates = std::collections::BinaryHeap::new();
let mut w = std::collections::BinaryHeap::new();
// 初始化候选点
for &ep in &entry_points {
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
let dist = euclidean_distance(query, &self.nodes[ep].vector);
candidates.push(std::cmp::Reverse((OrderedFloat(dist), ep)));
w.push((OrderedFloat(dist), ep));
visited.insert(ep);
}
}
while let Some(std::cmp::Reverse((current_dist, current))) = candidates.pop() {
// 如果当前距离已经比最远的结果距离大,停止搜索
if let Some(&(farthest_dist, _)) = w.iter().max() {
if current_dist > farthest_dist && w.len() >= ef {
break;
}
}
// 探索当前节点的邻居
if current < self.nodes.len() && layer < self.nodes[current].neighbors.len() {
for &neighbor in &self.nodes[current].neighbors[layer] {
if !visited.contains(&neighbor) && neighbor < self.nodes.len() {
visited.insert(neighbor);
let dist = euclidean_distance(query, &self.nodes[neighbor].vector);
let ordered_dist = OrderedFloat(dist);
if w.len() < ef {
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
w.push((ordered_dist, neighbor));
} else if let Some(&(farthest_dist, _)) = w.iter().max() {
if ordered_dist < farthest_dist {
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
w.push((ordered_dist, neighbor));
// 移除最远的点
if let Some(max_item) = w.iter().max().copied() {
w.retain(|&x| x != max_item);
}
}
}
}
}
}
}
// 返回按距离排序的结果
let mut results: Vec<_> = w.into_iter().collect();
results.sort_by(|a, b| a.0.cmp(&b.0));
results.into_iter().map(|(_, idx)| idx).collect()
}
fn select_level(&self, level_prob: f32) -> usize {
let mut rng = rand::rng();
let mut level = 0;
while level < self.max_level && rng.random::<f32>() < level_prob {
level += 1;
}
level
}
fn connect_nodes(&mut self, node1: usize, node2: usize, layer: usize) {
self.nodes[node1].neighbors[layer].push(node2);
self.nodes[node2].neighbors[layer].push(node1);
if self.nodes[node1].neighbors[layer].len() > self.max_connections {
self.prune_node_connections(node1, layer);
}
if self.nodes[node2].neighbors[layer].len() > self.max_connections {
self.prune_node_connections(node2, layer);
}
}
fn prune_node_connections(&mut self, node_id: usize, _layer: usize) {
let node_vector = self.nodes[node_id].vector.clone();
let nodes = self.nodes.clone(); // Create immutable reference before mutable borrow
// 计算所有邻居的距离并排序
let neighbors = &mut self.nodes[node_id].neighbors[_layer];
let mut neighbor_distances: Vec<(f64, usize)> = neighbors
.iter()
.map(|&neighbor_id| {
let dist = euclidean_distance(&node_vector, &nodes[neighbor_id].vector);
(dist, neighbor_id)
})
.collect();
// 按距离排序
neighbor_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// 更新邻居列表,只保留距离最近的
*neighbors = neighbor_distances
.into_iter()
.take(self.max_connections)
.map(|(_, neighbor_id)| neighbor_id)
.collect();
}
}
// fn main() -> Result<()> {
// let args = Args::parse();
//
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
//
// let mut results = Vec::new();
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// }
//
// Ok(())
// }
fn main() -> Result<()> {
let args = Args::parse();
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
let mut results = Vec::new();
println!("🔧 HNSW Parameters:");
println!(" max_connections: {}", args.max_connections);
println!(" max_level: {}", args.max_level);
println!(" level_prob: {}", args.level_prob);
println!(" ef_upper: {}", args.ef_upper);
println!(" ef_bottom: {}", args.ef_bottom);
let mut results = Vec::new();
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
for data_point in &dataset.data {
hnsw.insert_node(data_point.clone(), args.level_prob);
}
let nearest = if let Some(entry_point) = hnsw.entry_point {
let mut search_results = vec![entry_point];
// 上层搜索使用ef_upper参数
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
search_results =
hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
}
// 底层搜索使用ef_bottom参数
let final_results =
hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
final_results
.into_iter()
.take(10)
.map(|idx| idx + 1)
.collect()
} else {
Vec::new()
};
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
results.push(Prediction { answer: nearest });
}
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
println!("{}", serde_json::to_string(&result)?);
}
Ok(())

View File

@ -128,251 +128,3 @@ pub fn decrypt_indices(encrypted_indices: &[FheUint8], client_key: &ClientKey) -
})
.collect()
}
///// 密文版 HNSW 节点
//#[derive(Clone)]
//pub struct FheHnswNode {
// pub encrypted_point: EncryptedPoint,
// pub level: usize,
// pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
//}
//
///// 密文版 HNSW 图
//#[derive(Clone)]
//pub struct FheHnswGraph {
// pub nodes: Vec<FheHnswNode>,
// pub entry_point: Option<usize>,
// pub max_level: usize,
// pub distances: [Option<FheInt14>; 100],
//}
//
//impl FheHnswGraph {
// pub fn new() -> Self {
// Self {
// nodes: Vec::new(),
// entry_point: None,
// max_level: 0,
// distances: [const { None };100]
// }
// }
//
// /// 搜索层函数
// ///
// /// 核心算法实现标准HNSW贪心搜索
// ///
// /// 算法流程:
// /// 1. 初始化候选队列(candidates)和结果集(w)
// /// 2. 从entry_points开始计算到query的密文距离
// /// 3. 贪心搜索循环:
// /// - 从candidates中选择距离最小的点作为当前探索点
// /// - 探索该点的所有邻居
// /// - 将未访问的邻居加入candidates和w
// /// - 维护w的大小不超过ef移除最远的点
// /// - 实现剪枝如果当前点比w中最远点还远且w已满则停止
// /// 4. 返回w中的节点索引
// ///
// /// 关键挑战:
// /// - 需要模拟优先队列但只能用密文排序
// /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
// /// - 候选队列管理要高效,避免过多密文运算
// pub fn search_layer(
// &self,
// query: &EncryptedQuery,
// entry_points: Vec<usize>,
// ef: usize,
// layer: usize,
// zero: &FheInt14,
// ) -> Vec<usize> {
// let visited: HashSet<usize> = HashSet::new();
// let mut candidate =
//
// // TODO: 设计合适的数据结构
// // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
// // w: 结果集维护当前找到的最好的ef个候选点
//
// // TODO: Step 1 - 初始化候选点
// // 对每个entry_point:
// // - 检查节点是否存在且level >= layer
// // - 计算到query的密文距离euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
// // - 将节点加入visited集合
// // - 将(节点索引, 距离)加入candidates和w
//
// println!(
// "🔍 Starting search with {} initial candidates",
// // TODO: 显示实际初始化的候选点数量
// 0
// );
//
// // TODO: Step 2 - 主搜索循环
// // while candidates不为空:
// // - 从candidates中找到距离最小的点需要密文排序或其他方法
// // - 将该点从candidates中移除设为当前探索点current
// //
// // - 剪枝检查(可选,复杂):
// // 如果w.len() >= ef且current的距离 > w中最远点的距离则break
// //
// // - 探索current的邻居
// // for neighbor in self.nodes[current].neighbors[layer]:
// // - 如果neighbor未访问过
// // - 标记为已访问
// // - 计算到query的密文距离
// // - 将neighbor加入candidates
// // - 管理结果集w
// // if w.len() < ef: 直接加入w
// // else: 加入w后排序移除最远的点保持w大小为ef
//
// println!(
// "🔍 Found {} total candidates (ef={})",
// // TODO: 显示最终找到的候选点数量
// 0,
// ef
// );
//
// // TODO: Step 3 - 返回结果
// // 从w中提取节点索引并返回
// // w.into_iter().map(|(node_idx, _)| node_idx).collect()
//
// // 临时返回空结果,避免编译错误
// Vec::new()
// }
//}
//
//impl Default for FheHnswGraph {
// fn default() -> Self {
// Self::new()
// }
//}
//
///// 从明文数据构建密文 HNSW 图的辅助结构
//#[derive(Clone)]
//pub struct PlaintextHnswNode {
// pub vector: Vec<f64>,
// pub level: usize,
// pub neighbors: Vec<Vec<usize>>,
//}
//
///// 从明文数据构建密文 HNSW 图
//pub fn build_fhe_hnsw_from_plaintext(
// plaintext_nodes: &[PlaintextHnswNode],
// plaintext_entry_point: Option<usize>,
// plaintext_max_level: usize,
// client_key: &ClientKey,
//) -> FheHnswGraph {
// use crate::logging::print_progress_bar;
//
// let mut fhe_nodes = Vec::new();
// let total_nodes = plaintext_nodes.len();
//
// println!("🔐 Encrypting {total_nodes} HNSW nodes...");
//
// for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
// print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
// let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
// let fhe_node = FheHnswNode {
// encrypted_point,
// level: plain_node.level,
// neighbors: plain_node.neighbors.clone(),
// };
// fhe_nodes.push(fhe_node);
// }
//
//
// FheHnswGraph {
// nodes: fhe_nodes,
// entry_point: plaintext_entry_point,
// max_level: plaintext_max_level,
// }
//}
//
///// 构建明文HNSW图
//pub fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
// println!("🔨 Building HNSW graph from {} points...", data.len());
//
// let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
// let max_level = 3; // 最优参数保持3层结构
// let mut nodes = Vec::new();
// let mut entry_point = None;
// let mut current_max_level = 0;
//
// // 创建所有节点
// for (idx, vector) in data.iter().enumerate() {
// let level = select_level_for_node(max_level);
// current_max_level = current_max_level.max(level);
//
// let node = PlaintextHnswNode {
// vector: vector.clone(),
// level,
// neighbors: vec![Vec::new(); level + 1],
// };
// nodes.push(node);
//
// if idx == 0 {
// entry_point = Some(idx);
// }
// }
//
// // 为每个节点建立连接(简化版实现)
// for node_id in 0..nodes.len() {
// print_progress_bar(node_id + 1, nodes.len(), "Building connections");
//
// let node_vector = nodes[node_id].vector.clone();
// let node_level = nodes[node_id].level;
//
// // 为每层找到最近的邻居
// for layer in 0..=node_level {
// let mut distances: Vec<(f64, usize)> = nodes
// .iter()
// .enumerate()
// .filter(|(idx, n)| *idx != node_id && n.level >= layer)
// .map(|(idx, n)| {
// let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
// (dist, idx)
// })
// .collect();
//
// distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
//
// // 添加最近的邻居(限制连接数)
// for &(_, neighbor_id) in distances.iter().take(max_connections) {
// nodes[node_id].neighbors[layer].push(neighbor_id);
// }
// }
//
// // 更新入口点
// if nodes[node_id].level > nodes[entry_point.unwrap()].level {
// entry_point = Some(node_id);
// }
// }
//
// println!(); // 新行
// println!(
// "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
// nodes.len(),
// entry_point,
// current_max_level
// );
//
// (nodes, entry_point, current_max_level)
//}
//
///// 选择节点的层级
//pub fn select_level_for_node(max_level: usize) -> usize {
// let mut rng = rand::rng();
// let mut level = 0;
// while level < max_level && rng.random::<f32>() < 0.6 {
// // 最优参数提高层级概率到0.6
// level += 1;
// }
// level
//}
//
//
//
///// 明文欧几里得距离计算
//fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
// a.iter()
// .zip(b.iter())
// .map(|(x, y)| (x - y).powi(2))
// .sum::<f64>()
// .sqrt()
//}

3
src/main.rs Normal file
View File

@ -0,0 +1,3 @@
fn main() {
todo!();
}

View File

@ -1,304 +0,0 @@
#!/usr/bin/env python3
"""
HNSW参数优化测试脚本 - 控制变量法
针对100个10维向量单次查询的场景进行参数调优
每个配置运行100次记录90%准确率的成功率
"""
import subprocess
import json
from pathlib import Path
# 正确答案
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
def calculate_accuracy(result, correct):
"""计算准确率:匹配元素数量 / 总元素数量"""
matches = len(set(result) & set(correct))
return matches / len(correct) * 100
def run_hnsw_test(max_connections, max_level, level_prob, ef_upper, ef_bottom):
"""运行HNSW测试并返回结果"""
cmd = [
"cargo",
"run",
"--bin",
"plain",
"--",
"--max-connections",
str(max_connections),
"--max-level",
str(max_level),
"--level-prob",
str(level_prob),
"--ef-upper",
str(ef_upper),
"--ef-bottom",
str(ef_bottom),
"--predictions",
"./test_output.jsonl",
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
return None
with open("./test_output.jsonl", "r") as f:
output = json.load(f)
return output["answer"]
except Exception as e:
return None
def test_config_stability(config, runs=100):
"""测试配置的稳定性运行100次统计≥90%准确率的次数"""
success_count = 0
accuracies = []
print(f" 测试中... ", end="", flush=True)
for i in range(runs):
if i % 20 == 0 and i > 0:
print(f"{i}/100 ", end="", flush=True)
result = run_hnsw_test(
config["max_connections"],
config["max_level"],
config["level_prob"],
config["ef_upper"],
config["ef_bottom"],
)
if result is not None:
accuracy = calculate_accuracy(result, CORRECT_ANSWER)
accuracies.append(accuracy)
if accuracy >= 90:
success_count += 1
success_rate = success_count / len(accuracies) * 100 if accuracies else 0
avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
print(f"完成!")
return success_rate, avg_accuracy, len(accuracies)
def main():
"""主测试函数 - 测试最优参数组合"""
print("🚀 HNSW最优参数组合测试")
print(f"🎯 正确答案: {CORRECT_ANSWER}")
print("📊 测试最优参数组合运行100次记录≥90%准确率的成功率")
print()
# 最优参数组合
optimal_config = {
"max_connections": 8,
"max_level": 3,
"level_prob": 0.6,
"ef_upper": 1,
"ef_bottom": 10,
}
print("=" * 80)
print("🏆 测试最优参数组合")
print("=" * 80)
print(
f"参数配置: M={optimal_config['max_connections']}, L={optimal_config['max_level']}, "
+ f"P={optimal_config['level_prob']}, ef_upper={optimal_config['ef_upper']}, "
+ f"ef_bottom={optimal_config['ef_bottom']}"
)
print()
print("最优组合 ", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(optimal_config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
print()
print("=" * 80)
print("📈 测试结果分析")
print("=" * 80)
if success_rate >= 50:
print("✅ 最优参数组合测试通过!")
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
print(f" - 平均准确率: {avg_acc:.1f}%")
else:
print("❌ 最优参数组合未达到50%成功率")
print(f" - 成功率: {success_rate:.1f}% (<50%)")
print(f" - 平均准确率: {avg_acc:.1f}%")
return
# 以下是控制变量测试代码 - 已注释掉
"""
# 控制变量测试 - 每次只改变一个参数
test_results = []
print("=" * 80)
print("1测试 max_connections (M) 参数影响")
print("=" * 80)
for m in [4, 8, 12, 16, 20]:
config = base_config.copy()
config["max_connections"] = m
config_name = f"M={m}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "max_connections",
"value": m,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
print()
print("=" * 80)
print("2测试 max_level (L) 参数影响")
print("=" * 80)
for l in [3, 4, 5, 6, 7]:
config = base_config.copy()
config["max_level"] = l
config_name = f"L={l}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "max_level",
"value": l,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
print()
print("=" * 80)
print("3测试 level_prob (P) 参数影响")
print("=" * 80)
for p in [0.2, 0.3, 0.4, 0.5, 0.6]:
config = base_config.copy()
config["level_prob"] = p
config_name = f"P={p}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "level_prob",
"value": p,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
print()
print("=" * 80)
print("4⃣ 测试 ef_bottom 参数影响")
print("=" * 80)
for ef in [10, 16, 25, 40, 60]:
config = base_config.copy()
config["ef_bottom"] = ef
config_name = f"ef_b={ef}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "ef_bottom",
"value": ef,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
# 总结报告
print()
print("=" * 80)
print("📈 测试总结")
print("=" * 80)
passed_configs = [r for r in test_results if r["pass"]]
print(f"总测试配置数: {len(test_results)}")
print(
f"成功率≥50%的配置: {len(passed_configs)} ({len(passed_configs)/len(test_results)*100:.1f}%)"
)
print()
if passed_configs:
print("🏆 成功率≥50%的参数配置:")
for result in passed_configs:
print(
f" {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%, 平均准确率 {result['avg_accuracy']:.1f}%"
)
# 找出每个参数的最佳值
print()
print("🎯 各参数最佳值推荐:")
for param_name in ["max_connections", "max_level", "level_prob", "ef_bottom"]:
param_results = [r for r in test_results if r["param"] == param_name]
best_result = max(param_results, key=lambda x: x["success_rate"])
print(
f" {param_name}: {best_result['value']} (成功率: {best_result['success_rate']:.1f}%)"
)
else:
print("❌ 没有配置的成功率达到50%")
print("📊 按成功率排序的前5个配置:")
top_configs = sorted(
test_results, key=lambda x: x["success_rate"], reverse=True
)[:5]
for i, result in enumerate(top_configs, 1):
print(
f" {i}. {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%"
)
# 清理临时文件
Path("./test_output.jsonl").unlink(missing_ok=True)
"""
if __name__ == "__main__":
main()