diff --git a/detection/__main__.py b/detection/__main__.py index 9586c4d..2f9b158 100644 --- a/detection/__main__.py +++ b/detection/__main__.py @@ -445,7 +445,7 @@ def main(): "-p", "--pycdc", help="Path to pycdc.exe to decompile", - default=os.getenv("pycdc"), + default=os.getenv("PATH"), ) args = parser.parse_args() output_format = "txt" # Default output format diff --git a/detection/requirements_detection.py b/detection/requirements_detection.py index 5404292..5a1c78f 100644 --- a/detection/requirements_detection.py +++ b/detection/requirements_detection.py @@ -1,239 +1,113 @@ -import re -import os -import requests import argparse +import requests from bs4 import BeautifulSoup -from typing import List, Tuple, Optional -from packaging import version -from packaging.specifiers import SpecifierSet -from reportlab.lib.pagesizes import letter -from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer -from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle +from packaging.version import Version, InvalidVersion +import sys -def fetch_html(url: str) -> Optional[str]: - """Fetch HTML content from the specified URL. - - Args: - url (str): URL to fetch HTML from. - - Returns: - Optional[str]: HTML content as a string, or None if fetch fails. - """ - response = requests.get(url) - if response.status_code == 200: +def fetch_html(url: str) -> str: + try: + response = requests.get(url) + response.raise_for_status() return response.text - return None + except requests.RequestException as e: + print(f"Error fetching {url}: {e}") + return "" -def parse_html(html: str) -> List[Tuple[str, List[str]]]: - """Parse HTML to get content of all 'a' and 'span' tags under the second 'td' of each 'tr'. - - Args: - html (str): HTML content as a string. - - Returns: - List[Tuple[str, List[str]]]: A list of tuples containing the text of 'a' tags and lists of 'span' texts. - """ +def parse_html(html: str) -> list: soup = BeautifulSoup(html, "html.parser") table = soup.find("table", id="sortable-table") + if not table: + return [] + + rows = table.find_all("tr", class_="vue--table__row") results = [] - if table: - rows = table.find("tbody").find_all("tr") - for row in rows: - tds = row.find_all("td") - if len(tds) >= 2: - a_tags = tds[1].find_all("a") - span_tags = tds[1].find_all("span") - spans = [span.text.strip() for span in span_tags] - for a_tag in a_tags: - results.append((a_tag.text.strip(), spans)) + for row in rows: + info = {} + link = row.find("a") + chip = row.find("span", class_="vue--chip__value") + if link and chip: + info["link"] = link.get_text(strip=True) + info["chip"] = chip.get_text(strip=True) + results.append(info) return results -def format_results(results: List[Tuple[str, List[str]]]) -> str: - """Format extracted data as a string. - - Args: - results (List[Tuple[str, List[str]]]): Extracted data to format. - - Returns: - str: Formatted string of the extracted data. - """ - formatted_result = "" - for package_name, version_ranges in results: - formatted_result += f"Package Name: {package_name}\n" - formatted_result += "Version Ranges: " + ", ".join(version_ranges) + "\n" - formatted_result += "-" * 50 + "\n" - return formatted_result - - -def trans_vulnerable_packages(content): - """将漏洞版本中的集合形式转换为大于小于的格式 - Args: - content (str): 漏洞版本汇总信息. - """ - vulnerabilities = {} - blocks = content.split("--------------------------------------------------") - range_pattern = re.compile(r"\[(.*?),\s*(.*?)\)") - - for block in blocks: - name_match = re.search(r"Package Name: (.+)", block) - if name_match: - package_name = name_match.group(1).strip() - ranges = range_pattern.findall(block) - specifier_list = [] - for start, end in ranges: - if start and end: - specifier_list.append(f">={start},<{end}") - elif start: - specifier_list.append(f">={start}") - elif end: - specifier_list.append(f"<{end}") - if specifier_list: - vulnerabilities[package_name] = SpecifierSet(",".join(specifier_list)) - return vulnerabilities - - -def format_vulnerabilities(vuln_packages): - """将字典形式的漏洞信息格式化 - Args: - vuln_packages (List[Tuple[str, List[str]]]): Extracted data to format. - """ - res = "" - for package, specifiers in vuln_packages.items(): - res += f"Package Name: {package}\n" - res += f"Version Ranges: {specifiers}\n" - res += "-" * 50 + "\n" - return res - - -def load_requirements(filename): - """从文件加载项目的依赖信息""" - with open(filename, "r", encoding="utf-8") as file: - lines = file.readlines() - requirements = {} - for line in lines: - if "==" in line: - package_name, package_version = line.strip().split("==") - requirements[package_name] = package_version +def load_requirements(file_path: str) -> list: + requirements = [] + try: + with open(file_path, "r") as file: + for line in file: + line = line.strip() + if line and not line.startswith("#"): + requirements.append(line) + except FileNotFoundError: + print(f"Error: File {file_path} not found.") + sys.exit(1) return requirements -def check_vulnerabilities(requirements, vulnerabilities, output_file): - """检查依赖项是否存在已知漏洞,并输出结果""" - results_warning = [] # 存储有漏洞的依赖 - results_ok = [] # 存储没有漏洞的依赖 - - for req_name, req_version in requirements.items(): - if req_name in vulnerabilities: - spec = vulnerabilities[req_name] - if version.parse(req_version) in spec: - results_warning.append( - f"WARNING: {req_name}=={req_version} is vulnerable!" - ) - else: - results_ok.append(f"OK: {req_name}=={req_version} is not affected.") - else: - results_ok.append( - f"OK: {req_name} not found in the vulnerability database." - ) - - # 合并结果,先输出所有警告,然后输出所有正常情况 - results = results_warning + results_ok - # print(results) - if output_file: - filename, ext = os.path.splitext(output_file) - output_format = ext[1:] if ext[1:] else "txt" - if output_format not in ["txt", "md", "html", "pdf"]: - print("Warning: Invalid file format specified. Defaulting to TXT format.") - output_format = "txt" # 确保使用默认格式 - output_file = filename + ".txt" - output_results(output_file, results, output_format) +def version_in_range(version, range_str: str) -> bool: + if version is not None: + try: + v = Version(version) + except InvalidVersion: + return False else: - print("\n".join(results)) + # 如果没有给版本号,默认使用最新版本 + if range_str[-2] == ",": + return True + + ranges = range_str.split(",") + for range_part in ranges: + range_part = range_part.strip("[]()") + if range_part: + try: + if range_part.endswith(")"): + upper = Version(range_part[:-1]) + if v >= upper: + return False + elif range_part.startswith("["): + lower = Version(range_part[1:]) + if v < lower: + return False + except InvalidVersion: + return False + return True -def trans_vulnerable_packages_to_dict(content): - """将漏洞信息转换为字典格式 - Args: - content str: 漏洞信息汇总. - """ - vulnerabilities = {} - blocks = content.split("--------------------------------------------------") - for block in blocks: - name_match = re.search(r"Package Name: (.+)", block) - range_match = re.search(r"Version Ranges: (.+)", block) - if name_match and range_match: - package_name = name_match.group(1).strip() - version_range = range_match.group(1).strip() - version_range = ",".join( - [part.strip() for part in version_range.split(",")] - ) - vulnerabilities[package_name] = SpecifierSet(version_range) - return vulnerabilities - - -def output_pdf(results, file_name): - doc = SimpleDocTemplate(file_name, pagesize=letter) - story = [] - styles = getSampleStyleSheet() - - # Custom styles - title_style = styles["Title"] - title_style.alignment = 1 # Center alignment - - warning_style = ParagraphStyle( - "WarningStyle", parent=styles["BodyText"], fontName="Helvetica-Bold" - ) - normal_style = styles["BodyText"] - - # Add the title - title = Paragraph("Vulnerability Report", title_style) - story.append(title) - story.append(Spacer(1, 20)) # Space after title - - # Iterate through results to add entries - for result in results: - if "WARNING:" in result: - # Add warning text in bold - entry = Paragraph( - result.replace("WARNING:", "WARNING:"), warning_style - ) - else: - # Add normal text - entry = Paragraph(result, normal_style) - - story.append(entry) - story.append(Spacer(1, 12)) # Space between entries - - doc.build(story) - - -def output_results(filename, results, format_type): - """根据指定的格式输出结果""" - output_dir = os.path.dirname(filename) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - with open(filename, "w", encoding="utf-8") as file: - if format_type == "html": - file.write("
{result}
\n") - file.write("") - elif format_type == "md": - file.write("# Vulnerability Report\n") - for result in results: - file.write(f"* {result}\n") - elif format_type == "pdf": - output_pdf(results, filename) - else: # 默认为txt - for result in results: - file.write(f"{result}\n") - - print("Results have been saved as " + filename) +def check_vulnerabilities(requirements: list, base_url: str, output_file: str): + with open(output_file, "w") as out_file: + for req in requirements: + version = "" + # 如果有版本 + if "==" in req: + package_name, version = req.split("==") + # 没有版本 + else: + package_name, version = req, None + # 拼接URL + url = f"{base_url}{package_name}" + print(f"Fetching data for {package_name} from {url}") + html_content = fetch_html(url) + if html_content: + # 解析hmtl + extracted_data = parse_html(html_content) + if extracted_data: + relevant_vulns = [] + for vuln in extracted_data: + if version_in_range(version, vuln["chip"]): + relevant_vulns.append(vuln) + if relevant_vulns: + out_file.write(f"Vulnerabilities found for {package_name}:\n") + for vuln in relevant_vulns: + out_file.write(f" - {vuln['link']}\n") + out_file.write("\n") + else: + print(f"No relevant data found for {package_name}.") + else: + print(f"Failed to fetch data for {package_name}.") def main(): @@ -241,38 +115,25 @@ def main(): description="Check project dependencies for vulnerabilities." ) parser.add_argument( - "requirements_file", help="Path to the requirements file of the project" + "-r", + "--requirement", + help="Path to the requirements file of the project", + required=True, ) parser.add_argument( "-o", "--output", help="Output file path with extension, e.g., './output/report.txt'", + required=True, ) args = parser.parse_args() - base_url = "https://security.snyk.io/vuln/pip/" - page_number = 1 - crawler_results = "" - while True: - url = f"{base_url}{page_number}" - print(f"Fetching data from {url}") - html_content = fetch_html(url) - if not html_content: - print("No more data found or failed to fetch.") - break - extracted_data = parse_html(html_content) - if not extracted_data: - print("No relevant data found on page.") - break - crawler_results += format_results(extracted_data) - page_number += 1 - print("Results have been stored in memory.\n") - - trans_res = trans_vulnerable_packages(crawler_results) - trans_res = format_vulnerabilities(trans_res) - trans_res = trans_vulnerable_packages_to_dict(trans_res) - requirements = load_requirements(args.requirements_file) - check_vulnerabilities(requirements, trans_res, args.output) + base_url = "https://security.snyk.io/package/pip/" + # 分析项目依赖,包括名称和版本(如果有的话) + requirements = load_requirements(args.requirement) + # 传入依赖信息,url前缀,扫描结果输出位置 + check_vulnerabilities(requirements, base_url, args.output) + print("Vulnerability scan complete. Results saved to", args.output) if __name__ == "__main__": diff --git a/tests/test_final_tests.py b/tests/test_final_tests.py index 7adad0c..caad14b 100644 --- a/tests/test_final_tests.py +++ b/tests/test_final_tests.py @@ -3,6 +3,7 @@ import unittest import shutil import os import threading +import re from detection.utils import read_file_content from .final_tests_util import ( @@ -11,6 +12,7 @@ from .final_tests_util import ( inject_pickle_backdoor, inject_random_backdoor, inject_pyc_backdoor, + backdoors, ) from detection.Regexdetection import find_dangerous_functions from detection.GPTdetection import detectGPT @@ -102,11 +104,22 @@ class TestFinalTests(unittest.TestCase): # test injected code with open(self.path + "output.txt", "r") as f: lines = f.readlines() - injected_detectedNum = 0 + injected_detected_num = 0 + injected_correct_num = 0 + pattern = r"\w+\.py: Line \d+: (.+)" for line in lines: if "py:" in line: - injected_detectedNum += 1 - injected_accurency = injected_detectedNum / self.injectedNum + injected_detected_num += 1 + match = re.search(pattern, line) + command = "" + if match: + command = match.group(1) + for backdoor in backdoors: + if command in backdoor: + injected_correct_num += 1 + break + + injected_accurency = injected_detected_num / self.py_files_num print(f"injected files accurency: {injected_accurency}") try: GPTresult = GPTdetectFileList(possibly_dangerous_file) @@ -122,21 +135,28 @@ class TestFinalTests(unittest.TestCase): # test pickle files with open(self.path + "output.txt", "r") as f: lines = f.readlines() - pickle_detectedNum = 0 + pickle_detected_num = 0 + pickle_correct_num = 0 for line in lines: if "pickle" in line: - pickle_detectedNum += 1 - pickle_accurency = pickle_detectedNum / self.pickle_files_num + pickle_detected_num += 1 + if re.search(r"backdoor\d*\.pickle", line): + pickle_correct_num += 1 + + pickle_accurency = pickle_detected_num / self.pickle_true_num print(f"pickle files accurency: {pickle_accurency}") # test pyc files with open(self.path + "output.txt", "r") as f: lines = f.readlines() - pyc_detectedNum = 0 + pyc_detected_num = 0 + pyc_correct_num = 0 for line in lines: if "pyc" in line: - pyc_detectedNum += 1 - pyc_accurency = pyc_detectedNum / self.pyc_files_num + pyc_detected_num += 1 + if re.search(r"backdoor\d*\.pyc", line): + pyc_correct_num += 1 + pyc_accurency = pyc_detected_num / self.pyc_true_num print(f"pyc files accurency: {pyc_accurency}")