feat:为llm常规添加并发,提高效率
This commit is contained in:
		| @@ -1,8 +1,11 @@ | |||||||
| import json | import json | ||||||
| import os | import os | ||||||
|  | import threading | ||||||
|  | import time | ||||||
|  |  | ||||||
| from .utils import * | from .utils import * | ||||||
| import openai | import openai | ||||||
| import signal | # import signal | ||||||
|  |  | ||||||
|  |  | ||||||
| class TimeoutException(Exception): | class TimeoutException(Exception): | ||||||
| @@ -22,8 +25,8 @@ def detectGPT(content: str): | |||||||
|         raise ValueError("env OPENAI_API_KEY no set") |         raise ValueError("env OPENAI_API_KEY no set") | ||||||
|  |  | ||||||
|     # Set alarm timer |     # Set alarm timer | ||||||
|     signal.signal(signal.SIGTERM, timeout_handler) |     # signal.signal(signal.SIGTERM, timeout_handler) | ||||||
|     signal.alarm(10) |     # signal.alarm(10) | ||||||
|  |  | ||||||
|     client = openai.OpenAI(base_url="https://api.xiaoai.plus/v1", api_key=api_key) |     client = openai.OpenAI(base_url="https://api.xiaoai.plus/v1", api_key=api_key) | ||||||
|     text = content |     text = content | ||||||
| @@ -34,7 +37,9 @@ def detectGPT(content: str): | |||||||
|                 "role": "system", |                 "role": "system", | ||||||
|                 "content": "You are a Python code reviewer.Read the code below and identify any potential security vulnerabilities. Classify them by risk level (high, medium, low, none). Only report the line number and the risk level.\nYou should output the result as json format in one line. For example: " |                 "content": "You are a Python code reviewer.Read the code below and identify any potential security vulnerabilities. Classify them by risk level (high, medium, low, none). Only report the line number and the risk level.\nYou should output the result as json format in one line. For example: " | ||||||
|                            '[{"Line": {the line number}, "Risk": "{choose from (high,medium,low)}","Reason":"{how it is vulnable}"}] Each of these three field is required.\n' |                            '[{"Line": {the line number}, "Risk": "{choose from (high,medium,low)}","Reason":"{how it is vulnable}"}] Each of these three field is required.\n' | ||||||
|                            "You are required to only output the json format. Do not output any other information.\n", |                            "You are required to only output the json format. Do not output any other information.请注意:只对有具体危害的代码片段判定为有风险。\n" | ||||||
|  |                            "For examples:\nos.system('ls'),subprocess.call(['ls', '-l']),subprocess.call([\"/bin/sh\",\"-i\"]),eval(code),exec(code) and so on.\n" | ||||||
|  |                            "Please IGNORE the risks that dont matter a lot.", | ||||||
|             }, |             }, | ||||||
|             { |             { | ||||||
|                 "role": "user", |                 "role": "user", | ||||||
| @@ -55,8 +60,8 @@ def detectGPT(content: str): | |||||||
|     except TimeoutException: |     except TimeoutException: | ||||||
|         raise TimeoutException("The api call timed out") |         raise TimeoutException("The api call timed out") | ||||||
|  |  | ||||||
|     finally: |     # finally: | ||||||
|         signal.alarm(0) |     #     signal.alarm(0) | ||||||
|  |  | ||||||
|     classified_results = {"high": [], "medium": [], "low": [], "none": []} |     classified_results = {"high": [], "medium": [], "low": [], "none": []} | ||||||
|     for res in res_json: |     for res in res_json: | ||||||
| @@ -67,3 +72,33 @@ def detectGPT(content: str): | |||||||
|         except IndexError: |         except IndexError: | ||||||
|             pass |             pass | ||||||
|     return classified_results |     return classified_results | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def GPTdetectFileList(fileList): | ||||||
|  |     # print(len(fileList)) | ||||||
|  |     results = {"high": [], "medium": [], "low": [], "none": []} | ||||||
|  |     threads = [] | ||||||
|  |     for file in fileList: | ||||||
|  |         content = read_file_content(str(file)) | ||||||
|  |         threads.append(threading.Thread(target=GPTThread, args=(str(file), content, results))) | ||||||
|  |     for thread in threads: | ||||||
|  |         thread.start() | ||||||
|  |         time.sleep(0.1) | ||||||
|  |     for thread in threads: | ||||||
|  |         thread.join() | ||||||
|  |     return results | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def GPTThread(filename, content, results): | ||||||
|  |     try: | ||||||
|  |         res = detectGPT(content) | ||||||
|  |         for key in res: | ||||||
|  |             if key != "none":  # Exclude 'none' risk level | ||||||
|  |                 results[key].extend( | ||||||
|  |                     [ | ||||||
|  |                         (f"{filename}: Line {line_num}", line) | ||||||
|  |                         for line_num, line in res[key] | ||||||
|  |                     ] | ||||||
|  |                 ) | ||||||
|  |     except Exception as e: | ||||||
|  |         print(e) | ||||||
|   | |||||||
| @@ -7,7 +7,7 @@ from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate | |||||||
|  |  | ||||||
| from detection.pickle_detection import pickleDataDetection | from detection.pickle_detection import pickleDataDetection | ||||||
| from .Regexdetection import find_dangerous_functions | from .Regexdetection import find_dangerous_functions | ||||||
| from .GPTdetection import detectGPT | from .GPTdetection import detectGPT,GPTdetectFileList | ||||||
| from .pyc_detection import disassemble_pyc | from .pyc_detection import disassemble_pyc | ||||||
| from .utils import * | from .utils import * | ||||||
| import sys | import sys | ||||||
| @@ -107,6 +107,7 @@ def generate_text_content(results: Dict[str, List[Tuple[int, str]]]) -> str: | |||||||
|     text_output += "=" * 30 + "\n\n" |     text_output += "=" * 30 + "\n\n" | ||||||
|  |  | ||||||
|     for risk_level, entries in results.items(): |     for risk_level, entries in results.items(): | ||||||
|  |         # print(risk_level, entries) | ||||||
|         if risk_level == "pickles": |         if risk_level == "pickles": | ||||||
|             text_output += f"Pickles:\n" |             text_output += f"Pickles:\n" | ||||||
|             for i in entries: |             for i in entries: | ||||||
| @@ -378,29 +379,31 @@ def process_path( | |||||||
|             for file_path in Path(path).rglob("*") |             for file_path in Path(path).rglob("*") | ||||||
|             if file_path.suffix in SUPPORTED_EXTENSIONS |             if file_path.suffix in SUPPORTED_EXTENSIONS | ||||||
|         ] |         ] | ||||||
|  |         if mode == "llm": | ||||||
|  |             results = GPTdetectFileList(all_files) | ||||||
|  |         else: | ||||||
|         # 扫描动画 |         # 扫描动画 | ||||||
|         for file_path in tqdm(all_files, desc="Scanning files", unit="file"): |             for file_path in tqdm(all_files, desc="Scanning files", unit="file"): | ||||||
|             file_extension = file_path.suffix |                 file_extension = file_path.suffix | ||||||
|             if file_extension in [".pkl",".pickle"]: |                 if file_extension in [".pkl",".pickle"]: | ||||||
|                 res = pickleDataDetection(str(file_path), output_file) |                     res = pickleDataDetection(str(file_path), output_file) | ||||||
|                 results["pickles"].append({ |                     results["pickles"].append({ | ||||||
|                     "file": str(file_path), |                         "file": str(file_path), | ||||||
|                     "result": res |                         "result": res | ||||||
|                 }) |                     }) | ||||||
|                 continue |                     continue | ||||||
|             file_results = checkModeAndDetect( |                 file_results = checkModeAndDetect( | ||||||
|                 mode, str(file_path), file_extension, pycdc_addr |                     mode, str(file_path), file_extension, pycdc_addr | ||||||
|             ) |                 ) | ||||||
|             if file_results is not None: |                 if file_results is not None: | ||||||
|                 for key in file_results: |                     for key in file_results: | ||||||
|                     if key != "none":  # Exclude 'none' risk level |                         if key != "none":  # Exclude 'none' risk level | ||||||
|                         results[key].extend( |                             results[key].extend( | ||||||
|                             [ |                                 [ | ||||||
|                                 (f"{file_path}: Line {line_num}", line) |                                     (f"{file_path}: Line {line_num}", line) | ||||||
|                                 for line_num, line in file_results[key] |                                     for line_num, line in file_results[key] | ||||||
|                             ] |                                 ] | ||||||
|                         ) |                             ) | ||||||
|     elif os.path.isfile(path): |     elif os.path.isfile(path): | ||||||
|         file_extension = os.path.splitext(path)[1] |         file_extension = os.path.splitext(path)[1] | ||||||
|         if file_extension in [".pkl", ".pickle"]: |         if file_extension in [".pkl", ".pickle"]: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user