diff --git a/src/algorithms.rs b/src/algorithms.rs index 07db1de..df92e9a 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -4,7 +4,7 @@ use crate::logging::{format_duration, print_progress_bar}; use rayon::prelude::*; use std::time::Instant; use tfhe::prelude::*; -use tfhe::{FheUint8, FheUint12}; +use tfhe::{FheInt14, FheUint8}; /// 优化的欧几里得距离计算:使用预计算的平方和 /// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ) @@ -19,8 +19,8 @@ use tfhe::{FheUint8, FheUint12}; pub fn euclidean_distance( query: &EncryptedQuery, point: &EncryptedPoint, - zero: &FheUint12, -) -> FheUint12 { + zero: &FheInt14, +) -> FheInt14 { // 计算 Σ((2*aᵢ)*bᵢ) let mut cross_product_sum = zero.clone(); for (double_a, b) in query.double_coords.iter().zip(&point.coords) { @@ -45,50 +45,34 @@ pub fn euclidean_distance( pub fn compute_distances( query: &EncryptedQuery, points: &[EncryptedPoint], - zero: &FheUint12, + zero: &FheInt14, ) -> Vec { - println!("🔢 Computing encrypted distances..."); - println!("🏭 Pre-encrypting constants for optimization..."); + println!("🔢 Computing encrypted distances in parallel..."); let total_points = points.len(); - let mut distances = Vec::with_capacity(total_points); + let computation_start = Instant::now(); - for (i, point) in points.iter().enumerate() { - print_progress_bar(i + 1, total_points, "Distance calculation"); - let dist_start = Instant::now(); + // 确保并行计算结果顺序一致性 + let mut distances = vec![None; total_points]; + distances.par_iter_mut().enumerate().for_each(|(i, slot)| { + let point = &points[i]; let distance = euclidean_distance(query, point, zero); - let dist_time = dist_start.elapsed(); - - if i == 0 { - let estimated_total = dist_time.as_secs_f64() * total_points as f64; - println!( - "\n⏱️ First distance took {:.1}s, estimated total: {:.1} minutes", - dist_time.as_secs_f64(), - estimated_total / 60.0 - ); - } - - // Show progress every 25 points for long calculations - if (i + 1) % 25 == 0 && i > 0 { - let elapsed = dist_time.as_secs_f64() * (i + 1) as f64; - let remaining_points = total_points - i - 1; - let estimated_remaining = dist_time.as_secs_f64() * remaining_points as f64; - println!( - "\n🕐 Completed {}/{} distances in {:.1}m, estimated {:.1}m remaining", - i + 1, - total_points, - elapsed / 60.0, - estimated_remaining / 60.0 - ); - } - - distances.push(EncryptedNeighbor { + *slot = Some(EncryptedNeighbor { distance, index: point.index.clone(), }); - } - println!(); // New line after progress bar + }); + + let computation_time = computation_start.elapsed(); + println!( + "✅ Parallel distance computation completed in {} for {} points", + format_duration(computation_time), + total_points + ); + + // 提取结果,确保所有计算都完成 + let distances: Vec = distances.into_iter().map(|opt| opt.unwrap()).collect(); distances } @@ -108,7 +92,7 @@ pub fn perform_knn_selection( distances: &mut Vec, k: usize, algorithm: &str, - max_distance: Option<&FheUint12>, + max_distance: Option<&FheInt14>, max_index: Option<&FheUint8>, ) -> Vec { match algorithm { @@ -175,7 +159,7 @@ pub fn encrypted_selection_sort(distances: &mut Vec, k: usize pub fn encrypted_bitonic_sort( distances: &mut Vec, k: usize, - max_distance: &FheUint12, + max_distance: &FheInt14, max_index: &FheUint8, ) { println!("🔄 Starting bitonic sort..."); @@ -201,7 +185,7 @@ pub fn encrypted_bitonic_sort( } } - bitonic_sort_encrypted(distances, true); + bitonic_sort_encrypted(distances, true); // true表示升序,最小距离在前 // 只保留前k个结果 distances.truncate(k); @@ -263,31 +247,22 @@ pub fn encrypted_heap_select(distances: &mut Vec, k: usize) { /// * `arr` - 待排序的加密邻居数组 /// * `up` - 排序方向,true为升序,false为降序 fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) { - bitonic_sort_encrypted_recursive(arr, up, 0); + bitonic_sort_encrypted_recursive(arr, up); } -/// 双调排序的递归实现,带深度跟踪 -fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool, depth: usize) { +/// 双调排序的递归实现 +fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool) { if arr.len() <= 1 { return; } - // 只在最顶层显示进度 - if depth == 0 && arr.len() > 50 { - println!( - "🔀 Bitonic sort depth {}: processing {} elements", - depth, - arr.len() - ); - } - let mid = arr.len() / 2; // 并行执行两个递归调用 - server key已通过rayon::broadcast设置 let (left, right) = arr.split_at_mut(mid); rayon::join( - || bitonic_sort_encrypted_recursive(left, true, depth + 1), - || bitonic_sort_encrypted_recursive(right, false, depth + 1), + || bitonic_sort_encrypted_recursive(left, true), + || bitonic_sort_encrypted_recursive(right, false), ); bitonic_merge_encrypted(arr, up); @@ -361,7 +336,7 @@ pub fn perform_hnsw_search( graph: &FheHnswGraph, query: &EncryptedQuery, k: usize, - zero: &FheUint12, + zero: &FheInt14, ) -> Vec { println!("🚀 Starting HNSW approximate search..."); diff --git a/src/bin/enc.rs b/src/bin/enc.rs index 61af2d7..a7316a6 100644 --- a/src/bin/enc.rs +++ b/src/bin/enc.rs @@ -7,7 +7,7 @@ use std::fs::File; use std::io::{BufRead, BufReader, Write}; use std::time::Instant; use tfhe::prelude::*; -use tfhe::{ConfigBuilder, FheUint12, generate_keys, set_server_key}; +use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key}; // Import from our library modules use hfe_knn::{ @@ -73,8 +73,8 @@ fn debug_compute_distances( // Get the plaintext distance for this specific point (before sorting) let plaintext_distance = plaintext_distances[i].0; - let scaled_distance = u32::scale_value(plaintext_distance) as u16; - let encrypted_distance = FheUint12::try_encrypt(scaled_distance, client_key).unwrap(); + let scaled_distance = i16::scale_value(plaintext_distance); + let encrypted_distance = FheInt14::try_encrypt(scaled_distance, client_key).unwrap(); encrypted_distances.push(EncryptedNeighbor { distance: encrypted_distance, @@ -249,7 +249,7 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key); // 3. 执行HNSW搜索 - let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap(); + let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap(); perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero) } else { // 传统算法路径 @@ -271,13 +271,13 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan client_key, ) } else { - let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap(); + let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap(); compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero) }; // Perform KNN selection using the specified algorithm let max_distance = if args.algorithm == "bitonic" { - Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap()) + Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值 } else { None }; diff --git a/src/bin/plain.rs b/src/bin/plain.rs index c331999..5d05f29 100644 --- a/src/bin/plain.rs +++ b/src/bin/plain.rs @@ -15,15 +15,27 @@ struct Args { dataset: String, #[arg(long, default_value = "./dataset/answer1.jsonl")] predictions: String, - #[arg(long, default_value = "8", help = "Max connections per node (M parameter)")] + #[arg( + long, + default_value = "8", + help = "Max connections per node (M parameter)" + )] max_connections: usize, #[arg(long, default_value = "3", help = "Max levels in HNSW graph")] max_level: usize, #[arg(long, default_value = "0.6", help = "Level selection probability")] level_prob: f32, - #[arg(long, default_value = "1", help = "ef parameter for upper layer search")] + #[arg( + long, + default_value = "1", + help = "ef parameter for upper layer search" + )] ef_upper: usize, - #[arg(long, default_value = "10", help = "ef parameter for bottom layer search")] + #[arg( + long, + default_value = "10", + help = "ef parameter for bottom layer search" + )] ef_bottom: usize, } @@ -39,11 +51,18 @@ struct Prediction { } fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { - a.iter() + // 模拟加密版本的缩放和精度损失 + let scaled_distance: f64 = a.iter() .zip(b.iter()) - .map(|(x, y)| (x - y).powi(2)) - .sum::() - .sqrt() + .map(|(x, y)| { + // 缩放坐标(乘以10,然后舍弃小数) + let scaled_x = (x * 10.0).floor(); + let scaled_y = (y * 10.0).floor(); + (scaled_x - scaled_y).powi(2) + }) + .sum::(); + + scaled_distance } fn knn_classify(query: &[f64], data: &[Vec], k: usize) -> Vec { @@ -52,8 +71,12 @@ fn knn_classify(query: &[f64], data: &[Vec], k: usize) -> Vec { .enumerate() .map(|(idx, point)| (euclidean_distance(query, point), idx + 1)) .collect(); + println!(); distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + for distance in distances.iter().take(10) { + println!("{},{}", distance.0, distance.1); + } distances.into_iter().take(k).map(|(_, idx)| idx).collect() } @@ -257,78 +280,78 @@ impl HNSWGraph { } } -// fn main() -> Result<()> { -// let args = Args::parse(); -// -// let file = File::open(&args.dataset)?; -// let reader = BufReader::new(file); -// -// let mut results = Vec::new(); -// -// for line in reader.lines() { -// let line = line?; -// let dataset: Dataset = serde_json::from_str(&line)?; -// -// let nearest = knn_classify(&dataset.query, &dataset.data, 10); -// results.push(Prediction { answer: nearest }); -// } -// -// let mut output_file = File::create(&args.predictions)?; -// for result in results { -// writeln!(output_file, "{}", serde_json::to_string(&result)?)?; -// } -// -// Ok(()) -// } fn main() -> Result<()> { let args = Args::parse(); + let file = File::open(&args.dataset)?; let reader = BufReader::new(file); - let mut results = Vec::new(); - println!("🔧 HNSW Parameters:"); - println!(" max_connections: {}", args.max_connections); - println!(" max_level: {}", args.max_level); - println!(" level_prob: {}", args.level_prob); - println!(" ef_upper: {}", args.ef_upper); - println!(" ef_bottom: {}", args.ef_bottom); + let mut results = Vec::new(); for line in reader.lines() { let line = line?; let dataset: Dataset = serde_json::from_str(&line)?; - let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections); - for data_point in &dataset.data { - hnsw.insert_node(data_point.clone(), args.level_prob); - } - - let nearest = if let Some(entry_point) = hnsw.entry_point { - let mut search_results = vec![entry_point]; - - // 上层搜索使用ef_upper参数 - for layer in (1..=hnsw.nodes[entry_point].level).rev() { - search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer); - } - - // 底层搜索使用ef_bottom参数 - let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0); - final_results - .into_iter() - .take(10) - .map(|idx| idx + 1) - .collect() - } else { - Vec::new() - }; - + let nearest = knn_classify(&dataset.query, &dataset.data, 10); results.push(Prediction { answer: nearest }); } let mut output_file = File::create(&args.predictions)?; for result in results { writeln!(output_file, "{}", serde_json::to_string(&result)?)?; - println!("{}", serde_json::to_string(&result)?); } Ok(()) } +// fn main() -> Result<()> { +// let args = Args::parse(); +// let file = File::open(&args.dataset)?; +// let reader = BufReader::new(file); +// let mut results = Vec::new(); +// +// println!("🔧 HNSW Parameters:"); +// println!(" max_connections: {}", args.max_connections); +// println!(" max_level: {}", args.max_level); +// println!(" level_prob: {}", args.level_prob); +// println!(" ef_upper: {}", args.ef_upper); +// println!(" ef_bottom: {}", args.ef_bottom); +// +// for line in reader.lines() { +// let line = line?; +// let dataset: Dataset = serde_json::from_str(&line)?; +// +// let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections); +// for data_point in &dataset.data { +// hnsw.insert_node(data_point.clone(), args.level_prob); +// } +// +// let nearest = if let Some(entry_point) = hnsw.entry_point { +// let mut search_results = vec![entry_point]; +// +// // 上层搜索使用ef_upper参数 +// for layer in (1..=hnsw.nodes[entry_point].level).rev() { +// search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer); +// } +// +// // 底层搜索使用ef_bottom参数 +// let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0); +// final_results +// .into_iter() +// .take(10) +// .map(|idx| idx + 1) +// .collect() +// } else { +// Vec::new() +// }; +// +// results.push(Prediction { answer: nearest }); +// } +// +// let mut output_file = File::create(&args.predictions)?; +// for result in results { +// writeln!(output_file, "{}", serde_json::to_string(&result)?)?; +// println!("{}", serde_json::to_string(&result)?); +// } +// +// Ok(()) +// } diff --git a/src/data.rs b/src/data.rs index 4dd737d..5b47930 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tfhe::prelude::*; -use tfhe::{ClientKey, FheUint8, FheUint12}; +use tfhe::{ClientKey, FheInt14, FheUint8}; /// 数据集结构,包含查询点和训练数据 #[derive(Deserialize)] @@ -19,23 +19,23 @@ pub struct Prediction { /// 加密的查询点 #[derive(Clone)] pub struct EncryptedQuery { - pub coords: Vec, - pub double_coords: Vec, // 2 * aᵢ - pub sum_squares: FheUint12, // Σaᵢ² + pub coords: Vec, + pub double_coords: Vec, // 2 * aᵢ + pub sum_squares: FheInt14, // Σaᵢ² } /// 加密的数据点,包含坐标、预计算值和索引 #[derive(Clone)] pub struct EncryptedPoint { - pub coords: Vec, - pub sum_squares: FheUint12, // Σbᵢ² + pub coords: Vec, + pub sum_squares: FheInt14, // Σbᵢ² pub index: FheUint8, } /// 加密的邻居点,包含距离和索引 #[derive(Clone)] pub struct EncryptedNeighbor { - pub distance: FheUint12, + pub distance: FheInt14, pub index: FheUint8, } @@ -50,37 +50,35 @@ pub trait ScaleInt { fn scale_value(value: f64) -> Self::Output; } -/// u32缩放实现:舍弃2位小数,限制在4000以内 -impl ScaleInt for u32 { - type Output = u32; - fn scale_value(value: f64) -> u32 { - let truncated = value.floor() as u32; - truncated.min(4000) +/// i16缩放实现:舍弃2位小数,支持负数 +impl ScaleInt for i16 { + type Output = i16; + fn scale_value(value: f64) -> i16 { + value.floor() as i16 // 直接取整,支持负数 } } /// 将明文查询点转换为加密查询点 pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery { - let scaled_coords: Vec = coords + let scaled_coords: Vec = coords .iter() - .map(|&coord| u32::scale_value(coord) as u16) + .map(|&coord| i16::scale_value(coord)) .collect(); - let encrypted_coords: Vec = scaled_coords + let encrypted_coords: Vec = scaled_coords .iter() - .map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap()) + .map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap()) .collect(); // 预计算 2 * aᵢ - let encrypted_double_coords: Vec = scaled_coords + let encrypted_double_coords: Vec = scaled_coords .iter() - .map(|&coord| FheUint12::try_encrypt(coord * 2, client_key).unwrap()) + .map(|&coord| FheInt14::try_encrypt(coord * 2, client_key).unwrap()) .collect(); // 预计算 Σaᵢ² - let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum(); - let encrypted_sum_squares = - FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap(); + let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum(); + let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap(); EncryptedQuery { coords: encrypted_coords, @@ -91,21 +89,20 @@ pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery { /// 将明文数据点转换为加密数据点 pub fn encrypt_point(coords: &[f64], index: usize, client_key: &ClientKey) -> EncryptedPoint { - let scaled_coords: Vec = coords + let scaled_coords: Vec = coords .iter() - .map(|&coord| u32::scale_value(coord) as u16) + .map(|&coord| i16::scale_value(coord)) .collect(); - let encrypted_coords: Vec = scaled_coords + let encrypted_coords: Vec = scaled_coords .iter() - .map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap()) + .map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap()) .collect(); // 预计算 Σbᵢ² - let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum(); + let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum(); - let encrypted_sum_squares = - FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap(); + let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap(); let encrypted_index = FheUint8::try_encrypt(index as u8, client_key).unwrap(); EncryptedPoint { @@ -183,7 +180,7 @@ impl FheHnswGraph { entry_points: Vec, ef: usize, layer: usize, - _zero: &FheUint12, + _zero: &FheInt14, ) -> Vec { let _visited: HashSet = HashSet::new(); @@ -244,6 +241,12 @@ impl FheHnswGraph { } } +impl Default for FheHnswGraph { + fn default() -> Self { + Self::new() + } +} + /// 从明文数据构建密文 HNSW 图的辅助结构 #[derive(Clone)] pub struct PlaintextHnswNode { diff --git a/src/logging.rs b/src/logging.rs index 69c1511..5982cb8 100644 --- a/src/logging.rs +++ b/src/logging.rs @@ -1,5 +1,5 @@ -use std::time::Duration; use std::io::{self, Write}; +use std::time::Duration; /// 格式化时间长度为人类可读的字符串 /// @@ -42,4 +42,5 @@ pub fn print_progress_bar(current: usize, total: usize, operation: &str) { percentage ); io::stdout().flush().unwrap(); -} \ No newline at end of file +} +