refactor: migrate from FheUint12 to FheInt14 for better range support

- Replace FheUint12 with FheInt14 across all encryption operations
- Update scaling to use i16 instead of u32 for signed integer support
- Fix FheInt14 max value to 8191 (correct range limit)
- Optimize distance computation with parallel processing using rayon
- Simplify bitonic sort implementation by removing depth tracking
- Update plaintext version to match encrypted scaling behavior
- Add Default trait implementation for FheHnswGraph
This commit is contained in:
2025-08-06 16:26:54 +08:00
parent 4e13fd3023
commit 46b3562de0
5 changed files with 158 additions and 156 deletions

View File

@@ -4,7 +4,7 @@ use crate::logging::{format_duration, print_progress_bar};
use rayon::prelude::*;
use std::time::Instant;
use tfhe::prelude::*;
use tfhe::{FheUint8, FheUint12};
use tfhe::{FheInt14, FheUint8};
/// 优化的欧几里得距离计算:使用预计算的平方和
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
@@ -19,8 +19,8 @@ use tfhe::{FheUint8, FheUint12};
pub fn euclidean_distance(
query: &EncryptedQuery,
point: &EncryptedPoint,
zero: &FheUint12,
) -> FheUint12 {
zero: &FheInt14,
) -> FheInt14 {
// 计算 Σ((2*aᵢ)*bᵢ)
let mut cross_product_sum = zero.clone();
for (double_a, b) in query.double_coords.iter().zip(&point.coords) {
@@ -45,50 +45,34 @@ pub fn euclidean_distance(
pub fn compute_distances(
query: &EncryptedQuery,
points: &[EncryptedPoint],
zero: &FheUint12,
zero: &FheInt14,
) -> Vec<EncryptedNeighbor> {
println!("🔢 Computing encrypted distances...");
println!("🏭 Pre-encrypting constants for optimization...");
println!("🔢 Computing encrypted distances in parallel...");
let total_points = points.len();
let mut distances = Vec::with_capacity(total_points);
let computation_start = Instant::now();
for (i, point) in points.iter().enumerate() {
print_progress_bar(i + 1, total_points, "Distance calculation");
let dist_start = Instant::now();
// 确保并行计算结果顺序一致性
let mut distances = vec![None; total_points];
distances.par_iter_mut().enumerate().for_each(|(i, slot)| {
let point = &points[i];
let distance = euclidean_distance(query, point, zero);
let dist_time = dist_start.elapsed();
if i == 0 {
let estimated_total = dist_time.as_secs_f64() * total_points as f64;
println!(
"\n⏱️ First distance took {:.1}s, estimated total: {:.1} minutes",
dist_time.as_secs_f64(),
estimated_total / 60.0
);
}
// Show progress every 25 points for long calculations
if (i + 1) % 25 == 0 && i > 0 {
let elapsed = dist_time.as_secs_f64() * (i + 1) as f64;
let remaining_points = total_points - i - 1;
let estimated_remaining = dist_time.as_secs_f64() * remaining_points as f64;
println!(
"\n🕐 Completed {}/{} distances in {:.1}m, estimated {:.1}m remaining",
i + 1,
total_points,
elapsed / 60.0,
estimated_remaining / 60.0
);
}
distances.push(EncryptedNeighbor {
*slot = Some(EncryptedNeighbor {
distance,
index: point.index.clone(),
});
}
println!(); // New line after progress bar
});
let computation_time = computation_start.elapsed();
println!(
"✅ Parallel distance computation completed in {} for {} points",
format_duration(computation_time),
total_points
);
// 提取结果,确保所有计算都完成
let distances: Vec<EncryptedNeighbor> = distances.into_iter().map(|opt| opt.unwrap()).collect();
distances
}
@@ -108,7 +92,7 @@ pub fn perform_knn_selection(
distances: &mut Vec<EncryptedNeighbor>,
k: usize,
algorithm: &str,
max_distance: Option<&FheUint12>,
max_distance: Option<&FheInt14>,
max_index: Option<&FheUint8>,
) -> Vec<FheUint8> {
match algorithm {
@@ -175,7 +159,7 @@ pub fn encrypted_selection_sort(distances: &mut Vec<EncryptedNeighbor>, k: usize
pub fn encrypted_bitonic_sort(
distances: &mut Vec<EncryptedNeighbor>,
k: usize,
max_distance: &FheUint12,
max_distance: &FheInt14,
max_index: &FheUint8,
) {
println!("🔄 Starting bitonic sort...");
@@ -201,7 +185,7 @@ pub fn encrypted_bitonic_sort(
}
}
bitonic_sort_encrypted(distances, true);
bitonic_sort_encrypted(distances, true); // true表示升序最小距离在前
// 只保留前k个结果
distances.truncate(k);
@@ -263,31 +247,22 @@ pub fn encrypted_heap_select(distances: &mut Vec<EncryptedNeighbor>, k: usize) {
/// * `arr` - 待排序的加密邻居数组
/// * `up` - 排序方向true为升序false为降序
fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) {
bitonic_sort_encrypted_recursive(arr, up, 0);
bitonic_sort_encrypted_recursive(arr, up);
}
/// 双调排序的递归实现,带深度跟踪
fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool, depth: usize) {
/// 双调排序的递归实现
fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool) {
if arr.len() <= 1 {
return;
}
// 只在最顶层显示进度
if depth == 0 && arr.len() > 50 {
println!(
"🔀 Bitonic sort depth {}: processing {} elements",
depth,
arr.len()
);
}
let mid = arr.len() / 2;
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置
let (left, right) = arr.split_at_mut(mid);
rayon::join(
|| bitonic_sort_encrypted_recursive(left, true, depth + 1),
|| bitonic_sort_encrypted_recursive(right, false, depth + 1),
|| bitonic_sort_encrypted_recursive(left, true),
|| bitonic_sort_encrypted_recursive(right, false),
);
bitonic_merge_encrypted(arr, up);
@@ -361,7 +336,7 @@ pub fn perform_hnsw_search(
graph: &FheHnswGraph,
query: &EncryptedQuery,
k: usize,
zero: &FheUint12,
zero: &FheInt14,
) -> Vec<FheUint8> {
println!("🚀 Starting HNSW approximate search...");

View File

@@ -7,7 +7,7 @@ use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::time::Instant;
use tfhe::prelude::*;
use tfhe::{ConfigBuilder, FheUint12, generate_keys, set_server_key};
use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key};
// Import from our library modules
use hfe_knn::{
@@ -73,8 +73,8 @@ fn debug_compute_distances(
// Get the plaintext distance for this specific point (before sorting)
let plaintext_distance = plaintext_distances[i].0;
let scaled_distance = u32::scale_value(plaintext_distance) as u16;
let encrypted_distance = FheUint12::try_encrypt(scaled_distance, client_key).unwrap();
let scaled_distance = i16::scale_value(plaintext_distance);
let encrypted_distance = FheInt14::try_encrypt(scaled_distance, client_key).unwrap();
encrypted_distances.push(EncryptedNeighbor {
distance: encrypted_distance,
@@ -249,7 +249,7 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
// 3. 执行HNSW搜索
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap();
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero)
} else {
// 传统算法路径
@@ -271,13 +271,13 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
client_key,
)
} else {
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap();
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
};
// Perform KNN selection using the specified algorithm
let max_distance = if args.algorithm == "bitonic" {
Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap())
Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
} else {
None
};

View File

@@ -15,15 +15,27 @@ struct Args {
dataset: String,
#[arg(long, default_value = "./dataset/answer1.jsonl")]
predictions: String,
#[arg(long, default_value = "8", help = "Max connections per node (M parameter)")]
#[arg(
long,
default_value = "8",
help = "Max connections per node (M parameter)"
)]
max_connections: usize,
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
max_level: usize,
#[arg(long, default_value = "0.6", help = "Level selection probability")]
level_prob: f32,
#[arg(long, default_value = "1", help = "ef parameter for upper layer search")]
#[arg(
long,
default_value = "1",
help = "ef parameter for upper layer search"
)]
ef_upper: usize,
#[arg(long, default_value = "10", help = "ef parameter for bottom layer search")]
#[arg(
long,
default_value = "10",
help = "ef parameter for bottom layer search"
)]
ef_bottom: usize,
}
@@ -39,11 +51,18 @@ struct Prediction {
}
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
// 模拟加密版本的缩放和精度损失
let scaled_distance: f64 = a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
.map(|(x, y)| {
// 缩放坐标乘以10然后舍弃小数
let scaled_x = (x * 10.0).floor();
let scaled_y = (y * 10.0).floor();
(scaled_x - scaled_y).powi(2)
})
.sum::<f64>();
scaled_distance
}
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
@@ -52,8 +71,12 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
.enumerate()
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
.collect();
println!();
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
for distance in distances.iter().take(10) {
println!("{},{}", distance.0, distance.1);
}
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
}
@@ -257,78 +280,78 @@ impl HNSWGraph {
}
}
// fn main() -> Result<()> {
// let args = Args::parse();
//
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
//
// let mut results = Vec::new();
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// }
//
// Ok(())
// }
fn main() -> Result<()> {
let args = Args::parse();
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
let mut results = Vec::new();
println!("🔧 HNSW Parameters:");
println!(" max_connections: {}", args.max_connections);
println!(" max_level: {}", args.max_level);
println!(" level_prob: {}", args.level_prob);
println!(" ef_upper: {}", args.ef_upper);
println!(" ef_bottom: {}", args.ef_bottom);
let mut results = Vec::new();
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
for data_point in &dataset.data {
hnsw.insert_node(data_point.clone(), args.level_prob);
}
let nearest = if let Some(entry_point) = hnsw.entry_point {
let mut search_results = vec![entry_point];
// 上层搜索使用ef_upper参数
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
}
// 底层搜索使用ef_bottom参数
let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
final_results
.into_iter()
.take(10)
.map(|idx| idx + 1)
.collect()
} else {
Vec::new()
};
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
results.push(Prediction { answer: nearest });
}
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
println!("{}", serde_json::to_string(&result)?);
}
Ok(())
}
// fn main() -> Result<()> {
// let args = Args::parse();
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
// let mut results = Vec::new();
//
// println!("🔧 HNSW Parameters:");
// println!(" max_connections: {}", args.max_connections);
// println!(" max_level: {}", args.max_level);
// println!(" level_prob: {}", args.level_prob);
// println!(" ef_upper: {}", args.ef_upper);
// println!(" ef_bottom: {}", args.ef_bottom);
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
// for data_point in &dataset.data {
// hnsw.insert_node(data_point.clone(), args.level_prob);
// }
//
// let nearest = if let Some(entry_point) = hnsw.entry_point {
// let mut search_results = vec![entry_point];
//
// // 上层搜索使用ef_upper参数
// for layer in (1..=hnsw.nodes[entry_point].level).rev() {
// search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
// }
//
// // 底层搜索使用ef_bottom参数
// let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
// final_results
// .into_iter()
// .take(10)
// .map(|idx| idx + 1)
// .collect()
// } else {
// Vec::new()
// };
//
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// println!("{}", serde_json::to_string(&result)?);
// }
//
// Ok(())
// }

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use tfhe::prelude::*;
use tfhe::{ClientKey, FheUint8, FheUint12};
use tfhe::{ClientKey, FheInt14, FheUint8};
/// 数据集结构,包含查询点和训练数据
#[derive(Deserialize)]
@@ -19,23 +19,23 @@ pub struct Prediction {
/// 加密的查询点
#[derive(Clone)]
pub struct EncryptedQuery {
pub coords: Vec<FheUint12>,
pub double_coords: Vec<FheUint12>, // 2 * aᵢ
pub sum_squares: FheUint12, // Σaᵢ²
pub coords: Vec<FheInt14>,
pub double_coords: Vec<FheInt14>, // 2 * aᵢ
pub sum_squares: FheInt14, // Σaᵢ²
}
/// 加密的数据点,包含坐标、预计算值和索引
#[derive(Clone)]
pub struct EncryptedPoint {
pub coords: Vec<FheUint12>,
pub sum_squares: FheUint12, // Σbᵢ²
pub coords: Vec<FheInt14>,
pub sum_squares: FheInt14, // Σbᵢ²
pub index: FheUint8,
}
/// 加密的邻居点,包含距离和索引
#[derive(Clone)]
pub struct EncryptedNeighbor {
pub distance: FheUint12,
pub distance: FheInt14,
pub index: FheUint8,
}
@@ -50,37 +50,35 @@ pub trait ScaleInt {
fn scale_value(value: f64) -> Self::Output;
}
/// u32缩放实现舍弃2位小数限制在4000以内
impl ScaleInt for u32 {
type Output = u32;
fn scale_value(value: f64) -> u32 {
let truncated = value.floor() as u32;
truncated.min(4000)
/// i16缩放实现舍弃2位小数支持负数
impl ScaleInt for i16 {
type Output = i16;
fn scale_value(value: f64) -> i16 {
value.floor() as i16 // 直接取整,支持负数
}
}
/// 将明文查询点转换为加密查询点
pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
let scaled_coords: Vec<u16> = coords
let scaled_coords: Vec<i16> = coords
.iter()
.map(|&coord| u32::scale_value(coord) as u16)
.map(|&coord| i16::scale_value(coord))
.collect();
let encrypted_coords: Vec<FheUint12> = scaled_coords
let encrypted_coords: Vec<FheInt14> = scaled_coords
.iter()
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap())
.map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
.collect();
// 预计算 2 * aᵢ
let encrypted_double_coords: Vec<FheUint12> = scaled_coords
let encrypted_double_coords: Vec<FheInt14> = scaled_coords
.iter()
.map(|&coord| FheUint12::try_encrypt(coord * 2, client_key).unwrap())
.map(|&coord| FheInt14::try_encrypt(coord * 2, client_key).unwrap())
.collect();
// 预计算 Σaᵢ²
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum();
let encrypted_sum_squares =
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
EncryptedQuery {
coords: encrypted_coords,
@@ -91,21 +89,20 @@ pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
/// 将明文数据点转换为加密数据点
pub fn encrypt_point(coords: &[f64], index: usize, client_key: &ClientKey) -> EncryptedPoint {
let scaled_coords: Vec<u16> = coords
let scaled_coords: Vec<i16> = coords
.iter()
.map(|&coord| u32::scale_value(coord) as u16)
.map(|&coord| i16::scale_value(coord))
.collect();
let encrypted_coords: Vec<FheUint12> = scaled_coords
let encrypted_coords: Vec<FheInt14> = scaled_coords
.iter()
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap())
.map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
.collect();
// 预计算 Σbᵢ²
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum();
let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
let encrypted_sum_squares =
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
let encrypted_index = FheUint8::try_encrypt(index as u8, client_key).unwrap();
EncryptedPoint {
@@ -183,7 +180,7 @@ impl FheHnswGraph {
entry_points: Vec<usize>,
ef: usize,
layer: usize,
_zero: &FheUint12,
_zero: &FheInt14,
) -> Vec<usize> {
let _visited: HashSet<usize> = HashSet::new();
@@ -244,6 +241,12 @@ impl FheHnswGraph {
}
}
impl Default for FheHnswGraph {
fn default() -> Self {
Self::new()
}
}
/// 从明文数据构建密文 HNSW 图的辅助结构
#[derive(Clone)]
pub struct PlaintextHnswNode {

View File

@@ -1,5 +1,5 @@
use std::time::Duration;
use std::io::{self, Write};
use std::time::Duration;
/// 格式化时间长度为人类可读的字符串
///
@@ -42,4 +42,5 @@ pub fn print_progress_bar(current: usize, total: usize, operation: &str) {
percentage
);
io::stdout().flush().unwrap();
}
}