feat: implement parallel bitonic sort with proper power-of-2 padding (v0.2.0)
- Add rayon dependency for parallel processing - Implement rayon::broadcast for proper TFHE server key distribution across threads - Fix bitonic sort to handle non-power-of-2 input sizes by padding to nearest 2^k - Add parallel execution for recursive calls and comparison operations - Update API to accept encrypted max values for safe padding - Add Send/Sync traits for EncryptedNeighbor to enable thread safety - Modify perform_knn_selection to support bitonic sort requirements - Bump version to 0.2.0 This resolves the bitonic sort correctness issue where non-power-of-2 array sizes caused incorrect results. The algorithm now properly pads from 100 to 128 elements and should produce deterministic, correct results.
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
use crate::data::{EncryptedNeighbor, EncryptedPoint, FheHnswGraph};
|
||||
use crate::logging::{format_duration, print_progress_bar};
|
||||
use crate::{EncryptedQuery, SupportedFheInt};
|
||||
use rayon::prelude::*;
|
||||
use std::ops::{Add, AddAssign, Mul, Sub};
|
||||
use std::time::Instant;
|
||||
use tfhe::FheUint8;
|
||||
use tfhe::prelude::*;
|
||||
|
||||
|
||||
/// 优化的欧几里得距离计算:使用预计算的平方和
|
||||
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
|
||||
/// 速度相较于直接计算距离提升约10%
|
||||
@@ -111,6 +113,8 @@ where
|
||||
/// * `distances` - 距离数组(会被修改)
|
||||
/// * `k` - 返回的最近邻数量
|
||||
/// * `algorithm` - 使用的排序算法
|
||||
/// * `max_distance` - 加密的最大距离值(用于双调排序填充)
|
||||
/// * `max_index` - 加密的最大索引值(用于双调排序填充)
|
||||
///
|
||||
/// # Returns
|
||||
/// * k个最近邻的加密索引列表
|
||||
@@ -118,9 +122,12 @@ pub fn perform_knn_selection<T: SupportedFheInt>(
|
||||
distances: &mut Vec<EncryptedNeighbor<T>>,
|
||||
k: usize,
|
||||
algorithm: &str,
|
||||
max_distance: Option<&T>,
|
||||
max_index: Option<&FheUint8>,
|
||||
) -> Vec<FheUint8>
|
||||
where
|
||||
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
|
||||
T: Clone,
|
||||
{
|
||||
match algorithm {
|
||||
"selection" => {
|
||||
@@ -129,7 +136,12 @@ where
|
||||
}
|
||||
"bitonic" => {
|
||||
println!("📊 Finding {k} smallest distances using bitonic sort...");
|
||||
encrypted_bitonic_sort(distances, k);
|
||||
if let (Some(max_dist), Some(max_idx)) = (max_distance, max_index) {
|
||||
encrypted_bitonic_sort(distances, k, max_dist, max_idx);
|
||||
} else {
|
||||
println!("⚠️ Bitonic sort requires max_distance and max_index, falling back to selection sort");
|
||||
encrypted_selection_sort(distances, k);
|
||||
}
|
||||
}
|
||||
"heap" => {
|
||||
println!("📊 Finding {k} smallest distances using min-heap selection...");
|
||||
@@ -174,26 +186,57 @@ pub fn encrypted_selection_sort<T: SupportedFheInt>(
|
||||
}
|
||||
|
||||
/// 双调排序算法 - 适合并行,固定比较模式
|
||||
/// 要求输入长度为2的幂次,自动填充到最近的2^k
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `distances` - 距离数组
|
||||
/// * `k` - 需要的最小元素数量
|
||||
/// * `max_distance` - 加密的最大距离值,用于填充
|
||||
/// * `max_index` - 加密的最大索引值,用于填充
|
||||
pub fn encrypted_bitonic_sort<T: SupportedFheInt>(
|
||||
distances: &mut Vec<EncryptedNeighbor<T>>,
|
||||
k: usize,
|
||||
max_distance: &T,
|
||||
max_index: &FheUint8,
|
||||
) where
|
||||
tfhe::FheBool: tfhe::prelude::IfThenElse<T> + tfhe::prelude::IfThenElse<FheUint8>,
|
||||
T: Clone,
|
||||
{
|
||||
println!("🔄 Starting bitonic sort...");
|
||||
let sort_start = Instant::now();
|
||||
|
||||
let original_len = distances.len();
|
||||
|
||||
// 计算最近的2的幂次
|
||||
let power_of_2 = (original_len as f64).log2().ceil() as u32;
|
||||
let target_len = 2_usize.pow(power_of_2);
|
||||
|
||||
println!("📏 Padding array from {} to {} (2^{})", original_len, target_len, power_of_2);
|
||||
|
||||
// 填充数组到2的幂次长度,使用客户端提供的最大值作为哨兵
|
||||
if target_len > original_len {
|
||||
let padding_elem = EncryptedNeighbor {
|
||||
distance: max_distance.clone(),
|
||||
index: max_index.clone(),
|
||||
};
|
||||
|
||||
for _ in original_len..target_len {
|
||||
distances.push(padding_elem.clone());
|
||||
}
|
||||
}
|
||||
|
||||
bitonic_sort_encrypted(distances, true);
|
||||
|
||||
// 只保留前k个结果
|
||||
distances.truncate(k);
|
||||
|
||||
println!(
|
||||
"✅ Bitonic sort completed in {}",
|
||||
format_duration(sort_start.elapsed())
|
||||
);
|
||||
distances.truncate(k);
|
||||
}
|
||||
|
||||
|
||||
/// 堆选择算法 - 适合高维数据的top-k选择
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -278,8 +321,14 @@ fn bitonic_sort_encrypted_recursive<T: SupportedFheInt>(
|
||||
}
|
||||
|
||||
let mid = arr.len() / 2;
|
||||
bitonic_sort_encrypted_recursive(&mut arr[..mid], true, depth + 1);
|
||||
bitonic_sort_encrypted_recursive(&mut arr[mid..], false, depth + 1);
|
||||
|
||||
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置
|
||||
let (left, right) = arr.split_at_mut(mid);
|
||||
rayon::join(
|
||||
|| bitonic_sort_encrypted_recursive(left, true, depth + 1),
|
||||
|| bitonic_sort_encrypted_recursive(right, false, depth + 1)
|
||||
);
|
||||
|
||||
bitonic_merge_encrypted(arr, up);
|
||||
}
|
||||
|
||||
@@ -297,19 +346,25 @@ where
|
||||
}
|
||||
|
||||
let mid = arr.len() / 2;
|
||||
for i in 0..mid {
|
||||
let should_swap = if up {
|
||||
arr[i].distance.gt(&arr[i + mid].distance)
|
||||
} else {
|
||||
arr[i].distance.lt(&arr[i + mid].distance)
|
||||
};
|
||||
|
||||
// 并行执行比较和条件交换操作 - server key已通过rayon::broadcast设置
|
||||
let (left, right) = arr.split_at_mut(mid);
|
||||
left.par_iter_mut()
|
||||
.zip(right.par_iter_mut())
|
||||
.for_each(|(left_elem, right_elem)| {
|
||||
let should_swap = if up {
|
||||
left_elem.distance.gt(&right_elem.distance)
|
||||
} else {
|
||||
left_elem.distance.lt(&right_elem.distance)
|
||||
};
|
||||
encrypted_conditional_swap(left_elem, right_elem, &should_swap);
|
||||
});
|
||||
|
||||
let (left, right) = arr.split_at_mut(mid);
|
||||
encrypted_conditional_swap(&mut left[i], &mut right[i], &should_swap);
|
||||
}
|
||||
|
||||
bitonic_merge_encrypted(&mut arr[..mid], up);
|
||||
bitonic_merge_encrypted(&mut arr[mid..], up);
|
||||
// 并行执行两个递归合并调用
|
||||
rayon::join(
|
||||
|| bitonic_merge_encrypted(left, up),
|
||||
|| bitonic_merge_encrypted(right, up)
|
||||
);
|
||||
}
|
||||
|
||||
/// 基于加密条件交换两个加密邻居的距离和索引
|
||||
|
||||
@@ -3,6 +3,7 @@ use chrono::Local;
|
||||
use clap::Parser;
|
||||
use log::info;
|
||||
use rand::Rng;
|
||||
use rayon::prelude::*;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::time::Instant;
|
||||
@@ -267,6 +268,8 @@ fn main() -> Result<()> {
|
||||
let config = ConfigBuilder::default().build();
|
||||
let (client_key, server_key) = generate_keys(config);
|
||||
|
||||
// 使用TFHE官方推荐的方式设置Rayon多线程
|
||||
rayon::broadcast(|_| set_server_key(server_key.clone()));
|
||||
set_server_key(server_key);
|
||||
info!(
|
||||
"✅ TFHE setup completed in {}",
|
||||
@@ -347,7 +350,14 @@ fn process_with_i32(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta
|
||||
};
|
||||
|
||||
// Perform KNN selection using the specified algorithm
|
||||
perform_knn_selection(&mut distances, k, &args.algorithm)
|
||||
let max_distance = if args.algorithm == "bitonic" {
|
||||
Some(FheInt32::try_encrypt(i32::MAX, client_key).unwrap())
|
||||
} else { None };
|
||||
let max_index = if args.algorithm == "bitonic" {
|
||||
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
|
||||
} else { None };
|
||||
|
||||
perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref())
|
||||
};
|
||||
|
||||
// Decrypt the results
|
||||
@@ -436,7 +446,14 @@ fn process_with_u12(args: &Args, client_key: &tfhe::ClientKey, start_time: Insta
|
||||
};
|
||||
|
||||
// Perform KNN selection using the specified algorithm
|
||||
perform_knn_selection(&mut distances, k, &args.algorithm)
|
||||
let max_distance = if args.algorithm == "bitonic" {
|
||||
Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap())
|
||||
} else { None };
|
||||
let max_index = if args.algorithm == "bitonic" {
|
||||
Some(tfhe::FheUint8::try_encrypt(255u8, client_key).unwrap())
|
||||
} else { None };
|
||||
|
||||
perform_knn_selection(&mut distances, k, &args.algorithm, max_distance.as_ref(), max_index.as_ref())
|
||||
};
|
||||
|
||||
// Decrypt the results
|
||||
|
||||
@@ -64,6 +64,11 @@ pub struct EncryptedNeighbor<T: SupportedFheInt> {
|
||||
pub index: FheUint8,
|
||||
}
|
||||
|
||||
// 为EncryptedNeighbor实现Send和Sync trait以支持并行处理
|
||||
// 这是安全的,因为我们使用rayon::broadcast确保每个线程都有正确的server key
|
||||
unsafe impl<T: SupportedFheInt> Send for EncryptedNeighbor<T> {}
|
||||
unsafe impl<T: SupportedFheInt> Sync for EncryptedNeighbor<T> {}
|
||||
|
||||
/// 泛型数值缩放trait
|
||||
pub trait ScaleInt {
|
||||
type Output;
|
||||
|
||||
Reference in New Issue
Block a user