diff --git a/.gitea/workflows/ci.yaml b/.gitea/workflows/ci.yaml new file mode 100644 index 0000000..bcd41a8 --- /dev/null +++ b/.gitea/workflows/ci.yaml @@ -0,0 +1,46 @@ +name: Deploy App + +on: + push: + branches: + - main + +jobs: + deploy: + name: Deploy to Web Server + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + + - name: copy file via ssh password + uses: cross-the-world/scp-pipeline@master + with: + host: "110.41.155.96" + user: ${{ secrets.USERNAME }} + pass: ${{ secrets.PASSWORD }} + local: "src/*" + remote: /root/mimajingsai/src/ + + - name: copy file via ssh password + uses: cross-the-world/scp-pipeline@master + with: + host: "110.41.130.197" + user: ${{ secrets.USERNAME }} + pass: ${{ secrets.PASSWORD }} + local: "src/*" + remote: /root/mimajingsai/src/ + + - name: copy file via ssh password + uses: cross-the-world/scp-pipeline@master + with: + host: "110.41.21.35" + user: ${{ secrets.USERNAME }} + pass: ${{ secrets.PASSWORD }} + local: "src/*" + remote: /root/mimajingsai/src/ + + + diff --git a/README_en.md b/README_en.md index a0d9bbd..2ec2728 100644 --- a/README_en.md +++ b/README_en.md @@ -52,6 +52,7 @@ pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple ### Use base image and build yourself ```bash +docker build . -f basedockerfile -t git.mamahaha.work/sangge/tpre:base docker pull git.mamahaha.work/sangge/tpre:base docker build . -t your_image_name docker run your_image_name diff --git a/doc/README_app_en.md b/doc/README_app_en.md index 174da99..d24eab1 100644 --- a/doc/README_app_en.md +++ b/doc/README_app_en.md @@ -4,4 +4,24 @@ /request_node get method -pr +pr + + +docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=110.41.130.197 git.mamahaha.work/sangge/tpre:base bash + + +tpre3: docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=60.204.233.103 git.mamahaha.work/sangge/tpre:base bash + + +110.41.155.96 tpre1 +110.41.130.197 tpre2 +110.41.21.35 tpre3 + +python client_cli.py 110.41.21.35 aaa + + +apt update && apt install docker.io mosh -y + +60.204.236.38 tpre1 +1.94.42.18 tpre2 +60.204.233.103 tpre3 diff --git a/requirements.txt b/requirements.txt index 9551179..8b65c2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ gmssl-python fastapi -uvicorn \ No newline at end of file +uvicorn +requests \ No newline at end of file diff --git a/src/client.ini b/src/client.ini index 0725142..02be9e6 100644 --- a/src/client.ini +++ b/src/client.ini @@ -1,3 +1,3 @@ [settings] -server_address = 10.20.127.226:8000 +server_address = 60.204.236.38:8000 version = 1.0 diff --git a/src/client.py b/src/client.py index 59a91cd..cb3437b 100644 --- a/src/client.py +++ b/src/client.py @@ -9,6 +9,11 @@ from pydantic import BaseModel import socket import random import time +import base64 +import json +import pickle +from fastapi.responses import JSONResponse +import asyncio @asynccontextmanager @@ -28,8 +33,7 @@ def init(): # load config from config file init_config() - - # get_node_list(6, server_address) # type: ignore + get_node_list(2, server_address) # type: ignore def init_db(): @@ -39,7 +43,7 @@ def init_db(): """ CREATE TABLE IF NOT EXISTS message ( id INTEGER PRIMARY KEY, - capsule TEXT, + capsule BLOB, ct TEXT, senderip TEXT ); @@ -62,7 +66,8 @@ def init_db(): CREATE TABLE IF NOT EXISTS senderinfo ( id INTEGER PRIMARY KEY, ip TEXT, - publickey TEXT, + pkx TEXT, + pky TEXT, threshold INTEGER ) """ @@ -84,6 +89,13 @@ def init_config(): # execute on exit def clean_env(): + global message, node_response + message = b"" + node_response = False + with sqlite3.connect("client.db") as db: + db.execute("DELETE FROM node") + db.execute("DELETE FROM message") + db.commit() print("Exit app") @@ -94,9 +106,10 @@ async def read_root(): class C(BaseModel): - Tuple: Tuple[capsule, int] + Tuple: Tuple[Tuple[Tuple[int, int], Tuple[int, int], int, Tuple[int, int]], int] ip: str + # receive messages from nodes @app.post("/receive_messages") async def receive_messages(message: C): @@ -108,32 +121,30 @@ async def receive_messages(message: C): return: status_code """ - C_tuple = message.Tuple - ip = message.ip - if not C_tuple or not ip: + if not message.Tuple or not message.ip: raise HTTPException(status_code=400, detail="Invalid input data") - C_capsule = C_tuple[0] - C_ct = C_tuple[1] + C_capsule, C_ct = message.Tuple + ip = message.ip - if not Checkcapsule(C_capsule): - raise HTTPException(status_code=400, detail="Invalid capsule") + # Serialization + bin_C_capsule = pickle.dumps(C_capsule) # insert record into database - with sqlite3.connect("message.db") as db: + with sqlite3.connect("client.db") as db: try: db.execute( """ INSERT INTO message - (capsule_column, ct_column, ip_column) + (capsule, ct, senderip) VALUES (?, ?, ?) """, - (C_capsule, C_ct, ip), + (bin_C_capsule, str(C_ct), ip), ) db.commit() - await check_merge(C_ct, ip) + check_merge(C_ct, ip) return HTTPException(status_code=200, detail="Message received") except Exception as e: print(f"Error occurred: {e}") @@ -142,37 +153,66 @@ async def receive_messages(message: C): # check record count -async def check_merge(ct: int, ip: str): +def check_merge(ct: int, ip: str): global sk, pk, node_response, message + """ + CREATE TABLE IF NOT EXISTS senderinfo ( + id INTEGER PRIMARY KEY, + ip TEXT, + pkx TEXT, + pky TEXT, + threshold INTEGER + ) + """ with sqlite3.connect("client.db") as db: - # Check if the combination of ct_column and ip_column appears more than once. + # Check if the combination of ct_column and ip_column appears more than once. cursor = db.execute( """ SELECT capsule, ct FROM message WHERE ct = ? AND senderip = ? """, - (ct, ip), + (str(ct), ip), ) # [(capsule, ct), ...] cfrag_cts = cursor.fetchall() - # get T + # get _sender_pk cursor = db.execute( """ - SELECT publickey, threshold + SELECT pkx, pky FROM senderinfo - WHERE senderip = ? + WHERE ip = ? """, - (ip), + (ip,), ) result = cursor.fetchall() - pk_sender, T = result[0] # result[0] = (pk, threshold) - + try: + pkx, pky = result[0] # result[0] = (pkx, pky) + pk_sender = (int(pkx), int(pky)) + except: + pk_sender, T = 0, -1 + + T = 2 if len(cfrag_cts) >= T: - cfrags = mergecfrag(cfrag_cts) - message = DecryptFrags(sk, pk, pk_sender, cfrags) # type: ignore + # Deserialization + temp_cfrag_cts = [] + for i in cfrag_cts: + capsule = pickle.loads(i[0]) + temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(32))) + + cfrags = mergecfrag(temp_cfrag_cts) + + print("sk:", type(sk)) + print("pk:", type(pk)) + print("pk_sender:", type(pk_sender)) + print("cfrags:", type(cfrags)) + message = DecryptFrags(sk, pk, pk_sender, cfrags) + + print("merge success", message) node_response = True + + print("merge:", node_response) # send message to node @@ -181,29 +221,33 @@ async def send_messages( ): global pk, sk id_list = [] + # calculate id of nodes for node_ip in node_ips: + node_ip = node_ip[0] ip_parts = node_ip.split(".") id = 0 for i in range(4): id += int(ip_parts[i]) << (24 - (8 * i)) id_list.append(id) - # generate rk rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore - - capsule_ct = Encrypt(pk, message) # type: ignore + + capsule, ct = Encrypt(pk, message) # type: ignore + # capsule_ct = (capsule, int.from_bytes(ct)) for i in range(len(node_ips)): - url = "http://" + node_ips[i] + ":8001" + "/user_src?message" - + url = "http://" + node_ips[i][0] + ":8001" + "/user_src" payload = { "source_ip": local_ip, "dest_ip": dest_ip, - "capsule_ct": capsule_ct, + "capsule": capsule, + "ct": int.from_bytes(ct), "rk": rk_list[i], } + print(json.dumps(payload)) response = requests.post(url, json=payload) + if response.status_code == 200: print(f"send to {node_ips[i]} successful") return 0 @@ -229,8 +273,8 @@ async def request_message(i_m: Request_Message): # dest_ip = dest_ip.split(":")[0] message_name = i_m.message_name source_ip = get_own_ip() - dest_port = "8003" - url = "http://" + dest_ip + ":" + dest_port + "/recieve_request?i_m" + dest_port = "8002" + url = "http://" + dest_ip + ":" + dest_port + "/receive_request" payload = { "dest_ip": dest_ip, "message_name": message_name, @@ -238,33 +282,17 @@ async def request_message(i_m: Request_Message): "pk": pk, } try: - response = requests.post(url, json=payload) + response = requests.post(url, json=payload, timeout=1) + # print("menxian and pk", response.text) - except: + except requests.Timeout: print("can't post") - return {"message": "can't post"} + # content = {"message": "post timeout", "error": str(e)} + # return JSONResponse(content, status_code=400) - try: - if response.status_code == 200: - data = response.json() - public_key = int(data["public_key"]) - threshold = int(data["threshold"]) - with sqlite3.connect("client.db") as db: - db.execute( - """ - INSERT INTO senderinfo - (public_key, threshold) - VALUES - (?, ?) - """, - (public_key, threshold), - ) - except: - print("Database error") - return {"message": "Database Error"} - - # wait 10s to recieve message from nodes + # wait 3s to receive message from nodes for _ in range(10): + print("wait:", node_response) if node_response: data = message @@ -273,20 +301,22 @@ async def request_message(i_m: Request_Message): node_response = False # return message to frontend - return {"message": data} - time.sleep(1) - return {"message": "recieve timeout"} + return {"message": str(data)} + await asyncio.sleep(0.2) + content = {"message": "receive timeout"} + return JSONResponse(content, status_code=400) -# recieve request from others -@app.post("/recieve_request") -async def recieve_request(i_m: IP_Message): +# receive request from others +@app.post("/receive_request") +async def receive_request(i_m: IP_Message): global pk source_ip = get_own_ip() if source_ip != i_m.dest_ip: return HTTPException(status_code=400, detail="Wrong ip") dest_ip = i_m.source_ip - threshold = random.randrange(1, 6) + # threshold = random.randrange(1, 2) + threshold = 2 own_public_key = pk pk_B = i_m.pk @@ -300,26 +330,26 @@ async def recieve_request(i_m: IP_Message): (threshold,), ) node_ips = cursor.fetchall() - + # message name message = b"hello world" + random.randbytes(8) - + # send message to nodes - await send_messages(tuple(node_ips), message, dest_ip, pk_B, threshold) + await send_messages(tuple(node_ips), message, dest_ip, pk_B, threshold) response = {"threshold": threshold, "public_key": own_public_key} + print("###############RESPONSE = ", response) return response def get_own_ip() -> str: - hostname = socket.gethostname() - ip = socket.gethostbyname(hostname) + ip = os.environ.get("HOST_IP", "IP not set") return ip # get node list from central server def get_node_list(count: int, server_addr: str): url = "http://" + server_addr + "/server/send_nodes_list?count=" + str(count) - response = requests.get(url) + response = requests.get(url, timeout=3) # Checking the response if response.status_code == 200: print("Success get node list") @@ -342,8 +372,48 @@ def get_node_list(count: int, server_addr: str): print("Failed:", response.status_code, response.text) -pk = point -sk = int +# send pk to others +@app.get("/get_pk") +async def get_pk(): + global pk, sk + print(sk) + return {"pkx": str(pk[0]), "pky": str(pk[1])} + + +class pk_model(BaseModel): + pkx: str + pky: str + ip: str + + +# recieve pk from frontend +@app.post("/recieve_pk") +async def recieve_pk(pk: pk_model): + pkx = pk.pkx + pky = pk.pky + dest_ip = pk.ip + try: + threshold = 2 + with sqlite3.connect("client.db") as db: + db.execute( + """ + INSERT INTO senderinfo + (ip, pkx, pky, threshold) + VALUES + (?, ?, ?, ?) + """, + (str(dest_ip), pkx, pky, threshold), + ) + except Exception as e: + # raise error + print("Database error") + content = {"message": "Database Error", "error": str(e)} + return JSONResponse(content, status_code=400) + return {"message": "save pk in database"} + + +pk = (0, 0) +sk = 0 server_address = str node_response = False message = bytes @@ -352,4 +422,4 @@ local_ip = get_own_ip() if __name__ == "__main__": import uvicorn # pylint: disable=e0401 - uvicorn.run("client:app", host="0.0.0.0", port=8003, reload=True) + uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True, log_level="debug") diff --git a/src/client_cli.py b/src/client_cli.py index 19fef3d..e87a495 100644 --- a/src/client_cli.py +++ b/src/client_cli.py @@ -1,23 +1,41 @@ import argparse import requests +import json + def send_post_request(ip_addr, message_name): - url = f"http://localhost:20234/request_message/?i_m" - data = { - "dest_ip": ip_addr, - "message_name": message_name - } + url = f"http://localhost:8002/request_message" + data = {"dest_ip": ip_addr, "message_name": message_name} response = requests.post(url, json=data) return response.text + +def get_pk(ip_addr): + url = f"http://" + ip_addr + ":8002/get_pk" + response = requests.get(url, timeout=1) + print(response.text) + json_pk = json.loads(response.text) + payload = {"pkx": json_pk["pkx"], "pky": json_pk["pky"], "ip": ip_addr} + response = requests.post("http://localhost:8002/recieve_pk", json=payload) + + return response.text + + + def main(): parser = argparse.ArgumentParser(description="Send POST request to a specified IP.") parser.add_argument("ip_addr", help="IP address to send request to.") parser.add_argument("message_name", help="Message name to send.") args = parser.parse_args() - response = send_post_request(args.ip_addr, args.message_name) + + response = get_pk(args.ip_addr) print(response) + + response = send_post_request(args.ip_addr, args.message_name) + + print(response) + if __name__ == "__main__": main() diff --git a/src/client_demo.py b/src/client_demo.py new file mode 100644 index 0000000..8ca0a3b --- /dev/null +++ b/src/client_demo.py @@ -0,0 +1,4 @@ +from tpre import * + +# local {"pkx":"110913495319893280527511520027612816833094668640322629943553195742251267532611","pky":"42442813417048462506373786007682778510807282038950736216326706485290996455738"} +# pkb (110913495319893280527511520027612816833094668640322629943553195742251267532611,42442813417048462506373786007682778510807282038950736216326706485290996455738 diff --git a/src/demo.py b/src/demo.py index 7c077b5..0f0646a 100644 --- a/src/demo.py +++ b/src/demo.py @@ -1,51 +1,60 @@ from tpre import * import time -# 1 -start_time = time.time() -pk_a, sk_a = GenerateKeyPair() -m = b"hello world" -end_time = time.time() -elapsed_time = end_time - start_time -print(f"代码块1运行时间:{elapsed_time}秒") +for N in range(4,21,4): + # N = 10 + # T = 5 + T = N // 2 + print(f"当前门限值: N = {N}, T = {T}") + + start_total_time = time.time() + # 1 + start_time = time.time() + pk_a, sk_a = GenerateKeyPair() + m = b"hello world" + end_time = time.time() + elapsed_time = end_time - start_time + print(f"密钥生成运行时间:{elapsed_time}秒") -# 2 -start_time = time.time() -capsule_ct = Encrypt(pk_a, m) -end_time = time.time() -elapsed_time = end_time - start_time -print(f"代码块2运行时间:{elapsed_time}秒") + # 2 + start_time = time.time() + capsule_ct = Encrypt(pk_a, m) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"加密算法运行时间:{elapsed_time}秒") -# 3 -pk_b, sk_b = GenerateKeyPair() + # 3 + pk_b, sk_b = GenerateKeyPair() -N = 10 -T = 5 + + # 5 + start_time = time.time() + id_tuple = tuple(range(N)) + rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"重加密密钥生成算法运行时间:{elapsed_time}秒") -# 5 -start_time = time.time() -id_tuple = tuple(range(N)) -rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) -end_time = time.time() -elapsed_time = end_time - start_time -print(f"代码块5运行时间:{elapsed_time}秒") + # 7 + start_time = time.time() + cfrag_cts = [] -# 7 -start_time = time.time() -cfrag_cts = [] + for rekey in rekeys: + cfrag_ct = ReEncrypt(rekey, capsule_ct) + cfrag_cts.append(cfrag_ct) + end_time = time.time() + elapsed_time = (end_time - start_time) / len(rekeys) + print(f"重加密算法运行时间:{elapsed_time}秒") -for rekey in rekeys: - cfrag_ct = ReEncrypt(rekey, capsule_ct) - cfrag_cts.append(cfrag_ct) -end_time = time.time() -elapsed_time = end_time - start_time -print(f"代码块7运行时间:{elapsed_time}秒") - -# 9 -start_time = time.time() -cfrags = mergecfrag(cfrag_cts) -m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) -end_time = time.time() -elapsed_time = end_time - start_time -print(f"代码块9运行时间:{elapsed_time}秒") -print(m) + # 9 + start_time = time.time() + cfrags = mergecfrag(cfrag_cts) + m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) + end_time = time.time() + elapsed_time = end_time - start_time + end_total_time = time.time() + total_time = end_total_time - start_total_time + print(f"解密算法运行时间:{elapsed_time}秒") + print("成功解密:", m) + print(f"算法总运行时间:{total_time}秒") + print() diff --git a/src/node.py b/src/node.py index 21aad78..0388e57 100644 --- a/src/node.py +++ b/src/node.py @@ -5,6 +5,9 @@ import socket import asyncio from pydantic import BaseModel from tpre import * +import os +from typing import Any, Tuple +import base64 @asynccontextmanager @@ -15,9 +18,9 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) -server_address = "http://10.20.14.232:8000/server" +server_address = "http://60.204.236.38:8000/server" id = 0 -ip = "10.16.21.163" +ip = "" client_ip_src = "" # 发送信息用户的ip client_ip_des = "" # 接收信息用户的ip processed_message = () # 重加密后的数据 @@ -32,54 +35,54 @@ def send_ip(): url = server_address + "/get_node?ip=" + ip # ip = get_local_ip() # type: ignore global id - id = requests.get(url) + id = requests.get(url, timeout=3) -# 用socket获取本机ip +# 用环境变量获取本机ip def get_local_ip(): - # 创建一个套接字对象 - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - # 连接到一个外部的服务器,这将自动绑定到本地IP地址 - s.connect(("8.8.8.8", 80)) - # 获取本地IP地址 - local_ip = s.getsockname()[0] - s.close() global ip - ip = local_ip + ip = os.environ.get("HOST_IP", "IP not set") def init(): - # get_local_ip() - global id + get_local_ip() send_ip() task = asyncio.create_task(send_heartbeat_internal()) print("Finish init") def clear(): - pass + print("exit node") # 接收用户发来的消息,经过处理之后,再将消息发送给其他用户 async def send_heartbeat_internal() -> None: - timeout = 3 + timeout = 30 global ip url = server_address + "/heartbeat?ip=" + ip while True: # print('successful send my_heart') try: - folderol = requests.get(url) + folderol = requests.get(url, timeout=3) except: print("Central server error") - + # 删除超时的节点(假设你有一个异步的数据库操作函数) await asyncio.sleep(timeout) +class Req(BaseModel): + source_ip: str + dest_ip: str + capsule: capsule + ct: int + rk: list + + @app.post("/user_src") # 接收用户1发送的信息 -async def receive_user_src_message(message: Request): +async def user_src(message: Req): global client_ip_src, client_ip_des # kfrag , capsule_ct ,client_ip_src , client_ip_des = json_data[] # 看梁俊勇 """ @@ -90,14 +93,16 @@ async def receive_user_src_message(message: Request): "rk": rk_list[i], } """ + print("node: ", message) + source_ip = message.source_ip + dest_ip = message.dest_ip + capsule = message.capsule + ct = message.ct + capsule_ct = (capsule, ct.to_bytes(32)) + rk = message.rk - data = await message.json() - source_ip = data.get("source_ip") - dest_ip = data.get("dest_ip") - capsule_ct = data.get("capsule_ct") - rk = data.get("rk") - - processed_message = ReEncrypt(rk, capsule_ct) + a, b = ReEncrypt(rk, capsule_ct) + processed_message = (a, int.from_bytes(b)) await send_user_des_message(source_ip, dest_ip, processed_message) return HTTPException(status_code=200, detail="message recieved") @@ -107,12 +112,12 @@ async def send_user_des_message(source_ip: str, dest_ip: str, re_message): # # 发送 HTTP POST 请求 response = requests.post( - "http://" + dest_ip + "/receive_messages?message", json=data + "http://" + dest_ip + ":8002" + "/receive_messages", json=data ) - print(response) + print("send stauts:" ,response.text) if __name__ == "__main__": import uvicorn # pylint: disable=e0401 - uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True) + uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True,log_level="debug") diff --git a/src/server.py b/src/server.py index 18d6214..f0530cf 100644 --- a/src/server.py +++ b/src/server.py @@ -6,34 +6,41 @@ import sqlite3 import asyncio import time + @asynccontextmanager async def lifespan(app: FastAPI): init() yield clean_env() -app = FastAPI(lifespan = lifespan) + +app = FastAPI(lifespan=lifespan) # 连接到数据库(如果数据库不存在,则会自动创建) -conn = sqlite3.connect('server.db') +conn = sqlite3.connect("server.db") # 创建游标对象,用于执行SQL语句 cursor = conn.cursor() # 创建表: id: int; ip: TEXT -cursor.execute('''CREATE TABLE IF NOT EXISTS nodes ( +cursor.execute( + """CREATE TABLE IF NOT EXISTS nodes ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, last_heartbeat INTEGER - )''') + )""" +) + def init(): asyncio.create_task(receive_heartbeat_internal()) + def clean_env(): clear_database() # 关闭游标和连接 cursor.close() conn.close() + @app.get("/server/show_nodes") async def show_nodes() -> list: nodes_list = [] @@ -44,37 +51,42 @@ async def show_nodes() -> list: nodes_list.append(row) return nodes_list + @app.get("/server/get_node") async def get_node(ip: str) -> int: - ''' - 中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点 - params: - ip: node ip - return: - id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id - ''' + """ + 中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点 + params: + ip: node ip + return: + id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id + """ ip_parts = ip.split(".") ip_int = 0 for i in range(4): ip_int += int(ip_parts[i]) << (24 - (8 * i)) - + # 获取当前时间 current_time = int(time.time()) # 插入数据 - cursor.execute("INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)", (ip_int, ip, current_time)) + cursor.execute( + "INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)", + (ip_int, ip, current_time), + ) conn.commit() return ip_int + @app.get("/server/delete_node") async def delete_node(ip: str) -> None: - ''' + """ param: ip: 待删除节点的ip地址 return: None - ''' + """ # 查询要删除的节点 cursor.execute("SELECT * FROM nodes WHERE ip=?", (ip,)) row = cursor.fetchone() @@ -86,46 +98,52 @@ async def delete_node(ip: str) -> None: else: print(f"Node with IP {ip} not found.") + # 接收节点心跳包 @app.get("/server/heartbeat") async def receive_heartbeat(ip: str): - cursor.execute("UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)) - return {"status": "received"} - + cursor.execute( + "UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip) + ) + return {"status": "received"} + + async def receive_heartbeat_internal(): while 1: - timeout = 7 + timeout = 70 # 删除超时的节点 - # cursor.execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,)) - # conn.commit() + cursor.execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,)) + conn.commit() await asyncio.sleep(timeout) + @app.get("/server/send_nodes_list") async def send_nodes_list(count: int) -> list: - ''' + """ 中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成list格式返回给客户端 - params: - count: 所需节点个数 - return: + params: + count: 所需节点个数 + return: nodes_list: list - ''' + """ nodes_list = [] # 查询数据库中的节点数据 cursor.execute("SELECT * FROM nodes LIMIT ?", (count,)) rows = cursor.fetchall() - for row in rows: id, ip, last_heartbeat = row nodes_list.append(ip) return nodes_list + # @app.get("/server/clear_database") def clear_database() -> None: cursor.execute("DELETE FROM nodes") conn.commit() + if __name__ == "__main__": import uvicorn # pylint: disable=e0401 diff --git a/src/tpre.py b/src/tpre.py index 3019cd0..3f35f83 100644 --- a/src/tpre.py +++ b/src/tpre.py @@ -272,7 +272,9 @@ def f(x: int, f_modulus: list, T: int) -> int: return res -def GenerateReKey(sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int,...]) -> list: +def GenerateReKey( + sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...] +) -> list: """ param: skA, pkB, N(节点总数), T(阈值) @@ -341,7 +343,7 @@ def Checkcapsule(capsule: capsule) -> bool: # 验证胶囊的有效性 return flag -def ReEncapsulate(kFrag: list, capsule: capsule) -> Tuple[point, point, int, point]: +def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, point]: id, rk, Xa, U1 = kFrag E, V, s = capsule if not Checkcapsule(capsule): @@ -355,7 +357,7 @@ def ReEncapsulate(kFrag: list, capsule: capsule) -> Tuple[point, point, int, poi def ReEncrypt( - kFrag: list, C: Tuple[capsule, bytes] + kFrag: tuple, C: Tuple[capsule, bytes] ) -> Tuple[Tuple[point, point, int, point], bytes]: capsule, enc_Data = C