Implement plain KNN classifier and testing infrastructure
- Add plain KNN implementation with JSONL data processing - Create Docker deployment setup with python:3.13-slim base - Add comprehensive OJ-style testing system with accuracy validation - Update README with detailed scoring mechanism explanation - Add run.sh script following competition manual requirements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
136583715e
commit
d58adda9ab
241
Cargo.lock
generated
241
Cargo.lock
generated
@ -23,6 +23,56 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
"anstyle-query",
|
||||
"anstyle-wincon",
|
||||
"colorchoice",
|
||||
"is_terminal_polyfill",
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-parse"
|
||||
version = "0.2.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
|
||||
dependencies = [
|
||||
"utf8parse",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-query"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9"
|
||||
dependencies = [
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-wincon"
|
||||
version = "3.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"once_cell_polyfill",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.98"
|
||||
@ -61,9 +111,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.18.1"
|
||||
version = "3.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee"
|
||||
checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
@ -87,6 +137,52 @@ dependencies = [
|
||||
"inout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
"strsim",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.40"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.7.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675"
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
version = "0.2.17"
|
||||
@ -131,27 +227,6 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf"
|
||||
dependencies = [
|
||||
"csv-core",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv-core"
|
||||
version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@ -241,9 +316,10 @@ name = "hfe_knn"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"csv",
|
||||
"clap",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tfhe",
|
||||
]
|
||||
|
||||
@ -256,6 +332,12 @@ dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "is_terminal_polyfill"
|
||||
version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.14.0"
|
||||
@ -340,6 +422,12 @@ version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "once_cell_polyfill"
|
||||
version = "1.70.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad"
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.15"
|
||||
@ -489,6 +577,18 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.140"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha3"
|
||||
version = "0.10.8"
|
||||
@ -499,6 +599,12 @@ dependencies = [
|
||||
"keccak",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
||||
|
||||
[[package]]
|
||||
name = "strum"
|
||||
version = "0.27.1"
|
||||
@ -534,9 +640,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tfhe"
|
||||
version = "1.2.0"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cb9a09347f586b9e49336ddad761706dfbfc414c44a8892b5a38666a47736a9"
|
||||
checksum = "11e74f14e2812ac6a2fe516f2c090431b614d589c24c6e3bca3fe763d4de928e"
|
||||
dependencies = [
|
||||
"aligned-vec",
|
||||
"bincode",
|
||||
@ -558,9 +664,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tfhe-csprng"
|
||||
version = "0.5.0"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b1a0b71a91c9e36e9e1e8e6a528a056086b01e193e173472ae59cfaac56fdad"
|
||||
checksum = "06450766ae375cd305281c5d691a41d8877e40d3cc7510a33244775d62f60ad4"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"libc",
|
||||
@ -628,6 +734,12 @@ version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||
|
||||
[[package]]
|
||||
name = "utf8parse"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
@ -706,6 +818,79 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.59.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_gnullvm",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rt"
|
||||
version = "0.39.0"
|
||||
|
@ -4,8 +4,9 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "~1.2.0", features = ["integer"] }
|
||||
csv = "1.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
anyhow = "1.0"
|
||||
tfhe = { version = "1", features = ["integer"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
anyhow = "1"
|
||||
rand = "0.9"
|
||||
clap = { version = "4.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
22
Dockerfile
Normal file
22
Dockerfile
Normal file
@ -0,0 +1,22 @@
|
||||
FROM python:3.13-slim
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /home/admin/predict \
|
||||
/home/admin/data \
|
||||
/home/admin/workspace/job/logs \
|
||||
/home/admin/workspace/job/output/predictions \
|
||||
/home/admin/workspace/job/input
|
||||
|
||||
# Copy files to required locations according to manual
|
||||
COPY run.sh /home/admin/predict/run.sh
|
||||
RUN chmod 777 /home/admin/predict/run.sh
|
||||
|
||||
# Copy the compiled binary as 'test' executable
|
||||
COPY target/release/hfe_knn /home/admin/predict/test
|
||||
RUN chmod +x /home/admin/predict/test
|
||||
|
||||
# Copy training data
|
||||
COPY dataset/train.jsonl /home/admin/data/data.jsonl
|
||||
|
||||
# Set default command
|
||||
CMD ["/home/admin/predict/run.sh"]
|
22
README.md
22
README.md
@ -8,3 +8,25 @@ knn算法关键是使用一个距离函数来计算样本之间的距离,常
|
||||
|
||||
全同态加密计划使用TFHE-rs, 简化后流程无须交互,仅在单个程序内模拟即可。
|
||||
评分详情见[参赛手册](./manual.md)
|
||||
|
||||
## 评分机制
|
||||
|
||||
### 正确率计算
|
||||
- 算法需要返回10个最近邻向量的索引
|
||||
- 正确率 = 正确的索引数量 / 10
|
||||
- 例如:10个结果中有9个正确,正确率 = 90%
|
||||
|
||||
### 评分规则
|
||||
- **门槛要求**:正确率必须≥90%(即10个结果中至少9个正确)
|
||||
- **排名依据**:达到门槛后,按总耗时排名(越快越好)
|
||||
- **淘汰机制**:正确率<90%直接得0分
|
||||
|
||||
### "正确"的定义
|
||||
- 比赛方有标准答案(真实的10个最近邻)
|
||||
- 算法结果与标准答案比较
|
||||
- 顺序不重要,只要索引正确即可
|
||||
|
||||
### 数据格式
|
||||
- 训练数据:JSONL格式,每行包含一个query向量和data数组
|
||||
- 输出格式:`{"answer": [索引1, 索引2, ..., 索引10]}`
|
||||
- 索引从1开始编号
|
||||
|
1
dataset/answer.jsonl
Normal file
1
dataset/answer.jsonl
Normal file
@ -0,0 +1 @@
|
||||
{"answer":[93,94,90,27,87,50,47,16,6,40]}
|
22
run.sh
Normal file
22
run.sh
Normal file
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
SCRIPT_DIR=$(dirname "$0")
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
# 根据运行环境选择文件路径
|
||||
if [ "$ALIPAY_APP_ENV" = "prod" ]; then
|
||||
PREDICTIONS_RESULT_FILE="/home/admin/workspace/job/output/predictions/predictions.jsonl"
|
||||
DATASET_FILE="/home/admin/workspace/job/input/test.jsonl"
|
||||
else
|
||||
PREDICTIONS_RESULT_FILE="${PARENT_DIR}/data/predictions.jsonl"
|
||||
DATASET_FILE="${PARENT_DIR}/data/data.jsonl"
|
||||
fi
|
||||
|
||||
echo $PREDICTIONS_RESULT_FILE
|
||||
|
||||
#以上内容**不可**修改!
|
||||
#选手仅可修改下面的test为自己的实际比赛代码入口
|
||||
chmod +x "${SCRIPT_DIR}/test"
|
||||
|
||||
"${SCRIPT_DIR}/test" \
|
||||
--dataset "$DATASET_FILE" \
|
||||
--predictions "$PREDICTIONS_RESULT_FILE"
|
160
src/bin/plain.rs
160
src/bin/plain.rs
@ -1,141 +1,71 @@
|
||||
use anyhow::Result;
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::collections::HashMap;
|
||||
use clap::Parser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct IrisData {
|
||||
#[serde(rename = "SepalLengthCm", deserialize_with = "f64_to_u32")]
|
||||
sepal_length: u32,
|
||||
#[serde(rename = "SepalWidthCm", deserialize_with = "f64_to_u32")]
|
||||
sepal_width: u32,
|
||||
#[serde(rename = "PetalLengthCm", deserialize_with = "f64_to_u32")]
|
||||
petal_length: u32,
|
||||
#[serde(rename = "PetalWidthCm", deserialize_with = "f64_to_u32")]
|
||||
petal_width: u32,
|
||||
#[serde(rename = "Species")]
|
||||
species: String,
|
||||
#[derive(Parser)]
|
||||
#[command(name = "hfe_knn")]
|
||||
#[command(about = "FHE-based KNN classifier")]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
dataset: String,
|
||||
#[arg(long)]
|
||||
predictions: String,
|
||||
}
|
||||
|
||||
struct IrisDataUnknown {
|
||||
sepal_length: u32,
|
||||
sepal_width: u32,
|
||||
petal_length: u32,
|
||||
petal_width: u32,
|
||||
#[derive(Deserialize)]
|
||||
struct Dataset {
|
||||
query: Vec<f64>,
|
||||
data: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
fn f64_to_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let f = f64::deserialize(deserializer)?;
|
||||
Ok((f * 10.0).round() as u32) // 放大10倍保留一位小数精度
|
||||
#[derive(Serialize)]
|
||||
struct Prediction {
|
||||
answer: Vec<usize>,
|
||||
}
|
||||
|
||||
fn distance(a: &IrisDataUnknown, b: &IrisData) -> u32 {
|
||||
let diff_sl = (a.sepal_length as i32 - b.sepal_length as i32).unsigned_abs();
|
||||
let diff_sw = (a.sepal_width as i32 - b.sepal_width as i32).unsigned_abs();
|
||||
let diff_pl = (a.petal_length as i32 - b.petal_length as i32).unsigned_abs();
|
||||
let diff_pw = (a.petal_width as i32 - b.petal_width as i32).unsigned_abs();
|
||||
|
||||
// Squared euclidean distance (avoiding sqrt for integer math)
|
||||
diff_sl * diff_sl + diff_sw * diff_sw + diff_pl * diff_pl + diff_pw * diff_pw
|
||||
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f64>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn knn_classify(query: &IrisDataUnknown, labeled_data: &[IrisData], k: usize) -> String {
|
||||
let mut distances: Vec<(u32, &String)> = labeled_data
|
||||
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||
let mut distances: Vec<(f64, usize)> = data
|
||||
.iter()
|
||||
.map(|data| (distance(query, data), &data.species))
|
||||
.enumerate()
|
||||
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
|
||||
.collect();
|
||||
|
||||
distances.sort_by_key(|&(dist, _)| dist);
|
||||
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
|
||||
let mut species_count = HashMap::new();
|
||||
for &(_, species) in distances.iter().take(k) {
|
||||
*species_count.entry(species.clone()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
species_count
|
||||
.into_iter()
|
||||
.max_by_key(|&(_, count)| count)
|
||||
.map(|(species, _)| species)
|
||||
.unwrap_or_else(|| "Unknown".to_string())
|
||||
}
|
||||
|
||||
fn load_from_csv(filename: &str) -> Result<Vec<IrisData>> {
|
||||
let mut rdr = csv::Reader::from_path(filename)?;
|
||||
let mut data = Vec::new();
|
||||
for result in rdr.deserialize() {
|
||||
let record: IrisData = result?;
|
||||
data.push(record);
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn split_dataset(dataset: &[IrisData], test_ratio: f64) -> (Vec<IrisData>, Vec<IrisData>) {
|
||||
let mut shuffled = dataset.to_vec();
|
||||
shuffled.shuffle(&mut rng());
|
||||
|
||||
let test_size = (dataset.len() as f64 * test_ratio) as usize;
|
||||
let train_size = dataset.len() - test_size;
|
||||
|
||||
let train_data = shuffled[..train_size].to_vec();
|
||||
let test_data = shuffled[train_size..].to_vec();
|
||||
|
||||
(train_data, test_data)
|
||||
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let filename = "./dataset/Iris.csv";
|
||||
let data = load_from_csv(filename)?;
|
||||
println!("Loaded {} records from {}", data.len(), filename);
|
||||
let args = Args::parse();
|
||||
|
||||
let rounds = 10;
|
||||
let k = 5;
|
||||
let test_ratio = 0.2;
|
||||
let mut total_accuracy = 0.0;
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
println!(
|
||||
"Running {} rounds of KNN classification (k={}, test_ratio={})",
|
||||
rounds, k, test_ratio
|
||||
);
|
||||
let mut results = Vec::new();
|
||||
|
||||
for round in 1..=rounds {
|
||||
let (train_data, test_data) = split_dataset(&data, test_ratio);
|
||||
let mut correct = 0;
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
for test_item in &test_data {
|
||||
let query = IrisDataUnknown {
|
||||
sepal_length: test_item.sepal_length,
|
||||
sepal_width: test_item.sepal_width,
|
||||
petal_length: test_item.petal_length,
|
||||
petal_width: test_item.petal_width,
|
||||
};
|
||||
|
||||
let predicted = knn_classify(&query, &train_data, k);
|
||||
if predicted == test_item.species {
|
||||
correct += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let accuracy = correct as f64 / test_data.len() as f64;
|
||||
total_accuracy += accuracy;
|
||||
println!(
|
||||
"Round {}: Accuracy = {:.2}% ({}/{} correct)",
|
||||
round,
|
||||
accuracy * 100.0,
|
||||
correct,
|
||||
test_data.len()
|
||||
);
|
||||
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let avg_accuracy = total_accuracy / rounds as f64;
|
||||
println!(
|
||||
"\nAverage accuracy over {} rounds: {:.2}%",
|
||||
rounds,
|
||||
avg_accuracy * 100.0
|
||||
);
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
175
test_oj.py
Executable file
175
test_oj.py
Executable file
@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
模拟OJ评分系统的测试脚本
|
||||
用于测试FHE KNN算法的正确性
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
def generate_test_data(num_data_points: int = 100, dimensions: int = 10) -> dict:
|
||||
"""生成测试数据"""
|
||||
print(f"生成测试数据: {num_data_points}个数据点, {dimensions}维向量")
|
||||
|
||||
# 生成随机数据点
|
||||
data = []
|
||||
for _ in range(num_data_points):
|
||||
point = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
|
||||
data.append(point)
|
||||
|
||||
# 生成查询向量
|
||||
query = [round(random.uniform(-10, 10), 2) for _ in range(dimensions)]
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"data": data
|
||||
}
|
||||
|
||||
def write_test_data(test_data: dict, file_path: str):
|
||||
"""将测试数据写入JSONL文件"""
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(test_data, f)
|
||||
|
||||
def run_program(program_name: str, dataset_path: str, predictions_path: str) -> Tuple[List[int], float]:
|
||||
"""运行KNN程序并返回结果和耗时"""
|
||||
print(f"运行程序: {program_name}")
|
||||
|
||||
try:
|
||||
# 先编译程序(不计时)
|
||||
print(f"编译{program_name}程序...")
|
||||
compile_cmd = ["cargo", "build", "--release", "--bin", program_name]
|
||||
compile_result = subprocess.run(compile_cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
|
||||
|
||||
if compile_result.returncode != 0:
|
||||
print(f"编译失败: {compile_result.stderr}")
|
||||
return [], 0.0
|
||||
|
||||
# 构建运行命令
|
||||
if program_name == "plain":
|
||||
cmd = ["./target/release/plain", "--dataset", dataset_path, "--predictions", predictions_path]
|
||||
elif program_name == "enc":
|
||||
cmd = ["./target/release/enc", "--dataset", dataset_path, "--predictions", predictions_path]
|
||||
else:
|
||||
raise ValueError(f"未知程序: {program_name}")
|
||||
|
||||
# 计时运行程序
|
||||
print(f"运行{program_name}程序...")
|
||||
start_time = time.time()
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, cwd="/home/sangge/code/hfe_knn")
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"程序运行失败: {result.stderr}")
|
||||
return [], elapsed_time
|
||||
|
||||
# 读取结果
|
||||
with open(predictions_path, 'r') as f:
|
||||
line = f.readline().strip()
|
||||
prediction = json.loads(line)
|
||||
return prediction["answer"], elapsed_time
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"运行程序时出错: {e}")
|
||||
return [], elapsed_time
|
||||
|
||||
def calculate_accuracy(correct_answer: List[int], test_answer: List[int]) -> float:
|
||||
"""计算正确率"""
|
||||
if not test_answer:
|
||||
return 0.0
|
||||
|
||||
correct_count = len(set(correct_answer) & set(test_answer))
|
||||
total_count = len(correct_answer)
|
||||
|
||||
return correct_count / total_count
|
||||
|
||||
def main():
|
||||
print("=" * 50)
|
||||
print("FHE KNN 测试系统")
|
||||
print("=" * 50)
|
||||
|
||||
# 生成测试数据
|
||||
test_data = generate_test_data(100, 10)
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as test_file:
|
||||
test_file_path = test_file.name
|
||||
write_test_data(test_data, test_file_path)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as plain_result:
|
||||
plain_result_path = plain_result.name
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as enc_result:
|
||||
enc_result_path = enc_result.name
|
||||
|
||||
try:
|
||||
# 运行plain程序获取标准答案
|
||||
print("\\n1. 运行plain程序计算标准答案...")
|
||||
correct_answer, plain_time = run_program("plain", test_file_path, plain_result_path)
|
||||
|
||||
if not correct_answer:
|
||||
print("❌ plain程序运行失败")
|
||||
return
|
||||
|
||||
print(f"✅ plain程序完成, 耗时: {plain_time:.2f}秒")
|
||||
print(f"标准答案: {correct_answer}")
|
||||
|
||||
# 运行enc程序
|
||||
print("\\n2. 运行enc程序进行密文计算...")
|
||||
test_answer, enc_time = run_program("enc", test_file_path, enc_result_path)
|
||||
|
||||
if not test_answer:
|
||||
print("❌ enc程序运行失败")
|
||||
return
|
||||
|
||||
print(f"✅ enc程序完成, 耗时: {enc_time:.2f}秒")
|
||||
print(f"测试答案: {test_answer}")
|
||||
|
||||
# 计算正确率
|
||||
accuracy = calculate_accuracy(correct_answer, test_answer)
|
||||
|
||||
# 输出结果
|
||||
print("\\n" + "=" * 50)
|
||||
print("测试结果")
|
||||
print("=" * 50)
|
||||
print(f"标准答案: {correct_answer}")
|
||||
print(f"测试答案: {test_answer}")
|
||||
print(f"正确率: {accuracy:.1%}")
|
||||
print(f"plain耗时: {plain_time:.2f}秒")
|
||||
print(f"enc耗时: {enc_time:.2f}秒")
|
||||
print(f"性能比: {enc_time/plain_time:.1f}x")
|
||||
|
||||
# 判断是否通过
|
||||
if accuracy >= 0.9:
|
||||
print("✅ 测试通过! (正确率≥90%)")
|
||||
else:
|
||||
print("❌ 测试失败! (正确率<90%)")
|
||||
|
||||
# 显示错误的答案
|
||||
if accuracy < 1.0:
|
||||
correct_set = set(correct_answer)
|
||||
test_set = set(test_answer)
|
||||
wrong_answers = test_set - correct_set
|
||||
missed_answers = correct_set - test_set
|
||||
|
||||
print("\\n错误分析:")
|
||||
if wrong_answers:
|
||||
print(f"错误答案: {sorted(wrong_answers)}")
|
||||
if missed_answers:
|
||||
print(f"遗漏答案: {sorted(missed_answers)}")
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
for file_path in [test_file_path, plain_result_path, enc_result_path]:
|
||||
if os.path.exists(file_path):
|
||||
os.unlink(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user