diff --git a/Cargo.lock b/Cargo.lock index 2850ee4..5e2d345 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -419,6 +419,7 @@ dependencies = [ "env_logger", "log", "rand", + "rayon", "serde", "serde_json", "tfhe", diff --git a/Cargo.toml b/Cargo.toml index e886a99..7c17028 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hfe_knn" -version = "0.1.0" +version = "0.2.0" edition = "2024" [dependencies] @@ -14,3 +14,4 @@ chrono = { version = "0.4", features = ["serde"] } bincode = "2.0" log = "0.4" env_logger = "0.11" +rayon = "1.10" diff --git a/src/algorithms.rs b/src/algorithms.rs index f19b14b..03892d4 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -1,11 +1,13 @@ use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph}; use crate::logging::{format_duration, print_progress_bar}; use crate::{EncryptedQuery, SupportedFheInt}; +use rayon::prelude::*; use std::ops::{Add, AddAssign, Mul, Sub}; use std::time::Instant; use tfhe::FheUint8; use tfhe::prelude::*; + /// 优化的欧几里得距离计算:使用预计算的平方和 /// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ) /// 速度相较于直接计算距离提升约10% @@ -111,6 +113,8 @@ where /// * `distances` - 距离数组(会被修改) /// * `k` - 返回的最近邻数量 /// * `algorithm` - 使用的排序算法 +/// * `max_distance` - 加密的最大距离值(用于双调排序填充) +/// * `max_index` - 加密的最大索引值(用于双调排序填充) /// /// # Returns /// * k个最近邻的加密索引列表 @@ -118,9 +122,12 @@ pub fn perform_knn_selection( distances: &mut Vec>, k: usize, algorithm: &str, + max_distance: Option<&T>, + max_index: Option<&FheUint8>, ) -> Vec where tfhe::FheBool: tfhe::prelude::IfThenElse + tfhe::prelude::IfThenElse, + T: Clone, { match algorithm { "selection" => { @@ -129,7 +136,12 @@ where } "bitonic" => { println!("📊 Finding {k} smallest distances using bitonic sort..."); - encrypted_bitonic_sort(distances, k); + if let (Some(max_dist), Some(max_idx)) = (max_distance, max_index) { + encrypted_bitonic_sort(distances, k, max_dist, max_idx); + } else { + println!("⚠️ Bitonic sort requires max_distance and max_index, falling back to selection sort"); + encrypted_selection_sort(distances, k); + } } "heap" => { println!("📊 Finding {k} smallest distances using min-heap selection..."); @@ -174,26 +186,57 @@ pub fn encrypted_selection_sort( } /// 双调排序算法 - 适合并行,固定比较模式 +/// 要求输入长度为2的幂次,自动填充到最近的2^k /// /// # Arguments /// * `distances` - 距离数组 /// * `k` - 需要的最小元素数量 +/// * `max_distance` - 加密的最大距离值,用于填充 +/// * `max_index` - 加密的最大索引值,用于填充 pub fn encrypted_bitonic_sort( distances: &mut Vec>, k: usize, + max_distance: &T, + max_index: &FheUint8, ) where tfhe::FheBool: tfhe::prelude::IfThenElse + tfhe::prelude::IfThenElse, + T: Clone, { println!("🔄 Starting bitonic sort..."); let sort_start = Instant::now(); + + let original_len = distances.len(); + + // 计算最近的2的幂次 + let power_of_2 = (original_len as f64).log2().ceil() as u32; + let target_len = 2_usize.pow(power_of_2); + + println!("📏 Padding array from {} to {} (2^{})", original_len, target_len, power_of_2); + + // 填充数组到2的幂次长度,使用客户端提供的最大值作为哨兵 + if target_len > original_len { + let padding_elem = EncryptedNeighbor { + distance: max_distance.clone(), + index: max_index.clone(), + }; + + for _ in original_len..target_len { + distances.push(padding_elem.clone()); + } + } + bitonic_sort_encrypted(distances, true); + + // 只保留前k个结果 + distances.truncate(k); + println!( "✅ Bitonic sort completed in {}", format_duration(sort_start.elapsed()) ); - distances.truncate(k); } + /// 堆选择算法 - 适合高维数据的top-k选择 /// /// # Arguments @@ -278,8 +321,14 @@ fn bitonic_sort_encrypted_recursive( } let mid = arr.len() / 2; - bitonic_sort_encrypted_recursive(&mut arr[..mid], true, depth + 1); - bitonic_sort_encrypted_recursive(&mut arr[mid..], false, depth + 1); + + // 并行执行两个递归调用 - server key已通过rayon::broadcast设置 + let (left, right) = arr.split_at_mut(mid); + rayon::join( + || bitonic_sort_encrypted_recursive(left, true, depth + 1), + || bitonic_sort_encrypted_recursive(right, false, depth + 1) + ); + bitonic_merge_encrypted(arr, up); } @@ -297,19 +346,25 @@ where } let mid = arr.len() / 2; - for i in 0..mid { - let should_swap = if up { - arr[i].distance.gt(&arr[i + mid].distance) - } else { - arr[i].distance.lt(&arr[i + mid].distance) - }; + + // 并行执行比较和条件交换操作 - server key已通过rayon::broadcast设置 + let (left, right) = arr.split_at_mut(mid); + left.par_iter_mut() + .zip(right.par_iter_mut()) + .for_each(|(left_elem, right_elem)| { + let should_swap = if up { + left_elem.distance.gt(&right_elem.distance) + } else { + left_elem.distance.lt(&right_elem.distance) + }; + encrypted_conditional_swap(left_elem, right_elem, &should_swap); + }); - let (left, right) = arr.split_at_mut(mid); - encrypted_conditional_swap(&mut left[i], &mut right[i], &should_swap); - } - - bitonic_merge_encrypted(&mut arr[..mid], up); - bitonic_merge_encrypted(&mut arr[mid..], up); + // 并行执行两个递归合并调用 + rayon::join( + || bitonic_merge_encrypted(left, up), + || bitonic_merge_encrypted(right, up) + ); } /// 基于加密条件交换两个加密邻居的距离和索引 diff --git a/src/bin/enc.rs b/src/bin/enc.rs index 49e58ae..955d691 100644 --- a/src/bin/enc.rs +++ b/src/bin/enc.rs @@ -3,6 +3,7 @@ use chrono::Local; use clap::Parser; use log::info; use rand::Rng; +use rayon::prelude::*; use std::fs::File; use std::io::{BufRead, BufReader, Write}; use std::time::Instant; @@ -267,6 +268,8 @@ fn main() -> Result<()> { let config = ConfigBuilder::default().build(); let (client_key, server_key) = generate_keys(config); + // 使用TFHE官方推荐的方式设置Rayon多线程 + rayon::broadcast(|_| set_server_key(server_key.clone())); set_server_key(server_key); info!( "✅ TFHE setup completed in {}", @@ -347,7 +350,14 @@ fn process_with_i32(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta }; // Perform KNN selection using the specified algorithm - perform_knn_selection(&mut distances, k, &args.algorithm) + let max_distance = if args.algorithm == "bitonic" { + Some(FheInt32::try_encrypt(i32::MAX, client_key).unwrap()) + } else { None }; + let max_index = if args.algorithm == "bitonic" { + Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap()) + } else { None }; + + perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref()) }; // Decrypt the results @@ -436,7 +446,14 @@ fn process_with_u12(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta }; // Perform KNN selection using the specified algorithm - perform_knn_selection(&mut distances, k, &args.algorithm) + let max_distance = if args.algorithm == "bitonic" { + Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap()) + } else { None }; + let max_index = if args.algorithm == "bitonic" { + Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap()) + } else { None }; + + perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref()) }; // Decrypt the results diff --git a/src/data.rs b/src/data.rs index 3b548ad..cebd29e 100644 --- a/src/data.rs +++ b/src/data.rs @@ -64,6 +64,11 @@ pub struct EncryptedNeighbor { pub index: FheUint8, } +// 为EncryptedNeighbor实现Send和Sync trait以支持并行处理 +// 这是安全的,因为我们使用rayon::broadcast确保每个线程都有正确的server key +unsafe impl Send for EncryptedNeighbor {} +unsafe impl Sync for EncryptedNeighbor {} + /// 泛型数值缩放trait pub trait ScaleInt { type Output;