Implement plain KNN classifier and testing infrastructure

- Add plain KNN implementation with JSONL data processing
- Create Docker deployment setup with python:3.13-slim base
- Add comprehensive OJ-style testing system with accuracy validation
- Update README with detailed scoring mechanism explanation
- Add run.sh script following competition manual requirements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-07-05 21:11:22 +08:00
parent 136583715e
commit d58adda9ab
8 changed files with 505 additions and 147 deletions

View File

@@ -1,141 +1,71 @@
use anyhow::Result;
use rand::rng;
use rand::seq::SliceRandom;
use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
use clap::Parser;
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
#[derive(Debug, Deserialize, Clone)]
struct IrisData {
#[serde(rename = "SepalLengthCm", deserialize_with = "f64_to_u32")]
sepal_length: u32,
#[serde(rename = "SepalWidthCm", deserialize_with = "f64_to_u32")]
sepal_width: u32,
#[serde(rename = "PetalLengthCm", deserialize_with = "f64_to_u32")]
petal_length: u32,
#[serde(rename = "PetalWidthCm", deserialize_with = "f64_to_u32")]
petal_width: u32,
#[serde(rename = "Species")]
species: String,
#[derive(Parser)]
#[command(name = "hfe_knn")]
#[command(about = "FHE-based KNN classifier")]
struct Args {
#[arg(long)]
dataset: String,
#[arg(long)]
predictions: String,
}
struct IrisDataUnknown {
sepal_length: u32,
sepal_width: u32,
petal_length: u32,
petal_width: u32,
#[derive(Deserialize)]
struct Dataset {
query: Vec<f64>,
data: Vec<Vec<f64>>,
}
fn f64_to_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
where
D: Deserializer<'de>,
{
let f = f64::deserialize(deserializer)?;
Ok((f * 10.0).round() as u32) // 放大10倍保留一位小数精度
#[derive(Serialize)]
struct Prediction {
answer: Vec<usize>,
}
fn distance(a: &IrisDataUnknown, b: &IrisData) -> u32 {
let diff_sl = (a.sepal_length as i32 - b.sepal_length as i32).unsigned_abs();
let diff_sw = (a.sepal_width as i32 - b.sepal_width as i32).unsigned_abs();
let diff_pl = (a.petal_length as i32 - b.petal_length as i32).unsigned_abs();
let diff_pw = (a.petal_width as i32 - b.petal_width as i32).unsigned_abs();
// Squared euclidean distance (avoiding sqrt for integer math)
diff_sl * diff_sl + diff_sw * diff_sw + diff_pl * diff_pl + diff_pw * diff_pw
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
fn knn_classify(query: &IrisDataUnknown, labeled_data: &[IrisData], k: usize) -> String {
let mut distances: Vec<(u32, &String)> = labeled_data
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
let mut distances: Vec<(f64, usize)> = data
.iter()
.map(|data| (distance(query, data), &data.species))
.enumerate()
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
.collect();
distances.sort_by_key(|&(dist, _)| dist);
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let mut species_count = HashMap::new();
for &(_, species) in distances.iter().take(k) {
*species_count.entry(species.clone()).or_insert(0) += 1;
}
species_count
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(species, _)| species)
.unwrap_or_else(|| "Unknown".to_string())
}
fn load_from_csv(filename: &str) -> Result<Vec<IrisData>> {
let mut rdr = csv::Reader::from_path(filename)?;
let mut data = Vec::new();
for result in rdr.deserialize() {
let record: IrisData = result?;
data.push(record);
}
Ok(data)
}
fn split_dataset(dataset: &[IrisData], test_ratio: f64) -> (Vec<IrisData>, Vec<IrisData>) {
let mut shuffled = dataset.to_vec();
shuffled.shuffle(&mut rng());
let test_size = (dataset.len() as f64 * test_ratio) as usize;
let train_size = dataset.len() - test_size;
let train_data = shuffled[..train_size].to_vec();
let test_data = shuffled[train_size..].to_vec();
(train_data, test_data)
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
}
fn main() -> Result<()> {
let filename = "./dataset/Iris.csv";
let data = load_from_csv(filename)?;
println!("Loaded {} records from {}", data.len(), filename);
let args = Args::parse();
let rounds = 10;
let k = 5;
let test_ratio = 0.2;
let mut total_accuracy = 0.0;
let file = File::open(&args.dataset)?;
let reader = BufReader::new(file);
println!(
"Running {} rounds of KNN classification (k={}, test_ratio={})",
rounds, k, test_ratio
);
let mut results = Vec::new();
for round in 1..=rounds {
let (train_data, test_data) = split_dataset(&data, test_ratio);
let mut correct = 0;
for line in reader.lines() {
let line = line?;
let dataset: Dataset = serde_json::from_str(&line)?;
for test_item in &test_data {
let query = IrisDataUnknown {
sepal_length: test_item.sepal_length,
sepal_width: test_item.sepal_width,
petal_length: test_item.petal_length,
petal_width: test_item.petal_width,
};
let predicted = knn_classify(&query, &train_data, k);
if predicted == test_item.species {
correct += 1;
}
}
let accuracy = correct as f64 / test_data.len() as f64;
total_accuracy += accuracy;
println!(
"Round {}: Accuracy = {:.2}% ({}/{} correct)",
round,
accuracy * 100.0,
correct,
test_data.len()
);
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
results.push(Prediction { answer: nearest });
}
let avg_accuracy = total_accuracy / rounds as f64;
println!(
"\nAverage accuracy over {} rounds: {:.2}%",
rounds,
avg_accuracy * 100.0
);
let mut output_file = File::create(&args.predictions)?;
for result in results {
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
}
Ok(())
}