fix: 修复一些错误
附带完成了一些格式化
This commit is contained in:
		@@ -1,10 +1,8 @@
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
from typing import Dict, List, Tuple
 | 
					from typing import Dict, List, Tuple
 | 
				
			||||||
from reportlab.lib.pagesizes import letter
 | 
					from reportlab.lib.pagesizes import letter
 | 
				
			||||||
from reportlab.pdfgen import canvas
 | 
					 | 
				
			||||||
from reportlab.lib.styles import getSampleStyleSheet
 | 
					from reportlab.lib.styles import getSampleStyleSheet
 | 
				
			||||||
from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate
 | 
					from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate
 | 
				
			||||||
from reportlab.lib import colors
 | 
					 | 
				
			||||||
from .Regexdetection import find_dangerous_functions
 | 
					from .Regexdetection import find_dangerous_functions
 | 
				
			||||||
from .GPTdetection import detectGPT
 | 
					from .GPTdetection import detectGPT
 | 
				
			||||||
from .utils import *
 | 
					from .utils import *
 | 
				
			||||||
@@ -25,7 +23,7 @@ def generate_text_content(results):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def output_results(results, output_format, output_file=None):
 | 
					def output_results(results, output_format, output_file=None):
 | 
				
			||||||
    if output_file:
 | 
					    if output_file:
 | 
				
			||||||
        file_name, file_extension = os.path.splitext(output_file)
 | 
					        file_name = os.path.splitext(output_file)
 | 
				
			||||||
        if output_format not in OUTPUT_FORMATS:
 | 
					        if output_format not in OUTPUT_FORMATS:
 | 
				
			||||||
            output_format = "txt"
 | 
					            output_format = "txt"
 | 
				
			||||||
            output_file = f"{file_name}.txt"
 | 
					            output_file = f"{file_name}.txt"
 | 
				
			||||||
@@ -118,12 +116,14 @@ def output_text(results: Dict[str, List[Tuple[int, str]]], file_name=None):
 | 
				
			|||||||
        return text_output
 | 
					        return text_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def checkModeAndDetect(mode: str,filePath: str,fileExtension: str):
 | 
					def checkModeAndDetect(mode: str, filePath: str, fileExtension: str):
 | 
				
			||||||
    #TODO:添加更多方式,这里提高代码的复用性和扩展性
 | 
					    # TODO:添加更多方式,这里提高代码的复用性和扩展性
 | 
				
			||||||
    if mode == "regex":
 | 
					    if mode == "regex":
 | 
				
			||||||
        return find_dangerous_functions(read_file_content(filePath), fileExtension)
 | 
					        return find_dangerous_functions(read_file_content(filePath), fileExtension)
 | 
				
			||||||
    elif mode == "llm":
 | 
					    elif mode == "llm":
 | 
				
			||||||
        return detectGPT(read_file_content(filePath))
 | 
					        return detectGPT(read_file_content(filePath))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return find_dangerous_functions(read_file_content(filePath), fileExtension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def process_path(path: str, output_format: str, mode: str, output_file=None):
 | 
					def process_path(path: str, output_format: str, mode: str, output_file=None):
 | 
				
			||||||
@@ -135,7 +135,7 @@ def process_path(path: str, output_format: str, mode: str, output_file=None):
 | 
				
			|||||||
                if file_extension in SUPPORTED_EXTENSIONS:
 | 
					                if file_extension in SUPPORTED_EXTENSIONS:
 | 
				
			||||||
                    file_path = os.path.join(root, file)
 | 
					                    file_path = os.path.join(root, file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    file_results = checkModeAndDetect(mode,file_path,file_extension)
 | 
					                    file_results = checkModeAndDetect(mode, file_path, file_extension)
 | 
				
			||||||
                    for key in file_results:
 | 
					                    for key in file_results:
 | 
				
			||||||
                        if key != "none":  # Exclude 'none' risk level
 | 
					                        if key != "none":  # Exclude 'none' risk level
 | 
				
			||||||
                            results[key].extend(
 | 
					                            results[key].extend(
 | 
				
			||||||
@@ -147,7 +147,7 @@ def process_path(path: str, output_format: str, mode: str, output_file=None):
 | 
				
			|||||||
    elif os.path.isfile(path):
 | 
					    elif os.path.isfile(path):
 | 
				
			||||||
        file_extension = os.path.splitext(path)[1]
 | 
					        file_extension = os.path.splitext(path)[1]
 | 
				
			||||||
        if file_extension in SUPPORTED_EXTENSIONS:
 | 
					        if file_extension in SUPPORTED_EXTENSIONS:
 | 
				
			||||||
            file_results = checkModeAndDetect(mode,path,file_extension)
 | 
					            file_results = checkModeAndDetect(mode, path, file_extension)
 | 
				
			||||||
            for key in file_results:
 | 
					            for key in file_results:
 | 
				
			||||||
                if key != "none":  # Exclude 'none' risk level
 | 
					                if key != "none":  # Exclude 'none' risk level
 | 
				
			||||||
                    results[key].extend(
 | 
					                    results[key].extend(
 | 
				
			||||||
@@ -172,7 +172,9 @@ def main():
 | 
				
			|||||||
    parser = argparse.ArgumentParser(description="Backdoor detection tool.")
 | 
					    parser = argparse.ArgumentParser(description="Backdoor detection tool.")
 | 
				
			||||||
    parser.add_argument("path", help="Path to the code to analyze")
 | 
					    parser.add_argument("path", help="Path to the code to analyze")
 | 
				
			||||||
    parser.add_argument("-o", "--output", help="Output file path", default=None)
 | 
					    parser.add_argument("-o", "--output", help="Output file path", default=None)
 | 
				
			||||||
    parser.add_argument("-m", "--mode", help="Mode of operation:[regex,llm]", default="regex")
 | 
					    parser.add_argument(
 | 
				
			||||||
 | 
					        "-m", "--mode", help="Mode of operation:[regex,llm]", default="regex"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    output_format = "txt"  # Default output format
 | 
					    output_format = "txt"  # Default output format
 | 
				
			||||||
    output_file = None
 | 
					    output_file = None
 | 
				
			||||||
@@ -187,6 +189,7 @@ def main():
 | 
				
			|||||||
                "Your input file format was incorrect, the output has been saved as a TXT file."
 | 
					                "Your input file format was incorrect, the output has been saved as a TXT file."
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            output_file = args.output.rsplit(".", 1)[0] + ".txt"
 | 
					            output_file = args.output.rsplit(".", 1)[0] + ".txt"
 | 
				
			||||||
 | 
					    # 如果未指定输出文件,则输出到 stdout;否则写入文件
 | 
				
			||||||
    process_path(args.path, output_format, args.mode, output_file)
 | 
					    process_path(args.path, output_format, args.mode, output_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user