diff --git a/src/node_test.log b/src/node_test.log new file mode 100644 index 0000000..fad9afe --- /dev/null +++ b/src/node_test.log @@ -0,0 +1 @@ +DEBUG:asyncio:Using selector: EpollSelector diff --git a/src/node_test1.py b/src/node_test1.py new file mode 100644 index 0000000..66a7a47 --- /dev/null +++ b/src/node_test1.py @@ -0,0 +1,35 @@ +#测试 get_local_ip()函数 +import unittest +from unittest.mock import patch, MagicMock +import node + +class TestGetLocalIP(unittest.TestCase): + + @patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量 + def test_get_ip_from_env(self): + # 调用被测函数 + node.get_local_ip() + + # 检查函数是否正确获取到 HOST_IP + self.assertEqual(node.ip, '60.204.193.58') + + @patch('socket.socket') # Mock socket 连接行为 + @patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量 + def test_get_ip_from_socket(self, mock_socket): + # 模拟 socket 返回的 IP 地址 + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0) + + # 调用被测函数 + node.get_local_ip() + + # 确认 socket 被调用过 + mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80)) + mock_socket_instance.close.assert_called_once() + + # 检查是否通过 socket 获取到正确的 IP 地址 + self.assertEqual(node.ip, '110.41.155.96') + +if __name__ == '__main__': + unittest.main() diff --git a/src/node_test2.py b/src/node_test2.py new file mode 100644 index 0000000..7e01de8 --- /dev/null +++ b/src/node_test2.py @@ -0,0 +1,34 @@ +#测试send_ip()函数 +import os +import unittest +from unittest.mock import patch, Mock +import node # 导入要测试的模块 + +class TestSendIP(unittest.TestCase): + @patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP + @patch('requests.get') # Mock requests.get 调用 + def test_send_ip(self, mock_get): + # 设置模拟返回的 HTTP 响应 + mock_response = Mock() + mock_response.text = "node123" # 模拟返回的节点ID + mock_response.status_code = 200 + mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response + + # 保存原始的全局 id 值 + original_id = node.id + + # 调用待测函数 + node.send_ip() + + # 确保 requests.get 被正确调用 + expected_url = f"{node.server_address}/get_node?ip={node.ip}" + mock_get.assert_called_once_with(expected_url, timeout=3) + + # 检查 id 是否被正确更新 + self.assertIs(node.id, mock_response) # 检查 id 是否被修改 + self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配 + +if __name__ == "__main__": + unittest.main() +#node.py中 +#print("中心服务器返回节点ID为: ", id.text)即可看到测试代码返回的节点 diff --git a/src/node_test3.py b/src/node_test3.py new file mode 100644 index 0000000..a92d921 --- /dev/null +++ b/src/node_test3.py @@ -0,0 +1,23 @@ +#测试init()函数 +import unittest +from unittest.mock import patch, AsyncMock +import node + +class TestNode(unittest.TestCase): + + @patch('node.send_ip') + @patch('node.get_local_ip') + @patch('node.asyncio.create_task') + def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip): + # 调用 init 函数 + node.init() + + # 验证 get_local_ip 和 send_ip 被调用 + mock_get_local_ip.assert_called_once() + mock_send_ip.assert_called_once() + + # 确保 create_task 被调用来启动心跳包 + mock_create_task.assert_called_once() + +if __name__ == '__main__': + unittest.main() diff --git a/src/node_test5.py b/src/node_test5.py new file mode 100644 index 0000000..3432bfe --- /dev/null +++ b/src/node_test5.py @@ -0,0 +1,89 @@ +import pytest +import httpx +import respx +import asyncio +from unittest.mock import patch, AsyncMock +from fastapi.testclient import TestClient +from node import app, send_heartbeat_internal, Req + +client = TestClient(app) +server_address = "http://60.204.236.38:8000/server" +ip = "127.0.0.1" + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + +@pytest.mark.asyncio +@respx.mock +async def test_send_heartbeat_internal_success(): + # 模拟心跳请求 + heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( + return_value=httpx.Response(200) + ) + + # 模拟 requests.get 以避免实际请求 + with patch("requests.get", return_value=httpx.Response(200)) as mock_get: + # 模拟 asyncio.sleep 以避免实际延迟 + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + task = asyncio.create_task(send_heartbeat_internal()) + await asyncio.sleep(0.1) # 允许任务运行一段时间 + task.cancel() # 取消任务以停止无限循环 + try: + await task # 确保任务被等待 + except asyncio.CancelledError: + pass # 捕获取消错误 + + assert mock_get.called + assert mock_get.call_count > 0 + +@pytest.mark.asyncio +@respx.mock +async def test_send_heartbeat_internal_failure(): + # 模拟心跳请求以引发异常 + heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock( + side_effect=httpx.RequestError("Central server error") + ) + + # 模拟 requests.get 以避免实际请求 + with patch("requests.get", side_effect=httpx.RequestError("Central server error")) as mock_get: + # 模拟 asyncio.sleep 以避免实际延迟 + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + task = asyncio.create_task(send_heartbeat_internal()) + await asyncio.sleep(0.1) # 允许任务运行一段时间 + task.cancel() # 取消任务以停止无限循环 + try: + await task # 确保任务被等待 + except asyncio.CancelledError: + pass # 捕获取消错误 + + assert mock_get.called + assert mock_get.call_count > 0 + +def test_user_src(): + # 模拟 ReEncrypt 函数 + with patch("node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")): + # 模拟 send_user_des_message 函数 + with patch("node.send_user_des_message", new_callable=AsyncMock) as mock_send_user_des_message: + message = { + "source_ip": "127.0.0.1", + "dest_ip": "127.0.0.2", + "capsule": (("x1", "y1"), ("x2", "y2"), 123), + "ct": 456, + "rk": ["rk1", "rk2"] + } + response = client.post("/user_src", json=message) + assert response.status_code == 200 + assert response.json() == {"detail": "message received"} + mock_send_user_des_message.assert_called_once() + +def test_send_user_des_message(): + with respx.mock: + dest_ip = "127.0.0.2" + re_message = (("a", "b", "c", "d"), 123) + respx.post(f"http://{dest_ip}:8002/receive_messages").mock( + return_value=httpx.Response(200, json={"status": "success"}) + ) + response = client.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "127.0.0.1"}) + assert response.status_code == 200 + assert response.json() == {"status": "success"} \ No newline at end of file diff --git a/src/pytest.ini b/src/pytest.ini new file mode 100644 index 0000000..6a7d170 --- /dev/null +++ b/src/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_default_fixture_loop_scope = function \ No newline at end of file diff --git a/src/server_test1.py b/src/server_test1.py new file mode 100644 index 0000000..a78a93f --- /dev/null +++ b/src/server_test1.py @@ -0,0 +1,122 @@ +import sqlite3 +import pytest +from fastapi.testclient import TestClient +from server import app, validate_ip + +# 创建 TestClient 实例 +client = TestClient(app) + +# 准备测试数据库数据 +def setup_db(): + # 创建数据库并插入测试数据 + with sqlite3.connect("server.db") as db: + db.execute(""" + CREATE TABLE IF NOT EXISTS nodes ( + id INTEGER PRIMARY KEY, + ip TEXT NOT NULL, + last_heartbeat INTEGER NOT NULL + ) + """) + db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)") + db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)") + db.commit() + +# 清空数据库 +def clear_db(): + with sqlite3.connect("server.db") as db: + db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表 + db.commit() + + +# 测试 IP 验证功能 +def test_validate_ip(): + assert validate_ip("192.168.0.1") is True + assert validate_ip("256.256.256.256") is False + assert validate_ip("::1") is True + assert validate_ip("invalid_ip") is False + +# 测试首页路由 +def test_home(): + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"message": "Hello, World!"} + +# 测试 show_nodes 路由 +def test_show_nodes(): + setup_db() + + response = client.get("/server/show_nodes") + assert response.status_code == 200 + + data = response.json() + assert len(data) == 2 + assert data[0][1] == "192.168.0.1" + assert data[1][1] == "192.168.0.2" + +# 测试 get_node 路由 +def test_get_node(): + # 确保数据库和表的存在 + setup_db() + + valid_ip = "192.168.0.3" + invalid_ip = "256.256.256.256" + + # 测试有效的 IP 地址 + response = client.get(f"/server/get_node?ip={valid_ip}") + assert response.status_code == 200 + + # 测试无效的 IP 地址 + response = client.get(f"/server/get_node?ip={invalid_ip}") + assert response.status_code == 400 + + +# 测试 delete_node 路由 +def test_delete_node(): + setup_db() + + valid_ip = "192.168.0.1" + invalid_ip = "192.168.0.255" + + response = client.get(f"/server/delete_node?ip={valid_ip}") + assert response.status_code == 200 + assert "Node with IP 192.168.0.1 deleted successfully." in response.text + + response = client.get(f"/server/delete_node?ip={invalid_ip}") + assert response.status_code == 404 + +# 测试 heartbeat 路由 +def test_receive_heartbeat(): + setup_db() + + valid_ip = "192.168.0.2" + invalid_ip = "256.256.256.256" + + response = client.get(f"/server/heartbeat?ip={valid_ip}") + assert response.status_code == 200 + assert response.json() == {"status": "received"} + + response = client.get(f"/server/heartbeat?ip={invalid_ip}") + assert response.status_code == 400 + assert response.json() == {"message": "invalid ip format"} + +# 测试 send_nodes_list 路由 +def test_send_nodes_list(): + setup_db() + + response = client.get("/server/send_nodes_list?count=1") + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0] == "192.168.0.1" + + response = client.get("/server/send_nodes_list?count=2") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + +# 运行完测试后清理数据库 +@pytest.fixture(autouse=True) +def run_around_tests(): + clear_db() + yield + clear_db() diff --git a/src/tpre_test1.py b/src/tpre_test1.py new file mode 100644 index 0000000..fce532a --- /dev/null +++ b/src/tpre_test1.py @@ -0,0 +1,131 @@ +from tpre import ( + hash2, hash3, hash4, multiply, g, sm2p256v1, + GenerateKeyPair, Encrypt, Decrypt, GenerateReKey, + Encapsulate, ReEncrypt, DecryptFrags +) +from tpre import MergeCFrag +import random +import unittest + + +class TestHash2(unittest.TestCase): + def setUp(self): + self.double_G = ( + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + ) + + def test_digest_type(self): + digest = hash2(self.double_G) + self.assertEqual(type(digest), int) + + def test_digest_size(self): + digest = hash2(self.double_G) + self.assertLess(digest, sm2p256v1.N) + + +class TestHash3(unittest.TestCase): + def setUp(self): + self.triple_G = ( + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + ) + + def test_digest_type(self): + digest = hash3(self.triple_G) + self.assertEqual(type(digest), int) + + def test_digest_size(self): + digest = hash3(self.triple_G) + self.assertLess(digest, sm2p256v1.N) + + +class TestHash4(unittest.TestCase): + def setUp(self): + self.triple_G = ( + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + multiply(g, random.randint(0, sm2p256v1.N - 1)), + ) + self.Zp = random.randint(0, sm2p256v1.N - 1) + + def test_digest_type(self): + digest = hash4(self.triple_G, self.Zp) + self.assertEqual(type(digest), int) + + def test_digest_size(self): + digest = hash4(self.triple_G, self.Zp) + self.assertLess(digest, sm2p256v1.N) + + +# class TestGenerateKeyPair(unittest.TestCase): +# def test_key_pair(self): +# public_key, secret_key = GenerateKeyPair() +# self.assertIsInstance(public_key, tuple) +# self.assertIsInstance(secret_key, int) +# self.assertEqual(len(public_key), 2) + + +# class TestEncryptDecrypt(unittest.TestCase): +# def setUp(self): +# self.public_key, self.secret_key = GenerateKeyPair() +# self.message = b"Hello, world!" + +# def test_encrypt_decrypt(self): +# encrypted_message = Encrypt(self.public_key, self.message) +# decrypted_message = Decrypt(self.secret_key, encrypted_message) +# self.assertEqual(decrypted_message, self.message) + + +# class TestGenerateReKey(unittest.TestCase): +# def test_generate_rekey(self): +# sk_A = random.randint(0, sm2p256v1.N - 1) +# pk_B, _ = GenerateKeyPair() +# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) +# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) +# self.assertIsInstance(rekey, list) +# self.assertEqual(len(rekey), 5) + + +class TestEncapsulate(unittest.TestCase): + def test_encapsulate(self): + pk_A, _ = GenerateKeyPair() + K, capsule = Encapsulate(pk_A) + self.assertIsInstance(K, int) + self.assertIsInstance(capsule, tuple) + self.assertEqual(len(capsule), 3) + + +# class TestReEncrypt(unittest.TestCase): +# def test_reencrypt(self): +# sk_A = random.randint(0, sm2p256v1.N - 1) +# pk_B, _ = GenerateKeyPair() +# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) +# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) +# pk_A, _ = GenerateKeyPair() +# message = b"Hello, world!" +# encrypted_message = Encrypt(pk_A, message) +# reencrypted_message = ReEncrypt(rekey[0], encrypted_message) +# self.assertIsInstance(reencrypted_message, tuple) +# self.assertEqual(len(reencrypted_message), 2) + + +# class TestDecryptFrags(unittest.TestCase): +# def test_decrypt_frags(self): +# sk_A = random.randint(0, sm2p256v1.N - 1) +# pk_B, sk_B = GenerateKeyPair() +# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5)) +# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple) +# pk_A, _ = GenerateKeyPair() +# message = b"Hello, world!" +# encrypted_message = Encrypt(pk_A, message) +# reencrypted_message = ReEncrypt(rekey[0], encrypted_message) +# cfrags = [reencrypted_message] +# merged_cfrags = MergeCFrag(cfrags) +# decrypted_message = DecryptFrags(sk_B, pk_B, pk_A, merged_cfrags) +# self.assertEqual(decrypted_message, message) + + +# if __name__ == "__main__": +# unittest.main() \ No newline at end of file