feature/国内GPT-文心一言

This commit is contained in:
2024-05-16 21:15:22 +08:00
parent 9d6f054478
commit dd45c467a3
2 changed files with 110 additions and 79 deletions

View File

@@ -1,9 +1,10 @@
import json import os
import requests import requests
import signal import signal
from typing import Dict, List, Tuple # 用于类型提示的模块,使用了 Dict, List, Tuple 进行类型注解。 import re
import json
from typing import List, Dict, Any
# 参考文档https://blog.csdn.net/weixin_73654895/article/details/133799269
class TimeoutException(Exception): class TimeoutException(Exception):
"""Custom exception to handle timeouts.""" """Custom exception to handle timeouts."""
@@ -15,83 +16,99 @@ def timeout_handler(signum, frame):
raise TimeoutException raise TimeoutException
def get_baidu_access_token(api_key: str, secret_key: str) -> str: # 从环境变量中获取API密钥
""" API_KEY = os.getenv('BAIDU_API_KEY')
Retrieve the access token from Baidu API using API key and Secret key. SECRET_KEY = os.getenv('BAIDU_SECRET_KEY')
Args: #API_KEY = "DUBWNIrB6QJLOsLkpnEz2ZZa"
api_key (str): The API key for Baidu API. #SECRET_KEY = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7"
secret_key (str): The Secret key for Baidu API.
Returns:
str: The access token.
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
response = requests.post(url, params=params)
response_data = response.json()
if 'access_token' not in response_data:
raise ValueError("Error: Could not retrieve access token.")
return str(response_data["access_token"])
def cndetectGPT(content: str) -> Dict[str, List[Tuple[int, str]]]: def detectGPT(content):
""" # signal.signal(signal.SIGTERM, timeout_handler)
Detect potential security vulnerabilities in the provided code content using Baidu's AI model. # signal.alarm(10)
Args: url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + get_access_token()
content (str): The code content to be analyzed.
Returns:
Dict[str, List[Tuple[int, str]]]: Classified results of detected vulnerabilities.
"""
API_KEY = "DUBWNIrB6QJLOsLkpnEz2ZZa"
SECRET_KEY = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7"
# Set alarm timer
signal.signal(signal.SIGTERM, timeout_handler)
signal.alarm(10)
try:
access_token = get_baidu_access_token(API_KEY, SECRET_KEY)
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={access_token}"
# 注意message必须是奇数条
payload = json.dumps({ payload = json.dumps({
"messages": [ "messages": [
{
"role": "system",
"content": "You are a Python code reviewer. Read the code below and identify any potential security vulnerabilities. Classify them by risk level (high, medium, low, none). Only report the line number and the risk level.\nYou should output the result as json format in one line. For example: "
'[{"Line": {the line number}, "Risk": "{choose from (high,medium,low)}","Reason":"{how it is vulnerable}"}] Each of these three fields is required.\n'
"You are required to only output the json format. Do not output any other information.\n"
},
{ {
"role": "user", "role": "user",
"content": content "content": (
"You are a Python code reviewer. Read the code below and identify any potential "
"security vulnerabilities. Classify them by risk level (high, medium, low, none). "
'Only report the line number and the risk level.\nYou should output the result as '
'json format in one line. For example: [{"Line": {the line number}, "Risk": "{choose from (high,medium,low)}","Reason":"{how it is vulnerable}"}] '
"Each of these three fields is required.\nYou are required to only output the json format. "
"Do not output any other information." + content
)
} }
] ]
}) })
headers = { headers = {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
response = requests.post(url, headers=headers, data=payload) res_json = requests.request("POST", url, headers=headers, data=payload).json()
response_data = response.json() try:
message_content = response_data.get('result', None) message_content = res_json.get('result') # 使用get方法获取result避免KeyError异常
if message_content is None: if message_content is None:
raise ValueError("API response content is None") raise ValueError("API response content is None")
res_json = json.loads(message_content)
except json.JSONDecodeError:
raise ValueError("Error: Could not parse the response. Please try again.")
except TimeoutException: except TimeoutException:
raise TimeoutException("The API call timed out") raise TimeoutException("The api call timed out")
finally:
signal.alarm(0)
except Exception as e:
raise ValueError(f"Error: {str(e)}")
# finally:
# signal.alarm(0)
# 提取数据
extracted_data = extract_json_from_text(message_content)
# 输出提取的 JSON 数据
classified_results = {"high": [], "medium": [], "low": [], "none": []} classified_results = {"high": [], "medium": [], "low": [], "none": []}
for res in res_json: for res in extracted_data:
classified_results[res["Risk"]].append( classified_results[res["Risk"]].append(
(res["Line"], content.split("\n")[res["Line"] - 1].strip()) (res["Line"], content.split("\n")[res["Line"] - 1].strip())
) )
return classified_results #return classified_results
result = json.dumps(classified_results, indent=2, ensure_ascii=False)
return result
# 获得访问令牌
def get_access_token():
"""
使用 AKSK 生成鉴权签名Access Token
:return: access_token或是None(如果错误)
"""
url = "https://aip.baidubce.com/oauth/2.0/token"
params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
return str(requests.post(url, params=params).json().get("access_token"))
def extract_json_from_text(text: str) -> List[Dict[str, Any]]:
"""
从文本中提取 JSON 数据。
参数:
- text: 包含 JSON 数据的字符串文本。
返回:
- 包含提取 JSON 数据的字典列表。
"""
# 使用正则表达式找到 JSON 部分
json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
if not json_match:
print("未找到 JSON 数据")
return []
json_string = json_match.group(0)
try:
data = json.loads(json_string)
except json.JSONDecodeError as e:
print(f"解码 JSON 时出错: {e}")
return []
return data

View File

@@ -1,34 +1,48 @@
import unittest import unittest
import warnings import warnings
import os import os
from detection.cngptdetection import cndetectGPT
from detection.cngptdetection import detectGPT # 导入调用百度 ai 模型的函数
class TestBackdoorDetection(unittest.TestCase): class TestBackdoorDetection(unittest.TestCase):
def test_gpt_risk_detection(self): def test_gpt_risk_detection(self):
"""
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
"""
content = """import os content = """import os
os.system('rm -rf /') # high risk os.system('rm -rf /') # high risk
exec('print("Hello")') # high risk exec('print("Hello")') # high risk
eval('2 + 2') # high risk eval('2 + 2') # high risk
""" """
results = cndetectGPT(content) results1 = detectGPT(content)
self.assertEqual(len(results["high"]), 3) self.assertEqual(len(results1["high"]), 3)
def test_gpt_no_risk_detection(self): def test_gpt_no_risk_detection(self):
"""
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
"""
content = """a = 10 content = """a = 10
b = a + 5 b = a + 5
print('This should not be detected as risky.') print('This should not be detected as risky.')
""" """
results = cndetectGPT(content) results2 = detectGPT(content)
self.assertEqual(len(results["high"]), 0) self.assertEqual(len(results2["high"]), 0)
self.assertEqual(len(results["medium"]), 0) self.assertEqual(len(results2["medium"]), 0)
self.assertEqual(len(results["low"]), 0) self.assertEqual(len(results2["low"]), 0)
def test_gpt_env_no_set(self): def test_gpt_env_no_set(self):
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None: """
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None:
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set")
"""
content = "print('test test')" content = "print('test test')"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
cndetectGPT(content) detectGPT(content)
if __name__ == "__main__": if __name__ == "__main__":