From f03fc66c1e0ec26ec3904ad7c3e2fb9069fee21f Mon Sep 17 00:00:00 2001 From: sangge-redmi <2251250136@qq.com> Date: Thu, 3 Apr 2025 13:45:37 +0800 Subject: [PATCH] refactor: cleanup test code --- src/eth_logger_test.py | 8 --- tests/{test_client.py => client_test.py} | 0 tests/ecc_speed_comparison_test.py | 47 +++++++++++++ tests/ecc_speed_test.py | 37 ---------- tests/message_max_length_test.py | 78 +++++++++++++++++++++ tests/node_test.py | 1 - {src => tests}/pytest.ini | 0 tests/speed_test.py | 72 ------------------- tests/{lenth_test.py => tpre_speed_test.py} | 37 +++++----- tests/tpre_test.py | 21 +++--- 10 files changed, 152 insertions(+), 149 deletions(-) delete mode 100644 src/eth_logger_test.py rename tests/{test_client.py => client_test.py} (100%) create mode 100644 tests/ecc_speed_comparison_test.py delete mode 100644 tests/ecc_speed_test.py create mode 100644 tests/message_max_length_test.py rename {src => tests}/pytest.ini (100%) delete mode 100644 tests/speed_test.py rename tests/{lenth_test.py => tpre_speed_test.py} (75%) diff --git a/src/eth_logger_test.py b/src/eth_logger_test.py deleted file mode 100644 index 9cf9d1d..0000000 --- a/src/eth_logger_test.py +++ /dev/null @@ -1,8 +0,0 @@ -from eth_logger import call_eth_logger - -wallet_address = ( - "0xe02666Cb63b3645E7B03C9082a24c4c1D7C9EFf6" # 修改成要使用的钱包地址/私钥 -) -wallet_pk = "ae66ae3711a69079efd3d3e9b55f599ce7514eb29dfe4f9551404d3f361438c6" - -call_eth_logger(wallet_address, wallet_pk, "hello World") diff --git a/tests/test_client.py b/tests/client_test.py similarity index 100% rename from tests/test_client.py rename to tests/client_test.py diff --git a/tests/ecc_speed_comparison_test.py b/tests/ecc_speed_comparison_test.py new file mode 100644 index 0000000..9486d5a --- /dev/null +++ b/tests/ecc_speed_comparison_test.py @@ -0,0 +1,47 @@ +import sys +import os +import random + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) +from tpre import add, multiply, sm2p256v1 +import time + +# 生成元 +g = (sm2p256v1.Gx, sm2p256v1.Gy) + + +def test_rust_vs_python_multiply(): + mul_times = random.randint(0, sm2p256v1.N - 1) + # Rust实现 + start_time = time.time() + for _ in range(10): + _ = multiply(g, mul_times, 1) + rust_time = time.time() - start_time + print(f"\nRust multiply 执行时间: {rust_time:.6f} 秒") + + # Python实现 + start_time = time.time() + for _ in range(10): + _ = multiply(g, mul_times, 0) + python_time = time.time() - start_time + print(f"Python multiply 执行时间: {python_time:.6f} 秒") + + assert rust_time < python_time, "Rust实现应该比Python更快" + + +def test_rust_vs_python_add(): + # Rust实现 + start_time = time.time() + for _ in range(10): + _ = add(g, g, 1) + rust_time = time.time() - start_time + print(f"\nRust add 执行时间: {rust_time:.6f} 秒") + + # Python实现 + start_time = time.time() + for _ in range(10): + _ = add(g, g, 0) + python_time = time.time() - start_time + print(f"Python add 执行时间: {python_time:.6f} 秒") + + assert rust_time < python_time, "Rust实现应该比Python更快" diff --git a/tests/ecc_speed_test.py b/tests/ecc_speed_test.py deleted file mode 100644 index cd8f369..0000000 --- a/tests/ecc_speed_test.py +++ /dev/null @@ -1,37 +0,0 @@ -import sys -import os - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) -from tpre import add, multiply, sm2p256v1 -import time - -# 生成元 -g = (sm2p256v1.Gx, sm2p256v1.Gy) - -start_time = time.time() # 获取开始时间 -for i in range(10): - result = multiply(g, 10000, 1) # 执行函数 -end_time = time.time() # 获取结束时间 -elapsed_time = end_time - start_time # 计算执行时间 -print(f"rust multiply 执行时间: {elapsed_time:.6f} 秒") - -start_time = time.time() # 获取开始时间 -for i in range(10): - result = multiply(g, 10000, 0) # 执行函数 -end_time = time.time() # 获取结束时间 -elapsed_time = end_time - start_time # 计算执行时间 -print(f"python multiply 执行时间: {elapsed_time:.6f} 秒") - -start_time = time.time() # 获取开始时间 -for i in range(10): - result = add(g, g, 1) # 执行函数 -end_time = time.time() # 获取结束时间 -elapsed_time = end_time - start_time # 计算执行时间 -print(f"rust add 执行时间: {elapsed_time:.6f} 秒") - -start_time = time.time() # 获取开始时间 -for i in range(10): - result = add(g, g, 0) # 执行函数 -end_time = time.time() # 获取结束时间 -elapsed_time = end_time - start_time # 计算执行时间 -print(f"python add 执行时间: {elapsed_time:.6f} 秒") \ No newline at end of file diff --git a/tests/message_max_length_test.py b/tests/message_max_length_test.py new file mode 100644 index 0000000..9eee891 --- /dev/null +++ b/tests/message_max_length_test.py @@ -0,0 +1,78 @@ +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) +from tpre import ( + GenerateKeyPair, + Encrypt, + GenerateReKey, + ReEncrypt, + MergeCFrag, + DecryptFrags, +) + + +def test_tpre_performance_with_different_message_sizes(): + """测试不同消息大小下的TPRE性能""" + N = 20 + T = N // 2 + print(f"当前门限值: N = {N}, T = {T}") + + for i in range(1, 6): + total_time = 0 + # 1. 密钥生成和消息准备 + start_time = time.time() + pk_a, sk_a = GenerateKeyPair() + m = b"hello world" * pow(10, i) + print(f"明文长度:{len(m)}") + end_time = time.time() + elapsed_time = end_time - start_time + total_time += elapsed_time + print(f"密钥生成运行时间:{elapsed_time}秒") + + # 2. 加密 + start_time = time.time() + capsule_ct = Encrypt(pk_a, m) + end_time = time.time() + elapsed_time = end_time - start_time + total_time += elapsed_time + print(f"加密算法运行时间:{elapsed_time}秒") + + # 3. 接收方密钥生成 + pk_b, sk_b = GenerateKeyPair() + + # 4. 重加密密钥生成 + start_time = time.time() + id_tuple = tuple(range(N)) + rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) + end_time = time.time() + elapsed_time = end_time - start_time + total_time += elapsed_time + print(f"重加密密钥生成算法运行时间:{elapsed_time}秒") + + # 5. 重加密 + start_time = time.time() + cfrag_cts = [] + for rekey in rekeys: + cfrag_ct = ReEncrypt(rekey, capsule_ct) + cfrag_cts.append(cfrag_ct) + end_time = time.time() + elapsed_time = (end_time - start_time) / len(rekeys) + total_time += elapsed_time + print(f"重加密算法运行时间:{elapsed_time}秒") + + # 6. 解密 + start_time = time.time() + cfrags = MergeCFrag(cfrag_cts) + decrypted_m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) + end_time = time.time() + elapsed_time = end_time - start_time + total_time += elapsed_time + print(f"解密算法运行时间:{elapsed_time}秒") + + # 验证解密结果是否正确 + assert decrypted_m == m, f"解密失败,明文长度: {len(m)}" + print("成功解密") + print(f"算法总运行时间:{total_time}秒") + print() diff --git a/tests/node_test.py b/tests/node_test.py index 6be7bf6..42548db 100644 --- a/tests/node_test.py +++ b/tests/node_test.py @@ -3,7 +3,6 @@ import os import unittest from unittest.mock import patch, MagicMock, Mock import sys -import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) import node diff --git a/src/pytest.ini b/tests/pytest.ini similarity index 100% rename from src/pytest.ini rename to tests/pytest.ini diff --git a/tests/speed_test.py b/tests/speed_test.py deleted file mode 100644 index cbb8eca..0000000 --- a/tests/speed_test.py +++ /dev/null @@ -1,72 +0,0 @@ -import sys -import os - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) -from tpre import ( - GenerateKeyPair, - Encrypt, - GenerateReKey, - ReEncrypt, - MergeCFrag, - DecryptFrags, -) -import time - -N = 20 -T = N // 2 -print(f"当前门限值: N = {N}, T = {T}") - -total_time = 0 - -# 1 -start_time = time.time() -pk_a, sk_a = GenerateKeyPair() -m = b"hello world" -end_time = time.time() -elapsed_time = end_time - start_time -total_time += elapsed_time -print(f"密钥生成运行时间:{elapsed_time}秒") - -# 2 -start_time = time.time() -capsule_ct = Encrypt(pk_a, m) -end_time = time.time() -elapsed_time = end_time - start_time -total_time += elapsed_time -print(f"加密算法运行时间:{elapsed_time}秒") - -# 3 -pk_b, sk_b = GenerateKeyPair() - -# 5 -start_time = time.time() -id_tuple = tuple(range(N)) -rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) -end_time = time.time() -elapsed_time = end_time - start_time -total_time += elapsed_time -print(f"重加密密钥生成算法运行时间:{elapsed_time}秒") - -# 7 -start_time = time.time() -cfrag_cts = [] - -for rekey in rekeys: - cfrag_ct = ReEncrypt(rekey, capsule_ct) - cfrag_cts.append(cfrag_ct) -end_time = time.time() -elapsed_time = (end_time - start_time) / len(rekeys) -total_time += elapsed_time -print(f"重加密算法运行时间:{elapsed_time}秒") - -# 9 -start_time = time.time() -cfrags = MergeCFrag(cfrag_cts) -m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) -end_time = time.time() -elapsed_time = end_time - start_time -total_time += elapsed_time -print(f"解密算法运行时间:{elapsed_time}秒") -print("成功解密:", m) -print(f"算法总运行时间:{total_time}秒") -print() diff --git a/tests/lenth_test.py b/tests/tpre_speed_test.py similarity index 75% rename from tests/lenth_test.py rename to tests/tpre_speed_test.py index c27d23d..7cbb3c0 100644 --- a/tests/lenth_test.py +++ b/tests/tpre_speed_test.py @@ -1,5 +1,7 @@ import sys import os +import time + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from tpre import ( @@ -10,26 +12,25 @@ from tpre import ( MergeCFrag, DecryptFrags, ) -import time -N = 20 -T = N // 2 -print(f"当前门限值: N = {N}, T = {T}") -for i in range(1, 10): +def test_tpre_full_workflow_speed(): + """测试TPRE完整工作流程的运行速度""" + N = 20 + T = N // 2 + print(f"当前门限值: N = {N}, T = {T}") total_time = 0 - # 1 + # 1. 密钥生成 start_time = time.time() pk_a, sk_a = GenerateKeyPair() - m = b"hello world" * pow(10, i) - print(f"明文长度:{len(m)}") + m = b"hello world" end_time = time.time() elapsed_time = end_time - start_time total_time += elapsed_time print(f"密钥生成运行时间:{elapsed_time}秒") - # 2 + # 2. 加密 start_time = time.time() capsule_ct = Encrypt(pk_a, m) end_time = time.time() @@ -37,10 +38,10 @@ for i in range(1, 10): total_time += elapsed_time print(f"加密算法运行时间:{elapsed_time}秒") - # 3 + # 3. 接收方密钥生成 pk_b, sk_b = GenerateKeyPair() - # 5 + # 4. 重加密密钥生成 start_time = time.time() id_tuple = tuple(range(N)) rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) @@ -49,10 +50,9 @@ for i in range(1, 10): total_time += elapsed_time print(f"重加密密钥生成算法运行时间:{elapsed_time}秒") - # 7 + # 5. 重加密 start_time = time.time() cfrag_cts = [] - for rekey in rekeys: cfrag_ct = ReEncrypt(rekey, capsule_ct) cfrag_cts.append(cfrag_ct) @@ -61,14 +61,15 @@ for i in range(1, 10): total_time += elapsed_time print(f"重加密算法运行时间:{elapsed_time}秒") - # 9 + # 6. 合并和解密 start_time = time.time() cfrags = MergeCFrag(cfrag_cts) - m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) + decrypted_m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) end_time = time.time() elapsed_time = end_time - start_time total_time += elapsed_time print(f"解密算法运行时间:{elapsed_time}秒") - print("成功解密:") - print(f"算法总运行时间:{total_time}秒") - print() + + # 验证解密结果 + assert decrypted_m == m, "解密结果与原始消息不匹配" + print(f"成功解密: {decrypted_m}") diff --git a/tests/tpre_test.py b/tests/tpre_test.py index 8a90746..05b1935 100644 --- a/tests/tpre_test.py +++ b/tests/tpre_test.py @@ -1,6 +1,5 @@ import sys import os -import hashlib sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from tpre import ( @@ -84,19 +83,15 @@ class TestGenerateKeyPair(unittest.TestCase): self.assertGreater(secret_key, 0) -# class TestEncryptDecrypt(unittest.TestCase): -# def setUp(self): -# self.public_key, self.secret_key = GenerateKeyPair() -# self.message = b"Hello, world!" +class TestEncryptDecrypt(unittest.TestCase): + def setUp(self): + self.public_key, self.secret_key = GenerateKeyPair() + self.message = b"Hello, world!" -# def test_encrypt_decrypt(self): -# encrypted_message = Encrypt(self.public_key, self.message) -# # 使用 SHA-256 哈希函数确保密钥为 16 字节 -# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest() -# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数 - -# decrypted_message = Decrypt(secret_key_int, encrypted_message) -# self.assertEqual(decrypted_message, self.message) + def test_encrypt_decrypt(self): + encrypted_message = Encrypt(self.public_key, self.message) + decrypted_message = Decrypt(self.secret_key, encrypted_message) + self.assertEqual(decrypted_message, self.message) class TestGenerateReKey(unittest.TestCase):