feat: 实现明文版hnsw算法和最优参数

This commit is contained in:
2025-07-24 18:45:50 +08:00
parent bc92fc704f
commit 2a376b920b
3 changed files with 651 additions and 43 deletions

View File

@@ -15,6 +15,16 @@ struct Args {
dataset: String,
#[arg(long, default_value = "./dataset/answer1.jsonl")]
predictions: String,
#[arg(long, default_value = "8", help = "Max connections per node (M parameter)")]
max_connections: usize,
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
max_level: usize,
#[arg(long, default_value = "0.6", help = "Level selection probability")]
level_prob: f32,
#[arg(long, default_value = "1", help = "ef parameter for upper layer search")]
ef_upper: usize,
#[arg(long, default_value = "10", help = "ef parameter for bottom layer search")]
ef_bottom: usize,
}
#[derive(Deserialize)]
@@ -81,17 +91,17 @@ impl Ord for OrderedFloat {
}
impl HNSWGraph {
fn new() -> Self {
fn new(max_level: usize, max_connections: usize) -> Self {
Self {
nodes: Vec::new(),
entry_point: None,
max_level: 3,
max_connections: 16,
max_level,
max_connections,
}
}
fn insert_node(&mut self, vector: Vec<f64>) -> usize {
let level = self.select_level();
fn insert_node(&mut self, vector: Vec<f64>, level_prob: f32) -> usize {
let level = self.select_level(level_prob);
let node_id = self.nodes.len();
let node = HNSWNode {
@@ -200,10 +210,10 @@ impl HNSWGraph {
results.into_iter().map(|(_, idx)| idx).collect()
}
fn select_level(&self) -> usize {
fn select_level(&self, level_prob: f32) -> usize {
let mut rng = rand::rng();
let mut level = 0;
while level < self.max_level && rng.random::<f32>() < 0.5 {
while level < self.max_level && rng.random::<f32>() < level_prob {
level += 1;
}
level
@@ -276,23 +286,32 @@ fn main() -> Result<()> {
let reader = BufReader::new(file);
let mut results = Vec::new();
println!("🔧 HNSW Parameters:");
println!(" max_connections: {}", args.max_connections);
println!(" max_level: {}", args.max_level);
println!(" level_prob: {}", args.level_prob);
println!(" ef_upper: {}", args.ef_upper);
println!(" ef_bottom: {}", args.ef_bottom);
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
let mut hnsw = HNSWGraph::new();
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
for data_point in &dataset.data {
hnsw.insert_node(data_point.clone());
hnsw.insert_node(data_point.clone(), args.level_prob);
}
let nearest = if let Some(entry_point) = hnsw.entry_point {
let mut search_results = vec![entry_point];
// 上层搜索使用ef_upper参数
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
search_results = hnsw.search_layer(&dataset.query, search_results, 1, layer);
search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
}
let final_results = hnsw.search_layer(&dataset.query, search_results, 16, 0);
// 底层搜索使用ef_bottom参数
let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
final_results
.into_iter()
.take(10)