diff --git a/src/client.py b/src/client.py index 4c27c44..ea0ad94 100644 --- a/src/client.py +++ b/src/client.py @@ -1,4 +1,3 @@ -<<<<<<< HEAD from fastapi import FastAPI, HTTPException import requests import os @@ -81,21 +80,11 @@ def init_db(): def init_config(): import configparser - print("Starting function: init_config") - global server_address config = configparser.ConfigParser() - print("Attempting to read client.ini...") config.read("client.ini") - if "settings" in config and "server_address" in config["settings"]: - server_address = config["settings"]["server_address"] - print(f"Config loaded successfully. Server address: {server_address}") - else: - print("Error: 'settings' section or 'server_address' key not found in client.ini") - - print("Function init_config executed successfully!") - + server_address = config["settings"]["server_address"] # execute on exit @@ -106,6 +95,7 @@ def clean_env(): with sqlite3.connect("client.db") as db: db.execute("DELETE FROM node") db.execute("DELETE FROM message") + db.execute("DELETE FROM senderinfo") db.commit() print("Exit app") @@ -132,26 +122,22 @@ async def receive_messages(message: C): return: status_code """ - - print("Starting function: receive_messages") - + print(f"Received message: {message}") + if not message.Tuple or not message.ip: - print("Invalid input data") + print("Invalid input data received.") raise HTTPException(status_code=400, detail="Invalid input data") C_capsule, C_ct = message.Tuple ip = message.ip - print(f"Received message: Capsule = {C_capsule}, C_ct = {C_ct}, IP = {ip}") # Serialization - print("Serializing the capsule...") bin_C_capsule = pickle.dumps(C_capsule) - print("Serialization successful") + # insert record into database with sqlite3.connect("client.db") as db: try: - print("Attempting to insert data into 'message' table...") db.execute( """ INSERT INTO message @@ -162,30 +148,29 @@ async def receive_messages(message: C): (bin_C_capsule, str(C_ct), ip), ) db.commit() - print("Data insertion successful") - + print("Data inserted successfully into database.") check_merge(C_ct, ip) - print("check_merge executed successfully") - return HTTPException(status_code=200, detail="Message received") except Exception as e: print(f"Error occurred: {e}") db.rollback() return HTTPException(status_code=400, detail="Database error") - print("Function receive_messages executed successfully!") - - # check record count def check_merge(ct: int, ip: str): - print("Starting function: check_merge") - global sk, pk, node_response, message - + """ + CREATE TABLE IF NOT EXISTS senderinfo ( + id INTEGER PRIMARY KEY, + ip TEXT, + pkx TEXT, + pky TEXT, + threshold INTEGER + ) + """ with sqlite3.connect("client.db") as db: # Check if the combination of ct_column and ip_column appears more than once. - print("Fetching data from 'message' table...") cursor = db.execute( """ SELECT capsule, ct @@ -194,11 +179,10 @@ def check_merge(ct: int, ip: str): """, (str(ct), ip), ) + # [(capsule, ct), ...] cfrag_cts = cursor.fetchall() - print(f"Number of records fetched from 'message' table: {len(cfrag_cts)}") # get _sender_pk - print("Fetching sender's public key...") cursor = db.execute( """ SELECT pkx, pky @@ -211,10 +195,8 @@ def check_merge(ct: int, ip: str): try: pkx, pky = result[0] # result[0] = (pkx, pky) pk_sender = (int(pkx), int(pky)) - print(f"Successfully fetched sender's public key: {pk_sender}") except: pk_sender, T = 0, -1 - print("Failed to fetch sender's public key") T = 2 if len(cfrag_cts) >= T: @@ -223,38 +205,29 @@ def check_merge(ct: int, ip: str): for i in cfrag_cts: capsule = pickle.loads(i[0]) temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(32))) - print("Deserialization completed") cfrags = mergecfrag(temp_cfrag_cts) - print("Attempting decryption...") print("sk:", type(sk)) print("pk:", type(pk)) print("pk_sender:", type(pk_sender)) print("cfrags:", type(cfrags)) message = DecryptFrags(sk, pk, pk_sender, cfrags) - print(f"Decryption successful, message: {message}") + print("merge success", message) node_response = True - print(f"Node response set to: {node_response}") - else: - print("Insufficient number of cfrag_cts, skipping decryption") - - print("Function check_merge executed successfully!") + print("merge:", node_response) # send message to node async def send_messages( node_ips: tuple[str, ...], message: bytes, dest_ip: str, pk_B: point, shreshold: int ): - print("Starting function: send_messages") - global pk, sk id_list = [] # calculate id of nodes - print("Calculating ID of nodes...") for node_ip in node_ips: node_ip = node_ip[0] ip_parts = node_ip.split(".") @@ -262,12 +235,8 @@ async def send_messages( for i in range(4): id += int(ip_parts[i]) << (24 - (8 * i)) id_list.append(id) - print(f"Calculated IDs: {id_list}") - # generate rk - print("Generating rekey...") rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore - print(f"Generated ReKey: {rk_list}") capsule, ct = Encrypt(pk, message) # type: ignore # capsule_ct = (capsule, int.from_bytes(ct)) @@ -281,20 +250,14 @@ async def send_messages( "ct": int.from_bytes(ct), "rk": rk_list[i], } - print(f"Sending payload to {url}:") print(json.dumps(payload)) response = requests.post(url, json=payload) if response.status_code == 200: print(f"send to {node_ips[i]} successful") - else: - print(f"send to {node_ips[i]} failed with status code {response.status_code}") - - print("Function send_messages executed successfully!") return 0 - class IP_Message(BaseModel): dest_ip: str message_name: str @@ -310,8 +273,6 @@ class Request_Message(BaseModel): # request message from others @app.post("/request_message") async def request_message(i_m: Request_Message): - print("Starting function: request_message") - global message, node_response, pk dest_ip = i_m.dest_ip # dest_ip = dest_ip.split(":")[0] @@ -325,88 +286,64 @@ async def request_message(i_m: Request_Message): "source_ip": source_ip, "pk": pk, } - - print(f"Requesting message from: {url}") try: response = requests.post(url, json=payload, timeout=1) - print(f"Response received from {url}: {response.text}") + # print("menxian and pk", response.text) - except requests.Timeout as e: - print(f"Request to {url} timed out!") + except requests.Timeout: print("can't post") + # content = {"message": "post timeout", "error": str(e)} + # return JSONResponse(content, status_code=400) - # wait 2s to receive message from nodes + # wait 3s to receive message from nodes for _ in range(10): - print(f"Waiting for response... (iteration {_ + 1})") - print("Current node_response:", node_response) + print("wait:", node_response) if node_response: data = message - + # reset message and node_response - print("Resetting message and node_response...") message = b"" node_response = False # return message to frontend - print("Returning message to frontend:", str(data)) return {"message": str(data)} await asyncio.sleep(0.2) - - print("Timeout occurred while waiting for response.") content = {"message": "receive timeout"} return JSONResponse(content, status_code=400) - -# request message from others -@app.post("/request_message") -async def request_message(i_m: Request_Message): - print("Starting function: request_message") - - global message, node_response, pk - dest_ip = i_m.dest_ip - # dest_ip = dest_ip.split(":")[0] - message_name = i_m.message_name +# receive request from others +@app.post("/receive_request") +async def receive_request(i_m: IP_Message): + global pk source_ip = get_own_ip() - dest_port = "8002" - url = "http://" + dest_ip + ":" + dest_port + "/receive_request" - payload = { - "dest_ip": dest_ip, - "message_name": message_name, - "source_ip": source_ip, - "pk": pk, - } + if source_ip != i_m.dest_ip: + return HTTPException(status_code=400, detail="Wrong ip") + dest_ip = i_m.source_ip + # threshold = random.randrange(1, 2) + threshold = 2 + own_public_key = pk + pk_B = i_m.pk - print(f"Requesting message from: {url}") - try: - response = requests.post(url, json=payload, timeout=1) - print(f"Response received from {url}: {response.text}") + with sqlite3.connect("client.db") as db: + cursor = db.execute( + """ + SELECT nodeip + FROM node + LIMIT ? + """, + (threshold,), + ) + node_ips = cursor.fetchall() - except requests.Timeout as e: - print(f"Request to {url} timed out!") - print("can't post") - - # wait 2s to receive message from nodes - for _ in range(10): - print(f"Waiting for response... (iteration {_ + 1})") - print("Current node_response:", node_response) - if node_response: - data = message - - # reset message and node_response - print("Resetting message and node_response...") - message = b"" - node_response = False - - # return message to frontend - print("Returning message to frontend:", str(data)) - return {"message": str(data)} - await asyncio.sleep(0.2) - - print("Timeout occurred while waiting for response.") - content = {"message": "receive timeout"} - return JSONResponse(content, status_code=400) + # message name + message = b"hello world" + random.randbytes(8) + # send message to nodes + await send_messages(tuple(node_ips), message, dest_ip, pk_B, threshold) + response = {"threshold": threshold, "public_key": own_public_key} + print("###############RESPONSE = ", response) + return response def get_own_ip() -> str: @@ -453,21 +390,16 @@ class pk_model(BaseModel): pky: str ip: str + +# recieve pk from frontend @app.post("/recieve_pk") async def recieve_pk(pk: pk_model): - print("Starting function: recieve_pk") - pkx = pk.pkx pky = pk.pky dest_ip = pk.ip - - print(f"Received pkx: {pkx}, pky: {pky}, IP: {dest_ip}") - try: threshold = 2 - print("Connecting to client.db...") with sqlite3.connect("client.db") as db: - print("Connected to client.db, inserting data...") db.execute( """ INSERT INTO senderinfo @@ -477,16 +409,14 @@ async def recieve_pk(pk: pk_model): """, (str(dest_ip), pkx, pky, threshold), ) - print("Data inserted successfully!") except Exception as e: # raise error - print("Database error:", str(e)) + print("Database error") content = {"message": "Database Error", "error": str(e)} return JSONResponse(content, status_code=400) - - print("Function recieve_pk executed successfully!") return {"message": "save pk in database"} + pk = (0, 0) sk = 0 server_address = str @@ -498,5 +428,3 @@ if __name__ == "__main__": import uvicorn # pylint: disable=e0401 uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True, log_level="debug") -======= ->>>>>>> parent of 7b6e456 (feat: init client)