refactor: cleanup test code

This commit is contained in:
2025-04-03 13:45:37 +08:00
parent ef14e4851c
commit f03fc66c1e
10 changed files with 152 additions and 149 deletions

View File

@@ -0,0 +1,47 @@
import sys
import os
import random
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import add, multiply, sm2p256v1
import time
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
def test_rust_vs_python_multiply():
mul_times = random.randint(0, sm2p256v1.N - 1)
# Rust实现
start_time = time.time()
for _ in range(10):
_ = multiply(g, mul_times, 1)
rust_time = time.time() - start_time
print(f"\nRust multiply 执行时间: {rust_time:.6f}")
# Python实现
start_time = time.time()
for _ in range(10):
_ = multiply(g, mul_times, 0)
python_time = time.time() - start_time
print(f"Python multiply 执行时间: {python_time:.6f}")
assert rust_time < python_time, "Rust实现应该比Python更快"
def test_rust_vs_python_add():
# Rust实现
start_time = time.time()
for _ in range(10):
_ = add(g, g, 1)
rust_time = time.time() - start_time
print(f"\nRust add 执行时间: {rust_time:.6f}")
# Python实现
start_time = time.time()
for _ in range(10):
_ = add(g, g, 0)
python_time = time.time() - start_time
print(f"Python add 执行时间: {python_time:.6f}")
assert rust_time < python_time, "Rust实现应该比Python更快"

View File

@@ -1,37 +0,0 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import add, multiply, sm2p256v1
import time
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust add 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python add 执行时间: {elapsed_time:.6f}")

View File

@@ -0,0 +1,78 @@
import sys
import os
import time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
def test_tpre_performance_with_different_message_sizes():
"""测试不同消息大小下的TPRE性能"""
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
for i in range(1, 6):
total_time = 0
# 1. 密钥生成和消息准备
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world" * pow(10, i)
print(f"明文长度:{len(m)}")
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"密钥生成运行时间:{elapsed_time}")
# 2. 加密
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"加密算法运行时间:{elapsed_time}")
# 3. 接收方密钥生成
pk_b, sk_b = GenerateKeyPair()
# 4. 重加密密钥生成
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"重加密密钥生成算法运行时间:{elapsed_time}")
# 5. 重加密
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
print(f"重加密算法运行时间:{elapsed_time}")
# 6. 解密
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
decrypted_m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"解密算法运行时间:{elapsed_time}")
# 验证解密结果是否正确
assert decrypted_m == m, f"解密失败,明文长度: {len(m)}"
print("成功解密")
print(f"算法总运行时间:{total_time}")
print()

View File

@@ -3,7 +3,6 @@ import os
import unittest
from unittest.mock import patch, MagicMock, Mock
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
import node

2
tests/pytest.ini Normal file
View File

@@ -0,0 +1,2 @@
[pytest]
asyncio_default_fixture_loop_scope = function

View File

@@ -1,72 +0,0 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
total_time = 0
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world"
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"密钥生成运行时间:{elapsed_time}")
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"加密算法运行时间:{elapsed_time}")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"重加密密钥生成算法运行时间:{elapsed_time}")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
print(f"重加密算法运行时间:{elapsed_time}")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"解密算法运行时间:{elapsed_time}")
print("成功解密:", m)
print(f"算法总运行时间:{total_time}")
print()

View File

@@ -1,5 +1,7 @@
import sys
import os
import time
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
@@ -10,26 +12,25 @@ from tpre import (
MergeCFrag,
DecryptFrags,
)
import time
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
for i in range(1, 10):
def test_tpre_full_workflow_speed():
"""测试TPRE完整工作流程的运行速度"""
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
total_time = 0
# 1
# 1. 密钥生成
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world" * pow(10, i)
print(f"明文长度:{len(m)}")
m = b"hello world"
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"密钥生成运行时间:{elapsed_time}")
# 2
# 2. 加密
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
@@ -37,10 +38,10 @@ for i in range(1, 10):
total_time += elapsed_time
print(f"加密算法运行时间:{elapsed_time}")
# 3
# 3. 接收方密钥生成
pk_b, sk_b = GenerateKeyPair()
# 5
# 4. 重加密密钥生成
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
@@ -49,10 +50,9 @@ for i in range(1, 10):
total_time += elapsed_time
print(f"重加密密钥生成算法运行时间:{elapsed_time}")
# 7
# 5. 重加密
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
@@ -61,14 +61,15 @@ for i in range(1, 10):
total_time += elapsed_time
print(f"重加密算法运行时间:{elapsed_time}")
# 9
# 6. 合并和解密
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
decrypted_m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"解密算法运行时间:{elapsed_time}")
print("成功解密:")
print(f"算法总运行时间:{total_time}")
print()
# 验证解密结果
assert decrypted_m == m, "解密结果与原始消息不匹配"
print(f"成功解密: {decrypted_m}")

View File

@@ -1,6 +1,5 @@
import sys
import os
import hashlib
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
@@ -84,19 +83,15 @@ class TestGenerateKeyPair(unittest.TestCase):
self.assertGreater(secret_key, 0)
# class TestEncryptDecrypt(unittest.TestCase):
# def setUp(self):
# self.public_key, self.secret_key = GenerateKeyPair()
# self.message = b"Hello, world!"
class TestEncryptDecrypt(unittest.TestCase):
def setUp(self):
self.public_key, self.secret_key = GenerateKeyPair()
self.message = b"Hello, world!"
# def test_encrypt_decrypt(self):
# encrypted_message = Encrypt(self.public_key, self.message)
# # 使用 SHA-256 哈希函数确保密钥为 16 字节
# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest()
# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = Decrypt(secret_key_int, encrypted_message)
# self.assertEqual(decrypted_message, self.message)
def test_encrypt_decrypt(self):
encrypted_message = Encrypt(self.public_key, self.message)
decrypted_message = Decrypt(self.secret_key, encrypted_message)
self.assertEqual(decrypted_message, self.message)
class TestGenerateReKey(unittest.TestCase):