feature/cn-gpt #21

Merged
dqy merged 10 commits from feature/cn-gpt into main 2024-05-26 16:59:24 +08:00
2 changed files with 58 additions and 62 deletions
Showing only changes of commit b1bc566c09 - Show all commits

View File

@@ -1,36 +1,35 @@
import os import os
import requests import requests
import signal
import re import re
import json import json
from typing import List, Dict, Any from typing import List, Dict, Any
class TimeoutException(Exception): class TimeoutException(Exception):
"""Custom exception to handle timeouts.""" """自定义异常用于处理超时情况。"""
pass pass
def timeout_handler(signum, frame): def detectGPT(content: str) -> str:
"""Handle the SIGALRM signal by raising a TimeoutException.""" """
raise TimeoutException 检测给定的代码内容中的潜在安全漏洞。
参数:
- content: 要检测的代码字符串。
# 从环境变量中获取API密钥 返回:
API_KEY = os.getenv('BAIDU_API_KEY') - 分类后的漏洞信息的JSON字符串。
SECRET_KEY = os.getenv('BAIDU_SECRET_KEY') """
api_key = os.getenv("BAIDU_API_KEY")
secret_key = os.getenv("BAIDU_SECRET_KEY")
#api_key = "DUBWNIrB6QJLOsLkpnEz2ZZa"
#secret_key = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7"
if not api_key or not secret_key:
raise ValueError("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
#API_KEY = "DUBWNIrB6QJLOsLkpnEz2ZZa" url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-0329?access_token=" + get_access_token(
#SECRET_KEY = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7" api_key, secret_key)
def detectGPT(content):
# signal.signal(signal.SIGTERM, timeout_handler)
# signal.alarm(10)
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + get_access_token()
# 注意message必须是奇数条
payload = json.dumps({ payload = json.dumps({
"messages": [ "messages": [
{ {
@@ -50,56 +49,55 @@ def detectGPT(content):
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
res_json = requests.request("POST", url, headers=headers, data=payload).json()
try: try:
message_content = res_json.get('result') # 使用get方法获取result避免KeyError异常 response = requests.post(url, headers=headers, data=payload)
response.raise_for_status()
res_json = response.json()
message_content = res_json.get('result')
if message_content is None: if message_content is None:
raise ValueError("API response content is None") raise ValueError("API response content is None")
except requests.RequestException as e:
raise ValueError(f"Request failed: {str(e)}")
except TimeoutException:
raise TimeoutException("The api call timed out")
except Exception as e:
raise ValueError(f"Error: {str(e)}")
# finally:
# signal.alarm(0)
# 提取数据
extracted_data = extract_json_from_text(message_content) extracted_data = extract_json_from_text(message_content)
# 输出提取的 JSON 数据
classified_results = {"high": [], "medium": [], "low": [], "none": []} classified_results = {"high": [], "medium": [], "low": [], "none": []}
for res in extracted_data: for res in extracted_data:
classified_results[res["Risk"]].append( try:
(res["Line"], content.split("\n")[res["Line"] - 1].strip()) line_number = int(res["Line"])
) classified_results[res["Risk"]].append(
#return classified_results (line_number, content.split("\n")[line_number - 1].strip())
result = json.dumps(classified_results, indent=2, ensure_ascii=False) )
classified_results = json.loads(result) except (ValueError, IndexError, KeyError):
return classified_results continue
# 获得访问令牌 return json.dumps(classified_results, indent=2, ensure_ascii=False)
def get_access_token():
def get_access_token(api_key: str, secret_key: str) -> str:
""" """
使用 AKSK 生成鉴权签名Access Token 使用API密钥和秘密生成访问令牌。
:return: access_token或是None(如果错误)
返回:
- access_token字符串。
""" """
url = "https://aip.baidubce.com/oauth/2.0/token" url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY} params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
return str(requests.post(url, params=params).json().get("access_token")) response = requests.post(url, params=params)
response.raise_for_status()
return response.json().get("access_token")
def extract_json_from_text(text: str) -> List[Dict[str, Any]]: def extract_json_from_text(text: str) -> List[Dict[str, Any]]:
""" """
从文本中提取 JSON 数据。 从文本中提取JSON数据。
参数: 参数:
- text: 包含 JSON 数据的字符串文本。 - text: 包含JSON数据的字符串文本。
返回: 返回:
- 包含提取 JSON 数据的字典列表。 - 包含提取JSON数据的字典列表。
""" """
# 使用正则表达式找到 JSON 部分
json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL) json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
if not json_match: if not json_match:
print("未找到 JSON 数据") print("未找到 JSON 数据")
@@ -112,4 +110,4 @@ def extract_json_from_text(text: str) -> List[Dict[str, Any]]:
print(f"解码 JSON 时出错: {e}") print(f"解码 JSON 时出错: {e}")
return [] return []
return data return data

View File

@@ -1,49 +1,47 @@
import unittest import unittest
import warnings import warnings
import os import os
import json
from detection.cngptdetection import detectGPT # 导入调用百度 ai 模型的函数 from detection.cngptdetection import detectGPT
class TestBackdoorDetection(unittest.TestCase): class TestBackdoorDetection(unittest.TestCase):
def test_gpt_risk_detection(self): def test_gpt_risk_detection(self):
""" if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
"""
content = """import os content = """import os
os.system('rm -rf /') # high risk os.system('rm -rf /') # high risk
exec('print("Hello")') # high risk exec('print("Hello")') # high risk
eval('2 + 2') # high risk eval('2 + 2') # high risk
""" """
results1 = detectGPT(content) results1 = detectGPT(content)
self.assertEqual(len(results1["high"]), 3) classified_results = json.loads(results1)
self.assertEqual(len(classified_results["high"]), 3)
def test_gpt_no_risk_detection(self): def test_gpt_no_risk_detection(self):
"""
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning) warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
"""
content = """a = 10 content = """a = 10
b = a + 5 b = a + 5
print('This should not be detected as risky.') print('This should not be detected as risky.')
""" """
results2 = detectGPT(content) results2 = detectGPT(content)
self.assertEqual(len(results2["high"]), 0) classified_results = json.loads(results2)
self.assertEqual(len(results2["medium"]), 0) self.assertEqual(len(classified_results["high"]), 0)
self.assertEqual(len(results2["low"]), 0) self.assertEqual(len(classified_results["medium"]), 0)
self.assertEqual(len(classified_results["low"]), 0)
def test_gpt_env_no_set(self): def test_gpt_env_no_set(self):
"""
if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None: if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None:
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set") self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set")
"""
content = "print('test test')" content = "print('test test')"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
detectGPT(content) detectGPT(content)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()