refactor: migrate from FheUint12 to FheInt14 for better range support
- Replace FheUint12 with FheInt14 across all encryption operations - Update scaling to use i16 instead of u32 for signed integer support - Fix FheInt14 max value to 8191 (correct range limit) - Optimize distance computation with parallel processing using rayon - Simplify bitonic sort implementation by removing depth tracking - Update plaintext version to match encrypted scaling behavior - Add Default trait implementation for FheHnswGraph
This commit is contained in:
@@ -4,7 +4,7 @@ use crate::logging::{format_duration, print_progress_bar};
|
||||
use rayon::prelude::*;
|
||||
use std::time::Instant;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{FheUint8, FheUint12};
|
||||
use tfhe::{FheInt14, FheUint8};
|
||||
|
||||
/// 优化的欧几里得距离计算:使用预计算的平方和
|
||||
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
|
||||
@@ -19,8 +19,8 @@ use tfhe::{FheUint8, FheUint12};
|
||||
pub fn euclidean_distance(
|
||||
query: &EncryptedQuery,
|
||||
point: &EncryptedPoint,
|
||||
zero: &FheUint12,
|
||||
) -> FheUint12 {
|
||||
zero: &FheInt14,
|
||||
) -> FheInt14 {
|
||||
// 计算 Σ((2*aᵢ)*bᵢ)
|
||||
let mut cross_product_sum = zero.clone();
|
||||
for (double_a, b) in query.double_coords.iter().zip(&point.coords) {
|
||||
@@ -45,50 +45,34 @@ pub fn euclidean_distance(
|
||||
pub fn compute_distances(
|
||||
query: &EncryptedQuery,
|
||||
points: &[EncryptedPoint],
|
||||
zero: &FheUint12,
|
||||
zero: &FheInt14,
|
||||
) -> Vec<EncryptedNeighbor> {
|
||||
println!("🔢 Computing encrypted distances...");
|
||||
println!("🏭 Pre-encrypting constants for optimization...");
|
||||
println!("🔢 Computing encrypted distances in parallel...");
|
||||
|
||||
let total_points = points.len();
|
||||
let mut distances = Vec::with_capacity(total_points);
|
||||
let computation_start = Instant::now();
|
||||
|
||||
for (i, point) in points.iter().enumerate() {
|
||||
print_progress_bar(i + 1, total_points, "Distance calculation");
|
||||
let dist_start = Instant::now();
|
||||
// 确保并行计算结果顺序一致性
|
||||
let mut distances = vec![None; total_points];
|
||||
|
||||
distances.par_iter_mut().enumerate().for_each(|(i, slot)| {
|
||||
let point = &points[i];
|
||||
let distance = euclidean_distance(query, point, zero);
|
||||
let dist_time = dist_start.elapsed();
|
||||
|
||||
if i == 0 {
|
||||
let estimated_total = dist_time.as_secs_f64() * total_points as f64;
|
||||
println!(
|
||||
"\n⏱️ First distance took {:.1}s, estimated total: {:.1} minutes",
|
||||
dist_time.as_secs_f64(),
|
||||
estimated_total / 60.0
|
||||
);
|
||||
}
|
||||
|
||||
// Show progress every 25 points for long calculations
|
||||
if (i + 1) % 25 == 0 && i > 0 {
|
||||
let elapsed = dist_time.as_secs_f64() * (i + 1) as f64;
|
||||
let remaining_points = total_points - i - 1;
|
||||
let estimated_remaining = dist_time.as_secs_f64() * remaining_points as f64;
|
||||
println!(
|
||||
"\n🕐 Completed {}/{} distances in {:.1}m, estimated {:.1}m remaining",
|
||||
i + 1,
|
||||
total_points,
|
||||
elapsed / 60.0,
|
||||
estimated_remaining / 60.0
|
||||
);
|
||||
}
|
||||
|
||||
distances.push(EncryptedNeighbor {
|
||||
*slot = Some(EncryptedNeighbor {
|
||||
distance,
|
||||
index: point.index.clone(),
|
||||
});
|
||||
}
|
||||
println!(); // New line after progress bar
|
||||
});
|
||||
|
||||
let computation_time = computation_start.elapsed();
|
||||
println!(
|
||||
"✅ Parallel distance computation completed in {} for {} points",
|
||||
format_duration(computation_time),
|
||||
total_points
|
||||
);
|
||||
|
||||
// 提取结果,确保所有计算都完成
|
||||
let distances: Vec<EncryptedNeighbor> = distances.into_iter().map(|opt| opt.unwrap()).collect();
|
||||
|
||||
distances
|
||||
}
|
||||
@@ -108,7 +92,7 @@ pub fn perform_knn_selection(
|
||||
distances: &mut Vec<EncryptedNeighbor>,
|
||||
k: usize,
|
||||
algorithm: &str,
|
||||
max_distance: Option<&FheUint12>,
|
||||
max_distance: Option<&FheInt14>,
|
||||
max_index: Option<&FheUint8>,
|
||||
) -> Vec<FheUint8> {
|
||||
match algorithm {
|
||||
@@ -175,7 +159,7 @@ pub fn encrypted_selection_sort(distances: &mut Vec<EncryptedNeighbor>, k: usize
|
||||
pub fn encrypted_bitonic_sort(
|
||||
distances: &mut Vec<EncryptedNeighbor>,
|
||||
k: usize,
|
||||
max_distance: &FheUint12,
|
||||
max_distance: &FheInt14,
|
||||
max_index: &FheUint8,
|
||||
) {
|
||||
println!("🔄 Starting bitonic sort...");
|
||||
@@ -201,7 +185,7 @@ pub fn encrypted_bitonic_sort(
|
||||
}
|
||||
}
|
||||
|
||||
bitonic_sort_encrypted(distances, true);
|
||||
bitonic_sort_encrypted(distances, true); // true表示升序,最小距离在前
|
||||
|
||||
// 只保留前k个结果
|
||||
distances.truncate(k);
|
||||
@@ -263,31 +247,22 @@ pub fn encrypted_heap_select(distances: &mut Vec<EncryptedNeighbor>, k: usize) {
|
||||
/// * `arr` - 待排序的加密邻居数组
|
||||
/// * `up` - 排序方向,true为升序,false为降序
|
||||
fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) {
|
||||
bitonic_sort_encrypted_recursive(arr, up, 0);
|
||||
bitonic_sort_encrypted_recursive(arr, up);
|
||||
}
|
||||
|
||||
/// 双调排序的递归实现,带深度跟踪
|
||||
fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool, depth: usize) {
|
||||
/// 双调排序的递归实现
|
||||
fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool) {
|
||||
if arr.len() <= 1 {
|
||||
return;
|
||||
}
|
||||
|
||||
// 只在最顶层显示进度
|
||||
if depth == 0 && arr.len() > 50 {
|
||||
println!(
|
||||
"🔀 Bitonic sort depth {}: processing {} elements",
|
||||
depth,
|
||||
arr.len()
|
||||
);
|
||||
}
|
||||
|
||||
let mid = arr.len() / 2;
|
||||
|
||||
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置
|
||||
let (left, right) = arr.split_at_mut(mid);
|
||||
rayon::join(
|
||||
|| bitonic_sort_encrypted_recursive(left, true, depth + 1),
|
||||
|| bitonic_sort_encrypted_recursive(right, false, depth + 1),
|
||||
|| bitonic_sort_encrypted_recursive(left, true),
|
||||
|| bitonic_sort_encrypted_recursive(right, false),
|
||||
);
|
||||
|
||||
bitonic_merge_encrypted(arr, up);
|
||||
@@ -361,7 +336,7 @@ pub fn perform_hnsw_search(
|
||||
graph: &FheHnswGraph,
|
||||
query: &EncryptedQuery,
|
||||
k: usize,
|
||||
zero: &FheUint12,
|
||||
zero: &FheInt14,
|
||||
) -> Vec<FheUint8> {
|
||||
println!("🚀 Starting HNSW approximate search...");
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::time::Instant;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{ConfigBuilder, FheUint12, generate_keys, set_server_key};
|
||||
use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key};
|
||||
|
||||
// Import from our library modules
|
||||
use hfe_knn::{
|
||||
@@ -73,8 +73,8 @@ fn debug_compute_distances(
|
||||
|
||||
// Get the plaintext distance for this specific point (before sorting)
|
||||
let plaintext_distance = plaintext_distances[i].0;
|
||||
let scaled_distance = u32::scale_value(plaintext_distance) as u16;
|
||||
let encrypted_distance = FheUint12::try_encrypt(scaled_distance, client_key).unwrap();
|
||||
let scaled_distance = i16::scale_value(plaintext_distance);
|
||||
let encrypted_distance = FheInt14::try_encrypt(scaled_distance, client_key).unwrap();
|
||||
|
||||
encrypted_distances.push(EncryptedNeighbor {
|
||||
distance: encrypted_distance,
|
||||
@@ -249,7 +249,7 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
|
||||
build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
|
||||
|
||||
// 3. 执行HNSW搜索
|
||||
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap();
|
||||
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero)
|
||||
} else {
|
||||
// 传统算法路径
|
||||
@@ -271,13 +271,13 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
|
||||
client_key,
|
||||
)
|
||||
} else {
|
||||
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap();
|
||||
let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
|
||||
};
|
||||
|
||||
// Perform KNN selection using the specified algorithm
|
||||
let max_distance = if args.algorithm == "bitonic" {
|
||||
Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap())
|
||||
Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
147
src/bin/plain.rs
147
src/bin/plain.rs
@@ -15,15 +15,27 @@ struct Args {
|
||||
dataset: String,
|
||||
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
||||
predictions: String,
|
||||
#[arg(long, default_value = "8", help = "Max connections per node (M parameter)")]
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "8",
|
||||
help = "Max connections per node (M parameter)"
|
||||
)]
|
||||
max_connections: usize,
|
||||
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
|
||||
max_level: usize,
|
||||
#[arg(long, default_value = "0.6", help = "Level selection probability")]
|
||||
level_prob: f32,
|
||||
#[arg(long, default_value = "1", help = "ef parameter for upper layer search")]
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "1",
|
||||
help = "ef parameter for upper layer search"
|
||||
)]
|
||||
ef_upper: usize,
|
||||
#[arg(long, default_value = "10", help = "ef parameter for bottom layer search")]
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "10",
|
||||
help = "ef parameter for bottom layer search"
|
||||
)]
|
||||
ef_bottom: usize,
|
||||
}
|
||||
|
||||
@@ -39,11 +51,18 @@ struct Prediction {
|
||||
}
|
||||
|
||||
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter()
|
||||
// 模拟加密版本的缩放和精度损失
|
||||
let scaled_distance: f64 = a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f64>()
|
||||
.sqrt()
|
||||
.map(|(x, y)| {
|
||||
// 缩放坐标(乘以10,然后舍弃小数)
|
||||
let scaled_x = (x * 10.0).floor();
|
||||
let scaled_y = (y * 10.0).floor();
|
||||
(scaled_x - scaled_y).powi(2)
|
||||
})
|
||||
.sum::<f64>();
|
||||
|
||||
scaled_distance
|
||||
}
|
||||
|
||||
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||
@@ -52,8 +71,12 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||
.enumerate()
|
||||
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
|
||||
.collect();
|
||||
println!();
|
||||
|
||||
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
for distance in distances.iter().take(10) {
|
||||
println!("{},{}", distance.0, distance.1);
|
||||
}
|
||||
|
||||
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
||||
}
|
||||
@@ -257,78 +280,78 @@ impl HNSWGraph {
|
||||
}
|
||||
}
|
||||
|
||||
// fn main() -> Result<()> {
|
||||
// let args = Args::parse();
|
||||
//
|
||||
// let file = File::open(&args.dataset)?;
|
||||
// let reader = BufReader::new(file);
|
||||
//
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// for line in reader.lines() {
|
||||
// let line = line?;
|
||||
// let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
//
|
||||
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
// results.push(Prediction { answer: nearest });
|
||||
// }
|
||||
//
|
||||
// let mut output_file = File::create(&args.predictions)?;
|
||||
// for result in results {
|
||||
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut results = Vec::new();
|
||||
|
||||
println!("🔧 HNSW Parameters:");
|
||||
println!(" max_connections: {}", args.max_connections);
|
||||
println!(" max_level: {}", args.max_level);
|
||||
println!(" level_prob: {}", args.level_prob);
|
||||
println!(" ef_upper: {}", args.ef_upper);
|
||||
println!(" ef_bottom: {}", args.ef_bottom);
|
||||
let mut results = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||
for data_point in &dataset.data {
|
||||
hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||
}
|
||||
|
||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
let mut search_results = vec![entry_point];
|
||||
|
||||
// 上层搜索使用ef_upper参数
|
||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||
}
|
||||
|
||||
// 底层搜索使用ef_bottom参数
|
||||
let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||
final_results
|
||||
.into_iter()
|
||||
.take(10)
|
||||
.map(|idx| idx + 1)
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
println!("{}", serde_json::to_string(&result)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// fn main() -> Result<()> {
|
||||
// let args = Args::parse();
|
||||
// let file = File::open(&args.dataset)?;
|
||||
// let reader = BufReader::new(file);
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// println!("🔧 HNSW Parameters:");
|
||||
// println!(" max_connections: {}", args.max_connections);
|
||||
// println!(" max_level: {}", args.max_level);
|
||||
// println!(" level_prob: {}", args.level_prob);
|
||||
// println!(" ef_upper: {}", args.ef_upper);
|
||||
// println!(" ef_bottom: {}", args.ef_bottom);
|
||||
//
|
||||
// for line in reader.lines() {
|
||||
// let line = line?;
|
||||
// let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
//
|
||||
// let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||
// for data_point in &dataset.data {
|
||||
// hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||
// }
|
||||
//
|
||||
// let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
// let mut search_results = vec![entry_point];
|
||||
//
|
||||
// // 上层搜索使用ef_upper参数
|
||||
// for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
// search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||
// }
|
||||
//
|
||||
// // 底层搜索使用ef_bottom参数
|
||||
// let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||
// final_results
|
||||
// .into_iter()
|
||||
// .take(10)
|
||||
// .map(|idx| idx + 1)
|
||||
// .collect()
|
||||
// } else {
|
||||
// Vec::new()
|
||||
// };
|
||||
//
|
||||
// results.push(Prediction { answer: nearest });
|
||||
// }
|
||||
//
|
||||
// let mut output_file = File::create(&args.predictions)?;
|
||||
// for result in results {
|
||||
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
// println!("{}", serde_json::to_string(&result)?);
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
63
src/data.rs
63
src/data.rs
@@ -1,7 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{ClientKey, FheUint8, FheUint12};
|
||||
use tfhe::{ClientKey, FheInt14, FheUint8};
|
||||
|
||||
/// 数据集结构,包含查询点和训练数据
|
||||
#[derive(Deserialize)]
|
||||
@@ -19,23 +19,23 @@ pub struct Prediction {
|
||||
/// 加密的查询点
|
||||
#[derive(Clone)]
|
||||
pub struct EncryptedQuery {
|
||||
pub coords: Vec<FheUint12>,
|
||||
pub double_coords: Vec<FheUint12>, // 2 * aᵢ
|
||||
pub sum_squares: FheUint12, // Σaᵢ²
|
||||
pub coords: Vec<FheInt14>,
|
||||
pub double_coords: Vec<FheInt14>, // 2 * aᵢ
|
||||
pub sum_squares: FheInt14, // Σaᵢ²
|
||||
}
|
||||
|
||||
/// 加密的数据点,包含坐标、预计算值和索引
|
||||
#[derive(Clone)]
|
||||
pub struct EncryptedPoint {
|
||||
pub coords: Vec<FheUint12>,
|
||||
pub sum_squares: FheUint12, // Σbᵢ²
|
||||
pub coords: Vec<FheInt14>,
|
||||
pub sum_squares: FheInt14, // Σbᵢ²
|
||||
pub index: FheUint8,
|
||||
}
|
||||
|
||||
/// 加密的邻居点,包含距离和索引
|
||||
#[derive(Clone)]
|
||||
pub struct EncryptedNeighbor {
|
||||
pub distance: FheUint12,
|
||||
pub distance: FheInt14,
|
||||
pub index: FheUint8,
|
||||
}
|
||||
|
||||
@@ -50,37 +50,35 @@ pub trait ScaleInt {
|
||||
fn scale_value(value: f64) -> Self::Output;
|
||||
}
|
||||
|
||||
/// u32缩放实现:舍弃2位小数,限制在4000以内
|
||||
impl ScaleInt for u32 {
|
||||
type Output = u32;
|
||||
fn scale_value(value: f64) -> u32 {
|
||||
let truncated = value.floor() as u32;
|
||||
truncated.min(4000)
|
||||
/// i16缩放实现:舍弃2位小数,支持负数
|
||||
impl ScaleInt for i16 {
|
||||
type Output = i16;
|
||||
fn scale_value(value: f64) -> i16 {
|
||||
value.floor() as i16 // 直接取整,支持负数
|
||||
}
|
||||
}
|
||||
|
||||
/// 将明文查询点转换为加密查询点
|
||||
pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
|
||||
let scaled_coords: Vec<u16> = coords
|
||||
let scaled_coords: Vec<i16> = coords
|
||||
.iter()
|
||||
.map(|&coord| u32::scale_value(coord) as u16)
|
||||
.map(|&coord| i16::scale_value(coord))
|
||||
.collect();
|
||||
|
||||
let encrypted_coords: Vec<FheUint12> = scaled_coords
|
||||
let encrypted_coords: Vec<FheInt14> = scaled_coords
|
||||
.iter()
|
||||
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap())
|
||||
.map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
|
||||
.collect();
|
||||
|
||||
// 预计算 2 * aᵢ
|
||||
let encrypted_double_coords: Vec<FheUint12> = scaled_coords
|
||||
let encrypted_double_coords: Vec<FheInt14> = scaled_coords
|
||||
.iter()
|
||||
.map(|&coord| FheUint12::try_encrypt(coord * 2, client_key).unwrap())
|
||||
.map(|&coord| FheInt14::try_encrypt(coord * 2, client_key).unwrap())
|
||||
.collect();
|
||||
|
||||
// 预计算 Σaᵢ²
|
||||
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum();
|
||||
let encrypted_sum_squares =
|
||||
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
|
||||
let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
|
||||
let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
|
||||
|
||||
EncryptedQuery {
|
||||
coords: encrypted_coords,
|
||||
@@ -91,21 +89,20 @@ pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
|
||||
|
||||
/// 将明文数据点转换为加密数据点
|
||||
pub fn encrypt_point(coords: &[f64], index: usize, client_key: &ClientKey) -> EncryptedPoint {
|
||||
let scaled_coords: Vec<u16> = coords
|
||||
let scaled_coords: Vec<i16> = coords
|
||||
.iter()
|
||||
.map(|&coord| u32::scale_value(coord) as u16)
|
||||
.map(|&coord| i16::scale_value(coord))
|
||||
.collect();
|
||||
|
||||
let encrypted_coords: Vec<FheUint12> = scaled_coords
|
||||
let encrypted_coords: Vec<FheInt14> = scaled_coords
|
||||
.iter()
|
||||
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap())
|
||||
.map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
|
||||
.collect();
|
||||
|
||||
// 预计算 Σbᵢ²
|
||||
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum();
|
||||
let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
|
||||
|
||||
let encrypted_sum_squares =
|
||||
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
|
||||
let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
|
||||
let encrypted_index = FheUint8::try_encrypt(index as u8, client_key).unwrap();
|
||||
|
||||
EncryptedPoint {
|
||||
@@ -183,7 +180,7 @@ impl FheHnswGraph {
|
||||
entry_points: Vec<usize>,
|
||||
ef: usize,
|
||||
layer: usize,
|
||||
_zero: &FheUint12,
|
||||
_zero: &FheInt14,
|
||||
) -> Vec<usize> {
|
||||
let _visited: HashSet<usize> = HashSet::new();
|
||||
|
||||
@@ -244,6 +241,12 @@ impl FheHnswGraph {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FheHnswGraph {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// 从明文数据构建密文 HNSW 图的辅助结构
|
||||
#[derive(Clone)]
|
||||
pub struct PlaintextHnswNode {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::time::Duration;
|
||||
use std::io::{self, Write};
|
||||
use std::time::Duration;
|
||||
|
||||
/// 格式化时间长度为人类可读的字符串
|
||||
///
|
||||
@@ -43,3 +43,4 @@ pub fn print_progress_bar(current: usize, total: usize, operation: &str) {
|
||||
);
|
||||
io::stdout().flush().unwrap();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user