feat:为llm常规添加并发,提高效率

This commit is contained in:
tritium0041 2024-06-04 21:47:17 +08:00
parent a2651b499e
commit e9b1e82492
2 changed files with 67 additions and 29 deletions

View File

@ -1,8 +1,11 @@
import json
import os
import threading
import time
from .utils import *
import openai
import signal
# import signal
class TimeoutException(Exception):
@ -22,8 +25,8 @@ def detectGPT(content: str):
raise ValueError("env OPENAI_API_KEY no set")
# Set alarm timer
signal.signal(signal.SIGTERM, timeout_handler)
signal.alarm(10)
# signal.signal(signal.SIGTERM, timeout_handler)
# signal.alarm(10)
client = openai.OpenAI(base_url="https://api.xiaoai.plus/v1", api_key=api_key)
text = content
@ -34,7 +37,9 @@ def detectGPT(content: str):
"role": "system",
"content": "You are a Python code reviewer.Read the code below and identify any potential security vulnerabilities. Classify them by risk level (high, medium, low, none). Only report the line number and the risk level.\nYou should output the result as json format in one line. For example: "
'[{"Line": {the line number}, "Risk": "{choose from (high,medium,low)}","Reason":"{how it is vulnable}"}] Each of these three field is required.\n'
"You are required to only output the json format. Do not output any other information.\n",
"You are required to only output the json format. Do not output any other information.请注意:只对有具体危害的代码片段判定为有风险。\n"
"For examples:\nos.system('ls'),subprocess.call(['ls', '-l']),subprocess.call([\"/bin/sh\",\"-i\"]),eval(code),exec(code) and so on.\n"
"Please IGNORE the risks that dont matter a lot.",
},
{
"role": "user",
@ -55,8 +60,8 @@ def detectGPT(content: str):
except TimeoutException:
raise TimeoutException("The api call timed out")
finally:
signal.alarm(0)
# finally:
# signal.alarm(0)
classified_results = {"high": [], "medium": [], "low": [], "none": []}
for res in res_json:
@ -67,3 +72,33 @@ def detectGPT(content: str):
except IndexError:
pass
return classified_results
def GPTdetectFileList(fileList):
# print(len(fileList))
results = {"high": [], "medium": [], "low": [], "none": []}
threads = []
for file in fileList:
content = read_file_content(str(file))
threads.append(threading.Thread(target=GPTThread, args=(str(file), content, results)))
for thread in threads:
thread.start()
time.sleep(0.1)
for thread in threads:
thread.join()
return results
def GPTThread(filename, content, results):
try:
res = detectGPT(content)
for key in res:
if key != "none": # Exclude 'none' risk level
results[key].extend(
[
(f"{filename}: Line {line_num}", line)
for line_num, line in res[key]
]
)
except Exception as e:
print(e)

View File

@ -7,7 +7,7 @@ from reportlab.platypus import Paragraph, Spacer, SimpleDocTemplate
from detection.pickle_detection import pickleDataDetection
from .Regexdetection import find_dangerous_functions
from .GPTdetection import detectGPT
from .GPTdetection import detectGPT,GPTdetectFileList
from .pyc_detection import disassemble_pyc
from .utils import *
import sys
@ -107,6 +107,7 @@ def generate_text_content(results: Dict[str, List[Tuple[int, str]]]) -> str:
text_output += "=" * 30 + "\n\n"
for risk_level, entries in results.items():
# print(risk_level, entries)
if risk_level == "pickles":
text_output += f"Pickles:\n"
for i in entries:
@ -378,29 +379,31 @@ def process_path(
for file_path in Path(path).rglob("*")
if file_path.suffix in SUPPORTED_EXTENSIONS
]
if mode == "llm":
results = GPTdetectFileList(all_files)
else:
# 扫描动画
for file_path in tqdm(all_files, desc="Scanning files", unit="file"):
file_extension = file_path.suffix
if file_extension in [".pkl",".pickle"]:
res = pickleDataDetection(str(file_path), output_file)
results["pickles"].append({
"file": str(file_path),
"result": res
})
continue
file_results = checkModeAndDetect(
mode, str(file_path), file_extension, pycdc_addr
)
if file_results is not None:
for key in file_results:
if key != "none": # Exclude 'none' risk level
results[key].extend(
[
(f"{file_path}: Line {line_num}", line)
for line_num, line in file_results[key]
]
)
for file_path in tqdm(all_files, desc="Scanning files", unit="file"):
file_extension = file_path.suffix
if file_extension in [".pkl",".pickle"]:
res = pickleDataDetection(str(file_path), output_file)
results["pickles"].append({
"file": str(file_path),
"result": res
})
continue
file_results = checkModeAndDetect(
mode, str(file_path), file_extension, pycdc_addr
)
if file_results is not None:
for key in file_results:
if key != "none": # Exclude 'none' risk level
results[key].extend(
[
(f"{file_path}: Line {line_num}", line)
for line_num, line in file_results[key]
]
)
elif os.path.isfile(path):
file_extension = os.path.splitext(path)[1]
if file_extension in [".pkl", ".pickle"]: