remove cache
This commit is contained in:
176
src/cache.rs
176
src/cache.rs
@@ -1,176 +0,0 @@
|
||||
use crate::data::EncryptedNeighbor;
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use tfhe::{ClientKey, ServerKey};
|
||||
|
||||
const CACHE_DIR: &str = ".cache";
|
||||
const KEYS_FILE: &str = ".cache/keys.json";
|
||||
const DISTANCES_FILE: &str = ".cache/distances.json";
|
||||
|
||||
/// 密钥对的序列化表示
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct KeyCache {
|
||||
pub client_key_bytes: Vec<u8>,
|
||||
pub server_key_bytes: Vec<u8>,
|
||||
}
|
||||
|
||||
/// 距离缓存结构
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct DistanceCache {
|
||||
pub query_hash: String,
|
||||
pub data_hash: String,
|
||||
pub distances: Vec<(Vec<u8>, Vec<u8>)>, // (encrypted_distance_bytes, encrypted_index_bytes)
|
||||
}
|
||||
|
||||
/// 创建缓存目录
|
||||
pub fn ensure_cache_dir() -> Result<()> {
|
||||
if !Path::new(CACHE_DIR).exists() {
|
||||
fs::create_dir(CACHE_DIR)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 保存密钥对到缓存
|
||||
pub fn save_keys(client_key: &ClientKey, server_key: &ServerKey) -> Result<()> {
|
||||
ensure_cache_dir()?;
|
||||
|
||||
println!("💾 Saving keys to cache...");
|
||||
|
||||
// 序列化密钥
|
||||
let client_key_bytes = bincode::serialize(client_key)?;
|
||||
let server_key_bytes = bincode::serialize(server_key)?;
|
||||
|
||||
let key_cache = KeyCache {
|
||||
client_key_bytes,
|
||||
server_key_bytes,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&key_cache)?;
|
||||
fs::write(KEYS_FILE, json)?;
|
||||
|
||||
println!("✅ Keys saved to cache");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 从缓存加载密钥对
|
||||
pub fn load_keys() -> Result<Option<(ClientKey, ServerKey)>> {
|
||||
if !Path::new(KEYS_FILE).exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
println!("🔑 Loading keys from cache...");
|
||||
|
||||
let json = fs::read_to_string(KEYS_FILE)?;
|
||||
let key_cache: KeyCache = serde_json::from_str(&json)?;
|
||||
|
||||
let client_key: ClientKey = bincode::deserialize(&key_cache.client_key_bytes)?;
|
||||
let server_key: ServerKey = bincode::deserialize(&key_cache.server_key_bytes)?;
|
||||
|
||||
println!("✅ Keys loaded from cache");
|
||||
Ok(Some((client_key, server_key)))
|
||||
}
|
||||
|
||||
/// 计算数据的简单哈希值
|
||||
pub fn hash_data(data: &[Vec<f64>]) -> String {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
for row in data {
|
||||
for &value in row {
|
||||
((value * 1000.0).round() as i64).hash(&mut hasher);
|
||||
}
|
||||
}
|
||||
format!("{:x}", hasher.finish())
|
||||
}
|
||||
|
||||
/// 保存距离计算结果到缓存
|
||||
pub fn save_distances(
|
||||
query: &[f64],
|
||||
data: &[Vec<f64>],
|
||||
distances: &[EncryptedNeighbor],
|
||||
) -> Result<()> {
|
||||
ensure_cache_dir()?;
|
||||
|
||||
println!("💾 Saving distances to cache...");
|
||||
|
||||
let query_hash = hash_data(&[query.to_vec()]);
|
||||
let data_hash = hash_data(data);
|
||||
|
||||
// 直接序列化加密的距离和索引
|
||||
let distance_pairs: Vec<(Vec<u8>, Vec<u8>)> = distances
|
||||
.iter()
|
||||
.map(|neighbor| {
|
||||
let distance_bytes = bincode::serialize(&neighbor.distance).unwrap();
|
||||
let index_bytes = bincode::serialize(&neighbor.index).unwrap();
|
||||
(distance_bytes, index_bytes)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let distance_cache = DistanceCache {
|
||||
query_hash,
|
||||
data_hash,
|
||||
distances: distance_pairs,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&distance_cache)?;
|
||||
fs::write(DISTANCES_FILE, json)?;
|
||||
|
||||
println!("✅ Distances saved to cache");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 从缓存加载距离计算结果
|
||||
pub fn load_distances(
|
||||
query: &[f64],
|
||||
data: &[Vec<f64>],
|
||||
) -> Result<Option<Vec<EncryptedNeighbor>>> {
|
||||
if !Path::new(DISTANCES_FILE).exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let query_hash = hash_data(&[query.to_vec()]);
|
||||
let data_hash = hash_data(data);
|
||||
|
||||
let json = fs::read_to_string(DISTANCES_FILE)?;
|
||||
let distance_cache: DistanceCache = serde_json::from_str(&json)?;
|
||||
|
||||
// 检查哈希是否匹配
|
||||
if distance_cache.query_hash != query_hash || distance_cache.data_hash != data_hash {
|
||||
println!("⚠️ Cache mismatch, will recalculate distances");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
println!("🚀 Loading distances from cache...");
|
||||
|
||||
// 直接反序列化加密的距离和索引
|
||||
let encrypted_neighbors: Vec<EncryptedNeighbor> = distance_cache
|
||||
.distances
|
||||
.iter()
|
||||
.map(|(distance_bytes, index_bytes)| EncryptedNeighbor {
|
||||
distance: bincode::deserialize(distance_bytes).unwrap(),
|
||||
index: bincode::deserialize(index_bytes).unwrap(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
println!("✅ Distances loaded from cache");
|
||||
Ok(Some(encrypted_neighbors))
|
||||
}
|
||||
|
||||
/// 清除所有缓存
|
||||
pub fn clear_cache() -> Result<()> {
|
||||
if Path::new(KEYS_FILE).exists() {
|
||||
fs::remove_file(KEYS_FILE)?;
|
||||
}
|
||||
if Path::new(DISTANCES_FILE).exists() {
|
||||
fs::remove_file(DISTANCES_FILE)?;
|
||||
}
|
||||
if Path::new(CACHE_DIR).exists() {
|
||||
fs::remove_dir(CACHE_DIR)?;
|
||||
}
|
||||
println!("🗑️ Cache cleared");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
pub mod data;
|
||||
pub mod algorithms;
|
||||
pub mod logging;
|
||||
pub mod cache;
|
||||
|
||||
// Re-export commonly used types and functions
|
||||
pub use data::*;
|
||||
pub use algorithms::*;
|
||||
pub use logging::*;
|
||||
pub use cache::*;
|
||||
pub use logging::*;
|
||||
Reference in New Issue
Block a user