From b1bc566c09b710c5ab6644043397bf36eb7a550c Mon Sep 17 00:00:00 2001 From: ccyj <2384899431@qq.com> Date: Fri, 24 May 2024 20:27:18 +0800 Subject: [PATCH] =?UTF-8?q?update=EF=BC=9A=E4=BF=AE=E6=94=B9=E5=9B=BD?= =?UTF-8?q?=E5=86=85gpt=E8=B0=83=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- detection/cngptdetection.py | 94 +++++++++++++++++----------------- tests/test_CN_GPT_detection.py | 26 +++++----- 2 files changed, 58 insertions(+), 62 deletions(-) diff --git a/detection/cngptdetection.py b/detection/cngptdetection.py index 4e9e891..20a8a79 100644 --- a/detection/cngptdetection.py +++ b/detection/cngptdetection.py @@ -1,36 +1,35 @@ import os import requests -import signal import re import json from typing import List, Dict, Any class TimeoutException(Exception): - """Custom exception to handle timeouts.""" + """自定义异常用于处理超时情况。""" pass -def timeout_handler(signum, frame): - """Handle the SIGALRM signal by raising a TimeoutException.""" - raise TimeoutException +def detectGPT(content: str) -> str: + """ + 检测给定的代码内容中的潜在安全漏洞。 + 参数: + - content: 要检测的代码字符串。 -# 从环境变量中获取API密钥 -API_KEY = os.getenv('BAIDU_API_KEY') -SECRET_KEY = os.getenv('BAIDU_SECRET_KEY') + 返回: + - 分类后的漏洞信息的JSON字符串。 + """ + api_key = os.getenv("BAIDU_API_KEY") + secret_key = os.getenv("BAIDU_SECRET_KEY") + #api_key = "DUBWNIrB6QJLOsLkpnEz2ZZa" + #secret_key = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7" + if not api_key or not secret_key: + raise ValueError("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") -#API_KEY = "DUBWNIrB6QJLOsLkpnEz2ZZa" -#SECRET_KEY = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7" + url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-0329?access_token=" + get_access_token( + api_key, secret_key) - -def detectGPT(content): - # signal.signal(signal.SIGTERM, timeout_handler) - # signal.alarm(10) - - url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + get_access_token() - - # 注意message必须是奇数条 payload = json.dumps({ "messages": [ { @@ -50,56 +49,55 @@ def detectGPT(content): 'Content-Type': 'application/json' } - res_json = requests.request("POST", url, headers=headers, data=payload).json() try: - message_content = res_json.get('result') # 使用get方法获取result,避免KeyError异常 + response = requests.post(url, headers=headers, data=payload) + response.raise_for_status() + res_json = response.json() + message_content = res_json.get('result') if message_content is None: raise ValueError("API response content is None") + except requests.RequestException as e: + raise ValueError(f"Request failed: {str(e)}") - except TimeoutException: - raise TimeoutException("The api call timed out") - - except Exception as e: - raise ValueError(f"Error: {str(e)}") - # finally: - # signal.alarm(0) - - # 提取数据 extracted_data = extract_json_from_text(message_content) - # 输出提取的 JSON 数据 classified_results = {"high": [], "medium": [], "low": [], "none": []} for res in extracted_data: - classified_results[res["Risk"]].append( - (res["Line"], content.split("\n")[res["Line"] - 1].strip()) - ) - #return classified_results - result = json.dumps(classified_results, indent=2, ensure_ascii=False) - classified_results = json.loads(result) - return classified_results + try: + line_number = int(res["Line"]) + classified_results[res["Risk"]].append( + (line_number, content.split("\n")[line_number - 1].strip()) + ) + except (ValueError, IndexError, KeyError): + continue -# 获得访问令牌 -def get_access_token(): + return json.dumps(classified_results, indent=2, ensure_ascii=False) + + +def get_access_token(api_key: str, secret_key: str) -> str: """ - 使用 AK,SK 生成鉴权签名(Access Token) - :return: access_token,或是None(如果错误) + 使用API密钥和秘密生成访问令牌。 + + 返回: + - access_token字符串。 """ url = "https://aip.baidubce.com/oauth/2.0/token" - params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY} - return str(requests.post(url, params=params).json().get("access_token")) + params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + response = requests.post(url, params=params) + response.raise_for_status() + return response.json().get("access_token") def extract_json_from_text(text: str) -> List[Dict[str, Any]]: """ - 从文本中提取 JSON 数据。 + 从文本中提取JSON数据。 参数: - - text: 包含 JSON 数据的字符串文本。 + - text: 包含JSON数据的字符串文本。 返回: - - 包含提取 JSON 数据的字典列表。 + - 包含提取JSON数据的字典列表。 """ - # 使用正则表达式找到 JSON 部分 json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL) if not json_match: print("未找到 JSON 数据") @@ -112,4 +110,4 @@ def extract_json_from_text(text: str) -> List[Dict[str, Any]]: print(f"解码 JSON 时出错: {e}") return [] - return data + return data \ No newline at end of file diff --git a/tests/test_CN_GPT_detection.py b/tests/test_CN_GPT_detection.py index 39e08ac..6f0cdd2 100644 --- a/tests/test_CN_GPT_detection.py +++ b/tests/test_CN_GPT_detection.py @@ -1,49 +1,47 @@ import unittest import warnings import os +import json -from detection.cngptdetection import detectGPT # 导入调用百度 ai 模型的函数 - +from detection.cngptdetection import detectGPT class TestBackdoorDetection(unittest.TestCase): def test_gpt_risk_detection(self): - """ - if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: + if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") - """ + content = """import os os.system('rm -rf /') # high risk exec('print("Hello")') # high risk eval('2 + 2') # high risk """ results1 = detectGPT(content) - self.assertEqual(len(results1["high"]), 3) + classified_results = json.loads(results1) + self.assertEqual(len(classified_results["high"]), 3) def test_gpt_no_risk_detection(self): - """ if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") - """ + content = """a = 10 b = a + 5 print('This should not be detected as risky.') """ results2 = detectGPT(content) - self.assertEqual(len(results2["high"]), 0) - self.assertEqual(len(results2["medium"]), 0) - self.assertEqual(len(results2["low"]), 0) + classified_results = json.loads(results2) + self.assertEqual(len(classified_results["high"]), 0) + self.assertEqual(len(classified_results["medium"]), 0) + self.assertEqual(len(classified_results["low"]), 0) def test_gpt_env_no_set(self): - """ if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None: self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set") - """ + content = "print('test test')" with self.assertRaises(ValueError): detectGPT(content) - if __name__ == "__main__": unittest.main()