hfe_knn/testhnsw.py

305 lines
9.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
HNSW参数优化测试脚本 - 控制变量法
针对100个10维向量单次查询的场景进行参数调优
每个配置运行100次记录≥90%准确率的成功率
"""
import subprocess
import json
from pathlib import Path
# 正确答案
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
def calculate_accuracy(result, correct):
"""计算准确率:匹配元素数量 / 总元素数量"""
matches = len(set(result) & set(correct))
return matches / len(correct) * 100
def run_hnsw_test(max_connections, max_level, level_prob, ef_upper, ef_bottom):
"""运行HNSW测试并返回结果"""
cmd = [
"cargo",
"run",
"--bin",
"plain",
"--",
"--max-connections",
str(max_connections),
"--max-level",
str(max_level),
"--level-prob",
str(level_prob),
"--ef-upper",
str(ef_upper),
"--ef-bottom",
str(ef_bottom),
"--predictions",
"./test_output.jsonl",
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
return None
with open("./test_output.jsonl", "r") as f:
output = json.load(f)
return output["answer"]
except Exception as e:
return None
def test_config_stability(config, runs=100):
"""测试配置的稳定性运行100次统计≥90%准确率的次数"""
success_count = 0
accuracies = []
print(f" 测试中... ", end="", flush=True)
for i in range(runs):
if i % 20 == 0 and i > 0:
print(f"{i}/100 ", end="", flush=True)
result = run_hnsw_test(
config["max_connections"],
config["max_level"],
config["level_prob"],
config["ef_upper"],
config["ef_bottom"],
)
if result is not None:
accuracy = calculate_accuracy(result, CORRECT_ANSWER)
accuracies.append(accuracy)
if accuracy >= 90:
success_count += 1
success_rate = success_count / len(accuracies) * 100 if accuracies else 0
avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0
print(f"完成!")
return success_rate, avg_accuracy, len(accuracies)
def main():
"""主测试函数 - 测试最优参数组合"""
print("🚀 HNSW最优参数组合测试")
print(f"🎯 正确答案: {CORRECT_ANSWER}")
print("📊 测试最优参数组合运行100次记录≥90%准确率的成功率")
print()
# 最优参数组合
optimal_config = {
"max_connections": 8,
"max_level": 3,
"level_prob": 0.6,
"ef_upper": 1,
"ef_bottom": 10,
}
print("=" * 80)
print("🏆 测试最优参数组合")
print("=" * 80)
print(
f"参数配置: M={optimal_config['max_connections']}, L={optimal_config['max_level']}, "
+ f"P={optimal_config['level_prob']}, ef_upper={optimal_config['ef_upper']}, "
+ f"ef_bottom={optimal_config['ef_bottom']}"
)
print()
print("最优组合 ", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(optimal_config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
print()
print("=" * 80)
print("📈 测试结果分析")
print("=" * 80)
if success_rate >= 50:
print("✅ 最优参数组合测试通过!")
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
print(f" - 平均准确率: {avg_acc:.1f}%")
else:
print("❌ 最优参数组合未达到50%成功率")
print(f" - 成功率: {success_rate:.1f}% (<50%)")
print(f" - 平均准确率: {avg_acc:.1f}%")
return
# 以下是控制变量测试代码 - 已注释掉
"""
# 控制变量测试 - 每次只改变一个参数
test_results = []
print("=" * 80)
print("1测试 max_connections (M) 参数影响")
print("=" * 80)
for m in [4, 8, 12, 16, 20]:
config = base_config.copy()
config["max_connections"] = m
config_name = f"M={m}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "max_connections",
"value": m,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
print()
print("=" * 80)
print("2测试 max_level (L) 参数影响")
print("=" * 80)
for l in [3, 4, 5, 6, 7]:
config = base_config.copy()
config["max_level"] = l
config_name = f"L={l}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "max_level",
"value": l,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
print()
print("=" * 80)
print("3测试 level_prob (P) 参数影响")
print("=" * 80)
for p in [0.2, 0.3, 0.4, 0.5, 0.6]:
config = base_config.copy()
config["level_prob"] = p
config_name = f"P={p}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "level_prob",
"value": p,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
print()
print("=" * 80)
print("4⃣ 测试 ef_bottom 参数影响")
print("=" * 80)
for ef in [10, 16, 25, 40, 60]:
config = base_config.copy()
config["ef_bottom"] = ef
config_name = f"ef_b={ef}"
print(f"{config_name:<15}", end=" ")
success_rate, avg_acc, total_runs = test_config_stability(config)
status = "✅ PASS" if success_rate >= 50 else "❌ FAIL"
print(
f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}"
)
test_results.append(
{
"param": "ef_bottom",
"value": ef,
"config": config,
"success_rate": success_rate,
"avg_accuracy": avg_acc,
"pass": success_rate >= 50,
}
)
# 总结报告
print()
print("=" * 80)
print("📈 测试总结")
print("=" * 80)
passed_configs = [r for r in test_results if r["pass"]]
print(f"总测试配置数: {len(test_results)}")
print(
f"成功率≥50%的配置: {len(passed_configs)} ({len(passed_configs)/len(test_results)*100:.1f}%)"
)
print()
if passed_configs:
print("🏆 成功率≥50%的参数配置:")
for result in passed_configs:
print(
f" {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%, 平均准确率 {result['avg_accuracy']:.1f}%"
)
# 找出每个参数的最佳值
print()
print("🎯 各参数最佳值推荐:")
for param_name in ["max_connections", "max_level", "level_prob", "ef_bottom"]:
param_results = [r for r in test_results if r["param"] == param_name]
best_result = max(param_results, key=lambda x: x["success_rate"])
print(
f" {param_name}: {best_result['value']} (成功率: {best_result['success_rate']:.1f}%)"
)
else:
print("❌ 没有配置的成功率达到50%")
print("📊 按成功率排序的前5个配置:")
top_configs = sorted(
test_results, key=lambda x: x["success_rate"], reverse=True
)[:5]
for i, result in enumerate(top_configs, 1):
print(
f" {i}. {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%"
)
# 清理临时文件
Path("./test_output.jsonl").unlink(missing_ok=True)
"""
if __name__ == "__main__":
main()