feat: remove hnsw

This commit is contained in:
2025-08-27 13:13:20 +08:00
parent 7fae6b23b7
commit 25675228f4
7 changed files with 71 additions and 909 deletions

View File

@@ -321,60 +321,3 @@ fn encrypted_conditional_swap(
a.index = new_a_index;
b.index = new_b_index;
}
///// 执行HNSW近似最近邻搜索
/////
///// # Arguments
///// * `graph` - 预构建的FHE HNSW图结构
///// * `query` - 加密的查询点
///// * `zero` - 加密的零值
/////
///// # Returns
///// * 10个最近邻的索引列表
//pub fn perform_hnsw_search(
// graph: &FheHnswGraph,
// query: &EncryptedQuery,
// zero: &FheInt14,
//) -> Vec<usize> {
// println!("🚀 Starting HNSW approximate search...");
//
// if graph.nodes.is_empty() {
// println!("❌ Empty HNSW graph");
// return Vec::new();
// }
//
// let Some(entry_point) = graph.entry_point else {
// println!("❌ No entry point in HNSW graph");
// return Vec::new();
// };
//
// println!(
// "🔍 HNSW search from entry point {} at level {}",
// entry_point, graph.max_level
// );
//
// let mut current_candidates = vec![entry_point];
//
// // 从最高层逐层搜索到第1层
// for layer in (1..=graph.max_level).rev() {
// println!("🔍 Searching layer {layer} with ef=1...");
// let layer_start = Instant::now();
// current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
// println!(
// "✅ Layer {} search completed in {}, {} candidates",
// layer,
// format_duration(layer_start.elapsed()),
// current_candidates.len()
// );
// }
//
// // 在第0层进行最终搜索
// let final_search_start = Instant::now();
// let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
// println!(
// "✅ Final layer search completed in {}, {} candidates",
// format_duration(final_search_start.elapsed()),
// final_candidates.len()
// );
// final_candidates
//}

View File

@@ -26,7 +26,7 @@ struct Args {
#[arg(
long,
default_value = "bitonic",
help = "Algorithm: selection, bitonic, heap, hnsw"
help = "Algorithm: selection, bitonic, heap"
)]
algorithm: String,
#[arg(long, help = "Enable debug mode (plaintext calculation first)")]
@@ -143,24 +143,6 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
let query_encrypted = encrypt_query(&dataset.query, client_key);
let k = 10; // Number of nearest neighbors
// if args.algorithm == "hnsw" {
// // HNSW 算法路径
// println!("🚀 Using HNSW algorithm...");
//
// // 1. 构建明文HNSW图
// let (plaintext_nodes, entry_point, max_level) =
// build_plaintext_hnsw_graph(&dataset.data);
//
// // 2. 转换为加密HNSW图
// println!("🔐 Converting to encrypted HNSW graph...");
// let fhe_graph =
// build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
//
// // 3. 执行HNSW搜索
// let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
// let answer = perform_hnsw_search(&fhe_graph, &query_encrypted, &encrypted_zero);
// results.push(Prediction { answer });
// } else {
// 传统算法路径
// Encrypt all training points
println!("🔐 Encrypting training points...");

View File

@@ -1,9 +1,6 @@
#![allow(dead_code)]
use anyhow::Result;
use clap::Parser;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
@@ -15,28 +12,6 @@ struct Args {
dataset: String,
#[arg(long, default_value = "./dataset/answer1.jsonl")]
predictions: String,
#[arg(
long,
default_value = "8",
help = "Max connections per node (M parameter)"
)]
max_connections: usize,
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
max_level: usize,
#[arg(long, default_value = "0.6", help = "Level selection probability")]
level_prob: f32,
#[arg(
long,
default_value = "1",
help = "ef parameter for upper layer search"
)]
ef_upper: usize,
#[arg(
long,
default_value = "10",
help = "ef parameter for bottom layer search"
)]
ef_bottom: usize,
}
#[derive(Deserialize)]
@@ -82,278 +57,25 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
}
#[derive(Clone)]
struct HNSWNode {
vector: Vec<f64>,
level: usize,
neighbors: Vec<Vec<usize>>,
}
#[derive(Clone)]
struct HNSWGraph {
nodes: Vec<HNSWNode>,
entry_point: Option<usize>,
max_level: usize,
max_connections: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct OrderedFloat(f64);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
}
}
impl HNSWGraph {
fn new(max_level: usize, max_connections: usize) -> Self {
Self {
nodes: Vec::new(),
entry_point: None,
max_level,
max_connections,
}
}
fn insert_node(&mut self, vector: Vec<f64>, level_prob: f32) -> usize {
let level = self.select_level(level_prob);
let node_id = self.nodes.len();
let node = HNSWNode {
vector,
level,
neighbors: vec![Vec::new(); level + 1],
};
// 先添加节点到向量中
self.nodes.push(node);
if node_id == 0 {
self.entry_point = Some(node_id);
return node_id;
}
let mut current_candidates = vec![self.entry_point.unwrap()];
// 从最高层搜索到目标层+1
for lc in (level + 1..=self.max_level).rev() {
current_candidates =
self.search_layer(&self.nodes[node_id].vector, current_candidates, 1, lc);
}
// 从目标层到第0层建立连接
for lc in (0..=level).rev() {
current_candidates = self.search_layer(
&self.nodes[node_id].vector,
current_candidates,
self.max_connections,
lc,
);
for &candidate_id in &current_candidates {
self.connect_nodes(node_id, candidate_id, lc);
}
}
// 更新入口点
if level > self.nodes[self.entry_point.unwrap()].level {
self.entry_point = Some(node_id);
}
node_id
}
fn search_layer(
&self,
query: &[f64],
entry_points: Vec<usize>,
ef: usize,
layer: usize,
) -> Vec<usize> {
let mut visited = std::collections::HashSet::new();
let mut candidates = std::collections::BinaryHeap::new();
let mut w = std::collections::BinaryHeap::new();
// 初始化候选点
for &ep in &entry_points {
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
let dist = euclidean_distance(query, &self.nodes[ep].vector);
candidates.push(std::cmp::Reverse((OrderedFloat(dist), ep)));
w.push((OrderedFloat(dist), ep));
visited.insert(ep);
}
}
while let Some(std::cmp::Reverse((current_dist, current))) = candidates.pop() {
// 如果当前距离已经比最远的结果距离大,停止搜索
if let Some(&(farthest_dist, _)) = w.iter().max() {
if current_dist > farthest_dist && w.len() >= ef {
break;
}
}
// 探索当前节点的邻居
if current < self.nodes.len() && layer < self.nodes[current].neighbors.len() {
for &neighbor in &self.nodes[current].neighbors[layer] {
if !visited.contains(&neighbor) && neighbor < self.nodes.len() {
visited.insert(neighbor);
let dist = euclidean_distance(query, &self.nodes[neighbor].vector);
let ordered_dist = OrderedFloat(dist);
if w.len() < ef {
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
w.push((ordered_dist, neighbor));
} else if let Some(&(farthest_dist, _)) = w.iter().max() {
if ordered_dist < farthest_dist {
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
w.push((ordered_dist, neighbor));
// 移除最远的点
if let Some(max_item) = w.iter().max().copied() {
w.retain(|&x| x != max_item);
}
}
}
}
}
}
}
// 返回按距离排序的结果
let mut results: Vec<_> = w.into_iter().collect();
results.sort_by(|a, b| a.0.cmp(&b.0));
results.into_iter().map(|(_, idx)| idx).collect()
}
fn select_level(&self, level_prob: f32) -> usize {
let mut rng = rand::rng();
let mut level = 0;
while level < self.max_level && rng.random::<f32>() < level_prob {
level += 1;
}
level
}
fn connect_nodes(&mut self, node1: usize, node2: usize, layer: usize) {
self.nodes[node1].neighbors[layer].push(node2);
self.nodes[node2].neighbors[layer].push(node1);
if self.nodes[node1].neighbors[layer].len() > self.max_connections {
self.prune_node_connections(node1, layer);
}
if self.nodes[node2].neighbors[layer].len() > self.max_connections {
self.prune_node_connections(node2, layer);
}
}
fn prune_node_connections(&mut self, node_id: usize, _layer: usize) {
let node_vector = self.nodes[node_id].vector.clone();
let nodes = self.nodes.clone(); // Create immutable reference before mutable borrow
// 计算所有邻居的距离并排序
let neighbors = &mut self.nodes[node_id].neighbors[_layer];
let mut neighbor_distances: Vec<(f64, usize)> = neighbors
.iter()
.map(|&neighbor_id| {
let dist = euclidean_distance(&node_vector, &nodes[neighbor_id].vector);
(dist, neighbor_id)
})
.collect();
// 按距离排序
neighbor_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// 更新邻居列表,只保留距离最近的
*neighbors = neighbor_distances
.into_iter()
.take(self.max_connections)
.map(|(_, neighbor_id)| neighbor_id)
.collect();
}
}
// fn main() -> Result<()> {
// let args = Args::parse();
//
// let file = File::open(&args.dataset)?;
// let reader = BufReader::new(file);
//
// let mut results = Vec::new();
//
// for line in reader.lines() {
// let line = line?;
// let dataset: Dataset = serde_json::from_str(&line)?;
//
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
// results.push(Prediction { answer: nearest });
// }
//
// let mut output_file = File::create(&args.predictions)?;
// for result in results {
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
// }
//
// Ok(())
// }
fn main() -> Result<()> {
let args = Args::parse();
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
let mut results = Vec::new();
println!("🔧 HNSW Parameters:");
println!(" max_connections: {}", args.max_connections);
println!(" max_level: {}", args.max_level);
println!(" level_prob: {}", args.level_prob);
println!(" ef_upper: {}", args.ef_upper);
println!(" ef_bottom: {}", args.ef_bottom);
let mut results = Vec::new();
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
for data_point in &dataset.data {
hnsw.insert_node(data_point.clone(), args.level_prob);
}
let nearest = if let Some(entry_point) = hnsw.entry_point {
let mut search_results = vec![entry_point];
// 上层搜索使用ef_upper参数
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
search_results =
hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
}
// 底层搜索使用ef_bottom参数
let final_results =
hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
final_results
.into_iter()
.take(10)
.map(|idx| idx + 1)
.collect()
} else {
Vec::new()
};
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
results.push(Prediction { answer: nearest });
}
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
println!("{}", serde_json::to_string(&result)?);
}
Ok(())

View File

@@ -128,251 +128,3 @@ pub fn decrypt_indices(encrypted_indices: &[FheUint8], client_key: &ClientKey) -
})
.collect()
}
///// 密文版 HNSW 节点
//#[derive(Clone)]
//pub struct FheHnswNode {
// pub encrypted_point: EncryptedPoint,
// pub level: usize,
// pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
//}
//
///// 密文版 HNSW 图
//#[derive(Clone)]
//pub struct FheHnswGraph {
// pub nodes: Vec<FheHnswNode>,
// pub entry_point: Option<usize>,
// pub max_level: usize,
// pub distances: [Option<FheInt14>; 100],
//}
//
//impl FheHnswGraph {
// pub fn new() -> Self {
// Self {
// nodes: Vec::new(),
// entry_point: None,
// max_level: 0,
// distances: [const { None };100]
// }
// }
//
// /// 搜索层函数
// ///
// /// 核心算法实现标准HNSW贪心搜索
// ///
// /// 算法流程:
// /// 1. 初始化候选队列(candidates)和结果集(w)
// /// 2. 从entry_points开始计算到query的密文距离
// /// 3. 贪心搜索循环:
// /// - 从candidates中选择距离最小的点作为当前探索点
// /// - 探索该点的所有邻居
// /// - 将未访问的邻居加入candidates和w
// /// - 维护w的大小不超过ef移除最远的点
// /// - 实现剪枝如果当前点比w中最远点还远且w已满则停止
// /// 4. 返回w中的节点索引
// ///
// /// 关键挑战:
// /// - 需要模拟优先队列但只能用密文排序
// /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
// /// - 候选队列管理要高效,避免过多密文运算
// pub fn search_layer(
// &self,
// query: &EncryptedQuery,
// entry_points: Vec<usize>,
// ef: usize,
// layer: usize,
// zero: &FheInt14,
// ) -> Vec<usize> {
// let visited: HashSet<usize> = HashSet::new();
// let mut candidate =
//
// // TODO: 设计合适的数据结构
// // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
// // w: 结果集维护当前找到的最好的ef个候选点
//
// // TODO: Step 1 - 初始化候选点
// // 对每个entry_point:
// // - 检查节点是否存在且level >= layer
// // - 计算到query的密文距离euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
// // - 将节点加入visited集合
// // - 将(节点索引, 距离)加入candidates和w
//
// println!(
// "🔍 Starting search with {} initial candidates",
// // TODO: 显示实际初始化的候选点数量
// 0
// );
//
// // TODO: Step 2 - 主搜索循环
// // while candidates不为空:
// // - 从candidates中找到距离最小的点需要密文排序或其他方法
// // - 将该点从candidates中移除设为当前探索点current
// //
// // - 剪枝检查(可选,复杂):
// // 如果w.len() >= ef且current的距离 > w中最远点的距离则break
// //
// // - 探索current的邻居
// // for neighbor in self.nodes[current].neighbors[layer]:
// // - 如果neighbor未访问过
// // - 标记为已访问
// // - 计算到query的密文距离
// // - 将neighbor加入candidates
// // - 管理结果集w
// // if w.len() < ef: 直接加入w
// // else: 加入w后排序移除最远的点保持w大小为ef
//
// println!(
// "🔍 Found {} total candidates (ef={})",
// // TODO: 显示最终找到的候选点数量
// 0,
// ef
// );
//
// // TODO: Step 3 - 返回结果
// // 从w中提取节点索引并返回
// // w.into_iter().map(|(node_idx, _)| node_idx).collect()
//
// // 临时返回空结果,避免编译错误
// Vec::new()
// }
//}
//
//impl Default for FheHnswGraph {
// fn default() -> Self {
// Self::new()
// }
//}
//
///// 从明文数据构建密文 HNSW 图的辅助结构
//#[derive(Clone)]
//pub struct PlaintextHnswNode {
// pub vector: Vec<f64>,
// pub level: usize,
// pub neighbors: Vec<Vec<usize>>,
//}
//
///// 从明文数据构建密文 HNSW 图
//pub fn build_fhe_hnsw_from_plaintext(
// plaintext_nodes: &[PlaintextHnswNode],
// plaintext_entry_point: Option<usize>,
// plaintext_max_level: usize,
// client_key: &ClientKey,
//) -> FheHnswGraph {
// use crate::logging::print_progress_bar;
//
// let mut fhe_nodes = Vec::new();
// let total_nodes = plaintext_nodes.len();
//
// println!("🔐 Encrypting {total_nodes} HNSW nodes...");
//
// for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
// print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
// let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
// let fhe_node = FheHnswNode {
// encrypted_point,
// level: plain_node.level,
// neighbors: plain_node.neighbors.clone(),
// };
// fhe_nodes.push(fhe_node);
// }
//
//
// FheHnswGraph {
// nodes: fhe_nodes,
// entry_point: plaintext_entry_point,
// max_level: plaintext_max_level,
// }
//}
//
///// 构建明文HNSW图
//pub fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
// println!("🔨 Building HNSW graph from {} points...", data.len());
//
// let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
// let max_level = 3; // 最优参数保持3层结构
// let mut nodes = Vec::new();
// let mut entry_point = None;
// let mut current_max_level = 0;
//
// // 创建所有节点
// for (idx, vector) in data.iter().enumerate() {
// let level = select_level_for_node(max_level);
// current_max_level = current_max_level.max(level);
//
// let node = PlaintextHnswNode {
// vector: vector.clone(),
// level,
// neighbors: vec![Vec::new(); level + 1],
// };
// nodes.push(node);
//
// if idx == 0 {
// entry_point = Some(idx);
// }
// }
//
// // 为每个节点建立连接(简化版实现)
// for node_id in 0..nodes.len() {
// print_progress_bar(node_id + 1, nodes.len(), "Building connections");
//
// let node_vector = nodes[node_id].vector.clone();
// let node_level = nodes[node_id].level;
//
// // 为每层找到最近的邻居
// for layer in 0..=node_level {
// let mut distances: Vec<(f64, usize)> = nodes
// .iter()
// .enumerate()
// .filter(|(idx, n)| *idx != node_id && n.level >= layer)
// .map(|(idx, n)| {
// let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
// (dist, idx)
// })
// .collect();
//
// distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
//
// // 添加最近的邻居(限制连接数)
// for &(_, neighbor_id) in distances.iter().take(max_connections) {
// nodes[node_id].neighbors[layer].push(neighbor_id);
// }
// }
//
// // 更新入口点
// if nodes[node_id].level > nodes[entry_point.unwrap()].level {
// entry_point = Some(node_id);
// }
// }
//
// println!(); // 新行
// println!(
// "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
// nodes.len(),
// entry_point,
// current_max_level
// );
//
// (nodes, entry_point, current_max_level)
//}
//
///// 选择节点的层级
//pub fn select_level_for_node(max_level: usize) -> usize {
// let mut rng = rand::rng();
// let mut level = 0;
// while level < max_level && rng.random::<f32>() < 0.6 {
// // 最优参数提高层级概率到0.6
// level += 1;
// }
// level
//}
//
//
//
///// 明文欧几里得距离计算
//fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
// a.iter()
// .zip(b.iter())
// .map(|(x, y)| (x - y).powi(2))
// .sum::<f64>()
// .sqrt()
//}

3
src/main.rs Normal file
View File

@@ -0,0 +1,3 @@
fn main() {
todo!();
}