From aeccc72b97546ffc079f73714e43b09fca724854 Mon Sep 17 00:00:00 2001 From: sangge <2251250136@qq.com> Date: Mon, 23 Oct 2023 17:00:09 +0800 Subject: [PATCH] fix: fix some errors --- src/client.py | 162 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 99 insertions(+), 63 deletions(-) diff --git a/src/client.py b/src/client.py index 6c05624..29d83c1 100644 --- a/src/client.py +++ b/src/client.py @@ -17,6 +17,7 @@ async def lifespan(app: FastAPI): yield clean_env() + app = FastAPI(lifespan=lifespan) @@ -24,7 +25,10 @@ def init(): global pk, sk, server_address init_db() pk, sk = GenerateKeyPair() + + # load config from config file init_config() + # get_node_list(6, server_address) # type: ignore @@ -99,7 +103,7 @@ async def receive_messages(message: C): """ receive capsule and ip from nodes params: - C: capsule and ct + Tuple: capsule and ct ip: sender ip return: status_code @@ -129,7 +133,7 @@ async def receive_messages(message: C): (C_capsule, C_ct, ip), ) db.commit() - await check_merge(db, C_ct, ip) + await check_merge(C_ct, ip) return HTTPException(status_code=200, detail="Message received") except Exception as e: print(f"Error occurred: {e}") @@ -138,31 +142,33 @@ async def receive_messages(message: C): # check record count -async def check_merge(db, ct: int, ip: str): +async def check_merge(ct: int, ip: str): global sk, pk, node_response, message + with sqlite3.connect("client.db") as db: # Check if the combination of ct_column and ip_column appears more than once. - cursor = db.execute( - """ - SELECT capsule, ct - FROM message - WHERE ct = ? AND senderip = ? - """, - (ct, ip), - ) - # [(capsule, ct), ...] - cfrag_cts = cursor.fetchall() + cursor = db.execute( + """ + SELECT capsule, ct + FROM message + WHERE ct = ? AND senderip = ? + """, + (ct, ip), + ) + # [(capsule, ct), ...] + cfrag_cts = cursor.fetchall() - # get N - cursor = db.execute( - """ - SELECT publickey, threshold - FROM senderinfo - WHERE senderip = ? - """, - (ip), - ) - result = cursor.fetchall() - pk_sender, T = result[0] + # get T + cursor = db.execute( + """ + SELECT publickey, threshold + FROM senderinfo + WHERE senderip = ? + """, + (ip), + ) + result = cursor.fetchall() + pk_sender, T = result[0] # result[0] = (pk, threshold) + if len(cfrag_cts) >= T: cfrags = mergecfrag(cfrag_cts) message = DecryptFrags(sk, pk, pk_sender, cfrags) # type: ignore @@ -170,12 +176,9 @@ async def check_merge(db, ct: int, ip: str): # send message to node -async def send_messages(node_ips: tuple[str, ...], - message: bytes, - dest_ip: str, - pk_B: point, - shreshold: int - ): +async def send_messages( + node_ips: tuple[str, ...], message: bytes, dest_ip: str, pk_B: point, shreshold: int +): global pk, sk id_list = [] for node_ip in node_ips: @@ -184,16 +187,21 @@ async def send_messages(node_ips: tuple[str, ...], for i in range(4): id += int(ip_parts[i]) << (24 - (8 * i)) id_list.append(id) - rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore + rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore + capsule_ct = Encrypt(pk, message) # type: ignore + for i in range(len(node_ips)): - url = "http://" + node_ips[i] + ":8001" + "/recieve_message" + url = "http://" + node_ips[i] + ":8001" + "/user_src?message" + payload = { "source_ip": local_ip, "dest_ip": dest_ip, - "message": message, - "rk": rk_list[i] + "capsule_ct": capsule_ct, + "rk": rk_list[i], } response = requests.post(url, json=payload) + if response.status_code == 200: + print(f"send to {node_ips[i]} successful") return 0 @@ -203,41 +211,63 @@ class IP_Message(BaseModel): source_ip: str pk: int + +class Request_Message(BaseModel): + dest_ip: str + message_name: str + + # request message from others @app.post("/request_message") -async def request_message(i_m: IP_Message): +async def request_message(i_m: Request_Message): global message, node_response, pk dest_ip = i_m.dest_ip + # dest_ip = dest_ip.split(":")[0] message_name = i_m.message_name source_ip = get_own_ip() dest_port = "8003" - url = "http://" + dest_ip + dest_port + "/recieve_request" - payload = {"dest_ip": dest_ip, - "message_name": message_name, - "source_ip": source_ip, - "pk": pk - } - response = requests.post(url, json=payload) - if response.status_code == 200: - data = response.json() - public_key = int(data["public_key"]) - threshold = int(data["threshold"]) - with sqlite3.connect("client.db") as db: - db.execute( - """ - INSERT INTO senderinfo - (public_key, threshold) - VALUES - (?, ?) - """, - (public_key, threshold), - ) + url = "http://" + dest_ip + ":" + dest_port + "/recieve_request?i_m" + payload = { + "dest_ip": dest_ip, + "message_name": message_name, + "source_ip": source_ip, + "pk": pk, + } + try: + response = requests.post(url, json=payload) - # wait to recieve message from nodes + except: + print("can't post") + return {"message": "can't post"} + + try: + if response.status_code == 200: + data = response.json() + public_key = int(data["public_key"]) + threshold = int(data["threshold"]) + with sqlite3.connect("client.db") as db: + db.execute( + """ + INSERT INTO senderinfo + (public_key, threshold) + VALUES + (?, ?) + """, + (public_key, threshold), + ) + except: + print("Database error") + return {"message": "Database Error"} + + # wait 10s to recieve message from nodes for _ in range(10): if node_response: data = message + + # reset message and node_response message = b"" + node_response = False + # return message to frontend return {"message": data} time.sleep(1) @@ -255,17 +285,24 @@ async def recieve_request(i_m: IP_Message): threshold = random.randrange(1, 6) own_public_key = pk pk_B = i_m.pk - + with sqlite3.connect("client.db") as db: - cursor = db.execute(""" + cursor = db.execute( + """ SELECT nodeip FROM node LIMIT ? - """,(threshold,)) + """, + (threshold,), + ) node_ips = cursor.fetchall() + + # message name message = b"hello world" + random.randbytes(8) - await send_messages(node_ips, message, dest_ip, pk_B, threshold) # type: ignore - response = {"threshold": threshold,"public_key": own_public_key} + + # send message to nodes + await send_messages(node_ips, message, dest_ip, pk_B, threshold) # type: ignore + response = {"threshold": threshold, "public_key": own_public_key} return response @@ -300,7 +337,6 @@ def get_node_list(count: int, server_addr: str): print("Failed:", response.status_code, response.text) - pk = point sk = int server_address = str