feat: implement parallel bitonic sort with proper power-of-2 padding (v0.2.0)
- Add rayon dependency for parallel processing - Implement rayon::broadcast for proper TFHE server key distribution across threads - Fix bitonic sort to handle non-power-of-2 input sizes by padding to nearest 2^k - Add parallel execution for recursive calls and comparison operations - Update API to accept encrypted max values for safe padding - Add Send/Sync traits for EncryptedNeighbor to enable thread safety - Modify perform_knn_selection to support bitonic sort requirements - Bump version to 0.2.0 This resolves the bitonic sort correctness issue where non-power-of-2 array sizes caused incorrect results. The algorithm now properly pads from 100 to 128 elements and should produce deterministic, correct results.
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -419,6 +419,7 @@ dependencies = [
|
|||||||
"env_logger",
|
"env_logger",
|
||||||
"log",
|
"log",
|
||||||
"rand",
|
"rand",
|
||||||
|
"rayon",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"tfhe",
|
"tfhe",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "hfe_knn"
|
name = "hfe_knn"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@@ -14,3 +14,4 @@ chrono = { version = "0.4", features = ["serde"] }
|
|||||||
bincode = "2.0"
|
bincode = "2.0"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
env_logger = "0.11"
|
env_logger = "0.11"
|
||||||
|
rayon = "1.10"
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph};
|
use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph};
|
||||||
use crate::logging::{format_duration, print_progress_bar};
|
use crate::logging::{format_duration, print_progress_bar};
|
||||||
use crate::{EncryptedQuery, SupportedFheInt};
|
use crate::{EncryptedQuery, SupportedFheInt};
|
||||||
|
use rayon::prelude::*;
|
||||||
use std::ops::{Add, AddAssign, Mul, Sub};
|
use std::ops::{Add, AddAssign, Mul, Sub};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tfhe::FheUint8;
|
use tfhe::FheUint8;
|
||||||
use tfhe::prelude::*;
|
use tfhe::prelude::*;
|
||||||
|
|
||||||
|
|
||||||
/// 优化的欧几里得距离计算:使用预计算的平方和
|
/// 优化的欧几里得距离计算:使用预计算的平方和
|
||||||
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
|
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
|
||||||
/// 速度相较于直接计算距离提升约10%
|
/// 速度相较于直接计算距离提升约10%
|
||||||
@@ -111,6 +113,8 @@ where
|
|||||||
/// * `distances` - 距离数组(会被修改)
|
/// * `distances` - 距离数组(会被修改)
|
||||||
/// * `k` - 返回的最近邻数量
|
/// * `k` - 返回的最近邻数量
|
||||||
/// * `algorithm` - 使用的排序算法
|
/// * `algorithm` - 使用的排序算法
|
||||||
|
/// * `max_distance` - 加密的最大距离值(用于双调排序填充)
|
||||||
|
/// * `max_index` - 加密的最大索引值(用于双调排序填充)
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * k个最近邻的加密索引列表
|
/// * k个最近邻的加密索引列表
|
||||||
@@ -118,9 +122,12 @@ pub fn perform_knn_selection<T: SupportedFheInt>(
|
|||||||
distances: &mut Vec<EncryptedNeighbor<T>>,
|
distances: &mut Vec<EncryptedNeighbor<T>>,
|
||||||
k: usize,
|
k: usize,
|
||||||
algorithm: &str,
|
algorithm: &str,
|
||||||
|
max_distance: Option<&T>,
|
||||||
|
max_index: Option<&FheUint8>,
|
||||||
) -> Vec<FheUint8>
|
) -> Vec<FheUint8>
|
||||||
where
|
where
|
||||||
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
|
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
|
||||||
|
T: Clone,
|
||||||
{
|
{
|
||||||
match algorithm {
|
match algorithm {
|
||||||
"selection" => {
|
"selection" => {
|
||||||
@@ -129,7 +136,12 @@ where
|
|||||||
}
|
}
|
||||||
"bitonic" => {
|
"bitonic" => {
|
||||||
println!("📊 Finding {k} smallest distances using bitonic sort...");
|
println!("📊 Finding {k} smallest distances using bitonic sort...");
|
||||||
encrypted_bitonic_sort(distances, k);
|
if let (Some(max_dist), Some(max_idx)) = (max_distance, max_index) {
|
||||||
|
encrypted_bitonic_sort(distances, k, max_dist, max_idx);
|
||||||
|
} else {
|
||||||
|
println!("⚠️ Bitonic sort requires max_distance and max_index, falling back to selection sort");
|
||||||
|
encrypted_selection_sort(distances, k);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
"heap" => {
|
"heap" => {
|
||||||
println!("📊 Finding {k} smallest distances using min-heap selection...");
|
println!("📊 Finding {k} smallest distances using min-heap selection...");
|
||||||
@@ -174,26 +186,57 @@ pub fn encrypted_selection_sort<T: SupportedFheInt>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 双调排序算法 - 适合并行,固定比较模式
|
/// 双调排序算法 - 适合并行,固定比较模式
|
||||||
|
/// 要求输入长度为2的幂次,自动填充到最近的2^k
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `distances` - 距离数组
|
/// * `distances` - 距离数组
|
||||||
/// * `k` - 需要的最小元素数量
|
/// * `k` - 需要的最小元素数量
|
||||||
|
/// * `max_distance` - 加密的最大距离值,用于填充
|
||||||
|
/// * `max_index` - 加密的最大索引值,用于填充
|
||||||
pub fn encrypted_bitonic_sort<T: SupportedFheInt>(
|
pub fn encrypted_bitonic_sort<T: SupportedFheInt>(
|
||||||
distances: &mut Vec<EncryptedNeighbor<T>>,
|
distances: &mut Vec<EncryptedNeighbor<T>>,
|
||||||
k: usize,
|
k: usize,
|
||||||
|
max_distance: &T,
|
||||||
|
max_index: &FheUint8,
|
||||||
) where
|
) where
|
||||||
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
|
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
|
||||||
|
T: Clone,
|
||||||
{
|
{
|
||||||
println!("🔄 Starting bitonic sort...");
|
println!("🔄 Starting bitonic sort...");
|
||||||
let sort_start = Instant::now();
|
let sort_start = Instant::now();
|
||||||
|
|
||||||
|
let original_len = distances.len();
|
||||||
|
|
||||||
|
// 计算最近的2的幂次
|
||||||
|
let power_of_2 = (original_len as f64).log2().ceil() as u32;
|
||||||
|
let target_len = 2_usize.pow(power_of_2);
|
||||||
|
|
||||||
|
println!("📏 Padding array from {} to {} (2^{})", original_len, target_len, power_of_2);
|
||||||
|
|
||||||
|
// 填充数组到2的幂次长度,使用客户端提供的最大值作为哨兵
|
||||||
|
if target_len > original_len {
|
||||||
|
let padding_elem = EncryptedNeighbor {
|
||||||
|
distance: max_distance.clone(),
|
||||||
|
index: max_index.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
for _ in original_len..target_len {
|
||||||
|
distances.push(padding_elem.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bitonic_sort_encrypted(distances, true);
|
bitonic_sort_encrypted(distances, true);
|
||||||
|
|
||||||
|
// 只保留前k个结果
|
||||||
|
distances.truncate(k);
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"✅ Bitonic sort completed in {}",
|
"✅ Bitonic sort completed in {}",
|
||||||
format_duration(sort_start.elapsed())
|
format_duration(sort_start.elapsed())
|
||||||
);
|
);
|
||||||
distances.truncate(k);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// 堆选择算法 - 适合高维数据的top-k选择
|
/// 堆选择算法 - 适合高维数据的top-k选择
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -278,8 +321,14 @@ fn bitonic_sort_encrypted_recursive<T: SupportedFheInt>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mid = arr.len() / 2;
|
let mid = arr.len() / 2;
|
||||||
bitonic_sort_encrypted_recursive(&mut arr[..mid], true, depth + 1);
|
|
||||||
bitonic_sort_encrypted_recursive(&mut arr[mid..], false, depth + 1);
|
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置
|
||||||
|
let (left, right) = arr.split_at_mut(mid);
|
||||||
|
rayon::join(
|
||||||
|
|| bitonic_sort_encrypted_recursive(left, true, depth + 1),
|
||||||
|
|| bitonic_sort_encrypted_recursive(right, false, depth + 1)
|
||||||
|
);
|
||||||
|
|
||||||
bitonic_merge_encrypted(arr, up);
|
bitonic_merge_encrypted(arr, up);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,19 +346,25 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mid = arr.len() / 2;
|
let mid = arr.len() / 2;
|
||||||
for i in 0..mid {
|
|
||||||
let should_swap = if up {
|
// 并行执行比较和条件交换操作 - server key已通过rayon::broadcast设置
|
||||||
arr[i].distance.gt(&arr[i + mid].distance)
|
let (left, right) = arr.split_at_mut(mid);
|
||||||
} else {
|
left.par_iter_mut()
|
||||||
arr[i].distance.lt(&arr[i + mid].distance)
|
.zip(right.par_iter_mut())
|
||||||
};
|
.for_each(|(left_elem, right_elem)| {
|
||||||
|
let should_swap = if up {
|
||||||
|
left_elem.distance.gt(&right_elem.distance)
|
||||||
|
} else {
|
||||||
|
left_elem.distance.lt(&right_elem.distance)
|
||||||
|
};
|
||||||
|
encrypted_conditional_swap(left_elem, right_elem, &should_swap);
|
||||||
|
});
|
||||||
|
|
||||||
let (left, right) = arr.split_at_mut(mid);
|
// 并行执行两个递归合并调用
|
||||||
encrypted_conditional_swap(&mut left[i], &mut right[i], &should_swap);
|
rayon::join(
|
||||||
}
|
|| bitonic_merge_encrypted(left, up),
|
||||||
|
|| bitonic_merge_encrypted(right, up)
|
||||||
bitonic_merge_encrypted(&mut arr[..mid], up);
|
);
|
||||||
bitonic_merge_encrypted(&mut arr[mid..], up);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 基于加密条件交换两个加密邻居的距离和索引
|
/// 基于加密条件交换两个加密邻居的距离和索引
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use chrono::Local;
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use log::info;
|
use log::info;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
use rayon::prelude::*;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{BufRead, BufReader, Write};
|
use std::io::{BufRead, BufReader, Write};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
@@ -267,6 +268,8 @@ fn main() -> Result<()> {
|
|||||||
let config = ConfigBuilder::default().build();
|
let config = ConfigBuilder::default().build();
|
||||||
let (client_key, server_key) = generate_keys(config);
|
let (client_key, server_key) = generate_keys(config);
|
||||||
|
|
||||||
|
// 使用TFHE官方推荐的方式设置Rayon多线程
|
||||||
|
rayon::broadcast(|_| set_server_key(server_key.clone()));
|
||||||
set_server_key(server_key);
|
set_server_key(server_key);
|
||||||
info!(
|
info!(
|
||||||
"✅ TFHE setup completed in {}",
|
"✅ TFHE setup completed in {}",
|
||||||
@@ -347,7 +350,14 @@ fn process_with_i32(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Perform KNN selection using the specified algorithm
|
// Perform KNN selection using the specified algorithm
|
||||||
perform_knn_selection(&mut distances, k, &args.algorithm)
|
let max_distance = if args.algorithm == "bitonic" {
|
||||||
|
Some(FheInt32::try_encrypt(i32::MAX, client_key).unwrap())
|
||||||
|
} else { None };
|
||||||
|
let max_index = if args.algorithm == "bitonic" {
|
||||||
|
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
|
||||||
|
} else { None };
|
||||||
|
|
||||||
|
perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref())
|
||||||
};
|
};
|
||||||
|
|
||||||
// Decrypt the results
|
// Decrypt the results
|
||||||
@@ -436,7 +446,14 @@ fn process_with_u12(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Perform KNN selection using the specified algorithm
|
// Perform KNN selection using the specified algorithm
|
||||||
perform_knn_selection(&mut distances, k, &args.algorithm)
|
let max_distance = if args.algorithm == "bitonic" {
|
||||||
|
Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap())
|
||||||
|
} else { None };
|
||||||
|
let max_index = if args.algorithm == "bitonic" {
|
||||||
|
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
|
||||||
|
} else { None };
|
||||||
|
|
||||||
|
perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref())
|
||||||
};
|
};
|
||||||
|
|
||||||
// Decrypt the results
|
// Decrypt the results
|
||||||
|
|||||||
@@ -64,6 +64,11 @@ pub struct EncryptedNeighbor<T: SupportedFheInt> {
|
|||||||
pub index: FheUint8,
|
pub index: FheUint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 为EncryptedNeighbor实现Send和Sync trait以支持并行处理
|
||||||
|
// 这是安全的,因为我们使用rayon::broadcast确保每个线程都有正确的server key
|
||||||
|
unsafe impl<T: SupportedFheInt> Send for EncryptedNeighbor<T> {}
|
||||||
|
unsafe impl<T: SupportedFheInt> Sync for EncryptedNeighbor<T> {}
|
||||||
|
|
||||||
/// 泛型数值缩放trait
|
/// 泛型数值缩放trait
|
||||||
pub trait ScaleInt {
|
pub trait ScaleInt {
|
||||||
type Output;
|
type Output;
|
||||||
|
|||||||
Reference in New Issue
Block a user