Improve test code files

This commit is contained in:
2024-10-04 22:28:14 +08:00
committed by sangge-redmi
parent d071296126
commit a9a4e28d9a
14 changed files with 191 additions and 282 deletions

37
tests/ecc_speed_test.py Normal file
View File

@@ -0,0 +1,37 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import add, multiply, sm2p256v1
import time
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 1) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust add 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g, 0) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"python add 执行时间: {elapsed_time:.6f}")

74
tests/lenth_test.py Normal file
View File

@@ -0,0 +1,74 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
for i in range(1, 10):
total_time = 0
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world" * pow(10, i)
print(f"明文长度:{len(m)}")
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"密钥生成运行时间:{elapsed_time}")
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"加密算法运行时间:{elapsed_time}")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"重加密密钥生成算法运行时间:{elapsed_time}")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
print(f"重加密算法运行时间:{elapsed_time}")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"解密算法运行时间:{elapsed_time}")
print("成功解密:")
print(f"算法总运行时间:{total_time}")
print()

70
tests/maxnode_test.py Normal file
View File

@@ -0,0 +1,70 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import *
import time
N = 80
total_time = 0
while total_time < 1:
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
total_time = 0
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world"
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"密钥生成运行时间:{elapsed_time}秒")
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"加密算法运行时间:{elapsed_time}秒")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"重加密密钥生成算法运行时间:{elapsed_time}秒")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
# print(f"重加密算法运行时间:{elapsed_time}秒")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
# print(f"解密算法运行时间:{elapsed_time}秒")
# print("成功解密:", m)
print(f"算法总运行时间:{total_time}")
print()
N += 1

88
tests/node_test.py Normal file
View File

@@ -0,0 +1,88 @@
# 测试 node.py 中的函数
import os
import unittest
from unittest.mock import patch, MagicMock, Mock
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
import node
class TestGetLocalIP(unittest.TestCase):
@patch.dict("os.environ", {"HOST_IP": "60.204.193.58"}) # 模拟设置 HOST_IP 环境变量
def test_get_ip_from_env(self):
# 调用被测函数
node.get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(node.ip, "60.204.193.58")
@patch("socket.socket") # Mock socket 连接行为
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
def test_get_ip_from_socket(self, mock_socket):
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
node.get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(node.ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@patch("requests.get") # Mock requests.get 调用
def test_send_ip(self, mock_get):
# 设置模拟返回的 HTTP 响应
mock_response = Mock()
mock_response.text = "node123" # 模拟返回的节点ID
mock_response.status_code = 200
mock_get.return_value = (
mock_response # 设置 requests.get() 的返回值为 mock_response
)
# 保存原始的全局 id 值
original_id = node.id
# 调用待测函数
node.send_ip()
# 确保 requests.get 被正确调用
expected_url = f"{node.server_address}/get_node?ip={node.ip}"
mock_get.assert_called_once_with(expected_url, timeout=3)
# 检查 id 是否被正确更新
self.assertIs(node.id, mock_response) # 检查 id 是否被修改
self.assertEqual(
node.id.text, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
node.init()
# 验证 get_local_ip 和 send_ip 被调用
mock_get_local_ip.assert_called_once()
mock_send_ip.assert_called_once()
# 确保 create_task 被调用来启动心跳包
mock_create_task.assert_called_once()
if __name__ == "__main__":
unittest.main()

199
tests/node_test5.py Normal file
View File

@@ -0,0 +1,199 @@
# node_test剩下部分(有问题)
import os
import unittest
import pytest
from unittest.mock import patch, MagicMock, Mock, AsyncMock
import requests
import asyncio
import httpx
import respx
from fastapi.testclient import TestClient
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from node import (
app,
send_heartbeat_internal,
Req,
send_ip,
get_local_ip,
init,
clear,
send_user_des_message,
ip,
id,
)
client = TestClient(app)
server_address = "http://60.204.236.38:8000/server"
# ip = None # 初始化全局变量 ip
# id = None # 初始化全局变量 id
class TestGetLocalIP(unittest.TestCase):
os.environ["HOST_IP"] = "60.204.193.58" # 模拟设置 HOST_IP 环境变量
def test_get_ip_from_env(self):
global ip
# 调用被测函数
get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(ip, "60.204.193.58")
@patch("socket.socket") # Mock socket 连接行为
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
def test_get_ip_from_socket(self, mock_socket):
global ip
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@respx.mock
def test_send_ip(self):
global ip, id
ip = "60.204.193.58"
mock_url = f"{server_address}/get_node?ip={ip}"
respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123"))
# 调用待测函数
send_ip()
# 确保 requests.get 被正确调用
self.assertEqual(
id, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
init()
# 验证 get_local_ip 和 send_ip 被调用
mock_get_local_ip.assert_called_once()
mock_send_ip.assert_called_once()
# 确保 create_task 被调用来启动心跳包
mock_create_task.assert_called_once()
def test_clear(self):
# 调用 clear 函数
clear()
# 检查输出
self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_success():
global ip
ip = "60.204.193.58"
# 模拟心跳请求
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
return_value=httpx.Response(200)
)
# 模拟 requests.get 以避免实际请求
with patch("requests.get", return_value=httpx.Response(200)) as mock_get:
# 模拟 asyncio.sleep 以避免实际延迟
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
task = asyncio.create_task(send_heartbeat_internal())
await asyncio.sleep(0.1) # 允许任务运行一段时间
task.cancel() # 取消任务以停止无限循环
try:
await task # 确保任务被等待
except asyncio.CancelledError:
pass # 捕获取消错误
assert mock_get.called
assert mock_get.call_count > 0
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_failure():
global ip
ip = "60.204.193.58"
# 模拟心跳请求以引发异常
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
side_effect=httpx.RequestError("Central server error")
)
# 模拟 requests.get 以避免实际请求
with patch(
"requests.get", side_effect=httpx.RequestError("Central server error")
) as mock_get:
# 模拟 asyncio.sleep 以避免实际延迟
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
task = asyncio.create_task(send_heartbeat_internal())
await asyncio.sleep(0.1) # 允许任务运行一段时间
task.cancel() # 取消任务以停止无限循环
try:
await task # 确保任务被等待
except asyncio.CancelledError:
pass # 捕获取消错误
assert mock_get.called
assert mock_get.call_count > 0
def test_user_src():
# 模拟 ReEncrypt 函数
with patch(
"node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")
):
# 模拟 send_user_des_message 函数
with patch(
"node.send_user_des_message", new_callable=AsyncMock
) as mock_send_user_des_message:
message = {
"source_ip": "60.204.193.58",
"dest_ip": "60.204.193.59",
"capsule": (("x1", "y1"), ("x2", "y2"), 123),
"ct": 456,
"rk": ["rk1", "rk2"],
}
response = client.post("/user_src", json=message)
assert response.status_code == 200
assert response.json() == {"detail": "message received"}
mock_send_user_des_message.assert_called_once()
def test_send_user_des_message():
with respx.mock:
dest_ip = "60.204.193.59"
re_message = (("a", "b", "c", "d"), 123)
respx.post(f"http://{dest_ip}:8002/receive_messages").mock(
return_value=httpx.Response(200, json={"status": "success"})
)
response = requests.post(
f"http://{dest_ip}:8002/receive_messages",
json={"Tuple": re_message, "ip": "60.204.193.58"},
)
assert response.status_code == 200
assert response.json() == {"status": "success"}
if __name__ == "__main__":
unittest.main()

140
tests/server_test.py Normal file
View File

@@ -0,0 +1,140 @@
import sqlite3
import pytest
from fastapi.testclient import TestClient
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from server import app, validate_ip
# 创建 TestClient 实例
client = TestClient(app)
# 准备测试数据库数据
def setup_db():
# 创建数据库并插入测试数据
with sqlite3.connect("server.db") as db:
db.execute(
"""
CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY,
ip TEXT NOT NULL,
last_heartbeat INTEGER NOT NULL
)
"""
)
db.execute(
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)"
)
db.execute(
"INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)"
)
db.commit()
# 清空数据库
def clear_db():
with sqlite3.connect("server.db") as db:
db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表
db.commit()
# 测试 IP 验证功能
def test_validate_ip():
assert validate_ip("192.168.0.1") is True
assert validate_ip("256.256.256.256") is False
assert validate_ip("::1") is True
assert validate_ip("invalid_ip") is False
# 测试首页路由
def test_home():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
# 测试 show_nodes 路由
def test_show_nodes():
setup_db()
response = client.get("/server/show_nodes")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0][1] == "192.168.0.1"
assert data[1][1] == "192.168.0.2"
# 测试 get_node 路由
def test_get_node():
# 确保数据库和表的存在
setup_db()
valid_ip = "192.168.0.3"
invalid_ip = "256.256.256.256"
# 测试有效的 IP 地址
response = client.get(f"/server/get_node?ip={valid_ip}")
assert response.status_code == 200
# 测试无效的 IP 地址
response = client.get(f"/server/get_node?ip={invalid_ip}")
assert response.status_code == 400
# 测试 delete_node 路由
def test_delete_node():
setup_db()
valid_ip = "192.168.0.1"
invalid_ip = "192.168.0.255"
response = client.get(f"/server/delete_node?ip={valid_ip}")
assert response.status_code == 200
assert "Node with IP 192.168.0.1 deleted successfully." in response.text
response = client.get(f"/server/delete_node?ip={invalid_ip}")
assert response.status_code == 404
# 测试 heartbeat 路由
def test_receive_heartbeat():
setup_db()
valid_ip = "192.168.0.2"
invalid_ip = "256.256.256.256"
response = client.get(f"/server/heartbeat?ip={valid_ip}")
assert response.status_code == 200
assert response.json() == {"status": "received"}
response = client.get(f"/server/heartbeat?ip={invalid_ip}")
assert response.status_code == 400
assert response.json() == {"message": "invalid ip format"}
# 测试 send_nodes_list 路由
def test_send_nodes_list():
setup_db()
response = client.get("/server/send_nodes_list?count=1")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0] == "192.168.0.1"
response = client.get("/server/send_nodes_list?count=2")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
# 运行完测试后清理数据库
@pytest.fixture(autouse=True)
def run_around_tests():
clear_db()
yield
clear_db()

72
tests/speed_test.py Normal file
View File

@@ -0,0 +1,72 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time
N = 20
T = N // 2
print(f"当前门限值: N = {N}, T = {T}")
total_time = 0
# 1
start_time = time.time()
pk_a, sk_a = GenerateKeyPair()
m = b"hello world"
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"密钥生成运行时间:{elapsed_time}")
# 2
start_time = time.time()
capsule_ct = Encrypt(pk_a, m)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"加密算法运行时间:{elapsed_time}")
# 3
pk_b, sk_b = GenerateKeyPair()
# 5
start_time = time.time()
id_tuple = tuple(range(N))
rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"重加密密钥生成算法运行时间:{elapsed_time}")
# 7
start_time = time.time()
cfrag_cts = []
for rekey in rekeys:
cfrag_ct = ReEncrypt(rekey, capsule_ct)
cfrag_cts.append(cfrag_ct)
end_time = time.time()
elapsed_time = (end_time - start_time) / len(rekeys)
total_time += elapsed_time
print(f"重加密算法运行时间:{elapsed_time}")
# 9
start_time = time.time()
cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time()
elapsed_time = end_time - start_time
total_time += elapsed_time
print(f"解密算法运行时间:{elapsed_time}")
print("成功解密:", m)
print(f"算法总运行时间:{total_time}")
print()

76
tests/test_client.py Normal file
View File

@@ -0,0 +1,76 @@
import os
import pytest
from fastapi.testclient import TestClient
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from client import app, init_db, clean_env, get_own_ip
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
# 设置测试环境
init_db()
yield
# 清理测试环境
clean_env()
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
def test_receive_messages():
message = {"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), "ip": "127.0.0.1"}
response = client.post("/receive_messages", json=message)
assert response.status_code == 200
assert response.json().get("detail") == "Message received"
# @respx.mock
# def test_request_message():
# request_message = {
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
# "message_name": "name"
# }
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
# response = client.post("/request_message", json=request_message)
# assert response.status_code == 200
# assert "threshold" in response.json()
# assert "public_key" in response.json()
# @respx.mock
# def test_receive_request():
# ip_message = {
# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址
# "message_name": "name",
# "source_ip": "124.70.165.73", # 使用不同的 IP 地址
# "pk": (123, 456)
# }
# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"}))
# response = client.post("/receive_request", json=ip_message)
# assert response.status_code == 200
# assert "threshold" in response.json()
# assert "public_key" in response.json()
def test_get_pk():
response = client.get("/get_pk")
assert response.status_code == 200
assert "pkx" in response.json()
assert "pky" in response.json()
def test_recieve_pk():
pk_data = {"pkx": "123", "pky": "456", "ip": "127.0.0.1"}
response = client.post("/recieve_pk", json=pk_data)
assert response.status_code == 200
assert response.json() == {"message": "save pk in database"}
if __name__ == "__main__":
pytest.main()

158
tests/tpre_test.py Normal file
View File

@@ -0,0 +1,158 @@
import sys
import os
import hashlib
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
hash2,
hash3,
hash4,
multiply,
g,
sm2p256v1,
GenerateKeyPair,
Encrypt,
Decrypt,
GenerateReKey,
Encapsulate,
ReEncrypt,
DecryptFrags,
MergeCFrag,
)
import random
import unittest
class TestHash2(unittest.TestCase):
def setUp(self):
self.double_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash2(self.double_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash2(self.double_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash3(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash3(self.triple_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash3(self.triple_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash4(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
self.Zp = random.randint(0, sm2p256v1.N - 1)
def test_digest_type(self):
digest = hash4(self.triple_G, self.Zp)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash4(self.triple_G, self.Zp)
self.assertLess(digest, sm2p256v1.N)
class TestGenerateKeyPair(unittest.TestCase):
def test_key_pair(self):
public_key, secret_key = GenerateKeyPair()
self.assertIsInstance(public_key, tuple)
self.assertEqual(len(public_key), 2)
self.assertIsInstance(secret_key, int)
self.assertLess(secret_key, sm2p256v1.N)
self.assertGreater(secret_key, 0)
# class TestEncryptDecrypt(unittest.TestCase):
# def setUp(self):
# self.public_key, self.secret_key = GenerateKeyPair()
# self.message = b"Hello, world!"
# def test_encrypt_decrypt(self):
# encrypted_message = Encrypt(self.public_key, self.message)
# # 使用 SHA-256 哈希函数确保密钥为 16 字节
# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest()
# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = Decrypt(secret_key_int, encrypted_message)
# self.assertEqual(decrypted_message, self.message)
class TestGenerateReKey(unittest.TestCase):
def test_generate_rekey(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
self.assertIsInstance(rekey, list)
self.assertEqual(len(rekey), 5)
class TestEncapsulate(unittest.TestCase):
def test_encapsulate(self):
pk_A, _ = GenerateKeyPair()
K, capsule = Encapsulate(pk_A)
self.assertIsInstance(K, int)
self.assertIsInstance(capsule, tuple)
self.assertEqual(len(capsule), 3)
class TestReEncrypt(unittest.TestCase):
def test_reencrypt(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
pk_A, _ = GenerateKeyPair()
message = b"Hello, world!"
encrypted_message = Encrypt(pk_A, message)
reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
self.assertIsInstance(reencrypted_message, tuple)
self.assertEqual(len(reencrypted_message), 2)
# class TestDecryptFrags(unittest.TestCase):
# def test_decrypt_frags(self):
# sk_A = random.randint(0, sm2p256v1.N - 1)
# pk_B, sk_B = GenerateKeyPair()
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
# pk_A, _ = GenerateKeyPair()
# message = b"Hello, world!"
# encrypted_message = Encrypt(pk_A, message)
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
# cfrags = [reencrypted_message]
# merged_cfrags = MergeCFrag(cfrags)
# self.assertIsNotNone(merged_cfrags)
# sk_B_hash = hashlib.sha256(sk_B.to_bytes((sk_B.bit_length() + 7) // 8, 'big')).digest()
# sk_B_int = int.from_bytes(sk_B_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = DecryptFrags(sk_B_int, pk_B, pk_A, merged_cfrags)
# self.assertEqual(decrypted_message, message)
if __name__ == "__main__":
unittest.main()