diff --git a/detection/requirements_detection.py b/detection/requirements_detection.py index 5a1c78f..8f2cdea 100644 --- a/detection/requirements_detection.py +++ b/detection/requirements_detection.py @@ -3,6 +3,15 @@ import requests from bs4 import BeautifulSoup from packaging.version import Version, InvalidVersion import sys +from reportlab.lib.pagesizes import letter +from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer +from colorama import Fore, Style, init +from tqdm import tqdm +import html + + +init(autoreset=True) # 初始化colorama,并在每次打印后自动重置颜色 def fetch_html(url: str) -> str: @@ -55,7 +64,6 @@ def version_in_range(version, range_str: str) -> bool: except InvalidVersion: return False else: - # 如果没有给版本号,默认使用最新版本 if range_str[-2] == ",": return True @@ -77,37 +85,155 @@ def version_in_range(version, range_str: str) -> bool: return True -def check_vulnerabilities(requirements: list, base_url: str, output_file: str): - with open(output_file, "w") as out_file: - for req in requirements: - version = "" - # 如果有版本 - if "==" in req: - package_name, version = req.split("==") - # 没有版本 - else: - package_name, version = req, None - # 拼接URL - url = f"{base_url}{package_name}" - print(f"Fetching data for {package_name} from {url}") - html_content = fetch_html(url) - if html_content: - # 解析hmtl - extracted_data = parse_html(html_content) - if extracted_data: - relevant_vulns = [] - for vuln in extracted_data: - if version_in_range(version, vuln["chip"]): - relevant_vulns.append(vuln) - if relevant_vulns: - out_file.write(f"Vulnerabilities found for {package_name}:\n") - for vuln in relevant_vulns: - out_file.write(f" - {vuln['link']}\n") - out_file.write("\n") - else: - print(f"No relevant data found for {package_name}.") - else: - print(f"Failed to fetch data for {package_name}.") +def check_vulnerabilities(requirements: list, base_url: str) -> str: + results = [] + for req in tqdm(requirements, desc="Checking vulnerabilities", unit="dependency"): + version = "" + if "==" in req: + package_name, version = req.split("==") + else: + package_name, version = req, None + url = f"{base_url}{package_name}" + # print(f"Fetching data for {package_name} from {url}") + html_content = fetch_html(url) + if html_content: + extracted_data = parse_html(html_content) + if extracted_data: + relevant_vulns = [] + for vuln in extracted_data: + if version_in_range(version, vuln["chip"]): + relevant_vulns.append(vuln) + if relevant_vulns: + result = f"Vulnerabilities found for {package_name}:\n" + for vuln in relevant_vulns: + result += f" - {vuln['link']}\n" + results.append(result) + return "\n".join(results) + + +def save_to_file(output_path: str, data: str): + if output_path.endswith(".html"): + save_as_html(output_path, data) + elif output_path.endswith(".pdf"): + save_as_pdf(output_path, data) + elif output_path.endswith(".md"): + save_as_markdown(output_path, data) + else: + save_as_txt(output_path, data) + + +def save_as_html(output_path: str, data: str): + escaped_data = html.escape(data) + html_content = f""" + + + + + + Vulnerability Report + + + +
+
Vulnerability Report
+
{escaped_data}
+
+ + + """ + with open(output_path, "w", encoding="utf-8") as file: + file.write(html_content) + + +def save_as_pdf(output_path: str, data: str): + doc = SimpleDocTemplate(output_path, pagesize=letter) + story = [] + styles = getSampleStyleSheet() + + # Add the title centered + title_style = ParagraphStyle( + "Title", + parent=styles["Title"], + alignment=1, # Center alignment + fontSize=24, + leading=28, + spaceAfter=20, + fontName="Helvetica-Bold", + ) + title = Paragraph("Vulnerability Report", title_style) + story.append(title) + + # Normal body text style + normal_style = ParagraphStyle( + "BodyText", parent=styles["BodyText"], fontSize=12, leading=15, spaceAfter=12 + ) + + # Add the vulnerability details + for line in data.split("\n"): + if line.strip(): # Skip empty lines + story.append(Paragraph(line, normal_style)) + + doc.build(story) + + +def save_as_markdown(output_path: str, data: str): + with open(output_path, "w") as file: + file.write("## Vulnerability Report: \n\n") + file.write(data) + + +def save_as_txt(output_path: str, data: str): + with open(output_path, "w") as file: + file.write("Vulnerability Report: \n\n") + file.write(data) + + +def print_separator(title, char="-", length=50, padding=2): + print(f"{title:^{length + 4*padding}}") # 居中打印标题,两侧各有padding个空格 + print(char * (length + 2 * padding)) # 打印分割线,两侧各有padding个字符的空格 def main(): @@ -124,16 +250,19 @@ def main(): "-o", "--output", help="Output file path with extension, e.g., './output/report.txt'", - required=True, ) args = parser.parse_args() base_url = "https://security.snyk.io/package/pip/" - # 分析项目依赖,包括名称和版本(如果有的话) requirements = load_requirements(args.requirement) - # 传入依赖信息,url前缀,扫描结果输出位置 - check_vulnerabilities(requirements, base_url, args.output) - print("Vulnerability scan complete. Results saved to", args.output) + results = check_vulnerabilities(requirements, base_url) + + if args.output: + save_to_file(args.output, results) + print(f"Vulnerability scan complete. Results saved to {args.output}") + else: + print_separator("\n\nVulnerability Report", "=", 40, 5) + print(results) if __name__ == "__main__":