# Usage: python requirements_detection.py ../crawler/trans_extracted_data.txt ../requirements.txt import sys import os from packaging import version from packaging.specifiers import SpecifierSet import re def load_vulnerable_packages(filename): """从文件加载有漏洞的包信息""" with open(filename, "r", encoding="utf-8") as file: content = file.read() vulnerabilities = {} blocks = content.split("--------------------------------------------------") for block in blocks: name_match = re.search(r"Package Name: (.+)", block) range_match = re.search(r"Version Ranges: (.+)", block) if name_match and range_match: package_name = name_match.group(1).strip() version_range = range_match.group(1).strip() version_range = ",".join( [part.strip() for part in version_range.split(",")] ) vulnerabilities[package_name] = SpecifierSet(version_range) return vulnerabilities def load_requirements(filename): """从文件加载项目的依赖信息""" with open(filename, "r", encoding="utf-8") as file: lines = file.readlines() requirements = {} for line in lines: if "==" in line: package_name, package_version = line.strip().split("==") requirements[package_name] = package_version return requirements def output_results(filename, results, format_type): """根据指定的格式输出结果""" output_dir = os.path.dirname(filename) if not os.path.exists(output_dir): os.makedirs(output_dir) with open(filename, "w", encoding="utf-8") as file: if format_type == "html": file.write("Vulnerability Report\n") file.write("

Vulnerability Report

\n") for result in results: file.write(f"

{result}

\n") file.write("") elif format_type == "md": file.write("# Vulnerability Report\n") for result in results: file.write(f"* {result}\n") else: # default to txt for result in results: file.write(f"{result}\n") def check_vulnerabilities(requirements, vulnerabilities, output_format): """检查依赖项是否存在已知漏洞,并输出结果""" results = [] for req_name, req_version in requirements.items(): if req_name in vulnerabilities: spec = vulnerabilities[req_name] if version.parse(req_version) in spec: results.append(f"WARNING: {req_name}=={req_version} is vulnerable!") else: results.append(f"OK: {req_name}=={req_version} is not affected.") else: results.append(f"OK: {req_name} not found in the vulnerability database.") output_results( "../results/requirements/results." + output_format, results, output_format ) def main(): if len(sys.argv) < 4: print( "Usage: python script.py " ) sys.exit(1) vulnerabilities_file = sys.argv[1] requirements_file = sys.argv[2] output_format = sys.argv[3] vulnerabilities = load_vulnerable_packages(vulnerabilities_file) requirements = load_requirements(requirements_file) check_vulnerabilities(requirements, vulnerabilities, output_format) if __name__ == "__main__": main()