#!/usr/bin/env python3 """ FHE HNSW并行测试脚本 - 多进程版本 测试密文态HNSW实现的稳定性和准确率 支持多进程并行执行,充分利用服务器CPU资源 运行100次,记录≥90%准确率的成功率和结果长度分布 """ import subprocess import json import time import os import multiprocessing as mp from pathlib import Path from collections import defaultdict # 正确答案 CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28] def calculate_accuracy(result, correct): """计算准确率:匹配元素数量 / 总元素数量""" matches = len(set(result) & set(correct)) return matches / len(correct) * 100 def run_single_test(test_id, data_bit_width="12", timeout_minutes=30): """运行单次FHE HNSW测试(供多进程调用)""" # 为每个进程创建独立的输出文件 output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl" cmd = [ "./enc", "--algorithm", "hnsw", "--data-bit-width", data_bit_width, "--predictions", output_file, ] start_time = time.time() try: result = subprocess.run( cmd, capture_output=True, text=True, timeout=timeout_minutes * 60 ) test_time = time.time() - start_time if result.returncode != 0: return { "test_id": test_id, "result": None, "error": f"Command failed: {result.stderr[:200]}", "time_minutes": test_time / 60, } # 读取结果 if not Path(output_file).exists(): return { "test_id": test_id, "result": None, "error": "Output file not found", "time_minutes": test_time / 60, } with open(output_file, "r") as f: output = json.load(f) answer = output["answer"] # 清理临时文件 Path(output_file).unlink(missing_ok=True) # 计算准确率 accuracy = calculate_accuracy(answer, CORRECT_ANSWER) return { "test_id": test_id, "result": answer, "length": len(answer), "accuracy": accuracy, "success": accuracy >= 90, "time_minutes": test_time / 60, "error": None, } except subprocess.TimeoutExpired: # 清理临时文件 Path(output_file).unlink(missing_ok=True) return { "test_id": test_id, "result": None, "error": "timeout", "time_minutes": timeout_minutes, } except Exception as e: # 清理临时文件 Path(output_file).unlink(missing_ok=True) return { "test_id": test_id, "result": None, "error": str(e), "time_minutes": (time.time() - start_time) / 60, } def print_progress(completed, total, start_time): """打印进度信息""" if completed == 0: return elapsed = time.time() - start_time avg_time = elapsed / completed remaining_time = avg_time * (total - completed) print( f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) " f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h", end="", flush=True, ) def main(): """主测试函数""" print("🚀 FHE HNSW 并行批量测试脚本") print(f"🎯 正确答案: {CORRECT_ANSWER}") print("📊 运行100次测试,记录准确率和结果长度分布") print("🔧 支持多进程并行执行") # 检查enc二进制文件是否存在 if not Path("./enc").exists(): print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:") print(" cargo build --release --bin enc") return # 获取CPU核心数并设置进程数 cpu_cores = mp.cpu_count() # 考虑到FHE运算的内存密集性,使用核心数的一半避免内存不足 num_processes = 8 timeout_minutes = 45 # 增加超时时间到45分钟 print(f"💻 使用 {num_processes} 个并行进程") print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟") print( f"⏱️ 预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时" ) print() print("=" * 80) print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)") print("=" * 80) start_time = time.time() # 创建进程池并执行测试 test_ids = list(range(1, 101)) # 1-100 with mp.Pool(processes=num_processes) as pool: # 创建异步任务 async_results = [] for test_id in test_ids: async_result = pool.apply_async( run_single_test, (test_id, "12", timeout_minutes) ) async_results.append(async_result) # 收集结果并显示进度 results = [] completed = 0 print_progress(completed, len(test_ids), start_time) for async_result in async_results: try: result = async_result.get( timeout=timeout_minutes * 60 + 60 ) # 额外1分钟缓冲 results.append(result) completed += 1 print_progress(completed, len(test_ids), start_time) except mp.TimeoutError: # 进程级别超时 results.append( { "test_id": len(results) + 1, "result": None, "error": "process_timeout", "time_minutes": timeout_minutes, } ) completed += 1 print_progress(completed, len(test_ids), start_time) print() # 换行 total_elapsed = time.time() - start_time # 分析结果 print("\n" + "=" * 80) print("📈 测试结果分析") print("=" * 80) # 统计变量 valid_results = [r for r in results if r["result"] is not None] success_count = sum(1 for r in valid_results if r["success"]) length_distribution = defaultdict(int) error_distribution = defaultdict(int) # 统计错误类型 for r in results: if r["error"]: error_type = r["error"] if "timeout" in error_type.lower(): error_distribution["timeout"] += 1 elif "failed" in error_type.lower(): error_distribution["failed"] += 1 else: error_distribution["other"] += 1 else: length_distribution[r["length"]] += 1 # 基本统计 total_tests = len(results) valid_tests = len(valid_results) if valid_tests == 0: print("❌ 没有有效的测试结果") return success_rate = success_count / valid_tests * 100 avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests avg_length = sum(r["length"] for r in valid_results) / valid_tests avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests print(f"总测试次数: {total_tests}") print(f"有效测试次数: {valid_tests}") print(f"成功次数 (≥90%准确率): {success_count}") print(f"成功率: {success_rate:.1f}%") print(f"平均准确率: {avg_accuracy:.1f}%") print(f"平均结果长度: {avg_length:.1f}") print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟") print(f"总测试时间: {total_elapsed/3600:.1f}小时") print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x") # 错误统计 if error_distribution: print("\n❌ 错误分布:") for error_type, count in error_distribution.items(): print(f" {error_type}: {count}次") # 结果长度分布 if length_distribution: print("\n📊 结果长度分布:") for length in sorted(length_distribution.keys()): count = length_distribution[length] percentage = count / valid_tests * 100 bar = "█" * (count // 2) if count > 0 else "" print(f" 长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}") # 结论 print() if success_rate >= 50: print("✅ 测试通过! FHE HNSW实现稳定性良好") print(f" - 成功率: {success_rate:.1f}% (≥50%)") print(f" - 平均准确率: {avg_accuracy:.1f}%") if avg_length >= 9.5: print(f" - 平均结果长度: {avg_length:.1f} (接近10)") else: print(f" - 平均结果长度: {avg_length:.1f} (需要改进)") else: print("❌ 测试未通过,需要进一步优化") print(f" - 成功率: {success_rate:.1f}% (<50%)") print(f" - 平均准确率: {avg_accuracy:.1f}%") # 保存详细结果 report_file = "fhe_hnsw_parallel_test_report.json" with open(report_file, "w") as f: json.dump( { "summary": { "total_tests": total_tests, "valid_tests": valid_tests, "success_count": success_count, "success_rate": success_rate, "avg_accuracy": avg_accuracy, "avg_length": avg_length, "avg_time_minutes": avg_time_per_test, "total_time_hours": total_elapsed / 3600, "num_processes": num_processes, "cpu_cores": cpu_cores, "speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600), }, "length_distribution": dict(length_distribution), "error_distribution": dict(error_distribution), "detailed_results": results, "correct_answer": CORRECT_ANSWER, }, f, indent=2, ) print(f"\n📁 详细报告已保存到: {report_file}") # 显示最好和最坏的几个结果 if valid_results: print("\n🏆 最高准确率的5个结果:") best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[ :5 ] for i, r in enumerate(best_results, 1): print( f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)" ) print("\n⚠️ 最低准确率的5个结果:") worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5] for i, r in enumerate(worst_results, 1): print( f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)" ) if __name__ == "__main__": try: # 设置多进程启动方法(Linux上默认是fork,但spawn更安全) mp.set_start_method("spawn", force=True) main() except KeyboardInterrupt: print("\n\n⏹️ 测试被用户中断") # 清理可能的临时文件 for f in Path(".").glob("test_fhe_output_*.jsonl"): f.unlink(missing_ok=True) except Exception as e: print(f"\n💥 程序异常: {e}") # 清理可能的临时文件 for f in Path(".").glob("test_fhe_output_*.jsonl"): f.unlink(missing_ok=True)