import io import json import pickle class _Unframer: def __init__(self, file_read, file_readline, file_tell=None): self.file_read = file_read self.file_readline = file_readline self.current_frame = None def readinto(self, buf): if self.current_frame: n = self.current_frame.readinto(buf) if n == 0 and len(buf) != 0: self.current_frame = None n = len(buf) buf[:] = self.file_read(n) return n if n < len(buf): raise pickle.UnpicklingError("pickle exhausted before end of frame") return n else: n = len(buf) buf[:] = self.file_read(n) return n def read(self, n): if self.current_frame: data = self.current_frame.read(n) if not data and n != 0: self.current_frame = None return self.file_read(n) if len(data) < n: raise pickle.UnpicklingError("pickle exhausted before end of frame") return data else: return self.file_read(n) def readline(self): if self.current_frame: data = self.current_frame.readline() if not data: self.current_frame = None return self.file_readline() if data[-1] != b"\n"[0]: raise pickle.UnpicklingError("pickle exhausted before end of frame") return data else: return self.file_readline() def load_frame(self, frame_size): if self.current_frame and self.current_frame.read() != b"": raise pickle.UnpicklingError( "beginning of a new frame before end of current frame" ) self.current_frame = io.BytesIO(self.file_read(frame_size)) dangerous_modules = ["os", "subprocess", "builtins", "nt"] dangerous_names = [ "system", "popen", "run", "call", "check_output", "check_call", ] class pickleScanner: def __init__( self, file, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None ): self._buffers = iter(buffers) if buffers is not None else None self._file_readline = file.readline self._file_read = file.read self.memo = {} self.encoding = encoding self.errors = errors self.proto = 0 self.fix_imports = fix_imports self.file = file self.ReduceCount = 0 self.maliciousModule = [] def find_class(self, module, name): if module.decode() in dangerous_modules or name.decode() in dangerous_names: # self.maliciousCount += 1 self.maliciousModule.append((module.decode(), name.decode())) def load(self): self._unframer = _Unframer(self._file_read, self._file_readline) self.read = self._unframer.read self.readinto = self._unframer.readinto self.readline = self._unframer.readline self.seek = self.file.seek self.metastack = [] self.stack = [] self.append = self.stack.append self.proto = 0 # 扫描所有的opcodes opcode = self.read(1) while opcode: if opcode == b"c": self.seek(-2, 1) codeN1 = self.read(1) if ( 65 <= ord(codeN1) <= 90 or 97 <= ord(codeN1) <= 122 or ord(codeN1) == 0 ): self.read(1) else: self.read(1) module = self.readline()[:-1] name = self.readline()[:-1] self.find_class(module, name) elif opcode in self.unsafe_opcodes: self.ReduceCount += 1 opcode = self.read(1) unsafe_opcodes = { b"r", # REDUCE - call a callable with arguments b"R", # REDUCE - same as 'r', but for args tuple } def output(self) -> dict: return { "ReduceCount": self.ReduceCount, "maliciousModule": self.maliciousModule, } def pickleDataDetection(filename: str, output_file=None): """ :param file: pickle file path """ with open(filename, "rb") as file: pickscan = pickleScanner(file) pickscan.load() res = pickscan.output() return res if __name__ == "__main__": pickleDataDetection("test.pkl")