feat: 实现明文版hnsw算法和最优参数
This commit is contained in:
315
testhnsw.py
315
testhnsw.py
@@ -1,52 +1,303 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
HNSW参数优化测试脚本 - 控制变量法
|
||||
针对100个10维向量,单次查询的场景进行参数调优
|
||||
每个配置运行100次,记录≥90%准确率的成功率
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import json
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
# 正确答案
|
||||
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
|
||||
|
||||
|
||||
def load_answers(filepath):
|
||||
with open(filepath, "r") as f:
|
||||
data = json.load(f)
|
||||
return data["answer"]
|
||||
def calculate_accuracy(result, correct):
|
||||
"""计算准确率:匹配元素数量 / 总元素数量"""
|
||||
matches = len(set(result) & set(correct))
|
||||
return matches / len(correct) * 100
|
||||
|
||||
|
||||
def run_plain_binary():
|
||||
result = subprocess.run(
|
||||
["cargo", "r", "-r", "--bin", "plain"], capture_output=True, text=True, cwd="."
|
||||
)
|
||||
if result.returncode == 0:
|
||||
# The program outputs the same results as answer1.jsonl
|
||||
return load_answers("dataset/answer1.jsonl")
|
||||
return None
|
||||
def run_hnsw_test(max_connections, max_level, level_prob, ef_upper, ef_bottom):
|
||||
"""运行HNSW测试并返回结果"""
|
||||
cmd = [
|
||||
"cargo",
|
||||
"run",
|
||||
"--bin",
|
||||
"plain",
|
||||
"--",
|
||||
"--max-connections",
|
||||
str(max_connections),
|
||||
"--max-level",
|
||||
str(max_level),
|
||||
"--level-prob",
|
||||
str(level_prob),
|
||||
"--ef-upper",
|
||||
str(ef_upper),
|
||||
"--ef-bottom",
|
||||
str(ef_bottom),
|
||||
"--predictions",
|
||||
"./test_output.jsonl",
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
with open("./test_output.jsonl", "r") as f:
|
||||
output = json.load(f)
|
||||
return output["answer"]
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def compare_answers(predictions, ground_truth):
|
||||
if not predictions or len(predictions) != len(ground_truth):
|
||||
return 0
|
||||
return sum(1 for p, gt in zip(predictions, ground_truth) if p == gt)
|
||||
def test_config_stability(config, runs=100):
|
||||
"""测试配置的稳定性:运行100次,统计≥90%准确率的次数"""
|
||||
success_count = 0
|
||||
accuracies = []
|
||||
|
||||
print(f" 测试中... ", end="", flush=True)
|
||||
|
||||
for i in range(runs):
|
||||
if i % 20 == 0 and i > 0:
|
||||
print(f"{i}/100 ", end="", flush=True)
|
||||
|
||||
result = run_hnsw_test(
|
||||
config["max_connections"],
|
||||
config["max_level"],
|
||||
config["level_prob"],
|
||||
config["ef_upper"],
|
||||
config["ef_bottom"],
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
accuracy = calculate_accuracy(result, CORRECT_ANSWER)
|
||||
accuracies.append(accuracy)
|
||||
if accuracy >= 90:
|
||||
success_count += 1
|
||||
|
||||
success_rate = success_count / len(accuracies) * 100 if accuracies else 0
|
||||
avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
|
||||
|
||||
print(f"完成!")
|
||||
return success_rate, avg_accuracy, len(accuracies)
|
||||
|
||||
|
||||
def main():
|
||||
ground_truth = load_answers("dataset/answer.jsonl")
|
||||
"""主测试函数 - 测试最优参数组合"""
|
||||
print("🚀 HNSW最优参数组合测试")
|
||||
print(f"🎯 正确答案: {CORRECT_ANSWER}")
|
||||
print("📊 测试最优参数组合运行100次,记录≥90%准确率的成功率")
|
||||
print()
|
||||
|
||||
num_runs = 100
|
||||
accuracies = []
|
||||
# 最优参数组合
|
||||
optimal_config = {
|
||||
"max_connections": 8,
|
||||
"max_level": 3,
|
||||
"level_prob": 0.6,
|
||||
"ef_upper": 1,
|
||||
"ef_bottom": 10,
|
||||
}
|
||||
|
||||
for i in range(num_runs):
|
||||
predictions = run_plain_binary()
|
||||
if predictions is not None:
|
||||
accuracy = compare_answers(predictions, ground_truth)
|
||||
accuracies.append(accuracy)
|
||||
|
||||
print(f"\nResults ({len(accuracies)} runs):")
|
||||
print("=" * 80)
|
||||
print("🏆 测试最优参数组合")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"Min: {min(accuracies)}, Max: {max(accuracies)}, Mean: {sum(accuracies)/len(accuracies):.2f}"
|
||||
f"参数配置: M={optimal_config['max_connections']}, L={optimal_config['max_level']}, "
|
||||
+ f"P={optimal_config['level_prob']}, ef_upper={optimal_config['ef_upper']}, "
|
||||
+ f"ef_bottom={optimal_config['ef_bottom']}"
|
||||
)
|
||||
print()
|
||||
|
||||
print("最优组合 ", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(optimal_config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
counter = Counter(accuracies)
|
||||
print("Distribution:")
|
||||
for correct_count in sorted(counter.keys()):
|
||||
print(f" {correct_count} correct: {counter[correct_count]} times")
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("📈 测试结果分析")
|
||||
print("=" * 80)
|
||||
|
||||
if success_rate >= 50:
|
||||
print("✅ 最优参数组合测试通过!")
|
||||
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
|
||||
print(f" - 平均准确率: {avg_acc:.1f}%")
|
||||
else:
|
||||
print("❌ 最优参数组合未达到50%成功率")
|
||||
print(f" - 成功率: {success_rate:.1f}% (<50%)")
|
||||
print(f" - 平均准确率: {avg_acc:.1f}%")
|
||||
|
||||
return
|
||||
|
||||
# 以下是控制变量测试代码 - 已注释掉
|
||||
"""
|
||||
# 控制变量测试 - 每次只改变一个参数
|
||||
test_results = []
|
||||
|
||||
print("=" * 80)
|
||||
print("1测试 max_connections (M) 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for m in [4, 8, 12, 16, 20]:
|
||||
config = base_config.copy()
|
||||
config["max_connections"] = m
|
||||
config_name = f"M={m}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "max_connections",
|
||||
"value": m,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("2测试 max_level (L) 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for l in [3, 4, 5, 6, 7]:
|
||||
config = base_config.copy()
|
||||
config["max_level"] = l
|
||||
config_name = f"L={l}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "max_level",
|
||||
"value": l,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("3测试 level_prob (P) 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for p in [0.2, 0.3, 0.4, 0.5, 0.6]:
|
||||
config = base_config.copy()
|
||||
config["level_prob"] = p
|
||||
config_name = f"P={p}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "level_prob",
|
||||
"value": p,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("4️⃣ 测试 ef_bottom 参数影响")
|
||||
print("=" * 80)
|
||||
|
||||
for ef in [10, 16, 25, 40, 60]:
|
||||
config = base_config.copy()
|
||||
config["ef_bottom"] = ef
|
||||
config_name = f"ef_b={ef}"
|
||||
|
||||
print(f"{config_name:<15}", end=" ")
|
||||
success_rate, avg_acc, total_runs = test_config_stability(config)
|
||||
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
|
||||
|
||||
print(
|
||||
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
|
||||
)
|
||||
|
||||
test_results.append(
|
||||
{
|
||||
"param": "ef_bottom",
|
||||
"value": ef,
|
||||
"config": config,
|
||||
"success_rate": success_rate,
|
||||
"avg_accuracy": avg_acc,
|
||||
"pass": success_rate >= 50,
|
||||
}
|
||||
)
|
||||
|
||||
# 总结报告
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("📈 测试总结")
|
||||
print("=" * 80)
|
||||
|
||||
passed_configs = [r for r in test_results if r["pass"]]
|
||||
print(f"总测试配置数: {len(test_results)}")
|
||||
print(
|
||||
f"成功率≥50%的配置: {len(passed_configs)} ({len(passed_configs)/len(test_results)*100:.1f}%)"
|
||||
)
|
||||
print()
|
||||
|
||||
if passed_configs:
|
||||
print("🏆 成功率≥50%的参数配置:")
|
||||
for result in passed_configs:
|
||||
print(
|
||||
f" {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%, 平均准确率 {result['avg_accuracy']:.1f}%"
|
||||
)
|
||||
|
||||
# 找出每个参数的最佳值
|
||||
print()
|
||||
print("🎯 各参数最佳值推荐:")
|
||||
for param_name in ["max_connections", "max_level", "level_prob", "ef_bottom"]:
|
||||
param_results = [r for r in test_results if r["param"] == param_name]
|
||||
best_result = max(param_results, key=lambda x: x["success_rate"])
|
||||
print(
|
||||
f" {param_name}: {best_result['value']} (成功率: {best_result['success_rate']:.1f}%)"
|
||||
)
|
||||
else:
|
||||
print("❌ 没有配置的成功率达到50%")
|
||||
print("📊 按成功率排序的前5个配置:")
|
||||
top_configs = sorted(
|
||||
test_results, key=lambda x: x["success_rate"], reverse=True
|
||||
)[:5]
|
||||
for i, result in enumerate(top_configs, 1):
|
||||
print(
|
||||
f" {i}. {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%"
|
||||
)
|
||||
|
||||
# 清理临时文件
|
||||
Path("./test_output.jsonl").unlink(missing_ok=True)
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user