Implement plain KNN classifier and testing infrastructure
- Add plain KNN implementation with JSONL data processing - Create Docker deployment setup with python:3.13-slim base - Add comprehensive OJ-style testing system with accuracy validation - Update README with detailed scoring mechanism explanation - Add run.sh script following competition manual requirements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
175
test_oj.py
Executable file
175
test_oj.py
Executable file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
模拟OJ评分系统的测试脚本
|
||||
用于测试FHE KNN算法的正确性
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
def generate_test_data(num_data_points: int = 100, dimensions: int = 10) -> dict:
|
||||
"""生成测试数据"""
|
||||
print(f"生成测试数据: {num_data_points}个数据点, {dimensions}维向量")
|
||||
|
||||
# 生成随机数据点
|
||||
data = []
|
||||
for _ in range(num_data_points):
|
||||
point = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
|
||||
data.append(point)
|
||||
|
||||
# 生成查询向量
|
||||
query = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"data": data
|
||||
}
|
||||
|
||||
def write_test_data(test_data: dict, file_path: str):
|
||||
"""将测试数据写入JSONL文件"""
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(test_data, f)
|
||||
|
||||
def run_program(program_name: str, dataset_path: str, predictions_path: str) -> Tuple[List[int], float]:
|
||||
"""运行KNN程序并返回结果和耗时"""
|
||||
print(f"运行程序: {program_name}")
|
||||
|
||||
try:
|
||||
# 先编译程序(不计时)
|
||||
print(f"编译{program_name}程序...")
|
||||
compile_cmd = ["cargo", "build", "--release", "--bin", program_name]
|
||||
compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
|
||||
|
||||
if compile_result.returncode != 0:
|
||||
print(f"编译失败: {compile_result.stderr}")
|
||||
return [], 0.0
|
||||
|
||||
# 构建运行命令
|
||||
if program_name == "plain":
|
||||
cmd = ["./target/release/plain", "--dataset", dataset_path, "--predictions", predictions_path]
|
||||
elif program_name == "enc":
|
||||
cmd = ["./target/release/enc", "--dataset", dataset_path, "--predictions", predictions_path]
|
||||
else:
|
||||
raise ValueError(f"未知程序: {program_name}")
|
||||
|
||||
# 计时运行程序
|
||||
print(f"运行{program_name}程序...")
|
||||
start_time = time.time()
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"程序运行失败: {result.stderr}")
|
||||
return [], elapsed_time
|
||||
|
||||
# 读取结果
|
||||
with open(predictions_path, 'r') as f:
|
||||
line = f.readline().strip()
|
||||
prediction = json.loads(line)
|
||||
return prediction["answer"], elapsed_time
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"运行程序时出错: {e}")
|
||||
return [], elapsed_time
|
||||
|
||||
def calculate_accuracy(correct_answer: List[int], test_answer: List[int]) -> float:
|
||||
"""计算正确率"""
|
||||
if not test_answer:
|
||||
return 0.0
|
||||
|
||||
correct_count = len(set(correct_answer) & set(test_answer))
|
||||
total_count = len(correct_answer)
|
||||
|
||||
return correct_count / total_count
|
||||
|
||||
def main():
|
||||
print("=" * 50)
|
||||
print("FHE KNN 测试系统")
|
||||
print("=" * 50)
|
||||
|
||||
# 生成测试数据
|
||||
test_data = generate_test_data(100, 10)
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as test_file:
|
||||
test_file_path = test_file.name
|
||||
write_test_data(test_data, test_file_path)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as plain_result:
|
||||
plain_result_path = plain_result.name
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as enc_result:
|
||||
enc_result_path = enc_result.name
|
||||
|
||||
try:
|
||||
# 运行plain程序获取标准答案
|
||||
print("\\n1. 运行plain程序计算标准答案...")
|
||||
correct_answer, plain_time = run_program("plain", test_file_path, plain_result_path)
|
||||
|
||||
if not correct_answer:
|
||||
print("❌ plain程序运行失败")
|
||||
return
|
||||
|
||||
print(f"✅ plain程序完成, 耗时: {plain_time:.2f}秒")
|
||||
print(f"标准答案: {correct_answer}")
|
||||
|
||||
# 运行enc程序
|
||||
print("\\n2. 运行enc程序进行密文计算...")
|
||||
test_answer, enc_time = run_program("enc", test_file_path, enc_result_path)
|
||||
|
||||
if not test_answer:
|
||||
print("❌ enc程序运行失败")
|
||||
return
|
||||
|
||||
print(f"✅ enc程序完成, 耗时: {enc_time:.2f}秒")
|
||||
print(f"测试答案: {test_answer}")
|
||||
|
||||
# 计算正确率
|
||||
accuracy = calculate_accuracy(correct_answer, test_answer)
|
||||
|
||||
# 输出结果
|
||||
print("\\n" + "=" * 50)
|
||||
print("测试结果")
|
||||
print("=" * 50)
|
||||
print(f"标准答案: {correct_answer}")
|
||||
print(f"测试答案: {test_answer}")
|
||||
print(f"正确率: {accuracy:.1%}")
|
||||
print(f"plain耗时: {plain_time:.2f}秒")
|
||||
print(f"enc耗时: {enc_time:.2f}秒")
|
||||
print(f"性能比: {enc_time/plain_time:.1f}x")
|
||||
|
||||
# 判断是否通过
|
||||
if accuracy >= 0.9:
|
||||
print("✅ 测试通过! (正确率≥90%)")
|
||||
else:
|
||||
print("❌ 测试失败! (正确率<90%)")
|
||||
|
||||
# 显示错误的答案
|
||||
if accuracy < 1.0:
|
||||
correct_set = set(correct_answer)
|
||||
test_set = set(test_answer)
|
||||
wrong_answers = test_set - correct_set
|
||||
missed_answers = correct_set - test_set
|
||||
|
||||
print("\\n错误分析:")
|
||||
if wrong_answers:
|
||||
print(f"错误答案: {sorted(wrong_answers)}")
|
||||
if missed_answers:
|
||||
print(f"遗漏答案: {sorted(missed_answers)}")
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
for file_path in [test_file_path, plain_result_path, enc_result_path]:
|
||||
if os.path.exists(file_path):
|
||||
os.unlink(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user