diff --git a/detection/GPTdetection.py b/detection/GPTdetection.py index 4d8afee..983e847 100644 --- a/detection/GPTdetection.py +++ b/detection/GPTdetection.py @@ -2,12 +2,29 @@ import json import os from .utils import * import openai +import signal + + +class TimeoutException(Exception): + """Custom exception to handle timeouts.""" + + pass + + +def timeout_handler(signum, frame): + """Handle the SIGALRM signal by raising a TimeoutException.""" + raise TimeoutException def detectGPT(content: str): api_key = os.getenv("OPENAI_API_KEY") if api_key is None: raise ValueError("env OPENAI_API_KEY no set") + + # Set alarm timer + signal.signal(signal.SIGTERM, timeout_handler) + signal.alarm(10) + client = openai.OpenAI(api_key=api_key) text = content # client = openai.OpenAI(api_key="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") #测试用key @@ -31,9 +48,16 @@ def detectGPT(content: str): if message_content is None: raise ValueError("API response content is None") res_json = json.loads(message_content) + except json.JSONDecodeError: raise ValueError("Error: Could not parse the response. Please try again.") + except TimeoutException: + raise TimeoutException("The api call timed out") + + finally: + signal.alarm(0) + classified_results = {"high": [], "medium": [], "low": [], "none": []} for res in res_json: classified_results[res["Risk"]].append( diff --git a/tests/test_backdoor_detection.py b/tests/test_backdoor_detection.py index 6e2fe60..7bbb0d4 100644 --- a/tests/test_backdoor_detection.py +++ b/tests/test_backdoor_detection.py @@ -1,6 +1,8 @@ import unittest import warnings +from pydantic import NoneStr + from detection.backdoor_detection import find_dangerous_functions from detection.GPTdetection import detectGPT import os @@ -84,6 +86,8 @@ class TestBackdoorDetection(unittest.TestCase): self.assertEqual(len(results["low"]), 0) def test_gpt_env_no_set(self): + if os.getenv("OPENAI_API_KEY") is not None: + self.skipTest("OPENAI_API_KEY is setted") content = "print('test test')" with self.assertRaises(ValueError): detectGPT(content)