#!/usr/bin/env python3 """ HNSW参数优化测试脚本 - 控制变量法 针对100个10维向量,单次查询的场景进行参数调优 每个配置运行100次,记录≥90%准确率的成功率 """ import subprocess import json from pathlib import Path # 正确答案 CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28] def calculate_accuracy(result, correct): """计算准确率:匹配元素数量 / 总元素数量""" matches = len(set(result) & set(correct)) return matches / len(correct) * 100 def run_hnsw_test(max_connections, max_level, level_prob, ef_upper, ef_bottom): """运行HNSW测试并返回结果""" cmd = [ "cargo", "run", "--bin", "plain", "--", "--max-connections", str(max_connections), "--max-level", str(max_level), "--level-prob", str(level_prob), "--ef-upper", str(ef_upper), "--ef-bottom", str(ef_bottom), "--predictions", "./test_output.jsonl", ] try: result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) if result.returncode != 0: return None with open("./test_output.jsonl", "r") as f: output = json.load(f) return output["answer"] except Exception as e: return None def test_config_stability(config, runs=100): """测试配置的稳定性:运行100次,统计≥90%准确率的次数""" success_count = 0 accuracies = [] print(f" 测试中... ", end="", flush=True) for i in range(runs): if i % 20 == 0 and i > 0: print(f"{i}/100 ", end="", flush=True) result = run_hnsw_test( config["max_connections"], config["max_level"], config["level_prob"], config["ef_upper"], config["ef_bottom"], ) if result is not None: accuracy = calculate_accuracy(result, CORRECT_ANSWER) accuracies.append(accuracy) if accuracy >= 90: success_count += 1 success_rate = success_count / len(accuracies) * 100 if accuracies else 0 avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0 print(f"完成!") return success_rate, avg_accuracy, len(accuracies) def main(): """主测试函数 - 测试最优参数组合""" print("🚀 HNSW最优参数组合测试") print(f"🎯 正确答案: {CORRECT_ANSWER}") print("📊 测试最优参数组合运行100次,记录≥90%准确率的成功率") print() # 最优参数组合 optimal_config = { "max_connections": 8, "max_level": 3, "level_prob": 0.6, "ef_upper": 1, "ef_bottom": 10, } print("=" * 80) print("🏆 测试最优参数组合") print("=" * 80) print( f"参数配置: M={optimal_config['max_connections']}, L={optimal_config['max_level']}, " + f"P={optimal_config['level_prob']}, ef_upper={optimal_config['ef_upper']}, " + f"ef_bottom={optimal_config['ef_bottom']}" ) print() print("最优组合 ", end=" ") success_rate, avg_acc, total_runs = test_config_stability(optimal_config) status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" print( f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" ) print() print("=" * 80) print("📈 测试结果分析") print("=" * 80) if success_rate >= 50: print("✅ 最优参数组合测试通过!") print(f" - 成功率: {success_rate:.1f}% (≥50%)") print(f" - 平均准确率: {avg_acc:.1f}%") else: print("❌ 最优参数组合未达到50%成功率") print(f" - 成功率: {success_rate:.1f}% (<50%)") print(f" - 平均准确率: {avg_acc:.1f}%") return # 以下是控制变量测试代码 - 已注释掉 """ # 控制变量测试 - 每次只改变一个参数 test_results = [] print("=" * 80) print("1测试 max_connections (M) 参数影响") print("=" * 80) for m in [4, 8, 12, 16, 20]: config = base_config.copy() config["max_connections"] = m config_name = f"M={m}" print(f"{config_name:<15}", end=" ") success_rate, avg_acc, total_runs = test_config_stability(config) status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" print( f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" ) test_results.append( { "param": "max_connections", "value": m, "config": config, "success_rate": success_rate, "avg_accuracy": avg_acc, "pass": success_rate >= 50, } ) print() print("=" * 80) print("2测试 max_level (L) 参数影响") print("=" * 80) for l in [3, 4, 5, 6, 7]: config = base_config.copy() config["max_level"] = l config_name = f"L={l}" print(f"{config_name:<15}", end=" ") success_rate, avg_acc, total_runs = test_config_stability(config) status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" print( f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" ) test_results.append( { "param": "max_level", "value": l, "config": config, "success_rate": success_rate, "avg_accuracy": avg_acc, "pass": success_rate >= 50, } ) print() print("=" * 80) print("3测试 level_prob (P) 参数影响") print("=" * 80) for p in [0.2, 0.3, 0.4, 0.5, 0.6]: config = base_config.copy() config["level_prob"] = p config_name = f"P={p}" print(f"{config_name:<15}", end=" ") success_rate, avg_acc, total_runs = test_config_stability(config) status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" print( f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" ) test_results.append( { "param": "level_prob", "value": p, "config": config, "success_rate": success_rate, "avg_accuracy": avg_acc, "pass": success_rate >= 50, } ) print() print("=" * 80) print("4️⃣ 测试 ef_bottom 参数影响") print("=" * 80) for ef in [10, 16, 25, 40, 60]: config = base_config.copy() config["ef_bottom"] = ef config_name = f"ef_b={ef}" print(f"{config_name:<15}", end=" ") success_rate, avg_acc, total_runs = test_config_stability(config) status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" print( f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" ) test_results.append( { "param": "ef_bottom", "value": ef, "config": config, "success_rate": success_rate, "avg_accuracy": avg_acc, "pass": success_rate >= 50, } ) # 总结报告 print() print("=" * 80) print("📈 测试总结") print("=" * 80) passed_configs = [r for r in test_results if r["pass"]] print(f"总测试配置数: {len(test_results)}") print( f"成功率≥50%的配置: {len(passed_configs)} ({len(passed_configs)/len(test_results)*100:.1f}%)" ) print() if passed_configs: print("🏆 成功率≥50%的参数配置:") for result in passed_configs: print( f" {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%, 平均准确率 {result['avg_accuracy']:.1f}%" ) # 找出每个参数的最佳值 print() print("🎯 各参数最佳值推荐:") for param_name in ["max_connections", "max_level", "level_prob", "ef_bottom"]: param_results = [r for r in test_results if r["param"] == param_name] best_result = max(param_results, key=lambda x: x["success_rate"]) print( f" {param_name}: {best_result['value']} (成功率: {best_result['success_rate']:.1f}%)" ) else: print("❌ 没有配置的成功率达到50%") print("📊 按成功率排序的前5个配置:") top_configs = sorted( test_results, key=lambda x: x["success_rate"], reverse=True )[:5] for i, result in enumerate(top_configs, 1): print( f" {i}. {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%" ) # 清理临时文件 Path("./test_output.jsonl").unlink(missing_ok=True) """ if __name__ == "__main__": main()