feat: 修改依赖检测功能
This commit is contained in:
		| @@ -1,239 +1,113 @@ | ||||
| import re | ||||
| import os | ||||
| import requests | ||||
| import argparse | ||||
| import requests | ||||
| from bs4 import BeautifulSoup | ||||
| from typing import List, Tuple, Optional | ||||
| from packaging import version | ||||
| from packaging.specifiers import SpecifierSet | ||||
| from reportlab.lib.pagesizes import letter | ||||
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer | ||||
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | ||||
| from packaging.version import Version, InvalidVersion | ||||
| import sys | ||||
|  | ||||
|  | ||||
| def fetch_html(url: str) -> Optional[str]: | ||||
|     """Fetch HTML content from the specified URL. | ||||
|  | ||||
|     Args: | ||||
|         url (str): URL to fetch HTML from. | ||||
|  | ||||
|     Returns: | ||||
|         Optional[str]: HTML content as a string, or None if fetch fails. | ||||
|     """ | ||||
|     response = requests.get(url) | ||||
|     if response.status_code == 200: | ||||
| def fetch_html(url: str) -> str: | ||||
|     try: | ||||
|         response = requests.get(url) | ||||
|         response.raise_for_status() | ||||
|         return response.text | ||||
|     return None | ||||
|     except requests.RequestException as e: | ||||
|         print(f"Error fetching {url}: {e}") | ||||
|         return "" | ||||
|  | ||||
|  | ||||
| def parse_html(html: str) -> List[Tuple[str, List[str]]]: | ||||
|     """Parse HTML to get content of all 'a' and 'span' tags under the second 'td' of each 'tr'. | ||||
|  | ||||
|     Args: | ||||
|         html (str): HTML content as a string. | ||||
|  | ||||
|     Returns: | ||||
|         List[Tuple[str, List[str]]]: A list of tuples containing the text of 'a' tags and lists of 'span' texts. | ||||
|     """ | ||||
| def parse_html(html: str) -> list: | ||||
|     soup = BeautifulSoup(html, "html.parser") | ||||
|     table = soup.find("table", id="sortable-table") | ||||
|     if not table: | ||||
|         return [] | ||||
|  | ||||
|     rows = table.find_all("tr", class_="vue--table__row") | ||||
|     results = [] | ||||
|     if table: | ||||
|         rows = table.find("tbody").find_all("tr") | ||||
|         for row in rows: | ||||
|             tds = row.find_all("td") | ||||
|             if len(tds) >= 2: | ||||
|                 a_tags = tds[1].find_all("a") | ||||
|                 span_tags = tds[1].find_all("span") | ||||
|                 spans = [span.text.strip() for span in span_tags] | ||||
|                 for a_tag in a_tags: | ||||
|                     results.append((a_tag.text.strip(), spans)) | ||||
|     for row in rows: | ||||
|         info = {} | ||||
|         link = row.find("a") | ||||
|         chip = row.find("span", class_="vue--chip__value") | ||||
|         if link and chip: | ||||
|             info["link"] = link.get_text(strip=True) | ||||
|             info["chip"] = chip.get_text(strip=True) | ||||
|             results.append(info) | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def format_results(results: List[Tuple[str, List[str]]]) -> str: | ||||
|     """Format extracted data as a string. | ||||
|  | ||||
|     Args: | ||||
|         results (List[Tuple[str, List[str]]]): Extracted data to format. | ||||
|  | ||||
|     Returns: | ||||
|         str: Formatted string of the extracted data. | ||||
|     """ | ||||
|     formatted_result = "" | ||||
|     for package_name, version_ranges in results: | ||||
|         formatted_result += f"Package Name: {package_name}\n" | ||||
|         formatted_result += "Version Ranges: " + ", ".join(version_ranges) + "\n" | ||||
|         formatted_result += "-" * 50 + "\n" | ||||
|     return formatted_result | ||||
|  | ||||
|  | ||||
| def trans_vulnerable_packages(content): | ||||
|     """将漏洞版本中的集合形式转换为大于小于的格式 | ||||
|     Args: | ||||
|         content (str): 漏洞版本汇总信息. | ||||
|     """ | ||||
|     vulnerabilities = {} | ||||
|     blocks = content.split("--------------------------------------------------") | ||||
|     range_pattern = re.compile(r"\[(.*?),\s*(.*?)\)") | ||||
|  | ||||
|     for block in blocks: | ||||
|         name_match = re.search(r"Package Name: (.+)", block) | ||||
|         if name_match: | ||||
|             package_name = name_match.group(1).strip() | ||||
|             ranges = range_pattern.findall(block) | ||||
|             specifier_list = [] | ||||
|             for start, end in ranges: | ||||
|                 if start and end: | ||||
|                     specifier_list.append(f">={start},<{end}") | ||||
|                 elif start: | ||||
|                     specifier_list.append(f">={start}") | ||||
|                 elif end: | ||||
|                     specifier_list.append(f"<{end}") | ||||
|             if specifier_list: | ||||
|                 vulnerabilities[package_name] = SpecifierSet(",".join(specifier_list)) | ||||
|     return vulnerabilities | ||||
|  | ||||
|  | ||||
| def format_vulnerabilities(vuln_packages): | ||||
|     """将字典形式的漏洞信息格式化 | ||||
|     Args: | ||||
|         vuln_packages (List[Tuple[str, List[str]]]): Extracted data to format. | ||||
|     """ | ||||
|     res = "" | ||||
|     for package, specifiers in vuln_packages.items(): | ||||
|         res += f"Package Name: {package}\n" | ||||
|         res += f"Version Ranges: {specifiers}\n" | ||||
|         res += "-" * 50 + "\n" | ||||
|     return res | ||||
|  | ||||
|  | ||||
| def load_requirements(filename): | ||||
|     """从文件加载项目的依赖信息""" | ||||
|     with open(filename, "r", encoding="utf-8") as file: | ||||
|         lines = file.readlines() | ||||
|     requirements = {} | ||||
|     for line in lines: | ||||
|         if "==" in line: | ||||
|             package_name, package_version = line.strip().split("==") | ||||
|             requirements[package_name] = package_version | ||||
| def load_requirements(file_path: str) -> list: | ||||
|     requirements = [] | ||||
|     try: | ||||
|         with open(file_path, "r") as file: | ||||
|             for line in file: | ||||
|                 line = line.strip() | ||||
|                 if line and not line.startswith("#"): | ||||
|                     requirements.append(line) | ||||
|     except FileNotFoundError: | ||||
|         print(f"Error: File {file_path} not found.") | ||||
|         sys.exit(1) | ||||
|     return requirements | ||||
|  | ||||
|  | ||||
| def check_vulnerabilities(requirements, vulnerabilities, output_file): | ||||
|     """检查依赖项是否存在已知漏洞,并输出结果""" | ||||
|     results_warning = []  # 存储有漏洞的依赖 | ||||
|     results_ok = []  # 存储没有漏洞的依赖 | ||||
|  | ||||
|     for req_name, req_version in requirements.items(): | ||||
|         if req_name in vulnerabilities: | ||||
|             spec = vulnerabilities[req_name] | ||||
|             if version.parse(req_version) in spec: | ||||
|                 results_warning.append( | ||||
|                     f"WARNING: {req_name}=={req_version} is vulnerable!" | ||||
|                 ) | ||||
|             else: | ||||
|                 results_ok.append(f"OK: {req_name}=={req_version} is not affected.") | ||||
|         else: | ||||
|             results_ok.append( | ||||
|                 f"OK: {req_name} not found in the vulnerability database." | ||||
|             ) | ||||
|  | ||||
|     # 合并结果,先输出所有警告,然后输出所有正常情况 | ||||
|     results = results_warning + results_ok | ||||
|     # print(results) | ||||
|     if output_file: | ||||
|         filename, ext = os.path.splitext(output_file) | ||||
|         output_format = ext[1:] if ext[1:] else "txt" | ||||
|         if output_format not in ["txt", "md", "html", "pdf"]: | ||||
|             print("Warning: Invalid file format specified. Defaulting to TXT format.") | ||||
|             output_format = "txt"  # 确保使用默认格式 | ||||
|             output_file = filename + ".txt" | ||||
|         output_results(output_file, results, output_format) | ||||
| def version_in_range(version, range_str: str) -> bool: | ||||
|     if version is not None: | ||||
|         try: | ||||
|             v = Version(version) | ||||
|         except InvalidVersion: | ||||
|             return False | ||||
|     else: | ||||
|         print("\n".join(results)) | ||||
|         # 如果没有给版本号,默认使用最新版本 | ||||
|         if range_str[-2] == ",": | ||||
|             return True | ||||
|  | ||||
|     ranges = range_str.split(",") | ||||
|     for range_part in ranges: | ||||
|         range_part = range_part.strip("[]()") | ||||
|         if range_part: | ||||
|             try: | ||||
|                 if range_part.endswith(")"): | ||||
|                     upper = Version(range_part[:-1]) | ||||
|                     if v >= upper: | ||||
|                         return False | ||||
|                 elif range_part.startswith("["): | ||||
|                     lower = Version(range_part[1:]) | ||||
|                     if v < lower: | ||||
|                         return False | ||||
|             except InvalidVersion: | ||||
|                 return False | ||||
|     return True | ||||
|  | ||||
|  | ||||
| def trans_vulnerable_packages_to_dict(content): | ||||
|     """将漏洞信息转换为字典格式 | ||||
|     Args: | ||||
|         content str: 漏洞信息汇总. | ||||
|     """ | ||||
|     vulnerabilities = {} | ||||
|     blocks = content.split("--------------------------------------------------") | ||||
|     for block in blocks: | ||||
|         name_match = re.search(r"Package Name: (.+)", block) | ||||
|         range_match = re.search(r"Version Ranges: (.+)", block) | ||||
|         if name_match and range_match: | ||||
|             package_name = name_match.group(1).strip() | ||||
|             version_range = range_match.group(1).strip() | ||||
|             version_range = ",".join( | ||||
|                 [part.strip() for part in version_range.split(",")] | ||||
|             ) | ||||
|             vulnerabilities[package_name] = SpecifierSet(version_range) | ||||
|     return vulnerabilities | ||||
|  | ||||
|  | ||||
| def output_pdf(results, file_name): | ||||
|     doc = SimpleDocTemplate(file_name, pagesize=letter) | ||||
|     story = [] | ||||
|     styles = getSampleStyleSheet() | ||||
|  | ||||
|     # Custom styles | ||||
|     title_style = styles["Title"] | ||||
|     title_style.alignment = 1  # Center alignment | ||||
|  | ||||
|     warning_style = ParagraphStyle( | ||||
|         "WarningStyle", parent=styles["BodyText"], fontName="Helvetica-Bold" | ||||
|     ) | ||||
|     normal_style = styles["BodyText"] | ||||
|  | ||||
|     # Add the title | ||||
|     title = Paragraph("Vulnerability Report", title_style) | ||||
|     story.append(title) | ||||
|     story.append(Spacer(1, 20))  # Space after title | ||||
|  | ||||
|     # Iterate through results to add entries | ||||
|     for result in results: | ||||
|         if "WARNING:" in result: | ||||
|             # Add warning text in bold | ||||
|             entry = Paragraph( | ||||
|                 result.replace("WARNING:", "<b>WARNING:</b>"), warning_style | ||||
|             ) | ||||
|         else: | ||||
|             # Add normal text | ||||
|             entry = Paragraph(result, normal_style) | ||||
|  | ||||
|         story.append(entry) | ||||
|         story.append(Spacer(1, 12))  # Space between entries | ||||
|  | ||||
|     doc.build(story) | ||||
|  | ||||
|  | ||||
| def output_results(filename, results, format_type): | ||||
|     """根据指定的格式输出结果""" | ||||
|     output_dir = os.path.dirname(filename) | ||||
|     if not os.path.exists(output_dir): | ||||
|         os.makedirs(output_dir) | ||||
|  | ||||
|     with open(filename, "w", encoding="utf-8") as file: | ||||
|         if format_type == "html": | ||||
|             file.write("<html><head><title>Vulnerability Report</title></head><body>\n") | ||||
|             file.write("<h1>Vulnerability Report</h1>\n") | ||||
|             for result in results: | ||||
|                 file.write(f"<p>{result}</p>\n") | ||||
|             file.write("</body></html>") | ||||
|         elif format_type == "md": | ||||
|             file.write("# Vulnerability Report\n") | ||||
|             for result in results: | ||||
|                 file.write(f"* {result}\n") | ||||
|         elif format_type == "pdf": | ||||
|             output_pdf(results, filename) | ||||
|         else:  # 默认为txt | ||||
|             for result in results: | ||||
|                 file.write(f"{result}\n") | ||||
|  | ||||
|     print("Results have been saved as " + filename) | ||||
| def check_vulnerabilities(requirements: list, base_url: str, output_file: str): | ||||
|     with open(output_file, "w") as out_file: | ||||
|         for req in requirements: | ||||
|             version = "" | ||||
|             # 如果有版本 | ||||
|             if "==" in req: | ||||
|                 package_name, version = req.split("==") | ||||
|             # 没有版本 | ||||
|             else: | ||||
|                 package_name, version = req, None | ||||
|             # 拼接URL | ||||
|             url = f"{base_url}{package_name}" | ||||
|             print(f"Fetching data for {package_name} from {url}") | ||||
|             html_content = fetch_html(url) | ||||
|             if html_content: | ||||
|                 # 解析hmtl | ||||
|                 extracted_data = parse_html(html_content) | ||||
|                 if extracted_data: | ||||
|                     relevant_vulns = [] | ||||
|                     for vuln in extracted_data: | ||||
|                         if version_in_range(version, vuln["chip"]): | ||||
|                             relevant_vulns.append(vuln) | ||||
|                     if relevant_vulns: | ||||
|                         out_file.write(f"Vulnerabilities found for {package_name}:\n") | ||||
|                         for vuln in relevant_vulns: | ||||
|                             out_file.write(f"  - {vuln['link']}\n") | ||||
|                         out_file.write("\n") | ||||
|                 else: | ||||
|                     print(f"No relevant data found for {package_name}.") | ||||
|             else: | ||||
|                 print(f"Failed to fetch data for {package_name}.") | ||||
|  | ||||
|  | ||||
| def main(): | ||||
| @@ -241,38 +115,25 @@ def main(): | ||||
|         description="Check project dependencies for vulnerabilities." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "requirements_file", help="Path to the requirements file of the project" | ||||
|         "-r", | ||||
|         "--requirement", | ||||
|         help="Path to the requirements file of the project", | ||||
|         required=True, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "-o", | ||||
|         "--output", | ||||
|         help="Output file path with extension, e.g., './output/report.txt'", | ||||
|         required=True, | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     base_url = "https://security.snyk.io/vuln/pip/" | ||||
|     page_number = 1 | ||||
|     crawler_results = "" | ||||
|     while True: | ||||
|         url = f"{base_url}{page_number}" | ||||
|         print(f"Fetching data from {url}") | ||||
|         html_content = fetch_html(url) | ||||
|         if not html_content: | ||||
|             print("No more data found or failed to fetch.") | ||||
|             break | ||||
|         extracted_data = parse_html(html_content) | ||||
|         if not extracted_data: | ||||
|             print("No relevant data found on page.") | ||||
|             break | ||||
|         crawler_results += format_results(extracted_data) | ||||
|         page_number += 1 | ||||
|     print("Results have been stored in memory.\n") | ||||
|  | ||||
|     trans_res = trans_vulnerable_packages(crawler_results) | ||||
|     trans_res = format_vulnerabilities(trans_res) | ||||
|     trans_res = trans_vulnerable_packages_to_dict(trans_res) | ||||
|     requirements = load_requirements(args.requirements_file) | ||||
|     check_vulnerabilities(requirements, trans_res, args.output) | ||||
|     base_url = "https://security.snyk.io/package/pip/" | ||||
|     # 分析项目依赖,包括名称和版本(如果有的话) | ||||
|     requirements = load_requirements(args.requirement) | ||||
|     # 传入依赖信息,url前缀,扫描结果输出位置 | ||||
|     check_vulnerabilities(requirements, base_url, args.output) | ||||
|     print("Vulnerability scan complete. Results saved to", args.output) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
		Reference in New Issue
	
	Block a user