use anyhow::Result; use clap::Parser; use serde::{Deserialize, Serialize}; use std::fs::File; use std::io::{BufRead, BufReader, Write}; use std::time::{Duration, Instant}; use tfhe::prelude::*; use tfhe::{ConfigBuilder, FheInt32, FheUint8, generate_keys, set_server_key}; #[derive(Parser)] #[command(name = "hfe_knn_enc")] #[command(about = "FHE-based KNN classifier using encryption")] struct Args { #[arg(long)] dataset: String, #[arg(long)] predictions: String, } #[derive(Deserialize)] struct Dataset { query: Vec, data: Vec>, } #[derive(Serialize)] struct Prediction { answer: Vec, } #[derive(Clone)] struct EncryptedPoint { coords: Vec, index: FheUint8, } #[derive(Clone)] struct EncryptedNeighbor { distance: FheInt32, index: FheUint8, } /// 将浮点数值缩放100倍并转换为i32,用于FHE整数运算 /// /// # Arguments /// * `value` - 原始浮点数值 /// /// # Returns /// * 缩放后的i32值 fn scale_to_i32(value: f64) -> i32 { (value * 100.0) as i32 } /// 将数据集中的查询和训练点转换为缩放后的i32向量 /// /// # Arguments /// * `data` - 包含查询和训练数据的数据集 /// /// # Returns /// * 元组:(查询向量, 训练点向量列表) fn convert_data_to_i32(data: &Dataset) -> (Vec, Vec>) { let query: Vec = data.query.iter().map(|&x| scale_to_i32(x)).collect(); let points: Vec> = data .data .iter() .map(|point| point.iter().map(|&x| scale_to_i32(x)).collect()) .collect(); (query, points) } /// 使用客户端密钥加密查询向量 /// /// # Arguments /// * `query` - 待加密的查询向量 /// * `client_key` - TFHE客户端密钥 /// /// # Returns /// * 加密后的查询向量或错误 fn encrypt_query(query: &[i32], client_key: &tfhe::ClientKey) -> Result> { query .iter() .map(|&x| FheInt32::try_encrypt(x, client_key)) .collect::, _>>() .map_err(|e| anyhow::anyhow!("Encryption failed: {}", e)) } /// 加密训练点集合,包含坐标和索引,显示进度条 /// /// # Arguments /// * `points` - 训练点向量列表 /// * `client_key` - TFHE客户端密钥 /// /// # Returns /// * 加密后的训练点列表或错误 fn encrypt_points( points: &[Vec], client_key: &tfhe::ClientKey, ) -> Result> { let total_points = points.len(); let mut encrypted_points = Vec::new(); for (i, point) in points.iter().enumerate() { print_progress_bar(i + 1, total_points, "Encrypting points"); let encrypted_coords: Result, _> = point .iter() .map(|&x| FheInt32::try_encrypt(x, client_key)) .collect(); let encrypted_index = FheUint8::try_encrypt((i + 1) as u8, client_key) .map_err(|e| anyhow::anyhow!("Index encryption failed: {}", e))?; match encrypted_coords { Ok(coords) => encrypted_points.push(EncryptedPoint { coords, index: encrypted_index, }), Err(e) => return Err(anyhow::anyhow!("Encryption failed: {}", e)), } } println!(); // New line after progress bar Ok(encrypted_points) } /// 计算两个加密向量间的欧氏距离平方 /// /// # Arguments /// * `p1` - 第一个加密向量 /// * `p2` - 第二个加密向量 /// * `encrypted_zero` - 预加密的0值,用于优化性能 /// /// # Returns /// * 加密的距离平方值 fn encrypted_euclidean_distance_squared( p1: &[FheInt32], p2: &[FheInt32], encrypted_zero: &FheInt32, ) -> FheInt32 { let mut sum = encrypted_zero.clone(); for (x1, x2) in p1.iter().zip(p2.iter()) { let diff = x1 - x2; let squared = &diff * &diff; sum += squared; } sum } /// 执行加密KNN分类,返回k个最近邻的加密索引 /// /// # Arguments /// * `query` - 加密的查询向量 /// * `points` - 加密的训练点集合 /// * `k` - 返回的最近邻数量 /// * `client_key` - TFHE客户端密钥(用于创建加密常量) /// /// # Returns /// * k个最近邻的加密索引列表 fn encrypted_knn_classify( query: &[FheInt32], points: &[EncryptedPoint], k: usize, client_key: &tfhe::ClientKey, ) -> Vec { println!("🔢 Computing encrypted distances..."); println!("🏭 Pre-encrypting constants for optimization..."); let encrypted_zero = FheInt32::try_encrypt(0i32, client_key).unwrap(); let total_points = points.len(); let mut distances = Vec::with_capacity(total_points); // Pre-allocate capacity for (i, point) in points.iter().enumerate() { print_progress_bar(i + 1, total_points, "Distance calculation"); let dist_start = Instant::now(); let distance = encrypted_euclidean_distance_squared(query, &point.coords, &encrypted_zero); let dist_time = dist_start.elapsed(); if i == 0 { let estimated_total = dist_time.as_secs_f64() * total_points as f64; println!( "\n⏱️ First distance took {:.1}s, estimated total: {:.1} minutes", dist_time.as_secs_f64(), estimated_total / 60.0 ); println!("💾 Memory optimization: reusing encrypted constants"); } distances.push(EncryptedNeighbor { distance, index: point.index.clone(), }); } println!(); // New line after progress bar println!("📊 Sorting {} distances...", distances.len()); bitonic_sort_encrypted(&mut distances, true); distances.truncate(k); distances.iter().map(|n| n.index.clone()).collect() } /// 解密KNN结果,将加密索引转换为明文索引 /// /// # Arguments /// * `encrypted_result` - 加密的索引列表 /// * `client_key` - TFHE客户端密钥 /// /// # Returns /// * 解密后的索引列表 fn decrypt_knn_result(encrypted_result: &[FheUint8], client_key: &tfhe::ClientKey) -> Vec { encrypted_result .iter() .map(|encrypted_index| { let decrypted: u8 = encrypted_index.decrypt(client_key); decrypted as usize }) .collect() } /// 格式化时间持续时长为可读字符串 /// /// # Arguments /// * `duration` - 时间持续时长 /// /// # Returns /// * 格式化的时间字符串(如 "5m 30s" 或 "45s") fn format_duration(duration: Duration) -> String { let total_secs = duration.as_secs(); let mins = total_secs / 60; let secs = total_secs % 60; if mins > 0 { format!("{mins}m {secs}s") } else { format!("{secs}s") } } /// 在终端显示进度条 /// /// # Arguments /// * `current` - 当前进度数量 /// * `total` - 总数量 /// * `operation` - 操作描述 fn print_progress_bar(current: usize, total: usize, operation: &str) { let percentage = (current as f64 / total as f64 * 100.0) as usize; let bar_length = 40; let filled = (current * bar_length / total).min(bar_length); let empty = bar_length - filled; print!( "\r{}: [{}{}] {}/{} {}%", operation, "█".repeat(filled), "░".repeat(empty), current, total, percentage ); std::io::stdout().flush().unwrap(); } /// 对加密邻居数组执行双调排序 /// /// # Arguments /// * `arr` - 待排序的加密邻居数组 /// * `up` - 排序方向,true为升序,false为降序 fn bitonic_sort_encrypted(arr: &mut [EncryptedNeighbor], up: bool) { if arr.len() <= 1 { return; } let mid = arr.len() / 2; bitonic_sort_encrypted(&mut arr[..mid], true); bitonic_sort_encrypted(&mut arr[mid..], false); bitonic_merge_encrypted(arr, up); } /// 双调排序的合并阶段,使用加密条件交换 /// /// # Arguments /// * `arr` - 待合并的加密邻居数组 /// * `up` - 合并方向,true为升序,false为降序 fn bitonic_merge_encrypted(arr: &mut [EncryptedNeighbor], up: bool) { if arr.len() <= 1 { return; } let mid = arr.len() / 2; for i in 0..mid { let should_swap = if up { arr[i].distance.gt(&arr[i + mid].distance) } else { arr[i].distance.lt(&arr[i + mid].distance) }; let (left, right) = arr.split_at_mut(mid); encrypted_conditional_swap(&mut left[i], &mut right[i], &should_swap); } bitonic_merge_encrypted(&mut arr[..mid], up); bitonic_merge_encrypted(&mut arr[mid..], up); } /// 基于加密条件交换两个加密邻居的距离和索引 /// /// # Arguments /// * `a` - 第一个加密邻居 /// * `b` - 第二个加密邻居 /// * `condition` - 加密的交换条件 fn encrypted_conditional_swap( a: &mut EncryptedNeighbor, b: &mut EncryptedNeighbor, condition: &tfhe::FheBool, ) { let new_a_distance = condition.select(&b.distance, &a.distance); let new_b_distance = condition.select(&a.distance, &b.distance); let new_a_index = condition.select(&b.index, &a.index); let new_b_index = condition.select(&a.index, &b.index); a.distance = new_a_distance; b.distance = new_b_distance; a.index = new_a_index; b.index = new_b_index; } fn main() -> Result<()> { let args = Args::parse(); let start_time = Instant::now(); println!("🔧 Setting up TFHE parameters..."); let setup_start = Instant::now(); let config = ConfigBuilder::default().build(); let (client_key, server_key) = generate_keys(config); set_server_key(server_key); println!( "✅ TFHE setup completed in {}", format_duration(setup_start.elapsed()) ); let file = File::open(&args.dataset)?; let reader = BufReader::new(file); let mut results = Vec::new(); let mut query_count = 0; for line in reader.lines() { let line = line?; let dataset: Dataset = serde_json::from_str(&line)?; query_count += 1; println!( "\n📊 Processing query #{} with {} training points", query_count, dataset.data.len() ); let query_start = Instant::now(); // Data conversion let convert_start = Instant::now(); let (query, points) = convert_data_to_i32(&dataset); println!( "✅ Data converted to i32 (100x scaling) in {}", format_duration(convert_start.elapsed()) ); // Query encryption let query_enc_start = Instant::now(); let encrypted_query = encrypt_query(&query, &client_key)?; println!( "✅ Query encrypted in {}", format_duration(query_enc_start.elapsed()) ); // Points encryption with progress let points_enc_start = Instant::now(); let encrypted_points = encrypt_points(&points, &client_key)?; println!( "✅ Points encrypted in {}", format_duration(points_enc_start.elapsed()) ); // KNN computation println!("🔍 Running encrypted KNN search (k=10)..."); println!("⏱️ Estimated time: ~8 minutes for {} points", points.len()); let knn_start = Instant::now(); let encrypted_result = encrypted_knn_classify(&encrypted_query, &encrypted_points, 10, &client_key); let knn_time = knn_start.elapsed(); println!("✅ KNN search completed in {}", format_duration(knn_time)); // Decryption let decrypt_start = Instant::now(); let nearest = decrypt_knn_result(&encrypted_result, &client_key); println!( "✅ Results decrypted in {}", format_duration(decrypt_start.elapsed()) ); let query_total = query_start.elapsed(); println!( "📈 Query #{} completed in {}", query_count, format_duration(query_total) ); results.push(Prediction { answer: nearest }); } // Save results let save_start = Instant::now(); let mut output_file = File::create(&args.predictions)?; for result in results { writeln!(output_file, "{}", serde_json::to_string(&result)?)?; } println!( "✅ Results saved to {} in {}", args.predictions, format_duration(save_start.elapsed()) ); let total_time = start_time.elapsed(); println!("\n🎉 All operations completed!"); println!("📊 Total queries processed: {query_count}"); println!("⏱️ Total execution time: {}", format_duration(total_time)); Ok(()) }