feat: implement parallel bitonic sort with proper power-of-2 padding (v0.2.0)

- Add rayon dependency for parallel processing
- Implement rayon::broadcast for proper TFHE server key distribution across threads
- Fix bitonic sort to handle non-power-of-2 input sizes by padding to nearest 2^k
- Add parallel execution for recursive calls and comparison operations
- Update API to accept encrypted max values for safe padding
- Add Send/Sync traits for EncryptedNeighbor to enable thread safety
- Modify perform_knn_selection to support bitonic sort requirements
- Bump version to 0.2.0

This resolves the bitonic sort correctness issue where non-power-of-2 array sizes
caused incorrect results. The algorithm now properly pads from 100 to 128 elements
and should produce deterministic, correct results.
This commit is contained in:
2025-07-27 20:02:13 +08:00
parent 8b47403cc0
commit ef4f0021cf
5 changed files with 98 additions and 19 deletions

1
Cargo.lock generated
View File

@@ -419,6 +419,7 @@ dependencies = [
"env_logger",
"log",
"rand",
"rayon",
"serde",
"serde_json",
"tfhe",

View File

@@ -1,6 +1,6 @@
[package]
name = "hfe_knn"
version = "0.1.0"
version = "0.2.0"
edition = "2024"
[dependencies]
@@ -14,3 +14,4 @@ chrono = { version = "0.4", features = ["serde"] }
bincode = "2.0"
log = "0.4"
env_logger = "0.11"
rayon = "1.10"

View File

@@ -1,11 +1,13 @@
use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph};
use crate::logging::{format_duration, print_progress_bar};
use crate::{EncryptedQuery, SupportedFheInt};
use rayon::prelude::*;
use std::ops::{Add, AddAssign, Mul, Sub};
use std::time::Instant;
use tfhe::FheUint8;
use tfhe::prelude::*;
/// 优化的欧几里得距离计算:使用预计算的平方和
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
/// 速度相较于直接计算距离提升约10%
@@ -111,6 +113,8 @@ where
/// * `distances` - 距离数组(会被修改)
/// * `k` - 返回的最近邻数量
/// * `algorithm` - 使用的排序算法
/// * `max_distance` - 加密的最大距离值(用于双调排序填充)
/// * `max_index` - 加密的最大索引值(用于双调排序填充)
///
/// # Returns
/// * k个最近邻的加密索引列表
@@ -118,9 +122,12 @@ pub fn perform_knn_selection<T: SupportedFheInt>(
distances: &mut Vec<EncryptedNeighbor<T>>,
k: usize,
algorithm: &str,
max_distance: Option<&T>,
max_index: Option<&FheUint8>,
) -> Vec<FheUint8>
where
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
T: Clone,
{
match algorithm {
"selection" => {
@@ -129,7 +136,12 @@ where
}
"bitonic" => {
println!("📊 Finding {k} smallest distances using bitonic sort...");
encrypted_bitonic_sort(distances, k);
if let (Some(max_dist), Some(max_idx)) = (max_distance, max_index) {
encrypted_bitonic_sort(distances, k, max_dist, max_idx);
} else {
println!("⚠️ Bitonic sort requires max_distance and max_index, falling back to selection sort");
encrypted_selection_sort(distances, k);
}
}
"heap" => {
println!("📊 Finding {k} smallest distances using min-heap selection...");
@@ -174,26 +186,57 @@ pub fn encrypted_selection_sort<T: SupportedFheInt>(
}
/// 双调排序算法 - 适合并行,固定比较模式
/// 要求输入长度为2的幂次自动填充到最近的2^k
///
/// # Arguments
/// * `distances` - 距离数组
/// * `k` - 需要的最小元素数量
/// * `max_distance` - 加密的最大距离值,用于填充
/// * `max_index` - 加密的最大索引值,用于填充
pub fn encrypted_bitonic_sort<T: SupportedFheInt>(
distances: &mut Vec<EncryptedNeighbor<T>>,
k: usize,
max_distance: &T,
max_index: &FheUint8,
) where
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
T: Clone,
{
println!("🔄 Starting bitonic sort...");
let sort_start = Instant::now();
let original_len = distances.len();
// 计算最近的2的幂次
let power_of_2 = (original_len as f64).log2().ceil() as u32;
let target_len = 2_usize.pow(power_of_2);
println!("📏 Padding array from {} to {} (2^{})", original_len, target_len, power_of_2);
// 填充数组到2的幂次长度使用客户端提供的最大值作为哨兵
if target_len > original_len {
let padding_elem = EncryptedNeighbor {
distance: max_distance.clone(),
index: max_index.clone(),
};
for _ in original_len..target_len {
distances.push(padding_elem.clone());
}
}
bitonic_sort_encrypted(distances, true);
// 只保留前k个结果
distances.truncate(k);
println!(
"✅ Bitonic sort completed in {}",
format_duration(sort_start.elapsed())
);
distances.truncate(k);
}
/// 堆选择算法 - 适合高维数据的top-k选择
///
/// # Arguments
@@ -278,8 +321,14 @@ fn bitonic_sort_encrypted_recursive<T: SupportedFheInt>(
}
let mid = arr.len() / 2;
bitonic_sort_encrypted_recursive(&mut arr[..mid], true, depth + 1);
bitonic_sort_encrypted_recursive(&mut arr[mid..], false, depth + 1);
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置
let (left, right) = arr.split_at_mut(mid);
rayon::join(
|| bitonic_sort_encrypted_recursive(left, true, depth + 1),
|| bitonic_sort_encrypted_recursive(right, false, depth + 1)
);
bitonic_merge_encrypted(arr, up);
}
@@ -297,19 +346,25 @@ where
}
let mid = arr.len() / 2;
for i in 0..mid {
let should_swap = if up {
arr[i].distance.gt(&arr[i + mid].distance)
} else {
arr[i].distance.lt(&arr[i + mid].distance)
};
// 并行执行比较和条件交换操作 - server key已通过rayon::broadcast设置
let (left, right) = arr.split_at_mut(mid);
left.par_iter_mut()
.zip(right.par_iter_mut())
.for_each(|(left_elem, right_elem)| {
let should_swap = if up {
left_elem.distance.gt(&right_elem.distance)
} else {
left_elem.distance.lt(&right_elem.distance)
};
encrypted_conditional_swap(left_elem, right_elem, &should_swap);
});
let (left, right) = arr.split_at_mut(mid);
encrypted_conditional_swap(&mut left[i], &mut right[i], &should_swap);
}
bitonic_merge_encrypted(&mut arr[..mid], up);
bitonic_merge_encrypted(&mut arr[mid..], up);
// 并行执行两个递归合并调用
rayon::join(
|| bitonic_merge_encrypted(left, up),
|| bitonic_merge_encrypted(right, up)
);
}
/// 基于加密条件交换两个加密邻居的距离和索引

View File

@@ -3,6 +3,7 @@ use chrono::Local;
use clap::Parser;
use log::info;
use rand::Rng;
use rayon::prelude::*;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::time::Instant;
@@ -267,6 +268,8 @@ fn main() -> Result<()> {
let config = ConfigBuilder::default().build();
let (client_key, server_key) = generate_keys(config);
// 使用TFHE官方推荐的方式设置Rayon多线程
rayon::broadcast(|_| set_server_key(server_key.clone()));
set_server_key(server_key);
info!(
"✅ TFHE setup completed in {}",
@@ -347,7 +350,14 @@ fn process_with_i32(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta
};
// Perform KNN selection using the specified algorithm
perform_knn_selection(&mut distances, k, &args.algorithm)
let max_distance = if args.algorithm == "bitonic" {
Some(FheInt32::try_encrypt(i32::MAX, client_key).unwrap())
} else { None };
let max_index = if args.algorithm == "bitonic" {
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
} else { None };
perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref())
};
// Decrypt the results
@@ -436,7 +446,14 @@ fn process_with_u12(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta
};
// Perform KNN selection using the specified algorithm
perform_knn_selection(&mut distances, k, &args.algorithm)
let max_distance = if args.algorithm == "bitonic" {
Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap())
} else { None };
let max_index = if args.algorithm == "bitonic" {
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
} else { None };
perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref())
};
// Decrypt the results

View File

@@ -64,6 +64,11 @@ pub struct EncryptedNeighbor<T: SupportedFheInt> {
pub index: FheUint8,
}
// 为EncryptedNeighbor实现Send和Sync trait以支持并行处理
// 这是安全的因为我们使用rayon::broadcast确保每个线程都有正确的server key
unsafe impl<T: SupportedFheInt> Send for EncryptedNeighbor<T> {}
unsafe impl<T: SupportedFheInt> Sync for EncryptedNeighbor<T> {}
/// 泛型数值缩放trait
pub trait ScaleInt {
type Output;