fix: correct run.sh script parameters and disable HNSW code (v0.3.1)

- Fix syntax error in run.sh: remove extra quote and correct --log-path to --log-file
- Comment out HNSW algorithm implementation in enc.rs and algorithms.rs to simplify codebase
- Bump version to 0.3.1 in Cargo.toml
- Remove HNSW implementation guide and test files
- Add comprehensive project writeup documentation
This commit is contained in:
2025-08-06 21:40:12 +08:00
parent 46b3562de0
commit c2d423445d
10 changed files with 454 additions and 947 deletions

2
Cargo.lock generated
View File

@@ -410,7 +410,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hfe_knn"
version = "0.3.0"
version = "0.3.1"
dependencies = [
"anyhow",
"bincode 2.0.1",

View File

@@ -1,6 +1,6 @@
[package]
name = "hfe_knn"
version = "0.3.0"
version = "0.3.1"
edition = "2024"
[dependencies]

View File

@@ -1,143 +0,0 @@
# HNSW Search Layer 实现指南
## 目标
实现标准的HNSW贪心搜索算法但使用密文距离计算匹配明文版本的逻辑和性能。
## 关键数据结构
### 输入参数
- `query: &EncryptedQuery<T>` - 加密的查询点
- `entry_points: Vec<usize>` - 入口点的节点索引列表
- `ef: usize` - 搜索时的候选集大小
- `layer: usize` - 当前搜索的层级
- `zero: &T` - 加密的零值(用于距离计算)
### 内部数据结构建议
```rust
// 候选队列:存储待探索的节点
let mut candidates: Vec<(usize, EncryptedNeighbor<T>)> = Vec::new();
// 结果集维护当前最好的ef个候选点
let mut w: Vec<(usize, EncryptedNeighbor<T>)> = Vec::new();
// 访问标记
let mut visited: HashSet<usize> = HashSet::new();
```
其中 `EncryptedNeighbor<T>` 结构已定义:
```rust
pub struct EncryptedNeighbor<T> {
pub distance: T, // 密文距离
pub index: FheUint8, // 密文索引
}
```
## 实现步骤
### Step 1: 初始化候选点
```rust
for &ep in &entry_points {
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
visited.insert(ep);
let distance = euclidean_distance(query, &self.nodes[ep].encrypted_point, zero);
let neighbor = EncryptedNeighbor {
distance,
index: self.nodes[ep].encrypted_point.index.clone(),
};
candidates.push((ep, neighbor.clone()));
w.push((ep, neighbor));
}
}
```
### Step 2: 主搜索循环
```rust
while !candidates.is_empty() {
// 2.1 找到距离最小的候选点
// 提示需要对candidates中的EncryptedNeighbor按distance排序
// 可以使用 encrypted_selection_sort 或其他方法
// 2.2 移除最小距离的候选点作为当前探索点
let current = /* 从candidates中移除最小距离点的节点索引 */;
// 2.3 剪枝检查(可选,但会影响性能)
// 如果w.len() >= ef 且 current的距离 > w中最远点的距离则break
// 2.4 探索当前节点的邻居
for &neighbor_idx in &self.nodes[current].neighbors[layer] {
if !visited.contains(&neighbor_idx) && neighbor_idx < self.nodes.len() {
visited.insert(neighbor_idx);
let distance = euclidean_distance(query, &self.nodes[neighbor_idx].encrypted_point, zero);
let encrypted_neighbor = EncryptedNeighbor {
distance,
index: self.nodes[neighbor_idx].encrypted_point.index.clone(),
};
// 加入候选队列
candidates.push((neighbor_idx, encrypted_neighbor.clone()));
// 管理结果集w
if w.len() < ef {
w.push((neighbor_idx, encrypted_neighbor));
} else {
// 结果集已满,需要替换最远的点
w.push((neighbor_idx, encrypted_neighbor));
// 排序w只保留前ef个最近的点
// 提示可以先转换为Vec<EncryptedNeighbor>排序后重建w
}
}
}
}
```
### Step 3: 返回结果
```rust
w.into_iter().map(|(node_idx, _)| node_idx).collect()
```
## 性能优化建议
### 1. 减少密文排序次数
- **问题**:每次排序都很昂贵(~2-3分钟
- **策略**
- 只在必要时排序(如候选队列管理、结果集维护)
- 考虑批量处理而不是逐个比较
- 可以适当牺牲一些算法精确性来换取性能
### 2. 候选队列管理
- **明文版本**使用BinaryHeapO(log n)插入和删除
- **密文版本**只能用排序O(n log n)
- **优化**:考虑限制候选队列大小,避免无限增长
### 3. 剪枝策略
- **理想**`current_distance > farthest_w_distance && w.len() >= ef` 则停止
- **现实**:密文比较结果无法直接判断
- **权衡**:可以跳过复杂剪枝,让算法更彻底但稍慢
## 调试提示
### 1. 验证初始化
确保entry_points正确初始化到candidates和w中
### 2. 验证邻居探索
检查是否正确遍历`self.nodes[current].neighbors[layer]`
### 3. 验证visited逻辑
确保不重复访问同一节点
### 4. 验证结果集管理
确保w的大小不超过ef且包含距离最近的点
## 期望性能目标
- **明文版本**:毫秒级
- **密文版本目标**15-20分钟相比当前的100+分钟)
- **准确率目标**80%+相比当前的30%
## 可用的工具函数
- `euclidean_distance(query, point, zero)` - 计算密文欧几里得距离
- `encrypted_selection_sort(distances, k)` - 密文选择排序获取前k个最小值
- `EncryptedNeighbor` - 包装距离和索引的结构体
## 明文版本参考
参考 `src/bin/plain.rs` 中的 `search_layer` 函数实现理解标准HNSW算法的逻辑流程。

2
run.sh
View File

@@ -20,4 +20,4 @@ chmod +x "${SCRIPT_DIR}/test"
"${SCRIPT_DIR}/test" \
--dataset "$DATASET_FILE" \
--predictions "$PREDICTIONS_RESULT_FILE" \
--log-path "/home/admin/workspace/job/logs/user.log'"
--log-file "/home/admin/workspace/job/logs/user.log"

View File

@@ -1,5 +1,5 @@
use crate::EncryptedQuery;
use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph};
use crate::data::{EncryptedNeighbor, EncryptedPoint};
use crate::logging::{format_duration, print_progress_bar};
use rayon::prelude::*;
use std::time::Instant;
@@ -322,93 +322,59 @@ fn encrypted_conditional_swap(
b.index = new_b_index;
}
/// 执行HNSW近似最近邻搜索
///
/// # Arguments
/// * `graph` - 预构建的FHE HNSW图结构
/// * `query` - 加密的查询点
/// * `k` - 返回的最近邻数量
/// * `zero` - 加密的零值
///
/// # Returns
/// * k个最近邻的加密索引列表
pub fn perform_hnsw_search(
graph: &FheHnswGraph,
query: &EncryptedQuery,
k: usize,
zero: &FheInt14,
) -> Vec<FheUint8> {
println!("🚀 Starting HNSW approximate search...");
if graph.nodes.is_empty() {
println!("❌ Empty HNSW graph");
return Vec::new();
}
let Some(entry_point) = graph.entry_point else {
println!("❌ No entry point in HNSW graph");
return Vec::new();
};
println!(
"🔍 HNSW search from entry point {} at level {}",
entry_point, graph.max_level
);
let mut current_candidates = vec![entry_point];
// 从最高层逐层搜索到第1层
for layer in (1..=graph.max_level).rev() {
println!("🔍 Searching layer {layer} with ef=1...");
let layer_start = Instant::now();
current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
println!(
"✅ Layer {} search completed in {}, {} candidates",
layer,
format_duration(layer_start.elapsed()),
current_candidates.len()
);
}
// 在第0层进行最终搜索
println!("🔍 Final search at layer 0 with ef={}...", 10);
let final_search_start = Instant::now();
let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
println!(
"✅ Final layer search completed in {}, {} candidates",
format_duration(final_search_start.elapsed()),
final_candidates.len()
);
// 计算最终候选点的距离并排序
println!("🔢 Computing distances for final candidates...");
let distance_start = Instant::now();
let mut distances = Vec::new();
for (i, &candidate) in final_candidates.iter().enumerate().take(graph.nodes.len()) {
if candidate < graph.nodes.len() {
if i % 10 == 0 && i > 0 {
println!(
"🔢 Processed {}/{} final candidates",
i,
final_candidates.len().min(graph.nodes.len())
);
}
let distance = euclidean_distance(query, &graph.nodes[candidate].encrypted_point, zero);
distances.push(EncryptedNeighbor {
distance,
index: graph.nodes[candidate].encrypted_point.index.clone(),
});
}
}
println!(
"✅ Distance computation completed in {}",
format_duration(distance_start.elapsed())
);
// 选择最好的k个
println!("📊 Selecting top {} from {} candidates", k, distances.len());
let len = distances.len();
encrypted_selection_sort(&mut distances, k.min(len));
distances.iter().take(k).map(|n| n.index.clone()).collect()
}
///// 执行HNSW近似最近邻搜索
/////
///// # Arguments
///// * `graph` - 预构建的FHE HNSW图结构
///// * `query` - 加密的查询点
///// * `zero` - 加密的零值
/////
///// # Returns
///// * 10个最近邻的索引列表
//pub fn perform_hnsw_search(
// graph: &FheHnswGraph,
// query: &EncryptedQuery,
// zero: &FheInt14,
//) -> Vec<usize> {
// println!("🚀 Starting HNSW approximate search...");
//
// if graph.nodes.is_empty() {
// println!("❌ Empty HNSW graph");
// return Vec::new();
// }
//
// let Some(entry_point) = graph.entry_point else {
// println!("❌ No entry point in HNSW graph");
// return Vec::new();
// };
//
// println!(
// "🔍 HNSW search from entry point {} at level {}",
// entry_point, graph.max_level
// );
//
// let mut current_candidates = vec![entry_point];
//
// // 从最高层逐层搜索到第1层
// for layer in (1..=graph.max_level).rev() {
// println!("🔍 Searching layer {layer} with ef=1...");
// let layer_start = Instant::now();
// current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
// println!(
// "✅ Layer {} search completed in {}, {} candidates",
// layer,
// format_duration(layer_start.elapsed()),
// current_candidates.len()
// );
// }
//
// // 在第0层进行最终搜索
// let final_search_start = Instant::now();
// let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
// println!(
// "✅ Final layer search completed in {}, {} candidates",
// format_duration(final_search_start.elapsed()),
// final_candidates.len()
// );
// final_candidates
//}

View File

@@ -2,7 +2,6 @@ use anyhow::Result;
use chrono::Local;
use clap::Parser;
use log::info;
use rand::Rng;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::time::Instant;
@@ -11,9 +10,9 @@ use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key};
// Import from our library modules
use hfe_knn::{
Dataset, EncryptedNeighbor, EncryptedPoint, PlaintextHnswNode, Prediction, ScaleInt,
build_fhe_hnsw_from_plaintext, compute_distances, decrypt_indices, encrypt_point,
encrypt_query, format_duration, perform_hnsw_search, perform_knn_selection, print_progress_bar,
Dataset, EncryptedNeighbor, EncryptedPoint, Prediction, ScaleInt, compute_distances,
decrypt_indices, encrypt_point, encrypt_query, format_duration, perform_knn_selection,
print_progress_bar,
};
#[derive(Parser)]
@@ -87,97 +86,6 @@ fn debug_compute_distances(
encrypted_distances
}
/// 构建明文HNSW图
fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
println!("🔨 Building HNSW graph from {} points...", data.len());
let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
let max_level = 3; // 最优参数保持3层结构
let mut nodes = Vec::new();
let mut entry_point = None;
let mut current_max_level = 0;
// 创建所有节点
for (idx, vector) in data.iter().enumerate() {
let level = select_level_for_node(max_level);
current_max_level = current_max_level.max(level);
let node = PlaintextHnswNode {
vector: vector.clone(),
level,
neighbors: vec![Vec::new(); level + 1],
};
nodes.push(node);
if idx == 0 {
entry_point = Some(idx);
}
}
// 为每个节点建立连接(简化版实现)
for node_id in 0..nodes.len() {
print_progress_bar(node_id + 1, nodes.len(), "Building connections");
let node_vector = nodes[node_id].vector.clone();
let node_level = nodes[node_id].level;
// 为每层找到最近的邻居
for layer in 0..=node_level {
let mut distances: Vec<(f64, usize)> = nodes
.iter()
.enumerate()
.filter(|(idx, n)| *idx != node_id && n.level >= layer)
.map(|(idx, n)| {
let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
(dist, idx)
})
.collect();
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// 添加最近的邻居(限制连接数)
for &(_, neighbor_id) in distances.iter().take(max_connections) {
nodes[node_id].neighbors[layer].push(neighbor_id);
}
}
// 更新入口点
if nodes[node_id].level > nodes[entry_point.unwrap()].level {
entry_point = Some(node_id);
}
}
println!(); // 新行
println!(
"✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
nodes.len(),
entry_point,
current_max_level
);
(nodes, entry_point, current_max_level)
}
/// 选择节点的层级
fn select_level_for_node(max_level: usize) -> usize {
let mut rng = rand::rng();
let mut level = 0;
while level < max_level && rng.random::<f32>() < 0.6 {
// 最优参数提高层级概率到0.6
level += 1;
}
level
}
/// 明文欧几里得距离计算
fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
fn main() -> Result<()> {
let args = Args::parse();
let start_time = Instant::now();
@@ -235,67 +143,62 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
let query_encrypted = encrypt_query(&dataset.query, client_key);
let k = 10; // Number of nearest neighbors
let encrypted_neighbors = if args.algorithm == "hnsw" {
// HNSW 算法路径
println!("🚀 Using HNSW algorithm...");
// if args.algorithm == "hnsw" {
// // HNSW 算法路径
// println!("🚀 Using HNSW algorithm...");
//
// // 1. 构建明文HNSW图
// let (plaintext_nodes, entry_point, max_level) =
// build_plaintext_hnsw_graph(&dataset.data);
//
// // 2. 转换为加密HNSW图
// println!("🔐 Converting to encrypted HNSW graph...");
// let fhe_graph =
// build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
//
// // 3. 执行HNSW搜索
// let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
// let answer = perform_hnsw_search(&fhe_graph, &query_encrypted, &encrypted_zero);
// results.push(Prediction { answer });
// } else {
// 传统算法路径
// Encrypt all training points
println!("🔐 Encrypting training points...");
let points_encrypted: Vec<EncryptedPoint> = dataset
.data
.iter()
.enumerate()
.map(|(idx, coords)| encrypt_point(coords, idx + 1, client_key))
.collect();
// 1. 构建明文HNSW图
let (plaintext_nodes, entry_point, max_level) =
build_plaintext_hnsw_graph(&dataset.data);
// 2. 转换为加密HNSW图
println!("🔐 Converting to encrypted HNSW graph...");
let fhe_graph =
build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
// 3. 执行HNSW搜索
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero)
// Compute distances
let mut distances = if args.debug {
debug_compute_distances(&dataset.query, &dataset.data, &points_encrypted, client_key)
} else {
// 传统算法路径
// Encrypt all training points
println!("🔐 Encrypting training points...");
let points_encrypted: Vec<EncryptedPoint> = dataset
.data
.iter()
.enumerate()
.map(|(idx, coords)| encrypt_point(coords, idx + 1, client_key))
.collect();
// Compute distances
let mut distances = if args.debug {
debug_compute_distances(
&dataset.query,
&dataset.data,
&points_encrypted,
client_key,
)
} else {
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
};
// Perform KNN selection using the specified algorithm
let max_distance = if args.algorithm == "bitonic" {
Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
} else {
None
};
let max_index = if args.algorithm == "bitonic" {
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
} else {
None
};
perform_knn_selection(
&mut distances,
k,
&args.algorithm,
max_distance.as_ref(),
max_index.as_ref(),
)
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
};
// Perform KNN selection using the specified algorithm
let max_distance = if args.algorithm == "bitonic" {
Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
} else {
None
};
let max_index = if args.algorithm == "bitonic" {
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
} else {
None
};
let encrypted_neighbors = perform_knn_selection(
&mut distances,
k,
&args.algorithm,
max_distance.as_ref(),
max_index.as_ref(),
);
// Decrypt the results
println!("🔓 Decrypting results...");
let decrypted_indices = decrypt_indices(&encrypted_neighbors, client_key);

View File

@@ -52,7 +52,8 @@ struct Prediction {
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
// 模拟加密版本的缩放和精度损失
let scaled_distance: f64 = a.iter()
let scaled_distance: f64 = a
.iter()
.zip(b.iter())
.map(|(x, y)| {
// 缩放坐标乘以10然后舍弃小数
@@ -61,7 +62,7 @@ fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
(scaled_x - scaled_y).powi(2)
})
.sum::<f64>();
scaled_distance
}
@@ -280,78 +281,80 @@ impl HNSWGraph {
}
}
fn main() -> Result<()> {
let args = Args::parse();
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
let mut results = Vec::new();
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
results.push(Prediction { answer: nearest });
}
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
}
Ok(())
}
// fn main() -> Result<()> {
// let args = Args::parse();
//
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
// let mut results = Vec::new();
//
// println!("🔧 HNSW Parameters:");
// println!(" max_connections: {}", args.max_connections);
// println!(" max_level: {}", args.max_level);
// println!(" level_prob: {}", args.level_prob);
// println!(" ef_upper: {}", args.ef_upper);
// println!(" ef_bottom: {}", args.ef_bottom);
// let mut results = Vec::new();
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
// for data_point in &dataset.data {
// hnsw.insert_node(data_point.clone(), args.level_prob);
// }
//
// let nearest = if let Some(entry_point) = hnsw.entry_point {
// let mut search_results = vec![entry_point];
//
// // 上层搜索使用ef_upper参数
// for layer in (1..=hnsw.nodes[entry_point].level).rev() {
// search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
// }
//
// // 底层搜索使用ef_bottom参数
// let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
// final_results
// .into_iter()
// .take(10)
// .map(|idx| idx + 1)
// .collect()
// } else {
// Vec::new()
// };
//
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// println!("{}", serde_json::to_string(&result)?);
// }
//
// Ok(())
// }
fn main() -> Result<()> {
let args = Args::parse();
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
let mut results = Vec::new();
println!("🔧 HNSW Parameters:");
println!(" max_connections: {}", args.max_connections);
println!(" max_level: {}", args.max_level);
println!(" level_prob: {}", args.level_prob);
println!(" ef_upper: {}", args.ef_upper);
println!(" ef_bottom: {}", args.ef_bottom);
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
for data_point in &dataset.data {
hnsw.insert_node(data_point.clone(), args.level_prob);
}
let nearest = if let Some(entry_point) = hnsw.entry_point {
let mut search_results = vec![entry_point];
// 上层搜索使用ef_upper参数
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
search_results =
hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
}
// 底层搜索使用ef_bottom参数
let final_results =
hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
final_results
.into_iter()
.take(10)
.map(|idx| idx + 1)
.collect()
} else {
Vec::new()
};
results.push(Prediction { answer: nearest });
}
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
println!("{}", serde_json::to_string(&result)?);
}
Ok(())
}

View File

@@ -1,5 +1,4 @@
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use tfhe::prelude::*;
use tfhe::{ClientKey, FheInt14, FheUint8};
@@ -130,165 +129,250 @@ pub fn decrypt_indices(encrypted_indices: &[FheUint8], client_key: &ClientKey) -
.collect()
}
/// 密文版 HNSW 节点
#[derive(Clone)]
pub struct FheHnswNode {
pub encrypted_point: EncryptedPoint,
pub level: usize,
pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
}
/// 密文版 HNSW 图
#[derive(Clone)]
pub struct FheHnswGraph {
pub nodes: Vec<FheHnswNode>,
pub entry_point: Option<usize>,
pub max_level: usize,
}
impl FheHnswGraph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
entry_point: None,
max_level: 0,
}
}
/// 密文版搜索层函数
///
/// 核心算法实现标准HNSW贪心搜索但使用密文距离计算
///
/// 算法流程:
/// 1. 初始化候选队列(candidates)和结果集(w)
/// 2. 从entry_points开始计算到query的密文距离
/// 3. 贪心搜索循环:
/// - 从candidates中选择距离最小的点作为当前探索点
/// - 探索该点的所有邻居
/// - 将未访问的邻居加入candidates和w
/// - 维护w的大小不超过ef移除最远的点
/// - 实现剪枝如果当前点比w中最远点还远且w已满则停止
/// 4. 返回w中的节点索引
///
/// 关键挑战:
/// - 需要模拟优先队列但只能用密文排序
/// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
/// - 候选队列管理要高效,避免过多密文运算
pub fn search_layer(
&self,
_query: &EncryptedQuery,
entry_points: Vec<usize>,
ef: usize,
layer: usize,
_zero: &FheInt14,
) -> Vec<usize> {
let _visited: HashSet<usize> = HashSet::new();
// TODO: 设计合适的数据结构
// candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
// w: 结果集维护当前找到的最好的ef个候选点
println!(
"🔍 Initializing search layer {} with {} entry points",
layer,
entry_points.len()
);
// TODO: Step 1 - 初始化候选点
// 对每个entry_point:
// - 检查节点是否存在且level >= layer
// - 计算到query的密文距离euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
// - 将节点加入visited集合
// - 将(节点索引, 距离)加入candidates和w
println!(
"🔍 Starting search with {} initial candidates",
// TODO: 显示实际初始化的候选点数量
0
);
// TODO: Step 2 - 主搜索循环
// while candidates不为空:
// - 从candidates中找到距离最小的点需要密文排序或其他方法
// - 将该点从candidates中移除设为当前探索点current
//
// - 剪枝检查(可选,复杂)
// 如果w.len() >= ef且current的距离 > w中最远点的距离则break
//
// - 探索current的邻居
// for neighbor in self.nodes[current].neighbors[layer]:
// - 如果neighbor未访问过:
// - 标记为已访问
// - 计算到query的密文距离
// - 将neighbor加入candidates
// - 管理结果集w
// if w.len() < ef: 直接加入w
// else: 加入w后排序移除最远的点保持w大小为ef
println!(
"🔍 Found {} total candidates (ef={})",
// TODO: 显示最终找到的候选点数量
0,
ef
);
// TODO: Step 3 - 返回结果
// 从w中提取节点索引并返回
// w.into_iter().map(|(node_idx, _)| node_idx).collect()
// 临时返回空结果,避免编译错误
Vec::new()
}
}
impl Default for FheHnswGraph {
fn default() -> Self {
Self::new()
}
}
/// 从明文数据构建密文 HNSW 图的辅助结构
#[derive(Clone)]
pub struct PlaintextHnswNode {
pub vector: Vec<f64>,
pub level: usize,
pub neighbors: Vec<Vec<usize>>,
}
/// 从明文数据构建密文 HNSW 图
pub fn build_fhe_hnsw_from_plaintext(
plaintext_nodes: &[PlaintextHnswNode],
plaintext_entry_point: Option<usize>,
plaintext_max_level: usize,
client_key: &ClientKey,
) -> FheHnswGraph {
use crate::logging::print_progress_bar;
let mut fhe_nodes = Vec::new();
let total_nodes = plaintext_nodes.len();
println!("🔐 Encrypting {total_nodes} HNSW nodes...");
for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
let fhe_node = FheHnswNode {
encrypted_point,
level: plain_node.level,
neighbors: plain_node.neighbors.clone(),
};
fhe_nodes.push(fhe_node);
}
println!(); // 新行
println!(
"✅ FHE HNSW graph created with {} encrypted nodes",
fhe_nodes.len()
);
FheHnswGraph {
nodes: fhe_nodes,
entry_point: plaintext_entry_point,
max_level: plaintext_max_level,
}
}
///// 密文版 HNSW 节点
//#[derive(Clone)]
//pub struct FheHnswNode {
// pub encrypted_point: EncryptedPoint,
// pub level: usize,
// pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
//}
//
///// 密文版 HNSW 图
//#[derive(Clone)]
//pub struct FheHnswGraph {
// pub nodes: Vec<FheHnswNode>,
// pub entry_point: Option<usize>,
// pub max_level: usize,
// pub distances: [Option<FheInt14>; 100],
//}
//
//impl FheHnswGraph {
// pub fn new() -> Self {
// Self {
// nodes: Vec::new(),
// entry_point: None,
// max_level: 0,
// distances: [const { None };100]
// }
// }
//
// /// 搜索层函数
// ///
// /// 核心算法实现标准HNSW贪心搜索
// ///
// /// 算法流程:
// /// 1. 初始化候选队列(candidates)和结果集(w)
// /// 2. 从entry_points开始计算到query的密文距离
// /// 3. 贪心搜索循环:
// /// - candidates中选择距离最小的点作为当前探索点
// /// - 探索该点的所有邻居
// /// - 将未访问的邻居加入candidates和w
// /// - 维护w的大小不超过ef移除最远的点
// /// - 实现剪枝如果当前点比w中最远点还远且w已满则停止
// /// 4. 返回w中的节点索引
// ///
// /// 关键挑战:
// /// - 需要模拟优先队列但只能用密文排序
// /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
// /// - 候选队列管理要高效,避免过多密文运算
// pub fn search_layer(
// &self,
// query: &EncryptedQuery,
// entry_points: Vec<usize>,
// ef: usize,
// layer: usize,
// zero: &FheInt14,
// ) -> Vec<usize> {
// let visited: HashSet<usize> = HashSet::new();
// let mut candidate =
//
// // TODO: 设计合适的数据结构
// // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
// // w: 结果集维护当前找到的最好的ef个候选点
//
// // TODO: Step 1 - 初始化候选点
// // 对每个entry_point:
// // - 检查节点是否存在且level >= layer
// // - 计算到query的密文距离euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
// // - 将节点加入visited集合
// // - 将(节点索引, 距离)加入candidates和w
//
// println!(
// "🔍 Starting search with {} initial candidates",
// // TODO: 显示实际初始化的候选点数量
// 0
// );
//
// // TODO: Step 2 - 主搜索循环
// // while candidates不为空:
// // - 从candidates中找到距离最小的点需要密文排序或其他方法
// // - 将该点从candidates中移除设为当前探索点current
// //
// // - 剪枝检查(可选,复杂):
// // 如果w.len() >= ef且current的距离 > w中最远点的距离则break
// //
// // - 探索current的邻居
// // for neighbor in self.nodes[current].neighbors[layer]:
// // - 如果neighbor未访问过
// // - 标记为已访问
// // - 计算到query的密文距离
// // - neighbor加入candidates
// // - 管理结果集w
// // if w.len() < ef: 直接加入w
// // else: 加入w后排序移除最远的点保持w大小为ef
//
// println!(
// "🔍 Found {} total candidates (ef={})",
// // TODO: 显示最终找到的候选点数量
// 0,
// ef
// );
//
// // TODO: Step 3 - 返回结果
// // 从w中提取节点索引并返回
// // w.into_iter().map(|(node_idx, _)| node_idx).collect()
//
// // 临时返回空结果,避免编译错误
// Vec::new()
// }
//}
//
//impl Default for FheHnswGraph {
// fn default() -> Self {
// Self::new()
// }
//}
//
///// 从明文数据构建密文 HNSW 图的辅助结构
//#[derive(Clone)]
//pub struct PlaintextHnswNode {
// pub vector: Vec<f64>,
// pub level: usize,
// pub neighbors: Vec<Vec<usize>>,
//}
//
///// 从明文数据构建密文 HNSW 图
//pub fn build_fhe_hnsw_from_plaintext(
// plaintext_nodes: &[PlaintextHnswNode],
// plaintext_entry_point: Option<usize>,
// plaintext_max_level: usize,
// client_key: &ClientKey,
//) -> FheHnswGraph {
// use crate::logging::print_progress_bar;
//
// let mut fhe_nodes = Vec::new();
// let total_nodes = plaintext_nodes.len();
//
// println!("🔐 Encrypting {total_nodes} HNSW nodes...");
//
// for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
// print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
// let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
// let fhe_node = FheHnswNode {
// encrypted_point,
// level: plain_node.level,
// neighbors: plain_node.neighbors.clone(),
// };
// fhe_nodes.push(fhe_node);
// }
//
//
// FheHnswGraph {
// nodes: fhe_nodes,
// entry_point: plaintext_entry_point,
// max_level: plaintext_max_level,
// }
//}
//
///// 构建明文HNSW图
//pub fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
// println!("🔨 Building HNSW graph from {} points...", data.len());
//
// let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
// let max_level = 3; // 最优参数保持3层结构
// let mut nodes = Vec::new();
// let mut entry_point = None;
// let mut current_max_level = 0;
//
// // 创建所有节点
// for (idx, vector) in data.iter().enumerate() {
// let level = select_level_for_node(max_level);
// current_max_level = current_max_level.max(level);
//
// let node = PlaintextHnswNode {
// vector: vector.clone(),
// level,
// neighbors: vec![Vec::new(); level + 1],
// };
// nodes.push(node);
//
// if idx == 0 {
// entry_point = Some(idx);
// }
// }
//
// // 为每个节点建立连接(简化版实现)
// for node_id in 0..nodes.len() {
// print_progress_bar(node_id + 1, nodes.len(), "Building connections");
//
// let node_vector = nodes[node_id].vector.clone();
// let node_level = nodes[node_id].level;
//
// // 为每层找到最近的邻居
// for layer in 0..=node_level {
// let mut distances: Vec<(f64, usize)> = nodes
// .iter()
// .enumerate()
// .filter(|(idx, n)| *idx != node_id && n.level >= layer)
// .map(|(idx, n)| {
// let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
// (dist, idx)
// })
// .collect();
//
// distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
//
// // 添加最近的邻居(限制连接数)
// for &(_, neighbor_id) in distances.iter().take(max_connections) {
// nodes[node_id].neighbors[layer].push(neighbor_id);
// }
// }
//
// // 更新入口点
// if nodes[node_id].level > nodes[entry_point.unwrap()].level {
// entry_point = Some(node_id);
// }
// }
//
// println!(); // 新行
// println!(
// "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
// nodes.len(),
// entry_point,
// current_max_level
// );
//
// (nodes, entry_point, current_max_level)
//}
//
///// 选择节点的层级
//pub fn select_level_for_node(max_level: usize) -> usize {
// let mut rng = rand::rng();
// let mut level = 0;
// while level < max_level && rng.random::<f32>() < 0.6 {
// // 最优参数提高层级概率到0.6
// level += 1;
// }
// level
//}
//
//
//
///// 明文欧几里得距离计算
//fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
// a.iter()
// .zip(b.iter())
// .map(|(x, y)| (x - y).powi(2))
// .sum::<f64>()
// .sqrt()
//}

View File

@@ -1,338 +0,0 @@
#!/usr/bin/env python3
"""
FHE HNSW并行测试脚本 - 多进程版本
测试密文态HNSW实现的稳定性和准确率
支持多进程并行执行充分利用服务器CPU资源
运行100次记录≥90%准确率的成功率和结果长度分布
"""
import subprocess
import json
import time
import os
import multiprocessing as mp
from pathlib import Path
from collections import defaultdict
# 正确答案
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
def calculate_accuracy(result, correct):
"""计算准确率:匹配元素数量 / 总元素数量"""
matches = len(set(result) & set(correct))
return matches / len(correct) * 100
def run_single_test(test_id, data_bit_width="12", timeout_minutes=30):
"""运行单次FHE HNSW测试供多进程调用"""
# 为每个进程创建独立的输出文件
output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl"
cmd = [
"./enc",
"--algorithm",
"hnsw",
"--data-bit-width",
data_bit_width,
"--predictions",
output_file,
]
start_time = time.time()
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=timeout_minutes * 60
)
test_time = time.time() - start_time
if result.returncode != 0:
return {
"test_id": test_id,
"result": None,
"error": f"Command failed: {result.stderr[:200]}",
"time_minutes": test_time / 60,
}
# 读取结果
if not Path(output_file).exists():
return {
"test_id": test_id,
"result": None,
"error": "Output file not found",
"time_minutes": test_time / 60,
}
with open(output_file, "r") as f:
output = json.load(f)
answer = output["answer"]
# 清理临时文件
Path(output_file).unlink(missing_ok=True)
# 计算准确率
accuracy = calculate_accuracy(answer, CORRECT_ANSWER)
return {
"test_id": test_id,
"result": answer,
"length": len(answer),
"accuracy": accuracy,
"success": accuracy >= 90,
"time_minutes": test_time / 60,
"error": None,
}
except subprocess.TimeoutExpired:
# 清理临时文件
Path(output_file).unlink(missing_ok=True)
return {
"test_id": test_id,
"result": None,
"error": "timeout",
"time_minutes": timeout_minutes,
}
except Exception as e:
# 清理临时文件
Path(output_file).unlink(missing_ok=True)
return {
"test_id": test_id,
"result": None,
"error": str(e),
"time_minutes": (time.time() - start_time) / 60,
}
def print_progress(completed, total, start_time):
"""打印进度信息"""
if completed == 0:
return
elapsed = time.time() - start_time
avg_time = elapsed / completed
remaining_time = avg_time * (total - completed)
print(
f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) "
f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h",
end="",
flush=True,
)
def main():
"""主测试函数"""
print("🚀 FHE HNSW 并行批量测试脚本")
print(f"🎯 正确答案: {CORRECT_ANSWER}")
print("📊 运行100次测试记录准确率和结果长度分布")
print("🔧 支持多进程并行执行")
# 检查enc二进制文件是否存在
if not Path("./enc").exists():
print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:")
print(" cargo build --release --bin enc")
return
# 获取CPU核心数并设置进程数
cpu_cores = mp.cpu_count()
# 考虑到FHE运算的内存密集性使用核心数的一半避免内存不足
num_processes = 8
timeout_minutes = 45 # 增加超时时间到45分钟
print(f"💻 使用 {num_processes} 个并行进程")
print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟")
print(
f"⏱️ 预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时"
)
print()
print("=" * 80)
print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)")
print("=" * 80)
start_time = time.time()
# 创建进程池并执行测试
test_ids = list(range(1, 101)) # 1-100
with mp.Pool(processes=num_processes) as pool:
# 创建异步任务
async_results = []
for test_id in test_ids:
async_result = pool.apply_async(
run_single_test, (test_id, "12", timeout_minutes)
)
async_results.append(async_result)
# 收集结果并显示进度
results = []
completed = 0
print_progress(completed, len(test_ids), start_time)
for async_result in async_results:
try:
result = async_result.get(
timeout=timeout_minutes * 60 + 60
) # 额外1分钟缓冲
results.append(result)
completed += 1
print_progress(completed, len(test_ids), start_time)
except mp.TimeoutError:
# 进程级别超时
results.append(
{
"test_id": len(results) + 1,
"result": None,
"error": "process_timeout",
"time_minutes": timeout_minutes,
}
)
completed += 1
print_progress(completed, len(test_ids), start_time)
print() # 换行
total_elapsed = time.time() - start_time
# 分析结果
print("\n" + "=" * 80)
print("📈 测试结果分析")
print("=" * 80)
# 统计变量
valid_results = [r for r in results if r["result"] is not None]
success_count = sum(1 for r in valid_results if r["success"])
length_distribution = defaultdict(int)
error_distribution = defaultdict(int)
# 统计错误类型
for r in results:
if r["error"]:
error_type = r["error"]
if "timeout" in error_type.lower():
error_distribution["timeout"] += 1
elif "failed" in error_type.lower():
error_distribution["failed"] += 1
else:
error_distribution["other"] += 1
else:
length_distribution[r["length"]] += 1
# 基本统计
total_tests = len(results)
valid_tests = len(valid_results)
if valid_tests == 0:
print("❌ 没有有效的测试结果")
return
success_rate = success_count / valid_tests * 100
avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests
avg_length = sum(r["length"] for r in valid_results) / valid_tests
avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests
print(f"总测试次数: {total_tests}")
print(f"有效测试次数: {valid_tests}")
print(f"成功次数 (≥90%准确率): {success_count}")
print(f"成功率: {success_rate:.1f}%")
print(f"平均准确率: {avg_accuracy:.1f}%")
print(f"平均结果长度: {avg_length:.1f}")
print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟")
print(f"总测试时间: {total_elapsed/3600:.1f}小时")
print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x")
# 错误统计
if error_distribution:
print("\n❌ 错误分布:")
for error_type, count in error_distribution.items():
print(f" {error_type}: {count}")
# 结果长度分布
if length_distribution:
print("\n📊 结果长度分布:")
for length in sorted(length_distribution.keys()):
count = length_distribution[length]
percentage = count / valid_tests * 100
bar = "" * (count // 2) if count > 0 else ""
print(f" 长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}")
# 结论
print()
if success_rate >= 50:
print("✅ 测试通过! FHE HNSW实现稳定性良好")
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
print(f" - 平均准确率: {avg_accuracy:.1f}%")
if avg_length >= 9.5:
print(f" - 平均结果长度: {avg_length:.1f} (接近10)")
else:
print(f" - 平均结果长度: {avg_length:.1f} (需要改进)")
else:
print("❌ 测试未通过,需要进一步优化")
print(f" - 成功率: {success_rate:.1f}% (<50%)")
print(f" - 平均准确率: {avg_accuracy:.1f}%")
# 保存详细结果
report_file = "fhe_hnsw_parallel_test_report.json"
with open(report_file, "w") as f:
json.dump(
{
"summary": {
"total_tests": total_tests,
"valid_tests": valid_tests,
"success_count": success_count,
"success_rate": success_rate,
"avg_accuracy": avg_accuracy,
"avg_length": avg_length,
"avg_time_minutes": avg_time_per_test,
"total_time_hours": total_elapsed / 3600,
"num_processes": num_processes,
"cpu_cores": cpu_cores,
"speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600),
},
"length_distribution": dict(length_distribution),
"error_distribution": dict(error_distribution),
"detailed_results": results,
"correct_answer": CORRECT_ANSWER,
},
f,
indent=2,
)
print(f"\n📁 详细报告已保存到: {report_file}")
# 显示最好和最坏的几个结果
if valid_results:
print("\n🏆 最高准确率的5个结果:")
best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[
:5
]
for i, r in enumerate(best_results, 1):
print(
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
)
print("\n⚠️ 最低准确率的5个结果:")
worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5]
for i, r in enumerate(worst_results, 1):
print(
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
)
if __name__ == "__main__":
try:
# 设置多进程启动方法Linux上默认是fork但spawn更安全
mp.set_start_method("spawn", force=True)
main()
except KeyboardInterrupt:
print("\n\n⏹️ 测试被用户中断")
# 清理可能的临时文件
for f in Path(".").glob("test_fhe_output_*.jsonl"):
f.unlink(missing_ok=True)
except Exception as e:
print(f"\n💥 程序异常: {e}")
# 清理可能的临时文件
for f in Path(".").glob("test_fhe_output_*.jsonl"):
f.unlink(missing_ok=True)

32
writeup.md Normal file
View File

@@ -0,0 +1,32 @@
# 0xfa队 writeup
## 全同态加密算法介绍
加密算法方面选择了thfe算法库方面选择了较为成熟的thfe-rs算法。
### 算法参数
Message bits: 2位
Carry bits: 2位
噪声分布: TUniform (tweaked uniform)
Bootstrap失败概率: ≤ 2^-128 (CPU后端)
## knn算法实现细节
将欧式距离公式拆分:
> sum((a-b)^2)=sum(a^2) - sum(2a\*b) + sum(b^2)
减少了密文态的乘法和加法操作
选择上实现了双调排序将100个距离结果用最大值填充至128个结果。
然后进行并行排序,最后选择前十个密文。
两个操作都使用了rayon库做多核并行计算
## 本地测试结果
在本地i9-10920X12核24线程)情况下运行时间约9min(4min+5min)