280 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			280 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import re
 | |
| import os
 | |
| import requests
 | |
| import argparse
 | |
| from bs4 import BeautifulSoup
 | |
| from typing import List, Tuple, Optional
 | |
| from packaging import version
 | |
| from packaging.specifiers import SpecifierSet
 | |
| from reportlab.lib.pagesizes import letter
 | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
 | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
 | |
| 
 | |
| 
 | |
| def fetch_html(url: str) -> Optional[str]:
 | |
|     """Fetch HTML content from the specified URL.
 | |
| 
 | |
|     Args:
 | |
|         url (str): URL to fetch HTML from.
 | |
| 
 | |
|     Returns:
 | |
|         Optional[str]: HTML content as a string, or None if fetch fails.
 | |
|     """
 | |
|     response = requests.get(url)
 | |
|     if response.status_code == 200:
 | |
|         return response.text
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def parse_html(html: str) -> List[Tuple[str, List[str]]]:
 | |
|     """Parse HTML to get content of all 'a' and 'span' tags under the second 'td' of each 'tr'.
 | |
| 
 | |
|     Args:
 | |
|         html (str): HTML content as a string.
 | |
| 
 | |
|     Returns:
 | |
|         List[Tuple[str, List[str]]]: A list of tuples containing the text of 'a' tags and lists of 'span' texts.
 | |
|     """
 | |
|     soup = BeautifulSoup(html, "html.parser")
 | |
|     table = soup.find("table", id="sortable-table")
 | |
|     results = []
 | |
|     if table:
 | |
|         rows = table.find("tbody").find_all("tr")
 | |
|         for row in rows:
 | |
|             tds = row.find_all("td")
 | |
|             if len(tds) >= 2:
 | |
|                 a_tags = tds[1].find_all("a")
 | |
|                 span_tags = tds[1].find_all("span")
 | |
|                 spans = [span.text.strip() for span in span_tags]
 | |
|                 for a_tag in a_tags:
 | |
|                     results.append((a_tag.text.strip(), spans))
 | |
|     return results
 | |
| 
 | |
| 
 | |
| def format_results(results: List[Tuple[str, List[str]]]) -> str:
 | |
|     """Format extracted data as a string.
 | |
| 
 | |
|     Args:
 | |
|         results (List[Tuple[str, List[str]]]): Extracted data to format.
 | |
| 
 | |
|     Returns:
 | |
|         str: Formatted string of the extracted data.
 | |
|     """
 | |
|     formatted_result = ""
 | |
|     for package_name, version_ranges in results:
 | |
|         formatted_result += f"Package Name: {package_name}\n"
 | |
|         formatted_result += "Version Ranges: " + ", ".join(version_ranges) + "\n"
 | |
|         formatted_result += "-" * 50 + "\n"
 | |
|     return formatted_result
 | |
| 
 | |
| 
 | |
| def trans_vulnerable_packages(content):
 | |
|     """将漏洞版本中的集合形式转换为大于小于的格式
 | |
|     Args:
 | |
|         content (str): 漏洞版本汇总信息.
 | |
|     """
 | |
|     vulnerabilities = {}
 | |
|     blocks = content.split("--------------------------------------------------")
 | |
|     range_pattern = re.compile(r"\[(.*?),\s*(.*?)\)")
 | |
| 
 | |
|     for block in blocks:
 | |
|         name_match = re.search(r"Package Name: (.+)", block)
 | |
|         if name_match:
 | |
|             package_name = name_match.group(1).strip()
 | |
|             ranges = range_pattern.findall(block)
 | |
|             specifier_list = []
 | |
|             for start, end in ranges:
 | |
|                 if start and end:
 | |
|                     specifier_list.append(f">={start},<{end}")
 | |
|                 elif start:
 | |
|                     specifier_list.append(f">={start}")
 | |
|                 elif end:
 | |
|                     specifier_list.append(f"<{end}")
 | |
|             if specifier_list:
 | |
|                 vulnerabilities[package_name] = SpecifierSet(",".join(specifier_list))
 | |
|     return vulnerabilities
 | |
| 
 | |
| 
 | |
| def format_vulnerabilities(vuln_packages):
 | |
|     """将字典形式的漏洞信息格式化
 | |
|     Args:
 | |
|         vuln_packages (List[Tuple[str, List[str]]]): Extracted data to format.
 | |
|     """
 | |
|     res = ""
 | |
|     for package, specifiers in vuln_packages.items():
 | |
|         res += f"Package Name: {package}\n"
 | |
|         res += f"Version Ranges: {specifiers}\n"
 | |
|         res += "-" * 50 + "\n"
 | |
|     return res
 | |
| 
 | |
| 
 | |
| def load_requirements(filename):
 | |
|     """从文件加载项目的依赖信息"""
 | |
|     with open(filename, "r", encoding="utf-8") as file:
 | |
|         lines = file.readlines()
 | |
|     requirements = {}
 | |
|     for line in lines:
 | |
|         if "==" in line:
 | |
|             package_name, package_version = line.strip().split("==")
 | |
|             requirements[package_name] = package_version
 | |
|     return requirements
 | |
| 
 | |
| 
 | |
| def check_vulnerabilities(requirements, vulnerabilities, output_file):
 | |
|     """检查依赖项是否存在已知漏洞,并输出结果"""
 | |
|     results_warning = []  # 存储有漏洞的依赖
 | |
|     results_ok = []  # 存储没有漏洞的依赖
 | |
| 
 | |
|     for req_name, req_version in requirements.items():
 | |
|         if req_name in vulnerabilities:
 | |
|             spec = vulnerabilities[req_name]
 | |
|             if version.parse(req_version) in spec:
 | |
|                 results_warning.append(
 | |
|                     f"WARNING: {req_name}=={req_version} is vulnerable!"
 | |
|                 )
 | |
|             else:
 | |
|                 results_ok.append(f"OK: {req_name}=={req_version} is not affected.")
 | |
|         else:
 | |
|             results_ok.append(
 | |
|                 f"OK: {req_name} not found in the vulnerability database."
 | |
|             )
 | |
| 
 | |
|     # 合并结果,先输出所有警告,然后输出所有正常情况
 | |
|     results = results_warning + results_ok
 | |
|     # print(results)
 | |
|     if output_file:
 | |
|         filename, ext = os.path.splitext(output_file)
 | |
|         output_format = ext[1:] if ext[1:] else "txt"
 | |
|         if output_format not in ["txt", "md", "html", "pdf"]:
 | |
|             print("Warning: Invalid file format specified. Defaulting to TXT format.")
 | |
|             output_format = "txt"  # 确保使用默认格式
 | |
|             output_file = filename + ".txt"
 | |
|         output_results(output_file, results, output_format)
 | |
|     else:
 | |
|         print("\n".join(results))
 | |
| 
 | |
| 
 | |
| def trans_vulnerable_packages_to_dict(content):
 | |
|     """将漏洞信息转换为字典格式
 | |
|     Args:
 | |
|         content str: 漏洞信息汇总.
 | |
|     """
 | |
|     vulnerabilities = {}
 | |
|     blocks = content.split("--------------------------------------------------")
 | |
|     for block in blocks:
 | |
|         name_match = re.search(r"Package Name: (.+)", block)
 | |
|         range_match = re.search(r"Version Ranges: (.+)", block)
 | |
|         if name_match and range_match:
 | |
|             package_name = name_match.group(1).strip()
 | |
|             version_range = range_match.group(1).strip()
 | |
|             version_range = ",".join(
 | |
|                 [part.strip() for part in version_range.split(",")]
 | |
|             )
 | |
|             vulnerabilities[package_name] = SpecifierSet(version_range)
 | |
|     return vulnerabilities
 | |
| 
 | |
| 
 | |
| def output_pdf(results, file_name):
 | |
|     doc = SimpleDocTemplate(file_name, pagesize=letter)
 | |
|     story = []
 | |
|     styles = getSampleStyleSheet()
 | |
| 
 | |
|     # Custom styles
 | |
|     title_style = styles["Title"]
 | |
|     title_style.alignment = 1  # Center alignment
 | |
| 
 | |
|     warning_style = ParagraphStyle(
 | |
|         "WarningStyle", parent=styles["BodyText"], fontName="Helvetica-Bold"
 | |
|     )
 | |
|     normal_style = styles["BodyText"]
 | |
| 
 | |
|     # Add the title
 | |
|     title = Paragraph("Vulnerability Report", title_style)
 | |
|     story.append(title)
 | |
|     story.append(Spacer(1, 20))  # Space after title
 | |
| 
 | |
|     # Iterate through results to add entries
 | |
|     for result in results:
 | |
|         if "WARNING:" in result:
 | |
|             # Add warning text in bold
 | |
|             entry = Paragraph(
 | |
|                 result.replace("WARNING:", "<b>WARNING:</b>"), warning_style
 | |
|             )
 | |
|         else:
 | |
|             # Add normal text
 | |
|             entry = Paragraph(result, normal_style)
 | |
| 
 | |
|         story.append(entry)
 | |
|         story.append(Spacer(1, 12))  # Space between entries
 | |
| 
 | |
|     doc.build(story)
 | |
| 
 | |
| 
 | |
| def output_results(filename, results, format_type):
 | |
|     """根据指定的格式输出结果"""
 | |
|     output_dir = os.path.dirname(filename)
 | |
|     if not os.path.exists(output_dir):
 | |
|         os.makedirs(output_dir)
 | |
| 
 | |
|     with open(filename, "w", encoding="utf-8") as file:
 | |
|         if format_type == "html":
 | |
|             file.write("<html><head><title>Vulnerability Report</title></head><body>\n")
 | |
|             file.write("<h1>Vulnerability Report</h1>\n")
 | |
|             for result in results:
 | |
|                 file.write(f"<p>{result}</p>\n")
 | |
|             file.write("</body></html>")
 | |
|         elif format_type == "md":
 | |
|             file.write("# Vulnerability Report\n")
 | |
|             for result in results:
 | |
|                 file.write(f"* {result}\n")
 | |
|         elif format_type == "pdf":
 | |
|             output_pdf(results, filename)
 | |
|         else:  # 默认为txt
 | |
|             for result in results:
 | |
|                 file.write(f"{result}\n")
 | |
| 
 | |
|     print("Results have been saved as " + filename)
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="Check project dependencies for vulnerabilities."
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "requirements_file", help="Path to the requirements file of the project"
 | |
|     )
 | |
|     parser.add_argument(
 | |
|         "-o",
 | |
|         "--output",
 | |
|         help="Output file path with extension, e.g., './output/report.txt'",
 | |
|     )
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     base_url = "https://security.snyk.io/vuln/pip/"
 | |
|     page_number = 1
 | |
|     crawler_results = ""
 | |
|     while True:
 | |
|         url = f"{base_url}{page_number}"
 | |
|         print(f"Fetching data from {url}")
 | |
|         html_content = fetch_html(url)
 | |
|         if not html_content:
 | |
|             print("No more data found or failed to fetch.")
 | |
|             break
 | |
|         extracted_data = parse_html(html_content)
 | |
|         if not extracted_data:
 | |
|             print("No relevant data found on page.")
 | |
|             break
 | |
|         crawler_results += format_results(extracted_data)
 | |
|         page_number += 1
 | |
|     print("Results have been stored in memory.\n")
 | |
| 
 | |
|     trans_res = trans_vulnerable_packages(crawler_results)
 | |
|     trans_res = format_vulnerabilities(trans_res)
 | |
|     trans_res = trans_vulnerable_packages_to_dict(trans_res)
 | |
|     requirements = load_requirements(args.requirements_file)
 | |
|     check_vulnerabilities(requirements, trans_res, args.output)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |