update:修改国内gpt调用
This commit is contained in:
		| @@ -1,36 +1,35 @@ | ||||
| import os | ||||
| import requests | ||||
| import signal | ||||
| import re | ||||
| import json | ||||
| from typing import List, Dict, Any | ||||
|  | ||||
|  | ||||
| class TimeoutException(Exception): | ||||
|     """Custom exception to handle timeouts.""" | ||||
|     """自定义异常用于处理超时情况。""" | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def timeout_handler(signum, frame): | ||||
|     """Handle the SIGALRM signal by raising a TimeoutException.""" | ||||
|     raise TimeoutException | ||||
| def detectGPT(content: str) -> str: | ||||
|     """ | ||||
|     检测给定的代码内容中的潜在安全漏洞。 | ||||
|  | ||||
|     参数: | ||||
|     - content: 要检测的代码字符串。 | ||||
|  | ||||
| # 从环境变量中获取API密钥 | ||||
| API_KEY = os.getenv('BAIDU_API_KEY') | ||||
| SECRET_KEY = os.getenv('BAIDU_SECRET_KEY') | ||||
|     返回: | ||||
|     - 分类后的漏洞信息的JSON字符串。 | ||||
|     """ | ||||
|     api_key = os.getenv("BAIDU_API_KEY") | ||||
|     secret_key = os.getenv("BAIDU_SECRET_KEY") | ||||
|     #api_key = "DUBWNIrB6QJLOsLkpnEz2ZZa" | ||||
|     #secret_key = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7" | ||||
|     if not api_key or not secret_key: | ||||
|         raise ValueError("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set") | ||||
|  | ||||
| #API_KEY = "DUBWNIrB6QJLOsLkpnEz2ZZa" | ||||
| #SECRET_KEY = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7" | ||||
|     url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-0329?access_token=" + get_access_token( | ||||
|         api_key, secret_key) | ||||
|  | ||||
|  | ||||
| def detectGPT(content): | ||||
|     # signal.signal(signal.SIGTERM, timeout_handler) | ||||
|     # signal.alarm(10) | ||||
|  | ||||
|     url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" + get_access_token() | ||||
|  | ||||
|     # 注意message必须是奇数条 | ||||
|     payload = json.dumps({ | ||||
|         "messages": [ | ||||
|             { | ||||
| @@ -50,56 +49,55 @@ def detectGPT(content): | ||||
|         'Content-Type': 'application/json' | ||||
|     } | ||||
|  | ||||
|     res_json = requests.request("POST", url, headers=headers, data=payload).json() | ||||
|     try: | ||||
|         message_content = res_json.get('result')  # 使用get方法获取result,避免KeyError异常 | ||||
|         response = requests.post(url, headers=headers, data=payload) | ||||
|         response.raise_for_status() | ||||
|         res_json = response.json() | ||||
|         message_content = res_json.get('result') | ||||
|         if message_content is None: | ||||
|             raise ValueError("API response content is None") | ||||
|     except requests.RequestException as e: | ||||
|         raise ValueError(f"Request failed: {str(e)}") | ||||
|  | ||||
|     except TimeoutException: | ||||
|         raise TimeoutException("The api call timed out") | ||||
|  | ||||
|     except Exception as e: | ||||
|         raise ValueError(f"Error: {str(e)}") | ||||
|     # finally: | ||||
|     #   signal.alarm(0) | ||||
|  | ||||
|     # 提取数据 | ||||
|     extracted_data = extract_json_from_text(message_content) | ||||
|  | ||||
|     # 输出提取的 JSON 数据 | ||||
|     classified_results = {"high": [], "medium": [], "low": [], "none": []} | ||||
|     for res in extracted_data: | ||||
|         classified_results[res["Risk"]].append( | ||||
|             (res["Line"], content.split("\n")[res["Line"] - 1].strip()) | ||||
|         ) | ||||
|     #return classified_results | ||||
|     result = json.dumps(classified_results, indent=2, ensure_ascii=False) | ||||
|     classified_results = json.loads(result) | ||||
|     return classified_results | ||||
|         try: | ||||
|             line_number = int(res["Line"]) | ||||
|             classified_results[res["Risk"]].append( | ||||
|                 (line_number, content.split("\n")[line_number - 1].strip()) | ||||
|             ) | ||||
|         except (ValueError, IndexError, KeyError): | ||||
|             continue | ||||
|  | ||||
| # 获得访问令牌 | ||||
| def get_access_token(): | ||||
|     return json.dumps(classified_results, indent=2, ensure_ascii=False) | ||||
|  | ||||
|  | ||||
| def get_access_token(api_key: str, secret_key: str) -> str: | ||||
|     """ | ||||
|     使用 AK,SK 生成鉴权签名(Access Token) | ||||
|     :return: access_token,或是None(如果错误) | ||||
|     使用API密钥和秘密生成访问令牌。 | ||||
|  | ||||
|     返回: | ||||
|     - access_token字符串。 | ||||
|     """ | ||||
|     url = "https://aip.baidubce.com/oauth/2.0/token" | ||||
|     params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY} | ||||
|     return str(requests.post(url, params=params).json().get("access_token")) | ||||
|     params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} | ||||
|     response = requests.post(url, params=params) | ||||
|     response.raise_for_status() | ||||
|     return response.json().get("access_token") | ||||
|  | ||||
|  | ||||
| def extract_json_from_text(text: str) -> List[Dict[str, Any]]: | ||||
|     """ | ||||
|     从文本中提取 JSON 数据。 | ||||
|     从文本中提取JSON数据。 | ||||
|  | ||||
|     参数: | ||||
|     - text: 包含 JSON 数据的字符串文本。 | ||||
|     - text: 包含JSON数据的字符串文本。 | ||||
|  | ||||
|     返回: | ||||
|     - 包含提取 JSON 数据的字典列表。 | ||||
|     - 包含提取JSON数据的字典列表。 | ||||
|     """ | ||||
|     # 使用正则表达式找到 JSON 部分 | ||||
|     json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL) | ||||
|     if not json_match: | ||||
|         print("未找到 JSON 数据") | ||||
| @@ -112,4 +110,4 @@ def extract_json_from_text(text: str) -> List[Dict[str, Any]]: | ||||
|         print(f"解码 JSON 时出错: {e}") | ||||
|         return [] | ||||
|  | ||||
|     return data | ||||
|     return data | ||||
		Reference in New Issue
	
	Block a user