feature/cn-gpt #21
| @@ -1,36 +1,35 @@ | |||||||
| import os | import os | ||||||
| import requests | import requests | ||||||
| import signal |  | ||||||
| import re | import re | ||||||
| import json | import json | ||||||
| from typing import List, Dict, Any | from typing import List, Dict, Any | ||||||
|  |  | ||||||
|  |  | ||||||
| class TimeoutException(Exception): | class TimeoutException(Exception): | ||||||
|     """Custom exception to handle timeouts.""" |     """自定义异常用于处理超时情况。""" | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
| def timeout_handler(signum, frame): | def detectGPT(content: str) -> str: | ||||||
|     """Handle the SIGALRM signal by raising a TimeoutException.""" |     """ | ||||||
|     raise TimeoutException |     检测给定的代码内容中的潜在安全漏洞。 | ||||||
|  |  | ||||||
|  |     参数: | ||||||
|  |     - content: 要检测的代码字符串。 | ||||||
|  |  | ||||||
| # 从环境变量中获取API密钥 |     返回: | ||||||
| API_KEY = os.getenv('BAIDU_API_KEY') |     - 分类后的漏洞信息的JSON字符串。 | ||||||
| SECRET_KEY = os.getenv('BAIDU_SECRET_KEY') |     """ | ||||||
|  |     api_key = os.getenv("BAIDU_API_KEY") | ||||||
|  |     secret_key = os.getenv("BAIDU_SECRET_KEY") | ||||||
|  |     #api_key = "DUBWNIrB6QJLOsLkpnEz2ZZa" | ||||||
|  |     #secret_key = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7" | ||||||
|  |     if not api_key or not secret_key: | ||||||
|  |         raise ValueError("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") | ||||||
|  |  | ||||||
| #API_KEY = "DUBWNIrB6QJLOsLkpnEz2ZZa" |     url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-0329?access_token=" + get_access_token( | ||||||
| #SECRET_KEY = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7" |         api_key, secret_key) | ||||||
|  |  | ||||||
|  |  | ||||||
| def detectGPT(content): |  | ||||||
|     # signal.signal(signal.SIGTERM, timeout_handler) |  | ||||||
|     # signal.alarm(10) |  | ||||||
|  |  | ||||||
|     url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + get_access_token() |  | ||||||
|  |  | ||||||
|     # 注意message必须是奇数条 |  | ||||||
|     payload = json.dumps({ |     payload = json.dumps({ | ||||||
|         "messages": [ |         "messages": [ | ||||||
|             { |             { | ||||||
| @@ -50,56 +49,55 @@ def detectGPT(content): | |||||||
|         'Content-Type': 'application/json' |         'Content-Type': 'application/json' | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     res_json = requests.request("POST", url, headers=headers, data=payload).json() |  | ||||||
|     try: |     try: | ||||||
|         message_content = res_json.get('result')  # 使用get方法获取result,避免KeyError异常 |         response = requests.post(url, headers=headers, data=payload) | ||||||
|  |         response.raise_for_status() | ||||||
|  |         res_json = response.json() | ||||||
|  |         message_content = res_json.get('result') | ||||||
|         if message_content is None: |         if message_content is None: | ||||||
|             raise ValueError("API response content is None") |             raise ValueError("API response content is None") | ||||||
|  |     except requests.RequestException as e: | ||||||
|  |         raise ValueError(f"Request failed: {str(e)}") | ||||||
|  |  | ||||||
|     except TimeoutException: |  | ||||||
|         raise TimeoutException("The api call timed out") |  | ||||||
|  |  | ||||||
|     except Exception as e: |  | ||||||
|         raise ValueError(f"Error: {str(e)}") |  | ||||||
|     # finally: |  | ||||||
|     #   signal.alarm(0) |  | ||||||
|  |  | ||||||
|     # 提取数据 |  | ||||||
|     extracted_data = extract_json_from_text(message_content) |     extracted_data = extract_json_from_text(message_content) | ||||||
|  |  | ||||||
|     # 输出提取的 JSON 数据 |  | ||||||
|     classified_results = {"high": [], "medium": [], "low": [], "none": []} |     classified_results = {"high": [], "medium": [], "low": [], "none": []} | ||||||
|     for res in extracted_data: |     for res in extracted_data: | ||||||
|         classified_results[res["Risk"]].append( |         try: | ||||||
|             (res["Line"], content.split("\n")[res["Line"] - 1].strip()) |             line_number = int(res["Line"]) | ||||||
|         ) |             classified_results[res["Risk"]].append( | ||||||
|     #return classified_results |                 (line_number, content.split("\n")[line_number - 1].strip()) | ||||||
|     result = json.dumps(classified_results, indent=2, ensure_ascii=False) |             ) | ||||||
|     classified_results = json.loads(result) |         except (ValueError, IndexError, KeyError): | ||||||
|     return classified_results |             continue | ||||||
|  |  | ||||||
| # 获得访问令牌 |     return json.dumps(classified_results, indent=2, ensure_ascii=False) | ||||||
| def get_access_token(): |  | ||||||
|  |  | ||||||
|  | def get_access_token(api_key: str, secret_key: str) -> str: | ||||||
|     """ |     """ | ||||||
|     使用 AK,SK 生成鉴权签名(Access Token) |     使用API密钥和秘密生成访问令牌。 | ||||||
|     :return: access_token,或是None(如果错误) |  | ||||||
|  |     返回: | ||||||
|  |     - access_token字符串。 | ||||||
|     """ |     """ | ||||||
|     url = "https://aip.baidubce.com/oauth/2.0/token" |     url = "https://aip.baidubce.com/oauth/2.0/token" | ||||||
|     params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY} |     params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} | ||||||
|     return str(requests.post(url, params=params).json().get("access_token")) |     response = requests.post(url, params=params) | ||||||
|  |     response.raise_for_status() | ||||||
|  |     return response.json().get("access_token") | ||||||
|  |  | ||||||
|  |  | ||||||
| def extract_json_from_text(text: str) -> List[Dict[str, Any]]: | def extract_json_from_text(text: str) -> List[Dict[str, Any]]: | ||||||
|     """ |     """ | ||||||
|     从文本中提取 JSON 数据。 |     从文本中提取JSON数据。 | ||||||
|  |  | ||||||
|     参数: |     参数: | ||||||
|     - text: 包含 JSON 数据的字符串文本。 |     - text: 包含JSON数据的字符串文本。 | ||||||
|  |  | ||||||
|     返回: |     返回: | ||||||
|     - 包含提取 JSON 数据的字典列表。 |     - 包含提取JSON数据的字典列表。 | ||||||
|     """ |     """ | ||||||
|     # 使用正则表达式找到 JSON 部分 |  | ||||||
|     json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL) |     json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL) | ||||||
|     if not json_match: |     if not json_match: | ||||||
|         print("未找到 JSON 数据") |         print("未找到 JSON 数据") | ||||||
| @@ -112,4 +110,4 @@ def extract_json_from_text(text: str) -> List[Dict[str, Any]]: | |||||||
|         print(f"解码 JSON 时出错: {e}") |         print(f"解码 JSON 时出错: {e}") | ||||||
|         return [] |         return [] | ||||||
|  |  | ||||||
|     return data |     return data | ||||||
| @@ -1,49 +1,47 @@ | |||||||
| import unittest | import unittest | ||||||
| import warnings | import warnings | ||||||
| import os | import os | ||||||
|  | import json | ||||||
|  |  | ||||||
| from detection.cngptdetection import detectGPT  # 导入调用百度 ai 模型的函数 | from detection.cngptdetection import detectGPT | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestBackdoorDetection(unittest.TestCase): | class TestBackdoorDetection(unittest.TestCase): | ||||||
|     def test_gpt_risk_detection(self): |     def test_gpt_risk_detection(self): | ||||||
|         """ |         if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: | ||||||
|        if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: |  | ||||||
|             warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) |             warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) | ||||||
|             self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") |             self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") | ||||||
|         """ |  | ||||||
|         content = """import os |         content = """import os | ||||||
|         os.system('rm -rf /')   # high risk |         os.system('rm -rf /')   # high risk | ||||||
|         exec('print("Hello")')  # high risk |         exec('print("Hello")')  # high risk | ||||||
|         eval('2 + 2')   # high risk |         eval('2 + 2')   # high risk | ||||||
|         """ |         """ | ||||||
|         results1 = detectGPT(content) |         results1 = detectGPT(content) | ||||||
|         self.assertEqual(len(results1["high"]), 3) |         classified_results = json.loads(results1) | ||||||
|  |         self.assertEqual(len(classified_results["high"]), 3) | ||||||
|  |  | ||||||
|     def test_gpt_no_risk_detection(self): |     def test_gpt_no_risk_detection(self): | ||||||
|         """ |  | ||||||
|         if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: |         if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: | ||||||
|             warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) |             warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) | ||||||
|             self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") |             self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") | ||||||
|         """ |  | ||||||
|         content = """a = 10 |         content = """a = 10 | ||||||
|         b = a + 5 |         b = a + 5 | ||||||
|         print('This should not be detected as risky.') |         print('This should not be detected as risky.') | ||||||
|         """ |         """ | ||||||
|         results2 = detectGPT(content) |         results2 = detectGPT(content) | ||||||
|         self.assertEqual(len(results2["high"]), 0) |         classified_results = json.loads(results2) | ||||||
|         self.assertEqual(len(results2["medium"]), 0) |         self.assertEqual(len(classified_results["high"]), 0) | ||||||
|         self.assertEqual(len(results2["low"]), 0) |         self.assertEqual(len(classified_results["medium"]), 0) | ||||||
|  |         self.assertEqual(len(classified_results["low"]), 0) | ||||||
|  |  | ||||||
|     def test_gpt_env_no_set(self): |     def test_gpt_env_no_set(self): | ||||||
|         """ |  | ||||||
|         if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None: |         if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None: | ||||||
|             self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set") |             self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set") | ||||||
|         """ |  | ||||||
|         content = "print('test test')" |         content = "print('test test')" | ||||||
|         with self.assertRaises(ValueError): |         with self.assertRaises(ValueError): | ||||||
|             detectGPT(content) |             detectGPT(content) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user