From 57145589651e1ad49bc56507fac287c49d355d58 Mon Sep 17 00:00:00 2001 From: dqy <1016751306@qq.com> Date: Mon, 22 Apr 2024 17:08:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E6=BC=8F=E6=B4=9E=E6=A3=80=E6=B5=8B=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- detection/requirements_detection.py | 97 +++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 detection/requirements_detection.py diff --git a/detection/requirements_detection.py b/detection/requirements_detection.py new file mode 100644 index 0000000..2daba0d --- /dev/null +++ b/detection/requirements_detection.py @@ -0,0 +1,97 @@ +# Usage: python requirements_detection.py ../crawler/trans_extracted_data.txt ../requirements.txt + +import sys +import os +from packaging import version +from packaging.specifiers import SpecifierSet +import re + + +def load_vulnerable_packages(filename): + """从文件加载有漏洞的包信息""" + with open(filename, "r", encoding="utf-8") as file: + content = file.read() + vulnerabilities = {} + blocks = content.split("--------------------------------------------------") + for block in blocks: + name_match = re.search(r"Package Name: (.+)", block) + range_match = re.search(r"Version Ranges: (.+)", block) + if name_match and range_match: + package_name = name_match.group(1).strip() + version_range = range_match.group(1).strip() + version_range = ",".join( + [part.strip() for part in version_range.split(",")] + ) + vulnerabilities[package_name] = SpecifierSet(version_range) + return vulnerabilities + + +def load_requirements(filename): + """从文件加载项目的依赖信息""" + with open(filename, "r", encoding="utf-8") as file: + lines = file.readlines() + requirements = {} + for line in lines: + if "==" in line: + package_name, package_version = line.strip().split("==") + requirements[package_name] = package_version + return requirements + + +def output_results(filename, results, format_type): + """根据指定的格式输出结果""" + output_dir = os.path.dirname(filename) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(filename, "w", encoding="utf-8") as file: + if format_type == "html": + file.write("Vulnerability Report\n") + file.write("

Vulnerability Report

\n") + for result in results: + file.write(f"

{result}

\n") + file.write("") + elif format_type == "md": + file.write("# Vulnerability Report\n") + for result in results: + file.write(f"* {result}\n") + else: # default to txt + for result in results: + file.write(f"{result}\n") + + +def check_vulnerabilities(requirements, vulnerabilities, output_format): + """检查依赖项是否存在已知漏洞,并输出结果""" + results = [] + for req_name, req_version in requirements.items(): + if req_name in vulnerabilities: + spec = vulnerabilities[req_name] + if version.parse(req_version) in spec: + results.append(f"WARNING: {req_name}=={req_version} is vulnerable!") + else: + results.append(f"OK: {req_name}=={req_version} is not affected.") + else: + results.append(f"OK: {req_name} not found in the vulnerability database.") + output_results( + "../results/requirements/results." + output_format, results, output_format + ) + + +def main(): + if len(sys.argv) < 4: + print( + "Usage: python script.py " + ) + sys.exit(1) + + vulnerabilities_file = sys.argv[1] + requirements_file = sys.argv[2] + output_format = sys.argv[3] + + vulnerabilities = load_vulnerable_packages(vulnerabilities_file) + requirements = load_requirements(requirements_file) + check_vulnerabilities(requirements, vulnerabilities, output_format) + + +if __name__ == "__main__": + main()