From 2a376b920b90fd5960c1914e5b464ef63bfb5fe2 Mon Sep 17 00:00:00 2001 From: sangge <2251250136@qq.com> Date: Thu, 24 Jul 2025 18:45:50 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=98=8E=E6=96=87?= =?UTF-8?q?=E7=89=88hnsw=E7=AE=97=E6=B3=95=E5=92=8C=E6=9C=80=E4=BC=98?= =?UTF-8?q?=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/bin/plain.rs | 41 ++++-- test_fhe_hnsw.py | 338 +++++++++++++++++++++++++++++++++++++++++++++++ testhnsw.py | 315 ++++++++++++++++++++++++++++++++++++++----- 3 files changed, 651 insertions(+), 43 deletions(-) create mode 100755 test_fhe_hnsw.py diff --git a/src/bin/plain.rs b/src/bin/plain.rs index d2f76c2..c331999 100644 --- a/src/bin/plain.rs +++ b/src/bin/plain.rs @@ -15,6 +15,16 @@ struct Args { dataset: String, #[arg(long, default_value = "./dataset/answer1.jsonl")] predictions: String, + #[arg(long, default_value = "8", help = "Max connections per node (M parameter)")] + max_connections: usize, + #[arg(long, default_value = "3", help = "Max levels in HNSW graph")] + max_level: usize, + #[arg(long, default_value = "0.6", help = "Level selection probability")] + level_prob: f32, + #[arg(long, default_value = "1", help = "ef parameter for upper layer search")] + ef_upper: usize, + #[arg(long, default_value = "10", help = "ef parameter for bottom layer search")] + ef_bottom: usize, } #[derive(Deserialize)] @@ -81,17 +91,17 @@ impl Ord for OrderedFloat { } impl HNSWGraph { - fn new() -> Self { + fn new(max_level: usize, max_connections: usize) -> Self { Self { nodes: Vec::new(), entry_point: None, - max_level: 3, - max_connections: 16, + max_level, + max_connections, } } - fn insert_node(&mut self, vector: Vec) -> usize { - let level = self.select_level(); + fn insert_node(&mut self, vector: Vec, level_prob: f32) -> usize { + let level = self.select_level(level_prob); let node_id = self.nodes.len(); let node = HNSWNode { @@ -200,10 +210,10 @@ impl HNSWGraph { results.into_iter().map(|(_, idx)| idx).collect() } - fn select_level(&self) -> usize { + fn select_level(&self, level_prob: f32) -> usize { let mut rng = rand::rng(); let mut level = 0; - while level < self.max_level && rng.random::() < 0.5 { + while level < self.max_level && rng.random::() < level_prob { level += 1; } level @@ -276,23 +286,32 @@ fn main() -> Result<()> { let reader = BufReader::new(file); let mut results = Vec::new(); + println!("🔧 HNSW Parameters:"); + println!(" max_connections: {}", args.max_connections); + println!(" max_level: {}", args.max_level); + println!(" level_prob: {}", args.level_prob); + println!(" ef_upper: {}", args.ef_upper); + println!(" ef_bottom: {}", args.ef_bottom); + for line in reader.lines() { let line = line?; let dataset: Dataset = serde_json::from_str(&line)?; - let mut hnsw = HNSWGraph::new(); + let mut hnsw = HNSWGraph::new(args.max_level, args.max_connections); for data_point in &dataset.data { - hnsw.insert_node(data_point.clone()); + hnsw.insert_node(data_point.clone(), args.level_prob); } let nearest = if let Some(entry_point) = hnsw.entry_point { let mut search_results = vec![entry_point]; + // 上层搜索使用ef_upper参数 for layer in (1..=hnsw.nodes[entry_point].level).rev() { - search_results = hnsw.search_layer(&dataset.query, search_results, 1, layer); + search_results = hnsw.search_layer(&dataset.query, search_results, args.ef_upper, layer); } - let final_results = hnsw.search_layer(&dataset.query, search_results, 16, 0); + // 底层搜索使用ef_bottom参数 + let final_results = hnsw.search_layer(&dataset.query, search_results, args.ef_bottom, 0); final_results .into_iter() .take(10) diff --git a/test_fhe_hnsw.py b/test_fhe_hnsw.py new file mode 100755 index 0000000..b99d544 --- /dev/null +++ b/test_fhe_hnsw.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +""" +FHE HNSW并行测试脚本 - 多进程版本 +测试密文态HNSW实现的稳定性和准确率 +支持多进程并行执行,充分利用服务器CPU资源 +运行100次,记录≥90%准确率的成功率和结果长度分布 +""" + +import subprocess +import json +import time +import os +import multiprocessing as mp +from pathlib import Path +from collections import defaultdict + +# 正确答案 +CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28] + + +def calculate_accuracy(result, correct): + """计算准确率:匹配元素数量 / 总元素数量""" + matches = len(set(result) & set(correct)) + return matches / len(correct) * 100 + + +def run_single_test(test_id, data_bit_width="12", timeout_minutes=30): + """运行单次FHE HNSW测试(供多进程调用)""" + # 为每个进程创建独立的输出文件 + output_file = f"./test_fhe_output_{test_id}_{os.getpid()}.jsonl" + + cmd = [ + "./enc", + "--algorithm", + "hnsw", + "--data-bit-width", + data_bit_width, + "--predictions", + output_file, + ] + + start_time = time.time() + + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=timeout_minutes * 60 + ) + + test_time = time.time() - start_time + + if result.returncode != 0: + return { + "test_id": test_id, + "result": None, + "error": f"Command failed: {result.stderr[:200]}", + "time_minutes": test_time / 60, + } + + # 读取结果 + if not Path(output_file).exists(): + return { + "test_id": test_id, + "result": None, + "error": "Output file not found", + "time_minutes": test_time / 60, + } + + with open(output_file, "r") as f: + output = json.load(f) + answer = output["answer"] + + # 清理临时文件 + Path(output_file).unlink(missing_ok=True) + + # 计算准确率 + accuracy = calculate_accuracy(answer, CORRECT_ANSWER) + + return { + "test_id": test_id, + "result": answer, + "length": len(answer), + "accuracy": accuracy, + "success": accuracy >= 90, + "time_minutes": test_time / 60, + "error": None, + } + + except subprocess.TimeoutExpired: + # 清理临时文件 + Path(output_file).unlink(missing_ok=True) + return { + "test_id": test_id, + "result": None, + "error": "timeout", + "time_minutes": timeout_minutes, + } + except Exception as e: + # 清理临时文件 + Path(output_file).unlink(missing_ok=True) + return { + "test_id": test_id, + "result": None, + "error": str(e), + "time_minutes": (time.time() - start_time) / 60, + } + + +def print_progress(completed, total, start_time): + """打印进度信息""" + if completed == 0: + return + + elapsed = time.time() - start_time + avg_time = elapsed / completed + remaining_time = avg_time * (total - completed) + + print( + f"\r🔄 进度: {completed}/{total} ({completed/total*100:.1f}%) " + f"| 已用时: {elapsed/3600:.1f}h | 预计剩余: {remaining_time/3600:.1f}h", + end="", + flush=True, + ) + + +def main(): + """主测试函数""" + print("🚀 FHE HNSW 并行批量测试脚本") + print(f"🎯 正确答案: {CORRECT_ANSWER}") + print("📊 运行100次测试,记录准确率和结果长度分布") + print("🔧 支持多进程并行执行") + + # 检查enc二进制文件是否存在 + if not Path("./enc").exists(): + print("\n❌ 找不到 ./enc 二进制文件,请先编译项目:") + print(" cargo build --release --bin enc") + return + + # 获取CPU核心数并设置进程数 + cpu_cores = mp.cpu_count() + # 考虑到FHE运算的内存密集性,使用核心数的一半避免内存不足 + num_processes = 8 + timeout_minutes = 45 # 增加超时时间到45分钟 + + print(f"💻 使用 {num_processes} 个并行进程") + print(f"⏰ 单次测试超时时间: {timeout_minutes} 分钟") + print( + f"⏱️ 预计总时间: {100 * 15 / num_processes / 60:.1f}-{100 * 25 / num_processes / 60:.1f} 小时" + ) + print() + + print("=" * 80) + print("🔬 开始FHE HNSW并行测试 (data-bit-width=12)") + print("=" * 80) + + start_time = time.time() + + # 创建进程池并执行测试 + test_ids = list(range(1, 101)) # 1-100 + + with mp.Pool(processes=num_processes) as pool: + # 创建异步任务 + async_results = [] + for test_id in test_ids: + async_result = pool.apply_async( + run_single_test, (test_id, "12", timeout_minutes) + ) + async_results.append(async_result) + + # 收集结果并显示进度 + results = [] + completed = 0 + + print_progress(completed, len(test_ids), start_time) + + for async_result in async_results: + try: + result = async_result.get( + timeout=timeout_minutes * 60 + 60 + ) # 额外1分钟缓冲 + results.append(result) + completed += 1 + print_progress(completed, len(test_ids), start_time) + except mp.TimeoutError: + # 进程级别超时 + results.append( + { + "test_id": len(results) + 1, + "result": None, + "error": "process_timeout", + "time_minutes": timeout_minutes, + } + ) + completed += 1 + print_progress(completed, len(test_ids), start_time) + + print() # 换行 + total_elapsed = time.time() - start_time + + # 分析结果 + print("\n" + "=" * 80) + print("📈 测试结果分析") + print("=" * 80) + + # 统计变量 + valid_results = [r for r in results if r["result"] is not None] + success_count = sum(1 for r in valid_results if r["success"]) + length_distribution = defaultdict(int) + error_distribution = defaultdict(int) + + # 统计错误类型 + for r in results: + if r["error"]: + error_type = r["error"] + if "timeout" in error_type.lower(): + error_distribution["timeout"] += 1 + elif "failed" in error_type.lower(): + error_distribution["failed"] += 1 + else: + error_distribution["other"] += 1 + else: + length_distribution[r["length"]] += 1 + + # 基本统计 + total_tests = len(results) + valid_tests = len(valid_results) + + if valid_tests == 0: + print("❌ 没有有效的测试结果") + return + + success_rate = success_count / valid_tests * 100 + avg_accuracy = sum(r["accuracy"] for r in valid_results) / valid_tests + avg_length = sum(r["length"] for r in valid_results) / valid_tests + avg_time_per_test = sum(r["time_minutes"] for r in results) / total_tests + + print(f"总测试次数: {total_tests}") + print(f"有效测试次数: {valid_tests}") + print(f"成功次数 (≥90%准确率): {success_count}") + print(f"成功率: {success_rate:.1f}%") + print(f"平均准确率: {avg_accuracy:.1f}%") + print(f"平均结果长度: {avg_length:.1f}") + print(f"平均每次测试时间: {avg_time_per_test:.1f}分钟") + print(f"总测试时间: {total_elapsed/3600:.1f}小时") + print(f"并行加速比: {100 * avg_time_per_test / 60 / (total_elapsed/3600):.1f}x") + + # 错误统计 + if error_distribution: + print("\n❌ 错误分布:") + for error_type, count in error_distribution.items(): + print(f" {error_type}: {count}次") + + # 结果长度分布 + if length_distribution: + print("\n📊 结果长度分布:") + for length in sorted(length_distribution.keys()): + count = length_distribution[length] + percentage = count / valid_tests * 100 + bar = "█" * (count // 2) if count > 0 else "" + print(f" 长度 {length:2d}: {count:3d}次 ({percentage:5.1f}%) {bar}") + + # 结论 + print() + if success_rate >= 50: + print("✅ 测试通过! FHE HNSW实现稳定性良好") + print(f" - 成功率: {success_rate:.1f}% (≥50%)") + print(f" - 平均准确率: {avg_accuracy:.1f}%") + if avg_length >= 9.5: + print(f" - 平均结果长度: {avg_length:.1f} (接近10)") + else: + print(f" - 平均结果长度: {avg_length:.1f} (需要改进)") + else: + print("❌ 测试未通过,需要进一步优化") + print(f" - 成功率: {success_rate:.1f}% (<50%)") + print(f" - 平均准确率: {avg_accuracy:.1f}%") + + # 保存详细结果 + report_file = "fhe_hnsw_parallel_test_report.json" + with open(report_file, "w") as f: + json.dump( + { + "summary": { + "total_tests": total_tests, + "valid_tests": valid_tests, + "success_count": success_count, + "success_rate": success_rate, + "avg_accuracy": avg_accuracy, + "avg_length": avg_length, + "avg_time_minutes": avg_time_per_test, + "total_time_hours": total_elapsed / 3600, + "num_processes": num_processes, + "cpu_cores": cpu_cores, + "speedup": 100 * avg_time_per_test / 60 / (total_elapsed / 3600), + }, + "length_distribution": dict(length_distribution), + "error_distribution": dict(error_distribution), + "detailed_results": results, + "correct_answer": CORRECT_ANSWER, + }, + f, + indent=2, + ) + + print(f"\n📁 详细报告已保存到: {report_file}") + + # 显示最好和最坏的几个结果 + if valid_results: + print("\n🏆 最高准确率的5个结果:") + best_results = sorted(valid_results, key=lambda x: x["accuracy"], reverse=True)[ + :5 + ] + for i, r in enumerate(best_results, 1): + print( + f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)" + ) + + print("\n⚠️ 最低准确率的5个结果:") + worst_results = sorted(valid_results, key=lambda x: x["accuracy"])[:5] + for i, r in enumerate(worst_results, 1): + print( + f" {i}. 测试#{r['test_id']:3d}: 准确率{r['accuracy']:5.1f}% 长度{r['length']:2d} ({r['time_minutes']:4.1f}分钟)" + ) + + +if __name__ == "__main__": + try: + # 设置多进程启动方法(Linux上默认是fork,但spawn更安全) + mp.set_start_method("spawn", force=True) + main() + except KeyboardInterrupt: + print("\n\n⏹️ 测试被用户中断") + # 清理可能的临时文件 + for f in Path(".").glob("test_fhe_output_*.jsonl"): + f.unlink(missing_ok=True) + except Exception as e: + print(f"\n💥 程序异常: {e}") + # 清理可能的临时文件 + for f in Path(".").glob("test_fhe_output_*.jsonl"): + f.unlink(missing_ok=True) diff --git a/testhnsw.py b/testhnsw.py index 83a9aa5..8d68fdc 100644 --- a/testhnsw.py +++ b/testhnsw.py @@ -1,52 +1,303 @@ #!/usr/bin/env python3 +""" +HNSW参数优化测试脚本 - 控制变量法 +针对100个10维向量,单次查询的场景进行参数调优 +每个配置运行100次,记录≥90%准确率的成功率 +""" + import subprocess import json -from collections import Counter +from pathlib import Path + +# 正确答案 +CORRECT_ANSWER = [93, 94, 90, 27, 87, 50, 47, 40, 78, 28] -def load_answers(filepath): - with open(filepath, "r") as f: - data = json.load(f) - return data["answer"] +def calculate_accuracy(result, correct): + """计算准确率:匹配元素数量 / 总元素数量""" + matches = len(set(result) & set(correct)) + return matches / len(correct) * 100 -def run_plain_binary(): - result = subprocess.run( - ["cargo", "r", "-r", "--bin", "plain"], capture_output=True, text=True, cwd="." - ) - if result.returncode == 0: - # The program outputs the same results as answer1.jsonl - return load_answers("dataset/answer1.jsonl") - return None +def run_hnsw_test(max_connections, max_level, level_prob, ef_upper, ef_bottom): + """运行HNSW测试并返回结果""" + cmd = [ + "cargo", + "run", + "--bin", + "plain", + "--", + "--max-connections", + str(max_connections), + "--max-level", + str(max_level), + "--level-prob", + str(level_prob), + "--ef-upper", + str(ef_upper), + "--ef-bottom", + str(ef_bottom), + "--predictions", + "./test_output.jsonl", + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + if result.returncode != 0: + return None + + with open("./test_output.jsonl", "r") as f: + output = json.load(f) + return output["answer"] + except Exception as e: + return None -def compare_answers(predictions, ground_truth): - if not predictions or len(predictions) != len(ground_truth): - return 0 - return sum(1 for p, gt in zip(predictions, ground_truth) if p == gt) +def test_config_stability(config, runs=100): + """测试配置的稳定性:运行100次,统计≥90%准确率的次数""" + success_count = 0 + accuracies = [] + + print(f" 测试中... ", end="", flush=True) + + for i in range(runs): + if i % 20 == 0 and i > 0: + print(f"{i}/100 ", end="", flush=True) + + result = run_hnsw_test( + config["max_connections"], + config["max_level"], + config["level_prob"], + config["ef_upper"], + config["ef_bottom"], + ) + + if result is not None: + accuracy = calculate_accuracy(result, CORRECT_ANSWER) + accuracies.append(accuracy) + if accuracy >= 90: + success_count += 1 + + success_rate = success_count / len(accuracies) * 100 if accuracies else 0 + avg_accuracy = sum(accuracies) / len(accuracies) if accuracies else 0 + + print(f"完成!") + return success_rate, avg_accuracy, len(accuracies) def main(): - ground_truth = load_answers("dataset/answer.jsonl") + """主测试函数 - 测试最优参数组合""" + print("🚀 HNSW最优参数组合测试") + print(f"🎯 正确答案: {CORRECT_ANSWER}") + print("📊 测试最优参数组合运行100次,记录≥90%准确率的成功率") + print() - num_runs = 100 - accuracies = [] + # 最优参数组合 + optimal_config = { + "max_connections": 8, + "max_level": 3, + "level_prob": 0.6, + "ef_upper": 1, + "ef_bottom": 10, + } - for i in range(num_runs): - predictions = run_plain_binary() - if predictions is not None: - accuracy = compare_answers(predictions, ground_truth) - accuracies.append(accuracy) - - print(f"\nResults ({len(accuracies)} runs):") + print("=" * 80) + print("🏆 测试最优参数组合") + print("=" * 80) print( - f"Min: {min(accuracies)}, Max: {max(accuracies)}, Mean: {sum(accuracies)/len(accuracies):.2f}" + f"参数配置: M={optimal_config['max_connections']}, L={optimal_config['max_level']}, " + + f"P={optimal_config['level_prob']}, ef_upper={optimal_config['ef_upper']}, " + + f"ef_bottom={optimal_config['ef_bottom']}" + ) + print() + + print("最优组合 ", end=" ") + success_rate, avg_acc, total_runs = test_config_stability(optimal_config) + status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" + + print( + f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" ) - counter = Counter(accuracies) - print("Distribution:") - for correct_count in sorted(counter.keys()): - print(f" {correct_count} correct: {counter[correct_count]} times") + print() + print("=" * 80) + print("📈 测试结果分析") + print("=" * 80) + + if success_rate >= 50: + print("✅ 最优参数组合测试通过!") + print(f" - 成功率: {success_rate:.1f}% (≥50%)") + print(f" - 平均准确率: {avg_acc:.1f}%") + else: + print("❌ 最优参数组合未达到50%成功率") + print(f" - 成功率: {success_rate:.1f}% (<50%)") + print(f" - 平均准确率: {avg_acc:.1f}%") + + return + + # 以下是控制变量测试代码 - 已注释掉 + """ + # 控制变量测试 - 每次只改变一个参数 + test_results = [] + + print("=" * 80) + print("1测试 max_connections (M) 参数影响") + print("=" * 80) + + for m in [4, 8, 12, 16, 20]: + config = base_config.copy() + config["max_connections"] = m + config_name = f"M={m}" + + print(f"{config_name:<15}", end=" ") + success_rate, avg_acc, total_runs = test_config_stability(config) + status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" + + print( + f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" + ) + + test_results.append( + { + "param": "max_connections", + "value": m, + "config": config, + "success_rate": success_rate, + "avg_accuracy": avg_acc, + "pass": success_rate >= 50, + } + ) + + print() + print("=" * 80) + print("2测试 max_level (L) 参数影响") + print("=" * 80) + + for l in [3, 4, 5, 6, 7]: + config = base_config.copy() + config["max_level"] = l + config_name = f"L={l}" + + print(f"{config_name:<15}", end=" ") + success_rate, avg_acc, total_runs = test_config_stability(config) + status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" + + print( + f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" + ) + + test_results.append( + { + "param": "max_level", + "value": l, + "config": config, + "success_rate": success_rate, + "avg_accuracy": avg_acc, + "pass": success_rate >= 50, + } + ) + + print() + print("=" * 80) + print("3测试 level_prob (P) 参数影响") + print("=" * 80) + + for p in [0.2, 0.3, 0.4, 0.5, 0.6]: + config = base_config.copy() + config["level_prob"] = p + config_name = f"P={p}" + + print(f"{config_name:<15}", end=" ") + success_rate, avg_acc, total_runs = test_config_stability(config) + status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" + + print( + f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" + ) + + test_results.append( + { + "param": "level_prob", + "value": p, + "config": config, + "success_rate": success_rate, + "avg_accuracy": avg_acc, + "pass": success_rate >= 50, + } + ) + + print() + print("=" * 80) + print("4️⃣ 测试 ef_bottom 参数影响") + print("=" * 80) + + for ef in [10, 16, 25, 40, 60]: + config = base_config.copy() + config["ef_bottom"] = ef + config_name = f"ef_b={ef}" + + print(f"{config_name:<15}", end=" ") + success_rate, avg_acc, total_runs = test_config_stability(config) + status = "✅ PASS" if success_rate >= 50 else "❌ FAIL" + + print( + f"成功率: {success_rate:5.1f}% ({total_runs}次) 平均准确率: {avg_acc:5.1f}% {status}" + ) + + test_results.append( + { + "param": "ef_bottom", + "value": ef, + "config": config, + "success_rate": success_rate, + "avg_accuracy": avg_acc, + "pass": success_rate >= 50, + } + ) + + # 总结报告 + print() + print("=" * 80) + print("📈 测试总结") + print("=" * 80) + + passed_configs = [r for r in test_results if r["pass"]] + print(f"总测试配置数: {len(test_results)}") + print( + f"成功率≥50%的配置: {len(passed_configs)} ({len(passed_configs)/len(test_results)*100:.1f}%)" + ) + print() + + if passed_configs: + print("🏆 成功率≥50%的参数配置:") + for result in passed_configs: + print( + f" {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%, 平均准确率 {result['avg_accuracy']:.1f}%" + ) + + # 找出每个参数的最佳值 + print() + print("🎯 各参数最佳值推荐:") + for param_name in ["max_connections", "max_level", "level_prob", "ef_bottom"]: + param_results = [r for r in test_results if r["param"] == param_name] + best_result = max(param_results, key=lambda x: x["success_rate"]) + print( + f" {param_name}: {best_result['value']} (成功率: {best_result['success_rate']:.1f}%)" + ) + else: + print("❌ 没有配置的成功率达到50%") + print("📊 按成功率排序的前5个配置:") + top_configs = sorted( + test_results, key=lambda x: x["success_rate"], reverse=True + )[:5] + for i, result in enumerate(top_configs, 1): + print( + f" {i}. {result['param']}={result['value']}: 成功率 {result['success_rate']:.1f}%" + ) + + # 清理临时文件 + Path("./test_output.jsonl").unlink(missing_ok=True) + """ if __name__ == "__main__":