feat: 实现明文版hnsw算法和最优参数
This commit is contained in:
@@ -15,6 +15,16 @@ struct Args {
|
||||
dataset: String,
|
||||
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
||||
predictions: String,
|
||||
#[arg(long, default_value = "8", help = "Max connections per node (M parameter)")]
|
||||
max_connections: usize,
|
||||
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
|
||||
max_level: usize,
|
||||
#[arg(long, default_value = "0.6", help = "Level selection probability")]
|
||||
level_prob: f32,
|
||||
#[arg(long, default_value = "1", help = "ef parameter for upper layer search")]
|
||||
ef_upper: usize,
|
||||
#[arg(long, default_value = "10", help = "ef parameter for bottom layer search")]
|
||||
ef_bottom: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -81,17 +91,17 @@ impl Ord for OrderedFloat {
|
||||
}
|
||||
|
||||
impl HNSWGraph {
|
||||
fn new() -> Self {
|
||||
fn new(max_level: usize, max_connections: usize) -> Self {
|
||||
Self {
|
||||
nodes: Vec::new(),
|
||||
entry_point: None,
|
||||
max_level: 3,
|
||||
max_connections: 16,
|
||||
max_level,
|
||||
max_connections,
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_node(&mut self, vector: Vec<f64>) -> usize {
|
||||
let level = self.select_level();
|
||||
fn insert_node(&mut self, vector: Vec<f64>, level_prob: f32) -> usize {
|
||||
let level = self.select_level(level_prob);
|
||||
let node_id = self.nodes.len();
|
||||
|
||||
let node = HNSWNode {
|
||||
@@ -200,10 +210,10 @@ impl HNSWGraph {
|
||||
results.into_iter().map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
fn select_level(&self) -> usize {
|
||||
fn select_level(&self, level_prob: f32) -> usize {
|
||||
let mut rng = rand::rng();
|
||||
let mut level = 0;
|
||||
while level < self.max_level && rng.random::<f32>() < 0.5 {
|
||||
while level < self.max_level && rng.random::<f32>() < level_prob {
|
||||
level += 1;
|
||||
}
|
||||
level
|
||||
@@ -276,23 +286,32 @@ fn main() -> Result<()> {
|
||||
let reader = BufReader::new(file);
|
||||
let mut results = Vec::new();
|
||||
|
||||
println!("🔧 HNSW Parameters:");
|
||||
println!(" max_connections: {}", args.max_connections);
|
||||
println!(" max_level: {}", args.max_level);
|
||||
println!(" level_prob: {}", args.level_prob);
|
||||
println!(" ef_upper: {}", args.ef_upper);
|
||||
println!(" ef_bottom: {}", args.ef_bottom);
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
let mut hnsw = HNSWGraph::new();
|
||||
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||
for data_point in &dataset.data {
|
||||
hnsw.insert_node(data_point.clone());
|
||||
hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||
}
|
||||
|
||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||
let mut search_results = vec![entry_point];
|
||||
|
||||
// 上层搜索使用ef_upper参数
|
||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||
search_results = hnsw.search_layer(&dataset.query, search_results, 1, layer);
|
||||
search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||
}
|
||||
|
||||
let final_results = hnsw.search_layer(&dataset.query, search_results, 16, 0);
|
||||
// 底层搜索使用ef_bottom参数
|
||||
let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||
final_results
|
||||
.into_iter()
|
||||
.take(10)
|
||||
|
||||
Reference in New Issue
Block a user