Implement plain KNN classifier and testing infrastructure

- Add plain KNN implementation with JSONL data processing
- Create Docker deployment setup with python:3.13-slim base
- Add comprehensive OJ-style testing system with accuracy validation
- Update README with detailed scoring mechanism explanation
- Add run.sh script following competition manual requirements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
sangge 2025-07-05 21:11:22 +08:00
parent 136583715e
commit d58adda9ab
8 changed files with 505 additions and 147 deletions

241
Cargo.lock generated
View File

@ -23,6 +23,56 @@ dependencies = [
"serde",
]
[[package]]
name = "anstream"
version = "0.6.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
"is_terminal_polyfill",
"utf8parse",
]
[[package]]
name = "anstyle"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd"
[[package]]
name = "anstyle-parse"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
dependencies = [
"utf8parse",
]
[[package]]
name = "anstyle-query"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9"
dependencies = [
"windows-sys",
]
[[package]]
name = "anstyle-wincon"
version = "3.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys",
]
[[package]]
name = "anyhow"
version = "1.0.98"
@ -61,9 +111,9 @@ dependencies = [
[[package]]
name = "bumpalo"
version = "3.18.1"
version = "3.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee"
checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
[[package]]
name = "bytemuck"
@ -87,6 +137,52 @@ dependencies = [
"inout",
]
[[package]]
name = "clap"
version = "4.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f"
dependencies = [
"clap_builder",
"clap_derive",
]
[[package]]
name = "clap_builder"
version = "4.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "clap_lex"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
[[package]]
name = "colorchoice"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "cpufeatures"
version = "0.2.17"
@ -131,27 +227,6 @@ dependencies = [
"typenum",
]
[[package]]
name = "csv"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf"
dependencies = [
"csv-core",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "csv-core"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d"
dependencies = [
"memchr",
]
[[package]]
name = "digest"
version = "0.10.7"
@ -241,9 +316,10 @@ name = "hfe_knn"
version = "0.1.0"
dependencies = [
"anyhow",
"csv",
"clap",
"rand",
"serde",
"serde_json",
"tfhe",
]
@ -256,6 +332,12 @@ dependencies = [
"generic-array",
]
[[package]]
name = "is_terminal_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.14.0"
@ -340,6 +422,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "once_cell_polyfill"
version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
[[package]]
name = "paste"
version = "1.0.15"
@ -489,6 +577,18 @@ dependencies = [
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "sha3"
version = "0.10.8"
@ -499,6 +599,12 @@ dependencies = [
"keccak",
]
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.27.1"
@ -534,9 +640,9 @@ dependencies = [
[[package]]
name = "tfhe"
version = "1.2.0"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cb9a09347f586b9e49336ddad761706dfbfc414c44a8892b5a38666a47736a9"
checksum = "11e74f14e2812ac6a2fe516f2c090431b614d589c24c6e3bca3fe763d4de928e"
dependencies = [
"aligned-vec",
"bincode",
@ -558,9 +664,9 @@ dependencies = [
[[package]]
name = "tfhe-csprng"
version = "0.5.0"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b1a0b71a91c9e36e9e1e8e6a528a056086b01e193e173472ae59cfaac56fdad"
checksum = "06450766ae375cd305281c5d691a41d8877e40d3cc7510a33244775d62f60ad4"
dependencies = [
"aes",
"libc",
@ -628,6 +734,12 @@ version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "utf8parse"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "version_check"
version = "0.9.5"
@ -706,6 +818,79 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"

View File

@ -4,8 +4,9 @@ version = "0.1.0"
edition = "2024"
[dependencies]
tfhe = { version = "~1.2.0", features = ["integer"] }
csv = "1.3"
serde = { version = "1.0", features = ["derive"] }
anyhow = "1.0"
tfhe = { version = "1", features = ["integer"] }
serde = { version = "1", features = ["derive"] }
anyhow = "1"
rand = "0.9"
clap = { version = "4.0", features = ["derive"] }
serde_json = "1"

22
Dockerfile Normal file
View File

@ -0,0 +1,22 @@
FROM python:3.13-slim
# Create necessary directories
RUN mkdir -p /home/admin/predict \
/home/admin/data \
/home/admin/workspace/job/logs \
/home/admin/workspace/job/output/predictions \
/home/admin/workspace/job/input
# Copy files to required locations according to manual
COPY run.sh /home/admin/predict/run.sh
RUN chmod 777 /home/admin/predict/run.sh
# Copy the compiled binary as 'test' executable
COPY target/release/hfe_knn /home/admin/predict/test
RUN chmod +x /home/admin/predict/test
# Copy training data
COPY dataset/train.jsonl /home/admin/data/data.jsonl
# Set default command
CMD ["/home/admin/predict/run.sh"]

View File

@ -8,3 +8,25 @@ knn算法关键是使用一个距离函数来计算样本之间的距离
全同态加密计划使用TFHE-rs, 简化后流程无须交互,仅在单个程序内模拟即可。
评分详情见[参赛手册](./manual.md)
## 评分机制
### 正确率计算
- 算法需要返回10个最近邻向量的索引
- 正确率 = 正确的索引数量 / 10
- 例如10个结果中有9个正确正确率 = 90%
### 评分规则
- **门槛要求**正确率必须≥90%即10个结果中至少9个正确
- **排名依据**:达到门槛后,按总耗时排名(越快越好)
- **淘汰机制**:正确率<90%直接得0分
### "正确"的定义
- 比赛方有标准答案真实的10个最近邻
- 算法结果与标准答案比较
- 顺序不重要,只要索引正确即可
### 数据格式
- 训练数据JSONL格式每行包含一个query向量和data数组
- 输出格式:`{"answer": [索引1, 索引2, ..., 索引10]}`
- 索引从1开始编号

1
dataset/answer.jsonl Normal file
View File

@ -0,0 +1 @@
{"answer":[93,94,90,27,87,50,47,16,6,40]}

22
run.sh Normal file
View File

@ -0,0 +1,22 @@
#!/bin/bash
SCRIPT_DIR=$(dirname "$0")
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
# 根据运行环境选择文件路径
if [ "$ALIPAY_APP_ENV" = "prod" ]; then
PREDICTIONS_RESULT_FILE="/home/admin/workspace/job/output/predictions/predictions.jsonl"
DATASET_FILE="/home/admin/workspace/job/input/test.jsonl"
else
PREDICTIONS_RESULT_FILE="${PARENT_DIR}/data/predictions.jsonl"
DATASET_FILE="${PARENT_DIR}/data/data.jsonl"
fi
echo $PREDICTIONS_RESULT_FILE
#以上内容**不可**修改!
#选手仅可修改下面的test为自己的实际比赛代码入口
chmod +x "${SCRIPT_DIR}/test"
"${SCRIPT_DIR}/test" \
--dataset "$DATASET_FILE" \
--predictions "$PREDICTIONS_RESULT_FILE"

View File

@ -1,141 +1,71 @@
use anyhow::Result;
use rand::rng;
use rand::seq::SliceRandom;
use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
use clap::Parser;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
#[derive(Debug, Deserialize, Clone)]
struct IrisData {
#[serde(rename = "SepalLengthCm", deserialize_with = "f64_to_u32")]
sepal_length: u32,
#[serde(rename = "SepalWidthCm", deserialize_with = "f64_to_u32")]
sepal_width: u32,
#[serde(rename = "PetalLengthCm", deserialize_with = "f64_to_u32")]
petal_length: u32,
#[serde(rename = "PetalWidthCm", deserialize_with = "f64_to_u32")]
petal_width: u32,
#[serde(rename = "Species")]
species: String,
#[derive(Parser)]
#[command(name = "hfe_knn")]
#[command(about = "FHE-based KNN classifier")]
struct Args {
#[arg(long)]
dataset: String,
#[arg(long)]
predictions: String,
}
struct IrisDataUnknown {
sepal_length: u32,
sepal_width: u32,
petal_length: u32,
petal_width: u32,
#[derive(Deserialize)]
struct Dataset {
query: Vec<f64>,
data: Vec<Vec<f64>>,
}
fn f64_to_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
where
D: Deserializer<'de>,
{
let f = f64::deserialize(deserializer)?;
Ok((f * 10.0).round() as u32) // 放大10倍保留一位小数精度
#[derive(Serialize)]
struct Prediction {
answer: Vec<usize>,
}
fn distance(a: &IrisDataUnknown, b: &IrisData) -> u32 {
let diff_sl = (a.sepal_length as i32 - b.sepal_length as i32).unsigned_abs();
let diff_sw = (a.sepal_width as i32 - b.sepal_width as i32).unsigned_abs();
let diff_pl = (a.petal_length as i32 - b.petal_length as i32).unsigned_abs();
let diff_pw = (a.petal_width as i32 - b.petal_width as i32).unsigned_abs();
// Squared euclidean distance (avoiding sqrt for integer math)
diff_sl * diff_sl + diff_sw * diff_sw + diff_pl * diff_pl + diff_pw * diff_pw
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
fn knn_classify(query: &IrisDataUnknown, labeled_data: &[IrisData], k: usize) -> String {
let mut distances: Vec<(u32, &String)> = labeled_data
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
let mut distances: Vec<(f64, usize)> = data
.iter()
.map(|data| (distance(query, data), &data.species))
.enumerate()
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
.collect();
distances.sort_by_key(|&(dist, _)| dist);
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let mut species_count = HashMap::new();
for &(_, species) in distances.iter().take(k) {
*species_count.entry(species.clone()).or_insert(0) += 1;
}
species_count
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(species, _)| species)
.unwrap_or_else(|| "Unknown".to_string())
}
fn load_from_csv(filename: &str) -> Result<Vec<IrisData>> {
let mut rdr = csv::Reader::from_path(filename)?;
let mut data = Vec::new();
for result in rdr.deserialize() {
let record: IrisData = result?;
data.push(record);
}
Ok(data)
}
fn split_dataset(dataset: &[IrisData], test_ratio: f64) -> (Vec<IrisData>, Vec<IrisData>) {
let mut shuffled = dataset.to_vec();
shuffled.shuffle(&mut rng());
let test_size = (dataset.len() as f64 * test_ratio) as usize;
let train_size = dataset.len() - test_size;
let train_data = shuffled[..train_size].to_vec();
let test_data = shuffled[train_size..].to_vec();
(train_data, test_data)
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
}
fn main() -> Result<()> {
let filename = "./dataset/Iris.csv";
let data = load_from_csv(filename)?;
println!("Loaded {} records from {}", data.len(), filename);
let args = Args::parse();
let rounds = 10;
let k = 5;
let test_ratio = 0.2;
let mut total_accuracy = 0.0;
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
println!(
"Running {} rounds of KNN classification (k={}, test_ratio={})",
rounds, k, test_ratio
);
let mut results = Vec::new();
for round in 1..=rounds {
let (train_data, test_data) = split_dataset(&data, test_ratio);
let mut correct = 0;
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
for test_item in &test_data {
let query = IrisDataUnknown {
sepal_length: test_item.sepal_length,
sepal_width: test_item.sepal_width,
petal_length: test_item.petal_length,
petal_width: test_item.petal_width,
};
let predicted = knn_classify(&query, &train_data, k);
if predicted == test_item.species {
correct += 1;
}
}
let accuracy = correct as f64 / test_data.len() as f64;
total_accuracy += accuracy;
println!(
"Round {}: Accuracy = {:.2}% ({}/{} correct)",
round,
accuracy * 100.0,
correct,
test_data.len()
);
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
results.push(Prediction { answer: nearest });
}
let avg_accuracy = total_accuracy / rounds as f64;
println!(
"\nAverage accuracy over {} rounds: {:.2}%",
rounds,
avg_accuracy * 100.0
);
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
}
Ok(())
}

175
test_oj.py Executable file
View File

@ -0,0 +1,175 @@
#!/usr/bin/env python3
"""
模拟OJ评分系统的测试脚本
用于测试FHE KNN算法的正确性
"""
import json
import random
import subprocess
import tempfile
import os
import time
from typing import List, Tuple
def generate_test_data(num_data_points: int = 100, dimensions: int = 10) -> dict:
"""生成测试数据"""
print(f"生成测试数据: {num_data_points}个数据点, {dimensions}维向量")
# 生成随机数据点
data = []
for _ in range(num_data_points):
point = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
data.append(point)
# 生成查询向量
query = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
return {
"query": query,
"data": data
}
def write_test_data(test_data: dict, file_path: str):
"""将测试数据写入JSONL文件"""
with open(file_path, 'w') as f:
json.dump(test_data, f)
def run_program(program_name: str, dataset_path: str, predictions_path: str) -> Tuple[List[int], float]:
"""运行KNN程序并返回结果和耗时"""
print(f"运行程序: {program_name}")
try:
# 先编译程序(不计时)
print(f"编译{program_name}程序...")
compile_cmd = ["cargo", "build", "--release", "--bin", program_name]
compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
if compile_result.returncode != 0:
print(f"编译失败: {compile_result.stderr}")
return [], 0.0
# 构建运行命令
if program_name == "plain":
cmd = ["./target/release/plain", "--dataset", dataset_path, "--predictions", predictions_path]
elif program_name == "enc":
cmd = ["./target/release/enc", "--dataset", dataset_path, "--predictions", predictions_path]
else:
raise ValueError(f"未知程序: {program_name}")
# 计时运行程序
print(f"运行{program_name}程序...")
start_time = time.time()
result = subprocess.run(cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
end_time = time.time()
elapsed_time = end_time - start_time
if result.returncode != 0:
print(f"程序运行失败: {result.stderr}")
return [], elapsed_time
# 读取结果
with open(predictions_path, 'r') as f:
line = f.readline().strip()
prediction = json.loads(line)
return prediction["answer"], elapsed_time
except Exception as e:
end_time = time.time()
elapsed_time = end_time - start_time
print(f"运行程序时出错: {e}")
return [], elapsed_time
def calculate_accuracy(correct_answer: List[int], test_answer: List[int]) -> float:
"""计算正确率"""
if not test_answer:
return 0.0
correct_count = len(set(correct_answer) & set(test_answer))
total_count = len(correct_answer)
return correct_count / total_count
def main():
print("=" * 50)
print("FHE KNN 测试系统")
print("=" * 50)
# 生成测试数据
test_data = generate_test_data(100, 10)
# 创建临时文件
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as test_file:
test_file_path = test_file.name
write_test_data(test_data, test_file_path)
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as plain_result:
plain_result_path = plain_result.name
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as enc_result:
enc_result_path = enc_result.name
try:
# 运行plain程序获取标准答案
print("\\n1. 运行plain程序计算标准答案...")
correct_answer, plain_time = run_program("plain", test_file_path, plain_result_path)
if not correct_answer:
print("❌ plain程序运行失败")
return
print(f"✅ plain程序完成, 耗时: {plain_time:.2f}")
print(f"标准答案: {correct_answer}")
# 运行enc程序
print("\\n2. 运行enc程序进行密文计算...")
test_answer, enc_time = run_program("enc", test_file_path, enc_result_path)
if not test_answer:
print("❌ enc程序运行失败")
return
print(f"✅ enc程序完成, 耗时: {enc_time:.2f}")
print(f"测试答案: {test_answer}")
# 计算正确率
accuracy = calculate_accuracy(correct_answer, test_answer)
# 输出结果
print("\\n" + "=" * 50)
print("测试结果")
print("=" * 50)
print(f"标准答案: {correct_answer}")
print(f"测试答案: {test_answer}")
print(f"正确率: {accuracy:.1%}")
print(f"plain耗时: {plain_time:.2f}")
print(f"enc耗时: {enc_time:.2f}")
print(f"性能比: {enc_time/plain_time:.1f}x")
# 判断是否通过
if accuracy >= 0.9:
print("✅ 测试通过! (正确率≥90%)")
else:
print("❌ 测试失败! (正确率<90%)")
# 显示错误的答案
if accuracy < 1.0:
correct_set = set(correct_answer)
test_set = set(test_answer)
wrong_answers = test_set - correct_set
missed_answers = correct_set - test_set
print("\\n错误分析:")
if wrong_answers:
print(f"错误答案: {sorted(wrong_answers)}")
if missed_answers:
print(f"遗漏答案: {sorted(missed_answers)}")
finally:
# 清理临时文件
for file_path in [test_file_path, plain_result_path, enc_result_path]:
if os.path.exists(file_path):
os.unlink(file_path)
if __name__ == "__main__":
main()