Initial Rust project setup with dependencies and dataset

- Add Cargo.toml with TFHE, CSV, and Serde dependencies
- Add .gitignore for Rust target directory
- Include Iris dataset for machine learning experiments
- Add plain KNN implementation binary
- Update LICENSE to MIT and improve README

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-06-29 18:11:40 +08:00
parent e73cc21d01
commit 4c155d8bf4
7 changed files with 1051 additions and 229 deletions

141
src/bin/plain.rs Normal file
View File

@@ -0,0 +1,141 @@
use anyhow::Result;
use rand::rng;
use rand::seq::SliceRandom;
use serde::{Deserialize, Deserializer};
use std::collections::HashMap;
#[derive(Debug, Deserialize, Clone)]
struct IrisData {
#[serde(rename = "SepalLengthCm", deserialize_with = "f64_to_u32")]
sepal_length: u32,
#[serde(rename = "SepalWidthCm", deserialize_with = "f64_to_u32")]
sepal_width: u32,
#[serde(rename = "PetalLengthCm", deserialize_with = "f64_to_u32")]
petal_length: u32,
#[serde(rename = "PetalWidthCm", deserialize_with = "f64_to_u32")]
petal_width: u32,
#[serde(rename = "Species")]
species: String,
}
struct IrisDataUnknown {
sepal_length: u32,
sepal_width: u32,
petal_length: u32,
petal_width: u32,
}
fn f64_to_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
where
D: Deserializer<'de>,
{
let f = f64::deserialize(deserializer)?;
Ok((f * 10.0).round() as u32) // 放大10倍保留一位小数精度
}
fn distance(a: &IrisDataUnknown, b: &IrisData) -> u32 {
let diff_sl = (a.sepal_length as i32 - b.sepal_length as i32).unsigned_abs();
let diff_sw = (a.sepal_width as i32 - b.sepal_width as i32).unsigned_abs();
let diff_pl = (a.petal_length as i32 - b.petal_length as i32).unsigned_abs();
let diff_pw = (a.petal_width as i32 - b.petal_width as i32).unsigned_abs();
// Squared euclidean distance (avoiding sqrt for integer math)
diff_sl * diff_sl + diff_sw * diff_sw + diff_pl * diff_pl + diff_pw * diff_pw
}
fn knn_classify(query: &IrisDataUnknown, labeled_data: &[IrisData], k: usize) -> String {
let mut distances: Vec<(u32, &String)> = labeled_data
.iter()
.map(|data| (distance(query, data), &data.species))
.collect();
distances.sort_by_key(|&(dist, _)| dist);
let mut species_count = HashMap::new();
for &(_, species) in distances.iter().take(k) {
*species_count.entry(species.clone()).or_insert(0) += 1;
}
species_count
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(species, _)| species)
.unwrap_or_else(|| "Unknown".to_string())
}
fn load_from_csv(filename: &str) -> Result<Vec<IrisData>> {
let mut rdr = csv::Reader::from_path(filename)?;
let mut data = Vec::new();
for result in rdr.deserialize() {
let record: IrisData = result?;
data.push(record);
}
Ok(data)
}
fn split_dataset(dataset: &[IrisData], test_ratio: f64) -> (Vec<IrisData>, Vec<IrisData>) {
let mut shuffled = dataset.to_vec();
shuffled.shuffle(&mut rng());
let test_size = (dataset.len() as f64 * test_ratio) as usize;
let train_size = dataset.len() - test_size;
let train_data = shuffled[..train_size].to_vec();
let test_data = shuffled[train_size..].to_vec();
(train_data, test_data)
}
fn main() -> Result<()> {
let filename = "./dataset/Iris.csv";
let data = load_from_csv(filename)?;
println!("Loaded {} records from {}", data.len(), filename);
let rounds = 10;
let k = 5;
let test_ratio = 0.2;
let mut total_accuracy = 0.0;
println!(
"Running {} rounds of KNN classification (k={}, test_ratio={})",
rounds, k, test_ratio
);
for round in 1..=rounds {
let (train_data, test_data) = split_dataset(&data, test_ratio);
let mut correct = 0;
for test_item in &test_data {
let query = IrisDataUnknown {
sepal_length: test_item.sepal_length,
sepal_width: test_item.sepal_width,
petal_length: test_item.petal_length,
petal_width: test_item.petal_width,
};
let predicted = knn_classify(&query, &train_data, k);
if predicted == test_item.species {
correct += 1;
}
}
let accuracy = correct as f64 / test_data.len() as f64;
total_accuracy += accuracy;
println!(
"Round {}: Accuracy = {:.2}% ({}/{} correct)",
round,
accuracy * 100.0,
correct,
test_data.len()
);
}
let avg_accuracy = total_accuracy / rounds as f64;
println!(
"\nAverage accuracy over {} rounds: {:.2}%",
rounds,
avg_accuracy * 100.0
);
Ok(())
}