diff --git a/Cargo.lock b/Cargo.lock index 7271db8..d807486 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,56 @@ dependencies = [ "serde", ] +[[package]] +name = "anstream" +version = "0.6.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys", +] + [[package]] name = "anyhow" version = "1.0.98" @@ -61,9 +111,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.18.1" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytemuck" @@ -87,6 +137,52 @@ dependencies = [ "inout", ] +[[package]] +name = "clap" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + [[package]] name = "cpufeatures" version = "0.2.17" @@ -131,27 +227,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "csv" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" -dependencies = [ - "csv-core", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "csv-core" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" -dependencies = [ - "memchr", -] - [[package]] name = "digest" version = "0.10.7" @@ -241,9 +316,10 @@ name = "hfe_knn" version = "0.1.0" dependencies = [ "anyhow", - "csv", + "clap", "rand", "serde", + "serde_json", "tfhe", ] @@ -256,6 +332,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.14.0" @@ -340,6 +422,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" + [[package]] name = "paste" version = "1.0.15" @@ -489,6 +577,18 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.140" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "sha3" version = "0.10.8" @@ -499,6 +599,12 @@ dependencies = [ "keccak", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "strum" version = "0.27.1" @@ -534,9 +640,9 @@ dependencies = [ [[package]] name = "tfhe" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cb9a09347f586b9e49336ddad761706dfbfc414c44a8892b5a38666a47736a9" +checksum = "11e74f14e2812ac6a2fe516f2c090431b614d589c24c6e3bca3fe763d4de928e" dependencies = [ "aligned-vec", "bincode", @@ -558,9 +664,9 @@ dependencies = [ [[package]] name = "tfhe-csprng" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b1a0b71a91c9e36e9e1e8e6a528a056086b01e193e173472ae59cfaac56fdad" +checksum = "06450766ae375cd305281c5d691a41d8877e40d3cc7510a33244775d62f60ad4" dependencies = [ "aes", "libc", @@ -628,6 +734,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "version_check" version = "0.9.5" @@ -706,6 +818,79 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "wit-bindgen-rt" version = "0.39.0" diff --git a/Cargo.toml b/Cargo.toml index 077ebb4..586dca0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,8 +4,9 @@ version = "0.1.0" edition = "2024" [dependencies] -tfhe = { version = "~1.2.0", features = ["integer"] } -csv = "1.3" -serde = { version = "1.0", features = ["derive"] } -anyhow = "1.0" +tfhe = { version = "1", features = ["integer"] } +serde = { version = "1", features = ["derive"] } +anyhow = "1" rand = "0.9" +clap = { version = "4.0", features = ["derive"] } +serde_json = "1" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4308290 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,22 @@ +FROM python:3.13-slim + +# Create necessary directories +RUN mkdir -p /home/admin/predict \ + /home/admin/data \ + /home/admin/workspace/job/logs \ + /home/admin/workspace/job/output/predictions \ + /home/admin/workspace/job/input + +# Copy files to required locations according to manual +COPY run.sh /home/admin/predict/run.sh +RUN chmod 777 /home/admin/predict/run.sh + +# Copy the compiled binary as 'test' executable +COPY target/release/hfe_knn /home/admin/predict/test +RUN chmod +x /home/admin/predict/test + +# Copy training data +COPY dataset/train.jsonl /home/admin/data/data.jsonl + +# Set default command +CMD ["/home/admin/predict/run.sh"] diff --git a/README.md b/README.md index 72ba67b..034baa1 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,25 @@ knn算法关键是使用一个距离函数来计算样本之间的距离,常 全同态加密计划使用TFHE-rs, 简化后流程无须交互,仅在单个程序内模拟即可。 评分详情见[参赛手册](./manual.md) + +## 评分机制 + +### 正确率计算 +- 算法需要返回10个最近邻向量的索引 +- 正确率 = 正确的索引数量 / 10 +- 例如:10个结果中有9个正确,正确率 = 90% + +### 评分规则 +- **门槛要求**:正确率必须≥90%(即10个结果中至少9个正确) +- **排名依据**:达到门槛后,按总耗时排名(越快越好) +- **淘汰机制**:正确率<90%直接得0分 + +### "正确"的定义 +- 比赛方有标准答案(真实的10个最近邻) +- 算法结果与标准答案比较 +- 顺序不重要,只要索引正确即可 + +### 数据格式 +- 训练数据:JSONL格式,每行包含一个query向量和data数组 +- 输出格式:`{"answer": [索引1, 索引2, ..., 索引10]}` +- 索引从1开始编号 diff --git a/dataset/answer.jsonl b/dataset/answer.jsonl new file mode 100644 index 0000000..f306b59 --- /dev/null +++ b/dataset/answer.jsonl @@ -0,0 +1 @@ +{"answer":[93,94,90,27,87,50,47,16,6,40]} diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..9874ee7 --- /dev/null +++ b/run.sh @@ -0,0 +1,22 @@ +#!/bin/bash +SCRIPT_DIR=$(dirname "$0") +PARENT_DIR="$(dirname "$SCRIPT_DIR")" + +# 根据运行环境选择文件路径 +if [ "$ALIPAY_APP_ENV" = "prod" ]; then + PREDICTIONS_RESULT_FILE="/home/admin/workspace/job/output/predictions/predictions.jsonl" + DATASET_FILE="/home/admin/workspace/job/input/test.jsonl" +else + PREDICTIONS_RESULT_FILE="${PARENT_DIR}/data/predictions.jsonl" + DATASET_FILE="${PARENT_DIR}/data/data.jsonl" +fi + +echo $PREDICTIONS_RESULT_FILE + +#以上内容**不可**修改! +#选手仅可修改下面的test为自己的实际比赛代码入口 +chmod +x "${SCRIPT_DIR}/test" + +"${SCRIPT_DIR}/test" \ + --dataset "$DATASET_FILE" \ + --predictions "$PREDICTIONS_RESULT_FILE" \ No newline at end of file diff --git a/src/bin/plain.rs b/src/bin/plain.rs index 63d3514..9bb62bd 100644 --- a/src/bin/plain.rs +++ b/src/bin/plain.rs @@ -1,141 +1,71 @@ use anyhow::Result; -use rand::rng; -use rand::seq::SliceRandom; -use serde::{Deserialize, Deserializer}; -use std::collections::HashMap; +use clap::Parser; +use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::{BufRead, BufReader, Write}; -#[derive(Debug, Deserialize, Clone)] -struct IrisData { - #[serde(rename = "SepalLengthCm", deserialize_with = "f64_to_u32")] - sepal_length: u32, - #[serde(rename = "SepalWidthCm", deserialize_with = "f64_to_u32")] - sepal_width: u32, - #[serde(rename = "PetalLengthCm", deserialize_with = "f64_to_u32")] - petal_length: u32, - #[serde(rename = "PetalWidthCm", deserialize_with = "f64_to_u32")] - petal_width: u32, - #[serde(rename = "Species")] - species: String, +#[derive(Parser)] +#[command(name = "hfe_knn")] +#[command(about = "FHE-based KNN classifier")] +struct Args { + #[arg(long)] + dataset: String, + #[arg(long)] + predictions: String, } -struct IrisDataUnknown { - sepal_length: u32, - sepal_width: u32, - petal_length: u32, - petal_width: u32, +#[derive(Deserialize)] +struct Dataset { + query: Vec, + data: Vec>, } -fn f64_to_u32<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - let f = f64::deserialize(deserializer)?; - Ok((f * 10.0).round() as u32) // 放大10倍保留一位小数精度 +#[derive(Serialize)] +struct Prediction { + answer: Vec, } -fn distance(a: &IrisDataUnknown, b: &IrisData) -> u32 { - let diff_sl = (a.sepal_length as i32 - b.sepal_length as i32).unsigned_abs(); - let diff_sw = (a.sepal_width as i32 - b.sepal_width as i32).unsigned_abs(); - let diff_pl = (a.petal_length as i32 - b.petal_length as i32).unsigned_abs(); - let diff_pw = (a.petal_width as i32 - b.petal_width as i32).unsigned_abs(); - - // Squared euclidean distance (avoiding sqrt for integer math) - diff_sl * diff_sl + diff_sw * diff_sw + diff_pl * diff_pl + diff_pw * diff_pw +fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() } -fn knn_classify(query: &IrisDataUnknown, labeled_data: &[IrisData], k: usize) -> String { - let mut distances: Vec<(u32, &String)> = labeled_data +fn knn_classify(query: &[f64], data: &[Vec], k: usize) -> Vec { + let mut distances: Vec<(f64, usize)> = data .iter() - .map(|data| (distance(query, data), &data.species)) + .enumerate() + .map(|(idx, point)| (euclidean_distance(query, point), idx + 1)) .collect(); - distances.sort_by_key(|&(dist, _)| dist); + distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); - let mut species_count = HashMap::new(); - for &(_, species) in distances.iter().take(k) { - *species_count.entry(species.clone()).or_insert(0) += 1; - } - - species_count - .into_iter() - .max_by_key(|&(_, count)| count) - .map(|(species, _)| species) - .unwrap_or_else(|| "Unknown".to_string()) -} - -fn load_from_csv(filename: &str) -> Result> { - let mut rdr = csv::Reader::from_path(filename)?; - let mut data = Vec::new(); - for result in rdr.deserialize() { - let record: IrisData = result?; - data.push(record); - } - Ok(data) -} - -fn split_dataset(dataset: &[IrisData], test_ratio: f64) -> (Vec, Vec) { - let mut shuffled = dataset.to_vec(); - shuffled.shuffle(&mut rng()); - - let test_size = (dataset.len() as f64 * test_ratio) as usize; - let train_size = dataset.len() - test_size; - - let train_data = shuffled[..train_size].to_vec(); - let test_data = shuffled[train_size..].to_vec(); - - (train_data, test_data) + distances.into_iter().take(k).map(|(_, idx)| idx).collect() } fn main() -> Result<()> { - let filename = "./dataset/Iris.csv"; - let data = load_from_csv(filename)?; - println!("Loaded {} records from {}", data.len(), filename); + let args = Args::parse(); - let rounds = 10; - let k = 5; - let test_ratio = 0.2; - let mut total_accuracy = 0.0; + let file = File::open(&args.dataset)?; + let reader = BufReader::new(file); - println!( - "Running {} rounds of KNN classification (k={}, test_ratio={})", - rounds, k, test_ratio - ); + let mut results = Vec::new(); - for round in 1..=rounds { - let (train_data, test_data) = split_dataset(&data, test_ratio); - let mut correct = 0; + for line in reader.lines() { + let line = line?; + let dataset: Dataset = serde_json::from_str(&line)?; - for test_item in &test_data { - let query = IrisDataUnknown { - sepal_length: test_item.sepal_length, - sepal_width: test_item.sepal_width, - petal_length: test_item.petal_length, - petal_width: test_item.petal_width, - }; - - let predicted = knn_classify(&query, &train_data, k); - if predicted == test_item.species { - correct += 1; - } - } - - let accuracy = correct as f64 / test_data.len() as f64; - total_accuracy += accuracy; - println!( - "Round {}: Accuracy = {:.2}% ({}/{} correct)", - round, - accuracy * 100.0, - correct, - test_data.len() - ); + let nearest = knn_classify(&dataset.query, &dataset.data, 10); + results.push(Prediction { answer: nearest }); } - let avg_accuracy = total_accuracy / rounds as f64; - println!( - "\nAverage accuracy over {} rounds: {:.2}%", - rounds, - avg_accuracy * 100.0 - ); + let mut output_file = File::create(&args.predictions)?; + for result in results { + writeln!(output_file, "{}", serde_json::to_string(&result)?)?; + } Ok(()) } + diff --git a/test_oj.py b/test_oj.py new file mode 100755 index 0000000..35a3970 --- /dev/null +++ b/test_oj.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +""" +模拟OJ评分系统的测试脚本 +用于测试FHE KNN算法的正确性 +""" + +import json +import random +import subprocess +import tempfile +import os +import time +from typing import List, Tuple + +def generate_test_data(num_data_points: int = 100, dimensions: int = 10) -> dict: + """生成测试数据""" + print(f"生成测试数据: {num_data_points}个数据点, {dimensions}维向量") + + # 生成随机数据点 + data = [] + for _ in range(num_data_points): + point = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)] + data.append(point) + + # 生成查询向量 + query = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)] + + return { + "query": query, + "data": data + } + +def write_test_data(test_data: dict, file_path: str): + """将测试数据写入JSONL文件""" + with open(file_path, 'w') as f: + json.dump(test_data, f) + +def run_program(program_name: str, dataset_path: str, predictions_path: str) -> Tuple[List[int], float]: + """运行KNN程序并返回结果和耗时""" + print(f"运行程序: {program_name}") + + try: + # 先编译程序(不计时) + print(f"编译{program_name}程序...") + compile_cmd = ["cargo", "build", "--release", "--bin", program_name] + compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn") + + if compile_result.returncode != 0: + print(f"编译失败: {compile_result.stderr}") + return [], 0.0 + + # 构建运行命令 + if program_name == "plain": + cmd = ["./target/release/plain", "--dataset", dataset_path, "--predictions", predictions_path] + elif program_name == "enc": + cmd = ["./target/release/enc", "--dataset", dataset_path, "--predictions", predictions_path] + else: + raise ValueError(f"未知程序: {program_name}") + + # 计时运行程序 + print(f"运行{program_name}程序...") + start_time = time.time() + result = subprocess.run(cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn") + end_time = time.time() + elapsed_time = end_time - start_time + + if result.returncode != 0: + print(f"程序运行失败: {result.stderr}") + return [], elapsed_time + + # 读取结果 + with open(predictions_path, 'r') as f: + line = f.readline().strip() + prediction = json.loads(line) + return prediction["answer"], elapsed_time + + except Exception as e: + end_time = time.time() + elapsed_time = end_time - start_time + print(f"运行程序时出错: {e}") + return [], elapsed_time + +def calculate_accuracy(correct_answer: List[int], test_answer: List[int]) -> float: + """计算正确率""" + if not test_answer: + return 0.0 + + correct_count = len(set(correct_answer) & set(test_answer)) + total_count = len(correct_answer) + + return correct_count / total_count + +def main(): + print("=" * 50) + print("FHE KNN 测试系统") + print("=" * 50) + + # 生成测试数据 + test_data = generate_test_data(100, 10) + + # 创建临时文件 + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as test_file: + test_file_path = test_file.name + write_test_data(test_data, test_file_path) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as plain_result: + plain_result_path = plain_result.name + + with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as enc_result: + enc_result_path = enc_result.name + + try: + # 运行plain程序获取标准答案 + print("\\n1. 运行plain程序计算标准答案...") + correct_answer, plain_time = run_program("plain", test_file_path, plain_result_path) + + if not correct_answer: + print("❌ plain程序运行失败") + return + + print(f"✅ plain程序完成, 耗时: {plain_time:.2f}秒") + print(f"标准答案: {correct_answer}") + + # 运行enc程序 + print("\\n2. 运行enc程序进行密文计算...") + test_answer, enc_time = run_program("enc", test_file_path, enc_result_path) + + if not test_answer: + print("❌ enc程序运行失败") + return + + print(f"✅ enc程序完成, 耗时: {enc_time:.2f}秒") + print(f"测试答案: {test_answer}") + + # 计算正确率 + accuracy = calculate_accuracy(correct_answer, test_answer) + + # 输出结果 + print("\\n" + "=" * 50) + print("测试结果") + print("=" * 50) + print(f"标准答案: {correct_answer}") + print(f"测试答案: {test_answer}") + print(f"正确率: {accuracy:.1%}") + print(f"plain耗时: {plain_time:.2f}秒") + print(f"enc耗时: {enc_time:.2f}秒") + print(f"性能比: {enc_time/plain_time:.1f}x") + + # 判断是否通过 + if accuracy >= 0.9: + print("✅ 测试通过! (正确率≥90%)") + else: + print("❌ 测试失败! (正确率<90%)") + + # 显示错误的答案 + if accuracy < 1.0: + correct_set = set(correct_answer) + test_set = set(test_answer) + wrong_answers = test_set - correct_set + missed_answers = correct_set - test_set + + print("\\n错误分析:") + if wrong_answers: + print(f"错误答案: {sorted(wrong_answers)}") + if missed_answers: + print(f"遗漏答案: {sorted(missed_answers)}") + + finally: + # 清理临时文件 + for file_path in [test_file_path, plain_result_path, enc_result_path]: + if os.path.exists(file_path): + os.unlink(file_path) + +if __name__ == "__main__": + main() \ No newline at end of file