89 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			89 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# 测试 node.py 中的函数
 | 
						|
import os
 | 
						|
import unittest
 | 
						|
from unittest.mock import patch, MagicMock, Mock
 | 
						|
import sys
 | 
						|
import os
 | 
						|
 | 
						|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../src")))
 | 
						|
import node
 | 
						|
 | 
						|
 | 
						|
class TestGetLocalIP(unittest.TestCase):
 | 
						|
 | 
						|
    @patch.dict("os.environ", {"HOST_IP": "60.204.193.58"})  # 模拟设置 HOST_IP 环境变量
 | 
						|
    def test_get_ip_from_env(self):
 | 
						|
        # 调用被测函数
 | 
						|
        node.get_local_ip()
 | 
						|
 | 
						|
        # 检查函数是否正确获取到 HOST_IP
 | 
						|
        self.assertEqual(node.ip, "60.204.193.58")
 | 
						|
 | 
						|
    @patch("socket.socket")  # Mock socket 连接行为
 | 
						|
    @patch.dict("os.environ", {})  # 模拟没有 HOST_IP 环境变量
 | 
						|
    def test_get_ip_from_socket(self, mock_socket):
 | 
						|
        # 模拟 socket 返回的 IP 地址
 | 
						|
        mock_socket_instance = MagicMock()
 | 
						|
        mock_socket.return_value = mock_socket_instance
 | 
						|
        mock_socket_instance.getsockname.return_value = ("110.41.155.96", 0)
 | 
						|
 | 
						|
        # 调用被测函数
 | 
						|
        node.get_local_ip()
 | 
						|
 | 
						|
        # 确认 socket 被调用过
 | 
						|
        mock_socket_instance.connect.assert_called_with(("8.8.8.8", 80))
 | 
						|
        mock_socket_instance.close.assert_called_once()
 | 
						|
 | 
						|
        # 检查是否通过 socket 获取到正确的 IP 地址
 | 
						|
        self.assertEqual(node.ip, "110.41.155.96")
 | 
						|
 | 
						|
 | 
						|
class TestSendIP(unittest.TestCase):
 | 
						|
    @patch.dict(os.environ, {"HOST_IP": "60.204.193.58"})  # 设置环境变量 HOST_IP
 | 
						|
    @patch("requests.get")  # Mock requests.get 调用
 | 
						|
    def test_send_ip(self, mock_get):
 | 
						|
        # 设置模拟返回的 HTTP 响应
 | 
						|
        mock_response = Mock()
 | 
						|
        mock_response.text = "node123"  # 模拟返回的节点ID
 | 
						|
        mock_response.status_code = 200
 | 
						|
        mock_get.return_value = (
 | 
						|
            mock_response  # 设置 requests.get() 的返回值为 mock_response
 | 
						|
        )
 | 
						|
 | 
						|
        # 保存原始的全局 id 值
 | 
						|
        original_id = node.id
 | 
						|
 | 
						|
        # 调用待测函数
 | 
						|
        node.send_ip()
 | 
						|
 | 
						|
        # 确保 requests.get 被正确调用
 | 
						|
        expected_url = f"{node.server_address}/get_node?ip={node.ip}"
 | 
						|
        mock_get.assert_called_once_with(expected_url, timeout=3)
 | 
						|
 | 
						|
        # 检查 id 是否被正确更新
 | 
						|
        self.assertIs(node.id, mock_response)  # 检查 id 是否被修改
 | 
						|
        self.assertEqual(
 | 
						|
            node.id.text, "node123"
 | 
						|
        )  # 检查更新后的 id 是否与 mock_response.text 匹配
 | 
						|
 | 
						|
 | 
						|
class TestNode(unittest.TestCase):
 | 
						|
 | 
						|
    @patch("node.send_ip")
 | 
						|
    @patch("node.get_local_ip")
 | 
						|
    @patch("node.asyncio.create_task")
 | 
						|
    def test_init(self, mock_create_task, mock_get_local_ip, mock_send_ip):
 | 
						|
        # 调用 init 函数
 | 
						|
        node.init()
 | 
						|
 | 
						|
        # 验证 get_local_ip 和 send_ip 被调用
 | 
						|
        mock_get_local_ip.assert_called_once()
 | 
						|
        mock_send_ip.assert_called_once()
 | 
						|
 | 
						|
        # 确保 create_task 被调用来启动心跳包
 | 
						|
        mock_create_task.assert_called_once()
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    unittest.main()
 |