import unittest from unittest.mock import patch, Mock, MagicMock from detection.requirements_detection import ( fetch_html, parse_html, format_results, check_vulnerabilities, ) from packaging.version import Version from packaging.specifiers import SpecifierSet # Assuming the functions from your provided code are imported here # from your_module import fetch_html, parse_html, format_results, ... # 测试网页抓取和结果报告的测试类 class TestWebScrapingAndReporting(unittest.TestCase): def test_fetch_html_success(self): """测试fetch_html在请求成功时返回正确的HTML内容。""" with patch("requests.get") as mocked_get: mocked_get.return_value.status_code = 200 mocked_get.return_value.text = "success" url = "https://security.snyk.io/vuln/pip/" result = fetch_html(url) self.assertEqual(result, "success") def test_fetch_html_failure(self): """测试fetch_html在请求失败时返回None。""" with patch("requests.get") as mocked_get: mocked_get.return_code.status_code = 404 url = "https://security.snyk.io/vuln/pip/" result = fetch_html(url) self.assertIsNone(result) def test_parse_html(self): """测试parse_html能准确地解析HTML并提取预期的数据。""" html_content = """
Link1Span1
Link2Span2
""" expected = [("Link1", ["Span1"]), ("Link2", ["Span2"])] result = parse_html(html_content) self.assertEqual(result, expected) def test_format_results(self): """测试format_results能正确格式化解析后的数据。""" results = [("Package1", ["1.0", "2.0"]), ("Package2", ["1.5", "2.5"])] expected_output = ( "Package Name: Package1\nVersion Ranges: 1.0, 2.0\n" + "--------------------------------------------------\n" + "Package Name: Package2\nVersion Ranges: 1.5, 2.5\n" + "--------------------------------------------------\n" ) formatted_result = format_results(results) self.assertEqual(formatted_result, expected_output) # 测试报警 def setUp(self): """假设的依赖和漏洞数据""" self.requirements = {"package1": "1.0", "package2": "2.0"} self.vulnerabilities = { "package1": SpecifierSet(">=1.0,<2.0"), "package3": SpecifierSet(">=1.0,<1.5"), } @patch("builtins.print") # 模拟内置的print函数以捕获输出 def test_check_vulnerabilities_no_output_file(self, mock_print): """测试当不提供输出文件时的情况,应该打印输出到控制台。""" check_vulnerabilities(self.requirements, self.vulnerabilities, None) expected_calls = [ unittest.mock.call( "WARNING: package1==1.0 is vulnerable!\nOK: package2 not found in the vulnerability database." ) ] mock_print.assert_has_calls(expected_calls, any_order=True) @patch("builtins.open", new_callable=unittest.mock.mock_open) @patch("os.path.splitext", return_value=("output", ".txt")) @patch("os.path.exists", return_value=False) @patch("os.makedirs") def test_check_vulnerabilities_with_output_file( self, mock_makedirs, mock_exists, mock_splitext, mock_open ): """测试当提供输出文件时,应该将结果写入文件。""" check_vulnerabilities(self.requirements, self.vulnerabilities, "output.txt") mock_open.assert_called_once_with("output.txt", "w", encoding="utf-8") handle = mock_open() handle.write.assert_called() if __name__ == "__main__": unittest.main()