diff --git a/detection/requirements_detection.py b/detection/requirements_detection.py index 1350aa0..5d32782 100644 --- a/detection/requirements_detection.py +++ b/detection/requirements_detection.py @@ -1,10 +1,12 @@ -# Usage: python requirements_detection.py ../crawler/trans_extracted_data.txt ../requirements.txt - -import sys +import argparse import os +import re +import sys from packaging import version from packaging.specifiers import SpecifierSet -import re +from reportlab.lib.pagesizes import letter +from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle def load_vulnerable_packages(filename): @@ -38,6 +40,42 @@ def load_requirements(filename): return requirements +def output_pdf(results, file_name): + doc = SimpleDocTemplate(file_name, pagesize=letter) + story = [] + styles = getSampleStyleSheet() + + # Custom styles + title_style = styles["Title"] + title_style.alignment = 1 # Center alignment + + warning_style = ParagraphStyle( + "WarningStyle", parent=styles["BodyText"], fontName="Helvetica-Bold" + ) + normal_style = styles["BodyText"] + + # Add the title + title = Paragraph("Vulnerability Report", title_style) + story.append(title) + story.append(Spacer(1, 20)) # Space after title + + # Iterate through results to add entries + for result in results: + if "WARNING:" in result: + # Add warning text in bold + entry = Paragraph( + result.replace("WARNING:", "WARNING:"), warning_style + ) + else: + # Add normal text + entry = Paragraph(result, normal_style) + + story.append(entry) + story.append(Spacer(1, 12)) # Space between entries + + doc.build(story) + + def output_results(filename, results, format_type): """根据指定的格式输出结果""" output_dir = os.path.dirname(filename) @@ -55,43 +93,67 @@ def output_results(filename, results, format_type): file.write("# Vulnerability Report\n") for result in results: file.write(f"* {result}\n") - else: # default to txt + elif format_type == "pdf": + output_pdf(results, filename) + else: # 默认为txt for result in results: file.write(f"{result}\n") -def check_vulnerabilities(requirements, vulnerabilities, output_format): +def check_vulnerabilities(requirements, vulnerabilities, output_file): """检查依赖项是否存在已知漏洞,并输出结果""" - results = [] + results_warning = [] # 存储有漏洞的依赖 + results_ok = [] # 存储没有漏洞的依赖 + for req_name, req_version in requirements.items(): if req_name in vulnerabilities: spec = vulnerabilities[req_name] if version.parse(req_version) in spec: - results.append(f"WARNING: {req_name}=={req_version} is vulnerable!") + results_warning.append( + f"WARNING: {req_name}=={req_version} is vulnerable!" + ) else: - results.append(f"OK: {req_name}=={req_version} is not affected.") + results_ok.append(f"OK: {req_name}=={req_version} is not affected.") else: - results.append(f"OK: {req_name} not found in the vulnerability database.") - # 集成测试这里应该修改为./ - output_results( - "./results/requirements/results." + output_format, results, output_format - ) + results_ok.append( + f"OK: {req_name} not found in the vulnerability database." + ) + + # 合并结果,先输出所有警告,然后输出所有正常情况 + results = results_warning + results_ok + + if output_file: + filename, ext = os.path.splitext(output_file) + output_format = ext[1:] if ext[1:] else "txt" + if output_format not in ["txt", "md", "html", "pdf"]: + print("Warning: Invalid file format specified. Defaulting to TXT format.") + output_format = "txt" # 确保使用默认格式 + output_file = filename + ".txt" + output_results(output_file, results, output_format) + else: + print("\n".join(results)) def main(): - if len(sys.argv) < 4: - print( - "Usage: python script.py " - ) - sys.exit(1) + parser = argparse.ArgumentParser( + description="Check project dependencies for vulnerabilities." + ) + parser.add_argument( + "vulnerabilities_file", help="Path to the file containing vulnerability data" + ) + parser.add_argument( + "requirements_file", help="Path to the requirements file of the project" + ) + parser.add_argument( + "-o", + "--output", + help="Output file path with extension, e.g., './output/report.txt'", + ) + args = parser.parse_args() - vulnerabilities_file = sys.argv[1] - requirements_file = sys.argv[2] - output_format = sys.argv[3] - - vulnerabilities = load_vulnerable_packages(vulnerabilities_file) - requirements = load_requirements(requirements_file) - check_vulnerabilities(requirements, vulnerabilities, output_format) + vulnerabilities = load_vulnerable_packages(args.vulnerabilities_file) + requirements = load_requirements(args.requirements_file) + check_vulnerabilities(requirements, vulnerabilities, args.output) if __name__ == "__main__":