Improve test code files
All checks were successful
Test CI / test speed (push) Successful in 10s

This commit is contained in:
muzhi 2024-10-04 22:28:14 +08:00
parent aa0eb6bfbc
commit 061bd5d2bf
15 changed files with 191 additions and 282 deletions

1
FQA.md
View File

@ -59,3 +59,4 @@
15. **Q:国密SM2 SM3 SM4** 如题
**A:**
----------------------------------------------------------------------

BIN
server.db

Binary file not shown.

View File

View File

@ -1 +0,0 @@
DEBUG:asyncio:Using selector: EpollSelector

View File

@ -1,122 +0,0 @@
import sqlite3
import pytest
from fastapi.testclient import TestClient
from server import app, validate_ip
# 创建 TestClient 实例
client = TestClient(app)
# 准备测试数据库数据
def setup_db():
# 创建数据库并插入测试数据
with sqlite3.connect("server.db") as db:
db.execute("""
CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY,
ip TEXT NOT NULL,
last_heartbeat INTEGER NOT NULL
)
""")
db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.1', 1234567890)")
db.execute("INSERT INTO nodes (ip, last_heartbeat) VALUES ('192.168.0.2', 1234567890)")
db.commit()
# 清空数据库
def clear_db():
with sqlite3.connect("server.db") as db:
db.execute("DROP TABLE IF EXISTS nodes") # 删除旧表
db.commit()
# 测试 IP 验证功能
def test_validate_ip():
assert validate_ip("192.168.0.1") is True
assert validate_ip("256.256.256.256") is False
assert validate_ip("::1") is True
assert validate_ip("invalid_ip") is False
# 测试首页路由
def test_home():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
# 测试 show_nodes 路由
def test_show_nodes():
setup_db()
response = client.get("/server/show_nodes")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0][1] == "192.168.0.1"
assert data[1][1] == "192.168.0.2"
# 测试 get_node 路由
def test_get_node():
# 确保数据库和表的存在
setup_db()
valid_ip = "192.168.0.3"
invalid_ip = "256.256.256.256"
# 测试有效的 IP 地址
response = client.get(f"/server/get_node?ip={valid_ip}")
assert response.status_code == 200
# 测试无效的 IP 地址
response = client.get(f"/server/get_node?ip={invalid_ip}")
assert response.status_code == 400
# 测试 delete_node 路由
def test_delete_node():
setup_db()
valid_ip = "192.168.0.1"
invalid_ip = "192.168.0.255"
response = client.get(f"/server/delete_node?ip={valid_ip}")
assert response.status_code == 200
assert "Node with IP 192.168.0.1 deleted successfully." in response.text
response = client.get(f"/server/delete_node?ip={invalid_ip}")
assert response.status_code == 404
# 测试 heartbeat 路由
def test_receive_heartbeat():
setup_db()
valid_ip = "192.168.0.2"
invalid_ip = "256.256.256.256"
response = client.get(f"/server/heartbeat?ip={valid_ip}")
assert response.status_code == 200
assert response.json() == {"status": "received"}
response = client.get(f"/server/heartbeat?ip={invalid_ip}")
assert response.status_code == 400
assert response.json() == {"message": "invalid ip format"}
# 测试 send_nodes_list 路由
def test_send_nodes_list():
setup_db()
response = client.get("/server/send_nodes_list?count=1")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0] == "192.168.0.1"
response = client.get("/server/send_nodes_list?count=2")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
# 运行完测试后清理数据库
@pytest.fixture(autouse=True)
def run_around_tests():
clear_db()
yield
clear_db()

View File

@ -1,58 +0,0 @@
from tpre import hash2, hash3, hash4, multiply, g, sm2p256v1
import random
import unittest
class TestHash2(unittest.TestCase):
def setUp(self):
self.double_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash2(self.double_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash2(self.double_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash3(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
def test_digest_type(self):
digest = hash3(self.triple_G)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash3(self.triple_G)
self.assertLess(digest, sm2p256v1.N)
class TestHash4(unittest.TestCase):
def setUp(self):
self.triple_G = (
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
multiply(g, random.randint(0, sm2p256v1.N - 1)),
)
self.Zp = random.randint(0, sm2p256v1.N - 1)
def test_digest_type(self):
digest = hash4(self.triple_G, self.Zp)
self.assertEqual(type(digest), int)
def test_digest_size(self):
digest = hash4(self.triple_G, self.Zp)
self.assertLess(digest, sm2p256v1.N)
if __name__ == "__main__":
unittest.main()

View File

@ -1,3 +1,7 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import add, multiply, sm2p256v1
import time

View File

@ -1,3 +1,7 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,

View File

@ -1,3 +1,7 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import *
import time

View File

@ -2,45 +2,53 @@
import os
import unittest
from unittest.mock import patch, MagicMock, Mock
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
import node
class TestGetLocalIP(unittest.TestCase):
@patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量
@patch.dict("os.environ", {"HOST_IP": "60.204.193.58"}) # 模拟设置 HOST_IP 环境变量
def test_get_ip_from_env(self):
# 调用被测函数
node.get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(node.ip, '60.204.193.58')
self.assertEqual(node.ip, "60.204.193.58")
@patch('socket.socket') # Mock socket 连接行为
@patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量
@patch("socket.socket") # Mock socket 连接行为
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
def test_get_ip_from_socket(self, mock_socket):
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0)
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
node.get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80))
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(node.ip, '110.41.155.96')
self.assertEqual(node.ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP
@patch('requests.get') # Mock requests.get 调用
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@patch("requests.get") # Mock requests.get 调用
def test_send_ip(self, mock_get):
# 设置模拟返回的 HTTP 响应
mock_response = Mock()
mock_response.text = "node123" # 模拟返回的节点ID
mock_response.status_code = 200
mock_get.return_value = mock_response # 设置 requests.get() 的返回值为 mock_response
mock_get.return_value = (
mock_response # 设置 requests.get() 的返回值为 mock_response
)
# 保存原始的全局 id 值
original_id = node.id
@ -54,13 +62,16 @@ class TestSendIP(unittest.TestCase):
# 检查 id 是否被正确更新
self.assertIs(node.id, mock_response) # 检查 id 是否被修改
self.assertEqual(node.id.text, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配
self.assertEqual(
node.id.text, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch('node.send_ip')
@patch('node.get_local_ip')
@patch('node.asyncio.create_task')
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
node.init()
@ -72,5 +83,6 @@ class TestNode(unittest.TestCase):
# 确保 create_task 被调用来启动心跳包
mock_create_task.assert_called_once()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,4 @@
#node_test剩下部分(有问题)
# node_test剩下部分(有问题)
import os
import unittest
import pytest
@ -8,49 +8,67 @@ import asyncio
import httpx
import respx
from fastapi.testclient import TestClient
from node import app, send_heartbeat_internal, Req, send_ip, get_local_ip, init, clear, send_user_des_message
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from node import (
app,
send_heartbeat_internal,
Req,
send_ip,
get_local_ip,
init,
clear,
send_user_des_message,
ip,
id,
)
client = TestClient(app)
server_address = "http://60.204.236.38:8000/server"
ip = None # 初始化全局变量 ip
id = None # 初始化全局变量 id
# ip = None # 初始化全局变量 ip
# id = None # 初始化全局变量 id
class TestGetLocalIP(unittest.TestCase):
@patch.dict('os.environ', {'HOST_IP': '60.204.193.58'}) # 模拟设置 HOST_IP 环境变量
os.environ["HOST_IP"] = "60.204.193.58" # 模拟设置 HOST_IP 环境变量
def test_get_ip_from_env(self):
global ip
# 调用被测函数
get_local_ip()
# 检查函数是否正确获取到 HOST_IP
self.assertEqual(ip, '60.204.193.58')
self.assertEqual(ip, "60.204.193.58")
@patch('socket.socket') # Mock socket 连接行为
@patch.dict('os.environ', {}) # 模拟没有 HOST_IP 环境变量
@patch("socket.socket") # Mock socket 连接行为
@patch.dict("os.environ", {}) # 模拟没有 HOST_IP 环境变量
def test_get_ip_from_socket(self, mock_socket):
global ip
# 模拟 socket 返回的 IP 地址
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.getsockname.return_value = ('110.41.155.96', 0)
mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
# 调用被测函数
get_local_ip()
# 确认 socket 被调用过
mock_socket_instance.connect.assert_called_with(('8.8.8.8', 80))
mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
mock_socket_instance.close.assert_called_once()
# 检查是否通过 socket 获取到正确的 IP 地址
self.assertEqual(ip, '110.41.155.96')
self.assertEqual(ip, "110.41.155.96")
class TestSendIP(unittest.TestCase):
@patch.dict(os.environ, {'HOST_IP': '60.204.193.58'}) # 设置环境变量 HOST_IP
@patch.dict(os.environ, {"HOST_IP": "60.204.193.58"}) # 设置环境变量 HOST_IP
@respx.mock
def test_send_ip(self):
global ip, id
ip = '60.204.193.58'
ip = "60.204.193.58"
mock_url = f"{server_address}/get_node?ip={ip}"
respx.get(mock_url).mock(return_value=httpx.Response(200, text="node123"))
@ -58,13 +76,16 @@ class TestSendIP(unittest.TestCase):
send_ip()
# 确保 requests.get 被正确调用
self.assertEqual(id, "node123") # 检查更新后的 id 是否与 mock_response.text 匹配
self.assertEqual(
id, "node123"
) # 检查更新后的 id 是否与 mock_response.text 匹配
class TestNode(unittest.TestCase):
@patch('node.send_ip')
@patch('node.get_local_ip')
@patch('node.asyncio.create_task')
@patch("node.send_ip")
@patch("node.get_local_ip")
@patch("node.asyncio.create_task")
def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
# 调用 init 函数
init()
@ -82,11 +103,12 @@ class TestNode(unittest.TestCase):
# 检查输出
self.assertTrue(True) # 这里只是为了确保函数被调用,没有实际逻辑需要测试
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_success():
global ip
ip = '60.204.193.58'
ip = "60.204.193.58"
# 模拟心跳请求
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
return_value=httpx.Response(200)
@ -107,18 +129,21 @@ async def test_send_heartbeat_internal_success():
assert mock_get.called
assert mock_get.call_count > 0
@pytest.mark.asyncio
@respx.mock
async def test_send_heartbeat_internal_failure():
global ip
ip = '60.204.193.58'
ip = "60.204.193.58"
# 模拟心跳请求以引发异常
heartbeat_route = respx.get(f"{server_address}/heartbeat?ip={ip}").mock(
side_effect=httpx.RequestError("Central server error")
)
# 模拟 requests.get 以避免实际请求
with patch("requests.get", side_effect=httpx.RequestError("Central server error")) as mock_get:
with patch(
"requests.get", side_effect=httpx.RequestError("Central server error")
) as mock_get:
# 模拟 asyncio.sleep 以避免实际延迟
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
task = asyncio.create_task(send_heartbeat_internal())
@ -132,23 +157,29 @@ async def test_send_heartbeat_internal_failure():
assert mock_get.called
assert mock_get.call_count > 0
def test_user_src():
# 模拟 ReEncrypt 函数
with patch("node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")):
with patch(
"node.ReEncrypt", return_value=(("a", "b", "c", "d"), b"encrypted_data")
):
# 模拟 send_user_des_message 函数
with patch("node.send_user_des_message", new_callable=AsyncMock) as mock_send_user_des_message:
with patch(
"node.send_user_des_message", new_callable=AsyncMock
) as mock_send_user_des_message:
message = {
"source_ip": "60.204.193.58",
"dest_ip": "60.204.193.59",
"capsule": (("x1", "y1"), ("x2", "y2"), 123),
"ct": 456,
"rk": ["rk1", "rk2"]
"rk": ["rk1", "rk2"],
}
response = client.post("/user_src", json=message)
assert response.status_code == 200
assert response.json() == {"detail": "message received"}
mock_send_user_des_message.assert_called_once()
def test_send_user_des_message():
with respx.mock:
dest_ip = "60.204.193.59"
@ -156,9 +187,13 @@ def test_send_user_des_message():
respx.post(f"http://{dest_ip}:8002/receive_messages").mock(
return_value=httpx.Response(200, json={"status": "success"})
)
response = requests.post(f"http://{dest_ip}:8002/receive_messages", json={"Tuple": re_message, "ip": "60.204.193.58"})
response = requests.post(
f"http://{dest_ip}:8002/receive_messages",
json={"Tuple": re_message, "ip": "60.204.193.58"},
)
assert response.status_code == 200
assert response.json() == {"status": "success"}
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,11 +1,10 @@
import sqlite3
import pytest
from fastapi.testclient import TestClient
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
import sqlite3
import pytest
from fastapi.testclient import TestClient
from server import app, validate_ip
# 创建 TestClient 实例

View File

@ -1,3 +1,7 @@
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
GenerateKeyPair,
Encrypt,

View File

@ -1,13 +1,14 @@
import os
import pytest
import sqlite3
import respx
import httpx
from fastapi.testclient import TestClient
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from client import app, init_db, clean_env, get_own_ip
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
# 设置测试环境
@ -16,20 +17,20 @@ def setup_and_teardown():
# 清理测试环境
clean_env()
def test_read_root():
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"message": "Hello, World!"}
def test_receive_messages():
message = {
"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8),
"ip": "127.0.0.1"
}
message = {"Tuple": (((1, 2), (3, 4), 5, (6, 7)), 8), "ip": "127.0.0.1"}
response = client.post("/receive_messages", json=message)
assert response.status_code == 200
assert response.json().get("detail") == "Message received"
# @respx.mock
# def test_request_message():
# request_message = {
@ -56,21 +57,20 @@ def test_receive_messages():
# assert "threshold" in response.json()
# assert "public_key" in response.json()
def test_get_pk():
response = client.get("/get_pk")
assert response.status_code == 200
assert "pkx" in response.json()
assert "pky" in response.json()
def test_recieve_pk():
pk_data = {
"pkx": "123",
"pky": "456",
"ip": "127.0.0.1"
}
pk_data = {"pkx": "123", "pky": "456", "ip": "127.0.0.1"}
response = client.post("/recieve_pk", json=pk_data)
assert response.status_code == 200
assert response.json() == {"message": "save pk in database"}
if __name__ == "__main__":
pytest.main()

View File

@ -1,9 +1,24 @@
import sys
import os
import hashlib
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
from tpre import (
hash2, hash3, hash4, multiply, g, sm2p256v1,
GenerateKeyPair, Encrypt, Decrypt, GenerateReKey,
Encapsulate, ReEncrypt, DecryptFrags
hash2,
hash3,
hash4,
multiply,
g,
sm2p256v1,
GenerateKeyPair,
Encrypt,
Decrypt,
GenerateReKey,
Encapsulate,
ReEncrypt,
DecryptFrags,
MergeCFrag,
)
from tpre import MergeCFrag
import random
import unittest
@ -59,12 +74,14 @@ class TestHash4(unittest.TestCase):
self.assertLess(digest, sm2p256v1.N)
# class TestGenerateKeyPair(unittest.TestCase):
# def test_key_pair(self):
# public_key, secret_key = GenerateKeyPair()
# self.assertIsInstance(public_key, tuple)
# self.assertIsInstance(secret_key, int)
# self.assertEqual(len(public_key), 2)
class TestGenerateKeyPair(unittest.TestCase):
def test_key_pair(self):
public_key, secret_key = GenerateKeyPair()
self.assertIsInstance(public_key, tuple)
self.assertEqual(len(public_key), 2)
self.assertIsInstance(secret_key, int)
self.assertLess(secret_key, sm2p256v1.N)
self.assertGreater(secret_key, 0)
# class TestEncryptDecrypt(unittest.TestCase):
@ -74,18 +91,22 @@ class TestHash4(unittest.TestCase):
# def test_encrypt_decrypt(self):
# encrypted_message = Encrypt(self.public_key, self.message)
# decrypted_message = Decrypt(self.secret_key, encrypted_message)
# # 使用 SHA-256 哈希函数确保密钥为 16 字节
# secret_key_hash = hashlib.sha256(self.secret_key.to_bytes((self.secret_key.bit_length() + 7) // 8, 'big')).digest()
# secret_key_int = int.from_bytes(secret_key_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = Decrypt(secret_key_int, encrypted_message)
# self.assertEqual(decrypted_message, self.message)
# class TestGenerateReKey(unittest.TestCase):
# def test_generate_rekey(self):
# sk_A = random.randint(0, sm2p256v1.N - 1)
# pk_B, _ = GenerateKeyPair()
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
# self.assertIsInstance(rekey, list)
# self.assertEqual(len(rekey), 5)
class TestGenerateReKey(unittest.TestCase):
def test_generate_rekey(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
self.assertIsInstance(rekey, list)
self.assertEqual(len(rekey), 5)
class TestEncapsulate(unittest.TestCase):
@ -97,18 +118,18 @@ class TestEncapsulate(unittest.TestCase):
self.assertEqual(len(capsule), 3)
# class TestReEncrypt(unittest.TestCase):
# def test_reencrypt(self):
# sk_A = random.randint(0, sm2p256v1.N - 1)
# pk_B, _ = GenerateKeyPair()
# id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
# rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
# pk_A, _ = GenerateKeyPair()
# message = b"Hello, world!"
# encrypted_message = Encrypt(pk_A, message)
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
# self.assertIsInstance(reencrypted_message, tuple)
# self.assertEqual(len(reencrypted_message), 2)
class TestReEncrypt(unittest.TestCase):
def test_reencrypt(self):
sk_A = random.randint(0, sm2p256v1.N - 1)
pk_B, _ = GenerateKeyPair()
id_tuple = tuple(random.randint(0, sm2p256v1.N - 1) for _ in range(5))
rekey = GenerateReKey(sk_A, pk_B, 5, 3, id_tuple)
pk_A, _ = GenerateKeyPair()
message = b"Hello, world!"
encrypted_message = Encrypt(pk_A, message)
reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
self.assertIsInstance(reencrypted_message, tuple)
self.assertEqual(len(reencrypted_message), 2)
# class TestDecryptFrags(unittest.TestCase):
@ -123,9 +144,15 @@ class TestEncapsulate(unittest.TestCase):
# reencrypted_message = ReEncrypt(rekey[0], encrypted_message)
# cfrags = [reencrypted_message]
# merged_cfrags = MergeCFrag(cfrags)
# decrypted_message = DecryptFrags(sk_B, pk_B, pk_A, merged_cfrags)
# self.assertIsNotNone(merged_cfrags)
# sk_B_hash = hashlib.sha256(sk_B.to_bytes((sk_B.bit_length() + 7) // 8, 'big')).digest()
# sk_B_int = int.from_bytes(sk_B_hash[:16], 'big') # 取前 16 字节并转换为整数
# decrypted_message = DecryptFrags(sk_B_int, pk_B, pk_A, merged_cfrags)
# self.assertEqual(decrypted_message, message)
# if __name__ == "__main__":
# unittest.main()
if __name__ == "__main__":
unittest.main()