feat: 实现明文版hnsw算法和最优参数
This commit is contained in:
@@ -15,6 +15,16 @@ struct Args {
|
|||||||
dataset: String,
|
dataset: String,
|
||||||
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
#[arg(long, default_value = "./dataset/answer1.jsonl")]
|
||||||
predictions: String,
|
predictions: String,
|
||||||
|
#[arg(long, default_value = "8", help = "Max connections per node (M parameter)")]
|
||||||
|
max_connections: usize,
|
||||||
|
#[arg(long, default_value = "3", help = "Max levels in HNSW graph")]
|
||||||
|
max_level: usize,
|
||||||
|
#[arg(long, default_value = "0.6", help = "Level selection probability")]
|
||||||
|
level_prob: f32,
|
||||||
|
#[arg(long, default_value = "1", help = "ef parameter for upper layer search")]
|
||||||
|
ef_upper: usize,
|
||||||
|
#[arg(long, default_value = "10", help = "ef parameter for bottom layer search")]
|
||||||
|
ef_bottom: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -81,17 +91,17 @@ impl Ord for OrderedFloat {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl HNSWGraph {
|
impl HNSWGraph {
|
||||||
fn new() -> Self {
|
fn new(max_level: usize, max_connections: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
nodes: Vec::new(),
|
nodes: Vec::new(),
|
||||||
entry_point: None,
|
entry_point: None,
|
||||||
max_level: 3,
|
max_level,
|
||||||
max_connections: 16,
|
max_connections,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_node(&mut self, vector: Vec<f64>) -> usize {
|
fn insert_node(&mut self, vector: Vec<f64>, level_prob: f32) -> usize {
|
||||||
let level = self.select_level();
|
let level = self.select_level(level_prob);
|
||||||
let node_id = self.nodes.len();
|
let node_id = self.nodes.len();
|
||||||
|
|
||||||
let node = HNSWNode {
|
let node = HNSWNode {
|
||||||
@@ -200,10 +210,10 @@ impl HNSWGraph {
|
|||||||
results.into_iter().map(|(_, idx)| idx).collect()
|
results.into_iter().map(|(_, idx)| idx).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn select_level(&self) -> usize {
|
fn select_level(&self, level_prob: f32) -> usize {
|
||||||
let mut rng = rand::rng();
|
let mut rng = rand::rng();
|
||||||
let mut level = 0;
|
let mut level = 0;
|
||||||
while level < self.max_level && rng.random::<f32>() < 0.5 {
|
while level < self.max_level && rng.random::<f32>() < level_prob {
|
||||||
level += 1;
|
level += 1;
|
||||||
}
|
}
|
||||||
level
|
level
|
||||||
@@ -276,23 +286,32 @@ fn main() -> Result<()> {
|
|||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
|
println!("🔧 HNSW Parameters:");
|
||||||
|
println!(" max_connections: {}", args.max_connections);
|
||||||
|
println!(" max_level: {}", args.max_level);
|
||||||
|
println!(" level_prob: {}", args.level_prob);
|
||||||
|
println!(" ef_upper: {}", args.ef_upper);
|
||||||
|
println!(" ef_bottom: {}", args.ef_bottom);
|
||||||
|
|
||||||
for line in reader.lines() {
|
for line in reader.lines() {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||||
|
|
||||||
let mut hnsw = HNSWGraph::new();
|
let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections);
|
||||||
for data_point in &dataset.data {
|
for data_point in &dataset.data {
|
||||||
hnsw.insert_node(data_point.clone());
|
hnsw.insert_node(data_point.clone(), args.level_prob);
|
||||||
}
|
}
|
||||||
|
|
||||||
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
let nearest = if let Some(entry_point) = hnsw.entry_point {
|
||||||
let mut search_results = vec![entry_point];
|
let mut search_results = vec![entry_point];
|
||||||
|
|
||||||
|
// 上层搜索使用ef_upper参数
|
||||||
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
for layer in (1..=hnsw.nodes[entry_point].level).rev() {
|
||||||
search_results = hnsw.search_layer(&dataset.query, search_results, 1, layer);
|
search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer);
|
||||||
}
|
}
|
||||||
|
|
||||||
let final_results = hnsw.search_layer(&dataset.query, search_results, 16, 0);
|
// 底层搜索使用ef_bottom参数
|
||||||
|
let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0);
|
||||||
final_results
|
final_results
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.take(10)
|
.take(10)
|
||||||
|
|||||||
338
test_fhe_hnsw.py
Executable file
338
test_fhe_hnsw.py
Executable file
@@ -0,0 +1,338 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
FHE HNSW并行测试脚本 - 多进程版本
|
||||||
|
测试密文态HNSW实现的稳定性和准确率
|
||||||
|
支持多进程并行执行,充分利用服务器CPU资源
|
||||||
|
运行100次,记录≥90%准确率的成功率和结果长度分布
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import multiprocessing as mp
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
# 正确答案
|
||||||
|
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_accuracy(result, correct):
|
||||||
|
"""计算准确率:匹配元素数量 / 总元素数量"""
|
||||||
|
matches = len(set(result) & set(correct))
|
||||||
|
return matches / len(correct) * 100
|
||||||
|
|
||||||
|
|
||||||
|
def run_single_test(test_id, data_bit_width="12", timeout_minutes=30):
|
||||||
|
"""运行单次FHE HNSW测试(供多进程调用)"""
|
||||||
|
# 为每个进程创建独立的输出文件
|
||||||
|
output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl"
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"./enc",
|
||||||
|
"--algorithm",
|
||||||
|
"hnsw",
|
||||||
|
"--data-bit-width",
|
||||||
|
data_bit_width,
|
||||||
|
"--predictions",
|
||||||
|
output_file,
|
||||||
|
]
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd, capture_output=True, text=True, timeout=timeout_minutes * 60
|
||||||
|
)
|
||||||
|
|
||||||
|
test_time = time.time() - start_time
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return {
|
||||||
|
"test_id": test_id,
|
||||||
|
"result": None,
|
||||||
|
"error": f"Command failed: {result.stderr[:200]}",
|
||||||
|
"time_minutes": test_time / 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 读取结果
|
||||||
|
if not Path(output_file).exists():
|
||||||
|
return {
|
||||||
|
"test_id": test_id,
|
||||||
|
"result": None,
|
||||||
|
"error": "Output file not found",
|
||||||
|
"time_minutes": test_time / 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_file, "r") as f:
|
||||||
|
output = json.load(f)
|
||||||
|
answer = output["answer"]
|
||||||
|
|
||||||
|
# 清理临时文件
|
||||||
|
Path(output_file).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
# 计算准确率
|
||||||
|
accuracy = calculate_accuracy(answer, CORRECT_ANSWER)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"test_id": test_id,
|
||||||
|
"result": answer,
|
||||||
|
"length": len(answer),
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"success": accuracy >= 90,
|
||||||
|
"time_minutes": test_time / 60,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
# 清理临时文件
|
||||||
|
Path(output_file).unlink(missing_ok=True)
|
||||||
|
return {
|
||||||
|
"test_id": test_id,
|
||||||
|
"result": None,
|
||||||
|
"error": "timeout",
|
||||||
|
"time_minutes": timeout_minutes,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
# 清理临时文件
|
||||||
|
Path(output_file).unlink(missing_ok=True)
|
||||||
|
return {
|
||||||
|
"test_id": test_id,
|
||||||
|
"result": None,
|
||||||
|
"error": str(e),
|
||||||
|
"time_minutes": (time.time() - start_time) / 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress(completed, total, start_time):
|
||||||
|
"""打印进度信息"""
|
||||||
|
if completed == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_time = elapsed / completed
|
||||||
|
remaining_time = avg_time * (total - completed)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) "
|
||||||
|
f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h",
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主测试函数"""
|
||||||
|
print("🚀 FHE HNSW 并行批量测试脚本")
|
||||||
|
print(f"🎯 正确答案: {CORRECT_ANSWER}")
|
||||||
|
print("📊 运行100次测试,记录准确率和结果长度分布")
|
||||||
|
print("🔧 支持多进程并行执行")
|
||||||
|
|
||||||
|
# 检查enc二进制文件是否存在
|
||||||
|
if not Path("./enc").exists():
|
||||||
|
print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:")
|
||||||
|
print(" cargo build --release --bin enc")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取CPU核心数并设置进程数
|
||||||
|
cpu_cores = mp.cpu_count()
|
||||||
|
# 考虑到FHE运算的内存密集性,使用核心数的一半避免内存不足
|
||||||
|
num_processes = 8
|
||||||
|
timeout_minutes = 45 # 增加超时时间到45分钟
|
||||||
|
|
||||||
|
print(f"💻 使用 {num_processes} 个并行进程")
|
||||||
|
print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟")
|
||||||
|
print(
|
||||||
|
f"⏱️ 预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 创建进程池并执行测试
|
||||||
|
test_ids = list(range(1, 101)) # 1-100
|
||||||
|
|
||||||
|
with mp.Pool(processes=num_processes) as pool:
|
||||||
|
# 创建异步任务
|
||||||
|
async_results = []
|
||||||
|
for test_id in test_ids:
|
||||||
|
async_result = pool.apply_async(
|
||||||
|
run_single_test, (test_id, "12", timeout_minutes)
|
||||||
|
)
|
||||||
|
async_results.append(async_result)
|
||||||
|
|
||||||
|
# 收集结果并显示进度
|
||||||
|
results = []
|
||||||
|
completed = 0
|
||||||
|
|
||||||
|
print_progress(completed, len(test_ids), start_time)
|
||||||
|
|
||||||
|
for async_result in async_results:
|
||||||
|
try:
|
||||||
|
result = async_result.get(
|
||||||
|
timeout=timeout_minutes * 60 + 60
|
||||||
|
) # 额外1分钟缓冲
|
||||||
|
results.append(result)
|
||||||
|
completed += 1
|
||||||
|
print_progress(completed, len(test_ids), start_time)
|
||||||
|
except mp.TimeoutError:
|
||||||
|
# 进程级别超时
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"test_id": len(results) + 1,
|
||||||
|
"result": None,
|
||||||
|
"error": "process_timeout",
|
||||||
|
"time_minutes": timeout_minutes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
completed += 1
|
||||||
|
print_progress(completed, len(test_ids), start_time)
|
||||||
|
|
||||||
|
print() # 换行
|
||||||
|
total_elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
# 分析结果
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("📈 测试结果分析")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 统计变量
|
||||||
|
valid_results = [r for r in results if r["result"] is not None]
|
||||||
|
success_count = sum(1 for r in valid_results if r["success"])
|
||||||
|
length_distribution = defaultdict(int)
|
||||||
|
error_distribution = defaultdict(int)
|
||||||
|
|
||||||
|
# 统计错误类型
|
||||||
|
for r in results:
|
||||||
|
if r["error"]:
|
||||||
|
error_type = r["error"]
|
||||||
|
if "timeout" in error_type.lower():
|
||||||
|
error_distribution["timeout"] += 1
|
||||||
|
elif "failed" in error_type.lower():
|
||||||
|
error_distribution["failed"] += 1
|
||||||
|
else:
|
||||||
|
error_distribution["other"] += 1
|
||||||
|
else:
|
||||||
|
length_distribution[r["length"]] += 1
|
||||||
|
|
||||||
|
# 基本统计
|
||||||
|
total_tests = len(results)
|
||||||
|
valid_tests = len(valid_results)
|
||||||
|
|
||||||
|
if valid_tests == 0:
|
||||||
|
print("❌ 没有有效的测试结果")
|
||||||
|
return
|
||||||
|
|
||||||
|
success_rate = success_count / valid_tests * 100
|
||||||
|
avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests
|
||||||
|
avg_length = sum(r["length"] for r in valid_results) / valid_tests
|
||||||
|
avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests
|
||||||
|
|
||||||
|
print(f"总测试次数: {total_tests}")
|
||||||
|
print(f"有效测试次数: {valid_tests}")
|
||||||
|
print(f"成功次数 (≥90%准确率): {success_count}")
|
||||||
|
print(f"成功率: {success_rate:.1f}%")
|
||||||
|
print(f"平均准确率: {avg_accuracy:.1f}%")
|
||||||
|
print(f"平均结果长度: {avg_length:.1f}")
|
||||||
|
print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟")
|
||||||
|
print(f"总测试时间: {total_elapsed/3600:.1f}小时")
|
||||||
|
print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x")
|
||||||
|
|
||||||
|
# 错误统计
|
||||||
|
if error_distribution:
|
||||||
|
print("\n❌ 错误分布:")
|
||||||
|
for error_type, count in error_distribution.items():
|
||||||
|
print(f" {error_type}: {count}次")
|
||||||
|
|
||||||
|
# 结果长度分布
|
||||||
|
if length_distribution:
|
||||||
|
print("\n📊 结果长度分布:")
|
||||||
|
for length in sorted(length_distribution.keys()):
|
||||||
|
count = length_distribution[length]
|
||||||
|
percentage = count / valid_tests * 100
|
||||||
|
bar = "█" * (count // 2) if count > 0 else ""
|
||||||
|
print(f" 长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}")
|
||||||
|
|
||||||
|
# 结论
|
||||||
|
print()
|
||||||
|
if success_rate >= 50:
|
||||||
|
print("✅ 测试通过! FHE HNSW实现稳定性良好")
|
||||||
|
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
|
||||||
|
print(f" - 平均准确率: {avg_accuracy:.1f}%")
|
||||||
|
if avg_length >= 9.5:
|
||||||
|
print(f" - 平均结果长度: {avg_length:.1f} (接近10)")
|
||||||
|
else:
|
||||||
|
print(f" - 平均结果长度: {avg_length:.1f} (需要改进)")
|
||||||
|
else:
|
||||||
|
print("❌ 测试未通过,需要进一步优化")
|
||||||
|
print(f" - 成功率: {success_rate:.1f}% (<50%)")
|
||||||
|
print(f" - 平均准确率: {avg_accuracy:.1f}%")
|
||||||
|
|
||||||
|
# 保存详细结果
|
||||||
|
report_file = "fhe_hnsw_parallel_test_report.json"
|
||||||
|
with open(report_file, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"summary": {
|
||||||
|
"total_tests": total_tests,
|
||||||
|
"valid_tests": valid_tests,
|
||||||
|
"success_count": success_count,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"avg_accuracy": avg_accuracy,
|
||||||
|
"avg_length": avg_length,
|
||||||
|
"avg_time_minutes": avg_time_per_test,
|
||||||
|
"total_time_hours": total_elapsed / 3600,
|
||||||
|
"num_processes": num_processes,
|
||||||
|
"cpu_cores": cpu_cores,
|
||||||
|
"speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600),
|
||||||
|
},
|
||||||
|
"length_distribution": dict(length_distribution),
|
||||||
|
"error_distribution": dict(error_distribution),
|
||||||
|
"detailed_results": results,
|
||||||
|
"correct_answer": CORRECT_ANSWER,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n📁 详细报告已保存到: {report_file}")
|
||||||
|
|
||||||
|
# 显示最好和最坏的几个结果
|
||||||
|
if valid_results:
|
||||||
|
print("\n🏆 最高准确率的5个结果:")
|
||||||
|
best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[
|
||||||
|
:5
|
||||||
|
]
|
||||||
|
for i, r in enumerate(best_results, 1):
|
||||||
|
print(
|
||||||
|
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n⚠️ 最低准确率的5个结果:")
|
||||||
|
worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5]
|
||||||
|
for i, r in enumerate(worst_results, 1):
|
||||||
|
print(
|
||||||
|
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
# 设置多进程启动方法(Linux上默认是fork,但spawn更安全)
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n⏹️ 测试被用户中断")
|
||||||
|
# 清理可能的临时文件
|
||||||
|
for f in Path(".").glob("test_fhe_output_*.jsonl"):
|
||||||
|
f.unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n💥 程序异常: {e}")
|
||||||
|
# 清理可能的临时文件
|
||||||
|
for f in Path(".").glob("test_fhe_output_*.jsonl"):
|
||||||
|
f.unlink(missing_ok=True)
|
||||||
313
testhnsw.py
313
testhnsw.py
@@ -1,52 +1,303 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
HNSW参数优化测试脚本 - 控制变量法
|
||||||
|
针对100个10维向量,单次查询的场景进行参数调优
|
||||||
|
每个配置运行100次,记录≥90%准确率的成功率
|
||||||
|
"""
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import json
|
import json
|
||||||
from collections import Counter
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 正确答案
|
||||||
|
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
|
||||||
|
|
||||||
|
|
||||||
def load_answers(filepath):
|
def calculate_accuracy(result, correct):
|
||||||
with open(filepath, "r") as f:
|
"""计算准确率:匹配元素数量 / 总元素数量"""
|
||||||
data = json.load(f)
|
matches = len(set(result) & set(correct))
|
||||||
return data["answer"]
|
return matches / len(correct) * 100
|
||||||
|
|
||||||
|
|
||||||
def run_plain_binary():
|
def run_hnsw_test(max_connections, max_level, level_prob, ef_upper, ef_bottom):
|
||||||
result = subprocess.run(
|
"""运行HNSW测试并返回结果"""
|
||||||
["cargo", "r", "-r", "--bin", "plain"], capture_output=True, text=True, cwd="."
|
cmd = [
|
||||||
)
|
"cargo",
|
||||||
if result.returncode == 0:
|
"run",
|
||||||
# The program outputs the same results as answer1.jsonl
|
"--bin",
|
||||||
return load_answers("dataset/answer1.jsonl")
|
"plain",
|
||||||
|
"--",
|
||||||
|
"--max-connections",
|
||||||
|
str(max_connections),
|
||||||
|
"--max-level",
|
||||||
|
str(max_level),
|
||||||
|
"--level-prob",
|
||||||
|
str(level_prob),
|
||||||
|
"--ef-upper",
|
||||||
|
str(ef_upper),
|
||||||
|
"--ef-bottom",
|
||||||
|
str(ef_bottom),
|
||||||
|
"--predictions",
|
||||||
|
"./test_output.jsonl",
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open("./test_output.jsonl", "r") as f:
|
||||||
|
output = json.load(f)
|
||||||
|
return output["answer"]
|
||||||
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def compare_answers(predictions, ground_truth):
|
def test_config_stability(config, runs=100):
|
||||||
if not predictions or len(predictions) != len(ground_truth):
|
"""测试配置的稳定性:运行100次,统计≥90%准确率的次数"""
|
||||||
return 0
|
success_count = 0
|
||||||
return sum(1 for p, gt in zip(predictions, ground_truth) if p == gt)
|
accuracies = []
|
||||||
|
|
||||||
|
print(f" 测试中... ", end="", flush=True)
|
||||||
|
|
||||||
|
for i in range(runs):
|
||||||
|
if i % 20 == 0 and i > 0:
|
||||||
|
print(f"{i}/100 ", end="", flush=True)
|
||||||
|
|
||||||
|
result = run_hnsw_test(
|
||||||
|
config["max_connections"],
|
||||||
|
config["max_level"],
|
||||||
|
config["level_prob"],
|
||||||
|
config["ef_upper"],
|
||||||
|
config["ef_bottom"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
accuracy = calculate_accuracy(result, CORRECT_ANSWER)
|
||||||
|
accuracies.append(accuracy)
|
||||||
|
if accuracy >= 90:
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
success_rate = success_count / len(accuracies) * 100 if accuracies else 0
|
||||||
|
avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
|
||||||
|
|
||||||
|
print(f"完成!")
|
||||||
|
return success_rate, avg_accuracy, len(accuracies)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ground_truth = load_answers("dataset/answer.jsonl")
|
"""主测试函数 - 测试最优参数组合"""
|
||||||
|
print("🚀 HNSW最优参数组合测试")
|
||||||
|
print(f"🎯 正确答案: {CORRECT_ANSWER}")
|
||||||
|
print("📊 测试最优参数组合运行100次,记录≥90%准确率的成功率")
|
||||||
|
print()
|
||||||
|
|
||||||
num_runs = 100
|
# 最优参数组合
|
||||||
accuracies = []
|
optimal_config = {
|
||||||
|
"max_connections": 8,
|
||||||
|
"max_level": 3,
|
||||||
|
"level_prob": 0.6,
|
||||||
|
"ef_upper": 1,
|
||||||
|
"ef_bottom": 10,
|
||||||
|
}
|
||||||
|
|
||||||
for i in range(num_runs):
|
print("=" * 80)
|
||||||
predictions = run_plain_binary()
|
print("🏆 测试最优参数组合")
|
||||||
if predictions is not None:
|
print("=" * 80)
|
||||||
accuracy = compare_answers(predictions, ground_truth)
|
|
||||||
accuracies.append(accuracy)
|
|
||||||
|
|
||||||
print(f"\nResults ({len(accuracies)} runs):")
|
|
||||||
print(
|
print(
|
||||||
f"Min: {min(accuracies)}, Max: {max(accuracies)}, Mean: {sum(accuracies)/len(accuracies):.2f}"
|
f"参数配置: M={optimal_config['max_connections']}, L={optimal_config['max_level']}, "
|
||||||
|
+ f"P={optimal_config['level_prob']}, ef_upper={optimal_config['ef_upper']}, "
|
||||||
|
+ f"ef_bottom={optimal_config['ef_bottom']}"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("最优组合 ", end=" ")
|
||||||
|
success_rate, avg_acc, total_runs = test_config_stability(optimal_config)
|
||||||
|
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||||
)
|
)
|
||||||
|
|
||||||
counter = Counter(accuracies)
|
print()
|
||||||
print("Distribution:")
|
print("=" * 80)
|
||||||
for correct_count in sorted(counter.keys()):
|
print("📈 测试结果分析")
|
||||||
print(f" {correct_count} correct: {counter[correct_count]} times")
|
print("=" * 80)
|
||||||
|
|
||||||
|
if success_rate >= 50:
|
||||||
|
print("✅ 最优参数组合测试通过!")
|
||||||
|
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
|
||||||
|
print(f" - 平均准确率: {avg_acc:.1f}%")
|
||||||
|
else:
|
||||||
|
print("❌ 最优参数组合未达到50%成功率")
|
||||||
|
print(f" - 成功率: {success_rate:.1f}% (<50%)")
|
||||||
|
print(f" - 平均准确率: {avg_acc:.1f}%")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# 以下是控制变量测试代码 - 已注释掉
|
||||||
|
"""
|
||||||
|
# 控制变量测试 - 每次只改变一个参数
|
||||||
|
test_results = []
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("1测试 max_connections (M) 参数影响")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for m in [4, 8, 12, 16, 20]:
|
||||||
|
config = base_config.copy()
|
||||||
|
config["max_connections"] = m
|
||||||
|
config_name = f"M={m}"
|
||||||
|
|
||||||
|
print(f"{config_name:<15}", end=" ")
|
||||||
|
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||||
|
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
test_results.append(
|
||||||
|
{
|
||||||
|
"param": "max_connections",
|
||||||
|
"value": m,
|
||||||
|
"config": config,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"avg_accuracy": avg_acc,
|
||||||
|
"pass": success_rate >= 50,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 80)
|
||||||
|
print("2测试 max_level (L) 参数影响")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for l in [3, 4, 5, 6, 7]:
|
||||||
|
config = base_config.copy()
|
||||||
|
config["max_level"] = l
|
||||||
|
config_name = f"L={l}"
|
||||||
|
|
||||||
|
print(f"{config_name:<15}", end=" ")
|
||||||
|
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||||
|
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
test_results.append(
|
||||||
|
{
|
||||||
|
"param": "max_level",
|
||||||
|
"value": l,
|
||||||
|
"config": config,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"avg_accuracy": avg_acc,
|
||||||
|
"pass": success_rate >= 50,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 80)
|
||||||
|
print("3测试 level_prob (P) 参数影响")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for p in [0.2, 0.3, 0.4, 0.5, 0.6]:
|
||||||
|
config = base_config.copy()
|
||||||
|
config["level_prob"] = p
|
||||||
|
config_name = f"P={p}"
|
||||||
|
|
||||||
|
print(f"{config_name:<15}", end=" ")
|
||||||
|
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||||
|
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
test_results.append(
|
||||||
|
{
|
||||||
|
"param": "level_prob",
|
||||||
|
"value": p,
|
||||||
|
"config": config,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"avg_accuracy": avg_acc,
|
||||||
|
"pass": success_rate >= 50,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 80)
|
||||||
|
print("4️⃣ 测试 ef_bottom 参数影响")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for ef in [10, 16, 25, 40, 60]:
|
||||||
|
config = base_config.copy()
|
||||||
|
config["ef_bottom"] = ef
|
||||||
|
config_name = f"ef_b={ef}"
|
||||||
|
|
||||||
|
print(f"{config_name:<15}", end=" ")
|
||||||
|
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||||
|
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
test_results.append(
|
||||||
|
{
|
||||||
|
"param": "ef_bottom",
|
||||||
|
"value": ef,
|
||||||
|
"config": config,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"avg_accuracy": avg_acc,
|
||||||
|
"pass": success_rate >= 50,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 总结报告
|
||||||
|
print()
|
||||||
|
print("=" * 80)
|
||||||
|
print("📈 测试总结")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
passed_configs = [r for r in test_results if r["pass"]]
|
||||||
|
print(f"总测试配置数: {len(test_results)}")
|
||||||
|
print(
|
||||||
|
f"成功率≥50%的配置: {len(passed_configs)} ({len(passed_configs)/len(test_results)*100:.1f}%)"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if passed_configs:
|
||||||
|
print("🏆 成功率≥50%的参数配置:")
|
||||||
|
for result in passed_configs:
|
||||||
|
print(
|
||||||
|
f" {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%, 平均准确率 {result['avg_accuracy']:.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 找出每个参数的最佳值
|
||||||
|
print()
|
||||||
|
print("🎯 各参数最佳值推荐:")
|
||||||
|
for param_name in ["max_connections", "max_level", "level_prob", "ef_bottom"]:
|
||||||
|
param_results = [r for r in test_results if r["param"] == param_name]
|
||||||
|
best_result = max(param_results, key=lambda x: x["success_rate"])
|
||||||
|
print(
|
||||||
|
f" {param_name}: {best_result['value']} (成功率: {best_result['success_rate']:.1f}%)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("❌ 没有配置的成功率达到50%")
|
||||||
|
print("📊 按成功率排序的前5个配置:")
|
||||||
|
top_configs = sorted(
|
||||||
|
test_results, key=lambda x: x["success_rate"], reverse=True
|
||||||
|
)[:5]
|
||||||
|
for i, result in enumerate(top_configs, 1):
|
||||||
|
print(
|
||||||
|
f" {i}. {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 清理临时文件
|
||||||
|
Path("./test_output.jsonl").unlink(missing_ok=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user