feature/GPT #12
| @@ -2,12 +2,29 @@ import json | ||||
| import os | ||||
| from .utils import * | ||||
| import openai | ||||
| import signal | ||||
|  | ||||
|  | ||||
| class TimeoutException(Exception): | ||||
|     """Custom exception to handle timeouts.""" | ||||
|  | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def timeout_handler(signum, frame): | ||||
|     """Handle the SIGALRM signal by raising a TimeoutException.""" | ||||
|     raise TimeoutException | ||||
|  | ||||
|  | ||||
| def detectGPT(content: str): | ||||
|     api_key = os.getenv("OPENAI_API_KEY") | ||||
|     if api_key is None: | ||||
|         raise ValueError("env OPENAI_API_KEY no set") | ||||
|  | ||||
|     # Set alarm timer | ||||
|     signal.signal(signal.SIGTERM, timeout_handler) | ||||
|     signal.alarm(10) | ||||
|  | ||||
|     client = openai.OpenAI(api_key=api_key) | ||||
|     text = content | ||||
|     # client = openai.OpenAI(api_key="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") #测试用key | ||||
| @@ -31,9 +48,16 @@ def detectGPT(content: str): | ||||
|         if message_content is None: | ||||
|             raise ValueError("API response content is None") | ||||
|         res_json = json.loads(message_content) | ||||
|  | ||||
|     except json.JSONDecodeError: | ||||
|         raise ValueError("Error: Could not parse the response. Please try again.") | ||||
|  | ||||
|     except TimeoutException: | ||||
|         raise TimeoutException("The api call timed out") | ||||
|  | ||||
|     finally: | ||||
|         signal.alarm(0) | ||||
|  | ||||
|     classified_results = {"high": [], "medium": [], "low": [], "none": []} | ||||
|     for res in res_json: | ||||
|         classified_results[res["Risk"]].append( | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| import unittest | ||||
| import warnings | ||||
|  | ||||
| from pydantic import NoneStr | ||||
|  | ||||
| from detection.backdoor_detection import find_dangerous_functions | ||||
| from detection.GPTdetection import detectGPT | ||||
| import os | ||||
| @@ -84,6 +86,8 @@ class TestBackdoorDetection(unittest.TestCase): | ||||
|         self.assertEqual(len(results["low"]), 0) | ||||
|  | ||||
|     def test_gpt_env_no_set(self): | ||||
|         if os.getenv("OPENAI_API_KEY") is not None: | ||||
|             self.skipTest("OPENAI_API_KEY is setted") | ||||
|         content = "print('test test')" | ||||
|         with self.assertRaises(ValueError): | ||||
|             detectGPT(content) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user