diff --git a/src/bin/enc.rs b/src/bin/enc.rs new file mode 100644 index 0000000..d88f21a --- /dev/null +++ b/src/bin/enc.rs @@ -0,0 +1,431 @@ +use anyhow::Result; +use clap::Parser; +use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::{BufRead, BufReader, Write}; +use std::time::{Duration, Instant}; +use tfhe::prelude::*; +use tfhe::{ConfigBuilder, FheInt32, FheUint8, generate_keys, set_server_key}; + +#[derive(Parser)] +#[command(name = "hfe_knn_enc")] +#[command(about = "FHE-based KNN classifier using encryption")] +struct Args { + #[arg(long)] + dataset: String, + #[arg(long)] + predictions: String, +} + +#[derive(Deserialize)] +struct Dataset { + query: Vec, + data: Vec>, +} + +#[derive(Serialize)] +struct Prediction { + answer: Vec, +} + +#[derive(Clone)] +struct EncryptedPoint { + coords: Vec, + index: FheUint8, +} + +#[derive(Clone)] +struct EncryptedNeighbor { + distance: FheInt32, + index: FheUint8, +} + +/// 将浮点数值缩放100倍并转换为i32,用于FHE整数运算 +/// +/// # Arguments +/// * `value` - 原始浮点数值 +/// +/// # Returns +/// * 缩放后的i32值 +fn scale_to_i32(value: f64) -> i32 { + (value * 100.0) as i32 +} + +/// 将数据集中的查询和训练点转换为缩放后的i32向量 +/// +/// # Arguments +/// * `data` - 包含查询和训练数据的数据集 +/// +/// # Returns +/// * 元组:(查询向量, 训练点向量列表) +fn convert_data_to_i32(data: &Dataset) -> (Vec, Vec>) { + let query: Vec = data.query.iter().map(|&x| scale_to_i32(x)).collect(); + + let points: Vec> = data + .data + .iter() + .map(|point| point.iter().map(|&x| scale_to_i32(x)).collect()) + .collect(); + + (query, points) +} + +/// 使用客户端密钥加密查询向量 +/// +/// # Arguments +/// * `query` - 待加密的查询向量 +/// * `client_key` - TFHE客户端密钥 +/// +/// # Returns +/// * 加密后的查询向量或错误 +fn encrypt_query(query: &[i32], client_key: &tfhe::ClientKey) -> Result> { + query + .iter() + .map(|&x| FheInt32::try_encrypt(x, client_key)) + .collect::, _>>() + .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e)) +} + +/// 加密训练点集合,包含坐标和索引,显示进度条 +/// +/// # Arguments +/// * `points` - 训练点向量列表 +/// * `client_key` - TFHE客户端密钥 +/// +/// # Returns +/// * 加密后的训练点列表或错误 +fn encrypt_points( + points: &[Vec], + client_key: &tfhe::ClientKey, +) -> Result> { + let total_points = points.len(); + let mut encrypted_points = Vec::new(); + + for (i, point) in points.iter().enumerate() { + print_progress_bar(i + 1, total_points, "Encrypting points"); + + let encrypted_coords: Result, _> = point + .iter() + .map(|&x| FheInt32::try_encrypt(x, client_key)) + .collect(); + + let encrypted_index = FheUint8::try_encrypt((i + 1) as u8, client_key) + .map_err(|e| anyhow::anyhow!("Index encryption failed: {}", e))?; + + match encrypted_coords { + Ok(coords) => encrypted_points.push(EncryptedPoint { + coords, + index: encrypted_index, + }), + Err(e) => return Err(anyhow::anyhow!("Encryption failed: {}", e)), + } + } + + println!(); // New line after progress bar + Ok(encrypted_points) +} + +/// 计算两个加密向量间的欧氏距离平方 +/// +/// # Arguments +/// * `p1` - 第一个加密向量 +/// * `p2` - 第二个加密向量 +/// * `encrypted_zero` - 预加密的0值,用于优化性能 +/// +/// # Returns +/// * 加密的距离平方值 +fn encrypted_euclidean_distance_squared( + p1: &[FheInt32], + p2: &[FheInt32], + encrypted_zero: &FheInt32, +) -> FheInt32 { + let mut sum = encrypted_zero.clone(); + + for (x1, x2) in p1.iter().zip(p2.iter()) { + let diff = x1 - x2; + let squared = &diff * &diff; + sum += squared; + } + + sum +} + +/// 执行加密KNN分类,返回k个最近邻的加密索引 +/// +/// # Arguments +/// * `query` - 加密的查询向量 +/// * `points` - 加密的训练点集合 +/// * `k` - 返回的最近邻数量 +/// * `client_key` - TFHE客户端密钥(用于创建加密常量) +/// +/// # Returns +/// * k个最近邻的加密索引列表 +fn encrypted_knn_classify( + query: &[FheInt32], + points: &[EncryptedPoint], + k: usize, + client_key: &tfhe::ClientKey, +) -> Vec { + println!("🔢 Computing encrypted distances..."); + println!("🏭 Pre-encrypting constants for optimization..."); + let encrypted_zero = FheInt32::try_encrypt(0i32, client_key).unwrap(); + + let total_points = points.len(); + let mut distances = Vec::with_capacity(total_points); // Pre-allocate capacity + + for (i, point) in points.iter().enumerate() { + print_progress_bar(i + 1, total_points, "Distance calculation"); + let dist_start = Instant::now(); + let distance = encrypted_euclidean_distance_squared(query, &point.coords, &encrypted_zero); + let dist_time = dist_start.elapsed(); + + if i == 0 { + let estimated_total = dist_time.as_secs_f64() * total_points as f64; + println!( + "\n⏱️ First distance took {:.1}s, estimated total: {:.1} minutes", + dist_time.as_secs_f64(), + estimated_total / 60.0 + ); + println!("💾 Memory optimization: reusing encrypted constants"); + } + + distances.push(EncryptedNeighbor { + distance, + index: point.index.clone(), + }); + } + println!(); // New line after progress bar + + println!("📊 Sorting {} distances...", distances.len()); + bitonic_sort_encrypted(&mut distances, true); + distances.truncate(k); + + distances.iter().map(|n| n.index.clone()).collect() +} + +/// 解密KNN结果,将加密索引转换为明文索引 +/// +/// # Arguments +/// * `encrypted_result` - 加密的索引列表 +/// * `client_key` - TFHE客户端密钥 +/// +/// # Returns +/// * 解密后的索引列表 +fn decrypt_knn_result(encrypted_result: &[FheUint8], client_key: &tfhe::ClientKey) -> Vec { + encrypted_result + .iter() + .map(|encrypted_index| { + let decrypted: u8 = encrypted_index.decrypt(client_key); + decrypted as usize + }) + .collect() +} + +/// 格式化时间持续时长为可读字符串 +/// +/// # Arguments +/// * `duration` - 时间持续时长 +/// +/// # Returns +/// * 格式化的时间字符串(如 "5m 30s" 或 "45s") +fn format_duration(duration: Duration) -> String { + let total_secs = duration.as_secs(); + let mins = total_secs / 60; + let secs = total_secs % 60; + + if mins > 0 { + format!("{mins}m {secs}s") + } else { + format!("{secs}s") + } +} + +/// 在终端显示进度条 +/// +/// # Arguments +/// * `current` - 当前进度数量 +/// * `total` - 总数量 +/// * `operation` - 操作描述 +fn print_progress_bar(current: usize, total: usize, operation: &str) { + let percentage = (current as f64 / total as f64 * 100.0) as usize; + let bar_length = 40; + let filled = (current * bar_length / total).min(bar_length); + let empty = bar_length - filled; + + print!( + "\r{}: [{}{}] {}/{} {}%", + operation, + "█".repeat(filled), + "░".repeat(empty), + current, + total, + percentage + ); + std::io::stdout().flush().unwrap(); +} + +/// 对加密邻居数组执行双调排序 +/// +/// # Arguments +/// * `arr` - 待排序的加密邻居数组 +/// * `up` - 排序方向,true为升序,false为降序 +fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) { + if arr.len() <= 1 { + return; + } + + let mid = arr.len() / 2; + bitonic_sort_encrypted(&mut arr[..mid], true); + bitonic_sort_encrypted(&mut arr[mid..], false); + bitonic_merge_encrypted(arr, up); +} + +/// 双调排序的合并阶段,使用加密条件交换 +/// +/// # Arguments +/// * `arr` - 待合并的加密邻居数组 +/// * `up` - 合并方向,true为升序,false为降序 +fn bitonic_merge_encrypted(arr: &mut [EncryptedNeighbor], up: bool) { + if arr.len() <= 1 { + return; + } + + let mid = arr.len() / 2; + for i in 0..mid { + let should_swap = if up { + arr[i].distance.gt(&arr[i + mid].distance) + } else { + arr[i].distance.lt(&arr[i + mid].distance) + }; + + let (left, right) = arr.split_at_mut(mid); + encrypted_conditional_swap(&mut left[i], &mut right[i], &should_swap); + } + + bitonic_merge_encrypted(&mut arr[..mid], up); + bitonic_merge_encrypted(&mut arr[mid..], up); +} + +/// 基于加密条件交换两个加密邻居的距离和索引 +/// +/// # Arguments +/// * `a` - 第一个加密邻居 +/// * `b` - 第二个加密邻居 +/// * `condition` - 加密的交换条件 +fn encrypted_conditional_swap( + a: &mut EncryptedNeighbor, + b: &mut EncryptedNeighbor, + condition: &tfhe::FheBool, +) { + let new_a_distance = condition.select(&b.distance, &a.distance); + let new_b_distance = condition.select(&a.distance, &b.distance); + let new_a_index = condition.select(&b.index, &a.index); + let new_b_index = condition.select(&a.index, &b.index); + + a.distance = new_a_distance; + b.distance = new_b_distance; + a.index = new_a_index; + b.index = new_b_index; +} + +fn main() -> Result<()> { + let args = Args::parse(); + let start_time = Instant::now(); + + println!("🔧 Setting up TFHE parameters..."); + let setup_start = Instant::now(); + let config = ConfigBuilder::default().build(); + let (client_key, server_key) = generate_keys(config); + set_server_key(server_key); + println!( + "✅ TFHE setup completed in {}", + format_duration(setup_start.elapsed()) + ); + + let file = File::open(&args.dataset)?; + let reader = BufReader::new(file); + let mut results = Vec::new(); + let mut query_count = 0; + + for line in reader.lines() { + let line = line?; + let dataset: Dataset = serde_json::from_str(&line)?; + query_count += 1; + + println!( + "\n📊 Processing query #{} with {} training points", + query_count, + dataset.data.len() + ); + let query_start = Instant::now(); + + // Data conversion + let convert_start = Instant::now(); + let (query, points) = convert_data_to_i32(&dataset); + println!( + "✅ Data converted to i32 (100x scaling) in {}", + format_duration(convert_start.elapsed()) + ); + + // Query encryption + let query_enc_start = Instant::now(); + let encrypted_query = encrypt_query(&query, &client_key)?; + println!( + "✅ Query encrypted in {}", + format_duration(query_enc_start.elapsed()) + ); + + // Points encryption with progress + let points_enc_start = Instant::now(); + let encrypted_points = encrypt_points(&points, &client_key)?; + println!( + "✅ Points encrypted in {}", + format_duration(points_enc_start.elapsed()) + ); + + // KNN computation + println!("🔍 Running encrypted KNN search (k=10)..."); + println!("⏱️ Estimated time: ~8 minutes for {} points", points.len()); + let knn_start = Instant::now(); + let encrypted_result = + encrypted_knn_classify(&encrypted_query, &encrypted_points, 10, &client_key); + let knn_time = knn_start.elapsed(); + println!("✅ KNN search completed in {}", format_duration(knn_time)); + + // Decryption + let decrypt_start = Instant::now(); + let nearest = decrypt_knn_result(&encrypted_result, &client_key); + println!( + "✅ Results decrypted in {}", + format_duration(decrypt_start.elapsed()) + ); + + let query_total = query_start.elapsed(); + println!( + "📈 Query #{} completed in {}", + query_count, + format_duration(query_total) + ); + + results.push(Prediction { answer: nearest }); + } + + // Save results + let save_start = Instant::now(); + let mut output_file = File::create(&args.predictions)?; + for result in results { + writeln!(output_file, "{}", serde_json::to_string(&result)?)?; + } + println!( + "✅ Results saved to {} in {}", + args.predictions, + format_duration(save_start.elapsed()) + ); + + let total_time = start_time.elapsed(); + println!("\n🎉 All operations completed!"); + println!("📊 Total queries processed: {query_count}"); + println!("⏱️ Total execution time: {}", format_duration(total_time)); + + Ok(()) +}