Implement plain KNN classifier and testing infrastructure
- Add plain KNN implementation with JSONL data processing - Create Docker deployment setup with python:3.13-slim base - Add comprehensive OJ-style testing system with accuracy validation - Update README with detailed scoring mechanism explanation - Add run.sh script following competition manual requirements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
160
src/bin/plain.rs
160
src/bin/plain.rs
@@ -1,141 +1,71 @@
|
||||
use anyhow::Result;
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::collections::HashMap;
|
||||
use clap::Parser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct IrisData {
|
||||
#[serde(rename = "SepalLengthCm", deserialize_with = "f64_to_u32")]
|
||||
sepal_length: u32,
|
||||
#[serde(rename = "SepalWidthCm", deserialize_with = "f64_to_u32")]
|
||||
sepal_width: u32,
|
||||
#[serde(rename = "PetalLengthCm", deserialize_with = "f64_to_u32")]
|
||||
petal_length: u32,
|
||||
#[serde(rename = "PetalWidthCm", deserialize_with = "f64_to_u32")]
|
||||
petal_width: u32,
|
||||
#[serde(rename = "Species")]
|
||||
species: String,
|
||||
#[derive(Parser)]
|
||||
#[command(name = "hfe_knn")]
|
||||
#[command(about = "FHE-based KNN classifier")]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
dataset: String,
|
||||
#[arg(long)]
|
||||
predictions: String,
|
||||
}
|
||||
|
||||
struct IrisDataUnknown {
|
||||
sepal_length: u32,
|
||||
sepal_width: u32,
|
||||
petal_length: u32,
|
||||
petal_width: u32,
|
||||
#[derive(Deserialize)]
|
||||
struct Dataset {
|
||||
query: Vec<f64>,
|
||||
data: Vec<Vec<f64>>,
|
||||
}
|
||||
|
||||
fn f64_to_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let f = f64::deserialize(deserializer)?;
|
||||
Ok((f * 10.0).round() as u32) // 放大10倍保留一位小数精度
|
||||
#[derive(Serialize)]
|
||||
struct Prediction {
|
||||
answer: Vec<usize>,
|
||||
}
|
||||
|
||||
fn distance(a: &IrisDataUnknown, b: &IrisData) -> u32 {
|
||||
let diff_sl = (a.sepal_length as i32 - b.sepal_length as i32).unsigned_abs();
|
||||
let diff_sw = (a.sepal_width as i32 - b.sepal_width as i32).unsigned_abs();
|
||||
let diff_pl = (a.petal_length as i32 - b.petal_length as i32).unsigned_abs();
|
||||
let diff_pw = (a.petal_width as i32 - b.petal_width as i32).unsigned_abs();
|
||||
|
||||
// Squared euclidean distance (avoiding sqrt for integer math)
|
||||
diff_sl * diff_sl + diff_sw * diff_sw + diff_pl * diff_pl + diff_pw * diff_pw
|
||||
fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f64>()
|
||||
.sqrt()
|
||||
}
|
||||
|
||||
fn knn_classify(query: &IrisDataUnknown, labeled_data: &[IrisData], k: usize) -> String {
|
||||
let mut distances: Vec<(u32, &String)> = labeled_data
|
||||
fn knn_classify(query: &[f64], data: &[Vec<f64>], k: usize) -> Vec<usize> {
|
||||
let mut distances: Vec<(f64, usize)> = data
|
||||
.iter()
|
||||
.map(|data| (distance(query, data), &data.species))
|
||||
.enumerate()
|
||||
.map(|(idx, point)| (euclidean_distance(query, point), idx + 1))
|
||||
.collect();
|
||||
|
||||
distances.sort_by_key(|&(dist, _)| dist);
|
||||
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
|
||||
|
||||
let mut species_count = HashMap::new();
|
||||
for &(_, species) in distances.iter().take(k) {
|
||||
*species_count.entry(species.clone()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
species_count
|
||||
.into_iter()
|
||||
.max_by_key(|&(_, count)| count)
|
||||
.map(|(species, _)| species)
|
||||
.unwrap_or_else(|| "Unknown".to_string())
|
||||
}
|
||||
|
||||
fn load_from_csv(filename: &str) -> Result<Vec<IrisData>> {
|
||||
let mut rdr = csv::Reader::from_path(filename)?;
|
||||
let mut data = Vec::new();
|
||||
for result in rdr.deserialize() {
|
||||
let record: IrisData = result?;
|
||||
data.push(record);
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn split_dataset(dataset: &[IrisData], test_ratio: f64) -> (Vec<IrisData>, Vec<IrisData>) {
|
||||
let mut shuffled = dataset.to_vec();
|
||||
shuffled.shuffle(&mut rng());
|
||||
|
||||
let test_size = (dataset.len() as f64 * test_ratio) as usize;
|
||||
let train_size = dataset.len() - test_size;
|
||||
|
||||
let train_data = shuffled[..train_size].to_vec();
|
||||
let test_data = shuffled[train_size..].to_vec();
|
||||
|
||||
(train_data, test_data)
|
||||
distances.into_iter().take(k).map(|(_, idx)| idx).collect()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let filename = "./dataset/Iris.csv";
|
||||
let data = load_from_csv(filename)?;
|
||||
println!("Loaded {} records from {}", data.len(), filename);
|
||||
let args = Args::parse();
|
||||
|
||||
let rounds = 10;
|
||||
let k = 5;
|
||||
let test_ratio = 0.2;
|
||||
let mut total_accuracy = 0.0;
|
||||
let file = File::open(&args.dataset)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
println!(
|
||||
"Running {} rounds of KNN classification (k={}, test_ratio={})",
|
||||
rounds, k, test_ratio
|
||||
);
|
||||
let mut results = Vec::new();
|
||||
|
||||
for round in 1..=rounds {
|
||||
let (train_data, test_data) = split_dataset(&data, test_ratio);
|
||||
let mut correct = 0;
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
let dataset: Dataset = serde_json::from_str(&line)?;
|
||||
|
||||
for test_item in &test_data {
|
||||
let query = IrisDataUnknown {
|
||||
sepal_length: test_item.sepal_length,
|
||||
sepal_width: test_item.sepal_width,
|
||||
petal_length: test_item.petal_length,
|
||||
petal_width: test_item.petal_width,
|
||||
};
|
||||
|
||||
let predicted = knn_classify(&query, &train_data, k);
|
||||
if predicted == test_item.species {
|
||||
correct += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let accuracy = correct as f64 / test_data.len() as f64;
|
||||
total_accuracy += accuracy;
|
||||
println!(
|
||||
"Round {}: Accuracy = {:.2}% ({}/{} correct)",
|
||||
round,
|
||||
accuracy * 100.0,
|
||||
correct,
|
||||
test_data.len()
|
||||
);
|
||||
let nearest = knn_classify(&dataset.query, &dataset.data, 10);
|
||||
results.push(Prediction { answer: nearest });
|
||||
}
|
||||
|
||||
let avg_accuracy = total_accuracy / rounds as f64;
|
||||
println!(
|
||||
"\nAverage accuracy over {} rounds: {:.2}%",
|
||||
rounds,
|
||||
avg_accuracy * 100.0
|
||||
);
|
||||
let mut output_file = File::create(&args.predictions)?;
|
||||
for result in results {
|
||||
writeln!(output_file, "{}", serde_json::to_string(&result)?)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user