test(Add unit tests):
This commit is contained in:
parent
e8e7c59579
commit
acbc9ecb1a
1
src/node_test.log
Normal file
1
src/node_test.log
Normal file
@ -0,0 +1 @@
|
||||
DEBUG:asyncio:Using selector: EpollSelector
|
35
src/node_test1.py
Normal file
35
src/node_test1.py
Normal file
@ -0,0 +1,35 @@
|
||||
#测试 get_local_ip()函数
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import node
|
||||
|
||||
class TestGetLocalIP(unittest.TestCase):
|
||||
|
||||
@patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量
|
||||
def test_get_ip_from_env(self):
|
||||
# 调用被测函数
|
||||
node.get_local_ip()
|
||||
|
||||
# 检查函数是否正确获取到 HOST_IP
|
||||
self.assertEqual(node.ip, '60.204.193.58')
|
||||
|
||||
@patch('socket.socket') # Mock socket 连接行为
|
||||
@patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量
|
||||
def test_get_ip_from_socket(self, mock_socket):
|
||||
# 模拟 socket 返回的 IP 地址
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0)
|
||||
|
||||
# 调用被测函数
|
||||
node.get_local_ip()
|
||||
|
||||
# 确认 socket 被调用过
|
||||
mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80))
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
# 检查是否通过 socket 获取到正确的 IP 地址
|
||||
self.assertEqual(node.ip, '110.41.155.96')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
34
src/node_test2.py
Normal file
34
src/node_test2.py
Normal file
@ -0,0 +1,34 @@
|
||||
#测试send_ip()函数
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch, Mock
|
||||
import node # 导入要测试的模块
|
||||
|
||||
class TestSendIP(unittest.TestCase):
|
||||
@patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP
|
||||
@patch('requests.get') # Mock requests.get 调用
|
||||
def test_send_ip(self, mock_get):
|
||||
# 设置模拟返回的 HTTP 响应
|
||||
mock_response = Mock()
|
||||
mock_response.text = "node123" # 模拟返回的节点ID
|
||||
mock_response.status_code = 200
|
||||
mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response
|
||||
|
||||
# 保存原始的全局 id 值
|
||||
original_id = node.id
|
||||
|
||||
# 调用待测函数
|
||||
node.send_ip()
|
||||
|
||||
# 确保 requests.get 被正确调用
|
||||
expected_url = f"{node.server_address}/get_node?ip={node.ip}"
|
||||
mock_get.assert_called_once_with(expected_url, timeout=3)
|
||||
|
||||
# 检查 id 是否被正确更新
|
||||
self.assertIs(node.id, mock_response) # 检查 id 是否被修改
|
||||
self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
#node.py中
|
||||
#print("中心服务器返回节点ID为: ", id.text)即可看到测试代码返回的节点
|
23
src/node_test3.py
Normal file
23
src/node_test3.py
Normal file
@ -0,0 +1,23 @@
|
||||
#测试init()函数
|
||||
import unittest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
import node
|
||||
|
||||
class TestNode(unittest.TestCase):
|
||||
|
||||
@patch('node.send_ip')
|
||||
@patch('node.get_local_ip')
|
||||
@patch('node.asyncio.create_task')
|
||||
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
|
||||
# 调用 init 函数
|
||||
node.init()
|
||||
|
||||
# 验证 get_local_ip 和 send_ip 被调用
|
||||
mock_get_local_ip.assert_called_once()
|
||||
mock_send_ip.assert_called_once()
|
||||
|
||||
# 确保 create_task 被调用来启动心跳包
|
||||
mock_create_task.assert_called_once()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
89
src/node_test5.py
Normal file
89
src/node_test5.py
Normal file
@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
import httpx
|
||||
import respx
|
||||
import asyncio
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from fastapi.testclient import TestClient
|
||||
from node import app, send_heartbeat_internal, Req
|
||||
|
||||
client = TestClient(app)
|
||||
server_address = "http://60.204.236.38:8000/server"
|
||||
ip = "127.0.0.1"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def anyio_backend():
|
||||
return "asyncio"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_heartbeat_internal_success():
|
||||
# 模拟心跳请求
|
||||
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
|
||||
return_value=httpx.Response(200)
|
||||
)
|
||||
|
||||
# 模拟 requests.get 以避免实际请求
|
||||
with patch("requests.get", return_value=httpx.Response(200)) as mock_get:
|
||||
# 模拟 asyncio.sleep 以避免实际延迟
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
task = asyncio.create_task(send_heartbeat_internal())
|
||||
await asyncio.sleep(0.1) # 允许任务运行一段时间
|
||||
task.cancel() # 取消任务以停止无限循环
|
||||
try:
|
||||
await task # 确保任务被等待
|
||||
except asyncio.CancelledError:
|
||||
pass # 捕获取消错误
|
||||
|
||||
assert mock_get.called
|
||||
assert mock_get.call_count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@respx.mock
|
||||
async def test_send_heartbeat_internal_failure():
|
||||
# 模拟心跳请求以引发异常
|
||||
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
|
||||
side_effect=httpx.RequestError("Central server error")
|
||||
)
|
||||
|
||||
# 模拟 requests.get 以避免实际请求
|
||||
with patch("requests.get", side_effect=httpx.RequestError("Central server error")) as mock_get:
|
||||
# 模拟 asyncio.sleep 以避免实际延迟
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
task = asyncio.create_task(send_heartbeat_internal())
|
||||
await asyncio.sleep(0.1) # 允许任务运行一段时间
|
||||
task.cancel() # 取消任务以停止无限循环
|
||||
try:
|
||||
await task # 确保任务被等待
|
||||
except asyncio.CancelledError:
|
||||
pass # 捕获取消错误
|
||||
|
||||
assert mock_get.called
|
||||
assert mock_get.call_count > 0
|
||||
|
||||
def test_user_src():
|
||||
# 模拟 ReEncrypt 函数
|
||||
with patch("node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")):
|
||||
# 模拟 send_user_des_message 函数
|
||||
with patch("node.send_user_des_message", new_callable=AsyncMock) as mock_send_user_des_message:
|
||||
message = {
|
||||
"source_ip": "127.0.0.1",
|
||||
"dest_ip": "127.0.0.2",
|
||||
"capsule": (("x1", "y1"), ("x2", "y2"), 123),
|
||||
"ct": 456,
|
||||
"rk": ["rk1", "rk2"]
|
||||
}
|
||||
response = client.post("/user_src", json=message)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"detail": "message received"}
|
||||
mock_send_user_des_message.assert_called_once()
|
||||
|
||||
def test_send_user_des_message():
|
||||
with respx.mock:
|
||||
dest_ip = "127.0.0.2"
|
||||
re_message = (("a", "b", "c", "d"), 123)
|
||||
respx.post(f"http://{dest_ip}:8002/receive_messages").mock(
|
||||
return_value=httpx.Response(200, json={"status": "success"})
|
||||
)
|
||||
response = client.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "127.0.0.1"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "success"}
|
2
src/pytest.ini
Normal file
2
src/pytest.ini
Normal file
@ -0,0 +1,2 @@
|
||||
[pytest]
|
||||
asyncio_default_fixture_loop_scope = function
|
122
src/server_test1.py
Normal file
122
src/server_test1.py
Normal file
@ -0,0 +1,122 @@
|
||||
import sqlite3
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from server import app, validate_ip
|
||||
|
||||
# 创建 TestClient 实例
|
||||
client = TestClient(app)
|
||||
|
||||
# 准备测试数据库数据
|
||||
def setup_db():
|
||||
# 创建数据库并插入测试数据
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS nodes (
|
||||
id INTEGER PRIMARY KEY,
|
||||
ip TEXT NOT NULL,
|
||||
last_heartbeat INTEGER NOT NULL
|
||||
)
|
||||
""")
|
||||
db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)")
|
||||
db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)")
|
||||
db.commit()
|
||||
|
||||
# 清空数据库
|
||||
def clear_db():
|
||||
with sqlite3.connect("server.db") as db:
|
||||
db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表
|
||||
db.commit()
|
||||
|
||||
|
||||
# 测试 IP 验证功能
|
||||
def test_validate_ip():
|
||||
assert validate_ip("192.168.0.1") is True
|
||||
assert validate_ip("256.256.256.256") is False
|
||||
assert validate_ip("::1") is True
|
||||
assert validate_ip("invalid_ip") is False
|
||||
|
||||
# 测试首页路由
|
||||
def test_home():
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Hello, World!"}
|
||||
|
||||
# 测试 show_nodes 路由
|
||||
def test_show_nodes():
|
||||
setup_db()
|
||||
|
||||
response = client.get("/server/show_nodes")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
assert data[0][1] == "192.168.0.1"
|
||||
assert data[1][1] == "192.168.0.2"
|
||||
|
||||
# 测试 get_node 路由
|
||||
def test_get_node():
|
||||
# 确保数据库和表的存在
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.3"
|
||||
invalid_ip = "256.256.256.256"
|
||||
|
||||
# 测试有效的 IP 地址
|
||||
response = client.get(f"/server/get_node?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
|
||||
# 测试无效的 IP 地址
|
||||
response = client.get(f"/server/get_node?ip={invalid_ip}")
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# 测试 delete_node 路由
|
||||
def test_delete_node():
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.1"
|
||||
invalid_ip = "192.168.0.255"
|
||||
|
||||
response = client.get(f"/server/delete_node?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
assert "Node with IP 192.168.0.1 deleted successfully." in response.text
|
||||
|
||||
response = client.get(f"/server/delete_node?ip={invalid_ip}")
|
||||
assert response.status_code == 404
|
||||
|
||||
# 测试 heartbeat 路由
|
||||
def test_receive_heartbeat():
|
||||
setup_db()
|
||||
|
||||
valid_ip = "192.168.0.2"
|
||||
invalid_ip = "256.256.256.256"
|
||||
|
||||
response = client.get(f"/server/heartbeat?ip={valid_ip}")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "received"}
|
||||
|
||||
response = client.get(f"/server/heartbeat?ip={invalid_ip}")
|
||||
assert response.status_code == 400
|
||||
assert response.json() == {"message": "invalid ip format"}
|
||||
|
||||
# 测试 send_nodes_list 路由
|
||||
def test_send_nodes_list():
|
||||
setup_db()
|
||||
|
||||
response = client.get("/server/send_nodes_list?count=1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0] == "192.168.0.1"
|
||||
|
||||
response = client.get("/server/send_nodes_list?count=2")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 2
|
||||
|
||||
# 运行完测试后清理数据库
|
||||
@pytest.fixture(autouse=True)
|
||||
def run_around_tests():
|
||||
clear_db()
|
||||
yield
|
||||
clear_db()
|
131
src/tpre_test1.py
Normal file
131
src/tpre_test1.py
Normal file
@ -0,0 +1,131 @@
|
||||
from tpre import (
|
||||
hash2, hash3, hash4, multiply, g, sm2p256v1,
|
||||
GenerateKeyPair, Encrypt, Decrypt, GenerateReKey,
|
||||
Encapsulate, ReEncrypt, DecryptFrags
|
||||
)
|
||||
from tpre import MergeCFrag
|
||||
import random
|
||||
import unittest
|
||||
|
||||
|
||||
class TestHash2(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.double_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash2(self.double_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash3(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash3(self.triple_G)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
class TestHash4(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.triple_G = (
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
multiply(g, random.randint(0, sm2p256v1.N - 1)),
|
||||
)
|
||||
self.Zp = random.randint(0, sm2p256v1.N - 1)
|
||||
|
||||
def test_digest_type(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertEqual(type(digest), int)
|
||||
|
||||
def test_digest_size(self):
|
||||
digest = hash4(self.triple_G, self.Zp)
|
||||
self.assertLess(digest, sm2p256v1.N)
|
||||
|
||||
|
||||
# class TestGenerateKeyPair(unittest.TestCase):
|
||||
# def test_key_pair(self):
|
||||
# public_key, secret_key = GenerateKeyPair()
|
||||
# self.assertIsInstance(public_key, tuple)
|
||||
# self.assertIsInstance(secret_key, int)
|
||||
# self.assertEqual(len(public_key), 2)
|
||||
|
||||
|
||||
# class TestEncryptDecrypt(unittest.TestCase):
|
||||
# def setUp(self):
|
||||
# self.public_key, self.secret_key = GenerateKeyPair()
|
||||
# self.message = b"Hello, world!"
|
||||
|
||||
# def test_encrypt_decrypt(self):
|
||||
# encrypted_message = Encrypt(self.public_key, self.message)
|
||||
# decrypted_message = Decrypt(self.secret_key, encrypted_message)
|
||||
# self.assertEqual(decrypted_message, self.message)
|
||||
|
||||
|
||||
# class TestGenerateReKey(unittest.TestCase):
|
||||
# def test_generate_rekey(self):
|
||||
# sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
# pk_B, _ = GenerateKeyPair()
|
||||
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
# self.assertIsInstance(rekey, list)
|
||||
# self.assertEqual(len(rekey), 5)
|
||||
|
||||
|
||||
class TestEncapsulate(unittest.TestCase):
|
||||
def test_encapsulate(self):
|
||||
pk_A, _ = GenerateKeyPair()
|
||||
K, capsule = Encapsulate(pk_A)
|
||||
self.assertIsInstance(K, int)
|
||||
self.assertIsInstance(capsule, tuple)
|
||||
self.assertEqual(len(capsule), 3)
|
||||
|
||||
|
||||
# class TestReEncrypt(unittest.TestCase):
|
||||
# def test_reencrypt(self):
|
||||
# sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
# pk_B, _ = GenerateKeyPair()
|
||||
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
# pk_A, _ = GenerateKeyPair()
|
||||
# message = b"Hello, world!"
|
||||
# encrypted_message = Encrypt(pk_A, message)
|
||||
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
|
||||
# self.assertIsInstance(reencrypted_message, tuple)
|
||||
# self.assertEqual(len(reencrypted_message), 2)
|
||||
|
||||
|
||||
# class TestDecryptFrags(unittest.TestCase):
|
||||
# def test_decrypt_frags(self):
|
||||
# sk_A = random.randint(0, sm2p256v1.N - 1)
|
||||
# pk_B, sk_B = GenerateKeyPair()
|
||||
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
|
||||
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
|
||||
# pk_A, _ = GenerateKeyPair()
|
||||
# message = b"Hello, world!"
|
||||
# encrypted_message = Encrypt(pk_A, message)
|
||||
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
|
||||
# cfrags = [reencrypted_message]
|
||||
# merged_cfrags = MergeCFrag(cfrags)
|
||||
# decrypted_message = DecryptFrags(sk_B, pk_B, pk_A, merged_cfrags)
|
||||
# self.assertEqual(decrypted_message, message)
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# unittest.main()
|
Loading…
x
Reference in New Issue
Block a user