feat: remove hnsw
This commit is contained in:
@@ -321,60 +321,3 @@ fn encrypted_conditional_swap(
|
||||
a.index = new_a_index;
|
||||
b.index = new_b_index;
|
||||
}
|
||||
|
||||
///// 执行HNSW近似最近邻搜索
|
||||
/////
|
||||
///// # Arguments
|
||||
///// * `graph` - 预构建的FHE HNSW图结构
|
||||
///// * `query` - 加密的查询点
|
||||
///// * `zero` - 加密的零值
|
||||
/////
|
||||
///// # Returns
|
||||
///// * 10个最近邻的索引列表
|
||||
//pub fn perform_hnsw_search(
|
||||
// graph: &FheHnswGraph,
|
||||
// query: &EncryptedQuery,
|
||||
// zero: &FheInt14,
|
||||
//) -> Vec<usize> {
|
||||
// println!("🚀 Starting HNSW approximate search...");
|
||||
//
|
||||
// if graph.nodes.is_empty() {
|
||||
// println!("❌ Empty HNSW graph");
|
||||
// return Vec::new();
|
||||
// }
|
||||
//
|
||||
// let Some(entry_point) = graph.entry_point else {
|
||||
// println!("❌ No entry point in HNSW graph");
|
||||
// return Vec::new();
|
||||
// };
|
||||
//
|
||||
// println!(
|
||||
// "🔍 HNSW search from entry point {} at level {}",
|
||||
// entry_point, graph.max_level
|
||||
// );
|
||||
//
|
||||
// let mut current_candidates = vec![entry_point];
|
||||
//
|
||||
// // 从最高层逐层搜索到第1层
|
||||
// for layer in (1..=graph.max_level).rev() {
|
||||
// println!("🔍 Searching layer {layer} with ef=1...");
|
||||
// let layer_start = Instant::now();
|
||||
// current_candidates = graph.search_layer(query, current_candidates, 1, layer, zero);
|
||||
// println!(
|
||||
// "✅ Layer {} search completed in {}, {} candidates",
|
||||
// layer,
|
||||
// format_duration(layer_start.elapsed()),
|
||||
// current_candidates.len()
|
||||
// );
|
||||
// }
|
||||
//
|
||||
// // 在第0层进行最终搜索
|
||||
// let final_search_start = Instant::now();
|
||||
// let final_candidates = graph.search_layer(query, current_candidates, 10, 0, zero);
|
||||
// println!(
|
||||
// "✅ Final layer search completed in {}, {} candidates",
|
||||
// format_duration(final_search_start.elapsed()),
|
||||
// final_candidates.len()
|
||||
// );
|
||||
// final_candidates
|
||||
//}
|
||||
|
||||
@@ -26,7 +26,7 @@ struct Args {
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "bitonic",
|
||||
help = "Algorithm: selection, bitonic, heap, hnsw"
|
||||
help = "Algorithm: selection, bitonic, heap"
|
||||
)]
|
||||
algorithm: String,
|
||||
#[arg(long, help = "Enable debug mode (plaintext calculation first)")]
|
||||
@@ -143,24 +143,6 @@ fn process_dataset(args: &Args, client_key: &tfhe::ClientKey, start_time: Instan
|
||||
let query_encrypted = encrypt_query(&dataset.query, client_key);
|
||||
|
||||
let k = 10; // Number of nearest neighbors
|
||||
// if args.algorithm == "hnsw" {
|
||||
// // HNSW 算法路径
|
||||
// println!("🚀 Using HNSW algorithm...");
|
||||
//
|
||||
// // 1. 构建明文HNSW图
|
||||
// let (plaintext_nodes, entry_point, max_level) =
|
||||
// build_plaintext_hnsw_graph(&dataset.data);
|
||||
//
|
||||
// // 2. 转换为加密HNSW图
|
||||
// println!("🔐 Converting to encrypted HNSW graph...");
|
||||
// let fhe_graph =
|
||||
// build_fhe_hnsw_from_plaintext(&plaintext_nodes, entry_point, max_level, client_key);
|
||||
//
|
||||
// // 3. 执行HNSW搜索
|
||||
// let encrypted_zero = FheInt14::try_encrypt(0i16, client_key).unwrap();
|
||||
// let answer = perform_hnsw_search(&fhe_graph, &query_encrypted, &encrypted_zero);
|
||||
// results.push(Prediction { answer });
|
||||
// } else {
|
||||
// 传统算法路径
|
||||
// Encrypt all training points
|
||||
println!("🔐 Encrypting training points...");
|
||||
|
||||
284
src/bin/plain.rs
284
src/bin/plain.rs
@@ -1,9 +1,6 @@
|
||||
#![allow(dead_code)]
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::Ordering;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
|
||||
@@ -15,28 +12,6 @@ struct Args {
|
||||
dataset: String,
|
||||
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
||||
predictions: String,
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "8",
|
||||
help = "Max connections per node (M parameter)"
|
||||
)]
|
||||
max_connections: usize,
|
||||
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
|
||||
max_level: usize,
|
||||
#[arg(long, default_value = "0.6", help = "Level selection probability")]
|
||||
level_prob: f32,
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "1",
|
||||
help = "ef parameter for upper layer search"
|
||||
)]
|
||||
ef_upper: usize,
|
||||
#[arg(
|
||||
long,
|
||||
default_value = "10",
|
||||
help = "ef parameter for bottom layer search"
|
||||
)]
|
||||
ef_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -82,278 +57,25 @@ fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HNSWNode {
|
||||
vector: Vec<f64>,
|
||||
level: usize,
|
||||
neighbors: Vec<Vec<usize>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HNSWGraph {
|
||||
nodes: Vec<HNSWNode>,
|
||||
entry_point: Option<usize>,
|
||||
max_level: usize,
|
||||
max_connections: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
struct OrderedFloat(f64);
|
||||
|
||||
impl Eq for OrderedFloat {}
|
||||
|
||||
impl PartialOrd for OrderedFloat {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for OrderedFloat {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
impl HNSWGraph {
|
||||
fn new(max_level: usize, max_connections: usize) -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
entry_point: None,
|
||||
max_level,
|
||||
max_connections,
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_node(&mut self, vector: Vec<f64>, level_prob: f32) -> usize {
|
||||
let level = self.select_level(level_prob);
|
||||
let node_id = self.nodes.len();
|
||||
|
||||
let node = HNSWNode {
|
||||
vector,
|
||||
level,
|
||||
neighbors: vec![Vec::new(); level + 1],
|
||||
};
|
||||
|
||||
// 先添加节点到向量中
|
||||
self.nodes.push(node);
|
||||
|
||||
if node_id == 0 {
|
||||
self.entry_point = Some(node_id);
|
||||
return node_id;
|
||||
}
|
||||
|
||||
let mut current_candidates = vec![self.entry_point.unwrap()];
|
||||
|
||||
// 从最高层搜索到目标层+1
|
||||
for lc in (level + 1..=self.max_level).rev() {
|
||||
current_candidates =
|
||||
self.search_layer(&self.nodes[node_id].vector, current_candidates, 1, lc);
|
||||
}
|
||||
|
||||
// 从目标层到第0层建立连接
|
||||
for lc in (0..=level).rev() {
|
||||
current_candidates = self.search_layer(
|
||||
&self.nodes[node_id].vector,
|
||||
current_candidates,
|
||||
self.max_connections,
|
||||
lc,
|
||||
);
|
||||
|
||||
for &candidate_id in ¤t_candidates {
|
||||
self.connect_nodes(node_id, candidate_id, lc);
|
||||
}
|
||||
}
|
||||
|
||||
// 更新入口点
|
||||
if level > self.nodes[self.entry_point.unwrap()].level {
|
||||
self.entry_point = Some(node_id);
|
||||
}
|
||||
|
||||
node_id
|
||||
}
|
||||
|
||||
fn search_layer(
|
||||
&self,
|
||||
query: &[f64],
|
||||
entry_points: Vec<usize>,
|
||||
ef: usize,
|
||||
layer: usize,
|
||||
) -> Vec<usize> {
|
||||
let mut visited = std::collections::HashSet::new();
|
||||
let mut candidates = std::collections::BinaryHeap::new();
|
||||
let mut w = std::collections::BinaryHeap::new();
|
||||
|
||||
// 初始化候选点
|
||||
for &ep in &entry_points {
|
||||
if ep < self.nodes.len() && self.nodes[ep].level >= layer {
|
||||
let dist = euclidean_distance(query, &self.nodes[ep].vector);
|
||||
candidates.push(std::cmp::Reverse((OrderedFloat(dist), ep)));
|
||||
w.push((OrderedFloat(dist), ep));
|
||||
visited.insert(ep);
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(std::cmp::Reverse((current_dist, current))) = candidates.pop() {
|
||||
// 如果当前距离已经比最远的结果距离大,停止搜索
|
||||
if let Some(&(farthest_dist, _)) = w.iter().max() {
|
||||
if current_dist > farthest_dist && w.len() >= ef {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 探索当前节点的邻居
|
||||
if current < self.nodes.len() && layer < self.nodes[current].neighbors.len() {
|
||||
for &neighbor in &self.nodes[current].neighbors[layer] {
|
||||
if !visited.contains(&neighbor) && neighbor < self.nodes.len() {
|
||||
visited.insert(neighbor);
|
||||
let dist = euclidean_distance(query, &self.nodes[neighbor].vector);
|
||||
let ordered_dist = OrderedFloat(dist);
|
||||
|
||||
if w.len() < ef {
|
||||
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
|
||||
w.push((ordered_dist, neighbor));
|
||||
} else if let Some(&(farthest_dist, _)) = w.iter().max() {
|
||||
if ordered_dist < farthest_dist {
|
||||
candidates.push(std::cmp::Reverse((ordered_dist, neighbor)));
|
||||
w.push((ordered_dist, neighbor));
|
||||
|
||||
// 移除最远的点
|
||||
if let Some(max_item) = w.iter().max().copied() {
|
||||
w.retain(|&x| x != max_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 返回按距离排序的结果
|
||||
let mut results: Vec<_> = w.into_iter().collect();
|
||||
results.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
results.into_iter().map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
fn select_level(&self, level_prob: f32) -> usize {
|
||||
let mut rng = rand::rng();
|
||||
let mut level = 0;
|
||||
while level < self.max_level && rng.random::<f32>() < level_prob {
|
||||
level += 1;
|
||||
}
|
||||
level
|
||||
}
|
||||
|
||||
fn connect_nodes(&mut self, node1: usize, node2: usize, layer: usize) {
|
||||
self.nodes[node1].neighbors[layer].push(node2);
|
||||
self.nodes[node2].neighbors[layer].push(node1);
|
||||
|
||||
if self.nodes[node1].neighbors[layer].len() > self.max_connections {
|
||||
self.prune_node_connections(node1, layer);
|
||||
}
|
||||
if self.nodes[node2].neighbors[layer].len() > self.max_connections {
|
||||
self.prune_node_connections(node2, layer);
|
||||
}
|
||||
}
|
||||
|
||||
fn prune_node_connections(&mut self, node_id: usize, _layer: usize) {
|
||||
let node_vector = self.nodes[node_id].vector.clone();
|
||||
let nodes = self.nodes.clone(); // Create immutable reference before mutable borrow
|
||||
|
||||
// 计算所有邻居的距离并排序
|
||||
let neighbors = &mut self.nodes[node_id].neighbors[_layer];
|
||||
let mut neighbor_distances: Vec<(f64, usize)> = neighbors
|
||||
.iter()
|
||||
.map(|&neighbor_id| {
|
||||
let dist = euclidean_distance(&node_vector, &nodes[neighbor_id].vector);
|
||||
(dist, neighbor_id)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 按距离排序
|
||||
neighbor_distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
|
||||
// 更新邻居列表,只保留距离最近的
|
||||
*neighbors = neighbor_distances
|
||||
.into_iter()
|
||||
.take(self.max_connections)
|
||||
.map(|(_, neighbor_id)| neighbor_id)
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
// fn main() -> Result<()> {
|
||||
// let args = Args::parse();
|
||||
//
|
||||
// let file = File::open(&args.dataset)?;
|
||||
// let reader = BufReader::new(file);
|
||||
//
|
||||
// let mut results = Vec::new();
|
||||
//
|
||||
// for line in reader.lines() {
|
||||
// let line = line?;
|
||||
// let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
//
|
||||
// let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
// results.push(Prediction { answer: nearest });
|
||||
// }
|
||||
//
|
||||
// let mut output_file = File::create(&args.predictions)?;
|
||||
// for result in results {
|
||||
// writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
// }
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut results = Vec::new();
|
||||
|
||||
println!("🔧 HNSW Parameters:");
|
||||
println!(" max_connections: {}", args.max_connections);
|
||||
println!(" max_level: {}", args.max_level);
|
||||
println!(" level_prob: {}", args.level_prob);
|
||||
println!(" ef_upper: {}", args.ef_upper);
|
||||
println!(" ef_bottom: {}", args.ef_bottom);
|
||||
let mut results = Vec::new();
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||
for data_point in &dataset.data {
|
||||
hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||
}
|
||||
|
||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
let mut search_results = vec![entry_point];
|
||||
|
||||
// 上层搜索使用ef_upper参数
|
||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
search_results =
|
||||
hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||
}
|
||||
|
||||
// 底层搜索使用ef_bottom参数
|
||||
let final_results =
|
||||
hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||
final_results
|
||||
.into_iter()
|
||||
.take(10)
|
||||
.map(|idx| idx + 1)
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
println!("{}", serde_json::to_string(&result)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
248
src/data.rs
248
src/data.rs
@@ -128,251 +128,3 @@ pub fn decrypt_indices(encrypted_indices: &[FheUint8], client_key: &ClientKey) -
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
///// 密文版 HNSW 节点
|
||||
//#[derive(Clone)]
|
||||
//pub struct FheHnswNode {
|
||||
// pub encrypted_point: EncryptedPoint,
|
||||
// pub level: usize,
|
||||
// pub neighbors: Vec<Vec<usize>>, // 明文邻居索引(预处理时确定)
|
||||
//}
|
||||
//
|
||||
///// 密文版 HNSW 图
|
||||
//#[derive(Clone)]
|
||||
//pub struct FheHnswGraph {
|
||||
// pub nodes: Vec<FheHnswNode>,
|
||||
// pub entry_point: Option<usize>,
|
||||
// pub max_level: usize,
|
||||
// pub distances: [Option<FheInt14>; 100],
|
||||
//}
|
||||
//
|
||||
//impl FheHnswGraph {
|
||||
// pub fn new() -> Self {
|
||||
// Self {
|
||||
// nodes: Vec::new(),
|
||||
// entry_point: None,
|
||||
// max_level: 0,
|
||||
// distances: [const { None };100]
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /// 搜索层函数
|
||||
// ///
|
||||
// /// 核心算法:实现标准HNSW贪心搜索
|
||||
// ///
|
||||
// /// 算法流程:
|
||||
// /// 1. 初始化候选队列(candidates)和结果集(w)
|
||||
// /// 2. 从entry_points开始,计算到query的密文距离
|
||||
// /// 3. 贪心搜索循环:
|
||||
// /// - 从candidates中选择距离最小的点作为当前探索点
|
||||
// /// - 探索该点的所有邻居
|
||||
// /// - 将未访问的邻居加入candidates和w
|
||||
// /// - 维护w的大小不超过ef(移除最远的点)
|
||||
// /// - 实现剪枝:如果当前点比w中最远点还远且w已满,则停止
|
||||
// /// 4. 返回w中的节点索引
|
||||
// ///
|
||||
// /// 关键挑战:
|
||||
// /// - 需要模拟优先队列但只能用密文排序
|
||||
// /// - 剪枝条件难以在密文下判断,需要权衡准确性和性能
|
||||
// /// - 候选队列管理要高效,避免过多密文运算
|
||||
// pub fn search_layer(
|
||||
// &self,
|
||||
// query: &EncryptedQuery,
|
||||
// entry_points: Vec<usize>,
|
||||
// ef: usize,
|
||||
// layer: usize,
|
||||
// zero: &FheInt14,
|
||||
// ) -> Vec<usize> {
|
||||
// let visited: HashSet<usize> = HashSet::new();
|
||||
// let mut candidate =
|
||||
//
|
||||
// // TODO: 设计合适的数据结构
|
||||
// // candidates: 候选队列,存储(节点索引, 密文距离)或类似结构
|
||||
// // w: 结果集,维护当前找到的最好的ef个候选点
|
||||
//
|
||||
// // TODO: Step 1 - 初始化候选点
|
||||
// // 对每个entry_point:
|
||||
// // - 检查节点是否存在且level >= layer
|
||||
// // - 计算到query的密文距离:euclidean_distance(query, &self.nodes[ep].encrypted_point, zero)
|
||||
// // - 将节点加入visited集合
|
||||
// // - 将(节点索引, 距离)加入candidates和w
|
||||
//
|
||||
// println!(
|
||||
// "🔍 Starting search with {} initial candidates",
|
||||
// // TODO: 显示实际初始化的候选点数量
|
||||
// 0
|
||||
// );
|
||||
//
|
||||
// // TODO: Step 2 - 主搜索循环
|
||||
// // while candidates不为空:
|
||||
// // - 从candidates中找到距离最小的点(需要密文排序或其他方法)
|
||||
// // - 将该点从candidates中移除,设为当前探索点current
|
||||
// //
|
||||
// // - 剪枝检查(可选,复杂):
|
||||
// // 如果w.len() >= ef且current的距离 > w中最远点的距离,则break
|
||||
// //
|
||||
// // - 探索current的邻居:
|
||||
// // for neighbor in self.nodes[current].neighbors[layer]:
|
||||
// // - 如果neighbor未访问过:
|
||||
// // - 标记为已访问
|
||||
// // - 计算到query的密文距离
|
||||
// // - 将neighbor加入candidates
|
||||
// // - 管理结果集w:
|
||||
// // if w.len() < ef: 直接加入w
|
||||
// // else: 加入w后排序,移除最远的点,保持w大小为ef
|
||||
//
|
||||
// println!(
|
||||
// "🔍 Found {} total candidates (ef={})",
|
||||
// // TODO: 显示最终找到的候选点数量
|
||||
// 0,
|
||||
// ef
|
||||
// );
|
||||
//
|
||||
// // TODO: Step 3 - 返回结果
|
||||
// // 从w中提取节点索引并返回
|
||||
// // w.into_iter().map(|(node_idx, _)| node_idx).collect()
|
||||
//
|
||||
// // 临时返回空结果,避免编译错误
|
||||
// Vec::new()
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//impl Default for FheHnswGraph {
|
||||
// fn default() -> Self {
|
||||
// Self::new()
|
||||
// }
|
||||
//}
|
||||
//
|
||||
///// 从明文数据构建密文 HNSW 图的辅助结构
|
||||
//#[derive(Clone)]
|
||||
//pub struct PlaintextHnswNode {
|
||||
// pub vector: Vec<f64>,
|
||||
// pub level: usize,
|
||||
// pub neighbors: Vec<Vec<usize>>,
|
||||
//}
|
||||
//
|
||||
///// 从明文数据构建密文 HNSW 图
|
||||
//pub fn build_fhe_hnsw_from_plaintext(
|
||||
// plaintext_nodes: &[PlaintextHnswNode],
|
||||
// plaintext_entry_point: Option<usize>,
|
||||
// plaintext_max_level: usize,
|
||||
// client_key: &ClientKey,
|
||||
//) -> FheHnswGraph {
|
||||
// use crate::logging::print_progress_bar;
|
||||
//
|
||||
// let mut fhe_nodes = Vec::new();
|
||||
// let total_nodes = plaintext_nodes.len();
|
||||
//
|
||||
// println!("🔐 Encrypting {total_nodes} HNSW nodes...");
|
||||
//
|
||||
// for (idx, plain_node) in plaintext_nodes.iter().enumerate() {
|
||||
// print_progress_bar(idx + 1, total_nodes, "Encrypting nodes");
|
||||
// let encrypted_point = encrypt_point(&plain_node.vector, idx + 1, client_key);
|
||||
// let fhe_node = FheHnswNode {
|
||||
// encrypted_point,
|
||||
// level: plain_node.level,
|
||||
// neighbors: plain_node.neighbors.clone(),
|
||||
// };
|
||||
// fhe_nodes.push(fhe_node);
|
||||
// }
|
||||
//
|
||||
//
|
||||
// FheHnswGraph {
|
||||
// nodes: fhe_nodes,
|
||||
// entry_point: plaintext_entry_point,
|
||||
// max_level: plaintext_max_level,
|
||||
// }
|
||||
//}
|
||||
//
|
||||
///// 构建明文HNSW图
|
||||
//pub fn build_plaintext_hnsw_graph(data: &[Vec<f64>]) -> (Vec<PlaintextHnswNode>, Option<usize>, usize) {
|
||||
// println!("🔨 Building HNSW graph from {} points...", data.len());
|
||||
//
|
||||
// let max_connections = 8; // 最优参数:减少邻居数以降低密文运算量
|
||||
// let max_level = 3; // 最优参数:保持3层结构
|
||||
// let mut nodes = Vec::new();
|
||||
// let mut entry_point = None;
|
||||
// let mut current_max_level = 0;
|
||||
//
|
||||
// // 创建所有节点
|
||||
// for (idx, vector) in data.iter().enumerate() {
|
||||
// let level = select_level_for_node(max_level);
|
||||
// current_max_level = current_max_level.max(level);
|
||||
//
|
||||
// let node = PlaintextHnswNode {
|
||||
// vector: vector.clone(),
|
||||
// level,
|
||||
// neighbors: vec![Vec::new(); level + 1],
|
||||
// };
|
||||
// nodes.push(node);
|
||||
//
|
||||
// if idx == 0 {
|
||||
// entry_point = Some(idx);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 为每个节点建立连接(简化版实现)
|
||||
// for node_id in 0..nodes.len() {
|
||||
// print_progress_bar(node_id + 1, nodes.len(), "Building connections");
|
||||
//
|
||||
// let node_vector = nodes[node_id].vector.clone();
|
||||
// let node_level = nodes[node_id].level;
|
||||
//
|
||||
// // 为每层找到最近的邻居
|
||||
// for layer in 0..=node_level {
|
||||
// let mut distances: Vec<(f64, usize)> = nodes
|
||||
// .iter()
|
||||
// .enumerate()
|
||||
// .filter(|(idx, n)| *idx != node_id && n.level >= layer)
|
||||
// .map(|(idx, n)| {
|
||||
// let dist = euclidean_distance_plaintext(&node_vector, &n.vector);
|
||||
// (dist, idx)
|
||||
// })
|
||||
// .collect();
|
||||
//
|
||||
// distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
//
|
||||
// // 添加最近的邻居(限制连接数)
|
||||
// for &(_, neighbor_id) in distances.iter().take(max_connections) {
|
||||
// nodes[node_id].neighbors[layer].push(neighbor_id);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // 更新入口点
|
||||
// if nodes[node_id].level > nodes[entry_point.unwrap()].level {
|
||||
// entry_point = Some(node_id);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// println!(); // 新行
|
||||
// println!(
|
||||
// "✅ HNSW graph built with {} nodes, entry point: {:?}, max level: {}",
|
||||
// nodes.len(),
|
||||
// entry_point,
|
||||
// current_max_level
|
||||
// );
|
||||
//
|
||||
// (nodes, entry_point, current_max_level)
|
||||
//}
|
||||
//
|
||||
///// 选择节点的层级
|
||||
//pub fn select_level_for_node(max_level: usize) -> usize {
|
||||
// let mut rng = rand::rng();
|
||||
// let mut level = 0;
|
||||
// while level < max_level && rng.random::<f32>() < 0.6 {
|
||||
// // 最优参数:提高层级概率到0.6
|
||||
// level += 1;
|
||||
// }
|
||||
// level
|
||||
//}
|
||||
//
|
||||
//
|
||||
//
|
||||
///// 明文欧几里得距离计算
|
||||
//fn euclidean_distance_plaintext(a: &[f64], b: &[f64]) -> f64 {
|
||||
// a.iter()
|
||||
// .zip(b.iter())
|
||||
// .map(|(x, y)| (x - y).powi(2))
|
||||
// .sum::<f64>()
|
||||
// .sqrt()
|
||||
//}
|
||||
|
||||
3
src/main.rs
Normal file
3
src/main.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
fn main() {
|
||||
todo!();
|
||||
}
|
||||
Reference in New Issue
Block a user