#!/usr/bin/env python3 """ 模拟OJ评分系统的测试脚本 用于测试FHE KNN算法的正确性 """ import json import random import subprocess import tempfile import os import time from typing import List, Tuple def generate_test_data(num_data_points: int = 100, dimensions: int = 10) -> dict: """生成测试数据""" print(f"生成测试数据: {num_data_points}个数据点, {dimensions}维向量") # 生成随机数据点 data = [] for _ in range(num_data_points): point = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)] data.append(point) # 生成查询向量 query = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)] return { "query": query, "data": data } def write_test_data(test_data: dict, file_path: str): """将测试数据写入JSONL文件""" with open(file_path, 'w') as f: json.dump(test_data, f) def run_program(program_name: str, dataset_path: str, predictions_path: str) -> Tuple[List[int], float]: """运行KNN程序并返回结果和耗时""" print(f"运行程序: {program_name}") try: # 先编译程序(不计时) print(f"编译{program_name}程序...") compile_cmd = ["cargo", "build", "--release", "--bin", program_name] compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn") if compile_result.returncode != 0: print(f"编译失败: {compile_result.stderr}") return [], 0.0 # 构建运行命令 if program_name == "plain": cmd = ["./target/release/plain", "--dataset", dataset_path, "--predictions", predictions_path] elif program_name == "enc": cmd = ["./target/release/enc", "--dataset", dataset_path, "--predictions", predictions_path] else: raise ValueError(f"未知程序: {program_name}") # 计时运行程序 print(f"运行{program_name}程序...") start_time = time.time() result = subprocess.run(cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn") end_time = time.time() elapsed_time = end_time - start_time if result.returncode != 0: print(f"程序运行失败: {result.stderr}") return [], elapsed_time # 读取结果 with open(predictions_path, 'r') as f: line = f.readline().strip() prediction = json.loads(line) return prediction["answer"], elapsed_time except Exception as e: end_time = time.time() elapsed_time = end_time - start_time print(f"运行程序时出错: {e}") return [], elapsed_time def calculate_accuracy(correct_answer: List[int], test_answer: List[int]) -> float: """计算正确率""" if not test_answer: return 0.0 correct_count = len(set(correct_answer) & set(test_answer)) total_count = len(correct_answer) return correct_count / total_count def main(): print("=" * 50) print("FHE KNN 测试系统") print("=" * 50) # 生成测试数据 test_data = generate_test_data(100, 10) # 创建临时文件 with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as test_file: test_file_path = test_file.name write_test_data(test_data, test_file_path) with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as plain_result: plain_result_path = plain_result.name with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as enc_result: enc_result_path = enc_result.name try: # 运行plain程序获取标准答案 print("\\n1. 运行plain程序计算标准答案...") correct_answer, plain_time = run_program("plain", test_file_path, plain_result_path) if not correct_answer: print("❌ plain程序运行失败") return print(f"✅ plain程序完成, 耗时: {plain_time:.2f}秒") print(f"标准答案: {correct_answer}") # 运行enc程序 print("\\n2. 运行enc程序进行密文计算...") test_answer, enc_time = run_program("enc", test_file_path, enc_result_path) if not test_answer: print("❌ enc程序运行失败") return print(f"✅ enc程序完成, 耗时: {enc_time:.2f}秒") print(f"测试答案: {test_answer}") # 计算正确率 accuracy = calculate_accuracy(correct_answer, test_answer) # 输出结果 print("\\n" + "=" * 50) print("测试结果") print("=" * 50) print(f"标准答案: {correct_answer}") print(f"测试答案: {test_answer}") print(f"正确率: {accuracy:.1%}") print(f"plain耗时: {plain_time:.2f}秒") print(f"enc耗时: {enc_time:.2f}秒") print(f"性能比: {enc_time/plain_time:.1f}x") # 判断是否通过 if accuracy >= 0.9: print("✅ 测试通过! (正确率≥90%)") else: print("❌ 测试失败! (正确率<90%)") # 显示错误的答案 if accuracy < 1.0: correct_set = set(correct_answer) test_set = set(test_answer) wrong_answers = test_set - correct_set missed_answers = correct_set - test_set print("\\n错误分析:") if wrong_answers: print(f"错误答案: {sorted(wrong_answers)}") if missed_answers: print(f"遗漏答案: {sorted(missed_answers)}") finally: # 清理临时文件 for file_path in [test_file_path, plain_result_path, enc_result_path]: if os.path.exists(file_path): os.unlink(file_path) if __name__ == "__main__": main()