diff --git a/tests/test_pickle_detection.py b/tests/test_pickle_detection.py new file mode 100644 index 0000000..34521e1 --- /dev/null +++ b/tests/test_pickle_detection.py @@ -0,0 +1,56 @@ +import unittest +import pickle +import tempfile +from detection.pickle_detection import pickleScanner, pickleDataDetection +from unittest.mock import patch + + +class TestPickleScanner(unittest.TestCase): + + def setUp(self): + # Create temporary files with valid and malicious data + self.valid_data = {"key": "value"} + self.malicious_data = b"\x80\x03csubprocess\ncheck_output\nq\x00X\x05\x00\x00\x00echo 1q\x01\x85q\x02Rq\x03." + + self.valid_file = tempfile.NamedTemporaryFile(delete=False) + self.valid_file.write(pickle.dumps(self.valid_data)) + self.valid_file.close() + + self.malicious_file = tempfile.NamedTemporaryFile(delete=False) + self.malicious_file.write(self.malicious_data) + self.malicious_file.close() + + def tearDown(self): + # Clean up temporary files + import os + + os.remove(self.valid_file.name) + os.remove(self.malicious_file.name) + + def test_valid_pickle(self): + with open(self.valid_file.name, "rb") as file: + scanner = pickleScanner(file) + print(scanner.maliciousModule) + scanner.load() + output = scanner.output() + self.assertEqual(output["ReduceCount"], 0) + self.assertEqual(output["maliciousModule"], []) + + def test_malicious_pickle(self): + with open(self.malicious_file.name, "rb") as file: + scanner = pickleScanner(file) + scanner.load() + output = scanner.output() + self.assertEqual(output["ReduceCount"], 1) + self.assertIn(("subprocess", "check_output"), output["maliciousModule"]) + + @patch("builtins.print") + def test_pickleDataDetection_no_output_file(self, mock_print): + # test output to stdout if filename is not given + with patch("builtins.print") as mock_print: + pickleDataDetection(self.valid_file.name) + mock_print.assert_called_once() + + +if __name__ == "__main__": + unittest.main()