From bc852ec52c803d1ce18455035b1f81f169b31830 Mon Sep 17 00:00:00 2001
From: dqy <1016751306@qq.com>
Date: Thu, 25 Apr 2024 20:10:39 +0800
Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BE=9D=E8=B5=96=E7=89=88=E6=9C=AC?=
 =?UTF-8?q?=E6=A3=80=E6=B5=8B=E6=B7=BB=E5=8A=A0pdf=E8=BE=93=E5=87=BA?=
 =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=B9=B6=E4=BF=AE=E6=94=B9=E5=91=BD=E4=BB=A4?=
 =?UTF-8?q?=E8=A1=8C=E5=8F=82=E6=95=B0=E6=A0=BC=E5=BC=8F?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 detection/requirements_detection.py | 114 +++++++++++++++++++++-------
 1 file changed, 88 insertions(+), 26 deletions(-)

diff --git a/detection/requirements_detection.py b/detection/requirements_detection.py
index 1350aa0..5d32782 100644
--- a/detection/requirements_detection.py
+++ b/detection/requirements_detection.py
@@ -1,10 +1,12 @@
-# Usage: python requirements_detection.py ../crawler/trans_extracted_data.txt ../requirements.txt <output_format>
-
-import sys
+import argparse
 import os
+import re
+import sys
 from packaging import version
 from packaging.specifiers import SpecifierSet
-import re
+from reportlab.lib.pagesizes import letter
+from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
+from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
 
 
 def load_vulnerable_packages(filename):
@@ -38,6 +40,42 @@ def load_requirements(filename):
     return requirements
 
 
+def output_pdf(results, file_name):
+    doc = SimpleDocTemplate(file_name, pagesize=letter)
+    story = []
+    styles = getSampleStyleSheet()
+
+    # Custom styles
+    title_style = styles["Title"]
+    title_style.alignment = 1  # Center alignment
+
+    warning_style = ParagraphStyle(
+        "WarningStyle", parent=styles["BodyText"], fontName="Helvetica-Bold"
+    )
+    normal_style = styles["BodyText"]
+
+    # Add the title
+    title = Paragraph("Vulnerability Report", title_style)
+    story.append(title)
+    story.append(Spacer(1, 20))  # Space after title
+
+    # Iterate through results to add entries
+    for result in results:
+        if "WARNING:" in result:
+            # Add warning text in bold
+            entry = Paragraph(
+                result.replace("WARNING:", "<b>WARNING:</b>"), warning_style
+            )
+        else:
+            # Add normal text
+            entry = Paragraph(result, normal_style)
+
+        story.append(entry)
+        story.append(Spacer(1, 12))  # Space between entries
+
+    doc.build(story)
+
+
 def output_results(filename, results, format_type):
     """根据指定的格式输出结果"""
     output_dir = os.path.dirname(filename)
@@ -55,43 +93,67 @@ def output_results(filename, results, format_type):
             file.write("# Vulnerability Report\n")
             for result in results:
                 file.write(f"* {result}\n")
-        else:  # default to txt
+        elif format_type == "pdf":
+            output_pdf(results, filename)
+        else:  # 默认为txt
             for result in results:
                 file.write(f"{result}\n")
 
 
-def check_vulnerabilities(requirements, vulnerabilities, output_format):
+def check_vulnerabilities(requirements, vulnerabilities, output_file):
     """检查依赖项是否存在已知漏洞,并输出结果"""
-    results = []
+    results_warning = []  # 存储有漏洞的依赖
+    results_ok = []  # 存储没有漏洞的依赖
+
     for req_name, req_version in requirements.items():
         if req_name in vulnerabilities:
             spec = vulnerabilities[req_name]
             if version.parse(req_version) in spec:
-                results.append(f"WARNING: {req_name}=={req_version} is vulnerable!")
+                results_warning.append(
+                    f"WARNING: {req_name}=={req_version} is vulnerable!"
+                )
             else:
-                results.append(f"OK: {req_name}=={req_version} is not affected.")
+                results_ok.append(f"OK: {req_name}=={req_version} is not affected.")
         else:
-            results.append(f"OK: {req_name} not found in the vulnerability database.")
-    # 集成测试这里应该修改为./
-    output_results(
-        "./results/requirements/results." + output_format, results, output_format
-    )
+            results_ok.append(
+                f"OK: {req_name} not found in the vulnerability database."
+            )
+
+    # 合并结果,先输出所有警告,然后输出所有正常情况
+    results = results_warning + results_ok
+
+    if output_file:
+        filename, ext = os.path.splitext(output_file)
+        output_format = ext[1:] if ext[1:] else "txt"
+        if output_format not in ["txt", "md", "html", "pdf"]:
+            print("Warning: Invalid file format specified. Defaulting to TXT format.")
+            output_format = "txt"  # 确保使用默认格式
+            output_file = filename + ".txt"
+        output_results(output_file, results, output_format)
+    else:
+        print("\n".join(results))
 
 
 def main():
-    if len(sys.argv) < 4:
-        print(
-            "Usage: python script.py <vulnerabilities_file> <requirements_file> <output_format>"
-        )
-        sys.exit(1)
+    parser = argparse.ArgumentParser(
+        description="Check project dependencies for vulnerabilities."
+    )
+    parser.add_argument(
+        "vulnerabilities_file", help="Path to the file containing vulnerability data"
+    )
+    parser.add_argument(
+        "requirements_file", help="Path to the requirements file of the project"
+    )
+    parser.add_argument(
+        "-o",
+        "--output",
+        help="Output file path with extension, e.g., './output/report.txt'",
+    )
+    args = parser.parse_args()
 
-    vulnerabilities_file = sys.argv[1]
-    requirements_file = sys.argv[2]
-    output_format = sys.argv[3]
-
-    vulnerabilities = load_vulnerable_packages(vulnerabilities_file)
-    requirements = load_requirements(requirements_file)
-    check_vulnerabilities(requirements, vulnerabilities, output_format)
+    vulnerabilities = load_vulnerable_packages(args.vulnerabilities_file)
+    requirements = load_requirements(args.requirements_file)
+    check_vulnerabilities(requirements, vulnerabilities, args.output)
 
 
 if __name__ == "__main__":