Initial Rust project setup with dependencies and dataset
- Add Cargo.toml with TFHE, CSV, and Serde dependencies - Add .gitignore for Rust target directory - Include Iris dataset for machine learning experiments - Add plain KNN implementation binary - Update LICENSE to MIT and improve README 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
141
src/bin/plain.rs
Normal file
141
src/bin/plain.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
use anyhow::Result;
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct IrisData {
|
||||
#[serde(rename = "SepalLengthCm", deserialize_with = "f64_to_u32")]
|
||||
sepal_length: u32,
|
||||
#[serde(rename = "SepalWidthCm", deserialize_with = "f64_to_u32")]
|
||||
sepal_width: u32,
|
||||
#[serde(rename = "PetalLengthCm", deserialize_with = "f64_to_u32")]
|
||||
petal_length: u32,
|
||||
#[serde(rename = "PetalWidthCm", deserialize_with = "f64_to_u32")]
|
||||
petal_width: u32,
|
||||
#[serde(rename = "Species")]
|
||||
species: String,
|
||||
}
|
||||
|
||||
struct IrisDataUnknown {
|
||||
sepal_length: u32,
|
||||
sepal_width: u32,
|
||||
petal_length: u32,
|
||||
petal_width: u32,
|
||||
}
|
||||
|
||||
fn f64_to_u32<'de, D>(deserializer: D) -> Result<u32, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let f = f64::deserialize(deserializer)?;
|
||||
Ok((f * 10.0).round() as u32) // 放大10倍保留一位小数精度
|
||||
}
|
||||
|
||||
fn distance(a: &IrisDataUnknown, b: &IrisData) -> u32 {
|
||||
let diff_sl = (a.sepal_length as i32 - b.sepal_length as i32).unsigned_abs();
|
||||
let diff_sw = (a.sepal_width as i32 - b.sepal_width as i32).unsigned_abs();
|
||||
let diff_pl = (a.petal_length as i32 - b.petal_length as i32).unsigned_abs();
|
||||
let diff_pw = (a.petal_width as i32 - b.petal_width as i32).unsigned_abs();
|
||||
|
||||
// Squared euclidean distance (avoiding sqrt for integer math)
|
||||
diff_sl * diff_sl + diff_sw * diff_sw + diff_pl * diff_pl + diff_pw * diff_pw
|
||||
}
|
||||
|
||||
fn knn_classify(query: &IrisDataUnknown, labeled_data: &[IrisData], k: usize) -> String {
|
||||
let mut distances: Vec<(u32, &String)> = labeled_data
|
||||
.iter()
|
||||
.map(|data| (distance(query, data), &data.species))
|
||||
.collect();
|
||||
|
||||
distances.sort_by_key(|&(dist, _)| dist);
|
||||
|
||||
let mut species_count = HashMap::new();
|
||||
for &(_, species) in distances.iter().take(k) {
|
||||
*species_count.entry(species.clone()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
species_count
|
||||
.into_iter()
|
||||
.max_by_key(|&(_, count)| count)
|
||||
.map(|(species, _)| species)
|
||||
.unwrap_or_else(|| "Unknown".to_string())
|
||||
}
|
||||
|
||||
fn load_from_csv(filename: &str) -> Result<Vec<IrisData>> {
|
||||
let mut rdr = csv::Reader::from_path(filename)?;
|
||||
let mut data = Vec::new();
|
||||
for result in rdr.deserialize() {
|
||||
let record: IrisData = result?;
|
||||
data.push(record);
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn split_dataset(dataset: &[IrisData], test_ratio: f64) -> (Vec<IrisData>, Vec<IrisData>) {
|
||||
let mut shuffled = dataset.to_vec();
|
||||
shuffled.shuffle(&mut rng());
|
||||
|
||||
let test_size = (dataset.len() as f64 * test_ratio) as usize;
|
||||
let train_size = dataset.len() - test_size;
|
||||
|
||||
let train_data = shuffled[..train_size].to_vec();
|
||||
let test_data = shuffled[train_size..].to_vec();
|
||||
|
||||
(train_data, test_data)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let filename = "./dataset/Iris.csv";
|
||||
let data = load_from_csv(filename)?;
|
||||
println!("Loaded {} records from {}", data.len(), filename);
|
||||
|
||||
let rounds = 10;
|
||||
let k = 5;
|
||||
let test_ratio = 0.2;
|
||||
let mut total_accuracy = 0.0;
|
||||
|
||||
println!(
|
||||
"Running {} rounds of KNN classification (k={}, test_ratio={})",
|
||||
rounds, k, test_ratio
|
||||
);
|
||||
|
||||
for round in 1..=rounds {
|
||||
let (train_data, test_data) = split_dataset(&data, test_ratio);
|
||||
let mut correct = 0;
|
||||
|
||||
for test_item in &test_data {
|
||||
let query = IrisDataUnknown {
|
||||
sepal_length: test_item.sepal_length,
|
||||
sepal_width: test_item.sepal_width,
|
||||
petal_length: test_item.petal_length,
|
||||
petal_width: test_item.petal_width,
|
||||
};
|
||||
|
||||
let predicted = knn_classify(&query, &train_data, k);
|
||||
if predicted == test_item.species {
|
||||
correct += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let accuracy = correct as f64 / test_data.len() as f64;
|
||||
total_accuracy += accuracy;
|
||||
println!(
|
||||
"Round {}: Accuracy = {:.2}% ({}/{} correct)",
|
||||
round,
|
||||
accuracy * 100.0,
|
||||
correct,
|
||||
test_data.len()
|
||||
);
|
||||
}
|
||||
|
||||
let avg_accuracy = total_accuracy / rounds as f64;
|
||||
println!(
|
||||
"\nAverage accuracy over {} rounds: {:.2}%",
|
||||
rounds,
|
||||
avg_accuracy * 100.0
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user