fix: correct run.sh script parameters and disable HNSW code (v0.3.1)
- Fix syntax error in run.sh: remove extra quote and correct --log-path to --log-file - Comment out HNSW algorithm implementation in enc.rs and algorithms.rs to simplify codebase - Bump version to 0.3.1 in Cargo.toml - Remove HNSW implementation guide and test files - Add comprehensive project writeup documentation
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -410,7 +410,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hfe_knn"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode 2.0.1",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "hfe_knn"
|
||||
version = "0.3.0"
|
||||
version = "0.3.1"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
# HNSW Search Layer 实现指南
|
||||
|
||||
## 目标
|
||||
实现标准的HNSW贪心搜索算法,但使用密文距离计算,匹配明文版本的逻辑和性能。
|
||||
|
||||
## 关键数据结构
|
||||
|
||||
### 输入参数
|
||||
- `query: &EncryptedQuery<T>` - 加密的查询点
|
||||
- `entry_points: Vec<usize>` - 入口点的节点索引列表
|
||||
- `ef: usize` - 搜索时的候选集大小
|
||||
- `layer: usize` - 当前搜索的层级
|
||||
- `zero: &T` - 加密的零值(用于距离计算)
|
||||
|
||||
### 内部数据结构建议
|
||||
```rust
|
||||
// 候选队列:存储待探索的节点
|
||||
let mut candidates: Vec<(usize, EncryptedNeighbor<T>)> = Vec::new();
|
||||
// 结果集:维护当前最好的ef个候选点
|
||||
let mut w: Vec<(usize, EncryptedNeighbor<T>)> = Vec::new();
|
||||
// 访问标记
|
||||
let mut visited: HashSet<usize> = HashSet::new();
|
||||
```
|
||||
|
||||
其中 `EncryptedNeighbor<T>` 结构已定义:
|
||||
```rust
|
||||
pub struct EncryptedNeighbor<T> {
|
||||
pub distance: T, // 密文距离
|
||||
pub index: FheUint8, // 密文索引
|
||||
}
|
||||
```
|
||||
|
||||
## 实现步骤
|
||||
|
||||
### Step 1: 初始化候选点
|
||||
```rust
|
||||
for &ep in &entry_points {
|
||||
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
|
||||
visited.insert(ep);
|
||||
let distance = euclidean_distance(query, &self.nodes[ep].encrypted_point, zero);
|
||||
let neighbor = EncryptedNeighbor {
|
||||
distance,
|
||||
index: self.nodes[ep].encrypted_point.index.clone(),
|
||||
};
|
||||
candidates.push((ep, neighbor.clone()));
|
||||
w.push((ep, neighbor));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: 主搜索循环
|
||||
```rust
|
||||
while !candidates.is_empty() {
|
||||
// 2.1 找到距离最小的候选点
|
||||
// 提示:需要对candidates中的EncryptedNeighbor按distance排序
|
||||
// 可以使用 encrypted_selection_sort 或其他方法
|
||||
|
||||
// 2.2 移除最小距离的候选点作为当前探索点
|
||||
let current = /* 从candidates中移除最小距离点的节点索引 */;
|
||||
|
||||
// 2.3 剪枝检查(可选,但会影响性能)
|
||||
// 如果w.len() >= ef 且 current的距离 > w中最远点的距离,则break
|
||||
|
||||
// 2.4 探索当前节点的邻居
|
||||
for &neighbor_idx in &self.nodes[current].neighbors[layer] {
|
||||
if !visited.contains(&neighbor_idx) && neighbor_idx < self.nodes.len() {
|
||||
visited.insert(neighbor_idx);
|
||||
let distance = euclidean_distance(query, &self.nodes[neighbor_idx].encrypted_point, zero);
|
||||
let encrypted_neighbor = EncryptedNeighbor {
|
||||
distance,
|
||||
index: self.nodes[neighbor_idx].encrypted_point.index.clone(),
|
||||
};
|
||||
|
||||
// 加入候选队列
|
||||
candidates.push((neighbor_idx, encrypted_neighbor.clone()));
|
||||
|
||||
// 管理结果集w
|
||||
if w.len() < ef {
|
||||
w.push((neighbor_idx, encrypted_neighbor));
|
||||
} else {
|
||||
// 结果集已满,需要替换最远的点
|
||||
w.push((neighbor_idx, encrypted_neighbor));
|
||||
// 排序w,只保留前ef个最近的点
|
||||
// 提示:可以先转换为Vec<EncryptedNeighbor>,排序后重建w
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 3: 返回结果
|
||||
```rust
|
||||
w.into_iter().map(|(node_idx, _)| node_idx).collect()
|
||||
```
|
||||
|
||||
## 性能优化建议
|
||||
|
||||
### 1. 减少密文排序次数
|
||||
- **问题**:每次排序都很昂贵(~2-3分钟)
|
||||
- **策略**:
|
||||
- 只在必要时排序(如候选队列管理、结果集维护)
|
||||
- 考虑批量处理而不是逐个比较
|
||||
- 可以适当牺牲一些算法精确性来换取性能
|
||||
|
||||
### 2. 候选队列管理
|
||||
- **明文版本**:使用BinaryHeap,O(log n)插入和删除
|
||||
- **密文版本**:只能用排序,O(n log n)
|
||||
- **优化**:考虑限制候选队列大小,避免无限增长
|
||||
|
||||
### 3. 剪枝策略
|
||||
- **理想**:`current_distance > farthest_w_distance && w.len() >= ef` 则停止
|
||||
- **现实**:密文比较结果无法直接判断
|
||||
- **权衡**:可以跳过复杂剪枝,让算法更彻底但稍慢
|
||||
|
||||
## 调试提示
|
||||
|
||||
### 1. 验证初始化
|
||||
确保entry_points正确初始化到candidates和w中
|
||||
|
||||
### 2. 验证邻居探索
|
||||
检查是否正确遍历`self.nodes[current].neighbors[layer]`
|
||||
|
||||
### 3. 验证visited逻辑
|
||||
确保不重复访问同一节点
|
||||
|
||||
### 4. 验证结果集管理
|
||||
确保w的大小不超过ef,且包含距离最近的点
|
||||
|
||||
## 期望性能目标
|
||||
|
||||
- **明文版本**:毫秒级
|
||||
- **密文版本目标**:15-20分钟(相比当前的100+分钟)
|
||||
- **准确率目标**:80%+(相比当前的30%)
|
||||
|
||||
## 可用的工具函数
|
||||
|
||||
- `euclidean_distance(query, point, zero)` - 计算密文欧几里得距离
|
||||
- `encrypted_selection_sort(distances, k)` - 密文选择排序,获取前k个最小值
|
||||
- `EncryptedNeighbor` - 包装距离和索引的结构体
|
||||
|
||||
## 明文版本参考
|
||||
|
||||
参考 `src/bin/plain.rs` 中的 `search_layer` 函数实现,理解标准HNSW算法的逻辑流程。
|
||||
2
run.sh
2
run.sh
@@ -20,4 +20,4 @@ chmod +x "${SCRIPT_DIR}/test"
|
||||
"${SCRIPT_DIR}/test" \
|
||||
--dataset "$DATASET_FILE" \
|
||||
--predictions "$PREDICTIONS_RESULT_FILE" \
|
||||
--log-path "/home/admin/workspace/job/logs/user.log'"
|
||||
--log-file "/home/admin/workspace/job/logs/user.log"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::EncryptedQuery;
|
||||
use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph};
|
||||
use crate::data::{EncryptedNeighbor, EncryptedPoint};
|
||||
use crate::logging::{format_duration, print_progress_bar};
|
||||
use rayon::prelude::*;
|
||||
use std::time::Instant;
|
||||
@@ -322,93 +322,59 @@ fn encrypted_conditional_swap(
|
||||
b.index = new_b_index;
|
||||
}
|
||||
|
||||
/// 执行HNSW近似最近邻搜索
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `graph` - 预构建的FHE HNSW图结构
|
||||
/// * `query` - 加密的查询点
|
||||
/// * `k` - 返回的最近邻数量
|
||||
/// * `zero` - 加密的零值
|
||||
///
|
||||
/// # Returns
|
||||
/// * k个最近邻的加密索引列表
|
||||
pub fn perform_hnsw_search(
|
||||
graph: &FheHnswGraph,
|
||||
query: &EncryptedQuery,
|
||||
k: usize,
|
||||
zero: &FheInt14,
|
||||
) -> Vec<FheUint8> {
|
||||
println!("🚀 Starting HNSW approximate search...");
|
||||
|
||||
if graph.nodes.is_empty() {
|
||||
println!("❌ Empty HNSW graph");
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let Some(entry_point) = graph.entry_point else {
|
||||
println!("❌ No entry point in HNSW graph");
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
println!(
|
||||
"🔍 HNSW search from entry point {} at level {}",
|
||||
entry_point, graph.max_level
|
||||
);
|
||||
|
||||
let mut current_candidates = vec![entry_point];
|
||||
|
||||
// 从最高层逐层搜索到第1层
|
||||
for layer in (1..=graph.max_level).rev() {
|
||||
println!("🔍 Searching layer {layer} with ef=1...");
|
||||
let layer_start = Instant::now();
|
||||
current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
|
||||
println!(
|
||||
"✅ Layer {} search completed in {}, {} candidates",
|
||||
layer,
|
||||
format_duration(layer_start.elapsed()),
|
||||
current_candidates.len()
|
||||
);
|
||||
}
|
||||
|
||||
// 在第0层进行最终搜索
|
||||
println!("🔍 Final search at layer 0 with ef={}...", 10);
|
||||
let final_search_start = Instant::now();
|
||||
let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
|
||||
println!(
|
||||
"✅ Final layer search completed in {}, {} candidates",
|
||||
format_duration(final_search_start.elapsed()),
|
||||
final_candidates.len()
|
||||
);
|
||||
|
||||
// 计算最终候选点的距离并排序
|
||||
println!("🔢 Computing distances for final candidates...");
|
||||
let distance_start = Instant::now();
|
||||
let mut distances = Vec::new();
|
||||
for (i, &candidate) in final_candidates.iter().enumerate().take(graph.nodes.len()) {
|
||||
if candidate < graph.nodes.len() {
|
||||
if i % 10 == 0 && i > 0 {
|
||||
println!(
|
||||
"🔢 Processed {}/{} final candidates",
|
||||
i,
|
||||
final_candidates.len().min(graph.nodes.len())
|
||||
);
|
||||
}
|
||||
let distance = euclidean_distance(query, &graph.nodes[candidate].encrypted_point, zero);
|
||||
distances.push(EncryptedNeighbor {
|
||||
distance,
|
||||
index: graph.nodes[candidate].encrypted_point.index.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"✅ Distance computation completed in {}",
|
||||
format_duration(distance_start.elapsed())
|
||||
);
|
||||
|
||||
// 选择最好的k个
|
||||
println!("📊 Selecting top {} from {} candidates", k, distances.len());
|
||||
let len = distances.len();
|
||||
encrypted_selection_sort(&mut distances, k.min(len));
|
||||
|
||||
distances.iter().take(k).map(|n| n.index.clone()).collect()
|
||||
}
|
||||
///// 执行HNSW近似最近邻搜索
|
||||
/////
|
||||
///// # Arguments
|
||||
///// * `graph` - 预构建的FHE HNSW图结构
|
||||
///// * `query` - 加密的查询点
|
||||
///// * `zero` - 加密的零值
|
||||
/////
|
||||
///// # Returns
|
||||
///// * 10个最近邻的索引列表
|
||||
//pub fn perform_hnsw_search(
|
||||
// graph: &FheHnswGraph,
|
||||
// query: &EncryptedQuery,
|
||||
// zero: &FheInt14,
|
||||
//) -> Vec<usize> {
|
||||
// println!("🚀 Starting HNSW approximate search...");
|
||||
//
|
||||
// if graph.nodes.is_empty() {
|
||||
// println!("❌ Empty HNSW graph");
|
||||
// return Vec::new();
|
||||
// }
|
||||
//
|
||||
// let Some(entry_point) = graph.entry_point else {
|
||||
// println!("❌ No entry point in HNSW graph");
|
||||
// return Vec::new();
|
||||
// };
|
||||
//
|
||||
// println!(
|
||||
// "🔍 HNSW search from entry point {} at level {}",
|
||||
// entry_point, graph.max_level
|
||||
// );
|
||||
//
|
||||
// let mut current_candidates = vec![entry_point];
|
||||
//
|
||||
// // 从最高层逐层搜索到第1层
|
||||
// for layer in (1..=graph.max_level).rev() {
|
||||
// println!("🔍 Searching layer {layer} with ef=1...");
|
||||
// let layer_start = Instant::now();
|
||||
// current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
|
||||
// println!(
|
||||
// "✅ Layer {} search completed in {}, {} candidates",
|
||||
// layer,
|
||||
// format_duration(layer_start.elapsed()),
|
||||
// current_candidates.len()
|
||||
// );
|
||||
// }
|
||||
//
|
||||
// // 在第0层进行最终搜索
|
||||
// let final_search_start = Instant::now();
|
||||
// let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
|
||||
// println!(
|
||||
// "✅ Final layer search completed in {}, {} candidates",
|
||||
// format_duration(final_search_start.elapsed()),
|
||||
// final_candidates.len()
|
||||
// );
|
||||
// final_candidates
|
||||
//}
|
||||
|
||||
207
src/bin/enc.rs
207
src/bin/enc.rs
@@ -2,7 +2,6 @@ use anyhow::Result;
|
||||
use chrono::Local;
|
||||
use clap::Parser;
|
||||
use log::info;
|
||||
use rand::Rng;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::time::Instant;
|
||||
@@ -11,9 +10,9 @@ use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key};
|
||||
|
||||
// Import from our library modules
|
||||
use hfe_knn::{
|
||||
Dataset, EncryptedNeighbor, EncryptedPoint, PlaintextHnswNode, Prediction, ScaleInt,
|
||||
build_fhe_hnsw_from_plaintext, compute_distances, decrypt_indices, encrypt_point,
|
||||
encrypt_query, format_duration, perform_hnsw_search, perform_knn_selection, print_progress_bar,
|
||||
Dataset, EncryptedNeighbor, EncryptedPoint, Prediction, ScaleInt, compute_distances,
|
||||
decrypt_indices, encrypt_point, encrypt_query, format_duration, perform_knn_selection,
|
||||
print_progress_bar,
|
||||
};
|
||||
|
||||
#[derive(Parser)]
|
||||
@@ -87,97 +86,6 @@ fn debug_compute_distances(
|
||||
encrypted_distances
|
||||
}
|
||||
|
||||
/// 构建明文HNSW图
|
||||
fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
|
||||
println!("🔨 Building HNSW graph from {} points...", data.len());
|
||||
|
||||
let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
|
||||
let max_level = 3; // 最优参数:保持3层结构
|
||||
let mut nodes = Vec::new();
|
||||
let mut entry_point = None;
|
||||
let mut current_max_level = 0;
|
||||
|
||||
// 创建所有节点
|
||||
for (idx, vector) in data.iter().enumerate() {
|
||||
let level = select_level_for_node(max_level);
|
||||
current_max_level = current_max_level.max(level);
|
||||
|
||||
let node = PlaintextHnswNode {
|
||||
vector: vector.clone(),
|
||||
level,
|
||||
neighbors: vec![Vec::new(); level + 1],
|
||||
};
|
||||
nodes.push(node);
|
||||
|
||||
if idx == 0 {
|
||||
entry_point = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// 为每个节点建立连接(简化版实现)
|
||||
for node_id in 0..nodes.len() {
|
||||
print_progress_bar(node_id + 1, nodes.len(), "Building connections");
|
||||
|
||||
let node_vector = nodes[node_id].vector.clone();
|
||||
let node_level = nodes[node_id].level;
|
||||
|
||||
// 为每层找到最近的邻居
|
||||
for layer in 0..=node_level {
|
||||
let mut distances: Vec<(f64, usize)> = nodes
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, n)| *idx != node_id && n.level >= layer)
|
||||
.map(|(idx, n)| {
|
||||
let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
|
||||
(dist, idx)
|
||||
})
|
||||
.collect();
|
||||
|
||||
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
|
||||
// 添加最近的邻居(限制连接数)
|
||||
for &(_, neighbor_id) in distances.iter().take(max_connections) {
|
||||
nodes[node_id].neighbors[layer].push(neighbor_id);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新入口点
|
||||
if nodes[node_id].level > nodes[entry_point.unwrap()].level {
|
||||
entry_point = Some(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
println!(); // 新行
|
||||
println!(
|
||||
"✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
|
||||
nodes.len(),
|
||||
entry_point,
|
||||
current_max_level
|
||||
);
|
||||
|
||||
(nodes, entry_point, current_max_level)
|
||||
}
|
||||
|
||||
/// 选择节点的层级
|
||||
fn select_level_for_node(max_level: usize) -> usize {
|
||||
let mut rng = rand::rng();
|
||||
let mut level = 0;
|
||||
while level < max_level && rng.random::<f32>() < 0.6 {
|
||||
// 最优参数:提高层级概率到0.6
|
||||
level += 1;
|
||||
}
|
||||
level
|
||||
}
|
||||
|
||||
/// 明文欧几里得距离计算
|
||||
fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f64>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let start_time = Instant::now();
|
||||
@@ -235,67 +143,62 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
|
||||
let query_encrypted = encrypt_query(&dataset.query, client_key);
|
||||
|
||||
let k = 10; // Number of nearest neighbors
|
||||
let encrypted_neighbors = if args.algorithm == "hnsw" {
|
||||
// HNSW 算法路径
|
||||
println!("🚀 Using HNSW algorithm...");
|
||||
// if args.algorithm == "hnsw" {
|
||||
// // HNSW 算法路径
|
||||
// println!("🚀 Using HNSW algorithm...");
|
||||
//
|
||||
// // 1. 构建明文HNSW图
|
||||
// let (plaintext_nodes, entry_point, max_level) =
|
||||
// build_plaintext_hnsw_graph(&dataset.data);
|
||||
//
|
||||
// // 2. 转换为加密HNSW图
|
||||
// println!("🔐 Converting to encrypted HNSW graph...");
|
||||
// let fhe_graph =
|
||||
// build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
|
||||
//
|
||||
// // 3. 执行HNSW搜索
|
||||
// let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
// let answer = perform_hnsw_search(&fhe_graph, &query_encrypted, &encrypted_zero);
|
||||
// results.push(Prediction { answer });
|
||||
// } else {
|
||||
// 传统算法路径
|
||||
// Encrypt all training points
|
||||
println!("🔐 Encrypting training points...");
|
||||
let points_encrypted: Vec<EncryptedPoint> = dataset
|
||||
.data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, coords)| encrypt_point(coords, idx + 1, client_key))
|
||||
.collect();
|
||||
|
||||
// 1. 构建明文HNSW图
|
||||
let (plaintext_nodes, entry_point, max_level) =
|
||||
build_plaintext_hnsw_graph(&dataset.data);
|
||||
|
||||
// 2. 转换为加密HNSW图
|
||||
println!("🔐 Converting to encrypted HNSW graph...");
|
||||
let fhe_graph =
|
||||
build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
|
||||
|
||||
// 3. 执行HNSW搜索
|
||||
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero)
|
||||
// Compute distances
|
||||
let mut distances = if args.debug {
|
||||
debug_compute_distances(&dataset.query, &dataset.data, &points_encrypted, client_key)
|
||||
} else {
|
||||
// 传统算法路径
|
||||
// Encrypt all training points
|
||||
println!("🔐 Encrypting training points...");
|
||||
let points_encrypted: Vec<EncryptedPoint> = dataset
|
||||
.data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, coords)| encrypt_point(coords, idx + 1, client_key))
|
||||
.collect();
|
||||
|
||||
// Compute distances
|
||||
let mut distances = if args.debug {
|
||||
debug_compute_distances(
|
||||
&dataset.query,
|
||||
&dataset.data,
|
||||
&points_encrypted,
|
||||
client_key,
|
||||
)
|
||||
} else {
|
||||
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
|
||||
};
|
||||
|
||||
// Perform KNN selection using the specified algorithm
|
||||
let max_distance = if args.algorithm == "bitonic" {
|
||||
Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let max_index = if args.algorithm == "bitonic" {
|
||||
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
perform_knn_selection(
|
||||
&mut distances,
|
||||
k,
|
||||
&args.algorithm,
|
||||
max_distance.as_ref(),
|
||||
max_index.as_ref(),
|
||||
)
|
||||
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
|
||||
};
|
||||
|
||||
// Perform KNN selection using the specified algorithm
|
||||
let max_distance = if args.algorithm == "bitonic" {
|
||||
Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let max_index = if args.algorithm == "bitonic" {
|
||||
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let encrypted_neighbors = perform_knn_selection(
|
||||
&mut distances,
|
||||
k,
|
||||
&args.algorithm,
|
||||
max_distance.as_ref(),
|
||||
max_index.as_ref(),
|
||||
);
|
||||
|
||||
// Decrypt the results
|
||||
println!("🔓 Decrypting results...");
|
||||
let decrypted_indices = decrypt_indices(&encrypted_neighbors, client_key);
|
||||
|
||||
117
src/bin/plain.rs
117
src/bin/plain.rs
@@ -52,7 +52,8 @@ struct Prediction {
|
||||
|
||||
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
||||
// 模拟加密版本的缩放和精度损失
|
||||
let scaled_distance: f64 = a.iter()
|
||||
let scaled_distance: f64 = a
|
||||
.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| {
|
||||
// 缩放坐标(乘以10,然后舍弃小数)
|
||||
@@ -61,7 +62,7 @@ fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
||||
(scaled_x - scaled_y).powi(2)
|
||||
})
|
||||
.sum::<f64>();
|
||||
|
||||
|
||||
scaled_distance
|
||||
}
|
||||
|
||||
@@ -280,78 +281,80 @@ impl HNSWGraph {
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// fn main() -> Result<()> {
|
||||
// let args = Args::parse();
|
||||
//
|
||||
// let file = File::open(&args.dataset)?;
|
||||
// let reader = BufReader::new(file);
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// println!("🔧 HNSW Parameters:");
|
||||
// println!(" max_connections: {}", args.max_connections);
|
||||
// println!(" max_level: {}", args.max_level);
|
||||
// println!(" level_prob: {}", args.level_prob);
|
||||
// println!(" ef_upper: {}", args.ef_upper);
|
||||
// println!(" ef_bottom: {}", args.ef_bottom);
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// for line in reader.lines() {
|
||||
// let line = line?;
|
||||
// let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
//
|
||||
// let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||
// for data_point in &dataset.data {
|
||||
// hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||
// }
|
||||
//
|
||||
// let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
// let mut search_results = vec![entry_point];
|
||||
//
|
||||
// // 上层搜索使用ef_upper参数
|
||||
// for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
// search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||
// }
|
||||
//
|
||||
// // 底层搜索使用ef_bottom参数
|
||||
// let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||
// final_results
|
||||
// .into_iter()
|
||||
// .take(10)
|
||||
// .map(|idx| idx + 1)
|
||||
// .collect()
|
||||
// } else {
|
||||
// Vec::new()
|
||||
// };
|
||||
//
|
||||
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
// results.push(Prediction { answer: nearest });
|
||||
// }
|
||||
//
|
||||
// let mut output_file = File::create(&args.predictions)?;
|
||||
// for result in results {
|
||||
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
// println!("{}", serde_json::to_string(&result)?);
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut results = Vec::new();
|
||||
|
||||
println!("🔧 HNSW Parameters:");
|
||||
println!(" max_connections: {}", args.max_connections);
|
||||
println!(" max_level: {}", args.max_level);
|
||||
println!(" level_prob: {}", args.level_prob);
|
||||
println!(" ef_upper: {}", args.ef_upper);
|
||||
println!(" ef_bottom: {}", args.ef_bottom);
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||
for data_point in &dataset.data {
|
||||
hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||
}
|
||||
|
||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
let mut search_results = vec![entry_point];
|
||||
|
||||
// 上层搜索使用ef_upper参数
|
||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
search_results =
|
||||
hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||
}
|
||||
|
||||
// 底层搜索使用ef_bottom参数
|
||||
let final_results =
|
||||
hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||
final_results
|
||||
.into_iter()
|
||||
.take(10)
|
||||
.map(|idx| idx + 1)
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
println!("{}", serde_json::to_string(&result)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
410
src/data.rs
410
src/data.rs
@@ -1,5 +1,4 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{ClientKey, FheInt14, FheUint8};
|
||||
|
||||
@@ -130,165 +129,250 @@ pub fn decrypt_indices(encrypted_indices: &[FheUint8], client_key: &ClientKey) -
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 密文版 HNSW 节点
|
||||
#[derive(Clone)]
|
||||
pub struct FheHnswNode {
|
||||
pub encrypted_point: EncryptedPoint,
|
||||
pub level: usize,
|
||||
pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
|
||||
}
|
||||
|
||||
/// 密文版 HNSW 图
|
||||
#[derive(Clone)]
|
||||
pub struct FheHnswGraph {
|
||||
pub nodes: Vec<FheHnswNode>,
|
||||
pub entry_point: Option<usize>,
|
||||
pub max_level: usize,
|
||||
}
|
||||
|
||||
impl FheHnswGraph {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
entry_point: None,
|
||||
max_level: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// 密文版搜索层函数
|
||||
///
|
||||
/// 核心算法:实现标准HNSW贪心搜索,但使用密文距离计算
|
||||
///
|
||||
/// 算法流程:
|
||||
/// 1. 初始化候选队列(candidates)和结果集(w)
|
||||
/// 2. 从entry_points开始,计算到query的密文距离
|
||||
/// 3. 贪心搜索循环:
|
||||
/// - 从candidates中选择距离最小的点作为当前探索点
|
||||
/// - 探索该点的所有邻居
|
||||
/// - 将未访问的邻居加入candidates和w
|
||||
/// - 维护w的大小不超过ef(移除最远的点)
|
||||
/// - 实现剪枝:如果当前点比w中最远点还远且w已满,则停止
|
||||
/// 4. 返回w中的节点索引
|
||||
///
|
||||
/// 关键挑战:
|
||||
/// - 需要模拟优先队列但只能用密文排序
|
||||
/// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
|
||||
/// - 候选队列管理要高效,避免过多密文运算
|
||||
pub fn search_layer(
|
||||
&self,
|
||||
_query: &EncryptedQuery,
|
||||
entry_points: Vec<usize>,
|
||||
ef: usize,
|
||||
layer: usize,
|
||||
_zero: &FheInt14,
|
||||
) -> Vec<usize> {
|
||||
let _visited: HashSet<usize> = HashSet::new();
|
||||
|
||||
// TODO: 设计合适的数据结构
|
||||
// candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
|
||||
// w: 结果集,维护当前找到的最好的ef个候选点
|
||||
|
||||
println!(
|
||||
"🔍 Initializing search layer {} with {} entry points",
|
||||
layer,
|
||||
entry_points.len()
|
||||
);
|
||||
|
||||
// TODO: Step 1 - 初始化候选点
|
||||
// 对每个entry_point:
|
||||
// - 检查节点是否存在且level >= layer
|
||||
// - 计算到query的密文距离:euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
|
||||
// - 将节点加入visited集合
|
||||
// - 将(节点索引, 距离)加入candidates和w
|
||||
|
||||
println!(
|
||||
"🔍 Starting search with {} initial candidates",
|
||||
// TODO: 显示实际初始化的候选点数量
|
||||
0
|
||||
);
|
||||
|
||||
// TODO: Step 2 - 主搜索循环
|
||||
// while candidates不为空:
|
||||
// - 从candidates中找到距离最小的点(需要密文排序或其他方法)
|
||||
// - 将该点从candidates中移除,设为当前探索点current
|
||||
//
|
||||
// - 剪枝检查(可选,复杂):
|
||||
// 如果w.len() >= ef且current的距离 > w中最远点的距离,则break
|
||||
//
|
||||
// - 探索current的邻居:
|
||||
// for neighbor in self.nodes[current].neighbors[layer]:
|
||||
// - 如果neighbor未访问过:
|
||||
// - 标记为已访问
|
||||
// - 计算到query的密文距离
|
||||
// - 将neighbor加入candidates
|
||||
// - 管理结果集w:
|
||||
// if w.len() < ef: 直接加入w
|
||||
// else: 加入w后排序,移除最远的点,保持w大小为ef
|
||||
|
||||
println!(
|
||||
"🔍 Found {} total candidates (ef={})",
|
||||
// TODO: 显示最终找到的候选点数量
|
||||
0,
|
||||
ef
|
||||
);
|
||||
|
||||
// TODO: Step 3 - 返回结果
|
||||
// 从w中提取节点索引并返回
|
||||
// w.into_iter().map(|(node_idx, _)| node_idx).collect()
|
||||
|
||||
// 临时返回空结果,避免编译错误
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FheHnswGraph {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// 从明文数据构建密文 HNSW 图的辅助结构
|
||||
#[derive(Clone)]
|
||||
pub struct PlaintextHnswNode {
|
||||
pub vector: Vec<f64>,
|
||||
pub level: usize,
|
||||
pub neighbors: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
/// 从明文数据构建密文 HNSW 图
|
||||
pub fn build_fhe_hnsw_from_plaintext(
|
||||
plaintext_nodes: &[PlaintextHnswNode],
|
||||
plaintext_entry_point: Option<usize>,
|
||||
plaintext_max_level: usize,
|
||||
client_key: &ClientKey,
|
||||
) -> FheHnswGraph {
|
||||
use crate::logging::print_progress_bar;
|
||||
|
||||
let mut fhe_nodes = Vec::new();
|
||||
let total_nodes = plaintext_nodes.len();
|
||||
|
||||
println!("🔐 Encrypting {total_nodes} HNSW nodes...");
|
||||
|
||||
for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
|
||||
print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
|
||||
let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
|
||||
let fhe_node = FheHnswNode {
|
||||
encrypted_point,
|
||||
level: plain_node.level,
|
||||
neighbors: plain_node.neighbors.clone(),
|
||||
};
|
||||
fhe_nodes.push(fhe_node);
|
||||
}
|
||||
|
||||
println!(); // 新行
|
||||
println!(
|
||||
"✅ FHE HNSW graph created with {} encrypted nodes",
|
||||
fhe_nodes.len()
|
||||
);
|
||||
|
||||
FheHnswGraph {
|
||||
nodes: fhe_nodes,
|
||||
entry_point: plaintext_entry_point,
|
||||
max_level: plaintext_max_level,
|
||||
}
|
||||
}
|
||||
///// 密文版 HNSW 节点
|
||||
//#[derive(Clone)]
|
||||
//pub struct FheHnswNode {
|
||||
// pub encrypted_point: EncryptedPoint,
|
||||
// pub level: usize,
|
||||
// pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
|
||||
//}
|
||||
//
|
||||
///// 密文版 HNSW 图
|
||||
//#[derive(Clone)]
|
||||
//pub struct FheHnswGraph {
|
||||
// pub nodes: Vec<FheHnswNode>,
|
||||
// pub entry_point: Option<usize>,
|
||||
// pub max_level: usize,
|
||||
// pub distances: [Option<FheInt14>; 100],
|
||||
//}
|
||||
//
|
||||
//impl FheHnswGraph {
|
||||
// pub fn new() -> Self {
|
||||
// Self {
|
||||
// nodes: Vec::new(),
|
||||
// entry_point: None,
|
||||
// max_level: 0,
|
||||
// distances: [const { None };100]
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /// 搜索层函数
|
||||
// ///
|
||||
// /// 核心算法:实现标准HNSW贪心搜索
|
||||
// ///
|
||||
// /// 算法流程:
|
||||
// /// 1. 初始化候选队列(candidates)和结果集(w)
|
||||
// /// 2. 从entry_points开始,计算到query的密文距离
|
||||
// /// 3. 贪心搜索循环:
|
||||
// /// - 从candidates中选择距离最小的点作为当前探索点
|
||||
// /// - 探索该点的所有邻居
|
||||
// /// - 将未访问的邻居加入candidates和w
|
||||
// /// - 维护w的大小不超过ef(移除最远的点)
|
||||
// /// - 实现剪枝:如果当前点比w中最远点还远且w已满,则停止
|
||||
// /// 4. 返回w中的节点索引
|
||||
// ///
|
||||
// /// 关键挑战:
|
||||
// /// - 需要模拟优先队列但只能用密文排序
|
||||
// /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
|
||||
// /// - 候选队列管理要高效,避免过多密文运算
|
||||
// pub fn search_layer(
|
||||
// &self,
|
||||
// query: &EncryptedQuery,
|
||||
// entry_points: Vec<usize>,
|
||||
// ef: usize,
|
||||
// layer: usize,
|
||||
// zero: &FheInt14,
|
||||
// ) -> Vec<usize> {
|
||||
// let visited: HashSet<usize> = HashSet::new();
|
||||
// let mut candidate =
|
||||
//
|
||||
// // TODO: 设计合适的数据结构
|
||||
// // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
|
||||
// // w: 结果集,维护当前找到的最好的ef个候选点
|
||||
//
|
||||
// // TODO: Step 1 - 初始化候选点
|
||||
// // 对每个entry_point:
|
||||
// // - 检查节点是否存在且level >= layer
|
||||
// // - 计算到query的密文距离:euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
|
||||
// // - 将节点加入visited集合
|
||||
// // - 将(节点索引, 距离)加入candidates和w
|
||||
//
|
||||
// println!(
|
||||
// "🔍 Starting search with {} initial candidates",
|
||||
// // TODO: 显示实际初始化的候选点数量
|
||||
// 0
|
||||
// );
|
||||
//
|
||||
// // TODO: Step 2 - 主搜索循环
|
||||
// // while candidates不为空:
|
||||
// // - 从candidates中找到距离最小的点(需要密文排序或其他方法)
|
||||
// // - 将该点从candidates中移除,设为当前探索点current
|
||||
// //
|
||||
// // - 剪枝检查(可选,复杂):
|
||||
// // 如果w.len() >= ef且current的距离 > w中最远点的距离,则break
|
||||
// //
|
||||
// // - 探索current的邻居:
|
||||
// // for neighbor in self.nodes[current].neighbors[layer]:
|
||||
// // - 如果neighbor未访问过:
|
||||
// // - 标记为已访问
|
||||
// // - 计算到query的密文距离
|
||||
// // - 将neighbor加入candidates
|
||||
// // - 管理结果集w:
|
||||
// // if w.len() < ef: 直接加入w
|
||||
// // else: 加入w后排序,移除最远的点,保持w大小为ef
|
||||
//
|
||||
// println!(
|
||||
// "🔍 Found {} total candidates (ef={})",
|
||||
// // TODO: 显示最终找到的候选点数量
|
||||
// 0,
|
||||
// ef
|
||||
// );
|
||||
//
|
||||
// // TODO: Step 3 - 返回结果
|
||||
// // 从w中提取节点索引并返回
|
||||
// // w.into_iter().map(|(node_idx, _)| node_idx).collect()
|
||||
//
|
||||
// // 临时返回空结果,避免编译错误
|
||||
// Vec::new()
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//impl Default for FheHnswGraph {
|
||||
// fn default() -> Self {
|
||||
// Self::new()
|
||||
// }
|
||||
//}
|
||||
//
|
||||
///// 从明文数据构建密文 HNSW 图的辅助结构
|
||||
//#[derive(Clone)]
|
||||
//pub struct PlaintextHnswNode {
|
||||
// pub vector: Vec<f64>,
|
||||
// pub level: usize,
|
||||
// pub neighbors: Vec<Vec<usize>>,
|
||||
//}
|
||||
//
|
||||
///// 从明文数据构建密文 HNSW 图
|
||||
//pub fn build_fhe_hnsw_from_plaintext(
|
||||
// plaintext_nodes: &[PlaintextHnswNode],
|
||||
// plaintext_entry_point: Option<usize>,
|
||||
// plaintext_max_level: usize,
|
||||
// client_key: &ClientKey,
|
||||
//) -> FheHnswGraph {
|
||||
// use crate::logging::print_progress_bar;
|
||||
//
|
||||
// let mut fhe_nodes = Vec::new();
|
||||
// let total_nodes = plaintext_nodes.len();
|
||||
//
|
||||
// println!("🔐 Encrypting {total_nodes} HNSW nodes...");
|
||||
//
|
||||
// for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
|
||||
// print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
|
||||
// let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
|
||||
// let fhe_node = FheHnswNode {
|
||||
// encrypted_point,
|
||||
// level: plain_node.level,
|
||||
// neighbors: plain_node.neighbors.clone(),
|
||||
// };
|
||||
// fhe_nodes.push(fhe_node);
|
||||
// }
|
||||
//
|
||||
//
|
||||
// FheHnswGraph {
|
||||
// nodes: fhe_nodes,
|
||||
// entry_point: plaintext_entry_point,
|
||||
// max_level: plaintext_max_level,
|
||||
// }
|
||||
//}
|
||||
//
|
||||
///// 构建明文HNSW图
|
||||
//pub fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
|
||||
// println!("🔨 Building HNSW graph from {} points...", data.len());
|
||||
//
|
||||
// let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
|
||||
// let max_level = 3; // 最优参数:保持3层结构
|
||||
// let mut nodes = Vec::new();
|
||||
// let mut entry_point = None;
|
||||
// let mut current_max_level = 0;
|
||||
//
|
||||
// // 创建所有节点
|
||||
// for (idx, vector) in data.iter().enumerate() {
|
||||
// let level = select_level_for_node(max_level);
|
||||
// current_max_level = current_max_level.max(level);
|
||||
//
|
||||
// let node = PlaintextHnswNode {
|
||||
// vector: vector.clone(),
|
||||
// level,
|
||||
// neighbors: vec![Vec::new(); level + 1],
|
||||
// };
|
||||
// nodes.push(node);
|
||||
//
|
||||
// if idx == 0 {
|
||||
// entry_point = Some(idx);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 为每个节点建立连接(简化版实现)
|
||||
// for node_id in 0..nodes.len() {
|
||||
// print_progress_bar(node_id + 1, nodes.len(), "Building connections");
|
||||
//
|
||||
// let node_vector = nodes[node_id].vector.clone();
|
||||
// let node_level = nodes[node_id].level;
|
||||
//
|
||||
// // 为每层找到最近的邻居
|
||||
// for layer in 0..=node_level {
|
||||
// let mut distances: Vec<(f64, usize)> = nodes
|
||||
// .iter()
|
||||
// .enumerate()
|
||||
// .filter(|(idx, n)| *idx != node_id && n.level >= layer)
|
||||
// .map(|(idx, n)| {
|
||||
// let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
|
||||
// (dist, idx)
|
||||
// })
|
||||
// .collect();
|
||||
//
|
||||
// distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
//
|
||||
// // 添加最近的邻居(限制连接数)
|
||||
// for &(_, neighbor_id) in distances.iter().take(max_connections) {
|
||||
// nodes[node_id].neighbors[layer].push(neighbor_id);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 更新入口点
|
||||
// if nodes[node_id].level > nodes[entry_point.unwrap()].level {
|
||||
// entry_point = Some(node_id);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// println!(); // 新行
|
||||
// println!(
|
||||
// "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
|
||||
// nodes.len(),
|
||||
// entry_point,
|
||||
// current_max_level
|
||||
// );
|
||||
//
|
||||
// (nodes, entry_point, current_max_level)
|
||||
//}
|
||||
//
|
||||
///// 选择节点的层级
|
||||
//pub fn select_level_for_node(max_level: usize) -> usize {
|
||||
// let mut rng = rand::rng();
|
||||
// let mut level = 0;
|
||||
// while level < max_level && rng.random::<f32>() < 0.6 {
|
||||
// // 最优参数:提高层级概率到0.6
|
||||
// level += 1;
|
||||
// }
|
||||
// level
|
||||
//}
|
||||
//
|
||||
//
|
||||
//
|
||||
///// 明文欧几里得距离计算
|
||||
//fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
|
||||
// a.iter()
|
||||
// .zip(b.iter())
|
||||
// .map(|(x, y)| (x - y).powi(2))
|
||||
// .sum::<f64>()
|
||||
// .sqrt()
|
||||
//}
|
||||
|
||||
338
test_fhe_hnsw.py
338
test_fhe_hnsw.py
@@ -1,338 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FHE HNSW并行测试脚本 - 多进程版本
|
||||
测试密文态HNSW实现的稳定性和准确率
|
||||
支持多进程并行执行,充分利用服务器CPU资源
|
||||
运行100次,记录≥90%准确率的成功率和结果长度分布
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
# 正确答案
|
||||
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
|
||||
|
||||
|
||||
def calculate_accuracy(result, correct):
|
||||
"""计算准确率:匹配元素数量 / 总元素数量"""
|
||||
matches = len(set(result) & set(correct))
|
||||
return matches / len(correct) * 100
|
||||
|
||||
|
||||
def run_single_test(test_id, data_bit_width="12", timeout_minutes=30):
|
||||
"""运行单次FHE HNSW测试(供多进程调用)"""
|
||||
# 为每个进程创建独立的输出文件
|
||||
output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl"
|
||||
|
||||
cmd = [
|
||||
"./enc",
|
||||
"--algorithm",
|
||||
"hnsw",
|
||||
"--data-bit-width",
|
||||
data_bit_width,
|
||||
"--predictions",
|
||||
output_file,
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=timeout_minutes * 60
|
||||
)
|
||||
|
||||
test_time = time.time() - start_time
|
||||
|
||||
if result.returncode != 0:
|
||||
return {
|
||||
"test_id": test_id,
|
||||
"result": None,
|
||||
"error": f"Command failed: {result.stderr[:200]}",
|
||||
"time_minutes": test_time / 60,
|
||||
}
|
||||
|
||||
# 读取结果
|
||||
if not Path(output_file).exists():
|
||||
return {
|
||||
"test_id": test_id,
|
||||
"result": None,
|
||||
"error": "Output file not found",
|
||||
"time_minutes": test_time / 60,
|
||||
}
|
||||
|
||||
with open(output_file, "r") as f:
|
||||
output = json.load(f)
|
||||
answer = output["answer"]
|
||||
|
||||
# 清理临时文件
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
# 计算准确率
|
||||
accuracy = calculate_accuracy(answer, CORRECT_ANSWER)
|
||||
|
||||
return {
|
||||
"test_id": test_id,
|
||||
"result": answer,
|
||||
"length": len(answer),
|
||||
"accuracy": accuracy,
|
||||
"success": accuracy >= 90,
|
||||
"time_minutes": test_time / 60,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
# 清理临时文件
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
return {
|
||||
"test_id": test_id,
|
||||
"result": None,
|
||||
"error": "timeout",
|
||||
"time_minutes": timeout_minutes,
|
||||
}
|
||||
except Exception as e:
|
||||
# 清理临时文件
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
return {
|
||||
"test_id": test_id,
|
||||
"result": None,
|
||||
"error": str(e),
|
||||
"time_minutes": (time.time() - start_time) / 60,
|
||||
}
|
||||
|
||||
|
||||
def print_progress(completed, total, start_time):
|
||||
"""打印进度信息"""
|
||||
if completed == 0:
|
||||
return
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
avg_time = elapsed / completed
|
||||
remaining_time = avg_time * (total - completed)
|
||||
|
||||
print(
|
||||
f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) "
|
||||
f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
"""主测试函数"""
|
||||
print("🚀 FHE HNSW 并行批量测试脚本")
|
||||
print(f"🎯 正确答案: {CORRECT_ANSWER}")
|
||||
print("📊 运行100次测试,记录准确率和结果长度分布")
|
||||
print("🔧 支持多进程并行执行")
|
||||
|
||||
# 检查enc二进制文件是否存在
|
||||
if not Path("./enc").exists():
|
||||
print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:")
|
||||
print(" cargo build --release --bin enc")
|
||||
return
|
||||
|
||||
# 获取CPU核心数并设置进程数
|
||||
cpu_cores = mp.cpu_count()
|
||||
# 考虑到FHE运算的内存密集性,使用核心数的一半避免内存不足
|
||||
num_processes = 8
|
||||
timeout_minutes = 45 # 增加超时时间到45分钟
|
||||
|
||||
print(f"💻 使用 {num_processes} 个并行进程")
|
||||
print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟")
|
||||
print(
|
||||
f"⏱️ 预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时"
|
||||
)
|
||||
print()
|
||||
|
||||
print("=" * 80)
|
||||
print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)")
|
||||
print("=" * 80)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 创建进程池并执行测试
|
||||
test_ids = list(range(1, 101)) # 1-100
|
||||
|
||||
with mp.Pool(processes=num_processes) as pool:
|
||||
# 创建异步任务
|
||||
async_results = []
|
||||
for test_id in test_ids:
|
||||
async_result = pool.apply_async(
|
||||
run_single_test, (test_id, "12", timeout_minutes)
|
||||
)
|
||||
async_results.append(async_result)
|
||||
|
||||
# 收集结果并显示进度
|
||||
results = []
|
||||
completed = 0
|
||||
|
||||
print_progress(completed, len(test_ids), start_time)
|
||||
|
||||
for async_result in async_results:
|
||||
try:
|
||||
result = async_result.get(
|
||||
timeout=timeout_minutes * 60 + 60
|
||||
) # 额外1分钟缓冲
|
||||
results.append(result)
|
||||
completed += 1
|
||||
print_progress(completed, len(test_ids), start_time)
|
||||
except mp.TimeoutError:
|
||||
# 进程级别超时
|
||||
results.append(
|
||||
{
|
||||
"test_id": len(results) + 1,
|
||||
"result": None,
|
||||
"error": "process_timeout",
|
||||
"time_minutes": timeout_minutes,
|
||||
}
|
||||
)
|
||||
completed += 1
|
||||
print_progress(completed, len(test_ids), start_time)
|
||||
|
||||
print() # 换行
|
||||
total_elapsed = time.time() - start_time
|
||||
|
||||
# 分析结果
|
||||
print("\n" + "=" * 80)
|
||||
print("📈 测试结果分析")
|
||||
print("=" * 80)
|
||||
|
||||
# 统计变量
|
||||
valid_results = [r for r in results if r["result"] is not None]
|
||||
success_count = sum(1 for r in valid_results if r["success"])
|
||||
length_distribution = defaultdict(int)
|
||||
error_distribution = defaultdict(int)
|
||||
|
||||
# 统计错误类型
|
||||
for r in results:
|
||||
if r["error"]:
|
||||
error_type = r["error"]
|
||||
if "timeout" in error_type.lower():
|
||||
error_distribution["timeout"] += 1
|
||||
elif "failed" in error_type.lower():
|
||||
error_distribution["failed"] += 1
|
||||
else:
|
||||
error_distribution["other"] += 1
|
||||
else:
|
||||
length_distribution[r["length"]] += 1
|
||||
|
||||
# 基本统计
|
||||
total_tests = len(results)
|
||||
valid_tests = len(valid_results)
|
||||
|
||||
if valid_tests == 0:
|
||||
print("❌ 没有有效的测试结果")
|
||||
return
|
||||
|
||||
success_rate = success_count / valid_tests * 100
|
||||
avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests
|
||||
avg_length = sum(r["length"] for r in valid_results) / valid_tests
|
||||
avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests
|
||||
|
||||
print(f"总测试次数: {total_tests}")
|
||||
print(f"有效测试次数: {valid_tests}")
|
||||
print(f"成功次数 (≥90%准确率): {success_count}")
|
||||
print(f"成功率: {success_rate:.1f}%")
|
||||
print(f"平均准确率: {avg_accuracy:.1f}%")
|
||||
print(f"平均结果长度: {avg_length:.1f}")
|
||||
print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟")
|
||||
print(f"总测试时间: {total_elapsed/3600:.1f}小时")
|
||||
print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x")
|
||||
|
||||
# 错误统计
|
||||
if error_distribution:
|
||||
print("\n❌ 错误分布:")
|
||||
for error_type, count in error_distribution.items():
|
||||
print(f" {error_type}: {count}次")
|
||||
|
||||
# 结果长度分布
|
||||
if length_distribution:
|
||||
print("\n📊 结果长度分布:")
|
||||
for length in sorted(length_distribution.keys()):
|
||||
count = length_distribution[length]
|
||||
percentage = count / valid_tests * 100
|
||||
bar = "█" * (count // 2) if count > 0 else ""
|
||||
print(f" 长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}")
|
||||
|
||||
# 结论
|
||||
print()
|
||||
if success_rate >= 50:
|
||||
print("✅ 测试通过! FHE HNSW实现稳定性良好")
|
||||
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
|
||||
print(f" - 平均准确率: {avg_accuracy:.1f}%")
|
||||
if avg_length >= 9.5:
|
||||
print(f" - 平均结果长度: {avg_length:.1f} (接近10)")
|
||||
else:
|
||||
print(f" - 平均结果长度: {avg_length:.1f} (需要改进)")
|
||||
else:
|
||||
print("❌ 测试未通过,需要进一步优化")
|
||||
print(f" - 成功率: {success_rate:.1f}% (<50%)")
|
||||
print(f" - 平均准确率: {avg_accuracy:.1f}%")
|
||||
|
||||
# 保存详细结果
|
||||
report_file = "fhe_hnsw_parallel_test_report.json"
|
||||
with open(report_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"summary": {
|
||||
"total_tests": total_tests,
|
||||
"valid_tests": valid_tests,
|
||||
"success_count": success_count,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_accuracy,
|
||||
"avg_length": avg_length,
|
||||
"avg_time_minutes": avg_time_per_test,
|
||||
"total_time_hours": total_elapsed / 3600,
|
||||
"num_processes": num_processes,
|
||||
"cpu_cores": cpu_cores,
|
||||
"speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600),
|
||||
},
|
||||
"length_distribution": dict(length_distribution),
|
||||
"error_distribution": dict(error_distribution),
|
||||
"detailed_results": results,
|
||||
"correct_answer": CORRECT_ANSWER,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
print(f"\n📁 详细报告已保存到: {report_file}")
|
||||
|
||||
# 显示最好和最坏的几个结果
|
||||
if valid_results:
|
||||
print("\n🏆 最高准确率的5个结果:")
|
||||
best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[
|
||||
:5
|
||||
]
|
||||
for i, r in enumerate(best_results, 1):
|
||||
print(
|
||||
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
|
||||
)
|
||||
|
||||
print("\n⚠️ 最低准确率的5个结果:")
|
||||
worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5]
|
||||
for i, r in enumerate(worst_results, 1):
|
||||
print(
|
||||
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 设置多进程启动方法(Linux上默认是fork,但spawn更安全)
|
||||
mp.set_start_method("spawn", force=True)
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⏹️ 测试被用户中断")
|
||||
# 清理可能的临时文件
|
||||
for f in Path(".").glob("test_fhe_output_*.jsonl"):
|
||||
f.unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
print(f"\n💥 程序异常: {e}")
|
||||
# 清理可能的临时文件
|
||||
for f in Path(".").glob("test_fhe_output_*.jsonl"):
|
||||
f.unlink(missing_ok=True)
|
||||
32
writeup.md
Normal file
32
writeup.md
Normal file
@@ -0,0 +1,32 @@
|
||||
# 0xfa队 writeup
|
||||
|
||||
## 全同态加密算法介绍
|
||||
|
||||
加密算法方面选择了thfe算法,库方面选择了较为成熟的thfe-rs算法。
|
||||
|
||||
### 算法参数
|
||||
|
||||
Message bits: 2位
|
||||
|
||||
Carry bits: 2位
|
||||
|
||||
噪声分布: TUniform (tweaked uniform)
|
||||
|
||||
Bootstrap失败概率: ≤ 2^-128 (CPU后端)
|
||||
|
||||
## knn算法实现细节
|
||||
|
||||
将欧式距离公式拆分:
|
||||
|
||||
> sum((a-b)^2)=sum(a^2) - sum(2a\*b) + sum(b^2)
|
||||
|
||||
减少了密文态的乘法和加法操作
|
||||
|
||||
选择上实现了双调排序,将100个距离结果,用最大值填充至128个结果。
|
||||
然后进行并行排序,最后选择前十个密文。
|
||||
|
||||
两个操作都使用了rayon库做多核并行计算
|
||||
|
||||
## 本地测试结果
|
||||
|
||||
在本地i9-10920X(12核24线程)情况下,运行时间约9min(4min+5min)
|
||||
Reference in New Issue
Block a user