fix: correct run.sh script parameters and disable HNSW code (v0.3.1)
- Fix syntax error in run.sh: remove extra quote and correct --log-path to --log-file - Comment out HNSW algorithm implementation in enc.rs and algorithms.rs to simplify codebase - Bump version to 0.3.1 in Cargo.toml - Remove HNSW implementation guide and test files - Add comprehensive project writeup documentation
This commit is contained in:
		
							
								
								
									
										2
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								Cargo.lock
									
									
									
										generated
									
									
									
								
							| @@ -410,7 +410,7 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" | ||||
|  | ||||
| [[package]] | ||||
| name = "hfe_knn" | ||||
| version = "0.3.0" | ||||
| version = "0.3.1" | ||||
| dependencies = [ | ||||
|  "anyhow", | ||||
|  "bincode 2.0.1", | ||||
|   | ||||
| @@ -1,6 +1,6 @@ | ||||
| [package] | ||||
| name = "hfe_knn" | ||||
| version = "0.3.0" | ||||
| version = "0.3.1" | ||||
| edition = "2024" | ||||
|  | ||||
| [dependencies] | ||||
|   | ||||
| @@ -1,143 +0,0 @@ | ||||
| # HNSW Search Layer 实现指南 | ||||
|  | ||||
| ## 目标 | ||||
| 实现标准的HNSW贪心搜索算法,但使用密文距离计算,匹配明文版本的逻辑和性能。 | ||||
|  | ||||
| ## 关键数据结构 | ||||
|  | ||||
| ### 输入参数 | ||||
| - `query: &EncryptedQuery<T>` - 加密的查询点 | ||||
| - `entry_points: Vec<usize>` - 入口点的节点索引列表 | ||||
| - `ef: usize` - 搜索时的候选集大小 | ||||
| - `layer: usize` - 当前搜索的层级 | ||||
| - `zero: &T` - 加密的零值(用于距离计算) | ||||
|  | ||||
| ### 内部数据结构建议 | ||||
| ```rust | ||||
| // 候选队列:存储待探索的节点 | ||||
| let mut candidates: Vec<(usize, EncryptedNeighbor<T>)> = Vec::new(); | ||||
| // 结果集:维护当前最好的ef个候选点   | ||||
| let mut w: Vec<(usize, EncryptedNeighbor<T>)> = Vec::new(); | ||||
| // 访问标记 | ||||
| let mut visited: HashSet<usize> = HashSet::new(); | ||||
| ``` | ||||
|  | ||||
| 其中 `EncryptedNeighbor<T>` 结构已定义: | ||||
| ```rust | ||||
| pub struct EncryptedNeighbor<T> { | ||||
|     pub distance: T,        // 密文距离 | ||||
|     pub index: FheUint8,    // 密文索引 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 实现步骤 | ||||
|  | ||||
| ### Step 1: 初始化候选点 | ||||
| ```rust | ||||
| for &ep in &entry_points { | ||||
|     if ep < self.nodes.len() && self.nodes[ep].level >= layer { | ||||
|         visited.insert(ep); | ||||
|         let distance = euclidean_distance(query, &self.nodes[ep].encrypted_point, zero); | ||||
|         let neighbor = EncryptedNeighbor { | ||||
|             distance, | ||||
|             index: self.nodes[ep].encrypted_point.index.clone(), | ||||
|         }; | ||||
|         candidates.push((ep, neighbor.clone())); | ||||
|         w.push((ep, neighbor)); | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ### Step 2: 主搜索循环 | ||||
| ```rust | ||||
| while !candidates.is_empty() { | ||||
|     // 2.1 找到距离最小的候选点 | ||||
|     // 提示:需要对candidates中的EncryptedNeighbor按distance排序 | ||||
|     // 可以使用 encrypted_selection_sort 或其他方法 | ||||
|      | ||||
|     // 2.2 移除最小距离的候选点作为当前探索点 | ||||
|     let current = /* 从candidates中移除最小距离点的节点索引 */; | ||||
|      | ||||
|     // 2.3 剪枝检查(可选,但会影响性能) | ||||
|     // 如果w.len() >= ef 且 current的距离 > w中最远点的距离,则break | ||||
|      | ||||
|     // 2.4 探索当前节点的邻居 | ||||
|     for &neighbor_idx in &self.nodes[current].neighbors[layer] { | ||||
|         if !visited.contains(&neighbor_idx) && neighbor_idx < self.nodes.len() { | ||||
|             visited.insert(neighbor_idx); | ||||
|             let distance = euclidean_distance(query, &self.nodes[neighbor_idx].encrypted_point, zero); | ||||
|             let encrypted_neighbor = EncryptedNeighbor { | ||||
|                 distance, | ||||
|                 index: self.nodes[neighbor_idx].encrypted_point.index.clone(), | ||||
|             }; | ||||
|              | ||||
|             // 加入候选队列 | ||||
|             candidates.push((neighbor_idx, encrypted_neighbor.clone())); | ||||
|              | ||||
|             // 管理结果集w | ||||
|             if w.len() < ef { | ||||
|                 w.push((neighbor_idx, encrypted_neighbor)); | ||||
|             } else { | ||||
|                 // 结果集已满,需要替换最远的点 | ||||
|                 w.push((neighbor_idx, encrypted_neighbor)); | ||||
|                 // 排序w,只保留前ef个最近的点 | ||||
|                 // 提示:可以先转换为Vec<EncryptedNeighbor>,排序后重建w | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ### Step 3: 返回结果 | ||||
| ```rust | ||||
| w.into_iter().map(|(node_idx, _)| node_idx).collect() | ||||
| ``` | ||||
|  | ||||
| ## 性能优化建议 | ||||
|  | ||||
| ### 1. 减少密文排序次数 | ||||
| - **问题**:每次排序都很昂贵(~2-3分钟) | ||||
| - **策略**: | ||||
|   - 只在必要时排序(如候选队列管理、结果集维护) | ||||
|   - 考虑批量处理而不是逐个比较 | ||||
|   - 可以适当牺牲一些算法精确性来换取性能 | ||||
|  | ||||
| ### 2. 候选队列管理 | ||||
| - **明文版本**:使用BinaryHeap,O(log n)插入和删除 | ||||
| - **密文版本**:只能用排序,O(n log n) | ||||
| - **优化**:考虑限制候选队列大小,避免无限增长 | ||||
|  | ||||
| ### 3. 剪枝策略 | ||||
| - **理想**:`current_distance > farthest_w_distance && w.len() >= ef` 则停止 | ||||
| - **现实**:密文比较结果无法直接判断 | ||||
| - **权衡**:可以跳过复杂剪枝,让算法更彻底但稍慢 | ||||
|  | ||||
| ## 调试提示 | ||||
|  | ||||
| ### 1. 验证初始化 | ||||
| 确保entry_points正确初始化到candidates和w中 | ||||
|  | ||||
| ### 2. 验证邻居探索 | ||||
| 检查是否正确遍历`self.nodes[current].neighbors[layer]` | ||||
|  | ||||
| ### 3. 验证visited逻辑 | ||||
| 确保不重复访问同一节点 | ||||
|  | ||||
| ### 4. 验证结果集管理 | ||||
| 确保w的大小不超过ef,且包含距离最近的点 | ||||
|  | ||||
| ## 期望性能目标 | ||||
|  | ||||
| - **明文版本**:毫秒级 | ||||
| - **密文版本目标**:15-20分钟(相比当前的100+分钟) | ||||
| - **准确率目标**:80%+(相比当前的30%) | ||||
|  | ||||
| ## 可用的工具函数 | ||||
|  | ||||
| - `euclidean_distance(query, point, zero)` - 计算密文欧几里得距离 | ||||
| - `encrypted_selection_sort(distances, k)` - 密文选择排序,获取前k个最小值 | ||||
| - `EncryptedNeighbor` - 包装距离和索引的结构体 | ||||
|  | ||||
| ## 明文版本参考 | ||||
|  | ||||
| 参考 `src/bin/plain.rs` 中的 `search_layer` 函数实现,理解标准HNSW算法的逻辑流程。 | ||||
							
								
								
									
										2
									
								
								run.sh
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								run.sh
									
									
									
									
									
								
							| @@ -20,4 +20,4 @@ chmod +x "${SCRIPT_DIR}/test" | ||||
| "${SCRIPT_DIR}/test" \ | ||||
| 	--dataset "$DATASET_FILE" \ | ||||
| 	--predictions "$PREDICTIONS_RESULT_FILE" \ | ||||
| 	--log-path "/home/admin/workspace/job/logs/user.log'" | ||||
| 	--log-file "/home/admin/workspace/job/logs/user.log" | ||||
|   | ||||
| @@ -1,5 +1,5 @@ | ||||
| use crate::EncryptedQuery; | ||||
| use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph}; | ||||
| use crate::data::{EncryptedNeighbor, EncryptedPoint}; | ||||
| use crate::logging::{format_duration, print_progress_bar}; | ||||
| use rayon::prelude::*; | ||||
| use std::time::Instant; | ||||
| @@ -322,93 +322,59 @@ fn encrypted_conditional_swap( | ||||
|     b.index = new_b_index; | ||||
| } | ||||
|  | ||||
| /// 执行HNSW近似最近邻搜索 | ||||
| /// | ||||
| /// # Arguments | ||||
| /// * `graph` - 预构建的FHE HNSW图结构 | ||||
| /// * `query` - 加密的查询点 | ||||
| /// * `k` - 返回的最近邻数量 | ||||
| /// * `zero` - 加密的零值 | ||||
| /// | ||||
| /// # Returns | ||||
| /// * k个最近邻的加密索引列表 | ||||
| pub fn perform_hnsw_search( | ||||
|     graph: &FheHnswGraph, | ||||
|     query: &EncryptedQuery, | ||||
|     k: usize, | ||||
|     zero: &FheInt14, | ||||
| ) -> Vec<FheUint8> { | ||||
|     println!("🚀 Starting HNSW approximate search..."); | ||||
|  | ||||
|     if graph.nodes.is_empty() { | ||||
|         println!("❌ Empty HNSW graph"); | ||||
|         return Vec::new(); | ||||
|     } | ||||
|  | ||||
|     let Some(entry_point) = graph.entry_point else { | ||||
|         println!("❌ No entry point in HNSW graph"); | ||||
|         return Vec::new(); | ||||
|     }; | ||||
|  | ||||
|     println!( | ||||
|         "🔍 HNSW search from entry point {} at level {}", | ||||
|         entry_point, graph.max_level | ||||
|     ); | ||||
|  | ||||
|     let mut current_candidates = vec![entry_point]; | ||||
|  | ||||
|     // 从最高层逐层搜索到第1层 | ||||
|     for layer in (1..=graph.max_level).rev() { | ||||
|         println!("🔍 Searching layer {layer} with ef=1..."); | ||||
|         let layer_start = Instant::now(); | ||||
|         current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero); | ||||
|         println!( | ||||
|             "✅ Layer {} search completed in {}, {} candidates", | ||||
|             layer, | ||||
|             format_duration(layer_start.elapsed()), | ||||
|             current_candidates.len() | ||||
|         ); | ||||
|     } | ||||
|  | ||||
|     // 在第0层进行最终搜索 | ||||
|     println!("🔍 Final search at layer 0 with ef={}...", 10); | ||||
|     let final_search_start = Instant::now(); | ||||
|     let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero); | ||||
|     println!( | ||||
|         "✅ Final layer search completed in {}, {} candidates", | ||||
|         format_duration(final_search_start.elapsed()), | ||||
|         final_candidates.len() | ||||
|     ); | ||||
|  | ||||
|     // 计算最终候选点的距离并排序 | ||||
|     println!("🔢 Computing distances for final candidates..."); | ||||
|     let distance_start = Instant::now(); | ||||
|     let mut distances = Vec::new(); | ||||
|     for (i, &candidate) in final_candidates.iter().enumerate().take(graph.nodes.len()) { | ||||
|         if candidate < graph.nodes.len() { | ||||
|             if i % 10 == 0 && i > 0 { | ||||
|                 println!( | ||||
|                     "🔢 Processed {}/{} final candidates", | ||||
|                     i, | ||||
|                     final_candidates.len().min(graph.nodes.len()) | ||||
|                 ); | ||||
|             } | ||||
|             let distance = euclidean_distance(query, &graph.nodes[candidate].encrypted_point, zero); | ||||
|             distances.push(EncryptedNeighbor { | ||||
|                 distance, | ||||
|                 index: graph.nodes[candidate].encrypted_point.index.clone(), | ||||
|             }); | ||||
|         } | ||||
|     } | ||||
|     println!( | ||||
|         "✅ Distance computation completed in {}", | ||||
|         format_duration(distance_start.elapsed()) | ||||
|     ); | ||||
|  | ||||
|     // 选择最好的k个 | ||||
|     println!("📊 Selecting top {} from {} candidates", k, distances.len()); | ||||
|     let len = distances.len(); | ||||
|     encrypted_selection_sort(&mut distances, k.min(len)); | ||||
|  | ||||
|     distances.iter().take(k).map(|n| n.index.clone()).collect() | ||||
| } | ||||
| ///// 执行HNSW近似最近邻搜索 | ||||
| ///// | ||||
| ///// # Arguments | ||||
| ///// * `graph` - 预构建的FHE HNSW图结构 | ||||
| ///// * `query` - 加密的查询点 | ||||
| ///// * `zero` - 加密的零值 | ||||
| ///// | ||||
| ///// # Returns | ||||
| ///// * 10个最近邻的索引列表 | ||||
| //pub fn perform_hnsw_search( | ||||
| //    graph: &FheHnswGraph, | ||||
| //    query: &EncryptedQuery, | ||||
| //    zero: &FheInt14, | ||||
| //) -> Vec<usize> { | ||||
| //    println!("🚀 Starting HNSW approximate search..."); | ||||
| // | ||||
| //    if graph.nodes.is_empty() { | ||||
| //        println!("❌ Empty HNSW graph"); | ||||
| //        return Vec::new(); | ||||
| //    } | ||||
| // | ||||
| //    let Some(entry_point) = graph.entry_point else { | ||||
| //        println!("❌ No entry point in HNSW graph"); | ||||
| //        return Vec::new(); | ||||
| //    }; | ||||
| // | ||||
| //    println!( | ||||
| //        "🔍 HNSW search from entry point {} at level {}", | ||||
| //        entry_point, graph.max_level | ||||
| //    ); | ||||
| // | ||||
| //    let mut current_candidates = vec![entry_point]; | ||||
| // | ||||
| //    // 从最高层逐层搜索到第1层 | ||||
| //    for layer in (1..=graph.max_level).rev() { | ||||
| //        println!("🔍 Searching layer {layer} with ef=1..."); | ||||
| //        let layer_start = Instant::now(); | ||||
| //        current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero); | ||||
| //        println!( | ||||
| //            "✅ Layer {} search completed in {}, {} candidates", | ||||
| //            layer, | ||||
| //            format_duration(layer_start.elapsed()), | ||||
| //            current_candidates.len() | ||||
| //        ); | ||||
| //    } | ||||
| // | ||||
| //    // 在第0层进行最终搜索 | ||||
| //    let final_search_start = Instant::now(); | ||||
| //    let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero); | ||||
| //    println!( | ||||
| //        "✅ Final layer search completed in {}, {} candidates", | ||||
| //        format_duration(final_search_start.elapsed()), | ||||
| //        final_candidates.len() | ||||
| //    ); | ||||
| //    final_candidates | ||||
| //} | ||||
|   | ||||
							
								
								
									
										207
									
								
								src/bin/enc.rs
									
									
									
									
									
								
							
							
						
						
									
										207
									
								
								src/bin/enc.rs
									
									
									
									
									
								
							| @@ -2,7 +2,6 @@ use anyhow::Result; | ||||
| use chrono::Local; | ||||
| use clap::Parser; | ||||
| use log::info; | ||||
| use rand::Rng; | ||||
| use std::fs::File; | ||||
| use std::io::{BufRead, BufReader, Write}; | ||||
| use std::time::Instant; | ||||
| @@ -11,9 +10,9 @@ use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key}; | ||||
|  | ||||
| // Import from our library modules | ||||
| use hfe_knn::{ | ||||
|     Dataset, EncryptedNeighbor, EncryptedPoint, PlaintextHnswNode, Prediction, ScaleInt, | ||||
|     build_fhe_hnsw_from_plaintext, compute_distances, decrypt_indices, encrypt_point, | ||||
|     encrypt_query, format_duration, perform_hnsw_search, perform_knn_selection, print_progress_bar, | ||||
|     Dataset, EncryptedNeighbor, EncryptedPoint, Prediction, ScaleInt, compute_distances, | ||||
|     decrypt_indices, encrypt_point, encrypt_query, format_duration, perform_knn_selection, | ||||
|     print_progress_bar, | ||||
| }; | ||||
|  | ||||
| #[derive(Parser)] | ||||
| @@ -87,97 +86,6 @@ fn debug_compute_distances( | ||||
|     encrypted_distances | ||||
| } | ||||
|  | ||||
| /// 构建明文HNSW图 | ||||
| fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) { | ||||
|     println!("🔨 Building HNSW graph from {} points...", data.len()); | ||||
|  | ||||
|     let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量 | ||||
|     let max_level = 3; // 最优参数:保持3层结构 | ||||
|     let mut nodes = Vec::new(); | ||||
|     let mut entry_point = None; | ||||
|     let mut current_max_level = 0; | ||||
|  | ||||
|     // 创建所有节点 | ||||
|     for (idx, vector) in data.iter().enumerate() { | ||||
|         let level = select_level_for_node(max_level); | ||||
|         current_max_level = current_max_level.max(level); | ||||
|  | ||||
|         let node = PlaintextHnswNode { | ||||
|             vector: vector.clone(), | ||||
|             level, | ||||
|             neighbors: vec![Vec::new(); level + 1], | ||||
|         }; | ||||
|         nodes.push(node); | ||||
|  | ||||
|         if idx == 0 { | ||||
|             entry_point = Some(idx); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // 为每个节点建立连接(简化版实现) | ||||
|     for node_id in 0..nodes.len() { | ||||
|         print_progress_bar(node_id + 1, nodes.len(), "Building connections"); | ||||
|  | ||||
|         let node_vector = nodes[node_id].vector.clone(); | ||||
|         let node_level = nodes[node_id].level; | ||||
|  | ||||
|         // 为每层找到最近的邻居 | ||||
|         for layer in 0..=node_level { | ||||
|             let mut distances: Vec<(f64, usize)> = nodes | ||||
|                 .iter() | ||||
|                 .enumerate() | ||||
|                 .filter(|(idx, n)| *idx != node_id && n.level >= layer) | ||||
|                 .map(|(idx, n)| { | ||||
|                     let dist = euclidean_distance_plaintext(&node_vector, &n.vector); | ||||
|                     (dist, idx) | ||||
|                 }) | ||||
|                 .collect(); | ||||
|  | ||||
|             distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); | ||||
|  | ||||
|             // 添加最近的邻居(限制连接数) | ||||
|             for &(_, neighbor_id) in distances.iter().take(max_connections) { | ||||
|                 nodes[node_id].neighbors[layer].push(neighbor_id); | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         // 更新入口点 | ||||
|         if nodes[node_id].level > nodes[entry_point.unwrap()].level { | ||||
|             entry_point = Some(node_id); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     println!(); // 新行 | ||||
|     println!( | ||||
|         "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}", | ||||
|         nodes.len(), | ||||
|         entry_point, | ||||
|         current_max_level | ||||
|     ); | ||||
|  | ||||
|     (nodes, entry_point, current_max_level) | ||||
| } | ||||
|  | ||||
| /// 选择节点的层级 | ||||
| fn select_level_for_node(max_level: usize) -> usize { | ||||
|     let mut rng = rand::rng(); | ||||
|     let mut level = 0; | ||||
|     while level < max_level && rng.random::<f32>() < 0.6 { | ||||
|         // 最优参数:提高层级概率到0.6 | ||||
|         level += 1; | ||||
|     } | ||||
|     level | ||||
| } | ||||
|  | ||||
| /// 明文欧几里得距离计算 | ||||
| fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 { | ||||
|     a.iter() | ||||
|         .zip(b.iter()) | ||||
|         .map(|(x, y)| (x - y).powi(2)) | ||||
|         .sum::<f64>() | ||||
|         .sqrt() | ||||
| } | ||||
|  | ||||
| fn main() -> Result<()> { | ||||
|     let args = Args::parse(); | ||||
|     let start_time = Instant::now(); | ||||
| @@ -235,67 +143,62 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan | ||||
|         let query_encrypted = encrypt_query(&dataset.query, client_key); | ||||
|  | ||||
|         let k = 10; // Number of nearest neighbors | ||||
|         let encrypted_neighbors = if args.algorithm == "hnsw" { | ||||
|             // HNSW 算法路径 | ||||
|             println!("🚀 Using HNSW algorithm..."); | ||||
|         // if args.algorithm == "hnsw" { | ||||
|         //     // HNSW 算法路径 | ||||
|         //     println!("🚀 Using HNSW algorithm..."); | ||||
|         // | ||||
|         //     // 1. 构建明文HNSW图 | ||||
|         //     let (plaintext_nodes, entry_point, max_level) = | ||||
|         //         build_plaintext_hnsw_graph(&dataset.data); | ||||
|         // | ||||
|         //     // 2. 转换为加密HNSW图 | ||||
|         //     println!("🔐 Converting to encrypted HNSW graph..."); | ||||
|         //     let fhe_graph = | ||||
|         //         build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key); | ||||
|         // | ||||
|         //     // 3. 执行HNSW搜索 | ||||
|         //     let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap(); | ||||
|         //     let answer = perform_hnsw_search(&fhe_graph, &query_encrypted, &encrypted_zero); | ||||
|         //     results.push(Prediction { answer }); | ||||
|         // } else { | ||||
|         // 传统算法路径 | ||||
|         // Encrypt all training points | ||||
|         println!("🔐 Encrypting training points..."); | ||||
|         let points_encrypted: Vec<EncryptedPoint> = dataset | ||||
|             .data | ||||
|             .iter() | ||||
|             .enumerate() | ||||
|             .map(|(idx, coords)| encrypt_point(coords, idx + 1, client_key)) | ||||
|             .collect(); | ||||
|  | ||||
|             // 1. 构建明文HNSW图 | ||||
|             let (plaintext_nodes, entry_point, max_level) = | ||||
|                 build_plaintext_hnsw_graph(&dataset.data); | ||||
|  | ||||
|             // 2. 转换为加密HNSW图 | ||||
|             println!("🔐 Converting to encrypted HNSW graph..."); | ||||
|             let fhe_graph = | ||||
|                 build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key); | ||||
|  | ||||
|             // 3. 执行HNSW搜索 | ||||
|             let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap(); | ||||
|             perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero) | ||||
|         // Compute distances | ||||
|         let mut distances = if args.debug { | ||||
|             debug_compute_distances(&dataset.query, &dataset.data, &points_encrypted, client_key) | ||||
|         } else { | ||||
|             // 传统算法路径 | ||||
|             // Encrypt all training points | ||||
|             println!("🔐 Encrypting training points..."); | ||||
|             let points_encrypted: Vec<EncryptedPoint> = dataset | ||||
|                 .data | ||||
|                 .iter() | ||||
|                 .enumerate() | ||||
|                 .map(|(idx, coords)| encrypt_point(coords, idx + 1, client_key)) | ||||
|                 .collect(); | ||||
|  | ||||
|             // Compute distances | ||||
|             let mut distances = if args.debug { | ||||
|                 debug_compute_distances( | ||||
|                     &dataset.query, | ||||
|                     &dataset.data, | ||||
|                     &points_encrypted, | ||||
|                     client_key, | ||||
|                 ) | ||||
|             } else { | ||||
|                 let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap(); | ||||
|                 compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero) | ||||
|             }; | ||||
|  | ||||
|             // Perform KNN selection using the specified algorithm | ||||
|             let max_distance = if args.algorithm == "bitonic" { | ||||
|                 Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值 | ||||
|             } else { | ||||
|                 None | ||||
|             }; | ||||
|             let max_index = if args.algorithm == "bitonic" { | ||||
|                 Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap()) | ||||
|             } else { | ||||
|                 None | ||||
|             }; | ||||
|  | ||||
|             perform_knn_selection( | ||||
|                 &mut distances, | ||||
|                 k, | ||||
|                 &args.algorithm, | ||||
|                 max_distance.as_ref(), | ||||
|                 max_index.as_ref(), | ||||
|             ) | ||||
|             let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap(); | ||||
|             compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero) | ||||
|         }; | ||||
|  | ||||
|         // Perform KNN selection using the specified algorithm | ||||
|         let max_distance = if args.algorithm == "bitonic" { | ||||
|             Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值 | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
|         let max_index = if args.algorithm == "bitonic" { | ||||
|             Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap()) | ||||
|         } else { | ||||
|             None | ||||
|         }; | ||||
|  | ||||
|         let encrypted_neighbors = perform_knn_selection( | ||||
|             &mut distances, | ||||
|             k, | ||||
|             &args.algorithm, | ||||
|             max_distance.as_ref(), | ||||
|             max_index.as_ref(), | ||||
|         ); | ||||
|  | ||||
|         // Decrypt the results | ||||
|         println!("🔓 Decrypting results..."); | ||||
|         let decrypted_indices = decrypt_indices(&encrypted_neighbors, client_key); | ||||
|   | ||||
							
								
								
									
										117
									
								
								src/bin/plain.rs
									
									
									
									
									
								
							
							
						
						
									
										117
									
								
								src/bin/plain.rs
									
									
									
									
									
								
							| @@ -52,7 +52,8 @@ struct Prediction { | ||||
|  | ||||
| fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { | ||||
|     // 模拟加密版本的缩放和精度损失 | ||||
|     let scaled_distance: f64 = a.iter() | ||||
|     let scaled_distance: f64 = a | ||||
|         .iter() | ||||
|         .zip(b.iter()) | ||||
|         .map(|(x, y)| { | ||||
|             // 缩放坐标(乘以10,然后舍弃小数) | ||||
| @@ -61,7 +62,7 @@ fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { | ||||
|             (scaled_x - scaled_y).powi(2) | ||||
|         }) | ||||
|         .sum::<f64>(); | ||||
|      | ||||
|  | ||||
|     scaled_distance | ||||
| } | ||||
|  | ||||
| @@ -280,78 +281,80 @@ impl HNSWGraph { | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn main() -> Result<()> { | ||||
|     let args = Args::parse(); | ||||
|  | ||||
|     let file = File::open(&args.dataset)?; | ||||
|     let reader = BufReader::new(file); | ||||
|  | ||||
|     let mut results = Vec::new(); | ||||
|  | ||||
|     for line in reader.lines() { | ||||
|         let line = line?; | ||||
|         let dataset: Dataset = serde_json::from_str(&line)?; | ||||
|  | ||||
|         let nearest = knn_classify(&dataset.query, &dataset.data, 10); | ||||
|         results.push(Prediction { answer: nearest }); | ||||
|     } | ||||
|  | ||||
|     let mut output_file = File::create(&args.predictions)?; | ||||
|     for result in results { | ||||
|         writeln!(output_file, "{}", serde_json::to_string(&result)?)?; | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
| // fn main() -> Result<()> { | ||||
| //     let args = Args::parse(); | ||||
| // | ||||
| //     let file = File::open(&args.dataset)?; | ||||
| //     let reader = BufReader::new(file); | ||||
| //     let mut results = Vec::new(); | ||||
| // | ||||
| //     println!("🔧 HNSW Parameters:"); | ||||
| //     println!("   max_connections: {}", args.max_connections); | ||||
| //     println!("   max_level: {}", args.max_level); | ||||
| //     println!("   level_prob: {}", args.level_prob); | ||||
| //     println!("   ef_upper: {}", args.ef_upper); | ||||
| //     println!("   ef_bottom: {}", args.ef_bottom); | ||||
| //     let mut results = Vec::new(); | ||||
| // | ||||
| //     for line in reader.lines() { | ||||
| //         let line = line?; | ||||
| //         let dataset: Dataset = serde_json::from_str(&line)?; | ||||
| // | ||||
| //         let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections); | ||||
| //         for data_point in &dataset.data { | ||||
| //             hnsw.insert_node(data_point.clone(), args.level_prob); | ||||
| //         } | ||||
| // | ||||
| //         let nearest = if let Some(entry_point) = hnsw.entry_point { | ||||
| //             let mut search_results = vec![entry_point]; | ||||
| // | ||||
| //             // 上层搜索使用ef_upper参数 | ||||
| //             for layer in (1..=hnsw.nodes[entry_point].level).rev() { | ||||
| //                 search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer); | ||||
| //             } | ||||
| // | ||||
| //             // 底层搜索使用ef_bottom参数 | ||||
| //             let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0); | ||||
| //             final_results | ||||
| //                 .into_iter() | ||||
| //                 .take(10) | ||||
| //                 .map(|idx| idx + 1) | ||||
| //                 .collect() | ||||
| //         } else { | ||||
| //             Vec::new() | ||||
| //         }; | ||||
| // | ||||
| //         let nearest = knn_classify(&dataset.query, &dataset.data, 10); | ||||
| //         results.push(Prediction { answer: nearest }); | ||||
| //     } | ||||
| // | ||||
| //     let mut output_file = File::create(&args.predictions)?; | ||||
| //     for result in results { | ||||
| //         writeln!(output_file, "{}", serde_json::to_string(&result)?)?; | ||||
| //         println!("{}", serde_json::to_string(&result)?); | ||||
| //     } | ||||
| // | ||||
| //     Ok(()) | ||||
| // } | ||||
| fn main() -> Result<()> { | ||||
|     let args = Args::parse(); | ||||
|     let file = File::open(&args.dataset)?; | ||||
|     let reader = BufReader::new(file); | ||||
|     let mut results = Vec::new(); | ||||
|  | ||||
|     println!("🔧 HNSW Parameters:"); | ||||
|     println!("   max_connections: {}", args.max_connections); | ||||
|     println!("   max_level: {}", args.max_level); | ||||
|     println!("   level_prob: {}", args.level_prob); | ||||
|     println!("   ef_upper: {}", args.ef_upper); | ||||
|     println!("   ef_bottom: {}", args.ef_bottom); | ||||
|  | ||||
|     for line in reader.lines() { | ||||
|         let line = line?; | ||||
|         let dataset: Dataset = serde_json::from_str(&line)?; | ||||
|  | ||||
|         let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections); | ||||
|         for data_point in &dataset.data { | ||||
|             hnsw.insert_node(data_point.clone(), args.level_prob); | ||||
|         } | ||||
|  | ||||
|         let nearest = if let Some(entry_point) = hnsw.entry_point { | ||||
|             let mut search_results = vec![entry_point]; | ||||
|  | ||||
|             // 上层搜索使用ef_upper参数 | ||||
|             for layer in (1..=hnsw.nodes[entry_point].level).rev() { | ||||
|                 search_results = | ||||
|                     hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer); | ||||
|             } | ||||
|  | ||||
|             // 底层搜索使用ef_bottom参数 | ||||
|             let final_results = | ||||
|                 hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0); | ||||
|             final_results | ||||
|                 .into_iter() | ||||
|                 .take(10) | ||||
|                 .map(|idx| idx + 1) | ||||
|                 .collect() | ||||
|         } else { | ||||
|             Vec::new() | ||||
|         }; | ||||
|  | ||||
|         results.push(Prediction { answer: nearest }); | ||||
|     } | ||||
|  | ||||
|     let mut output_file = File::create(&args.predictions)?; | ||||
|     for result in results { | ||||
|         writeln!(output_file, "{}", serde_json::to_string(&result)?)?; | ||||
|         println!("{}", serde_json::to_string(&result)?); | ||||
|     } | ||||
|  | ||||
|     Ok(()) | ||||
| } | ||||
|   | ||||
							
								
								
									
										410
									
								
								src/data.rs
									
									
									
									
									
								
							
							
						
						
									
										410
									
								
								src/data.rs
									
									
									
									
									
								
							| @@ -1,5 +1,4 @@ | ||||
| use serde::{Deserialize, Serialize}; | ||||
| use std::collections::HashSet; | ||||
| use tfhe::prelude::*; | ||||
| use tfhe::{ClientKey, FheInt14, FheUint8}; | ||||
|  | ||||
| @@ -130,165 +129,250 @@ pub fn decrypt_indices(encrypted_indices: &[FheUint8], client_key: &ClientKey) - | ||||
|         .collect() | ||||
| } | ||||
|  | ||||
| /// 密文版 HNSW 节点 | ||||
| #[derive(Clone)] | ||||
| pub struct FheHnswNode { | ||||
|     pub encrypted_point: EncryptedPoint, | ||||
|     pub level: usize, | ||||
|     pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定) | ||||
| } | ||||
|  | ||||
| /// 密文版 HNSW 图 | ||||
| #[derive(Clone)] | ||||
| pub struct FheHnswGraph { | ||||
|     pub nodes: Vec<FheHnswNode>, | ||||
|     pub entry_point: Option<usize>, | ||||
|     pub max_level: usize, | ||||
| } | ||||
|  | ||||
| impl FheHnswGraph { | ||||
|     pub fn new() -> Self { | ||||
|         Self { | ||||
|             nodes: Vec::new(), | ||||
|             entry_point: None, | ||||
|             max_level: 0, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     /// 密文版搜索层函数 | ||||
|     /// | ||||
|     /// 核心算法:实现标准HNSW贪心搜索,但使用密文距离计算 | ||||
|     /// | ||||
|     /// 算法流程: | ||||
|     /// 1. 初始化候选队列(candidates)和结果集(w) | ||||
|     /// 2. 从entry_points开始,计算到query的密文距离 | ||||
|     /// 3. 贪心搜索循环: | ||||
|     ///    - 从candidates中选择距离最小的点作为当前探索点 | ||||
|     ///    - 探索该点的所有邻居 | ||||
|     ///    - 将未访问的邻居加入candidates和w | ||||
|     ///    - 维护w的大小不超过ef(移除最远的点) | ||||
|     ///    - 实现剪枝:如果当前点比w中最远点还远且w已满,则停止 | ||||
|     /// 4. 返回w中的节点索引 | ||||
|     /// | ||||
|     /// 关键挑战: | ||||
|     /// - 需要模拟优先队列但只能用密文排序 | ||||
|     /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能 | ||||
|     /// - 候选队列管理要高效,避免过多密文运算 | ||||
|     pub fn search_layer( | ||||
|         &self, | ||||
|         _query: &EncryptedQuery, | ||||
|         entry_points: Vec<usize>, | ||||
|         ef: usize, | ||||
|         layer: usize, | ||||
|         _zero: &FheInt14, | ||||
|     ) -> Vec<usize> { | ||||
|         let _visited: HashSet<usize> = HashSet::new(); | ||||
|  | ||||
|         // TODO: 设计合适的数据结构 | ||||
|         // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构 | ||||
|         // w: 结果集,维护当前找到的最好的ef个候选点 | ||||
|  | ||||
|         println!( | ||||
|             "🔍 Initializing search layer {} with {} entry points", | ||||
|             layer, | ||||
|             entry_points.len() | ||||
|         ); | ||||
|  | ||||
|         // TODO: Step 1 - 初始化候选点 | ||||
|         // 对每个entry_point: | ||||
|         //   - 检查节点是否存在且level >= layer | ||||
|         //   - 计算到query的密文距离:euclidean_distance(query, &self.nodes[ep].encrypted_point, zero) | ||||
|         //   - 将节点加入visited集合 | ||||
|         //   - 将(节点索引, 距离)加入candidates和w | ||||
|  | ||||
|         println!( | ||||
|             "🔍 Starting search with {} initial candidates", | ||||
|             // TODO: 显示实际初始化的候选点数量 | ||||
|             0 | ||||
|         ); | ||||
|  | ||||
|         // TODO: Step 2 - 主搜索循环 | ||||
|         // while candidates不为空: | ||||
|         //   - 从candidates中找到距离最小的点(需要密文排序或其他方法) | ||||
|         //   - 将该点从candidates中移除,设为当前探索点current | ||||
|         // | ||||
|         //   - 剪枝检查(可选,复杂): | ||||
|         //     如果w.len() >= ef且current的距离 > w中最远点的距离,则break | ||||
|         // | ||||
|         //   - 探索current的邻居: | ||||
|         //     for neighbor in self.nodes[current].neighbors[layer]: | ||||
|         //       - 如果neighbor未访问过: | ||||
|         //         - 标记为已访问 | ||||
|         //         - 计算到query的密文距离 | ||||
|         //         - 将neighbor加入candidates | ||||
|         //         - 管理结果集w: | ||||
|         //           if w.len() < ef: 直接加入w | ||||
|         //           else: 加入w后排序,移除最远的点,保持w大小为ef | ||||
|  | ||||
|         println!( | ||||
|             "🔍 Found {} total candidates (ef={})", | ||||
|             // TODO: 显示最终找到的候选点数量 | ||||
|             0, | ||||
|             ef | ||||
|         ); | ||||
|  | ||||
|         // TODO: Step 3 - 返回结果 | ||||
|         // 从w中提取节点索引并返回 | ||||
|         // w.into_iter().map(|(node_idx, _)| node_idx).collect() | ||||
|  | ||||
|         // 临时返回空结果,避免编译错误 | ||||
|         Vec::new() | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl Default for FheHnswGraph { | ||||
|     fn default() -> Self { | ||||
|         Self::new() | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// 从明文数据构建密文 HNSW 图的辅助结构 | ||||
| #[derive(Clone)] | ||||
| pub struct PlaintextHnswNode { | ||||
|     pub vector: Vec<f64>, | ||||
|     pub level: usize, | ||||
|     pub neighbors: Vec<Vec<usize>>, | ||||
| } | ||||
|  | ||||
| /// 从明文数据构建密文 HNSW 图 | ||||
| pub fn build_fhe_hnsw_from_plaintext( | ||||
|     plaintext_nodes: &[PlaintextHnswNode], | ||||
|     plaintext_entry_point: Option<usize>, | ||||
|     plaintext_max_level: usize, | ||||
|     client_key: &ClientKey, | ||||
| ) -> FheHnswGraph { | ||||
|     use crate::logging::print_progress_bar; | ||||
|  | ||||
|     let mut fhe_nodes = Vec::new(); | ||||
|     let total_nodes = plaintext_nodes.len(); | ||||
|  | ||||
|     println!("🔐 Encrypting {total_nodes} HNSW nodes..."); | ||||
|  | ||||
|     for (idx, plain_node) in plaintext_nodes.iter().enumerate() { | ||||
|         print_progress_bar(idx + 1, total_nodes, "Encrypting nodes"); | ||||
|         let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key); | ||||
|         let fhe_node = FheHnswNode { | ||||
|             encrypted_point, | ||||
|             level: plain_node.level, | ||||
|             neighbors: plain_node.neighbors.clone(), | ||||
|         }; | ||||
|         fhe_nodes.push(fhe_node); | ||||
|     } | ||||
|  | ||||
|     println!(); // 新行 | ||||
|     println!( | ||||
|         "✅ FHE HNSW graph created with {} encrypted nodes", | ||||
|         fhe_nodes.len() | ||||
|     ); | ||||
|  | ||||
|     FheHnswGraph { | ||||
|         nodes: fhe_nodes, | ||||
|         entry_point: plaintext_entry_point, | ||||
|         max_level: plaintext_max_level, | ||||
|     } | ||||
| } | ||||
| ///// 密文版 HNSW 节点 | ||||
| //#[derive(Clone)] | ||||
| //pub struct FheHnswNode { | ||||
| //    pub encrypted_point: EncryptedPoint, | ||||
| //    pub level: usize, | ||||
| //    pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定) | ||||
| //} | ||||
| // | ||||
| ///// 密文版 HNSW 图 | ||||
| //#[derive(Clone)] | ||||
| //pub struct FheHnswGraph { | ||||
| //    pub nodes: Vec<FheHnswNode>, | ||||
| //    pub entry_point: Option<usize>, | ||||
| //    pub max_level: usize, | ||||
| //    pub distances: [Option<FheInt14>; 100], | ||||
| //} | ||||
| // | ||||
| //impl FheHnswGraph { | ||||
| //    pub fn new() -> Self { | ||||
| //        Self { | ||||
| //            nodes: Vec::new(), | ||||
| //            entry_point: None, | ||||
| //            max_level: 0, | ||||
| //            distances: [const { None };100] | ||||
| //        } | ||||
| //    } | ||||
| // | ||||
| //    /// 搜索层函数 | ||||
| //    /// | ||||
| //    /// 核心算法:实现标准HNSW贪心搜索 | ||||
| //    /// | ||||
| //    /// 算法流程: | ||||
| //    /// 1. 初始化候选队列(candidates)和结果集(w) | ||||
| //    /// 2. 从entry_points开始,计算到query的密文距离 | ||||
| //    /// 3. 贪心搜索循环: | ||||
| //    ///    - 从candidates中选择距离最小的点作为当前探索点 | ||||
| //    ///    - 探索该点的所有邻居 | ||||
| //    ///    - 将未访问的邻居加入candidates和w | ||||
| //    ///    - 维护w的大小不超过ef(移除最远的点) | ||||
| //    ///    - 实现剪枝:如果当前点比w中最远点还远且w已满,则停止 | ||||
| //    /// 4. 返回w中的节点索引 | ||||
| //    /// | ||||
| //    /// 关键挑战: | ||||
| //    /// - 需要模拟优先队列但只能用密文排序 | ||||
| //    /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能 | ||||
| //    /// - 候选队列管理要高效,避免过多密文运算 | ||||
| //    pub fn search_layer( | ||||
| //        &self, | ||||
| //        query: &EncryptedQuery, | ||||
| //        entry_points: Vec<usize>, | ||||
| //        ef: usize, | ||||
| //        layer: usize, | ||||
| //        zero: &FheInt14, | ||||
| //    ) -> Vec<usize> { | ||||
| //        let visited: HashSet<usize> = HashSet::new(); | ||||
| //        let mut candidate = | ||||
| // | ||||
| //        // TODO: 设计合适的数据结构 | ||||
| //        // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构 | ||||
| //        // w: 结果集,维护当前找到的最好的ef个候选点 | ||||
| // | ||||
| //        // TODO: Step 1 - 初始化候选点 | ||||
| //        // 对每个entry_point: | ||||
| //        //   - 检查节点是否存在且level >= layer | ||||
| //        //   - 计算到query的密文距离:euclidean_distance(query, &self.nodes[ep].encrypted_point, zero) | ||||
| //        //   - 将节点加入visited集合 | ||||
| //        //   - 将(节点索引, 距离)加入candidates和w | ||||
| // | ||||
| //        println!( | ||||
| //            "🔍 Starting search with {} initial candidates", | ||||
| //            // TODO: 显示实际初始化的候选点数量 | ||||
| //            0 | ||||
| //        ); | ||||
| // | ||||
| //        // TODO: Step 2 - 主搜索循环 | ||||
| //        // while candidates不为空: | ||||
| //        //   - 从candidates中找到距离最小的点(需要密文排序或其他方法) | ||||
| //        //   - 将该点从candidates中移除,设为当前探索点current | ||||
| //        // | ||||
| //        //   - 剪枝检查(可选,复杂): | ||||
| //        //     如果w.len() >= ef且current的距离 > w中最远点的距离,则break | ||||
| //        // | ||||
| //        //   - 探索current的邻居: | ||||
| //        //     for neighbor in self.nodes[current].neighbors[layer]: | ||||
| //        //       - 如果neighbor未访问过: | ||||
| //        //         - 标记为已访问 | ||||
| //        //         - 计算到query的密文距离 | ||||
| //        //         - 将neighbor加入candidates | ||||
| //        //         - 管理结果集w: | ||||
| //        //           if w.len() < ef: 直接加入w | ||||
| //        //           else: 加入w后排序,移除最远的点,保持w大小为ef | ||||
| // | ||||
| //        println!( | ||||
| //            "🔍 Found {} total candidates (ef={})", | ||||
| //            // TODO: 显示最终找到的候选点数量 | ||||
| //            0, | ||||
| //            ef | ||||
| //        ); | ||||
| // | ||||
| //        // TODO: Step 3 - 返回结果 | ||||
| //        // 从w中提取节点索引并返回 | ||||
| //        // w.into_iter().map(|(node_idx, _)| node_idx).collect() | ||||
| // | ||||
| //        // 临时返回空结果,避免编译错误 | ||||
| //        Vec::new() | ||||
| //    } | ||||
| //} | ||||
| // | ||||
| //impl Default for FheHnswGraph { | ||||
| //    fn default() -> Self { | ||||
| //        Self::new() | ||||
| //    } | ||||
| //} | ||||
| // | ||||
| ///// 从明文数据构建密文 HNSW 图的辅助结构 | ||||
| //#[derive(Clone)] | ||||
| //pub struct PlaintextHnswNode { | ||||
| //    pub vector: Vec<f64>, | ||||
| //    pub level: usize, | ||||
| //    pub neighbors: Vec<Vec<usize>>, | ||||
| //} | ||||
| // | ||||
| ///// 从明文数据构建密文 HNSW 图 | ||||
| //pub fn build_fhe_hnsw_from_plaintext( | ||||
| //    plaintext_nodes: &[PlaintextHnswNode], | ||||
| //    plaintext_entry_point: Option<usize>, | ||||
| //    plaintext_max_level: usize, | ||||
| //    client_key: &ClientKey, | ||||
| //) -> FheHnswGraph { | ||||
| //    use crate::logging::print_progress_bar; | ||||
| // | ||||
| //    let mut fhe_nodes = Vec::new(); | ||||
| //    let total_nodes = plaintext_nodes.len(); | ||||
| // | ||||
| //    println!("🔐 Encrypting {total_nodes} HNSW nodes..."); | ||||
| // | ||||
| //    for (idx, plain_node) in plaintext_nodes.iter().enumerate() { | ||||
| //        print_progress_bar(idx + 1, total_nodes, "Encrypting nodes"); | ||||
| //        let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key); | ||||
| //        let fhe_node = FheHnswNode { | ||||
| //            encrypted_point, | ||||
| //            level: plain_node.level, | ||||
| //            neighbors: plain_node.neighbors.clone(), | ||||
| //        }; | ||||
| //        fhe_nodes.push(fhe_node); | ||||
| //    } | ||||
| // | ||||
| // | ||||
| //    FheHnswGraph { | ||||
| //        nodes: fhe_nodes, | ||||
| //        entry_point: plaintext_entry_point, | ||||
| //        max_level: plaintext_max_level, | ||||
| //    } | ||||
| //} | ||||
| // | ||||
| ///// 构建明文HNSW图 | ||||
| //pub fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) { | ||||
| //    println!("🔨 Building HNSW graph from {} points...", data.len()); | ||||
| // | ||||
| //    let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量 | ||||
| //    let max_level = 3; // 最优参数:保持3层结构 | ||||
| //    let mut nodes = Vec::new(); | ||||
| //    let mut entry_point = None; | ||||
| //    let mut current_max_level = 0; | ||||
| // | ||||
| //    // 创建所有节点 | ||||
| //    for (idx, vector) in data.iter().enumerate() { | ||||
| //        let level = select_level_for_node(max_level); | ||||
| //        current_max_level = current_max_level.max(level); | ||||
| // | ||||
| //        let node = PlaintextHnswNode { | ||||
| //            vector: vector.clone(), | ||||
| //            level, | ||||
| //            neighbors: vec![Vec::new(); level + 1], | ||||
| //        }; | ||||
| //        nodes.push(node); | ||||
| // | ||||
| //        if idx == 0 { | ||||
| //            entry_point = Some(idx); | ||||
| //        } | ||||
| //    } | ||||
| // | ||||
| //    // 为每个节点建立连接(简化版实现) | ||||
| //    for node_id in 0..nodes.len() { | ||||
| //        print_progress_bar(node_id + 1, nodes.len(), "Building connections"); | ||||
| // | ||||
| //        let node_vector = nodes[node_id].vector.clone(); | ||||
| //        let node_level = nodes[node_id].level; | ||||
| // | ||||
| //        // 为每层找到最近的邻居 | ||||
| //        for layer in 0..=node_level { | ||||
| //            let mut distances: Vec<(f64, usize)> = nodes | ||||
| //                .iter() | ||||
| //                .enumerate() | ||||
| //                .filter(|(idx, n)| *idx != node_id && n.level >= layer) | ||||
| //                .map(|(idx, n)| { | ||||
| //                    let dist = euclidean_distance_plaintext(&node_vector, &n.vector); | ||||
| //                    (dist, idx) | ||||
| //                }) | ||||
| //                .collect(); | ||||
| // | ||||
| //            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); | ||||
| // | ||||
| //            // 添加最近的邻居(限制连接数) | ||||
| //            for &(_, neighbor_id) in distances.iter().take(max_connections) { | ||||
| //                nodes[node_id].neighbors[layer].push(neighbor_id); | ||||
| //            } | ||||
| //        } | ||||
| // | ||||
| //        // 更新入口点 | ||||
| //        if nodes[node_id].level > nodes[entry_point.unwrap()].level { | ||||
| //            entry_point = Some(node_id); | ||||
| //        } | ||||
| //    } | ||||
| // | ||||
| //    println!(); // 新行 | ||||
| //    println!( | ||||
| //        "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}", | ||||
| //        nodes.len(), | ||||
| //        entry_point, | ||||
| //        current_max_level | ||||
| //    ); | ||||
| // | ||||
| //    (nodes, entry_point, current_max_level) | ||||
| //} | ||||
| // | ||||
| ///// 选择节点的层级 | ||||
| //pub fn select_level_for_node(max_level: usize) -> usize { | ||||
| //    let mut rng = rand::rng(); | ||||
| //    let mut level = 0; | ||||
| //    while level < max_level && rng.random::<f32>() < 0.6 { | ||||
| //        // 最优参数:提高层级概率到0.6 | ||||
| //        level += 1; | ||||
| //    } | ||||
| //    level | ||||
| //} | ||||
| // | ||||
| // | ||||
| // | ||||
| ///// 明文欧几里得距离计算 | ||||
| //fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 { | ||||
| //    a.iter() | ||||
| //        .zip(b.iter()) | ||||
| //        .map(|(x, y)| (x - y).powi(2)) | ||||
| //        .sum::<f64>() | ||||
| //        .sqrt() | ||||
| //} | ||||
|   | ||||
							
								
								
									
										338
									
								
								test_fhe_hnsw.py
									
									
									
									
									
								
							
							
						
						
									
										338
									
								
								test_fhe_hnsw.py
									
									
									
									
									
								
							| @@ -1,338 +0,0 @@ | ||||
| #!/usr/bin/env python3 | ||||
| """ | ||||
| FHE HNSW并行测试脚本 - 多进程版本 | ||||
| 测试密文态HNSW实现的稳定性和准确率 | ||||
| 支持多进程并行执行,充分利用服务器CPU资源 | ||||
| 运行100次,记录≥90%准确率的成功率和结果长度分布 | ||||
| """ | ||||
|  | ||||
| import subprocess | ||||
| import json | ||||
| import time | ||||
| import os | ||||
| import multiprocessing as mp | ||||
| from pathlib import Path | ||||
| from collections import defaultdict | ||||
|  | ||||
| # 正确答案 | ||||
| CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28] | ||||
|  | ||||
|  | ||||
| def calculate_accuracy(result, correct): | ||||
|     """计算准确率:匹配元素数量 / 总元素数量""" | ||||
|     matches = len(set(result) & set(correct)) | ||||
|     return matches / len(correct) * 100 | ||||
|  | ||||
|  | ||||
| def run_single_test(test_id, data_bit_width="12", timeout_minutes=30): | ||||
|     """运行单次FHE HNSW测试(供多进程调用)""" | ||||
|     # 为每个进程创建独立的输出文件 | ||||
|     output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl" | ||||
|  | ||||
|     cmd = [ | ||||
|         "./enc", | ||||
|         "--algorithm", | ||||
|         "hnsw", | ||||
|         "--data-bit-width", | ||||
|         data_bit_width, | ||||
|         "--predictions", | ||||
|         output_file, | ||||
|     ] | ||||
|  | ||||
|     start_time = time.time() | ||||
|  | ||||
|     try: | ||||
|         result = subprocess.run( | ||||
|             cmd, capture_output=True, text=True, timeout=timeout_minutes * 60 | ||||
|         ) | ||||
|  | ||||
|         test_time = time.time() - start_time | ||||
|  | ||||
|         if result.returncode != 0: | ||||
|             return { | ||||
|                 "test_id": test_id, | ||||
|                 "result": None, | ||||
|                 "error": f"Command failed: {result.stderr[:200]}", | ||||
|                 "time_minutes": test_time / 60, | ||||
|             } | ||||
|  | ||||
|         # 读取结果 | ||||
|         if not Path(output_file).exists(): | ||||
|             return { | ||||
|                 "test_id": test_id, | ||||
|                 "result": None, | ||||
|                 "error": "Output file not found", | ||||
|                 "time_minutes": test_time / 60, | ||||
|             } | ||||
|  | ||||
|         with open(output_file, "r") as f: | ||||
|             output = json.load(f) | ||||
|             answer = output["answer"] | ||||
|  | ||||
|         # 清理临时文件 | ||||
|         Path(output_file).unlink(missing_ok=True) | ||||
|  | ||||
|         # 计算准确率 | ||||
|         accuracy = calculate_accuracy(answer, CORRECT_ANSWER) | ||||
|  | ||||
|         return { | ||||
|             "test_id": test_id, | ||||
|             "result": answer, | ||||
|             "length": len(answer), | ||||
|             "accuracy": accuracy, | ||||
|             "success": accuracy >= 90, | ||||
|             "time_minutes": test_time / 60, | ||||
|             "error": None, | ||||
|         } | ||||
|  | ||||
|     except subprocess.TimeoutExpired: | ||||
|         # 清理临时文件 | ||||
|         Path(output_file).unlink(missing_ok=True) | ||||
|         return { | ||||
|             "test_id": test_id, | ||||
|             "result": None, | ||||
|             "error": "timeout", | ||||
|             "time_minutes": timeout_minutes, | ||||
|         } | ||||
|     except Exception as e: | ||||
|         # 清理临时文件 | ||||
|         Path(output_file).unlink(missing_ok=True) | ||||
|         return { | ||||
|             "test_id": test_id, | ||||
|             "result": None, | ||||
|             "error": str(e), | ||||
|             "time_minutes": (time.time() - start_time) / 60, | ||||
|         } | ||||
|  | ||||
|  | ||||
| def print_progress(completed, total, start_time): | ||||
|     """打印进度信息""" | ||||
|     if completed == 0: | ||||
|         return | ||||
|  | ||||
|     elapsed = time.time() - start_time | ||||
|     avg_time = elapsed / completed | ||||
|     remaining_time = avg_time * (total - completed) | ||||
|  | ||||
|     print( | ||||
|         f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) " | ||||
|         f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h", | ||||
|         end="", | ||||
|         flush=True, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     """主测试函数""" | ||||
|     print("🚀 FHE HNSW 并行批量测试脚本") | ||||
|     print(f"🎯 正确答案: {CORRECT_ANSWER}") | ||||
|     print("📊 运行100次测试,记录准确率和结果长度分布") | ||||
|     print("🔧 支持多进程并行执行") | ||||
|  | ||||
|     # 检查enc二进制文件是否存在 | ||||
|     if not Path("./enc").exists(): | ||||
|         print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:") | ||||
|         print("   cargo build --release --bin enc") | ||||
|         return | ||||
|  | ||||
|     # 获取CPU核心数并设置进程数 | ||||
|     cpu_cores = mp.cpu_count() | ||||
|     # 考虑到FHE运算的内存密集性,使用核心数的一半避免内存不足 | ||||
|     num_processes = 8 | ||||
|     timeout_minutes = 45  # 增加超时时间到45分钟 | ||||
|  | ||||
|     print(f"💻 使用 {num_processes} 个并行进程") | ||||
|     print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟") | ||||
|     print( | ||||
|         f"⏱️  预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时" | ||||
|     ) | ||||
|     print() | ||||
|  | ||||
|     print("=" * 80) | ||||
|     print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)") | ||||
|     print("=" * 80) | ||||
|  | ||||
|     start_time = time.time() | ||||
|  | ||||
|     # 创建进程池并执行测试 | ||||
|     test_ids = list(range(1, 101))  # 1-100 | ||||
|  | ||||
|     with mp.Pool(processes=num_processes) as pool: | ||||
|         # 创建异步任务 | ||||
|         async_results = [] | ||||
|         for test_id in test_ids: | ||||
|             async_result = pool.apply_async( | ||||
|                 run_single_test, (test_id, "12", timeout_minutes) | ||||
|             ) | ||||
|             async_results.append(async_result) | ||||
|  | ||||
|         # 收集结果并显示进度 | ||||
|         results = [] | ||||
|         completed = 0 | ||||
|  | ||||
|         print_progress(completed, len(test_ids), start_time) | ||||
|  | ||||
|         for async_result in async_results: | ||||
|             try: | ||||
|                 result = async_result.get( | ||||
|                     timeout=timeout_minutes * 60 + 60 | ||||
|                 )  # 额外1分钟缓冲 | ||||
|                 results.append(result) | ||||
|                 completed += 1 | ||||
|                 print_progress(completed, len(test_ids), start_time) | ||||
|             except mp.TimeoutError: | ||||
|                 # 进程级别超时 | ||||
|                 results.append( | ||||
|                     { | ||||
|                         "test_id": len(results) + 1, | ||||
|                         "result": None, | ||||
|                         "error": "process_timeout", | ||||
|                         "time_minutes": timeout_minutes, | ||||
|                     } | ||||
|                 ) | ||||
|                 completed += 1 | ||||
|                 print_progress(completed, len(test_ids), start_time) | ||||
|  | ||||
|     print()  # 换行 | ||||
|     total_elapsed = time.time() - start_time | ||||
|  | ||||
|     # 分析结果 | ||||
|     print("\n" + "=" * 80) | ||||
|     print("📈 测试结果分析") | ||||
|     print("=" * 80) | ||||
|  | ||||
|     # 统计变量 | ||||
|     valid_results = [r for r in results if r["result"] is not None] | ||||
|     success_count = sum(1 for r in valid_results if r["success"]) | ||||
|     length_distribution = defaultdict(int) | ||||
|     error_distribution = defaultdict(int) | ||||
|  | ||||
|     # 统计错误类型 | ||||
|     for r in results: | ||||
|         if r["error"]: | ||||
|             error_type = r["error"] | ||||
|             if "timeout" in error_type.lower(): | ||||
|                 error_distribution["timeout"] += 1 | ||||
|             elif "failed" in error_type.lower(): | ||||
|                 error_distribution["failed"] += 1 | ||||
|             else: | ||||
|                 error_distribution["other"] += 1 | ||||
|         else: | ||||
|             length_distribution[r["length"]] += 1 | ||||
|  | ||||
|     # 基本统计 | ||||
|     total_tests = len(results) | ||||
|     valid_tests = len(valid_results) | ||||
|  | ||||
|     if valid_tests == 0: | ||||
|         print("❌ 没有有效的测试结果") | ||||
|         return | ||||
|  | ||||
|     success_rate = success_count / valid_tests * 100 | ||||
|     avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests | ||||
|     avg_length = sum(r["length"] for r in valid_results) / valid_tests | ||||
|     avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests | ||||
|  | ||||
|     print(f"总测试次数: {total_tests}") | ||||
|     print(f"有效测试次数: {valid_tests}") | ||||
|     print(f"成功次数 (≥90%准确率): {success_count}") | ||||
|     print(f"成功率: {success_rate:.1f}%") | ||||
|     print(f"平均准确率: {avg_accuracy:.1f}%") | ||||
|     print(f"平均结果长度: {avg_length:.1f}") | ||||
|     print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟") | ||||
|     print(f"总测试时间: {total_elapsed/3600:.1f}小时") | ||||
|     print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x") | ||||
|  | ||||
|     # 错误统计 | ||||
|     if error_distribution: | ||||
|         print("\n❌ 错误分布:") | ||||
|         for error_type, count in error_distribution.items(): | ||||
|             print(f"  {error_type}: {count}次") | ||||
|  | ||||
|     # 结果长度分布 | ||||
|     if length_distribution: | ||||
|         print("\n📊 结果长度分布:") | ||||
|         for length in sorted(length_distribution.keys()): | ||||
|             count = length_distribution[length] | ||||
|             percentage = count / valid_tests * 100 | ||||
|             bar = "█" * (count // 2) if count > 0 else "" | ||||
|             print(f"  长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}") | ||||
|  | ||||
|     # 结论 | ||||
|     print() | ||||
|     if success_rate >= 50: | ||||
|         print("✅ 测试通过! FHE HNSW实现稳定性良好") | ||||
|         print(f"   - 成功率: {success_rate:.1f}% (≥50%)") | ||||
|         print(f"   - 平均准确率: {avg_accuracy:.1f}%") | ||||
|         if avg_length >= 9.5: | ||||
|             print(f"   - 平均结果长度: {avg_length:.1f} (接近10)") | ||||
|         else: | ||||
|             print(f"   - 平均结果长度: {avg_length:.1f} (需要改进)") | ||||
|     else: | ||||
|         print("❌ 测试未通过,需要进一步优化") | ||||
|         print(f"   - 成功率: {success_rate:.1f}% (<50%)") | ||||
|         print(f"   - 平均准确率: {avg_accuracy:.1f}%") | ||||
|  | ||||
|     # 保存详细结果 | ||||
|     report_file = "fhe_hnsw_parallel_test_report.json" | ||||
|     with open(report_file, "w") as f: | ||||
|         json.dump( | ||||
|             { | ||||
|                 "summary": { | ||||
|                     "total_tests": total_tests, | ||||
|                     "valid_tests": valid_tests, | ||||
|                     "success_count": success_count, | ||||
|                     "success_rate": success_rate, | ||||
|                     "avg_accuracy": avg_accuracy, | ||||
|                     "avg_length": avg_length, | ||||
|                     "avg_time_minutes": avg_time_per_test, | ||||
|                     "total_time_hours": total_elapsed / 3600, | ||||
|                     "num_processes": num_processes, | ||||
|                     "cpu_cores": cpu_cores, | ||||
|                     "speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600), | ||||
|                 }, | ||||
|                 "length_distribution": dict(length_distribution), | ||||
|                 "error_distribution": dict(error_distribution), | ||||
|                 "detailed_results": results, | ||||
|                 "correct_answer": CORRECT_ANSWER, | ||||
|             }, | ||||
|             f, | ||||
|             indent=2, | ||||
|         ) | ||||
|  | ||||
|     print(f"\n📁 详细报告已保存到: {report_file}") | ||||
|  | ||||
|     # 显示最好和最坏的几个结果 | ||||
|     if valid_results: | ||||
|         print("\n🏆 最高准确率的5个结果:") | ||||
|         best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[ | ||||
|             :5 | ||||
|         ] | ||||
|         for i, r in enumerate(best_results, 1): | ||||
|             print( | ||||
|                 f"  {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)" | ||||
|             ) | ||||
|  | ||||
|         print("\n⚠️  最低准确率的5个结果:") | ||||
|         worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5] | ||||
|         for i, r in enumerate(worst_results, 1): | ||||
|             print( | ||||
|                 f"  {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)" | ||||
|             ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     try: | ||||
|         # 设置多进程启动方法(Linux上默认是fork,但spawn更安全) | ||||
|         mp.set_start_method("spawn", force=True) | ||||
|         main() | ||||
|     except KeyboardInterrupt: | ||||
|         print("\n\n⏹️  测试被用户中断") | ||||
|         # 清理可能的临时文件 | ||||
|         for f in Path(".").glob("test_fhe_output_*.jsonl"): | ||||
|             f.unlink(missing_ok=True) | ||||
|     except Exception as e: | ||||
|         print(f"\n💥 程序异常: {e}") | ||||
|         # 清理可能的临时文件 | ||||
|         for f in Path(".").glob("test_fhe_output_*.jsonl"): | ||||
|             f.unlink(missing_ok=True) | ||||
							
								
								
									
										32
									
								
								writeup.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								writeup.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| # 0xfa队 writeup | ||||
|  | ||||
| ## 全同态加密算法介绍 | ||||
|  | ||||
| 加密算法方面选择了thfe算法,库方面选择了较为成熟的thfe-rs算法。 | ||||
|  | ||||
| ### 算法参数 | ||||
|  | ||||
| Message bits: 2位 | ||||
|  | ||||
| Carry bits: 2位 | ||||
|  | ||||
| 噪声分布: TUniform (tweaked uniform) | ||||
|  | ||||
| Bootstrap失败概率: ≤ 2^-128 (CPU后端) | ||||
|  | ||||
| ## knn算法实现细节 | ||||
|  | ||||
| 将欧式距离公式拆分: | ||||
|  | ||||
| > sum((a-b)^2)=sum(a^2) - sum(2a\*b) + sum(b^2) | ||||
|  | ||||
| 减少了密文态的乘法和加法操作 | ||||
|  | ||||
| 选择上实现了双调排序,将100个距离结果,用最大值填充至128个结果。 | ||||
| 然后进行并行排序,最后选择前十个密文。 | ||||
|  | ||||
| 两个操作都使用了rayon库做多核并行计算 | ||||
|  | ||||
| ## 本地测试结果 | ||||
|  | ||||
| 在本地i9-10920X(12核24线程)情况下,运行时间约9min(4min+5min) | ||||
		Reference in New Issue
	
	Block a user