339 lines
11 KiB
Python
Executable File
339 lines
11 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
FHE HNSW并行测试脚本 - 多进程版本
|
||
测试密文态HNSW实现的稳定性和准确率
|
||
支持多进程并行执行,充分利用服务器CPU资源
|
||
运行100次,记录≥90%准确率的成功率和结果长度分布
|
||
"""
|
||
|
||
import subprocess
|
||
import json
|
||
import time
|
||
import os
|
||
import multiprocessing as mp
|
||
from pathlib import Path
|
||
from collections import defaultdict
|
||
|
||
# 正确答案
|
||
CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28]
|
||
|
||
|
||
def calculate_accuracy(result, correct):
|
||
"""计算准确率:匹配元素数量 / 总元素数量"""
|
||
matches = len(set(result) & set(correct))
|
||
return matches / len(correct) * 100
|
||
|
||
|
||
def run_single_test(test_id, data_bit_width="12", timeout_minutes=30):
|
||
"""运行单次FHE HNSW测试(供多进程调用)"""
|
||
# 为每个进程创建独立的输出文件
|
||
output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl"
|
||
|
||
cmd = [
|
||
"./enc",
|
||
"--algorithm",
|
||
"hnsw",
|
||
"--data-bit-width",
|
||
data_bit_width,
|
||
"--predictions",
|
||
output_file,
|
||
]
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd, capture_output=True, text=True, timeout=timeout_minutes * 60
|
||
)
|
||
|
||
test_time = time.time() - start_time
|
||
|
||
if result.returncode != 0:
|
||
return {
|
||
"test_id": test_id,
|
||
"result": None,
|
||
"error": f"Command failed: {result.stderr[:200]}",
|
||
"time_minutes": test_time / 60,
|
||
}
|
||
|
||
# 读取结果
|
||
if not Path(output_file).exists():
|
||
return {
|
||
"test_id": test_id,
|
||
"result": None,
|
||
"error": "Output file not found",
|
||
"time_minutes": test_time / 60,
|
||
}
|
||
|
||
with open(output_file, "r") as f:
|
||
output = json.load(f)
|
||
answer = output["answer"]
|
||
|
||
# 清理临时文件
|
||
Path(output_file).unlink(missing_ok=True)
|
||
|
||
# 计算准确率
|
||
accuracy = calculate_accuracy(answer, CORRECT_ANSWER)
|
||
|
||
return {
|
||
"test_id": test_id,
|
||
"result": answer,
|
||
"length": len(answer),
|
||
"accuracy": accuracy,
|
||
"success": accuracy >= 90,
|
||
"time_minutes": test_time / 60,
|
||
"error": None,
|
||
}
|
||
|
||
except subprocess.TimeoutExpired:
|
||
# 清理临时文件
|
||
Path(output_file).unlink(missing_ok=True)
|
||
return {
|
||
"test_id": test_id,
|
||
"result": None,
|
||
"error": "timeout",
|
||
"time_minutes": timeout_minutes,
|
||
}
|
||
except Exception as e:
|
||
# 清理临时文件
|
||
Path(output_file).unlink(missing_ok=True)
|
||
return {
|
||
"test_id": test_id,
|
||
"result": None,
|
||
"error": str(e),
|
||
"time_minutes": (time.time() - start_time) / 60,
|
||
}
|
||
|
||
|
||
def print_progress(completed, total, start_time):
|
||
"""打印进度信息"""
|
||
if completed == 0:
|
||
return
|
||
|
||
elapsed = time.time() - start_time
|
||
avg_time = elapsed / completed
|
||
remaining_time = avg_time * (total - completed)
|
||
|
||
print(
|
||
f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) "
|
||
f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h",
|
||
end="",
|
||
flush=True,
|
||
)
|
||
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("🚀 FHE HNSW 并行批量测试脚本")
|
||
print(f"🎯 正确答案: {CORRECT_ANSWER}")
|
||
print("📊 运行100次测试,记录准确率和结果长度分布")
|
||
print("🔧 支持多进程并行执行")
|
||
|
||
# 检查enc二进制文件是否存在
|
||
if not Path("./enc").exists():
|
||
print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:")
|
||
print(" cargo build --release --bin enc")
|
||
return
|
||
|
||
# 获取CPU核心数并设置进程数
|
||
cpu_cores = mp.cpu_count()
|
||
# 考虑到FHE运算的内存密集性,使用核心数的一半避免内存不足
|
||
num_processes = 8
|
||
timeout_minutes = 45 # 增加超时时间到45分钟
|
||
|
||
print(f"💻 使用 {num_processes} 个并行进程")
|
||
print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟")
|
||
print(
|
||
f"⏱️ 预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时"
|
||
)
|
||
print()
|
||
|
||
print("=" * 80)
|
||
print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)")
|
||
print("=" * 80)
|
||
|
||
start_time = time.time()
|
||
|
||
# 创建进程池并执行测试
|
||
test_ids = list(range(1, 101)) # 1-100
|
||
|
||
with mp.Pool(processes=num_processes) as pool:
|
||
# 创建异步任务
|
||
async_results = []
|
||
for test_id in test_ids:
|
||
async_result = pool.apply_async(
|
||
run_single_test, (test_id, "12", timeout_minutes)
|
||
)
|
||
async_results.append(async_result)
|
||
|
||
# 收集结果并显示进度
|
||
results = []
|
||
completed = 0
|
||
|
||
print_progress(completed, len(test_ids), start_time)
|
||
|
||
for async_result in async_results:
|
||
try:
|
||
result = async_result.get(
|
||
timeout=timeout_minutes * 60 + 60
|
||
) # 额外1分钟缓冲
|
||
results.append(result)
|
||
completed += 1
|
||
print_progress(completed, len(test_ids), start_time)
|
||
except mp.TimeoutError:
|
||
# 进程级别超时
|
||
results.append(
|
||
{
|
||
"test_id": len(results) + 1,
|
||
"result": None,
|
||
"error": "process_timeout",
|
||
"time_minutes": timeout_minutes,
|
||
}
|
||
)
|
||
completed += 1
|
||
print_progress(completed, len(test_ids), start_time)
|
||
|
||
print() # 换行
|
||
total_elapsed = time.time() - start_time
|
||
|
||
# 分析结果
|
||
print("\n" + "=" * 80)
|
||
print("📈 测试结果分析")
|
||
print("=" * 80)
|
||
|
||
# 统计变量
|
||
valid_results = [r for r in results if r["result"] is not None]
|
||
success_count = sum(1 for r in valid_results if r["success"])
|
||
length_distribution = defaultdict(int)
|
||
error_distribution = defaultdict(int)
|
||
|
||
# 统计错误类型
|
||
for r in results:
|
||
if r["error"]:
|
||
error_type = r["error"]
|
||
if "timeout" in error_type.lower():
|
||
error_distribution["timeout"] += 1
|
||
elif "failed" in error_type.lower():
|
||
error_distribution["failed"] += 1
|
||
else:
|
||
error_distribution["other"] += 1
|
||
else:
|
||
length_distribution[r["length"]] += 1
|
||
|
||
# 基本统计
|
||
total_tests = len(results)
|
||
valid_tests = len(valid_results)
|
||
|
||
if valid_tests == 0:
|
||
print("❌ 没有有效的测试结果")
|
||
return
|
||
|
||
success_rate = success_count / valid_tests * 100
|
||
avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests
|
||
avg_length = sum(r["length"] for r in valid_results) / valid_tests
|
||
avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests
|
||
|
||
print(f"总测试次数: {total_tests}")
|
||
print(f"有效测试次数: {valid_tests}")
|
||
print(f"成功次数 (≥90%准确率): {success_count}")
|
||
print(f"成功率: {success_rate:.1f}%")
|
||
print(f"平均准确率: {avg_accuracy:.1f}%")
|
||
print(f"平均结果长度: {avg_length:.1f}")
|
||
print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟")
|
||
print(f"总测试时间: {total_elapsed/3600:.1f}小时")
|
||
print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x")
|
||
|
||
# 错误统计
|
||
if error_distribution:
|
||
print("\n❌ 错误分布:")
|
||
for error_type, count in error_distribution.items():
|
||
print(f" {error_type}: {count}次")
|
||
|
||
# 结果长度分布
|
||
if length_distribution:
|
||
print("\n📊 结果长度分布:")
|
||
for length in sorted(length_distribution.keys()):
|
||
count = length_distribution[length]
|
||
percentage = count / valid_tests * 100
|
||
bar = "█" * (count // 2) if count > 0 else ""
|
||
print(f" 长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}")
|
||
|
||
# 结论
|
||
print()
|
||
if success_rate >= 50:
|
||
print("✅ 测试通过! FHE HNSW实现稳定性良好")
|
||
print(f" - 成功率: {success_rate:.1f}% (≥50%)")
|
||
print(f" - 平均准确率: {avg_accuracy:.1f}%")
|
||
if avg_length >= 9.5:
|
||
print(f" - 平均结果长度: {avg_length:.1f} (接近10)")
|
||
else:
|
||
print(f" - 平均结果长度: {avg_length:.1f} (需要改进)")
|
||
else:
|
||
print("❌ 测试未通过,需要进一步优化")
|
||
print(f" - 成功率: {success_rate:.1f}% (<50%)")
|
||
print(f" - 平均准确率: {avg_accuracy:.1f}%")
|
||
|
||
# 保存详细结果
|
||
report_file = "fhe_hnsw_parallel_test_report.json"
|
||
with open(report_file, "w") as f:
|
||
json.dump(
|
||
{
|
||
"summary": {
|
||
"total_tests": total_tests,
|
||
"valid_tests": valid_tests,
|
||
"success_count": success_count,
|
||
"success_rate": success_rate,
|
||
"avg_accuracy": avg_accuracy,
|
||
"avg_length": avg_length,
|
||
"avg_time_minutes": avg_time_per_test,
|
||
"total_time_hours": total_elapsed / 3600,
|
||
"num_processes": num_processes,
|
||
"cpu_cores": cpu_cores,
|
||
"speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600),
|
||
},
|
||
"length_distribution": dict(length_distribution),
|
||
"error_distribution": dict(error_distribution),
|
||
"detailed_results": results,
|
||
"correct_answer": CORRECT_ANSWER,
|
||
},
|
||
f,
|
||
indent=2,
|
||
)
|
||
|
||
print(f"\n📁 详细报告已保存到: {report_file}")
|
||
|
||
# 显示最好和最坏的几个结果
|
||
if valid_results:
|
||
print("\n🏆 最高准确率的5个结果:")
|
||
best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[
|
||
:5
|
||
]
|
||
for i, r in enumerate(best_results, 1):
|
||
print(
|
||
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
|
||
)
|
||
|
||
print("\n⚠️ 最低准确率的5个结果:")
|
||
worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5]
|
||
for i, r in enumerate(worst_results, 1):
|
||
print(
|
||
f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)"
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
# 设置多进程启动方法(Linux上默认是fork,但spawn更安全)
|
||
mp.set_start_method("spawn", force=True)
|
||
main()
|
||
except KeyboardInterrupt:
|
||
print("\n\n⏹️ 测试被用户中断")
|
||
# 清理可能的临时文件
|
||
for f in Path(".").glob("test_fhe_output_*.jsonl"):
|
||
f.unlink(missing_ok=True)
|
||
except Exception as e:
|
||
print(f"\n💥 程序异常: {e}")
|
||
# 清理可能的临时文件
|
||
for f in Path(".").glob("test_fhe_output_*.jsonl"):
|
||
f.unlink(missing_ok=True)
|