# Usage: python requirements_detection.py ../crawler/trans_extracted_data.txt ../requirements.txt <output_format>

import sys
import os
from packaging import version
from packaging.specifiers import SpecifierSet
import re


def load_vulnerable_packages(filename):
    """从文件加载有漏洞的包信息"""
    with open(filename, "r", encoding="utf-8") as file:
        content = file.read()
    vulnerabilities = {}
    blocks = content.split("--------------------------------------------------")
    for block in blocks:
        name_match = re.search(r"Package Name: (.+)", block)
        range_match = re.search(r"Version Ranges: (.+)", block)
        if name_match and range_match:
            package_name = name_match.group(1).strip()
            version_range = range_match.group(1).strip()
            version_range = ",".join(
                [part.strip() for part in version_range.split(",")]
            )
            vulnerabilities[package_name] = SpecifierSet(version_range)
    return vulnerabilities


def load_requirements(filename):
    """从文件加载项目的依赖信息"""
    with open(filename, "r", encoding="utf-8") as file:
        lines = file.readlines()
    requirements = {}
    for line in lines:
        if "==" in line:
            package_name, package_version = line.strip().split("==")
            requirements[package_name] = package_version
    return requirements


def output_results(filename, results, format_type):
    """根据指定的格式输出结果"""
    output_dir = os.path.dirname(filename)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(filename, "w", encoding="utf-8") as file:
        if format_type == "html":
            file.write("<html><head><title>Vulnerability Report</title></head><body>\n")
            file.write("<h1>Vulnerability Report</h1>\n")
            for result in results:
                file.write(f"<p>{result}</p>\n")
            file.write("</body></html>")
        elif format_type == "md":
            file.write("# Vulnerability Report\n")
            for result in results:
                file.write(f"* {result}\n")
        else:  # default to txt
            for result in results:
                file.write(f"{result}\n")


def check_vulnerabilities(requirements, vulnerabilities, output_format):
    """检查依赖项是否存在已知漏洞,并输出结果"""
    results = []
    for req_name, req_version in requirements.items():
        if req_name in vulnerabilities:
            spec = vulnerabilities[req_name]
            if version.parse(req_version) in spec:
                results.append(f"WARNING: {req_name}=={req_version} is vulnerable!")
            else:
                results.append(f"OK: {req_name}=={req_version} is not affected.")
        else:
            results.append(f"OK: {req_name} not found in the vulnerability database.")
    # 集成测试这里应该修改为./
    output_results(
        "./results/requirements/results." + output_format, results, output_format
    )


def main():
    if len(sys.argv) < 4:
        print(
            "Usage: python script.py <vulnerabilities_file> <requirements_file> <output_format>"
        )
        sys.exit(1)

    vulnerabilities_file = sys.argv[1]
    requirements_file = sys.argv[2]
    output_format = sys.argv[3]

    vulnerabilities = load_vulnerable_packages(vulnerabilities_file)
    requirements = load_requirements(requirements_file)
    check_vulnerabilities(requirements, vulnerabilities, output_format)


if __name__ == "__main__":
    main()