@@ -1,239 +1,113 @@
import re
import os
import requests
import argparse
import requests
from bs4 import BeautifulSoup
from typing import List , Tuple , Opt ional
from packaging import version
from packaging . specifiers import SpecifierSet
from reportlab . lib . pagesizes import letter
from reportlab . platypus import SimpleDocTemplate , Paragraph , Spacer
from reportlab . lib . styles import getSampleStyleSheet , ParagraphStyle
from packaging . version import Version , InvalidVers ion
import sys
def fetch_html ( url : str ) - > Optional [ str ] :
""" Fetch HTML content from the specified URL.
Args:
url (str): URL to fetch HTML from.
Returns:
Optional[str]: HTML content as a string, or None if fetch fails.
"""
response = requests . get ( url )
if response . status_code == 200 :
def fetch_html ( url : str ) - > str :
try :
response = requests . get ( url )
response . raise_for_status ( )
return response . text
return None
except requests . RequestException as e :
print ( f " Error fetching { url } : { e } " )
return " "
def parse_html ( html : str ) - > List [ Tuple [ str , List [ str ] ] ] :
""" Parse HTML to get content of all ' a ' and ' span ' tags under the second ' td ' of each ' tr ' .
Args:
html (str): HTML content as a string.
Returns:
List[Tuple[str, List[str]]]: A list of tuples containing the text of ' a ' tags and lists of ' span ' texts.
"""
def parse_html ( html : str ) - > list :
soup = BeautifulSoup ( html , " html.parser " )
table = soup . find ( " table " , id = " sortable-table " )
if not table :
return [ ]
rows = table . find_all ( " tr " , class_ = " vue--table__row " )
results = [ ]
if table :
rows = table . find ( " tbody " ) . find_all ( " tr " )
for row in rows :
tds = row . find_all ( " td " )
if len ( tds ) > = 2 :
a_tags = tds [ 1 ] . find_all ( " a " )
span_tags = tds [ 1 ] . find_all ( " span " )
spans = [ span . text . strip ( ) for span in span_tags ]
for a_tag in a_tags :
results . append ( ( a_tag . text . strip ( ) , spans ) )
for row in rows :
info = { }
link = row . find ( " a " )
chip = row . find ( " span " , class_ = " vue--chip__value " )
if link and chip :
info [ " link " ] = link . get_text ( strip = True )
info [ " chip " ] = chip . get_text ( strip = True )
results . append ( info )
return results
def format_results ( results : List [ Tuple [ str , List [ str ] ] ] ) - > str :
""" Format extracted data as a string.
Args :
results (List[Tuple[str, List[str]]]): Extracted data to format.
Returns:
str: Formatted string of the extracted data.
"""
formatted_result = " "
for package_name , version_ranges in results :
formatted_result + = f " Package Name: { package_name } \n "
formatted_result + = " Version Ranges: " + " , " . join ( version_ranges ) + " \n "
formatted_result + = " - " * 50 + " \n "
return formatted_result
def trans_vulnerable_packages ( content ) :
""" 将漏洞版本中的集合形式转换为大于小于的格式
Args:
content (str): 漏洞版本汇总信息.
"""
vulnerabilities = { }
blocks = content . split ( " -------------------------------------------------- " )
range_pattern = re . compile ( r " \ [(.*?), \ s*(.*?) \ ) " )
for block in blocks :
name_match = re . search ( r " Package Name: (.+) " , block )
if name_match :
package_name = name_match . group ( 1 ) . strip ( )
ranges = range_pattern . findall ( block )
specifier_list = [ ]
for start , end in ranges :
if start and end :
specifier_list . append ( f " >= { start } ,< { end } " )
elif start :
specifier_list . append ( f " >= { start } " )
elif end :
specifier_list . append ( f " < { end } " )
if specifier_list :
vulnerabilities [ package_name ] = SpecifierSet ( " , " . join ( specifier_list ) )
return vulnerabilities
def format_vulnerabilities ( vuln_packages ) :
""" 将字典形式的漏洞信息格式化
Args:
vuln_packages (List[Tuple[str, List[str]]]): Extracted data to format.
"""
res = " "
for package , specifiers in vuln_packages . items ( ) :
res + = f " Package Name: { package } \n "
res + = f " Version Ranges: { specifiers } \n "
res + = " - " * 50 + " \n "
return res
def load_requirements ( filename ) :
""" 从文件加载项目的依赖信息 """
with open ( filename , " r " , encoding = " utf-8 " ) as file :
lines = file . readlines ( )
requirements = { }
for line in lines :
if " == " in line :
package_name , package_version = line . strip ( ) . split ( " == " )
requirements [ package_name ] = package_version
def load_requirements ( file_path : str ) - > li st:
requirements = [ ]
try :
with open ( file_path , " r " ) as file :
for line in file :
line = line . strip ( )
if line and not line . startswith ( " # " ) :
requirements . append ( line )
except FileNotFoundError :
print ( f " Error: File { file_path } not found. " )
sys . exit ( 1 )
return requirements
def check_vulnerabilities ( requirements , vulnerabilities , output_file ) :
""" 检查依赖项是否存在已知漏洞,并输出结果 """
results_warning = [ ] # 存储有漏洞的依赖
results_ok = [ ] # 存储没有漏洞的依赖
for req_name , req_version in requirements . items ( ) :
if req_name in vulnerabilities :
spec = vulnerabilities [ req_name ]
if version . parse ( req_version ) in spec :
results_warning . append (
f " WARNING: { req_name } == { req_version } is vulnerable! "
)
else :
results_ok . append ( f " OK: { req_name } == { req_version } is not affected. " )
else :
results_ok . append (
f " OK: { req_name } not found in the vulnerability database. "
)
# 合并结果,先输出所有警告,然后输出所有正常情况
results = results_warning + results_ok
# print(results)
if output_file :
filename , ext = os . path . splitext ( output_file )
output_format = ext [ 1 : ] if ext [ 1 : ] else " txt "
if output_format not in [ " txt " , " md " , " html " , " pdf " ] :
print ( " Warning: Invalid file format specified. Defaulting to TXT format. " )
output_format = " txt " # 确保使用默认格式
output_file = filename + " .txt "
output_results ( output_file , results , output_format )
def version_in_range ( version , range_str : str ) - > bool :
if version is not None :
try :
v = Version ( version )
except InvalidVersion :
return False
else :
print ( " \n " . join ( results ) )
# 如果没有给版本号,默认使用最新版本
if range_str [ - 2 ] == " , " :
return True
ranges = range_str . split ( " , " )
for range_part in ranges :
range_part = range_part . strip ( " []() " )
if range_part :
try :
if range_part . endswith ( " ) " ) :
upper = Version ( range_part [ : - 1 ] )
if v > = upper :
return False
elif range_part . startswith ( " [ " ) :
lower = Version ( range_part [ 1 : ] )
if v < lower :
return False
except InvalidVersion :
return False
return True
def trans _vulnerable_packages_to_dict ( content ) :
""" 将漏洞信息转换为字典格式
Args :
content str: 漏洞信息汇总.
"""
vulnerabilities = { }
blocks = content . split ( " -------------------------------------------------- " )
for block in blocks :
name_match = re . search ( r " Package Name: (.+) " , block )
range_match = re . search ( r " Version Ranges: (.+) " , block )
if name_match and range_match :
package_name = name_match . group ( 1 ) . strip ( )
version_range = range_match . group ( 1 ) . strip ( )
version_range = " , " . join (
[ part . strip ( ) for part in version_range . split ( " , " ) ]
)
vulnerabilities [ package_name ] = SpecifierSet ( version_range )
return vulnerabilities
def output_pdf ( results , file_name ) :
doc = SimpleDocTemplate ( file_name , pagesize = letter )
story = [ ]
styles = getSampleStyleSheet ( )
# Custom styles
title_style = styles [ " Title " ]
title_style . alignment = 1 # Center alignment
warning_style = ParagraphStyl e(
" WarningStyle " , parent = styles [ " BodyText " ] , fontName = " Helvetica-Bold "
)
normal_style = styles [ " BodyText " ]
# Add the title
title = Paragraph ( " Vulnerability Report " , title_style )
story . append ( title )
story . append ( Spacer ( 1 , 20 ) ) # Space after title
# Iterate through results to add entries
for result in results :
if " WARNING: " in result :
# Add warning text in bold
entry = Paragraph (
result . replace ( " WARNING: " , " <b>WARNING:</b> " ) , warning_style
)
else :
# Add normal text
entry = Paragraph ( result , normal_style )
story . append ( entry )
story . append ( Spacer ( 1 , 12 ) ) # Space between entries
doc . build ( story )
def output_results ( filename , results , format_type ) :
""" 根据指定的格式输出结果 """
output_dir = os . path . dirname ( filename )
if not os . path . exists ( output_dir ) :
os . makedirs ( output_dir )
with open ( filename , " w " , encoding = " utf-8 " ) as file :
if format_type == " html " :
file . write ( " <html><head><title>Vulnerability Report</title></head><body> \n " )
file . write ( " <h1>Vulnerability Report</h1> \n " )
for result in results :
file . write ( f " <p> { result } </p> \n " )
file . write ( " </body></html> " )
elif format_type == " md " :
file . write ( " # Vulnerability Report \n " )
for result in results :
file . write ( f " * { result } \n " )
elif format_type == " pdf " :
output_pdf ( results , filename )
else : # 默认为txt
for result in results :
file . write ( f " { result } \n " )
print ( " Results have been saved as " + filename )
def check _vulnerabilities ( requirements : list , base_url : str , output_file : str ) :
with open ( output_file , " w " ) as out_file :
for req in requirements :
version = " "
# 如果有版本
if " == " in req :
package_name , version = req . split ( " == " )
# 没有版本
else :
package_name , version = req , None
# 拼接URL
url = f " { base_url } { package_name } "
print ( f " Fetching data for { package_name } from { url } " )
html_content = fetch_html ( url )
if html_content :
# 解析hmtl
extracted_data = parse_html ( html_content )
if extracted_data :
relevant_vulns = [ ]
for vuln in extracted_data :
if version_in_range ( version , vuln [ " chip " ] ) :
relevant_vulns . append ( vuln )
if relevant_vulns :
out_file . write ( f " Vulnerabilities found for { package_name } : \n " )
for vuln in relevant_vulns :
out_file . write ( f " - { vuln [ ' link ' ] } \n " )
out_file . write ( " \n " )
else :
print ( f " No relevant data found for { package_name } . " )
els e :
print ( f " Failed to fetch data for { package_name } . " )
def main ( ) :
@@ -241,38 +115,25 @@ def main():
description = " Check project dependencies for vulnerabilities. "
)
parser . add_argument (
" requirements_file " , help = " Path to the requirements file of the project "
" -r " ,
" --requirement " ,
help = " Path to the requirements file of the project " ,
required = True ,
)
parser . add_argument (
" -o " ,
" --output " ,
help = " Output file path with extension, e.g., ' ./output/report.txt ' " ,
required = True ,
)
args = parser . parse_args ( )
base_url = " https://security.snyk.io/vuln /pip/ "
page_number = 1
crawler_results = " "
while True :
url = f " { base_url } { page_number } "
print ( f " Fetching data from { url } " )
html_content = fetch_html ( url )
if not html_content :
print ( " No more data found or failed to fetch. " )
break
extracted_data = parse_html ( html_content )
if not extracted_data :
print ( " No relevant data found on page. " )
break
crawler_results + = format_results ( extracted_data )
page_number + = 1
print ( " Results have been stored in memory. \n " )
trans_res = trans_vulnerable_packages ( crawler_results )
trans_res = format_vulnerabilities ( trans_res )
trans_res = trans_vulnerable_packages_to_dict ( trans_res )
requirements = load_requirements ( args . requirements_file )
check_vulnerabilities ( requirements , trans_res , args . output )
base_url = " https://security.snyk.io/package /pip/ "
# 分析项目依赖,包括名称和版本(如果有的话)
requirements = load_requirements ( args . requirement )
# 传入依赖信息, url前缀, 扫描结果输出位置
check_vulnerabilities ( requirements , base_url , args . output )
print ( " Vulnerability scan complete. Results saved to " , args . output )
if __name__ == " __main__ " :