refactor: migrate from FheUint12 to FheInt14 for better range support
- Replace FheUint12 with FheInt14 across all encryption operations - Update scaling to use i16 instead of u32 for signed integer support - Fix FheInt14 max value to 8191 (correct range limit) - Optimize distance computation with parallel processing using rayon - Simplify bitonic sort implementation by removing depth tracking - Update plaintext version to match encrypted scaling behavior - Add Default trait implementation for FheHnswGraph
This commit is contained in:
@@ -4,7 +4,7 @@ use crate::logging::{format_duration, print_progress_bar};
|
|||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tfhe::prelude::*;
|
use tfhe::prelude::*;
|
||||||
use tfhe::{FheUint8, FheUint12};
|
use tfhe::{FheInt14, FheUint8};
|
||||||
|
|
||||||
/// 优化的欧几里得距离计算:使用预计算的平方和
|
/// 优化的欧几里得距离计算:使用预计算的平方和
|
||||||
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
|
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
|
||||||
@@ -19,8 +19,8 @@ use tfhe::{FheUint8, FheUint12};
|
|||||||
pub fn euclidean_distance(
|
pub fn euclidean_distance(
|
||||||
query: &EncryptedQuery,
|
query: &EncryptedQuery,
|
||||||
point: &EncryptedPoint,
|
point: &EncryptedPoint,
|
||||||
zero: &FheUint12,
|
zero: &FheInt14,
|
||||||
) -> FheUint12 {
|
) -> FheInt14 {
|
||||||
// 计算 Σ((2*aᵢ)*bᵢ)
|
// 计算 Σ((2*aᵢ)*bᵢ)
|
||||||
let mut cross_product_sum = zero.clone();
|
let mut cross_product_sum = zero.clone();
|
||||||
for (double_a, b) in query.double_coords.iter().zip(&point.coords) {
|
for (double_a, b) in query.double_coords.iter().zip(&point.coords) {
|
||||||
@@ -45,50 +45,34 @@ pub fn euclidean_distance(
|
|||||||
pub fn compute_distances(
|
pub fn compute_distances(
|
||||||
query: &EncryptedQuery,
|
query: &EncryptedQuery,
|
||||||
points: &[EncryptedPoint],
|
points: &[EncryptedPoint],
|
||||||
zero: &FheUint12,
|
zero: &FheInt14,
|
||||||
) -> Vec<EncryptedNeighbor> {
|
) -> Vec<EncryptedNeighbor> {
|
||||||
println!("🔢 Computing encrypted distances...");
|
println!("🔢 Computing encrypted distances in parallel...");
|
||||||
println!("🏭 Pre-encrypting constants for optimization...");
|
|
||||||
|
|
||||||
let total_points = points.len();
|
let total_points = points.len();
|
||||||
let mut distances = Vec::with_capacity(total_points);
|
let computation_start = Instant::now();
|
||||||
|
|
||||||
for (i, point) in points.iter().enumerate() {
|
// 确保并行计算结果顺序一致性
|
||||||
print_progress_bar(i + 1, total_points, "Distance calculation");
|
let mut distances = vec![None; total_points];
|
||||||
let dist_start = Instant::now();
|
|
||||||
|
|
||||||
|
distances.par_iter_mut().enumerate().for_each(|(i, slot)| {
|
||||||
|
let point = &points[i];
|
||||||
let distance = euclidean_distance(query, point, zero);
|
let distance = euclidean_distance(query, point, zero);
|
||||||
let dist_time = dist_start.elapsed();
|
*slot = Some(EncryptedNeighbor {
|
||||||
|
|
||||||
if i == 0 {
|
|
||||||
let estimated_total = dist_time.as_secs_f64() * total_points as f64;
|
|
||||||
println!(
|
|
||||||
"\n⏱️ First distance took {:.1}s, estimated total: {:.1} minutes",
|
|
||||||
dist_time.as_secs_f64(),
|
|
||||||
estimated_total / 60.0
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show progress every 25 points for long calculations
|
|
||||||
if (i + 1) % 25 == 0 && i > 0 {
|
|
||||||
let elapsed = dist_time.as_secs_f64() * (i + 1) as f64;
|
|
||||||
let remaining_points = total_points - i - 1;
|
|
||||||
let estimated_remaining = dist_time.as_secs_f64() * remaining_points as f64;
|
|
||||||
println!(
|
|
||||||
"\n🕐 Completed {}/{} distances in {:.1}m, estimated {:.1}m remaining",
|
|
||||||
i + 1,
|
|
||||||
total_points,
|
|
||||||
elapsed / 60.0,
|
|
||||||
estimated_remaining / 60.0
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
distances.push(EncryptedNeighbor {
|
|
||||||
distance,
|
distance,
|
||||||
index: point.index.clone(),
|
index: point.index.clone(),
|
||||||
});
|
});
|
||||||
}
|
});
|
||||||
println!(); // New line after progress bar
|
|
||||||
|
let computation_time = computation_start.elapsed();
|
||||||
|
println!(
|
||||||
|
"✅ Parallel distance computation completed in {} for {} points",
|
||||||
|
format_duration(computation_time),
|
||||||
|
total_points
|
||||||
|
);
|
||||||
|
|
||||||
|
// 提取结果,确保所有计算都完成
|
||||||
|
let distances: Vec<EncryptedNeighbor> = distances.into_iter().map(|opt| opt.unwrap()).collect();
|
||||||
|
|
||||||
distances
|
distances
|
||||||
}
|
}
|
||||||
@@ -108,7 +92,7 @@ pub fn perform_knn_selection(
|
|||||||
distances: &mut Vec<EncryptedNeighbor>,
|
distances: &mut Vec<EncryptedNeighbor>,
|
||||||
k: usize,
|
k: usize,
|
||||||
algorithm: &str,
|
algorithm: &str,
|
||||||
max_distance: Option<&FheUint12>,
|
max_distance: Option<&FheInt14>,
|
||||||
max_index: Option<&FheUint8>,
|
max_index: Option<&FheUint8>,
|
||||||
) -> Vec<FheUint8> {
|
) -> Vec<FheUint8> {
|
||||||
match algorithm {
|
match algorithm {
|
||||||
@@ -175,7 +159,7 @@ pub fn encrypted_selection_sort(distances: &mut Vec<EncryptedNeighbor>, k: usize
|
|||||||
pub fn encrypted_bitonic_sort(
|
pub fn encrypted_bitonic_sort(
|
||||||
distances: &mut Vec<EncryptedNeighbor>,
|
distances: &mut Vec<EncryptedNeighbor>,
|
||||||
k: usize,
|
k: usize,
|
||||||
max_distance: &FheUint12,
|
max_distance: &FheInt14,
|
||||||
max_index: &FheUint8,
|
max_index: &FheUint8,
|
||||||
) {
|
) {
|
||||||
println!("🔄 Starting bitonic sort...");
|
println!("🔄 Starting bitonic sort...");
|
||||||
@@ -201,7 +185,7 @@ pub fn encrypted_bitonic_sort(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bitonic_sort_encrypted(distances, true);
|
bitonic_sort_encrypted(distances, true); // true表示升序,最小距离在前
|
||||||
|
|
||||||
// 只保留前k个结果
|
// 只保留前k个结果
|
||||||
distances.truncate(k);
|
distances.truncate(k);
|
||||||
@@ -263,31 +247,22 @@ pub fn encrypted_heap_select(distances: &mut Vec<EncryptedNeighbor>, k: usize) {
|
|||||||
/// * `arr` - 待排序的加密邻居数组
|
/// * `arr` - 待排序的加密邻居数组
|
||||||
/// * `up` - 排序方向,true为升序,false为降序
|
/// * `up` - 排序方向,true为升序,false为降序
|
||||||
fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) {
|
fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) {
|
||||||
bitonic_sort_encrypted_recursive(arr, up, 0);
|
bitonic_sort_encrypted_recursive(arr, up);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 双调排序的递归实现,带深度跟踪
|
/// 双调排序的递归实现
|
||||||
fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool, depth: usize) {
|
fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool) {
|
||||||
if arr.len() <= 1 {
|
if arr.len() <= 1 {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只在最顶层显示进度
|
|
||||||
if depth == 0 && arr.len() > 50 {
|
|
||||||
println!(
|
|
||||||
"🔀 Bitonic sort depth {}: processing {} elements",
|
|
||||||
depth,
|
|
||||||
arr.len()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mid = arr.len() / 2;
|
let mid = arr.len() / 2;
|
||||||
|
|
||||||
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置
|
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置
|
||||||
let (left, right) = arr.split_at_mut(mid);
|
let (left, right) = arr.split_at_mut(mid);
|
||||||
rayon::join(
|
rayon::join(
|
||||||
|| bitonic_sort_encrypted_recursive(left, true, depth + 1),
|
|| bitonic_sort_encrypted_recursive(left, true),
|
||||||
|| bitonic_sort_encrypted_recursive(right, false, depth + 1),
|
|| bitonic_sort_encrypted_recursive(right, false),
|
||||||
);
|
);
|
||||||
|
|
||||||
bitonic_merge_encrypted(arr, up);
|
bitonic_merge_encrypted(arr, up);
|
||||||
@@ -361,7 +336,7 @@ pub fn perform_hnsw_search(
|
|||||||
graph: &FheHnswGraph,
|
graph: &FheHnswGraph,
|
||||||
query: &EncryptedQuery,
|
query: &EncryptedQuery,
|
||||||
k: usize,
|
k: usize,
|
||||||
zero: &FheUint12,
|
zero: &FheInt14,
|
||||||
) -> Vec<FheUint8> {
|
) -> Vec<FheUint8> {
|
||||||
println!("🚀 Starting HNSW approximate search...");
|
println!("🚀 Starting HNSW approximate search...");
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ use std::fs::File;
|
|||||||
use std::io::{BufRead, BufReader, Write};
|
use std::io::{BufRead, BufReader, Write};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tfhe::prelude::*;
|
use tfhe::prelude::*;
|
||||||
use tfhe::{ConfigBuilder, FheUint12, generate_keys, set_server_key};
|
use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key};
|
||||||
|
|
||||||
// Import from our library modules
|
// Import from our library modules
|
||||||
use hfe_knn::{
|
use hfe_knn::{
|
||||||
@@ -73,8 +73,8 @@ fn debug_compute_distances(
|
|||||||
|
|
||||||
// Get the plaintext distance for this specific point (before sorting)
|
// Get the plaintext distance for this specific point (before sorting)
|
||||||
let plaintext_distance = plaintext_distances[i].0;
|
let plaintext_distance = plaintext_distances[i].0;
|
||||||
let scaled_distance = u32::scale_value(plaintext_distance) as u16;
|
let scaled_distance = i16::scale_value(plaintext_distance);
|
||||||
let encrypted_distance = FheUint12::try_encrypt(scaled_distance, client_key).unwrap();
|
let encrypted_distance = FheInt14::try_encrypt(scaled_distance, client_key).unwrap();
|
||||||
|
|
||||||
encrypted_distances.push(EncryptedNeighbor {
|
encrypted_distances.push(EncryptedNeighbor {
|
||||||
distance: encrypted_distance,
|
distance: encrypted_distance,
|
||||||
@@ -249,7 +249,7 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
|
|||||||
build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
|
build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
|
||||||
|
|
||||||
// 3. 执行HNSW搜索
|
// 3. 执行HNSW搜索
|
||||||
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap();
|
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||||
perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero)
|
perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero)
|
||||||
} else {
|
} else {
|
||||||
// 传统算法路径
|
// 传统算法路径
|
||||||
@@ -271,13 +271,13 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
|
|||||||
client_key,
|
client_key,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap();
|
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||||
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
|
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform KNN selection using the specified algorithm
|
// Perform KNN selection using the specified algorithm
|
||||||
let max_distance = if args.algorithm == "bitonic" {
|
let max_distance = if args.algorithm == "bitonic" {
|
||||||
Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap())
|
Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|||||||
147
src/bin/plain.rs
147
src/bin/plain.rs
@@ -15,15 +15,27 @@ struct Args {
|
|||||||
dataset: String,
|
dataset: String,
|
||||||
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
||||||
predictions: String,
|
predictions: String,
|
||||||
#[arg(long, default_value = "8", help = "Max connections per node (M parameter)")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
default_value = "8",
|
||||||
|
help = "Max connections per node (M parameter)"
|
||||||
|
)]
|
||||||
max_connections: usize,
|
max_connections: usize,
|
||||||
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
|
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
|
||||||
max_level: usize,
|
max_level: usize,
|
||||||
#[arg(long, default_value = "0.6", help = "Level selection probability")]
|
#[arg(long, default_value = "0.6", help = "Level selection probability")]
|
||||||
level_prob: f32,
|
level_prob: f32,
|
||||||
#[arg(long, default_value = "1", help = "ef parameter for upper layer search")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
default_value = "1",
|
||||||
|
help = "ef parameter for upper layer search"
|
||||||
|
)]
|
||||||
ef_upper: usize,
|
ef_upper: usize,
|
||||||
#[arg(long, default_value = "10", help = "ef parameter for bottom layer search")]
|
#[arg(
|
||||||
|
long,
|
||||||
|
default_value = "10",
|
||||||
|
help = "ef parameter for bottom layer search"
|
||||||
|
)]
|
||||||
ef_bottom: usize,
|
ef_bottom: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,11 +51,18 @@ struct Prediction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
||||||
a.iter()
|
// 模拟加密版本的缩放和精度损失
|
||||||
|
let scaled_distance: f64 = a.iter()
|
||||||
.zip(b.iter())
|
.zip(b.iter())
|
||||||
.map(|(x, y)| (x - y).powi(2))
|
.map(|(x, y)| {
|
||||||
.sum::<f64>()
|
// 缩放坐标(乘以10,然后舍弃小数)
|
||||||
.sqrt()
|
let scaled_x = (x * 10.0).floor();
|
||||||
|
let scaled_y = (y * 10.0).floor();
|
||||||
|
(scaled_x - scaled_y).powi(2)
|
||||||
|
})
|
||||||
|
.sum::<f64>();
|
||||||
|
|
||||||
|
scaled_distance
|
||||||
}
|
}
|
||||||
|
|
||||||
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||||
@@ -52,8 +71,12 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
|||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
|
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
|
||||||
.collect();
|
.collect();
|
||||||
|
println!();
|
||||||
|
|
||||||
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||||
|
for distance in distances.iter().take(10) {
|
||||||
|
println!("{},{}", distance.0, distance.1);
|
||||||
|
}
|
||||||
|
|
||||||
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
||||||
}
|
}
|
||||||
@@ -257,78 +280,78 @@ impl HNSWGraph {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fn main() -> Result<()> {
|
|
||||||
// let args = Args::parse();
|
|
||||||
//
|
|
||||||
// let file = File::open(&args.dataset)?;
|
|
||||||
// let reader = BufReader::new(file);
|
|
||||||
//
|
|
||||||
// let mut results = Vec::new();
|
|
||||||
//
|
|
||||||
// for line in reader.lines() {
|
|
||||||
// let line = line?;
|
|
||||||
// let dataset: Dataset = serde_json::from_str(&line)?;
|
|
||||||
//
|
|
||||||
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
|
||||||
// results.push(Prediction { answer: nearest });
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// let mut output_file = File::create(&args.predictions)?;
|
|
||||||
// for result in results {
|
|
||||||
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Ok(())
|
|
||||||
// }
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let file = File::open(&args.dataset)?;
|
let file = File::open(&args.dataset)?;
|
||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
let mut results = Vec::new();
|
|
||||||
|
|
||||||
println!("🔧 HNSW Parameters:");
|
let mut results = Vec::new();
|
||||||
println!(" max_connections: {}", args.max_connections);
|
|
||||||
println!(" max_level: {}", args.max_level);
|
|
||||||
println!(" level_prob: {}", args.level_prob);
|
|
||||||
println!(" ef_upper: {}", args.ef_upper);
|
|
||||||
println!(" ef_bottom: {}", args.ef_bottom);
|
|
||||||
|
|
||||||
for line in reader.lines() {
|
for line in reader.lines() {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||||
|
|
||||||
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||||
for data_point in &dataset.data {
|
|
||||||
hnsw.insert_node(data_point.clone(), args.level_prob);
|
|
||||||
}
|
|
||||||
|
|
||||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
|
||||||
let mut search_results = vec![entry_point];
|
|
||||||
|
|
||||||
// 上层搜索使用ef_upper参数
|
|
||||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
|
||||||
search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 底层搜索使用ef_bottom参数
|
|
||||||
let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
|
||||||
final_results
|
|
||||||
.into_iter()
|
|
||||||
.take(10)
|
|
||||||
.map(|idx| idx + 1)
|
|
||||||
.collect()
|
|
||||||
} else {
|
|
||||||
Vec::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
results.push(Prediction { answer: nearest });
|
results.push(Prediction { answer: nearest });
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut output_file = File::create(&args.predictions)?;
|
let mut output_file = File::create(&args.predictions)?;
|
||||||
for result in results {
|
for result in results {
|
||||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||||
println!("{}", serde_json::to_string(&result)?);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
// fn main() -> Result<()> {
|
||||||
|
// let args = Args::parse();
|
||||||
|
// let file = File::open(&args.dataset)?;
|
||||||
|
// let reader = BufReader::new(file);
|
||||||
|
// let mut results = Vec::new();
|
||||||
|
//
|
||||||
|
// println!("🔧 HNSW Parameters:");
|
||||||
|
// println!(" max_connections: {}", args.max_connections);
|
||||||
|
// println!(" max_level: {}", args.max_level);
|
||||||
|
// println!(" level_prob: {}", args.level_prob);
|
||||||
|
// println!(" ef_upper: {}", args.ef_upper);
|
||||||
|
// println!(" ef_bottom: {}", args.ef_bottom);
|
||||||
|
//
|
||||||
|
// for line in reader.lines() {
|
||||||
|
// let line = line?;
|
||||||
|
// let dataset: Dataset = serde_json::from_str(&line)?;
|
||||||
|
//
|
||||||
|
// let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||||
|
// for data_point in &dataset.data {
|
||||||
|
// hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||||
|
// let mut search_results = vec![entry_point];
|
||||||
|
//
|
||||||
|
// // 上层搜索使用ef_upper参数
|
||||||
|
// for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||||
|
// search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // 底层搜索使用ef_bottom参数
|
||||||
|
// let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||||
|
// final_results
|
||||||
|
// .into_iter()
|
||||||
|
// .take(10)
|
||||||
|
// .map(|idx| idx + 1)
|
||||||
|
// .collect()
|
||||||
|
// } else {
|
||||||
|
// Vec::new()
|
||||||
|
// };
|
||||||
|
//
|
||||||
|
// results.push(Prediction { answer: nearest });
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// let mut output_file = File::create(&args.predictions)?;
|
||||||
|
// for result in results {
|
||||||
|
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||||
|
// println!("{}", serde_json::to_string(&result)?);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Ok(())
|
||||||
|
// }
|
||||||
|
|||||||
63
src/data.rs
63
src/data.rs
@@ -1,7 +1,7 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use tfhe::prelude::*;
|
use tfhe::prelude::*;
|
||||||
use tfhe::{ClientKey, FheUint8, FheUint12};
|
use tfhe::{ClientKey, FheInt14, FheUint8};
|
||||||
|
|
||||||
/// 数据集结构,包含查询点和训练数据
|
/// 数据集结构,包含查询点和训练数据
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -19,23 +19,23 @@ pub struct Prediction {
|
|||||||
/// 加密的查询点
|
/// 加密的查询点
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct EncryptedQuery {
|
pub struct EncryptedQuery {
|
||||||
pub coords: Vec<FheUint12>,
|
pub coords: Vec<FheInt14>,
|
||||||
pub double_coords: Vec<FheUint12>, // 2 * aᵢ
|
pub double_coords: Vec<FheInt14>, // 2 * aᵢ
|
||||||
pub sum_squares: FheUint12, // Σaᵢ²
|
pub sum_squares: FheInt14, // Σaᵢ²
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 加密的数据点,包含坐标、预计算值和索引
|
/// 加密的数据点,包含坐标、预计算值和索引
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct EncryptedPoint {
|
pub struct EncryptedPoint {
|
||||||
pub coords: Vec<FheUint12>,
|
pub coords: Vec<FheInt14>,
|
||||||
pub sum_squares: FheUint12, // Σbᵢ²
|
pub sum_squares: FheInt14, // Σbᵢ²
|
||||||
pub index: FheUint8,
|
pub index: FheUint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 加密的邻居点,包含距离和索引
|
/// 加密的邻居点,包含距离和索引
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct EncryptedNeighbor {
|
pub struct EncryptedNeighbor {
|
||||||
pub distance: FheUint12,
|
pub distance: FheInt14,
|
||||||
pub index: FheUint8,
|
pub index: FheUint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,37 +50,35 @@ pub trait ScaleInt {
|
|||||||
fn scale_value(value: f64) -> Self::Output;
|
fn scale_value(value: f64) -> Self::Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// u32缩放实现:舍弃2位小数,限制在4000以内
|
/// i16缩放实现:舍弃2位小数,支持负数
|
||||||
impl ScaleInt for u32 {
|
impl ScaleInt for i16 {
|
||||||
type Output = u32;
|
type Output = i16;
|
||||||
fn scale_value(value: f64) -> u32 {
|
fn scale_value(value: f64) -> i16 {
|
||||||
let truncated = value.floor() as u32;
|
value.floor() as i16 // 直接取整,支持负数
|
||||||
truncated.min(4000)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 将明文查询点转换为加密查询点
|
/// 将明文查询点转换为加密查询点
|
||||||
pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
|
pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
|
||||||
let scaled_coords: Vec<u16> = coords
|
let scaled_coords: Vec<i16> = coords
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&coord| u32::scale_value(coord) as u16)
|
.map(|&coord| i16::scale_value(coord))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let encrypted_coords: Vec<FheUint12> = scaled_coords
|
let encrypted_coords: Vec<FheInt14> = scaled_coords
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap())
|
.map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// 预计算 2 * aᵢ
|
// 预计算 2 * aᵢ
|
||||||
let encrypted_double_coords: Vec<FheUint12> = scaled_coords
|
let encrypted_double_coords: Vec<FheInt14> = scaled_coords
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&coord| FheUint12::try_encrypt(coord * 2, client_key).unwrap())
|
.map(|&coord| FheInt14::try_encrypt(coord * 2, client_key).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// 预计算 Σaᵢ²
|
// 预计算 Σaᵢ²
|
||||||
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum();
|
let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
|
||||||
let encrypted_sum_squares =
|
let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
|
||||||
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
|
|
||||||
|
|
||||||
EncryptedQuery {
|
EncryptedQuery {
|
||||||
coords: encrypted_coords,
|
coords: encrypted_coords,
|
||||||
@@ -91,21 +89,20 @@ pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
|
|||||||
|
|
||||||
/// 将明文数据点转换为加密数据点
|
/// 将明文数据点转换为加密数据点
|
||||||
pub fn encrypt_point(coords: &[f64], index: usize, client_key: &ClientKey) -> EncryptedPoint {
|
pub fn encrypt_point(coords: &[f64], index: usize, client_key: &ClientKey) -> EncryptedPoint {
|
||||||
let scaled_coords: Vec<u16> = coords
|
let scaled_coords: Vec<i16> = coords
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&coord| u32::scale_value(coord) as u16)
|
.map(|&coord| i16::scale_value(coord))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let encrypted_coords: Vec<FheUint12> = scaled_coords
|
let encrypted_coords: Vec<FheInt14> = scaled_coords
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap())
|
.map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// 预计算 Σbᵢ²
|
// 预计算 Σbᵢ²
|
||||||
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum();
|
let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
|
||||||
|
|
||||||
let encrypted_sum_squares =
|
let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
|
||||||
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
|
|
||||||
let encrypted_index = FheUint8::try_encrypt(index as u8, client_key).unwrap();
|
let encrypted_index = FheUint8::try_encrypt(index as u8, client_key).unwrap();
|
||||||
|
|
||||||
EncryptedPoint {
|
EncryptedPoint {
|
||||||
@@ -183,7 +180,7 @@ impl FheHnswGraph {
|
|||||||
entry_points: Vec<usize>,
|
entry_points: Vec<usize>,
|
||||||
ef: usize,
|
ef: usize,
|
||||||
layer: usize,
|
layer: usize,
|
||||||
_zero: &FheUint12,
|
_zero: &FheInt14,
|
||||||
) -> Vec<usize> {
|
) -> Vec<usize> {
|
||||||
let _visited: HashSet<usize> = HashSet::new();
|
let _visited: HashSet<usize> = HashSet::new();
|
||||||
|
|
||||||
@@ -244,6 +241,12 @@ impl FheHnswGraph {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for FheHnswGraph {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// 从明文数据构建密文 HNSW 图的辅助结构
|
/// 从明文数据构建密文 HNSW 图的辅助结构
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct PlaintextHnswNode {
|
pub struct PlaintextHnswNode {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::time::Duration;
|
|
||||||
use std::io::{self, Write};
|
use std::io::{self, Write};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
/// 格式化时间长度为人类可读的字符串
|
/// 格式化时间长度为人类可读的字符串
|
||||||
///
|
///
|
||||||
@@ -43,3 +43,4 @@ pub fn print_progress_bar(current: usize, total: usize, operation: &str) {
|
|||||||
);
|
);
|
||||||
io::stdout().flush().unwrap();
|
io::stdout().flush().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user