feature/GPT #12

Merged
sangge merged 18 commits from feature/GPT into main 2024-04-29 18:58:49 +08:00
Showing only changes of commit 698cf1c75c - Show all commits

View File

@ -1,10 +1,8 @@
import os
from typing import Dict, List, Tuple
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.styles import getSampleStyleSheet
from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate
from reportlab.lib import colors
from .Regexdetection import find_dangerous_functions
from .GPTdetection import detectGPT
from .utils import *
@ -25,7 +23,7 @@ def generate_text_content(results):
def output_results(results, output_format, output_file=None):
if output_file:
file_name, file_extension = os.path.splitext(output_file)
file_name = os.path.splitext(output_file)
if output_format not in OUTPUT_FORMATS:
output_format = "txt"
output_file = f"{file_name}.txt"
@ -118,12 +116,14 @@ def output_text(results: Dict[str, List[Tuple[int, str]]], file_name=None):
return text_output
def checkModeAndDetect(mode: str,filePath: str,fileExtension: str):
#TODO:添加更多方式,这里提高代码的复用性和扩展性
def checkModeAndDetect(mode: str, filePath: str, fileExtension: str):
# TODO:添加更多方式,这里提高代码的复用性和扩展性
if mode == "regex":
return find_dangerous_functions(read_file_content(filePath), fileExtension)
elif mode == "llm":
return detectGPT(read_file_content(filePath))
else:
return find_dangerous_functions(read_file_content(filePath), fileExtension)
def process_path(path: str, output_format: str, mode: str, output_file=None):
@ -135,7 +135,7 @@ def process_path(path: str, output_format: str, mode: str, output_file=None):
if file_extension in SUPPORTED_EXTENSIONS:
file_path = os.path.join(root, file)
file_results = checkModeAndDetect(mode,file_path,file_extension)
file_results = checkModeAndDetect(mode, file_path, file_extension)
for key in file_results:
if key != "none": # Exclude 'none' risk level
results[key].extend(
@ -147,7 +147,7 @@ def process_path(path: str, output_format: str, mode: str, output_file=None):
elif os.path.isfile(path):
file_extension = os.path.splitext(path)[1]
if file_extension in SUPPORTED_EXTENSIONS:
file_results = checkModeAndDetect(mode,path,file_extension)
file_results = checkModeAndDetect(mode, path, file_extension)
for key in file_results:
if key != "none": # Exclude 'none' risk level
results[key].extend(
@ -172,7 +172,9 @@ def main():
parser = argparse.ArgumentParser(description="Backdoor detection tool.")
parser.add_argument("path", help="Path to the code to analyze")
parser.add_argument("-o", "--output", help="Output file path", default=None)
parser.add_argument("-m", "--mode", help="Mode of operation:[regex,llm]", default="regex")
parser.add_argument(
"-m", "--mode", help="Mode of operation:[regex,llm]", default="regex"
)
args = parser.parse_args()
output_format = "txt" # Default output format
output_file = None
@ -187,6 +189,7 @@ def main():
"Your input file format was incorrect, the output has been saved as a TXT file."
)
output_file = args.output.rsplit(".", 1)[0] + ".txt"
# 如果未指定输出文件,则输出到 stdout否则写入文件
process_path(args.path, output_format, args.mode, output_file)