From aa0eb6bfbc9d263c544c55d1d1f3442105e4bc2e Mon Sep 17 00:00:00 2001 From: muzhi <2624758301@qq.com> Date: Wed, 2 Oct 2024 10:16:31 +0800 Subject: [PATCH 1/5] Remove unnecessary files and configurations --- server.db | Bin 0 -> 8192 bytes src/server_test.py | 156 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 135 insertions(+), 21 deletions(-) create mode 100644 server.db diff --git a/server.db b/server.db new file mode 100644 index 0000000000000000000000000000000000000000..7fc2400b9149e642f0466c9e39a3ebf434b110b2 GIT binary patch literal 8192 zcmeIuu?c`c429u$VkL-j6WjzgF5mz*Vq+_!U?GCsP>$&dT3@aFfkz-o5x%}_S3kRD zKeo}kTBVFpK?RA$HCl0R#|0009ILKmY**5I_Kd f&k3BL` Date: Fri, 4 Oct 2024 22:28:14 +0800 Subject: [PATCH 2/5] Improve test code files --- FQA.md | 1 + server.db | Bin 8192 -> 8192 bytes src/client_test.py | 0 src/node_test.log | 1 - src/server_test1.py | 122 ------------------------ src/tpre_test.py | 58 ----------- {src => tests}/ecc_speed_test.py | 6 +- {src => tests}/lenth_test.py | 4 + {src => tests}/maxnode_test.py | 4 + {src => tests}/node_test.py | 52 ++++++---- {src => tests}/node_test5.py | 93 ++++++++++++------ {src => tests}/server_test.py | 7 +- {src => tests}/speed_test.py | 4 + {src => tests}/test_client.py | 26 ++--- src/tpre_test1.py => tests/tpre_test.py | 95 +++++++++++------- 15 files changed, 191 insertions(+), 282 deletions(-) delete mode 100644 src/client_test.py delete mode 100644 src/node_test.log delete mode 100644 src/server_test1.py delete mode 100644 src/tpre_test.py rename {src => tests}/ecc_speed_test.py (87%) rename {src => tests}/lenth_test.py (94%) rename {src => tests}/maxnode_test.py (94%) rename {src => tests}/node_test.py (63%) rename {src => tests}/node_test5.py (71%) rename {src => tests}/server_test.py (99%) rename {src => tests}/speed_test.py (93%) rename {src => tests}/test_client.py (89%) rename src/tpre_test1.py => tests/tpre_test.py (54%) diff --git a/FQA.md b/FQA.md index 54e36c2..0ec47be 100644 --- a/FQA.md +++ b/FQA.md @@ -59,3 +59,4 @@ 15. **Q:国密SM2 SM3 SM4** 如题 **A:** +---------------------------------------------------------------------- diff --git a/server.db b/server.db index 7fc2400b9149e642f0466c9e39a3ebf434b110b2..fd9eb9d46f2fd5fb19e4695ad8545bb6e9a913ed 100644 GIT binary patch delta 33 hcmZp0XmFSyEvUl4z`z8>j6hmsqK+}6%Ep8R@&H#^1z!LF delta 33 hcmZp0XmFSyEhxmmz`z8>j6j-eqK+}6(8h!X@&HuF1u*~s diff --git a/src/client_test.py b/src/client_test.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/node_test.log b/src/node_test.log deleted file mode 100644 index fad9afe..0000000 --- a/src/node_test.log +++ /dev/null @@ -1 +0,0 @@ -DEBUG:asyncio:Using selector: EpollSelector diff --git a/src/server_test1.py b/src/server_test1.py deleted file mode 100644 index a78a93f..0000000 --- a/src/server_test1.py +++ /dev/null @@ -1,122 +0,0 @@ -import sqlite3 -import pytest -from fastapi.testclient import TestClient -from server import app, validate_ip - -# 创建 TestClient 实例 -client = TestClient(app) - -# 准备测试数据库数据 -def setup_db(): - # 创建数据库并插入测试数据 - with sqlite3.connect("server.db") as db: - db.execute(""" - CREATE TABLE IF NOT EXISTS nodes ( - id INTEGER PRIMARY KEY, - ip TEXT NOT NULL, - last_heartbeat INTEGER NOT NULL - ) - """) - db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)") - db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)") - db.commit() - -# 清空数据库 -def clear_db(): - with sqlite3.connect("server.db") as db: - db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表 - db.commit() - - -# 测试 IP 验证功能 -def test_validate_ip(): - assert validate_ip("192.168.0.1") is True - assert validate_ip("256.256.256.256") is False - assert validate_ip("::1") is True - assert validate_ip("invalid_ip") is False - -# 测试首页路由 -def test_home(): - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"message": "Hello, World!"} - -# 测试 show_nodes 路由 -def test_show_nodes(): - setup_db() - - response = client.get("/server/show_nodes") - assert response.status_code == 200 - - data = response.json() - assert len(data) == 2 - assert data[0][1] == "192.168.0.1" - assert data[1][1] == "192.168.0.2" - -# 测试 get_node 路由 -def test_get_node(): - # 确保数据库和表的存在 - setup_db() - - valid_ip = "192.168.0.3" - invalid_ip = "256.256.256.256" - - # 测试有效的 IP 地址 - response = client.get(f"/server/get_node?ip={valid_ip}") - assert response.status_code == 200 - - # 测试无效的 IP 地址 - response = client.get(f"/server/get_node?ip={invalid_ip}") - assert response.status_code == 400 - - -# 测试 delete_node 路由 -def test_delete_node(): - setup_db() - - valid_ip = "192.168.0.1" - invalid_ip = "192.168.0.255" - - response = client.get(f"/server/delete_node?ip={valid_ip}") - assert response.status_code == 200 - assert "Node with IP 192.168.0.1 deleted successfully." in response.text - - response = client.get(f"/server/delete_node?ip={invalid_ip}") - assert response.status_code == 404 - -# 测试 heartbeat 路由 -def test_receive_heartbeat(): - setup_db() - - valid_ip = "192.168.0.2" - invalid_ip = "256.256.256.256" - - response = client.get(f"/server/heartbeat?ip={valid_ip}") - assert response.status_code == 200 - assert response.json() == {"status": "received"} - - response = client.get(f"/server/heartbeat?ip={invalid_ip}") - assert response.status_code == 400 - assert response.json() == {"message": "invalid ip format"} - -# 测试 send_nodes_list 路由 -def test_send_nodes_list(): - setup_db() - - response = client.get("/server/send_nodes_list?count=1") - assert response.status_code == 200 - data = response.json() - assert len(data) == 1 - assert data[0] == "192.168.0.1" - - response = client.get("/server/send_nodes_list?count=2") - assert response.status_code == 200 - data = response.json() - assert len(data) == 2 - -# 运行完测试后清理数据库 -@pytest.fixture(autouse=True) -def run_around_tests(): - clear_db() - yield - clear_db() diff --git a/src/tpre_test.py b/src/tpre_test.py deleted file mode 100644 index 954e575..0000000 --- a/src/tpre_test.py +++ /dev/null @@ -1,58 +0,0 @@ -from tpre import hash2, hash3, hash4, multiply, g, sm2p256v1 -import random -import unittest - - -class TestHash2(unittest.TestCase): - def setUp(self): - self.double_G = ( - multiply(g, random.randint(0, sm2p256v1.N - 1)), - multiply(g, random.randint(0, sm2p256v1.N - 1)), - ) - - def test_digest_type(self): - digest = hash2(self.double_G) - self.assertEqual(type(digest), int) - - def test_digest_size(self): - digest = hash2(self.double_G) - self.assertLess(digest, sm2p256v1.N) - - -class TestHash3(unittest.TestCase): - def setUp(self): - self.triple_G = ( - multiply(g, random.randint(0, sm2p256v1.N - 1)), - multiply(g, random.randint(0, sm2p256v1.N - 1)), - multiply(g, random.randint(0, sm2p256v1.N - 1)), - ) - - def test_digest_type(self): - digest = hash3(self.triple_G) - self.assertEqual(type(digest), int) - - def test_digest_size(self): - digest = hash3(self.triple_G) - self.assertLess(digest, sm2p256v1.N) - - -class TestHash4(unittest.TestCase): - def setUp(self): - self.triple_G = ( - multiply(g, random.randint(0, sm2p256v1.N - 1)), - multiply(g, random.randint(0, sm2p256v1.N - 1)), - multiply(g, random.randint(0, sm2p256v1.N - 1)), - ) - self.Zp = random.randint(0, sm2p256v1.N - 1) - - def test_digest_type(self): - digest = hash4(self.triple_G, self.Zp) - self.assertEqual(type(digest), int) - - def test_digest_size(self): - digest = hash4(self.triple_G, self.Zp) - self.assertLess(digest, sm2p256v1.N) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/ecc_speed_test.py b/tests/ecc_speed_test.py similarity index 87% rename from src/ecc_speed_test.py rename to tests/ecc_speed_test.py index 03d8187..cd8f369 100644 --- a/src/ecc_speed_test.py +++ b/tests/ecc_speed_test.py @@ -1,3 +1,7 @@ +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from tpre import add, multiply, sm2p256v1 import time @@ -30,4 +34,4 @@ for i in range(10): result = add(g, g, 0) # 执行函数 end_time = time.time() # 获取结束时间 elapsed_time = end_time - start_time # 计算执行时间 -print(f"python add 执行时间: {elapsed_time:.6f} 秒") +print(f"python add 执行时间: {elapsed_time:.6f} 秒") \ No newline at end of file diff --git a/src/lenth_test.py b/tests/lenth_test.py similarity index 94% rename from src/lenth_test.py rename to tests/lenth_test.py index 8e48481..c27d23d 100644 --- a/src/lenth_test.py +++ b/tests/lenth_test.py @@ -1,3 +1,7 @@ +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from tpre import ( GenerateKeyPair, Encrypt, diff --git a/src/maxnode_test.py b/tests/maxnode_test.py similarity index 94% rename from src/maxnode_test.py rename to tests/maxnode_test.py index ee87c2b..b0fb346 100644 --- a/src/maxnode_test.py +++ b/tests/maxnode_test.py @@ -1,3 +1,7 @@ +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from tpre import * import time diff --git a/src/node_test.py b/tests/node_test.py similarity index 63% rename from src/node_test.py rename to tests/node_test.py index 2b5cdaf..6be7bf6 100644 --- a/src/node_test.py +++ b/tests/node_test.py @@ -2,45 +2,53 @@ import os import unittest from unittest.mock import patch, MagicMock, Mock +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) import node + class TestGetLocalIP(unittest.TestCase): - @patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量 + @patch.dict("os.environ", {"HOST_IP": "60.204.193.58"}) # 模拟设置 HOST_IP 环境变量 def test_get_ip_from_env(self): # 调用被测函数 node.get_local_ip() - - # 检查函数是否正确获取到 HOST_IP - self.assertEqual(node.ip, '60.204.193.58') - @patch('socket.socket') # Mock socket 连接行为 - @patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量 + # 检查函数是否正确获取到 HOST_IP + self.assertEqual(node.ip, "60.204.193.58") + + @patch("socket.socket") # Mock socket 连接行为 + @patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量 def test_get_ip_from_socket(self, mock_socket): # 模拟 socket 返回的 IP 地址 mock_socket_instance = MagicMock() mock_socket.return_value = mock_socket_instance - mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0) + mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0) # 调用被测函数 node.get_local_ip() # 确认 socket 被调用过 - mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80)) + mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80)) mock_socket_instance.close.assert_called_once() # 检查是否通过 socket 获取到正确的 IP 地址 - self.assertEqual(node.ip, '110.41.155.96') + self.assertEqual(node.ip, "110.41.155.96") + class TestSendIP(unittest.TestCase): - @patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP - @patch('requests.get') # Mock requests.get 调用 + @patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP + @patch("requests.get") # Mock requests.get 调用 def test_send_ip(self, mock_get): # 设置模拟返回的 HTTP 响应 mock_response = Mock() mock_response.text = "node123" # 模拟返回的节点ID mock_response.status_code = 200 - mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response + mock_get.return_value = ( + mock_response # 设置 requests.get() 的返回值为 mock_response + ) # 保存原始的全局 id 值 original_id = node.id @@ -54,17 +62,20 @@ class TestSendIP(unittest.TestCase): # 检查 id 是否被正确更新 self.assertIs(node.id, mock_response) # 检查 id 是否被修改 - self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配 + self.assertEqual( + node.id.text, "node123" + ) # 检查更新后的 id 是否与 mock_response.text 匹配 + class TestNode(unittest.TestCase): - - @patch('node.send_ip') - @patch('node.get_local_ip') - @patch('node.asyncio.create_task') + + @patch("node.send_ip") + @patch("node.get_local_ip") + @patch("node.asyncio.create_task") def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip): # 调用 init 函数 node.init() - + # 验证 get_local_ip 和 send_ip 被调用 mock_get_local_ip.assert_called_once() mock_send_ip.assert_called_once() @@ -72,5 +83,6 @@ class TestNode(unittest.TestCase): # 确保 create_task 被调用来启动心跳包 mock_create_task.assert_called_once() -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/src/node_test5.py b/tests/node_test5.py similarity index 71% rename from src/node_test5.py rename to tests/node_test5.py index 124f970..12aae1b 100644 --- a/src/node_test5.py +++ b/tests/node_test5.py @@ -1,4 +1,4 @@ -#node_test剩下部分(有问题) +# node_test剩下部分(有问题) import os import unittest import pytest @@ -8,49 +8,67 @@ import asyncio import httpx import respx from fastapi.testclient import TestClient -from node import app, send_heartbeat_internal, Req, send_ip, get_local_ip, init, clear, send_user_des_message +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) +from node import ( + app, + send_heartbeat_internal, + Req, + send_ip, + get_local_ip, + init, + clear, + send_user_des_message, + ip, + id, +) + client = TestClient(app) server_address = "http://60.204.236.38:8000/server" -ip = None # 初始化全局变量 ip -id = None # 初始化全局变量 id +# ip = None # 初始化全局变量 ip +# id = None # 初始化全局变量 id + class TestGetLocalIP(unittest.TestCase): - @patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量 + os.environ["HOST_IP"] = "60.204.193.58" # 模拟设置 HOST_IP 环境变量 + def test_get_ip_from_env(self): global ip # 调用被测函数 get_local_ip() - # 检查函数是否正确获取到 HOST_IP - self.assertEqual(ip, '60.204.193.58') + self.assertEqual(ip, "60.204.193.58") - @patch('socket.socket') # Mock socket 连接行为 - @patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量 + @patch("socket.socket") # Mock socket 连接行为 + @patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量 def test_get_ip_from_socket(self, mock_socket): global ip # 模拟 socket 返回的 IP 地址 mock_socket_instance = MagicMock() mock_socket.return_value = mock_socket_instance - mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0) + mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0) # 调用被测函数 get_local_ip() # 确认 socket 被调用过 - mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80)) + mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80)) mock_socket_instance.close.assert_called_once() # 检查是否通过 socket 获取到正确的 IP 地址 - self.assertEqual(ip, '110.41.155.96') + self.assertEqual(ip, "110.41.155.96") + class TestSendIP(unittest.TestCase): - @patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP + @patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP @respx.mock def test_send_ip(self): global ip, id - ip = '60.204.193.58' + ip = "60.204.193.58" mock_url = f"{server_address}/get_node?ip={ip}" respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123")) @@ -58,17 +76,20 @@ class TestSendIP(unittest.TestCase): send_ip() # 确保 requests.get 被正确调用 - self.assertEqual(id, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配 + self.assertEqual( + id, "node123" + ) # 检查更新后的 id 是否与 mock_response.text 匹配 + class TestNode(unittest.TestCase): - - @patch('node.send_ip') - @patch('node.get_local_ip') - @patch('node.asyncio.create_task') + + @patch("node.send_ip") + @patch("node.get_local_ip") + @patch("node.asyncio.create_task") def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip): # 调用 init 函数 init() - + # 验证 get_local_ip 和 send_ip 被调用 mock_get_local_ip.assert_called_once() mock_send_ip.assert_called_once() @@ -82,11 +103,12 @@ class TestNode(unittest.TestCase): # 检查输出 self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试 + @pytest.mark.asyncio @respx.mock async def test_send_heartbeat_internal_success(): global ip - ip = '60.204.193.58' + ip = "60.204.193.58" # 模拟心跳请求 heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( return_value=httpx.Response(200) @@ -107,18 +129,21 @@ async def test_send_heartbeat_internal_success(): assert mock_get.called assert mock_get.call_count > 0 + @pytest.mark.asyncio @respx.mock async def test_send_heartbeat_internal_failure(): global ip - ip = '60.204.193.58' + ip = "60.204.193.58" # 模拟心跳请求以引发异常 heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( side_effect=httpx.RequestError("Central server error") ) # 模拟 requests.get 以避免实际请求 - with patch("requests.get", side_effect=httpx.RequestError("Central server error")) as mock_get: + with patch( + "requests.get", side_effect=httpx.RequestError("Central server error") + ) as mock_get: # 模拟 asyncio.sleep 以避免实际延迟 with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: task = asyncio.create_task(send_heartbeat_internal()) @@ -132,23 +157,29 @@ async def test_send_heartbeat_internal_failure(): assert mock_get.called assert mock_get.call_count > 0 + def test_user_src(): # 模拟 ReEncrypt 函数 - with patch("node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")): + with patch( + "node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data") + ): # 模拟 send_user_des_message 函数 - with patch("node.send_user_des_message", new_callable=AsyncMock) as mock_send_user_des_message: + with patch( + "node.send_user_des_message", new_callable=AsyncMock + ) as mock_send_user_des_message: message = { "source_ip": "60.204.193.58", "dest_ip": "60.204.193.59", "capsule": (("x1", "y1"), ("x2", "y2"), 123), "ct": 456, - "rk": ["rk1", "rk2"] + "rk": ["rk1", "rk2"], } response = client.post("/user_src", json=message) assert response.status_code == 200 assert response.json() == {"detail": "message received"} mock_send_user_des_message.assert_called_once() + def test_send_user_des_message(): with respx.mock: dest_ip = "60.204.193.59" @@ -156,9 +187,13 @@ def test_send_user_des_message(): respx.post(f"http://{dest_ip}:8002/receive_messages").mock( return_value=httpx.Response(200, json={"status": "success"}) ) - response = requests.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "60.204.193.58"}) + response = requests.post( + f"http://{dest_ip}:8002/receive_messages", + json={"Tuple": re_message, "ip": "60.204.193.58"}, + ) assert response.status_code == 200 assert response.json() == {"status": "success"} -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/src/server_test.py b/tests/server_test.py similarity index 99% rename from src/server_test.py rename to tests/server_test.py index 9aab8e2..85a1b03 100644 --- a/src/server_test.py +++ b/tests/server_test.py @@ -1,11 +1,10 @@ +import sqlite3 +import pytest +from fastapi.testclient import TestClient import sys import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) - -import sqlite3 -import pytest -from fastapi.testclient import TestClient from server import app, validate_ip # 创建 TestClient 实例 diff --git a/src/speed_test.py b/tests/speed_test.py similarity index 93% rename from src/speed_test.py rename to tests/speed_test.py index d322f75..cbb8eca 100644 --- a/src/speed_test.py +++ b/tests/speed_test.py @@ -1,3 +1,7 @@ +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from tpre import ( GenerateKeyPair, Encrypt, diff --git a/src/test_client.py b/tests/test_client.py similarity index 89% rename from src/test_client.py rename to tests/test_client.py index 9cb2341..2e08067 100644 --- a/src/test_client.py +++ b/tests/test_client.py @@ -1,13 +1,14 @@ import os import pytest -import sqlite3 -import respx -import httpx from fastapi.testclient import TestClient +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from client import app, init_db, clean_env, get_own_ip client = TestClient(app) + @pytest.fixture(scope="module", autouse=True) def setup_and_teardown(): # 设置测试环境 @@ -16,20 +17,20 @@ def setup_and_teardown(): # 清理测试环境 clean_env() + def test_read_root(): response = client.get("/") assert response.status_code == 200 assert response.json() == {"message": "Hello, World!"} + def test_receive_messages(): - message = { - "Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), - "ip": "127.0.0.1" - } + message = {"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), "ip": "127.0.0.1"} response = client.post("/receive_messages", json=message) assert response.status_code == 200 assert response.json().get("detail") == "Message received" + # @respx.mock # def test_request_message(): # request_message = { @@ -56,21 +57,20 @@ def test_receive_messages(): # assert "threshold" in response.json() # assert "public_key" in response.json() + def test_get_pk(): response = client.get("/get_pk") assert response.status_code == 200 assert "pkx" in response.json() assert "pky" in response.json() + def test_recieve_pk(): - pk_data = { - "pkx": "123", - "pky": "456", - "ip": "127.0.0.1" - } + pk_data = {"pkx": "123", "pky": "456", "ip": "127.0.0.1"} response = client.post("/recieve_pk", json=pk_data) assert response.status_code == 200 assert response.json() == {"message": "save pk in database"} + if __name__ == "__main__": - pytest.main() \ No newline at end of file + pytest.main() diff --git a/src/tpre_test1.py b/tests/tpre_test.py similarity index 54% rename from src/tpre_test1.py rename to tests/tpre_test.py index fce532a..8a90746 100644 --- a/src/tpre_test1.py +++ b/tests/tpre_test.py @@ -1,9 +1,24 @@ +import sys +import os +import hashlib + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from tpre import ( - hash2, hash3, hash4, multiply, g, sm2p256v1, - GenerateKeyPair, Encrypt, Decrypt, GenerateReKey, - Encapsulate, ReEncrypt, DecryptFrags + hash2, + hash3, + hash4, + multiply, + g, + sm2p256v1, + GenerateKeyPair, + Encrypt, + Decrypt, + GenerateReKey, + Encapsulate, + ReEncrypt, + DecryptFrags, + MergeCFrag, ) -from tpre import MergeCFrag import random import unittest @@ -59,12 +74,14 @@ class TestHash4(unittest.TestCase): self.assertLess(digest, sm2p256v1.N) -# class TestGenerateKeyPair(unittest.TestCase): -# def test_key_pair(self): -# public_key, secret_key = GenerateKeyPair() -# self.assertIsInstance(public_key, tuple) -# self.assertIsInstance(secret_key, int) -# self.assertEqual(len(public_key), 2) +class TestGenerateKeyPair(unittest.TestCase): + def test_key_pair(self): + public_key, secret_key = GenerateKeyPair() + self.assertIsInstance(public_key, tuple) + self.assertEqual(len(public_key), 2) + self.assertIsInstance(secret_key, int) + self.assertLess(secret_key, sm2p256v1.N) + self.assertGreater(secret_key, 0) # class TestEncryptDecrypt(unittest.TestCase): @@ -74,18 +91,22 @@ class TestHash4(unittest.TestCase): # def test_encrypt_decrypt(self): # encrypted_message = Encrypt(self.public_key, self.message) -# decrypted_message = Decrypt(self.secret_key, encrypted_message) +# # 使用 SHA-256 哈希函数确保密钥为 16 字节 +# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest() +# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数 + +# decrypted_message = Decrypt(secret_key_int, encrypted_message) # self.assertEqual(decrypted_message, self.message) -# class TestGenerateReKey(unittest.TestCase): -# def test_generate_rekey(self): -# sk_A = random.randint(0, sm2p256v1.N - 1) -# pk_B, _ = GenerateKeyPair() -# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) -# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) -# self.assertIsInstance(rekey, list) -# self.assertEqual(len(rekey), 5) +class TestGenerateReKey(unittest.TestCase): + def test_generate_rekey(self): + sk_A = random.randint(0, sm2p256v1.N - 1) + pk_B, _ = GenerateKeyPair() + id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) + rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) + self.assertIsInstance(rekey, list) + self.assertEqual(len(rekey), 5) class TestEncapsulate(unittest.TestCase): @@ -97,18 +118,18 @@ class TestEncapsulate(unittest.TestCase): self.assertEqual(len(capsule), 3) -# class TestReEncrypt(unittest.TestCase): -# def test_reencrypt(self): -# sk_A = random.randint(0, sm2p256v1.N - 1) -# pk_B, _ = GenerateKeyPair() -# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) -# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) -# pk_A, _ = GenerateKeyPair() -# message = b"Hello, world!" -# encrypted_message = Encrypt(pk_A, message) -# reencrypted_message = ReEncrypt(rekey[0], encrypted_message) -# self.assertIsInstance(reencrypted_message, tuple) -# self.assertEqual(len(reencrypted_message), 2) +class TestReEncrypt(unittest.TestCase): + def test_reencrypt(self): + sk_A = random.randint(0, sm2p256v1.N - 1) + pk_B, _ = GenerateKeyPair() + id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) + rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) + pk_A, _ = GenerateKeyPair() + message = b"Hello, world!" + encrypted_message = Encrypt(pk_A, message) + reencrypted_message = ReEncrypt(rekey[0], encrypted_message) + self.assertIsInstance(reencrypted_message, tuple) + self.assertEqual(len(reencrypted_message), 2) # class TestDecryptFrags(unittest.TestCase): @@ -123,9 +144,15 @@ class TestEncapsulate(unittest.TestCase): # reencrypted_message = ReEncrypt(rekey[0], encrypted_message) # cfrags = [reencrypted_message] # merged_cfrags = MergeCFrag(cfrags) -# decrypted_message = DecryptFrags(sk_B, pk_B, pk_A, merged_cfrags) + +# self.assertIsNotNone(merged_cfrags) + +# sk_B_hash = hashlib.sha256(sk_B.to_bytes((sk_B.bit_length() + 7) // 8, 'big')).digest() +# sk_B_int = int.from_bytes(sk_B_hash[:16], 'big') # 取前 16 字节并转换为整数 + +# decrypted_message = DecryptFrags(sk_B_int, pk_B, pk_A, merged_cfrags) # self.assertEqual(decrypted_message, message) -# if __name__ == "__main__": -# unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() From 5b85db242748956ad6f0a6d18ccab35bb0ca858c Mon Sep 17 00:00:00 2001 From: Tritium Date: Tue, 8 Oct 2024 20:25:04 +0800 Subject: [PATCH 3/5] =?UTF-8?q?fix:serverIP=E4=BB=8E=E7=8E=AF=E5=A2=83?= =?UTF-8?q?=E5=8F=98=E9=87=8F=E8=8E=B7=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/client.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/client.py b/src/client.py index e355d4b..1d05156 100644 --- a/src/client.py +++ b/src/client.py @@ -89,13 +89,8 @@ def init_db(): # load config from config file def init_config(): - import configparser - global server_address - config = configparser.ConfigParser() - config.read("client.ini") - - server_address = config["settings"]["server_address"] + server_address = os.environ.get("server_address") # execute on exit @@ -469,7 +464,7 @@ async def recieve_pk(pk: pk_model): pk = (0, 0) sk = 0 -server_address = str +server_address = os.environ.get("server_address") node_response = False message = bytes local_ip = get_own_ip() From b26bd92328ea6888591901fdf0371aac45ebb868 Mon Sep 17 00:00:00 2001 From: Tritium Date: Tue, 8 Oct 2024 20:25:36 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20src/node.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/node.py b/src/node.py index e7a2c08..db01496 100644 --- a/src/node.py +++ b/src/node.py @@ -25,7 +25,7 @@ async def lifespan(_: FastAPI): message_list = [] app = FastAPI(lifespan=lifespan) -server_address = "http://60.204.236.38:8000/server" +server_address = os.environ.get("server_address") id = 0 ip = "" client_ip_src = "" # 发送信息用户的ip From 15e35405f088b9821998c8c3457dc58ab22c3c50 Mon Sep 17 00:00:00 2001 From: Tritium Date: Tue, 8 Oct 2024 20:30:35 +0800 Subject: [PATCH 5/5] =?UTF-8?q?feat=EF=BC=9A=E6=B7=BB=E5=8A=A0docker-compo?= =?UTF-8?q?se?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docker-compose.yml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..3bc281f --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,35 @@ +version: "3" +services: + server: + image: git.mamahaha.work/sangge/tpre:base + volumes: + - ./src:/app + environment: + - server_address=http://server:8000 + entrypoint: + - nohup + - python + - server.py + - & + node: + image: git.mamahaha.work/sangge/tpre:base + volumes: + - ./src:/app + environment: + - server_address=http://server:8000 + entrypoint: + - nohup + - python + - node.py + - & + client: + image: git.mamahaha.work/sangge/tpre:base + volumes: + - ./src:/app + environment: + - server_address=http://server:8000 + entrypoint: + - nohup + - python + - client.py + - & \ No newline at end of file