hfe_knn/test_fhe_hnsw.py

339 lines
11 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
FHE HNSW并行测试脚本 - 多进程版本
测试密文态HNSW实现的稳定性和准确率
支持多进程并行执行充分利用服务器CPU资源
运行100次记录≥90%准确率的成功率和结果长度分布
"""
import subprocess
import json
import time
import os
import multiprocessing as mp
from pathlib import Path
from collections import defaultdict
# 正确答案
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
def calculate_accuracy(result, correct):
"""计算准确率:匹配元素数量 / 总元素数量"""
matches = len(set(result) & set(correct))
return matches / len(correct) * 100
def run_single_test(test_id, data_bit_width="12", timeout_minutes=30):
"""运行单次FHE HNSW测试供多进程调用"""
# 为每个进程创建独立的输出文件
output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl"
cmd = [
"./enc",
"--algorithm",
"hnsw",
"--data-bit-width",
data_bit_width,
"--predictions",
output_file,
]
start_time = time.time()
try:
result = subprocess.run(
cmd, capture_output=True, text=True, timeout=timeout_minutes * 60
)
test_time = time.time() - start_time
if result.returncode != 0:
return {
"test_id": test_id,
"result": None,
"error": f"Command failed: {result.stderr[:200]}",
"time_minutes": test_time / 60,
}
# 读取结果
if not Path(output_file).exists():
return {
"test_id": test_id,
"result": None,
"error": "Output file not found",
"time_minutes": test_time / 60,
}
with open(output_file, "r") as f:
output = json.load(f)
answer = output["answer"]
# 清理临时文件
Path(output_file).unlink(missing_ok=True)
# 计算准确率
accuracy = calculate_accuracy(answer, CORRECT_ANSWER)
return {
"test_id": test_id,
"result": answer,
"length": len(answer),
"accuracy": accuracy,
"success": accuracy >= 90,
"time_minutes": test_time / 60,
"error": None,
}
except subprocess.TimeoutExpired:
# 清理临时文件
Path(output_file).unlink(missing_ok=True)
return {
"test_id": test_id,
"result": None,
"error": "timeout",
"time_minutes": timeout_minutes,
}
except Exception as e:
# 清理临时文件
Path(output_file).unlink(missing_ok=True)
return {
"test_id": test_id,
"result": None,
"error": str(e),
"time_minutes": (time.time() - start_time) / 60,
}
def print_progress(completed, total, start_time):
"""打印进度信息"""
if completed == 0:
return
elapsed = time.time() - start_time
avg_time = elapsed / completed
remaining_time = avg_time * (total - completed)
print(
f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) "
f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h",
end="",
flush=True,
)
def main():
"""主测试函数"""
print("🚀 FHE HNSW 并行批量测试脚本")
print(f"🎯 正确答案: {CORRECT_ANSWER}")
print("📊 运行100次测试记录准确率和结果长度分布")
print("🔧 支持多进程并行执行")
# 检查enc二进制文件是否存在
if not Path("./enc").exists():
print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:")
print(" cargo build --release --bin enc")
return
# 获取CPU核心数并设置进程数
cpu_cores = mp.cpu_count()
# 考虑到FHE运算的内存密集性使用核心数的一半避免内存不足
num_processes = 8
timeout_minutes = 45 # 增加超时时间到45分钟
print(f"💻 使用 {num_processes} 个并行进程")
print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟")
print(
f"⏱️ 预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时"
)
print()
print("=" * 80)
print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)")
print("=" * 80)
start_time = time.time()
# 创建进程池并执行测试
test_ids = list(range(1, 101)) # 1-100
with mp.Pool(processes=num_processes) as pool:
# 创建异步任务
async_results = []
for test_id in test_ids:
async_result = pool.apply_async(
run_single_test, (test_id, "12", timeout_minutes)
)
async_results.append(async_result)
# 收集结果并显示进度
results = []
completed = 0
print_progress(completed, len(test_ids), start_time)
for async_result in async_results:
try:
result = async_result.get(
timeout=timeout_minutes * 60 + 60
) # 额外1分钟缓冲
results.append(result)
completed += 1
print_progress(completed, len(test_ids), start_time)
except mp.TimeoutError:
# 进程级别超时
results.append(
{
"test_id": len(results) + 1,
"result": None,
"error": "process_timeout",
"time_minutes": timeout_minutes,
}
)
completed += 1
print_progress(completed, len(test_ids), start_time)
print() # 换行
total_elapsed = time.time() - start_time
# 分析结果
print("\n" + "=" * 80)
print("📈 测试结果分析")
print("=" * 80)
# 统计变量
valid_results = [r for r in results if r["result"] is not None]
success_count = sum(1 for r in valid_results if r["success"])
length_distribution = defaultdict(int)
error_distribution = defaultdict(int)
# 统计错误类型
for r in results:
if r["error"]:
error_type = r["error"]
if "timeout" in error_type.lower():
error_distribution["timeout"] += 1
elif "failed" in error_type.lower():
error_distribution["failed"] += 1
else:
error_distribution["other"] += 1
else:
length_distribution[r["length"]] += 1
# 基本统计
total_tests = len(results)
valid_tests = len(valid_results)
if valid_tests == 0:
print("❌ 没有有效的测试结果")
return
success_rate = success_count / valid_tests * 100
avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests
avg_length = sum(r["length"] for r in valid_results) / valid_tests
avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests
print(f"总测试次数: {total_tests}")
print(f"有效测试次数: {valid_tests}")
print(f"成功次数 (≥90%准确率): {success_count}")
print(f"成功率: {success_rate:.1f}%")
print(f"平均准确率: {avg_accuracy:.1f}%")
print(f"平均结果长度: {avg_length:.1f}")
print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟")
print(f"总测试时间: {total_elapsed/3600:.1f}小时")
print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x")
# 错误统计
if error_distribution:
print("\n❌ 错误分布:")
for error_type, count in error_distribution.items():
print(f" {error_type}: {count}")
# 结果长度分布
if length_distribution:
print("\n📊 结果长度分布:")
for length in sorted(length_distribution.keys()):
count = length_distribution[length]
percentage = count / valid_tests * 100
bar = "" * (count // 2) if count > 0 else ""
print(f" 长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}")
# 结论
print()
if success_rate >= 50:
print("✅ 测试通过! FHE HNSW实现稳定性良好")
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
print(f" - 平均准确率: {avg_accuracy:.1f}%")
if avg_length >= 9.5:
print(f" - 平均结果长度: {avg_length:.1f} (接近10)")
else:
print(f" - 平均结果长度: {avg_length:.1f} (需要改进)")
else:
print("❌ 测试未通过,需要进一步优化")
print(f" - 成功率: {success_rate:.1f}% (<50%)")
print(f" - 平均准确率: {avg_accuracy:.1f}%")
# 保存详细结果
report_file = "fhe_hnsw_parallel_test_report.json"
with open(report_file, "w") as f:
json.dump(
{
"summary": {
"total_tests": total_tests,
"valid_tests": valid_tests,
"success_count": success_count,
"success_rate": success_rate,
"avg_accuracy": avg_accuracy,
"avg_length": avg_length,
"avg_time_minutes": avg_time_per_test,
"total_time_hours": total_elapsed / 3600,
"num_processes": num_processes,
"cpu_cores": cpu_cores,
"speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600),
},
"length_distribution": dict(length_distribution),
"error_distribution": dict(error_distribution),
"detailed_results": results,
"correct_answer": CORRECT_ANSWER,
},
f,
indent=2,
)
print(f"\n📁 详细报告已保存到: {report_file}")
# 显示最好和最坏的几个结果
if valid_results:
print("\n🏆 最高准确率的5个结果:")
best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[
:5
]
for i, r in enumerate(best_results, 1):
print(
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
)
print("\n⚠️ 最低准确率的5个结果:")
worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5]
for i, r in enumerate(worst_results, 1):
print(
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
)
if __name__ == "__main__":
try:
# 设置多进程启动方法Linux上默认是fork但spawn更安全
mp.set_start_method("spawn", force=True)
main()
except KeyboardInterrupt:
print("\n\n⏹️ 测试被用户中断")
# 清理可能的临时文件
for f in Path(".").glob("test_fhe_output_*.jsonl"):
f.unlink(missing_ok=True)
except Exception as e:
print(f"\n💥 程序异常: {e}")
# 清理可能的临时文件
for f in Path(".").glob("test_fhe_output_*.jsonl"):
f.unlink(missing_ok=True)