98 lines
3.9 KiB
Python
98 lines
3.9 KiB
Python
import unittest
|
|
from unittest.mock import patch, Mock, MagicMock
|
|
from detection.requirements_detection import (
|
|
fetch_html,
|
|
parse_html,
|
|
format_results,
|
|
check_vulnerabilities,
|
|
)
|
|
from packaging.version import Version
|
|
from packaging.specifiers import SpecifierSet
|
|
|
|
# Assuming the functions from your provided code are imported here
|
|
# from your_module import fetch_html, parse_html, format_results, ...
|
|
|
|
|
|
# 测试网页抓取和结果报告的测试类
|
|
class TestWebScrapingAndReporting(unittest.TestCase):
|
|
|
|
def test_fetch_html_success(self):
|
|
"""测试fetch_html在请求成功时返回正确的HTML内容。"""
|
|
with patch("requests.get") as mocked_get:
|
|
mocked_get.return_value.status_code = 200
|
|
mocked_get.return_value.text = "success"
|
|
url = "https://security.snyk.io/vuln/pip/"
|
|
result = fetch_html(url)
|
|
self.assertEqual(result, "success")
|
|
|
|
def test_fetch_html_failure(self):
|
|
"""测试fetch_html在请求失败时返回None。"""
|
|
with patch("requests.get") as mocked_get:
|
|
mocked_get.return_code.status_code = 404
|
|
url = "https://security.snyk.io/vuln/pip/"
|
|
result = fetch_html(url)
|
|
self.assertIsNone(result)
|
|
|
|
def test_parse_html(self):
|
|
"""测试parse_html能准确地解析HTML并提取预期的数据。"""
|
|
html_content = """
|
|
<table id="sortable-table">
|
|
<tbody>
|
|
<tr><td></td><td><a href="#">Link1</a><span>Span1</span></td></tr>
|
|
<tr><td></td><td><a href="#">Link2</a><span>Span2</span></td></tr>
|
|
</tbody>
|
|
</table>
|
|
"""
|
|
expected = [("Link1", ["Span1"]), ("Link2", ["Span2"])]
|
|
result = parse_html(html_content)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_format_results(self):
|
|
"""测试format_results能正确格式化解析后的数据。"""
|
|
results = [("Package1", ["1.0", "2.0"]), ("Package2", ["1.5", "2.5"])]
|
|
expected_output = (
|
|
"Package Name: Package1\nVersion Ranges: 1.0, 2.0\n"
|
|
+ "--------------------------------------------------\n"
|
|
+ "Package Name: Package2\nVersion Ranges: 1.5, 2.5\n"
|
|
+ "--------------------------------------------------\n"
|
|
)
|
|
formatted_result = format_results(results)
|
|
self.assertEqual(formatted_result, expected_output)
|
|
|
|
# 测试报警
|
|
def setUp(self):
|
|
"""假设的依赖和漏洞数据"""
|
|
self.requirements = {"package1": "1.0", "package2": "2.0"}
|
|
self.vulnerabilities = {
|
|
"package1": SpecifierSet(">=1.0,<2.0"),
|
|
"package3": SpecifierSet(">=1.0,<1.5"),
|
|
}
|
|
|
|
@patch("builtins.print") # 模拟内置的print函数以捕获输出
|
|
def test_check_vulnerabilities_no_output_file(self, mock_print):
|
|
"""测试当不提供输出文件时的情况,应该打印输出到控制台。"""
|
|
check_vulnerabilities(self.requirements, self.vulnerabilities, None)
|
|
expected_calls = [
|
|
unittest.mock.call(
|
|
"WARNING: package1==1.0 is vulnerable!\nOK: package2 not found in the vulnerability database."
|
|
)
|
|
]
|
|
mock_print.assert_has_calls(expected_calls, any_order=True)
|
|
|
|
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
|
@patch("os.path.splitext", return_value=("output", ".txt"))
|
|
@patch("os.path.exists", return_value=False)
|
|
@patch("os.makedirs")
|
|
def test_check_vulnerabilities_with_output_file(
|
|
self, mock_makedirs, mock_exists, mock_splitext, mock_open
|
|
):
|
|
"""测试当提供输出文件时,应该将结果写入文件。"""
|
|
check_vulnerabilities(self.requirements, self.vulnerabilities, "output.txt")
|
|
mock_open.assert_called_once_with("output.txt", "w", encoding="utf-8")
|
|
handle = mock_open()
|
|
handle.write.assert_called()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|