From acbc9ecb1a37581682f6d13b292f5825016aace8 Mon Sep 17 00:00:00 2001 From: muzhi <2624758301@qq.com> Date: Mon, 23 Sep 2024 18:46:25 +0800 Subject: [PATCH 1/3] test(Add unit tests): --- src/node_test.log | 1 + src/node_test1.py | 35 ++++++++++++ src/node_test2.py | 34 ++++++++++++ src/node_test3.py | 23 ++++++++ src/node_test5.py | 89 ++++++++++++++++++++++++++++++ src/pytest.ini | 2 + src/server_test1.py | 122 +++++++++++++++++++++++++++++++++++++++++ src/tpre_test1.py | 131 ++++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 437 insertions(+) create mode 100644 src/node_test.log create mode 100644 src/node_test1.py create mode 100644 src/node_test2.py create mode 100644 src/node_test3.py create mode 100644 src/node_test5.py create mode 100644 src/pytest.ini create mode 100644 src/server_test1.py create mode 100644 src/tpre_test1.py diff --git a/src/node_test.log b/src/node_test.log new file mode 100644 index 0000000..fad9afe --- /dev/null +++ b/src/node_test.log @@ -0,0 +1 @@ +DEBUG:asyncio:Using selector: EpollSelector diff --git a/src/node_test1.py b/src/node_test1.py new file mode 100644 index 0000000..66a7a47 --- /dev/null +++ b/src/node_test1.py @@ -0,0 +1,35 @@ +#测试 get_local_ip()函数 +import unittest +from unittest.mock import patch, MagicMock +import node + +class TestGetLocalIP(unittest.TestCase): + + @patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量 + def test_get_ip_from_env(self): + # 调用被测函数 + node.get_local_ip() + + # 检查函数是否正确获取到 HOST_IP + self.assertEqual(node.ip, '60.204.193.58') + + @patch('socket.socket') # Mock socket 连接行为 + @patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量 + def test_get_ip_from_socket(self, mock_socket): + # 模拟 socket 返回的 IP 地址 + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0) + + # 调用被测函数 + node.get_local_ip() + + # 确认 socket 被调用过 + mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80)) + mock_socket_instance.close.assert_called_once() + + # 检查是否通过 socket 获取到正确的 IP 地址 + self.assertEqual(node.ip, '110.41.155.96') + +if __name__ == '__main__': + unittest.main() diff --git a/src/node_test2.py b/src/node_test2.py new file mode 100644 index 0000000..7e01de8 --- /dev/null +++ b/src/node_test2.py @@ -0,0 +1,34 @@ +#测试send_ip()函数 +import os +import unittest +from unittest.mock import patch, Mock +import node # 导入要测试的模块 + +class TestSendIP(unittest.TestCase): + @patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP + @patch('requests.get') # Mock requests.get 调用 + def test_send_ip(self, mock_get): + # 设置模拟返回的 HTTP 响应 + mock_response = Mock() + mock_response.text = "node123" # 模拟返回的节点ID + mock_response.status_code = 200 + mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response + + # 保存原始的全局 id 值 + original_id = node.id + + # 调用待测函数 + node.send_ip() + + # 确保 requests.get 被正确调用 + expected_url = f"{node.server_address}/get_node?ip={node.ip}" + mock_get.assert_called_once_with(expected_url, timeout=3) + + # 检查 id 是否被正确更新 + self.assertIs(node.id, mock_response) # 检查 id 是否被修改 + self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配 + +if __name__ == "__main__": + unittest.main() +#node.py中 +#print("中心服务器返回节点ID为: ", id.text)即可看到测试代码返回的节点 diff --git a/src/node_test3.py b/src/node_test3.py new file mode 100644 index 0000000..a92d921 --- /dev/null +++ b/src/node_test3.py @@ -0,0 +1,23 @@ +#测试init()函数 +import unittest +from unittest.mock import patch, AsyncMock +import node + +class TestNode(unittest.TestCase): + + @patch('node.send_ip') + @patch('node.get_local_ip') + @patch('node.asyncio.create_task') + def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip): + # 调用 init 函数 + node.init() + + # 验证 get_local_ip 和 send_ip 被调用 + mock_get_local_ip.assert_called_once() + mock_send_ip.assert_called_once() + + # 确保 create_task 被调用来启动心跳包 + mock_create_task.assert_called_once() + +if __name__ == '__main__': + unittest.main() diff --git a/src/node_test5.py b/src/node_test5.py new file mode 100644 index 0000000..3432bfe --- /dev/null +++ b/src/node_test5.py @@ -0,0 +1,89 @@ +import pytest +import httpx +import respx +import asyncio +from unittest.mock import patch, AsyncMock +from fastapi.testclient import TestClient +from node import app, send_heartbeat_internal, Req + +client = TestClient(app) +server_address = "http://60.204.236.38:8000/server" +ip = "127.0.0.1" + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + +@pytest.mark.asyncio +@respx.mock +async def test_send_heartbeat_internal_success(): + # 模拟心跳请求 + heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( + return_value=httpx.Response(200) + ) + + # 模拟 requests.get 以避免实际请求 + with patch("requests.get", return_value=httpx.Response(200)) as mock_get: + # 模拟 asyncio.sleep 以避免实际延迟 + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + task = asyncio.create_task(send_heartbeat_internal()) + await asyncio.sleep(0.1) # 允许任务运行一段时间 + task.cancel() # 取消任务以停止无限循环 + try: + await task # 确保任务被等待 + except asyncio.CancelledError: + pass # 捕获取消错误 + + assert mock_get.called + assert mock_get.call_count > 0 + +@pytest.mark.asyncio +@respx.mock +async def test_send_heartbeat_internal_failure(): + # 模拟心跳请求以引发异常 + heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( + side_effect=httpx.RequestError("Central server error") + ) + + # 模拟 requests.get 以避免实际请求 + with patch("requests.get", side_effect=httpx.RequestError("Central server error")) as mock_get: + # 模拟 asyncio.sleep 以避免实际延迟 + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + task = asyncio.create_task(send_heartbeat_internal()) + await asyncio.sleep(0.1) # 允许任务运行一段时间 + task.cancel() # 取消任务以停止无限循环 + try: + await task # 确保任务被等待 + except asyncio.CancelledError: + pass # 捕获取消错误 + + assert mock_get.called + assert mock_get.call_count > 0 + +def test_user_src(): + # 模拟 ReEncrypt 函数 + with patch("node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")): + # 模拟 send_user_des_message 函数 + with patch("node.send_user_des_message", new_callable=AsyncMock) as mock_send_user_des_message: + message = { + "source_ip": "127.0.0.1", + "dest_ip": "127.0.0.2", + "capsule": (("x1", "y1"), ("x2", "y2"), 123), + "ct": 456, + "rk": ["rk1", "rk2"] + } + response = client.post("/user_src", json=message) + assert response.status_code == 200 + assert response.json() == {"detail": "message received"} + mock_send_user_des_message.assert_called_once() + +def test_send_user_des_message(): + with respx.mock: + dest_ip = "127.0.0.2" + re_message = (("a", "b", "c", "d"), 123) + respx.post(f"http://{dest_ip}:8002/receive_messages").mock( + return_value=httpx.Response(200, json={"status": "success"}) + ) + response = client.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "127.0.0.1"}) + assert response.status_code == 200 + assert response.json() == {"status": "success"} \ No newline at end of file diff --git a/src/pytest.ini b/src/pytest.ini new file mode 100644 index 0000000..6a7d170 --- /dev/null +++ b/src/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/src/server_test1.py b/src/server_test1.py new file mode 100644 index 0000000..a78a93f --- /dev/null +++ b/src/server_test1.py @@ -0,0 +1,122 @@ +import sqlite3 +import pytest +from fastapi.testclient import TestClient +from server import app, validate_ip + +# 创建 TestClient 实例 +client = TestClient(app) + +# 准备测试数据库数据 +def setup_db(): + # 创建数据库并插入测试数据 + with sqlite3.connect("server.db") as db: + db.execute(""" + CREATE TABLE IF NOT EXISTS nodes ( + id INTEGER PRIMARY KEY, + ip TEXT NOT NULL, + last_heartbeat INTEGER NOT NULL + ) + """) + db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)") + db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)") + db.commit() + +# 清空数据库 +def clear_db(): + with sqlite3.connect("server.db") as db: + db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表 + db.commit() + + +# 测试 IP 验证功能 +def test_validate_ip(): + assert validate_ip("192.168.0.1") is True + assert validate_ip("256.256.256.256") is False + assert validate_ip("::1") is True + assert validate_ip("invalid_ip") is False + +# 测试首页路由 +def test_home(): + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "Hello, World!"} + +# 测试 show_nodes 路由 +def test_show_nodes(): + setup_db() + + response = client.get("/server/show_nodes") + assert response.status_code == 200 + + data = response.json() + assert len(data) == 2 + assert data[0][1] == "192.168.0.1" + assert data[1][1] == "192.168.0.2" + +# 测试 get_node 路由 +def test_get_node(): + # 确保数据库和表的存在 + setup_db() + + valid_ip = "192.168.0.3" + invalid_ip = "256.256.256.256" + + # 测试有效的 IP 地址 + response = client.get(f"/server/get_node?ip={valid_ip}") + assert response.status_code == 200 + + # 测试无效的 IP 地址 + response = client.get(f"/server/get_node?ip={invalid_ip}") + assert response.status_code == 400 + + +# 测试 delete_node 路由 +def test_delete_node(): + setup_db() + + valid_ip = "192.168.0.1" + invalid_ip = "192.168.0.255" + + response = client.get(f"/server/delete_node?ip={valid_ip}") + assert response.status_code == 200 + assert "Node with IP 192.168.0.1 deleted successfully." in response.text + + response = client.get(f"/server/delete_node?ip={invalid_ip}") + assert response.status_code == 404 + +# 测试 heartbeat 路由 +def test_receive_heartbeat(): + setup_db() + + valid_ip = "192.168.0.2" + invalid_ip = "256.256.256.256" + + response = client.get(f"/server/heartbeat?ip={valid_ip}") + assert response.status_code == 200 + assert response.json() == {"status": "received"} + + response = client.get(f"/server/heartbeat?ip={invalid_ip}") + assert response.status_code == 400 + assert response.json() == {"message": "invalid ip format"} + +# 测试 send_nodes_list 路由 +def test_send_nodes_list(): + setup_db() + + response = client.get("/server/send_nodes_list?count=1") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0] == "192.168.0.1" + + response = client.get("/server/send_nodes_list?count=2") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + +# 运行完测试后清理数据库 +@pytest.fixture(autouse=True) +def run_around_tests(): + clear_db() + yield + clear_db() diff --git a/src/tpre_test1.py b/src/tpre_test1.py new file mode 100644 index 0000000..fce532a --- /dev/null +++ b/src/tpre_test1.py @@ -0,0 +1,131 @@ +from tpre import ( + hash2, hash3, hash4, multiply, g, sm2p256v1, + GenerateKeyPair, Encrypt, Decrypt, GenerateReKey, + Encapsulate, ReEncrypt, DecryptFrags +) +from tpre import MergeCFrag +import random +import unittest + + +class TestHash2(unittest.TestCase): + def setUp(self): + self.double_G = ( + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + ) + + def test_digest_type(self): + digest = hash2(self.double_G) + self.assertEqual(type(digest), int) + + def test_digest_size(self): + digest = hash2(self.double_G) + self.assertLess(digest, sm2p256v1.N) + + +class TestHash3(unittest.TestCase): + def setUp(self): + self.triple_G = ( + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + ) + + def test_digest_type(self): + digest = hash3(self.triple_G) + self.assertEqual(type(digest), int) + + def test_digest_size(self): + digest = hash3(self.triple_G) + self.assertLess(digest, sm2p256v1.N) + + +class TestHash4(unittest.TestCase): + def setUp(self): + self.triple_G = ( + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + ) + self.Zp = random.randint(0, sm2p256v1.N - 1) + + def test_digest_type(self): + digest = hash4(self.triple_G, self.Zp) + self.assertEqual(type(digest), int) + + def test_digest_size(self): + digest = hash4(self.triple_G, self.Zp) + self.assertLess(digest, sm2p256v1.N) + + +# class TestGenerateKeyPair(unittest.TestCase): +# def test_key_pair(self): +# public_key, secret_key = GenerateKeyPair() +# self.assertIsInstance(public_key, tuple) +# self.assertIsInstance(secret_key, int) +# self.assertEqual(len(public_key), 2) + + +# class TestEncryptDecrypt(unittest.TestCase): +# def setUp(self): +# self.public_key, self.secret_key = GenerateKeyPair() +# self.message = b"Hello, world!" + +# def test_encrypt_decrypt(self): +# encrypted_message = Encrypt(self.public_key, self.message) +# decrypted_message = Decrypt(self.secret_key, encrypted_message) +# self.assertEqual(decrypted_message, self.message) + + +# class TestGenerateReKey(unittest.TestCase): +# def test_generate_rekey(self): +# sk_A = random.randint(0, sm2p256v1.N - 1) +# pk_B, _ = GenerateKeyPair() +# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) +# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) +# self.assertIsInstance(rekey, list) +# self.assertEqual(len(rekey), 5) + + +class TestEncapsulate(unittest.TestCase): + def test_encapsulate(self): + pk_A, _ = GenerateKeyPair() + K, capsule = Encapsulate(pk_A) + self.assertIsInstance(K, int) + self.assertIsInstance(capsule, tuple) + self.assertEqual(len(capsule), 3) + + +# class TestReEncrypt(unittest.TestCase): +# def test_reencrypt(self): +# sk_A = random.randint(0, sm2p256v1.N - 1) +# pk_B, _ = GenerateKeyPair() +# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) +# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) +# pk_A, _ = GenerateKeyPair() +# message = b"Hello, world!" +# encrypted_message = Encrypt(pk_A, message) +# reencrypted_message = ReEncrypt(rekey[0], encrypted_message) +# self.assertIsInstance(reencrypted_message, tuple) +# self.assertEqual(len(reencrypted_message), 2) + + +# class TestDecryptFrags(unittest.TestCase): +# def test_decrypt_frags(self): +# sk_A = random.randint(0, sm2p256v1.N - 1) +# pk_B, sk_B = GenerateKeyPair() +# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) +# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) +# pk_A, _ = GenerateKeyPair() +# message = b"Hello, world!" +# encrypted_message = Encrypt(pk_A, message) +# reencrypted_message = ReEncrypt(rekey[0], encrypted_message) +# cfrags = [reencrypted_message] +# merged_cfrags = MergeCFrag(cfrags) +# decrypted_message = DecryptFrags(sk_B, pk_B, pk_A, merged_cfrags) +# self.assertEqual(decrypted_message, message) + + +# if __name__ == "__main__": +# unittest.main() \ No newline at end of file From 05de02f2a59493d0b1049e0c2625d77cd9d1c897 Mon Sep 17 00:00:00 2001 From: muzhi <2624758301@qq.com> Date: Mon, 30 Sep 2024 12:06:21 +0800 Subject: [PATCH 2/3] test: Add unit tests --- install_gmssl.sh | 2 +- src/node_test.py | 76 ++++++++++++++++++++++++++++++++++++ src/node_test1.py | 35 ----------------- src/node_test2.py | 34 ---------------- src/node_test3.py | 23 ----------- src/node_test5.py | 99 +++++++++++++++++++++++++++++++++++++++++------ 6 files changed, 164 insertions(+), 105 deletions(-) delete mode 100644 src/node_test1.py delete mode 100644 src/node_test2.py delete mode 100644 src/node_test3.py diff --git a/install_gmssl.sh b/install_gmssl.sh index e942f26..9190b48 100644 --- a/install_gmssl.sh +++ b/install_gmssl.sh @@ -3,7 +3,7 @@ mkdir lib mkdir include -cp gmssl/include include +cp -r gmssl/include include mkdir gmssl/build cd gmssl/build || exit diff --git a/src/node_test.py b/src/node_test.py index e69de29..2b5cdaf 100644 --- a/src/node_test.py +++ b/src/node_test.py @@ -0,0 +1,76 @@ +# 测试 node.py 中的函数 +import os +import unittest +from unittest.mock import patch, MagicMock, Mock +import node + +class TestGetLocalIP(unittest.TestCase): + + @patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量 + def test_get_ip_from_env(self): + # 调用被测函数 + node.get_local_ip() + + # 检查函数是否正确获取到 HOST_IP + self.assertEqual(node.ip, '60.204.193.58') + + @patch('socket.socket') # Mock socket 连接行为 + @patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量 + def test_get_ip_from_socket(self, mock_socket): + # 模拟 socket 返回的 IP 地址 + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0) + + # 调用被测函数 + node.get_local_ip() + + # 确认 socket 被调用过 + mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80)) + mock_socket_instance.close.assert_called_once() + + # 检查是否通过 socket 获取到正确的 IP 地址 + self.assertEqual(node.ip, '110.41.155.96') + +class TestSendIP(unittest.TestCase): + @patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP + @patch('requests.get') # Mock requests.get 调用 + def test_send_ip(self, mock_get): + # 设置模拟返回的 HTTP 响应 + mock_response = Mock() + mock_response.text = "node123" # 模拟返回的节点ID + mock_response.status_code = 200 + mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response + + # 保存原始的全局 id 值 + original_id = node.id + + # 调用待测函数 + node.send_ip() + + # 确保 requests.get 被正确调用 + expected_url = f"{node.server_address}/get_node?ip={node.ip}" + mock_get.assert_called_once_with(expected_url, timeout=3) + + # 检查 id 是否被正确更新 + self.assertIs(node.id, mock_response) # 检查 id 是否被修改 + self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配 + +class TestNode(unittest.TestCase): + + @patch('node.send_ip') + @patch('node.get_local_ip') + @patch('node.asyncio.create_task') + def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip): + # 调用 init 函数 + node.init() + + # 验证 get_local_ip 和 send_ip 被调用 + mock_get_local_ip.assert_called_once() + mock_send_ip.assert_called_once() + + # 确保 create_task 被调用来启动心跳包 + mock_create_task.assert_called_once() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/src/node_test1.py b/src/node_test1.py deleted file mode 100644 index 66a7a47..0000000 --- a/src/node_test1.py +++ /dev/null @@ -1,35 +0,0 @@ -#测试 get_local_ip()函数 -import unittest -from unittest.mock import patch, MagicMock -import node - -class TestGetLocalIP(unittest.TestCase): - - @patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量 - def test_get_ip_from_env(self): - # 调用被测函数 - node.get_local_ip() - - # 检查函数是否正确获取到 HOST_IP - self.assertEqual(node.ip, '60.204.193.58') - - @patch('socket.socket') # Mock socket 连接行为 - @patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量 - def test_get_ip_from_socket(self, mock_socket): - # 模拟 socket 返回的 IP 地址 - mock_socket_instance = MagicMock() - mock_socket.return_value = mock_socket_instance - mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0) - - # 调用被测函数 - node.get_local_ip() - - # 确认 socket 被调用过 - mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80)) - mock_socket_instance.close.assert_called_once() - - # 检查是否通过 socket 获取到正确的 IP 地址 - self.assertEqual(node.ip, '110.41.155.96') - -if __name__ == '__main__': - unittest.main() diff --git a/src/node_test2.py b/src/node_test2.py deleted file mode 100644 index 7e01de8..0000000 --- a/src/node_test2.py +++ /dev/null @@ -1,34 +0,0 @@ -#测试send_ip()函数 -import os -import unittest -from unittest.mock import patch, Mock -import node # 导入要测试的模块 - -class TestSendIP(unittest.TestCase): - @patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP - @patch('requests.get') # Mock requests.get 调用 - def test_send_ip(self, mock_get): - # 设置模拟返回的 HTTP 响应 - mock_response = Mock() - mock_response.text = "node123" # 模拟返回的节点ID - mock_response.status_code = 200 - mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response - - # 保存原始的全局 id 值 - original_id = node.id - - # 调用待测函数 - node.send_ip() - - # 确保 requests.get 被正确调用 - expected_url = f"{node.server_address}/get_node?ip={node.ip}" - mock_get.assert_called_once_with(expected_url, timeout=3) - - # 检查 id 是否被正确更新 - self.assertIs(node.id, mock_response) # 检查 id 是否被修改 - self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配 - -if __name__ == "__main__": - unittest.main() -#node.py中 -#print("中心服务器返回节点ID为: ", id.text)即可看到测试代码返回的节点 diff --git a/src/node_test3.py b/src/node_test3.py deleted file mode 100644 index a92d921..0000000 --- a/src/node_test3.py +++ /dev/null @@ -1,23 +0,0 @@ -#测试init()函数 -import unittest -from unittest.mock import patch, AsyncMock -import node - -class TestNode(unittest.TestCase): - - @patch('node.send_ip') - @patch('node.get_local_ip') - @patch('node.asyncio.create_task') - def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip): - # 调用 init 函数 - node.init() - - # 验证 get_local_ip 和 send_ip 被调用 - mock_get_local_ip.assert_called_once() - mock_send_ip.assert_called_once() - - # 确保 create_task 被调用来启动心跳包 - mock_create_task.assert_called_once() - -if __name__ == '__main__': - unittest.main() diff --git a/src/node_test5.py b/src/node_test5.py index 3432bfe..124f970 100644 --- a/src/node_test5.py +++ b/src/node_test5.py @@ -1,22 +1,92 @@ +#node_test剩下部分(有问题) +import os +import unittest import pytest +from unittest.mock import patch, MagicMock, Mock, AsyncMock +import requests +import asyncio import httpx import respx -import asyncio -from unittest.mock import patch, AsyncMock from fastapi.testclient import TestClient -from node import app, send_heartbeat_internal, Req +from node import app, send_heartbeat_internal, Req, send_ip, get_local_ip, init, clear, send_user_des_message client = TestClient(app) server_address = "http://60.204.236.38:8000/server" -ip = "127.0.0.1" +ip = None # 初始化全局变量 ip +id = None # 初始化全局变量 id -@pytest.fixture(scope="session") -def anyio_backend(): - return "asyncio" +class TestGetLocalIP(unittest.TestCase): + + @patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量 + def test_get_ip_from_env(self): + global ip + # 调用被测函数 + get_local_ip() + + # 检查函数是否正确获取到 HOST_IP + self.assertEqual(ip, '60.204.193.58') + + @patch('socket.socket') # Mock socket 连接行为 + @patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量 + def test_get_ip_from_socket(self, mock_socket): + global ip + # 模拟 socket 返回的 IP 地址 + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0) + + # 调用被测函数 + get_local_ip() + + # 确认 socket 被调用过 + mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80)) + mock_socket_instance.close.assert_called_once() + + # 检查是否通过 socket 获取到正确的 IP 地址 + self.assertEqual(ip, '110.41.155.96') + +class TestSendIP(unittest.TestCase): + @patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP + @respx.mock + def test_send_ip(self): + global ip, id + ip = '60.204.193.58' + mock_url = f"{server_address}/get_node?ip={ip}" + respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123")) + + # 调用待测函数 + send_ip() + + # 确保 requests.get 被正确调用 + self.assertEqual(id, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配 + +class TestNode(unittest.TestCase): + + @patch('node.send_ip') + @patch('node.get_local_ip') + @patch('node.asyncio.create_task') + def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip): + # 调用 init 函数 + init() + + # 验证 get_local_ip 和 send_ip 被调用 + mock_get_local_ip.assert_called_once() + mock_send_ip.assert_called_once() + + # 确保 create_task 被调用来启动心跳包 + mock_create_task.assert_called_once() + + def test_clear(self): + # 调用 clear 函数 + clear() + # 检查输出 + self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试 @pytest.mark.asyncio @respx.mock async def test_send_heartbeat_internal_success(): + global ip + ip = '60.204.193.58' # 模拟心跳请求 heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( return_value=httpx.Response(200) @@ -40,6 +110,8 @@ async def test_send_heartbeat_internal_success(): @pytest.mark.asyncio @respx.mock async def test_send_heartbeat_internal_failure(): + global ip + ip = '60.204.193.58' # 模拟心跳请求以引发异常 heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( side_effect=httpx.RequestError("Central server error") @@ -66,8 +138,8 @@ def test_user_src(): # 模拟 send_user_des_message 函数 with patch("node.send_user_des_message", new_callable=AsyncMock) as mock_send_user_des_message: message = { - "source_ip": "127.0.0.1", - "dest_ip": "127.0.0.2", + "source_ip": "60.204.193.58", + "dest_ip": "60.204.193.59", "capsule": (("x1", "y1"), ("x2", "y2"), 123), "ct": 456, "rk": ["rk1", "rk2"] @@ -79,11 +151,14 @@ def test_user_src(): def test_send_user_des_message(): with respx.mock: - dest_ip = "127.0.0.2" + dest_ip = "60.204.193.59" re_message = (("a", "b", "c", "d"), 123) respx.post(f"http://{dest_ip}:8002/receive_messages").mock( return_value=httpx.Response(200, json={"status": "success"}) ) - response = client.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "127.0.0.1"}) + response = requests.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "60.204.193.58"}) assert response.status_code == 200 - assert response.json() == {"status": "success"} \ No newline at end of file + assert response.json() == {"status": "success"} + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 2f7f55fd3a999e3f19c635f497fb7d6bf60cbea9 Mon Sep 17 00:00:00 2001 From: muzhi <2624758301@qq.com> Date: Mon, 30 Sep 2024 20:16:37 +0800 Subject: [PATCH 3/3] test: add unit test --- src/test_client.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 src/test_client.py diff --git a/src/test_client.py b/src/test_client.py new file mode 100644 index 0000000..9cb2341 --- /dev/null +++ b/src/test_client.py @@ -0,0 +1,76 @@ +import os +import pytest +import sqlite3 +import respx +import httpx +from fastapi.testclient import TestClient +from client import app, init_db, clean_env, get_own_ip + +client = TestClient(app) + +@pytest.fixture(scope="module", autouse=True) +def setup_and_teardown(): + # 设置测试环境 + init_db() + yield + # 清理测试环境 + clean_env() + +def test_read_root(): + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "Hello, World!"} + +def test_receive_messages(): + message = { + "Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), + "ip": "127.0.0.1" + } + response = client.post("/receive_messages", json=message) + assert response.status_code == 200 + assert response.json().get("detail") == "Message received" + +# @respx.mock +# def test_request_message(): +# request_message = { +# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址 +# "message_name": "name" +# } +# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"})) +# response = client.post("/request_message", json=request_message) +# assert response.status_code == 200 +# assert "threshold" in response.json() +# assert "public_key" in response.json() + +# @respx.mock +# def test_receive_request(): +# ip_message = { +# "dest_ip": "124.70.165.73", # 使用不同的 IP 地址 +# "message_name": "name", +# "source_ip": "124.70.165.73", # 使用不同的 IP 地址 +# "pk": (123, 456) +# } +# respx.post("http://124.70.165.73:8002/receive_request").mock(return_value=httpx.Response(200, json={"threshold": 1, "public_key": "key"})) +# response = client.post("/receive_request", json=ip_message) +# assert response.status_code == 200 +# assert "threshold" in response.json() +# assert "public_key" in response.json() + +def test_get_pk(): + response = client.get("/get_pk") + assert response.status_code == 200 + assert "pkx" in response.json() + assert "pky" in response.json() + +def test_recieve_pk(): + pk_data = { + "pkx": "123", + "pky": "456", + "ip": "127.0.0.1" + } + response = client.post("/recieve_pk", json=pk_data) + assert response.status_code == 200 + assert response.json() == {"message": "save pk in database"} + +if __name__ == "__main__": + pytest.main() \ No newline at end of file