- Add plain KNN implementation with JSONL data processing - Create Docker deployment setup with python:3.13-slim base - Add comprehensive OJ-style testing system with accuracy validation - Update README with detailed scoring mechanism explanation - Add run.sh script following competition manual requirements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
175 lines
5.9 KiB
Python
Executable File
175 lines
5.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
模拟OJ评分系统的测试脚本
|
|
用于测试FHE KNN算法的正确性
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
import subprocess
|
|
import tempfile
|
|
import os
|
|
import time
|
|
from typing import List, Tuple
|
|
|
|
def generate_test_data(num_data_points: int = 100, dimensions: int = 10) -> dict:
|
|
"""生成测试数据"""
|
|
print(f"生成测试数据: {num_data_points}个数据点, {dimensions}维向量")
|
|
|
|
# 生成随机数据点
|
|
data = []
|
|
for _ in range(num_data_points):
|
|
point = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
|
|
data.append(point)
|
|
|
|
# 生成查询向量
|
|
query = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
|
|
|
|
return {
|
|
"query": query,
|
|
"data": data
|
|
}
|
|
|
|
def write_test_data(test_data: dict, file_path: str):
|
|
"""将测试数据写入JSONL文件"""
|
|
with open(file_path, 'w') as f:
|
|
json.dump(test_data, f)
|
|
|
|
def run_program(program_name: str, dataset_path: str, predictions_path: str) -> Tuple[List[int], float]:
|
|
"""运行KNN程序并返回结果和耗时"""
|
|
print(f"运行程序: {program_name}")
|
|
|
|
try:
|
|
# 先编译程序(不计时)
|
|
print(f"编译{program_name}程序...")
|
|
compile_cmd = ["cargo", "build", "--release", "--bin", program_name]
|
|
compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
|
|
|
|
if compile_result.returncode != 0:
|
|
print(f"编译失败: {compile_result.stderr}")
|
|
return [], 0.0
|
|
|
|
# 构建运行命令
|
|
if program_name == "plain":
|
|
cmd = ["./target/release/plain", "--dataset", dataset_path, "--predictions", predictions_path]
|
|
elif program_name == "enc":
|
|
cmd = ["./target/release/enc", "--dataset", dataset_path, "--predictions", predictions_path]
|
|
else:
|
|
raise ValueError(f"未知程序: {program_name}")
|
|
|
|
# 计时运行程序
|
|
print(f"运行{program_name}程序...")
|
|
start_time = time.time()
|
|
result = subprocess.run(cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
|
|
if result.returncode != 0:
|
|
print(f"程序运行失败: {result.stderr}")
|
|
return [], elapsed_time
|
|
|
|
# 读取结果
|
|
with open(predictions_path, 'r') as f:
|
|
line = f.readline().strip()
|
|
prediction = json.loads(line)
|
|
return prediction["answer"], elapsed_time
|
|
|
|
except Exception as e:
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
print(f"运行程序时出错: {e}")
|
|
return [], elapsed_time
|
|
|
|
def calculate_accuracy(correct_answer: List[int], test_answer: List[int]) -> float:
|
|
"""计算正确率"""
|
|
if not test_answer:
|
|
return 0.0
|
|
|
|
correct_count = len(set(correct_answer) & set(test_answer))
|
|
total_count = len(correct_answer)
|
|
|
|
return correct_count / total_count
|
|
|
|
def main():
|
|
print("=" * 50)
|
|
print("FHE KNN 测试系统")
|
|
print("=" * 50)
|
|
|
|
# 生成测试数据
|
|
test_data = generate_test_data(100, 10)
|
|
|
|
# 创建临时文件
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as test_file:
|
|
test_file_path = test_file.name
|
|
write_test_data(test_data, test_file_path)
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as plain_result:
|
|
plain_result_path = plain_result.name
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as enc_result:
|
|
enc_result_path = enc_result.name
|
|
|
|
try:
|
|
# 运行plain程序获取标准答案
|
|
print("\\n1. 运行plain程序计算标准答案...")
|
|
correct_answer, plain_time = run_program("plain", test_file_path, plain_result_path)
|
|
|
|
if not correct_answer:
|
|
print("❌ plain程序运行失败")
|
|
return
|
|
|
|
print(f"✅ plain程序完成, 耗时: {plain_time:.2f}秒")
|
|
print(f"标准答案: {correct_answer}")
|
|
|
|
# 运行enc程序
|
|
print("\\n2. 运行enc程序进行密文计算...")
|
|
test_answer, enc_time = run_program("enc", test_file_path, enc_result_path)
|
|
|
|
if not test_answer:
|
|
print("❌ enc程序运行失败")
|
|
return
|
|
|
|
print(f"✅ enc程序完成, 耗时: {enc_time:.2f}秒")
|
|
print(f"测试答案: {test_answer}")
|
|
|
|
# 计算正确率
|
|
accuracy = calculate_accuracy(correct_answer, test_answer)
|
|
|
|
# 输出结果
|
|
print("\\n" + "=" * 50)
|
|
print("测试结果")
|
|
print("=" * 50)
|
|
print(f"标准答案: {correct_answer}")
|
|
print(f"测试答案: {test_answer}")
|
|
print(f"正确率: {accuracy:.1%}")
|
|
print(f"plain耗时: {plain_time:.2f}秒")
|
|
print(f"enc耗时: {enc_time:.2f}秒")
|
|
print(f"性能比: {enc_time/plain_time:.1f}x")
|
|
|
|
# 判断是否通过
|
|
if accuracy >= 0.9:
|
|
print("✅ 测试通过! (正确率≥90%)")
|
|
else:
|
|
print("❌ 测试失败! (正确率<90%)")
|
|
|
|
# 显示错误的答案
|
|
if accuracy < 1.0:
|
|
correct_set = set(correct_answer)
|
|
test_set = set(test_answer)
|
|
wrong_answers = test_set - correct_set
|
|
missed_answers = correct_set - test_set
|
|
|
|
print("\\n错误分析:")
|
|
if wrong_answers:
|
|
print(f"错误答案: {sorted(wrong_answers)}")
|
|
if missed_answers:
|
|
print(f"遗漏答案: {sorted(missed_answers)}")
|
|
|
|
finally:
|
|
# 清理临时文件
|
|
for file_path in [test_file_path, plain_result_path, enc_result_path]:
|
|
if os.path.exists(file_path):
|
|
os.unlink(file_path)
|
|
|
|
if __name__ == "__main__":
|
|
main() |