diff --git a/detection/backdoor_detection.py b/detection/backdoor_detection.py new file mode 100644 index 0000000..b4786d6 --- /dev/null +++ b/detection/backdoor_detection.py @@ -0,0 +1,204 @@ +import os +from typing import Dict, List, Tuple +from reportlab.lib.pagesizes import letter +from reportlab.lib.styles import getSampleStyleSheet +from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate + +from detection.pickle_detection import pickleDataDetection +from .Regexdetection import find_dangerous_functions +from .GPTdetection import detectGPT +from .utils import * +import sys + +SUPPORTED_EXTENSIONS = {".py", ".js", ".cpp"} +OUTPUT_FORMATS = ["html", "md", "txt", "pdf"] + + +def generate_text_content(results): + text_output = "Security Analysis Report\n" + for risk_level, entries in results.items(): + if entries and risk_level != "none": + text_output += f"{risk_level.capitalize()} Risk:\n" + for line_num, line in entries: + text_output += f" Line {line_num}: {line}\n" + return text_output + + +def output_results(results, output_format, output_file=None): + if output_file: + file_name = os.path.splitext(output_file) + if output_format not in OUTPUT_FORMATS: + output_format = "txt" + output_file = f"{file_name}.txt" + results_dir = os.path.dirname(output_file) + if not os.path.exists(results_dir): + os.makedirs(results_dir) + if output_format == "pdf": + output_pdf(results, output_file) + elif output_format == "html": + output_html(results, output_file) + elif output_format == "md": + output_markdown(results, output_file) + else: # Default to txt + output_text(results, output_file) + else: + # If no output file is specified, default to text output to the terminal. + txt_output = generate_text_content(results) + print(txt_output) + + +def output_pdf(results: Dict[str, List[Tuple[int, str]]], file_name): + doc = SimpleDocTemplate(file_name, pagesize=letter) + story = [] + styles = getSampleStyleSheet() + + # Add the title centered + title_style = styles["Title"] + title_style.alignment = 1 # Center alignment + title = Paragraph("Security Analysis Report", title_style) + story.append(title) + story.append(Spacer(1, 20)) # Space after title + + # Add risk levels and entries + normal_style = styles["BodyText"] + for risk_level, entries in results.items(): + if risk_level != "none": + story.append( + Paragraph(f"{risk_level.capitalize()} Risk:", styles["Heading2"]) + ) + for line_num, line in entries: + entry = Paragraph(f"Line {line_num}: {line}", normal_style) + story.append(entry) + story.append(Spacer(1, 12)) # Space between sections + + doc.build(story) + + +def output_html(results: Dict[str, List[Tuple[int, str]]], file_name=None): + html_output = "Security Analysis Report" + html_output += "

Security Analysis Report

" + for risk_level, entries in results.items(): + if risk_level != "none": + html_output += f"

{risk_level.capitalize()} Risk

" + html_output += "" + if file_name: + with open(file_name, "w") as file: + file.write(html_output) + else: + return html_output + + +def output_markdown(results: Dict[str, List[Tuple[int, str]]], file_name=None): + md_output = "# Security Analysis Report\n" + for risk_level, entries in results.items(): + if risk_level != "none": + md_output += f"## {risk_level.capitalize()} Risk\n" + for line_num, line in entries: + md_output += f"- {line_num}: {line}\n" + if file_name: + with open(file_name, "w") as file: + file.write(md_output) + else: + return md_output + + +def output_text(results: Dict[str, List[Tuple[int, str]]], file_name=None): + text_output = "Security Analysis Report\n" + for risk_level, entries in results.items(): + if risk_level != "none": + text_output += f"{risk_level.capitalize()} Risk:\n" + for line_num, line in entries: + text_output += f" {line_num}: {line}\n" + if file_name: + with open(file_name, "w") as file: + file.write(text_output) + else: + return text_output + + +def checkModeAndDetect(mode: str, filePath: str, fileExtension: str): + # TODO:添加更多方式,这里提高代码的复用性和扩展性 + if mode == "regex": + return find_dangerous_functions(read_file_content(filePath), fileExtension) + elif mode == "llm": + return detectGPT(read_file_content(filePath)) + else: + return find_dangerous_functions(read_file_content(filePath), fileExtension) + + +def process_path(path: str, output_format: str, mode: str, output_file=None): + results = {"high": [], "medium": [], "low": [], "none": []} + if os.path.isdir(path): + for root, dirs, files in os.walk(path): + for file in files: + file_extension = os.path.splitext(file)[1] + if file_extension in SUPPORTED_EXTENSIONS: + file_path = os.path.join(root, file) + + file_results = checkModeAndDetect(mode, file_path, file_extension) + for key in file_results: + if key != "none": # Exclude 'none' risk level + results[key].extend( + [ + (f"{file_path}: Line {line_num}", line) + for line_num, line in file_results[key] + ] + ) + elif os.path.isfile(path): + file_extension = os.path.splitext(path)[1] + if file_extension in SUPPORTED_EXTENSIONS: + file_results = checkModeAndDetect(mode, path, file_extension) + for key in file_results: + if key != "none": # Exclude 'none' risk level + results[key].extend( + [ + (f"{path}: Line {line_num}", line) + for line_num, line in file_results[key] + ] + ) + else: + print("Unsupported file type.") + return + else: + print("Invalid path.") + sys.exit(1) + + output_results(results, output_format, output_file) + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Backdoor detection tool.") + parser.add_argument("path", help="Path to the code to analyze") + parser.add_argument("-o", "--output", help="Output file path", default=None) + parser.add_argument( + "-m", "--mode", help="Mode of operation:[regex,llm]", default="regex" + ) + parser.add_argument("-p", "--pickle", help="analyze the pickle file", default=None) + args = parser.parse_args() + output_format = "txt" # Default output format + output_file = None + if args.output: + _, ext = os.path.splitext(args.output) + ext = ext.lower() + if ext in [".html", ".md", ".txt", ".pdf"]: + output_format = ext.replace(".", "") + output_file = args.output + else: + print( + "Your input file format was incorrect, the output has been saved as a TXT file." + ) + output_file = args.output.rsplit(".", 1)[0] + ".txt" + # 如果未指定输出文件,则输出到 stdout;否则写入文件 + if args.pickle: + pickleDataDetection(args.pickle, output_file) + else: + process_path(args.path, output_format, args.mode, output_file) + + +if __name__ == "__main__": + main() diff --git a/detection/pickle_detection.py b/detection/pickle_detection.py new file mode 100644 index 0000000..cfbd258 --- /dev/null +++ b/detection/pickle_detection.py @@ -0,0 +1,153 @@ +import io +import json +import pickle + + +class _Unframer: + + def __init__(self, file_read, file_readline, file_tell=None): + self.file_read = file_read + self.file_readline = file_readline + self.current_frame = None + + def readinto(self, buf): + if self.current_frame: + n = self.current_frame.readinto(buf) + if n == 0 and len(buf) != 0: + self.current_frame = None + n = len(buf) + buf[:] = self.file_read(n) + return n + if n < len(buf): + raise pickle.UnpicklingError("pickle exhausted before end of frame") + return n + else: + n = len(buf) + buf[:] = self.file_read(n) + return n + + def read(self, n): + if self.current_frame: + data = self.current_frame.read(n) + if not data and n != 0: + self.current_frame = None + return self.file_read(n) + if len(data) < n: + raise pickle.UnpicklingError("pickle exhausted before end of frame") + return data + else: + return self.file_read(n) + + def readline(self): + if self.current_frame: + data = self.current_frame.readline() + if not data: + self.current_frame = None + return self.file_readline() + if data[-1] != b"\n"[0]: + raise pickle.UnpicklingError("pickle exhausted before end of frame") + return data + else: + return self.file_readline() + + def load_frame(self, frame_size): + if self.current_frame and self.current_frame.read() != b"": + raise pickle.UnpicklingError( + "beginning of a new frame before end of current frame" + ) + self.current_frame = io.BytesIO(self.file_read(frame_size)) + + +dangerous_modules = ["os", "subprocess", "builtins", "nt"] +dangerous_names = [ + "system", + "popen", + "run", + "call", + "check_output", + "check_call", +] + + +class pickleScanner: + + def __init__( + self, file, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None + ): + self._buffers = iter(buffers) if buffers is not None else None + self._file_readline = file.readline + self._file_read = file.read + self.memo = {} + self.encoding = encoding + self.errors = errors + self.proto = 0 + self.fix_imports = fix_imports + self.file = file + self.ReduceCount = 0 + self.maliciousModule = [] + + def find_class(self, module, name): + if module.decode() in dangerous_modules or name.decode() in dangerous_names: + # self.maliciousCount += 1 + self.maliciousModule.append((module.decode(), name.decode())) + + def load(self): + self._unframer = _Unframer(self._file_read, self._file_readline) + self.read = self._unframer.read + self.readinto = self._unframer.readinto + self.readline = self._unframer.readline + self.seek = self.file.seek + self.metastack = [] + self.stack = [] + self.append = self.stack.append + self.proto = 0 + # 扫描所有的opcodes + opcode = self.read(1) + while opcode: + if opcode == b"c": + self.seek(-2, 1) + codeN1 = self.read(1) + if ( + 65 <= ord(codeN1) <= 90 + or 97 <= ord(codeN1) <= 122 + or ord(codeN1) == 0 + ): + self.read(1) + else: + self.read(1) + module = self.readline()[:-1] + name = self.readline()[:-1] + self.find_class(module, name) + elif opcode in self.unsafe_opcodes: + self.ReduceCount += 1 + opcode = self.read(1) + + unsafe_opcodes = { + b"r", # REDUCE - call a callable with arguments + b"R", # REDUCE - same as 'r', but for args tuple + } + + def output(self) -> dict: + return { + "ReduceCount": self.ReduceCount, + "maliciousModule": self.maliciousModule, + } + + +def pickleDataDetection(filename: str, output_file=None): + """ + :param file: pickle file path + """ + with open(filename, "rb") as file: + pickscan = pickleScanner(file) + pickscan.load() + res = pickscan.output() + if output_file: + with open(output_file, "w") as file: + json.dump(res, file, indent=4) + else: + print(json.dumps(res)) + + +if __name__ == "__main__": + pickleDataDetection("test.pkl") diff --git a/tests/test_pickle_detection.py b/tests/test_pickle_detection.py new file mode 100644 index 0000000..34521e1 --- /dev/null +++ b/tests/test_pickle_detection.py @@ -0,0 +1,56 @@ +import unittest +import pickle +import tempfile +from detection.pickle_detection import pickleScanner, pickleDataDetection +from unittest.mock import patch + + +class TestPickleScanner(unittest.TestCase): + + def setUp(self): + # Create temporary files with valid and malicious data + self.valid_data = {"key": "value"} + self.malicious_data = b"\x80\x03csubprocess\ncheck_output\nq\x00X\x05\x00\x00\x00echo 1q\x01\x85q\x02Rq\x03." + + self.valid_file = tempfile.NamedTemporaryFile(delete=False) + self.valid_file.write(pickle.dumps(self.valid_data)) + self.valid_file.close() + + self.malicious_file = tempfile.NamedTemporaryFile(delete=False) + self.malicious_file.write(self.malicious_data) + self.malicious_file.close() + + def tearDown(self): + # Clean up temporary files + import os + + os.remove(self.valid_file.name) + os.remove(self.malicious_file.name) + + def test_valid_pickle(self): + with open(self.valid_file.name, "rb") as file: + scanner = pickleScanner(file) + print(scanner.maliciousModule) + scanner.load() + output = scanner.output() + self.assertEqual(output["ReduceCount"], 0) + self.assertEqual(output["maliciousModule"], []) + + def test_malicious_pickle(self): + with open(self.malicious_file.name, "rb") as file: + scanner = pickleScanner(file) + scanner.load() + output = scanner.output() + self.assertEqual(output["ReduceCount"], 1) + self.assertIn(("subprocess", "check_output"), output["maliciousModule"]) + + @patch("builtins.print") + def test_pickleDataDetection_no_output_file(self, mock_print): + # test output to stdout if filename is not given + with patch("builtins.print") as mock_print: + pickleDataDetection(self.valid_file.name) + mock_print.assert_called_once() + + +if __name__ == "__main__": + unittest.main()