Improve test code files
All checks were successful
Test CI / test speed (push) Successful in 10s
All checks were successful
Test CI / test speed (push) Successful in 10s
This commit is contained in:
parent
aa0eb6bfbc
commit
061bd5d2bf
1
FQA.md
1
FQA.md
@ -59,3 +59,4 @@
|
||||
15. **Q:国密SM2 SM3 SM4** 如题
|
||||
|
||||
**A:**
|
||||
----------------------------------------------------------------------
|
||||
|
@ -1 +0,0 @@
|
||||
DEBUG:asyncio:Using selector: EpollSelector
|
@ -1,122 +0,0 @@
|
||||
import sqlite3
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from server import app, validate_ip
|
||||
|
||||
# 创建 TestClient 实例
|
||||
client = TestClient(app)
|
||||
|
||||
# 准备测试数据库数据
|
||||
def setup_db():
|
||||
# 创建数据库并插入测试数据
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS nodes (
|
||||
id INTEGER PRIMARY KEY,
|
||||
ip TEXT NOT NULL,
|
||||
last_heartbeat INTEGER NOT NULL
|
||||
)
|
||||
""")
|
||||
db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)")
|
||||
db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)")
|
||||
db.commit()
|
||||
|
||||
# 清空数据库
|
||||
def clear_db():
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表
|
||||
db.commit()
|
||||
|
||||
|
||||
# 测试 IP 验证功能
|
||||
def test_validate_ip():
|
||||
assert validate_ip("192.168.0.1") is True
|
||||
assert validate_ip("256.256.256.256") is False
|
||||
assert validate_ip("::1") is True
|
||||
assert validate_ip("invalid_ip") is False
|
||||
|
||||
# 测试首页路由
|
||||
def test_home():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello, World!"}
|
||||
|
||||
# 测试 show_nodes 路由
|
||||
def test_show_nodes():
|
||||
setup_db()
|
||||
|
||||
response = client.get("/server/show_nodes")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0][1] == "192.168.0.1"
|
||||
assert data[1][1] == "192.168.0.2"
|
||||
|
||||
# 测试 get_node 路由
|
||||
def test_get_node():
|
||||
# 确保数据库和表的存在
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.3"
|
||||
invalid_ip = "256.256.256.256"
|
||||
|
||||
# 测试有效的 IP 地址
|
||||
response = client.get(f"/server/get_node?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# 测试无效的 IP 地址
|
||||
response = client.get(f"/server/get_node?ip={invalid_ip}")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# 测试 delete_node 路由
|
||||
def test_delete_node():
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.1"
|
||||
invalid_ip = "192.168.0.255"
|
||||
|
||||
response = client.get(f"/server/delete_node?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
assert "Node with IP 192.168.0.1 deleted successfully." in response.text
|
||||
|
||||
response = client.get(f"/server/delete_node?ip={invalid_ip}")
|
||||
assert response.status_code == 404
|
||||
|
||||
# 测试 heartbeat 路由
|
||||
def test_receive_heartbeat():
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.2"
|
||||
invalid_ip = "256.256.256.256"
|
||||
|
||||
response = client.get(f"/server/heartbeat?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "received"}
|
||||
|
||||
response = client.get(f"/server/heartbeat?ip={invalid_ip}")
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"message": "invalid ip format"}
|
||||
|
||||
# 测试 send_nodes_list 路由
|
||||
def test_send_nodes_list():
|
||||
setup_db()
|
||||
|
||||
response = client.get("/server/send_nodes_list?count=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0] == "192.168.0.1"
|
||||
|
||||
response = client.get("/server/send_nodes_list?count=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
||||
# 运行完测试后清理数据库
|
||||
@pytest.fixture(autouse=True)
|
||||
def run_around_tests():
|
||||
clear_db()
|
||||
yield
|
||||
clear_db()
|
@ -1,58 +0,0 @@
|
||||
from tpre import hash2, hash3, hash4, multiply, g, sm2p256v1
|
||||
import random
|
||||
import unittest
|
||||
|
||||
|
||||
class TestHash2(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.double_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash3(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash4(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
self.Zp = random.randint(0, sm2p256v1.N - 1)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,3 +1,7 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import add, multiply, sm2p256v1
|
||||
import time
|
||||
|
@ -1,3 +1,7 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import (
|
||||
GenerateKeyPair,
|
||||
Encrypt,
|
@ -1,3 +1,7 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import *
|
||||
import time
|
||||
|
@ -2,45 +2,53 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
import node
|
||||
|
||||
|
||||
class TestGetLocalIP(unittest.TestCase):
|
||||
|
||||
@patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量
|
||||
@patch.dict("os.environ", {"HOST_IP": "60.204.193.58"}) # 模拟设置 HOST_IP 环境变量
|
||||
def test_get_ip_from_env(self):
|
||||
# 调用被测函数
|
||||
node.get_local_ip()
|
||||
|
||||
# 检查函数是否正确获取到 HOST_IP
|
||||
self.assertEqual(node.ip, '60.204.193.58')
|
||||
self.assertEqual(node.ip, "60.204.193.58")
|
||||
|
||||
@patch('socket.socket') # Mock socket 连接行为
|
||||
@patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量
|
||||
@patch("socket.socket") # Mock socket 连接行为
|
||||
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
|
||||
def test_get_ip_from_socket(self, mock_socket):
|
||||
# 模拟 socket 返回的 IP 地址
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0)
|
||||
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
|
||||
|
||||
# 调用被测函数
|
||||
node.get_local_ip()
|
||||
|
||||
# 确认 socket 被调用过
|
||||
mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80))
|
||||
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
# 检查是否通过 socket 获取到正确的 IP 地址
|
||||
self.assertEqual(node.ip, '110.41.155.96')
|
||||
self.assertEqual(node.ip, "110.41.155.96")
|
||||
|
||||
|
||||
class TestSendIP(unittest.TestCase):
|
||||
@patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP
|
||||
@patch('requests.get') # Mock requests.get 调用
|
||||
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
|
||||
@patch("requests.get") # Mock requests.get 调用
|
||||
def test_send_ip(self, mock_get):
|
||||
# 设置模拟返回的 HTTP 响应
|
||||
mock_response = Mock()
|
||||
mock_response.text = "node123" # 模拟返回的节点ID
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response
|
||||
mock_get.return_value = (
|
||||
mock_response # 设置 requests.get() 的返回值为 mock_response
|
||||
)
|
||||
|
||||
# 保存原始的全局 id 值
|
||||
original_id = node.id
|
||||
@ -54,13 +62,16 @@ class TestSendIP(unittest.TestCase):
|
||||
|
||||
# 检查 id 是否被正确更新
|
||||
self.assertIs(node.id, mock_response) # 检查 id 是否被修改
|
||||
self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配
|
||||
self.assertEqual(
|
||||
node.id.text, "node123"
|
||||
) # 检查更新后的 id 是否与 mock_response.text 匹配
|
||||
|
||||
|
||||
class TestNode(unittest.TestCase):
|
||||
|
||||
@patch('node.send_ip')
|
||||
@patch('node.get_local_ip')
|
||||
@patch('node.asyncio.create_task')
|
||||
@patch("node.send_ip")
|
||||
@patch("node.get_local_ip")
|
||||
@patch("node.asyncio.create_task")
|
||||
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
|
||||
# 调用 init 函数
|
||||
node.init()
|
||||
@ -72,5 +83,6 @@ class TestNode(unittest.TestCase):
|
||||
# 确保 create_task 被调用来启动心跳包
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,4 +1,4 @@
|
||||
#node_test剩下部分(有问题)
|
||||
# node_test剩下部分(有问题)
|
||||
import os
|
||||
import unittest
|
||||
import pytest
|
||||
@ -8,49 +8,67 @@ import asyncio
|
||||
import httpx
|
||||
import respx
|
||||
from fastapi.testclient import TestClient
|
||||
from node import app, send_heartbeat_internal, Req, send_ip, get_local_ip, init, clear, send_user_des_message
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from node import (
|
||||
app,
|
||||
send_heartbeat_internal,
|
||||
Req,
|
||||
send_ip,
|
||||
get_local_ip,
|
||||
init,
|
||||
clear,
|
||||
send_user_des_message,
|
||||
ip,
|
||||
id,
|
||||
)
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
server_address = "http://60.204.236.38:8000/server"
|
||||
ip = None # 初始化全局变量 ip
|
||||
id = None # 初始化全局变量 id
|
||||
# ip = None # 初始化全局变量 ip
|
||||
# id = None # 初始化全局变量 id
|
||||
|
||||
|
||||
class TestGetLocalIP(unittest.TestCase):
|
||||
|
||||
@patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量
|
||||
os.environ["HOST_IP"] = "60.204.193.58" # 模拟设置 HOST_IP 环境变量
|
||||
|
||||
def test_get_ip_from_env(self):
|
||||
global ip
|
||||
# 调用被测函数
|
||||
get_local_ip()
|
||||
|
||||
# 检查函数是否正确获取到 HOST_IP
|
||||
self.assertEqual(ip, '60.204.193.58')
|
||||
self.assertEqual(ip, "60.204.193.58")
|
||||
|
||||
@patch('socket.socket') # Mock socket 连接行为
|
||||
@patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量
|
||||
@patch("socket.socket") # Mock socket 连接行为
|
||||
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
|
||||
def test_get_ip_from_socket(self, mock_socket):
|
||||
global ip
|
||||
# 模拟 socket 返回的 IP 地址
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0)
|
||||
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
|
||||
|
||||
# 调用被测函数
|
||||
get_local_ip()
|
||||
|
||||
# 确认 socket 被调用过
|
||||
mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80))
|
||||
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
# 检查是否通过 socket 获取到正确的 IP 地址
|
||||
self.assertEqual(ip, '110.41.155.96')
|
||||
self.assertEqual(ip, "110.41.155.96")
|
||||
|
||||
|
||||
class TestSendIP(unittest.TestCase):
|
||||
@patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP
|
||||
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
|
||||
@respx.mock
|
||||
def test_send_ip(self):
|
||||
global ip, id
|
||||
ip = '60.204.193.58'
|
||||
ip = "60.204.193.58"
|
||||
mock_url = f"{server_address}/get_node?ip={ip}"
|
||||
respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123"))
|
||||
|
||||
@ -58,13 +76,16 @@ class TestSendIP(unittest.TestCase):
|
||||
send_ip()
|
||||
|
||||
# 确保 requests.get 被正确调用
|
||||
self.assertEqual(id, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配
|
||||
self.assertEqual(
|
||||
id, "node123"
|
||||
) # 检查更新后的 id 是否与 mock_response.text 匹配
|
||||
|
||||
|
||||
class TestNode(unittest.TestCase):
|
||||
|
||||
@patch('node.send_ip')
|
||||
@patch('node.get_local_ip')
|
||||
@patch('node.asyncio.create_task')
|
||||
@patch("node.send_ip")
|
||||
@patch("node.get_local_ip")
|
||||
@patch("node.asyncio.create_task")
|
||||
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
|
||||
# 调用 init 函数
|
||||
init()
|
||||
@ -82,11 +103,12 @@ class TestNode(unittest.TestCase):
|
||||
# 检查输出
|
||||
self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_heartbeat_internal_success():
|
||||
global ip
|
||||
ip = '60.204.193.58'
|
||||
ip = "60.204.193.58"
|
||||
# 模拟心跳请求
|
||||
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
|
||||
return_value=httpx.Response(200)
|
||||
@ -107,18 +129,21 @@ async def test_send_heartbeat_internal_success():
|
||||
assert mock_get.called
|
||||
assert mock_get.call_count > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_heartbeat_internal_failure():
|
||||
global ip
|
||||
ip = '60.204.193.58'
|
||||
ip = "60.204.193.58"
|
||||
# 模拟心跳请求以引发异常
|
||||
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
|
||||
side_effect=httpx.RequestError("Central server error")
|
||||
)
|
||||
|
||||
# 模拟 requests.get 以避免实际请求
|
||||
with patch("requests.get", side_effect=httpx.RequestError("Central server error")) as mock_get:
|
||||
with patch(
|
||||
"requests.get", side_effect=httpx.RequestError("Central server error")
|
||||
) as mock_get:
|
||||
# 模拟 asyncio.sleep 以避免实际延迟
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
task = asyncio.create_task(send_heartbeat_internal())
|
||||
@ -132,23 +157,29 @@ async def test_send_heartbeat_internal_failure():
|
||||
assert mock_get.called
|
||||
assert mock_get.call_count > 0
|
||||
|
||||
|
||||
def test_user_src():
|
||||
# 模拟 ReEncrypt 函数
|
||||
with patch("node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")):
|
||||
with patch(
|
||||
"node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")
|
||||
):
|
||||
# 模拟 send_user_des_message 函数
|
||||
with patch("node.send_user_des_message", new_callable=AsyncMock) as mock_send_user_des_message:
|
||||
with patch(
|
||||
"node.send_user_des_message", new_callable=AsyncMock
|
||||
) as mock_send_user_des_message:
|
||||
message = {
|
||||
"source_ip": "60.204.193.58",
|
||||
"dest_ip": "60.204.193.59",
|
||||
"capsule": (("x1", "y1"), ("x2", "y2"), 123),
|
||||
"ct": 456,
|
||||
"rk": ["rk1", "rk2"]
|
||||
"rk": ["rk1", "rk2"],
|
||||
}
|
||||
response = client.post("/user_src", json=message)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"detail": "message received"}
|
||||
mock_send_user_des_message.assert_called_once()
|
||||
|
||||
|
||||
def test_send_user_des_message():
|
||||
with respx.mock:
|
||||
dest_ip = "60.204.193.59"
|
||||
@ -156,9 +187,13 @@ def test_send_user_des_message():
|
||||
respx.post(f"http://{dest_ip}:8002/receive_messages").mock(
|
||||
return_value=httpx.Response(200, json={"status": "success"})
|
||||
)
|
||||
response = requests.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "60.204.193.58"})
|
||||
response = requests.post(
|
||||
f"http://{dest_ip}:8002/receive_messages",
|
||||
json={"Tuple": re_message, "ip": "60.204.193.58"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "success"}
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,11 +1,10 @@
|
||||
import sqlite3
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
|
||||
import sqlite3
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from server import app, validate_ip
|
||||
|
||||
# 创建 TestClient 实例
|
@ -1,3 +1,7 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import (
|
||||
GenerateKeyPair,
|
||||
Encrypt,
|
@ -1,13 +1,14 @@
|
||||
import os
|
||||
import pytest
|
||||
import sqlite3
|
||||
import respx
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from client import app, init_db, clean_env, get_own_ip
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_and_teardown():
|
||||
# 设置测试环境
|
||||
@ -16,20 +17,20 @@ def setup_and_teardown():
|
||||
# 清理测试环境
|
||||
clean_env()
|
||||
|
||||
|
||||
def test_read_root():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello, World!"}
|
||||
|
||||
|
||||
def test_receive_messages():
|
||||
message = {
|
||||
"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8),
|
||||
"ip": "127.0.0.1"
|
||||
}
|
||||
message = {"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), "ip": "127.0.0.1"}
|
||||
response = client.post("/receive_messages", json=message)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("detail") == "Message received"
|
||||
|
||||
|
||||
# @respx.mock
|
||||
# def test_request_message():
|
||||
# request_message = {
|
||||
@ -56,21 +57,20 @@ def test_receive_messages():
|
||||
# assert "threshold" in response.json()
|
||||
# assert "public_key" in response.json()
|
||||
|
||||
|
||||
def test_get_pk():
|
||||
response = client.get("/get_pk")
|
||||
assert response.status_code == 200
|
||||
assert "pkx" in response.json()
|
||||
assert "pky" in response.json()
|
||||
|
||||
|
||||
def test_recieve_pk():
|
||||
pk_data = {
|
||||
"pkx": "123",
|
||||
"pky": "456",
|
||||
"ip": "127.0.0.1"
|
||||
}
|
||||
pk_data = {"pkx": "123", "pky": "456", "ip": "127.0.0.1"}
|
||||
response = client.post("/recieve_pk", json=pk_data)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "save pk in database"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
@ -1,9 +1,24 @@
|
||||
import sys
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
|
||||
from tpre import (
|
||||
hash2, hash3, hash4, multiply, g, sm2p256v1,
|
||||
GenerateKeyPair, Encrypt, Decrypt, GenerateReKey,
|
||||
Encapsulate, ReEncrypt, DecryptFrags
|
||||
hash2,
|
||||
hash3,
|
||||
hash4,
|
||||
multiply,
|
||||
g,
|
||||
sm2p256v1,
|
||||
GenerateKeyPair,
|
||||
Encrypt,
|
||||
Decrypt,
|
||||
GenerateReKey,
|
||||
Encapsulate,
|
||||
ReEncrypt,
|
||||
DecryptFrags,
|
||||
MergeCFrag,
|
||||
)
|
||||
from tpre import MergeCFrag
|
||||
import random
|
||||
import unittest
|
||||
|
||||
@ -59,12 +74,14 @@ class TestHash4(unittest.TestCase):
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
# class TestGenerateKeyPair(unittest.TestCase):
|
||||
# def test_key_pair(self):
|
||||
# public_key, secret_key = GenerateKeyPair()
|
||||
# self.assertIsInstance(public_key, tuple)
|
||||
# self.assertIsInstance(secret_key, int)
|
||||
# self.assertEqual(len(public_key), 2)
|
||||
class TestGenerateKeyPair(unittest.TestCase):
|
||||
def test_key_pair(self):
|
||||
public_key, secret_key = GenerateKeyPair()
|
||||
self.assertIsInstance(public_key, tuple)
|
||||
self.assertEqual(len(public_key), 2)
|
||||
self.assertIsInstance(secret_key, int)
|
||||
self.assertLess(secret_key, sm2p256v1.N)
|
||||
self.assertGreater(secret_key, 0)
|
||||
|
||||
|
||||
# class TestEncryptDecrypt(unittest.TestCase):
|
||||
@ -74,18 +91,22 @@ class TestHash4(unittest.TestCase):
|
||||
|
||||
# def test_encrypt_decrypt(self):
|
||||
# encrypted_message = Encrypt(self.public_key, self.message)
|
||||
# decrypted_message = Decrypt(self.secret_key, encrypted_message)
|
||||
# # 使用 SHA-256 哈希函数确保密钥为 16 字节
|
||||
# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest()
|
||||
# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数
|
||||
|
||||
# decrypted_message = Decrypt(secret_key_int, encrypted_message)
|
||||
# self.assertEqual(decrypted_message, self.message)
|
||||
|
||||
|
||||
# class TestGenerateReKey(unittest.TestCase):
|
||||
# def test_generate_rekey(self):
|
||||
# sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
# pk_B, _ = GenerateKeyPair()
|
||||
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
# self.assertIsInstance(rekey, list)
|
||||
# self.assertEqual(len(rekey), 5)
|
||||
class TestGenerateReKey(unittest.TestCase):
|
||||
def test_generate_rekey(self):
|
||||
sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
pk_B, _ = GenerateKeyPair()
|
||||
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
self.assertIsInstance(rekey, list)
|
||||
self.assertEqual(len(rekey), 5)
|
||||
|
||||
|
||||
class TestEncapsulate(unittest.TestCase):
|
||||
@ -97,18 +118,18 @@ class TestEncapsulate(unittest.TestCase):
|
||||
self.assertEqual(len(capsule), 3)
|
||||
|
||||
|
||||
# class TestReEncrypt(unittest.TestCase):
|
||||
# def test_reencrypt(self):
|
||||
# sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
# pk_B, _ = GenerateKeyPair()
|
||||
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
# pk_A, _ = GenerateKeyPair()
|
||||
# message = b"Hello, world!"
|
||||
# encrypted_message = Encrypt(pk_A, message)
|
||||
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
|
||||
# self.assertIsInstance(reencrypted_message, tuple)
|
||||
# self.assertEqual(len(reencrypted_message), 2)
|
||||
class TestReEncrypt(unittest.TestCase):
|
||||
def test_reencrypt(self):
|
||||
sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
pk_B, _ = GenerateKeyPair()
|
||||
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
pk_A, _ = GenerateKeyPair()
|
||||
message = b"Hello, world!"
|
||||
encrypted_message = Encrypt(pk_A, message)
|
||||
reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
|
||||
self.assertIsInstance(reencrypted_message, tuple)
|
||||
self.assertEqual(len(reencrypted_message), 2)
|
||||
|
||||
|
||||
# class TestDecryptFrags(unittest.TestCase):
|
||||
@ -123,9 +144,15 @@ class TestEncapsulate(unittest.TestCase):
|
||||
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
|
||||
# cfrags = [reencrypted_message]
|
||||
# merged_cfrags = MergeCFrag(cfrags)
|
||||
# decrypted_message = DecryptFrags(sk_B, pk_B, pk_A, merged_cfrags)
|
||||
|
||||
# self.assertIsNotNone(merged_cfrags)
|
||||
|
||||
# sk_B_hash = hashlib.sha256(sk_B.to_bytes((sk_B.bit_length() + 7) // 8, 'big')).digest()
|
||||
# sk_B_int = int.from_bytes(sk_B_hash[:16], 'big') # 取前 16 字节并转换为整数
|
||||
|
||||
# decrypted_message = DecryptFrags(sk_B_int, pk_B, pk_A, merged_cfrags)
|
||||
# self.assertEqual(decrypted_message, message)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# unittest.main()
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user