refactor: migrate from FheUint12 to FheInt14 for better range support

- Replace FheUint12 with FheInt14 across all encryption operations
- Update scaling to use i16 instead of u32 for signed integer support
- Fix FheInt14 max value to 8191 (correct range limit)
- Optimize distance computation with parallel processing using rayon
- Simplify bitonic sort implementation by removing depth tracking
- Update plaintext version to match encrypted scaling behavior
- Add Default trait implementation for FheHnswGraph
This commit is contained in:
2025-08-06 16:26:54 +08:00
parent 4e13fd3023
commit 46b3562de0
5 changed files with 158 additions and 156 deletions

View File

@@ -4,7 +4,7 @@ use crate::logging::{format_duration, print_progress_bar};
use rayon::prelude::*; use rayon::prelude::*;
use std::time::Instant; use std::time::Instant;
use tfhe::prelude::*; use tfhe::prelude::*;
use tfhe::{FheUint8, FheUint12}; use tfhe::{FheInt14, FheUint8};
/// 优化的欧几里得距离计算:使用预计算的平方和 /// 优化的欧几里得距离计算:使用预计算的平方和
/// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ) /// ||a-b||² = Σaᵢ² + Σbᵢ² - Σ((2*aᵢ)*bᵢ)
@@ -19,8 +19,8 @@ use tfhe::{FheUint8, FheUint12};
pub fn euclidean_distance( pub fn euclidean_distance(
query: &EncryptedQuery, query: &EncryptedQuery,
point: &EncryptedPoint, point: &EncryptedPoint,
zero: &FheUint12, zero: &FheInt14,
) -> FheUint12 { ) -> FheInt14 {
// 计算 Σ((2*aᵢ)*bᵢ) // 计算 Σ((2*aᵢ)*bᵢ)
let mut cross_product_sum = zero.clone(); let mut cross_product_sum = zero.clone();
for (double_a, b) in query.double_coords.iter().zip(&point.coords) { for (double_a, b) in query.double_coords.iter().zip(&point.coords) {
@@ -45,50 +45,34 @@ pub fn euclidean_distance(
pub fn compute_distances( pub fn compute_distances(
query: &EncryptedQuery, query: &EncryptedQuery,
points: &[EncryptedPoint], points: &[EncryptedPoint],
zero: &FheUint12, zero: &FheInt14,
) -> Vec<EncryptedNeighbor> { ) -> Vec<EncryptedNeighbor> {
println!("🔢 Computing encrypted distances..."); println!("🔢 Computing encrypted distances in parallel...");
println!("🏭 Pre-encrypting constants for optimization...");
let total_points = points.len(); let total_points = points.len();
let mut distances = Vec::with_capacity(total_points); let computation_start = Instant::now();
for (i, point) in points.iter().enumerate() { // 确保并行计算结果顺序一致性
print_progress_bar(i + 1, total_points, "Distance calculation"); let mut distances = vec![None; total_points];
let dist_start = Instant::now();
distances.par_iter_mut().enumerate().for_each(|(i, slot)| {
let point = &points[i];
let distance = euclidean_distance(query, point, zero); let distance = euclidean_distance(query, point, zero);
let dist_time = dist_start.elapsed(); *slot = Some(EncryptedNeighbor {
if i == 0 {
let estimated_total = dist_time.as_secs_f64() * total_points as f64;
println!(
"\n⏱️ First distance took {:.1}s, estimated total: {:.1} minutes",
dist_time.as_secs_f64(),
estimated_total / 60.0
);
}
// Show progress every 25 points for long calculations
if (i + 1) % 25 == 0 && i > 0 {
let elapsed = dist_time.as_secs_f64() * (i + 1) as f64;
let remaining_points = total_points - i - 1;
let estimated_remaining = dist_time.as_secs_f64() * remaining_points as f64;
println!(
"\n🕐 Completed {}/{} distances in {:.1}m, estimated {:.1}m remaining",
i + 1,
total_points,
elapsed / 60.0,
estimated_remaining / 60.0
);
}
distances.push(EncryptedNeighbor {
distance, distance,
index: point.index.clone(), index: point.index.clone(),
}); });
} });
println!(); // New line after progress bar
let computation_time = computation_start.elapsed();
println!(
"✅ Parallel distance computation completed in {} for {} points",
format_duration(computation_time),
total_points
);
// 提取结果,确保所有计算都完成
let distances: Vec<EncryptedNeighbor> = distances.into_iter().map(|opt| opt.unwrap()).collect();
distances distances
} }
@@ -108,7 +92,7 @@ pub fn perform_knn_selection(
distances: &mut Vec<EncryptedNeighbor>, distances: &mut Vec<EncryptedNeighbor>,
k: usize, k: usize,
algorithm: &str, algorithm: &str,
max_distance: Option<&FheUint12>, max_distance: Option<&FheInt14>,
max_index: Option<&FheUint8>, max_index: Option<&FheUint8>,
) -> Vec<FheUint8> { ) -> Vec<FheUint8> {
match algorithm { match algorithm {
@@ -175,7 +159,7 @@ pub fn encrypted_selection_sort(distances: &mut Vec<EncryptedNeighbor>, k: usize
pub fn encrypted_bitonic_sort( pub fn encrypted_bitonic_sort(
distances: &mut Vec<EncryptedNeighbor>, distances: &mut Vec<EncryptedNeighbor>,
k: usize, k: usize,
max_distance: &FheUint12, max_distance: &FheInt14,
max_index: &FheUint8, max_index: &FheUint8,
) { ) {
println!("🔄 Starting bitonic sort..."); println!("🔄 Starting bitonic sort...");
@@ -201,7 +185,7 @@ pub fn encrypted_bitonic_sort(
} }
} }
bitonic_sort_encrypted(distances, true); bitonic_sort_encrypted(distances, true); // true表示升序最小距离在前
// 只保留前k个结果 // 只保留前k个结果
distances.truncate(k); distances.truncate(k);
@@ -263,31 +247,22 @@ pub fn encrypted_heap_select(distances: &mut Vec<EncryptedNeighbor>, k: usize) {
/// * `arr` - 待排序的加密邻居数组 /// * `arr` - 待排序的加密邻居数组
/// * `up` - 排序方向true为升序false为降序 /// * `up` - 排序方向true为升序false为降序
fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) { fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) {
bitonic_sort_encrypted_recursive(arr, up, 0); bitonic_sort_encrypted_recursive(arr, up);
} }
/// 双调排序的递归实现,带深度跟踪 /// 双调排序的递归实现
fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool, depth: usize) { fn bitonic_sort_encrypted_recursive(arr: &mut [EncryptedNeighbor], up: bool) {
if arr.len() <= 1 { if arr.len() <= 1 {
return; return;
} }
// 只在最顶层显示进度
if depth == 0 && arr.len() > 50 {
println!(
"🔀 Bitonic sort depth {}: processing {} elements",
depth,
arr.len()
);
}
let mid = arr.len() / 2; let mid = arr.len() / 2;
// 并行执行两个递归调用 - server key已通过rayon::broadcast设置 // 并行执行两个递归调用 - server key已通过rayon::broadcast设置
let (left, right) = arr.split_at_mut(mid); let (left, right) = arr.split_at_mut(mid);
rayon::join( rayon::join(
|| bitonic_sort_encrypted_recursive(left, true, depth + 1), || bitonic_sort_encrypted_recursive(left, true),
|| bitonic_sort_encrypted_recursive(right, false, depth + 1), || bitonic_sort_encrypted_recursive(right, false),
); );
bitonic_merge_encrypted(arr, up); bitonic_merge_encrypted(arr, up);
@@ -361,7 +336,7 @@ pub fn perform_hnsw_search(
graph: &FheHnswGraph, graph: &FheHnswGraph,
query: &EncryptedQuery, query: &EncryptedQuery,
k: usize, k: usize,
zero: &FheUint12, zero: &FheInt14,
) -> Vec<FheUint8> { ) -> Vec<FheUint8> {
println!("🚀 Starting HNSW approximate search..."); println!("🚀 Starting HNSW approximate search...");

View File

@@ -7,7 +7,7 @@ use std::fs::File;
use std::io::{BufRead, BufReader, Write}; use std::io::{BufRead, BufReader, Write};
use std::time::Instant; use std::time::Instant;
use tfhe::prelude::*; use tfhe::prelude::*;
use tfhe::{ConfigBuilder, FheUint12, generate_keys, set_server_key}; use tfhe::{ConfigBuilder, FheInt14, generate_keys, set_server_key};
// Import from our library modules // Import from our library modules
use hfe_knn::{ use hfe_knn::{
@@ -73,8 +73,8 @@ fn debug_compute_distances(
// Get the plaintext distance for this specific point (before sorting) // Get the plaintext distance for this specific point (before sorting)
let plaintext_distance = plaintext_distances[i].0; let plaintext_distance = plaintext_distances[i].0;
let scaled_distance = u32::scale_value(plaintext_distance) as u16; let scaled_distance = i16::scale_value(plaintext_distance);
let encrypted_distance = FheUint12::try_encrypt(scaled_distance, client_key).unwrap(); let encrypted_distance = FheInt14::try_encrypt(scaled_distance, client_key).unwrap();
encrypted_distances.push(EncryptedNeighbor { encrypted_distances.push(EncryptedNeighbor {
distance: encrypted_distance, distance: encrypted_distance,
@@ -249,7 +249,7 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key); build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
// 3. 执行HNSW搜索 // 3. 执行HNSW搜索
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap(); let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero) perform_hnsw_search(&fhe_graph, &query_encrypted, k, &encrypted_zero)
} else { } else {
// 传统算法路径 // 传统算法路径
@@ -271,13 +271,13 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
client_key, client_key,
) )
} else { } else {
let encrypted_zero = FheUint12::try_encrypt(0u16, client_key).unwrap(); let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero) compute_distances(&query_encrypted, &points_encrypted, &encrypted_zero)
}; };
// Perform KNN selection using the specified algorithm // Perform KNN selection using the specified algorithm
let max_distance = if args.algorithm == "bitonic" { let max_distance = if args.algorithm == "bitonic" {
Some(tfhe::FheUint12::try_encrypt(u16::MAX, client_key).unwrap()) Some(FheInt14::try_encrypt(8191i16, client_key).unwrap()) // FheInt14的正确最大值
} else { } else {
None None
}; };

View File

@@ -15,15 +15,27 @@ struct Args {
dataset: String, dataset: String,
#[arg(long, default_value = "./dataset/answer1.jsonl")] #[arg(long, default_value = "./dataset/answer1.jsonl")]
predictions: String, predictions: String,
#[arg(long, default_value = "8", help = "Max connections per node (M parameter)")] #[arg(
long,
default_value = "8",
help = "Max connections per node (M parameter)"
)]
max_connections: usize, max_connections: usize,
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")] #[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
max_level: usize, max_level: usize,
#[arg(long, default_value = "0.6", help = "Level selection probability")] #[arg(long, default_value = "0.6", help = "Level selection probability")]
level_prob: f32, level_prob: f32,
#[arg(long, default_value = "1", help = "ef parameter for upper layer search")] #[arg(
long,
default_value = "1",
help = "ef parameter for upper layer search"
)]
ef_upper: usize, ef_upper: usize,
#[arg(long, default_value = "10", help = "ef parameter for bottom layer search")] #[arg(
long,
default_value = "10",
help = "ef parameter for bottom layer search"
)]
ef_bottom: usize, ef_bottom: usize,
} }
@@ -39,11 +51,18 @@ struct Prediction {
} }
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter() // 模拟加密版本的缩放和精度损失
let scaled_distance: f64 = a.iter()
.zip(b.iter()) .zip(b.iter())
.map(|(x, y)| (x - y).powi(2)) .map(|(x, y)| {
.sum::<f64>() // 缩放坐标乘以10然后舍弃小数
.sqrt() let scaled_x = (x * 10.0).floor();
let scaled_y = (y * 10.0).floor();
(scaled_x - scaled_y).powi(2)
})
.sum::<f64>();
scaled_distance
} }
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> { fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
@@ -52,8 +71,12 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
.enumerate() .enumerate()
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1)) .map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
.collect(); .collect();
println!();
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
for distance in distances.iter().take(10) {
println!("{},{}", distance.0, distance.1);
}
distances.into_iter().take(k).map(|(_, idx)| idx).collect() distances.into_iter().take(k).map(|(_, idx)| idx).collect()
} }
@@ -257,78 +280,78 @@ impl HNSWGraph {
} }
} }
// fn main() -> Result<()> {
// let args = Args::parse();
//
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
//
// let mut results = Vec::new();
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// }
//
// Ok(())
// }
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
let file = File::open(&args.dataset)?; let file = File::open(&args.dataset)?;
let reader = BufReader::new(file); let reader = BufReader::new(file);
let mut results = Vec::new();
println!("🔧 HNSW Parameters:"); let mut results = Vec::new();
println!(" max_connections: {}", args.max_connections);
println!(" max_level: {}", args.max_level);
println!(" level_prob: {}", args.level_prob);
println!(" ef_upper: {}", args.ef_upper);
println!(" ef_bottom: {}", args.ef_bottom);
for line in reader.lines() { for line in reader.lines() {
let line = line?; let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?; let dataset: Dataset = serde_json::from_str(&line)?;
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections); let nearest = knn_classify(&dataset.query, &dataset.data, 10);
for data_point in &dataset.data {
hnsw.insert_node(data_point.clone(), args.level_prob);
}
let nearest = if let Some(entry_point) = hnsw.entry_point {
let mut search_results = vec![entry_point];
// 上层搜索使用ef_upper参数
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
}
// 底层搜索使用ef_bottom参数
let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
final_results
.into_iter()
.take(10)
.map(|idx| idx + 1)
.collect()
} else {
Vec::new()
};
results.push(Prediction { answer: nearest }); results.push(Prediction { answer: nearest });
} }
let mut output_file = File::create(&args.predictions)?; let mut output_file = File::create(&args.predictions)?;
for result in results { for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?; writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
println!("{}", serde_json::to_string(&result)?);
} }
Ok(()) Ok(())
} }
// fn main() -> Result<()> {
// let args = Args::parse();
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
// let mut results = Vec::new();
//
// println!("🔧 HNSW Parameters:");
// println!(" max_connections: {}", args.max_connections);
// println!(" max_level: {}", args.max_level);
// println!(" level_prob: {}", args.level_prob);
// println!(" ef_upper: {}", args.ef_upper);
// println!(" ef_bottom: {}", args.ef_bottom);
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
// for data_point in &dataset.data {
// hnsw.insert_node(data_point.clone(), args.level_prob);
// }
//
// let nearest = if let Some(entry_point) = hnsw.entry_point {
// let mut search_results = vec![entry_point];
//
// // 上层搜索使用ef_upper参数
// for layer in (1..=hnsw.nodes[entry_point].level).rev() {
// search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
// }
//
// // 底层搜索使用ef_bottom参数
// let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
// final_results
// .into_iter()
// .take(10)
// .map(|idx| idx + 1)
// .collect()
// } else {
// Vec::new()
// };
//
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// println!("{}", serde_json::to_string(&result)?);
// }
//
// Ok(())
// }

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashSet; use std::collections::HashSet;
use tfhe::prelude::*; use tfhe::prelude::*;
use tfhe::{ClientKey, FheUint8, FheUint12}; use tfhe::{ClientKey, FheInt14, FheUint8};
/// 数据集结构,包含查询点和训练数据 /// 数据集结构,包含查询点和训练数据
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -19,23 +19,23 @@ pub struct Prediction {
/// 加密的查询点 /// 加密的查询点
#[derive(Clone)] #[derive(Clone)]
pub struct EncryptedQuery { pub struct EncryptedQuery {
pub coords: Vec<FheUint12>, pub coords: Vec<FheInt14>,
pub double_coords: Vec<FheUint12>, // 2 * aᵢ pub double_coords: Vec<FheInt14>, // 2 * aᵢ
pub sum_squares: FheUint12, // Σaᵢ² pub sum_squares: FheInt14, // Σaᵢ²
} }
/// 加密的数据点,包含坐标、预计算值和索引 /// 加密的数据点,包含坐标、预计算值和索引
#[derive(Clone)] #[derive(Clone)]
pub struct EncryptedPoint { pub struct EncryptedPoint {
pub coords: Vec<FheUint12>, pub coords: Vec<FheInt14>,
pub sum_squares: FheUint12, // Σbᵢ² pub sum_squares: FheInt14, // Σbᵢ²
pub index: FheUint8, pub index: FheUint8,
} }
/// 加密的邻居点,包含距离和索引 /// 加密的邻居点,包含距离和索引
#[derive(Clone)] #[derive(Clone)]
pub struct EncryptedNeighbor { pub struct EncryptedNeighbor {
pub distance: FheUint12, pub distance: FheInt14,
pub index: FheUint8, pub index: FheUint8,
} }
@@ -50,37 +50,35 @@ pub trait ScaleInt {
fn scale_value(value: f64) -> Self::Output; fn scale_value(value: f64) -> Self::Output;
} }
/// u32缩放实现舍弃2位小数限制在4000以内 /// i16缩放实现舍弃2位小数支持负数
impl ScaleInt for u32 { impl ScaleInt for i16 {
type Output = u32; type Output = i16;
fn scale_value(value: f64) -> u32 { fn scale_value(value: f64) -> i16 {
let truncated = value.floor() as u32; value.floor() as i16 // 直接取整,支持负数
truncated.min(4000)
} }
} }
/// 将明文查询点转换为加密查询点 /// 将明文查询点转换为加密查询点
pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery { pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
let scaled_coords: Vec<u16> = coords let scaled_coords: Vec<i16> = coords
.iter() .iter()
.map(|&coord| u32::scale_value(coord) as u16) .map(|&coord| i16::scale_value(coord))
.collect(); .collect();
let encrypted_coords: Vec<FheUint12> = scaled_coords let encrypted_coords: Vec<FheInt14> = scaled_coords
.iter() .iter()
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap()) .map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
.collect(); .collect();
// 预计算 2 * aᵢ // 预计算 2 * aᵢ
let encrypted_double_coords: Vec<FheUint12> = scaled_coords let encrypted_double_coords: Vec<FheInt14> = scaled_coords
.iter() .iter()
.map(|&coord| FheUint12::try_encrypt(coord * 2, client_key).unwrap()) .map(|&coord| FheInt14::try_encrypt(coord * 2, client_key).unwrap())
.collect(); .collect();
// 预计算 Σaᵢ² // 预计算 Σaᵢ²
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum(); let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
let encrypted_sum_squares = let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
EncryptedQuery { EncryptedQuery {
coords: encrypted_coords, coords: encrypted_coords,
@@ -91,21 +89,20 @@ pub fn encrypt_query(coords: &[f64], client_key: &ClientKey) -> EncryptedQuery {
/// 将明文数据点转换为加密数据点 /// 将明文数据点转换为加密数据点
pub fn encrypt_point(coords: &[f64], index: usize, client_key: &ClientKey) -> EncryptedPoint { pub fn encrypt_point(coords: &[f64], index: usize, client_key: &ClientKey) -> EncryptedPoint {
let scaled_coords: Vec<u16> = coords let scaled_coords: Vec<i16> = coords
.iter() .iter()
.map(|&coord| u32::scale_value(coord) as u16) .map(|&coord| i16::scale_value(coord))
.collect(); .collect();
let encrypted_coords: Vec<FheUint12> = scaled_coords let encrypted_coords: Vec<FheInt14> = scaled_coords
.iter() .iter()
.map(|&coord| FheUint12::try_encrypt(coord, client_key).unwrap()) .map(|&coord| FheInt14::try_encrypt(coord, client_key).unwrap())
.collect(); .collect();
// 预计算 Σbᵢ² // 预计算 Σbᵢ²
let sum_squares: u32 = scaled_coords.iter().map(|&x| (x as u32) * (x as u32)).sum(); let sum_squares: i32 = scaled_coords.iter().map(|&x| (x as i32) * (x as i32)).sum();
let encrypted_sum_squares = let encrypted_sum_squares = FheInt14::try_encrypt(sum_squares as i16, client_key).unwrap();
FheUint12::try_encrypt((sum_squares.min(4095)) as u16, client_key).unwrap();
let encrypted_index = FheUint8::try_encrypt(index as u8, client_key).unwrap(); let encrypted_index = FheUint8::try_encrypt(index as u8, client_key).unwrap();
EncryptedPoint { EncryptedPoint {
@@ -183,7 +180,7 @@ impl FheHnswGraph {
entry_points: Vec<usize>, entry_points: Vec<usize>,
ef: usize, ef: usize,
layer: usize, layer: usize,
_zero: &FheUint12, _zero: &FheInt14,
) -> Vec<usize> { ) -> Vec<usize> {
let _visited: HashSet<usize> = HashSet::new(); let _visited: HashSet<usize> = HashSet::new();
@@ -244,6 +241,12 @@ impl FheHnswGraph {
} }
} }
impl Default for FheHnswGraph {
fn default() -> Self {
Self::new()
}
}
/// 从明文数据构建密文 HNSW 图的辅助结构 /// 从明文数据构建密文 HNSW 图的辅助结构
#[derive(Clone)] #[derive(Clone)]
pub struct PlaintextHnswNode { pub struct PlaintextHnswNode {

View File

@@ -1,5 +1,5 @@
use std::time::Duration;
use std::io::{self, Write}; use std::io::{self, Write};
use std::time::Duration;
/// 格式化时间长度为人类可读的字符串 /// 格式化时间长度为人类可读的字符串
/// ///
@@ -42,4 +42,5 @@ pub fn print_progress_bar(current: usize, total: usize, operation: &str) {
percentage percentage
); );
io::stdout().flush().unwrap(); io::stdout().flush().unwrap();
} }