feat: 实时爬取漏洞信息

This commit is contained in:
dqy 2024-04-30 21:38:14 +08:00
parent 953a320dd5
commit e28cb2416d

View File

@ -1,7 +1,9 @@
import argparse
import os
import re import re
import sys import os
import requests
import argparse
from bs4 import BeautifulSoup
from typing import List, Tuple, Optional
from packaging import version from packaging import version
from packaging.specifiers import SpecifierSet from packaging.specifiers import SpecifierSet
from reportlab.lib.pagesizes import letter from reportlab.lib.pagesizes import letter
@ -9,10 +11,154 @@ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
def load_vulnerable_packages(filename): def fetch_html(url: str) -> Optional[str]:
"""从文件加载有漏洞的包信息""" """Fetch HTML content from the specified URL.
Args:
url (str): URL to fetch HTML from.
Returns:
Optional[str]: HTML content as a string, or None if fetch fails.
"""
response = requests.get(url)
if response.status_code == 200:
return response.text
return None
def parse_html(html: str) -> List[Tuple[str, List[str]]]:
"""Parse HTML to get content of all 'a' and 'span' tags under the second 'td' of each 'tr'.
Args:
html (str): HTML content as a string.
Returns:
List[Tuple[str, List[str]]]: A list of tuples containing the text of 'a' tags and lists of 'span' texts.
"""
soup = BeautifulSoup(html, "html.parser")
table = soup.find("table", id="sortable-table")
results = []
if table:
rows = table.find("tbody").find_all("tr")
for row in rows:
tds = row.find_all("td")
if len(tds) >= 2:
a_tags = tds[1].find_all("a")
span_tags = tds[1].find_all("span")
spans = [span.text.strip() for span in span_tags]
for a_tag in a_tags:
results.append((a_tag.text.strip(), spans))
return results
def format_results(results: List[Tuple[str, List[str]]]) -> str:
"""Format extracted data as a string.
Args:
results (List[Tuple[str, List[str]]]): Extracted data to format.
Returns:
str: Formatted string of the extracted data.
"""
formatted_result = ""
for package_name, version_ranges in results:
formatted_result += f"Package Name: {package_name}\n"
formatted_result += "Version Ranges: " + ", ".join(version_ranges) + "\n"
formatted_result += "-" * 50 + "\n"
return formatted_result
def trans_vulnerable_packages(content):
"""将漏洞版本中的集合形式转换为大于小于的格式
Args:
content (str): 漏洞版本汇总信息.
"""
vulnerabilities = {}
blocks = content.split("--------------------------------------------------")
range_pattern = re.compile(r"\[(.*?),\s*(.*?)\)")
for block in blocks:
name_match = re.search(r"Package Name: (.+)", block)
if name_match:
package_name = name_match.group(1).strip()
ranges = range_pattern.findall(block)
specifier_list = []
for start, end in ranges:
if start and end:
specifier_list.append(f">={start},<{end}")
elif start:
specifier_list.append(f">={start}")
elif end:
specifier_list.append(f"<{end}")
if specifier_list:
vulnerabilities[package_name] = SpecifierSet(",".join(specifier_list))
return vulnerabilities
def format_vulnerabilities(vuln_packages):
"""将字典形式的漏洞信息格式化
Args:
vuln_packages (List[Tuple[str, List[str]]]): Extracted data to format.
"""
res = ""
for package, specifiers in vuln_packages.items():
res += f"Package Name: {package}\n"
res += f"Version Ranges: {specifiers}\n"
res += "-" * 50 + "\n"
return res
def load_requirements(filename):
"""从文件加载项目的依赖信息"""
with open(filename, "r", encoding="utf-8") as file: with open(filename, "r", encoding="utf-8") as file:
content = file.read() lines = file.readlines()
requirements = {}
for line in lines:
if "==" in line:
package_name, package_version = line.strip().split("==")
requirements[package_name] = package_version
return requirements
def check_vulnerabilities(requirements, vulnerabilities, output_file):
"""检查依赖项是否存在已知漏洞,并输出结果"""
results_warning = [] # 存储有漏洞的依赖
results_ok = [] # 存储没有漏洞的依赖
for req_name, req_version in requirements.items():
if req_name in vulnerabilities:
spec = vulnerabilities[req_name]
if version.parse(req_version) in spec:
results_warning.append(
f"WARNING: {req_name}=={req_version} is vulnerable!"
)
else:
results_ok.append(f"OK: {req_name}=={req_version} is not affected.")
else:
results_ok.append(
f"OK: {req_name} not found in the vulnerability database."
)
# 合并结果,先输出所有警告,然后输出所有正常情况
results = results_warning + results_ok
# print(results)
if output_file:
filename, ext = os.path.splitext(output_file)
output_format = ext[1:] if ext[1:] else "txt"
if output_format not in ["txt", "md", "html", "pdf"]:
print("Warning: Invalid file format specified. Defaulting to TXT format.")
output_format = "txt" # 确保使用默认格式
output_file = filename + ".txt"
output_results(output_file, results, output_format)
else:
print("\n".join(results))
def trans_vulnerable_packages_to_dict(content):
"""将漏洞信息转换为字典格式
Args:
content str: 漏洞信息汇总.
"""
vulnerabilities = {} vulnerabilities = {}
blocks = content.split("--------------------------------------------------") blocks = content.split("--------------------------------------------------")
for block in blocks: for block in blocks:
@ -28,18 +174,6 @@ def load_vulnerable_packages(filename):
return vulnerabilities return vulnerabilities
def load_requirements(filename):
"""从文件加载项目的依赖信息"""
with open(filename, "r", encoding="utf-8") as file:
lines = file.readlines()
requirements = {}
for line in lines:
if "==" in line:
package_name, package_version = line.strip().split("==")
requirements[package_name] = package_version
return requirements
def output_pdf(results, file_name): def output_pdf(results, file_name):
doc = SimpleDocTemplate(file_name, pagesize=letter) doc = SimpleDocTemplate(file_name, pagesize=letter)
story = [] story = []
@ -99,48 +233,13 @@ def output_results(filename, results, format_type):
for result in results: for result in results:
file.write(f"{result}\n") file.write(f"{result}\n")
print("Results have been saved as " + filename)
def check_vulnerabilities(requirements, vulnerabilities, output_file):
"""检查依赖项是否存在已知漏洞,并输出结果"""
results_warning = [] # 存储有漏洞的依赖
results_ok = [] # 存储没有漏洞的依赖
for req_name, req_version in requirements.items():
if req_name in vulnerabilities:
spec = vulnerabilities[req_name]
if version.parse(req_version) in spec:
results_warning.append(
f"WARNING: {req_name}=={req_version} is vulnerable!"
)
else:
results_ok.append(f"OK: {req_name}=={req_version} is not affected.")
else:
results_ok.append(
f"OK: {req_name} not found in the vulnerability database."
)
# 合并结果,先输出所有警告,然后输出所有正常情况
results = results_warning + results_ok
if output_file:
filename, ext = os.path.splitext(output_file)
output_format = ext[1:] if ext[1:] else "txt"
if output_format not in ["txt", "md", "html", "pdf"]:
print("Warning: Invalid file format specified. Defaulting to TXT format.")
output_format = "txt" # 确保使用默认格式
output_file = filename + ".txt"
output_results(output_file, results, output_format)
else:
print("\n".join(results))
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Check project dependencies for vulnerabilities." description="Check project dependencies for vulnerabilities."
) )
parser.add_argument(
"vulnerabilities_file", help="Path to the file containing vulnerability data"
)
parser.add_argument( parser.add_argument(
"requirements_file", help="Path to the requirements file of the project" "requirements_file", help="Path to the requirements file of the project"
) )
@ -151,9 +250,29 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
vulnerabilities = load_vulnerable_packages(args.vulnerabilities_file) base_url = "https://security.snyk.io/vuln/pip/"
page_number = 1
crawler_results = ""
while True:
url = f"{base_url}{page_number}"
print(f"Fetching data from {url}")
html_content = fetch_html(url)
if not html_content:
print("No more data found or failed to fetch.")
break
extracted_data = parse_html(html_content)
if not extracted_data:
print("No relevant data found on page.")
break
crawler_results += format_results(extracted_data)
page_number += 1
print("Results have been stored in memory.\n")
trans_res = trans_vulnerable_packages(crawler_results)
trans_res = format_vulnerabilities(trans_res)
trans_res = trans_vulnerable_packages_to_dict(trans_res)
requirements = load_requirements(args.requirements_file) requirements = load_requirements(args.requirements_file)
check_vulnerabilities(requirements, vulnerabilities, args.output) check_vulnerabilities(requirements, trans_res, args.output)
if __name__ == "__main__": if __name__ == "__main__":