feat: 实时爬取漏洞信息
This commit is contained in:
parent
953a320dd5
commit
e28cb2416d
@ -1,7 +1,9 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
import argparse
|
||||
from bs4 import BeautifulSoup
|
||||
from typing import List, Tuple, Optional
|
||||
from packaging import version
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from reportlab.lib.pagesizes import letter
|
||||
@ -9,10 +11,154 @@ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
|
||||
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
||||
|
||||
|
||||
def load_vulnerable_packages(filename):
|
||||
"""从文件加载有漏洞的包信息"""
|
||||
def fetch_html(url: str) -> Optional[str]:
|
||||
"""Fetch HTML content from the specified URL.
|
||||
|
||||
Args:
|
||||
url (str): URL to fetch HTML from.
|
||||
|
||||
Returns:
|
||||
Optional[str]: HTML content as a string, or None if fetch fails.
|
||||
"""
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return response.text
|
||||
return None
|
||||
|
||||
|
||||
def parse_html(html: str) -> List[Tuple[str, List[str]]]:
|
||||
"""Parse HTML to get content of all 'a' and 'span' tags under the second 'td' of each 'tr'.
|
||||
|
||||
Args:
|
||||
html (str): HTML content as a string.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, List[str]]]: A list of tuples containing the text of 'a' tags and lists of 'span' texts.
|
||||
"""
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
table = soup.find("table", id="sortable-table")
|
||||
results = []
|
||||
if table:
|
||||
rows = table.find("tbody").find_all("tr")
|
||||
for row in rows:
|
||||
tds = row.find_all("td")
|
||||
if len(tds) >= 2:
|
||||
a_tags = tds[1].find_all("a")
|
||||
span_tags = tds[1].find_all("span")
|
||||
spans = [span.text.strip() for span in span_tags]
|
||||
for a_tag in a_tags:
|
||||
results.append((a_tag.text.strip(), spans))
|
||||
return results
|
||||
|
||||
|
||||
def format_results(results: List[Tuple[str, List[str]]]) -> str:
|
||||
"""Format extracted data as a string.
|
||||
|
||||
Args:
|
||||
results (List[Tuple[str, List[str]]]): Extracted data to format.
|
||||
|
||||
Returns:
|
||||
str: Formatted string of the extracted data.
|
||||
"""
|
||||
formatted_result = ""
|
||||
for package_name, version_ranges in results:
|
||||
formatted_result += f"Package Name: {package_name}\n"
|
||||
formatted_result += "Version Ranges: " + ", ".join(version_ranges) + "\n"
|
||||
formatted_result += "-" * 50 + "\n"
|
||||
return formatted_result
|
||||
|
||||
|
||||
def trans_vulnerable_packages(content):
|
||||
"""将漏洞版本中的集合形式转换为大于小于的格式
|
||||
Args:
|
||||
content (str): 漏洞版本汇总信息.
|
||||
"""
|
||||
vulnerabilities = {}
|
||||
blocks = content.split("--------------------------------------------------")
|
||||
range_pattern = re.compile(r"\[(.*?),\s*(.*?)\)")
|
||||
|
||||
for block in blocks:
|
||||
name_match = re.search(r"Package Name: (.+)", block)
|
||||
if name_match:
|
||||
package_name = name_match.group(1).strip()
|
||||
ranges = range_pattern.findall(block)
|
||||
specifier_list = []
|
||||
for start, end in ranges:
|
||||
if start and end:
|
||||
specifier_list.append(f">={start},<{end}")
|
||||
elif start:
|
||||
specifier_list.append(f">={start}")
|
||||
elif end:
|
||||
specifier_list.append(f"<{end}")
|
||||
if specifier_list:
|
||||
vulnerabilities[package_name] = SpecifierSet(",".join(specifier_list))
|
||||
return vulnerabilities
|
||||
|
||||
|
||||
def format_vulnerabilities(vuln_packages):
|
||||
"""将字典形式的漏洞信息格式化
|
||||
Args:
|
||||
vuln_packages (List[Tuple[str, List[str]]]): Extracted data to format.
|
||||
"""
|
||||
res = ""
|
||||
for package, specifiers in vuln_packages.items():
|
||||
res += f"Package Name: {package}\n"
|
||||
res += f"Version Ranges: {specifiers}\n"
|
||||
res += "-" * 50 + "\n"
|
||||
return res
|
||||
|
||||
|
||||
def load_requirements(filename):
|
||||
"""从文件加载项目的依赖信息"""
|
||||
with open(filename, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
lines = file.readlines()
|
||||
requirements = {}
|
||||
for line in lines:
|
||||
if "==" in line:
|
||||
package_name, package_version = line.strip().split("==")
|
||||
requirements[package_name] = package_version
|
||||
return requirements
|
||||
|
||||
|
||||
def check_vulnerabilities(requirements, vulnerabilities, output_file):
|
||||
"""检查依赖项是否存在已知漏洞,并输出结果"""
|
||||
results_warning = [] # 存储有漏洞的依赖
|
||||
results_ok = [] # 存储没有漏洞的依赖
|
||||
|
||||
for req_name, req_version in requirements.items():
|
||||
if req_name in vulnerabilities:
|
||||
spec = vulnerabilities[req_name]
|
||||
if version.parse(req_version) in spec:
|
||||
results_warning.append(
|
||||
f"WARNING: {req_name}=={req_version} is vulnerable!"
|
||||
)
|
||||
else:
|
||||
results_ok.append(f"OK: {req_name}=={req_version} is not affected.")
|
||||
else:
|
||||
results_ok.append(
|
||||
f"OK: {req_name} not found in the vulnerability database."
|
||||
)
|
||||
|
||||
# 合并结果,先输出所有警告,然后输出所有正常情况
|
||||
results = results_warning + results_ok
|
||||
# print(results)
|
||||
if output_file:
|
||||
filename, ext = os.path.splitext(output_file)
|
||||
output_format = ext[1:] if ext[1:] else "txt"
|
||||
if output_format not in ["txt", "md", "html", "pdf"]:
|
||||
print("Warning: Invalid file format specified. Defaulting to TXT format.")
|
||||
output_format = "txt" # 确保使用默认格式
|
||||
output_file = filename + ".txt"
|
||||
output_results(output_file, results, output_format)
|
||||
else:
|
||||
print("\n".join(results))
|
||||
|
||||
|
||||
def trans_vulnerable_packages_to_dict(content):
|
||||
"""将漏洞信息转换为字典格式
|
||||
Args:
|
||||
content str: 漏洞信息汇总.
|
||||
"""
|
||||
vulnerabilities = {}
|
||||
blocks = content.split("--------------------------------------------------")
|
||||
for block in blocks:
|
||||
@ -28,18 +174,6 @@ def load_vulnerable_packages(filename):
|
||||
return vulnerabilities
|
||||
|
||||
|
||||
def load_requirements(filename):
|
||||
"""从文件加载项目的依赖信息"""
|
||||
with open(filename, "r", encoding="utf-8") as file:
|
||||
lines = file.readlines()
|
||||
requirements = {}
|
||||
for line in lines:
|
||||
if "==" in line:
|
||||
package_name, package_version = line.strip().split("==")
|
||||
requirements[package_name] = package_version
|
||||
return requirements
|
||||
|
||||
|
||||
def output_pdf(results, file_name):
|
||||
doc = SimpleDocTemplate(file_name, pagesize=letter)
|
||||
story = []
|
||||
@ -99,48 +233,13 @@ def output_results(filename, results, format_type):
|
||||
for result in results:
|
||||
file.write(f"{result}\n")
|
||||
|
||||
|
||||
def check_vulnerabilities(requirements, vulnerabilities, output_file):
|
||||
"""检查依赖项是否存在已知漏洞,并输出结果"""
|
||||
results_warning = [] # 存储有漏洞的依赖
|
||||
results_ok = [] # 存储没有漏洞的依赖
|
||||
|
||||
for req_name, req_version in requirements.items():
|
||||
if req_name in vulnerabilities:
|
||||
spec = vulnerabilities[req_name]
|
||||
if version.parse(req_version) in spec:
|
||||
results_warning.append(
|
||||
f"WARNING: {req_name}=={req_version} is vulnerable!"
|
||||
)
|
||||
else:
|
||||
results_ok.append(f"OK: {req_name}=={req_version} is not affected.")
|
||||
else:
|
||||
results_ok.append(
|
||||
f"OK: {req_name} not found in the vulnerability database."
|
||||
)
|
||||
|
||||
# 合并结果,先输出所有警告,然后输出所有正常情况
|
||||
results = results_warning + results_ok
|
||||
|
||||
if output_file:
|
||||
filename, ext = os.path.splitext(output_file)
|
||||
output_format = ext[1:] if ext[1:] else "txt"
|
||||
if output_format not in ["txt", "md", "html", "pdf"]:
|
||||
print("Warning: Invalid file format specified. Defaulting to TXT format.")
|
||||
output_format = "txt" # 确保使用默认格式
|
||||
output_file = filename + ".txt"
|
||||
output_results(output_file, results, output_format)
|
||||
else:
|
||||
print("\n".join(results))
|
||||
print("Results have been saved as " + filename)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Check project dependencies for vulnerabilities."
|
||||
)
|
||||
parser.add_argument(
|
||||
"vulnerabilities_file", help="Path to the file containing vulnerability data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"requirements_file", help="Path to the requirements file of the project"
|
||||
)
|
||||
@ -151,9 +250,29 @@ def main():
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
vulnerabilities = load_vulnerable_packages(args.vulnerabilities_file)
|
||||
base_url = "https://security.snyk.io/vuln/pip/"
|
||||
page_number = 1
|
||||
crawler_results = ""
|
||||
while True:
|
||||
url = f"{base_url}{page_number}"
|
||||
print(f"Fetching data from {url}")
|
||||
html_content = fetch_html(url)
|
||||
if not html_content:
|
||||
print("No more data found or failed to fetch.")
|
||||
break
|
||||
extracted_data = parse_html(html_content)
|
||||
if not extracted_data:
|
||||
print("No relevant data found on page.")
|
||||
break
|
||||
crawler_results += format_results(extracted_data)
|
||||
page_number += 1
|
||||
print("Results have been stored in memory.\n")
|
||||
|
||||
trans_res = trans_vulnerable_packages(crawler_results)
|
||||
trans_res = format_vulnerabilities(trans_res)
|
||||
trans_res = trans_vulnerable_packages_to_dict(trans_res)
|
||||
requirements = load_requirements(args.requirements_file)
|
||||
check_vulnerabilities(requirements, vulnerabilities, args.output)
|
||||
check_vulnerabilities(requirements, trans_res, args.output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user