update:修改国内gpt调用
This commit is contained in:
parent
f0e2251dc0
commit
b1bc566c09
@ -1,36 +1,35 @@
|
|||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
import signal
|
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class TimeoutException(Exception):
|
class TimeoutException(Exception):
|
||||||
"""Custom exception to handle timeouts."""
|
"""自定义异常用于处理超时情况。"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def timeout_handler(signum, frame):
|
def detectGPT(content: str) -> str:
|
||||||
"""Handle the SIGALRM signal by raising a TimeoutException."""
|
"""
|
||||||
raise TimeoutException
|
检测给定的代码内容中的潜在安全漏洞。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- content: 要检测的代码字符串。
|
||||||
|
|
||||||
# 从环境变量中获取API密钥
|
返回:
|
||||||
API_KEY = os.getenv('BAIDU_API_KEY')
|
- 分类后的漏洞信息的JSON字符串。
|
||||||
SECRET_KEY = os.getenv('BAIDU_SECRET_KEY')
|
"""
|
||||||
|
api_key = os.getenv("BAIDU_API_KEY")
|
||||||
|
secret_key = os.getenv("BAIDU_SECRET_KEY")
|
||||||
|
#api_key = "DUBWNIrB6QJLOsLkpnEz2ZZa"
|
||||||
|
#secret_key = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7"
|
||||||
|
if not api_key or not secret_key:
|
||||||
|
raise ValueError("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
|
||||||
|
|
||||||
#API_KEY = "DUBWNIrB6QJLOsLkpnEz2ZZa"
|
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-0329?access_token=" + get_access_token(
|
||||||
#SECRET_KEY = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7"
|
api_key, secret_key)
|
||||||
|
|
||||||
|
|
||||||
def detectGPT(content):
|
|
||||||
# signal.signal(signal.SIGTERM, timeout_handler)
|
|
||||||
# signal.alarm(10)
|
|
||||||
|
|
||||||
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + get_access_token()
|
|
||||||
|
|
||||||
# 注意message必须是奇数条
|
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
@ -50,56 +49,55 @@ def detectGPT(content):
|
|||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
res_json = requests.request("POST", url, headers=headers, data=payload).json()
|
|
||||||
try:
|
try:
|
||||||
message_content = res_json.get('result') # 使用get方法获取result,避免KeyError异常
|
response = requests.post(url, headers=headers, data=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
res_json = response.json()
|
||||||
|
message_content = res_json.get('result')
|
||||||
if message_content is None:
|
if message_content is None:
|
||||||
raise ValueError("API response content is None")
|
raise ValueError("API response content is None")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise ValueError(f"Request failed: {str(e)}")
|
||||||
|
|
||||||
except TimeoutException:
|
|
||||||
raise TimeoutException("The api call timed out")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Error: {str(e)}")
|
|
||||||
# finally:
|
|
||||||
# signal.alarm(0)
|
|
||||||
|
|
||||||
# 提取数据
|
|
||||||
extracted_data = extract_json_from_text(message_content)
|
extracted_data = extract_json_from_text(message_content)
|
||||||
|
|
||||||
# 输出提取的 JSON 数据
|
|
||||||
classified_results = {"high": [], "medium": [], "low": [], "none": []}
|
classified_results = {"high": [], "medium": [], "low": [], "none": []}
|
||||||
for res in extracted_data:
|
for res in extracted_data:
|
||||||
classified_results[res["Risk"]].append(
|
try:
|
||||||
(res["Line"], content.split("\n")[res["Line"] - 1].strip())
|
line_number = int(res["Line"])
|
||||||
)
|
classified_results[res["Risk"]].append(
|
||||||
#return classified_results
|
(line_number, content.split("\n")[line_number - 1].strip())
|
||||||
result = json.dumps(classified_results, indent=2, ensure_ascii=False)
|
)
|
||||||
classified_results = json.loads(result)
|
except (ValueError, IndexError, KeyError):
|
||||||
return classified_results
|
continue
|
||||||
|
|
||||||
# 获得访问令牌
|
return json.dumps(classified_results, indent=2, ensure_ascii=False)
|
||||||
def get_access_token():
|
|
||||||
|
|
||||||
|
def get_access_token(api_key: str, secret_key: str) -> str:
|
||||||
"""
|
"""
|
||||||
使用 AK,SK 生成鉴权签名(Access Token)
|
使用API密钥和秘密生成访问令牌。
|
||||||
:return: access_token,或是None(如果错误)
|
|
||||||
|
返回:
|
||||||
|
- access_token字符串。
|
||||||
"""
|
"""
|
||||||
url = "https://aip.baidubce.com/oauth/2.0/token"
|
url = "https://aip.baidubce.com/oauth/2.0/token"
|
||||||
params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
|
params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
|
||||||
return str(requests.post(url, params=params).json().get("access_token"))
|
response = requests.post(url, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json().get("access_token")
|
||||||
|
|
||||||
|
|
||||||
def extract_json_from_text(text: str) -> List[Dict[str, Any]]:
|
def extract_json_from_text(text: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从文本中提取 JSON 数据。
|
从文本中提取JSON数据。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- text: 包含 JSON 数据的字符串文本。
|
- text: 包含JSON数据的字符串文本。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
- 包含提取 JSON 数据的字典列表。
|
- 包含提取JSON数据的字典列表。
|
||||||
"""
|
"""
|
||||||
# 使用正则表达式找到 JSON 部分
|
|
||||||
json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
|
json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
|
||||||
if not json_match:
|
if not json_match:
|
||||||
print("未找到 JSON 数据")
|
print("未找到 JSON 数据")
|
||||||
|
@ -1,49 +1,47 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
from detection.cngptdetection import detectGPT # 导入调用百度 ai 模型的函数
|
from detection.cngptdetection import detectGPT
|
||||||
|
|
||||||
|
|
||||||
class TestBackdoorDetection(unittest.TestCase):
|
class TestBackdoorDetection(unittest.TestCase):
|
||||||
def test_gpt_risk_detection(self):
|
def test_gpt_risk_detection(self):
|
||||||
"""
|
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
|
||||||
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
|
|
||||||
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
|
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
|
||||||
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
|
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
|
||||||
"""
|
|
||||||
content = """import os
|
content = """import os
|
||||||
os.system('rm -rf /') # high risk
|
os.system('rm -rf /') # high risk
|
||||||
exec('print("Hello")') # high risk
|
exec('print("Hello")') # high risk
|
||||||
eval('2 + 2') # high risk
|
eval('2 + 2') # high risk
|
||||||
"""
|
"""
|
||||||
results1 = detectGPT(content)
|
results1 = detectGPT(content)
|
||||||
self.assertEqual(len(results1["high"]), 3)
|
classified_results = json.loads(results1)
|
||||||
|
self.assertEqual(len(classified_results["high"]), 3)
|
||||||
|
|
||||||
def test_gpt_no_risk_detection(self):
|
def test_gpt_no_risk_detection(self):
|
||||||
"""
|
|
||||||
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
|
if os.getenv("BAIDU_API_KEY") is None or os.getenv("BAIDU_SECRET_KEY") is None:
|
||||||
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
|
warnings.warn("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set, test skipped.", UserWarning)
|
||||||
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
|
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
|
||||||
"""
|
|
||||||
content = """a = 10
|
content = """a = 10
|
||||||
b = a + 5
|
b = a + 5
|
||||||
print('This should not be detected as risky.')
|
print('This should not be detected as risky.')
|
||||||
"""
|
"""
|
||||||
results2 = detectGPT(content)
|
results2 = detectGPT(content)
|
||||||
self.assertEqual(len(results2["high"]), 0)
|
classified_results = json.loads(results2)
|
||||||
self.assertEqual(len(results2["medium"]), 0)
|
self.assertEqual(len(classified_results["high"]), 0)
|
||||||
self.assertEqual(len(results2["low"]), 0)
|
self.assertEqual(len(classified_results["medium"]), 0)
|
||||||
|
self.assertEqual(len(classified_results["low"]), 0)
|
||||||
|
|
||||||
def test_gpt_env_no_set(self):
|
def test_gpt_env_no_set(self):
|
||||||
"""
|
|
||||||
if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None:
|
if os.getenv("BAIDU_API_KEY") is not None or os.getenv("BAIDU_SECRET_KEY") is not None:
|
||||||
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set")
|
self.skipTest("BAIDU_API_KEY or BAIDU_SECRET_KEY is set")
|
||||||
"""
|
|
||||||
content = "print('test test')"
|
content = "print('test test')"
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
detectGPT(content)
|
detectGPT(content)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user