feature/GPT #12

Merged
sangge merged 18 commits from feature/GPT into main 2024-04-29 18:58:49 +08:00
Showing only changes of commit 698cf1c75c - Show all commits

View File

@ -1,10 +1,8 @@
import os import os
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from reportlab.lib.pagesizes import letter from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.styles import getSampleStyleSheet from reportlab.lib.styles import getSampleStyleSheet
from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate
from reportlab.lib import colors
from .Regexdetection import find_dangerous_functions from .Regexdetection import find_dangerous_functions
from .GPTdetection import detectGPT from .GPTdetection import detectGPT
from .utils import * from .utils import *
@ -25,7 +23,7 @@ def generate_text_content(results):
def output_results(results, output_format, output_file=None): def output_results(results, output_format, output_file=None):
if output_file: if output_file:
file_name, file_extension = os.path.splitext(output_file) file_name = os.path.splitext(output_file)
if output_format not in OUTPUT_FORMATS: if output_format not in OUTPUT_FORMATS:
output_format = "txt" output_format = "txt"
output_file = f"{file_name}.txt" output_file = f"{file_name}.txt"
@ -118,12 +116,14 @@ def output_text(results: Dict[str, List[Tuple[int, str]]], file_name=None):
return text_output return text_output
def checkModeAndDetect(mode: str,filePath: str,fileExtension: str): def checkModeAndDetect(mode: str, filePath: str, fileExtension: str):
#TODO:添加更多方式,这里提高代码的复用性和扩展性 # TODO:添加更多方式,这里提高代码的复用性和扩展性
if mode == "regex": if mode == "regex":
return find_dangerous_functions(read_file_content(filePath), fileExtension) return find_dangerous_functions(read_file_content(filePath), fileExtension)
elif mode == "llm": elif mode == "llm":
return detectGPT(read_file_content(filePath)) return detectGPT(read_file_content(filePath))
else:
return find_dangerous_functions(read_file_content(filePath), fileExtension)
def process_path(path: str, output_format: str, mode: str, output_file=None): def process_path(path: str, output_format: str, mode: str, output_file=None):
@ -135,7 +135,7 @@ def process_path(path: str, output_format: str, mode: str, output_file=None):
if file_extension in SUPPORTED_EXTENSIONS: if file_extension in SUPPORTED_EXTENSIONS:
file_path = os.path.join(root, file) file_path = os.path.join(root, file)
file_results = checkModeAndDetect(mode,file_path,file_extension) file_results = checkModeAndDetect(mode, file_path, file_extension)
for key in file_results: for key in file_results:
if key != "none": # Exclude 'none' risk level if key != "none": # Exclude 'none' risk level
results[key].extend( results[key].extend(
@ -147,7 +147,7 @@ def process_path(path: str, output_format: str, mode: str, output_file=None):
elif os.path.isfile(path): elif os.path.isfile(path):
file_extension = os.path.splitext(path)[1] file_extension = os.path.splitext(path)[1]
if file_extension in SUPPORTED_EXTENSIONS: if file_extension in SUPPORTED_EXTENSIONS:
file_results = checkModeAndDetect(mode,path,file_extension) file_results = checkModeAndDetect(mode, path, file_extension)
for key in file_results: for key in file_results:
if key != "none": # Exclude 'none' risk level if key != "none": # Exclude 'none' risk level
results[key].extend( results[key].extend(
@ -172,7 +172,9 @@ def main():
parser = argparse.ArgumentParser(description="Backdoor detection tool.") parser = argparse.ArgumentParser(description="Backdoor detection tool.")
parser.add_argument("path", help="Path to the code to analyze") parser.add_argument("path", help="Path to the code to analyze")
parser.add_argument("-o", "--output", help="Output file path", default=None) parser.add_argument("-o", "--output", help="Output file path", default=None)
parser.add_argument("-m", "--mode", help="Mode of operation:[regex,llm]", default="regex") parser.add_argument(
"-m", "--mode", help="Mode of operation:[regex,llm]", default="regex"
)
args = parser.parse_args() args = parser.parse_args()
output_format = "txt" # Default output format output_format = "txt" # Default output format
output_file = None output_file = None
@ -187,6 +189,7 @@ def main():
"Your input file format was incorrect, the output has been saved as a TXT file." "Your input file format was incorrect, the output has been saved as a TXT file."
) )
output_file = args.output.rsplit(".", 1)[0] + ".txt" output_file = args.output.rsplit(".", 1)[0] + ".txt"
# 如果未指定输出文件,则输出到 stdout否则写入文件
process_path(args.path, output_format, args.mode, output_file) process_path(args.path, output_format, args.mode, output_file)