150 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			150 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import threading
 | |
| import time
 | |
| 
 | |
| import requests
 | |
| import re
 | |
| import json
 | |
| from typing import List, Dict, Any
 | |
| 
 | |
| from detection.utils import read_file_content
 | |
| 
 | |
| 
 | |
| class TimeoutException(Exception):
 | |
|     """自定义异常用于处理超时情况。"""
 | |
|     pass
 | |
| 
 | |
| 
 | |
| def detectGPT(content: str,token:str):
 | |
|     """
 | |
|     检测给定的代码内容中的潜在安全漏洞。
 | |
| 
 | |
|     参数:
 | |
|     - content: 要检测的代码字符串。
 | |
| 
 | |
|     返回:
 | |
|     - 分类后的漏洞信息的JSON字符串。
 | |
|     """
 | |
| 
 | |
|     url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-0329?access_token=" + token
 | |
| 
 | |
|     payload = json.dumps({
 | |
|         "messages": [
 | |
|             {
 | |
|                 "role": "user",
 | |
|                 "content": (
 | |
|                         "You are a Python code reviewer. Read the code below and identify any potential "
 | |
|                         "security vulnerabilities. Classify them by risk level (high, medium, low, none). "
 | |
|                         'Only report the line number and the risk level.\nYou should output the result as '
 | |
|                         'json format in one line. For example: [{"Line": {the line number}, "Risk": "{choose from (high,medium,low)}","Reason":"{how it is vulnerable}"}] '
 | |
|                         "Each of these three fields is required.\nYou are required to only output the json format. "
 | |
|                         "Do not output any other information." + content
 | |
|                 )
 | |
|             }
 | |
|         ]
 | |
|     })
 | |
|     headers = {
 | |
|         'Content-Type': 'application/json'
 | |
|     }
 | |
| 
 | |
|     try:
 | |
|         response = requests.post(url, headers=headers, data=payload)
 | |
|         response.raise_for_status()
 | |
|         res_json = response.json()
 | |
|         message_content = res_json.get('result')
 | |
|         if message_content is None:
 | |
|             raise ValueError("API response content is None")
 | |
|     except requests.RequestException as e:
 | |
|         raise ValueError(f"Request failed: {str(e)}")
 | |
| 
 | |
|     extracted_data = extract_json_from_text(message_content)
 | |
| 
 | |
|     classified_results = {"high": [], "medium": [], "low": [], "none": []}
 | |
|     for res in extracted_data:
 | |
|         # print(res)
 | |
|         try:
 | |
|             line_number = int(res["Line"])
 | |
|             classified_results[res["Risk"]].append(
 | |
|                 (line_number, content.split("\n")[line_number - 1].strip())
 | |
|             )
 | |
|         except (ValueError, IndexError, KeyError):
 | |
|             continue
 | |
| 
 | |
|     return classified_results
 | |
| 
 | |
| 
 | |
| def get_access_token(api_key: str, secret_key: str) -> str:
 | |
|     """
 | |
|     使用API密钥和秘密生成访问令牌。
 | |
| 
 | |
|     返回:
 | |
|     - access_token字符串。
 | |
|     """
 | |
|     url = "https://aip.baidubce.com/oauth/2.0/token"
 | |
|     params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key}
 | |
|     response = requests.post(url, params=params)
 | |
|     response.raise_for_status()
 | |
|     return response.json().get("access_token")
 | |
| 
 | |
| 
 | |
| def extract_json_from_text(text: str) -> List[Dict[str, Any]]:
 | |
|     """
 | |
|     从文本中提取JSON数据。
 | |
| 
 | |
|     参数:
 | |
|     - text: 包含JSON数据的字符串文本。
 | |
| 
 | |
|     返回:
 | |
|     - 包含提取JSON数据的字典列表。
 | |
|     """
 | |
|     json_match = re.search(r'\[\s*{.*?}\s*\]', text, re.DOTALL)
 | |
|     if not json_match:
 | |
|         print("未找到 JSON 数据")
 | |
|         return []
 | |
| 
 | |
|     json_string = json_match.group(0)
 | |
|     try:
 | |
|         data = json.loads(json_string)
 | |
|     except json.JSONDecodeError as e:
 | |
|         print(f"解码 JSON 时出错: {e}")
 | |
|         return []
 | |
| 
 | |
|     return data
 | |
| 
 | |
| 
 | |
| def GPTdetectFileList(fileList):
 | |
|     api_key = os.getenv("BAIDU_API_KEY")
 | |
|     secret_key = os.getenv("BAIDU_SECRET_KEY")
 | |
|     # api_key = "DUBWNIrB6QJLOsLkpnEz2ZZa"
 | |
|     # secret_key = "9WK4HIV2n9r1ePPirqD4EQ6Ea33rH1m7"
 | |
|     if not api_key or not secret_key:
 | |
|         raise ValueError("BAIDU_API_KEY or BAIDU_SECRET_KEY is not set")
 | |
|     # print(len(fileList))
 | |
|     results = {"high": [], "medium": [], "low": [], "none": []}
 | |
|     threads = []
 | |
|     token = get_access_token(api_key, secret_key)
 | |
|     # print(token)
 | |
|     for file in fileList:
 | |
|         content = read_file_content(str(file))
 | |
|         threads.append(threading.Thread(target=GPTThread, args=(str(file), content, results,token)))
 | |
|     for thread in threads:
 | |
|         thread.start()
 | |
|         time.sleep(0.5)
 | |
|     for thread in threads:
 | |
|         thread.join()
 | |
|     return results
 | |
| 
 | |
| 
 | |
| def GPTThread(filename, content, results,token):
 | |
| 
 | |
|         res = detectGPT(content,token)
 | |
|         # print(res)
 | |
|         for key in res:
 | |
|             if key != "none":  # Exclude 'none' risk level
 | |
|                 results[key].extend(
 | |
|                     [
 | |
|                         (f"{filename}: Line {line_num}", line)
 | |
|                         for line_num, line in res[key]
 | |
|                     ]
 | |
|                 )
 |