hfe_knn/test_oj.py
sangge d58adda9ab Implement plain KNN classifier and testing infrastructure
- Add plain KNN implementation with JSONL data processing
- Create Docker deployment setup with python:3.13-slim base
- Add comprehensive OJ-style testing system with accuracy validation
- Update README with detailed scoring mechanism explanation
- Add run.sh script following competition manual requirements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-07-05 21:11:22 +08:00

175 lines
5.9 KiB
Python
Executable File

#!/usr/bin/env python3
"""
模拟OJ评分系统的测试脚本
用于测试FHE KNN算法的正确性
"""
import json
import random
import subprocess
import tempfile
import os
import time
from typing import List, Tuple
def generate_test_data(num_data_points: int = 100, dimensions: int = 10) -> dict:
"""生成测试数据"""
print(f"生成测试数据: {num_data_points}个数据点, {dimensions}维向量")
# 生成随机数据点
data = []
for _ in range(num_data_points):
point = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
data.append(point)
# 生成查询向量
query = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
return {
"query": query,
"data": data
}
def write_test_data(test_data: dict, file_path: str):
"""将测试数据写入JSONL文件"""
with open(file_path, 'w') as f:
json.dump(test_data, f)
def run_program(program_name: str, dataset_path: str, predictions_path: str) -> Tuple[List[int], float]:
"""运行KNN程序并返回结果和耗时"""
print(f"运行程序: {program_name}")
try:
# 先编译程序(不计时)
print(f"编译{program_name}程序...")
compile_cmd = ["cargo", "build", "--release", "--bin", program_name]
compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
if compile_result.returncode != 0:
print(f"编译失败: {compile_result.stderr}")
return [], 0.0
# 构建运行命令
if program_name == "plain":
cmd = ["./target/release/plain", "--dataset", dataset_path, "--predictions", predictions_path]
elif program_name == "enc":
cmd = ["./target/release/enc", "--dataset", dataset_path, "--predictions", predictions_path]
else:
raise ValueError(f"未知程序: {program_name}")
# 计时运行程序
print(f"运行{program_name}程序...")
start_time = time.time()
result = subprocess.run(cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
end_time = time.time()
elapsed_time = end_time - start_time
if result.returncode != 0:
print(f"程序运行失败: {result.stderr}")
return [], elapsed_time
# 读取结果
with open(predictions_path, 'r') as f:
line = f.readline().strip()
prediction = json.loads(line)
return prediction["answer"], elapsed_time
except Exception as e:
end_time = time.time()
elapsed_time = end_time - start_time
print(f"运行程序时出错: {e}")
return [], elapsed_time
def calculate_accuracy(correct_answer: List[int], test_answer: List[int]) -> float:
"""计算正确率"""
if not test_answer:
return 0.0
correct_count = len(set(correct_answer) & set(test_answer))
total_count = len(correct_answer)
return correct_count / total_count
def main():
print("=" * 50)
print("FHE KNN 测试系统")
print("=" * 50)
# 生成测试数据
test_data = generate_test_data(100, 10)
# 创建临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as test_file:
test_file_path = test_file.name
write_test_data(test_data, test_file_path)
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as plain_result:
plain_result_path = plain_result.name
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as enc_result:
enc_result_path = enc_result.name
try:
# 运行plain程序获取标准答案
print("\\n1. 运行plain程序计算标准答案...")
correct_answer, plain_time = run_program("plain", test_file_path, plain_result_path)
if not correct_answer:
print("❌ plain程序运行失败")
return
print(f"✅ plain程序完成, 耗时: {plain_time:.2f}")
print(f"标准答案: {correct_answer}")
# 运行enc程序
print("\\n2. 运行enc程序进行密文计算...")
test_answer, enc_time = run_program("enc", test_file_path, enc_result_path)
if not test_answer:
print("❌ enc程序运行失败")
return
print(f"✅ enc程序完成, 耗时: {enc_time:.2f}")
print(f"测试答案: {test_answer}")
# 计算正确率
accuracy = calculate_accuracy(correct_answer, test_answer)
# 输出结果
print("\\n" + "=" * 50)
print("测试结果")
print("=" * 50)
print(f"标准答案: {correct_answer}")
print(f"测试答案: {test_answer}")
print(f"正确率: {accuracy:.1%}")
print(f"plain耗时: {plain_time:.2f}")
print(f"enc耗时: {enc_time:.2f}")
print(f"性能比: {enc_time/plain_time:.1f}x")
# 判断是否通过
if accuracy >= 0.9:
print("✅ 测试通过! (正确率≥90%)")
else:
print("❌ 测试失败! (正确率<90%)")
# 显示错误的答案
if accuracy < 1.0:
correct_set = set(correct_answer)
test_set = set(test_answer)
wrong_answers = test_set - correct_set
missed_answers = correct_set - test_set
print("\\n错误分析:")
if wrong_answers:
print(f"错误答案: {sorted(wrong_answers)}")
if missed_answers:
print(f"遗漏答案: {sorted(missed_answers)}")
finally:
# 清理临时文件
for file_path in [test_file_path, plain_result_path, enc_result_path]:
if os.path.exists(file_path):
os.unlink(file_path)
if __name__ == "__main__":
main()