forked from sangge/tpre-python
		
	Merge pull request 'fix: fix all bug' (#33) from sangge/mimajingsai:main into main
Reviewed-on: dqy/mimajingsai#33
This commit is contained in:
		
							
								
								
									
										112
									
								
								src/client.py
									
									
									
									
									
								
							
							
						
						
									
										112
									
								
								src/client.py
									
									
									
									
									
								
							| @@ -13,7 +13,7 @@ import base64 | |||||||
| import json | import json | ||||||
| import pickle | import pickle | ||||||
| from fastapi.responses import JSONResponse | from fastapi.responses import JSONResponse | ||||||
|  | import asyncio | ||||||
|  |  | ||||||
|  |  | ||||||
| @asynccontextmanager | @asynccontextmanager | ||||||
| @@ -89,6 +89,9 @@ def init_config(): | |||||||
|  |  | ||||||
| # execute on exit | # execute on exit | ||||||
| def clean_env(): | def clean_env(): | ||||||
|  |     global message, node_response | ||||||
|  |     message = b"" | ||||||
|  |     node_response = False | ||||||
|     with sqlite3.connect("client.db") as db: |     with sqlite3.connect("client.db") as db: | ||||||
|         db.execute("DELETE FROM node") |         db.execute("DELETE FROM node") | ||||||
|         db.execute("DELETE FROM message") |         db.execute("DELETE FROM message") | ||||||
| @@ -141,7 +144,7 @@ async def receive_messages(message: C): | |||||||
|                 (bin_C_capsule, str(C_ct), ip), |                 (bin_C_capsule, str(C_ct), ip), | ||||||
|             ) |             ) | ||||||
|             db.commit() |             db.commit() | ||||||
|             await check_merge(C_ct, ip) |             check_merge(C_ct, ip) | ||||||
|             return HTTPException(status_code=200, detail="Message received") |             return HTTPException(status_code=200, detail="Message received") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"Error occurred: {e}") |             print(f"Error occurred: {e}") | ||||||
| @@ -150,13 +153,14 @@ async def receive_messages(message: C): | |||||||
|  |  | ||||||
|  |  | ||||||
| # check record count | # check record count | ||||||
| async def check_merge(ct: int, ip: str): | def check_merge(ct: int, ip: str): | ||||||
|     global sk, pk, node_response, message |     global sk, pk, node_response, message | ||||||
|     """ |     """ | ||||||
|     CREATE TABLE IF NOT EXISTS senderinfo ( |     CREATE TABLE IF NOT EXISTS senderinfo ( | ||||||
|         id INTEGER PRIMARY KEY, |         id INTEGER PRIMARY KEY, | ||||||
|         ip TEXT, |         ip TEXT, | ||||||
|         publickey TEXT, |         pkx TEXT, | ||||||
|  |         pky TEXT, | ||||||
|         threshold INTEGER |         threshold INTEGER | ||||||
|     ) |     ) | ||||||
|     """ |     """ | ||||||
| @@ -173,10 +177,10 @@ async def check_merge(ct: int, ip: str): | |||||||
|         # [(capsule, ct), ...] |         # [(capsule, ct), ...] | ||||||
|         cfrag_cts = cursor.fetchall() |         cfrag_cts = cursor.fetchall() | ||||||
|  |  | ||||||
|         # get T |         # get _sender_pk | ||||||
|         cursor = db.execute( |         cursor = db.execute( | ||||||
|             """ |             """ | ||||||
|         SELECT publickey, threshold  |         SELECT pkx, pky | ||||||
|         FROM senderinfo |         FROM senderinfo | ||||||
|         WHERE ip = ? |         WHERE ip = ? | ||||||
|         """, |         """, | ||||||
| @@ -184,22 +188,32 @@ async def check_merge(ct: int, ip: str): | |||||||
|         ) |         ) | ||||||
|         result = cursor.fetchall() |         result = cursor.fetchall() | ||||||
|         try: |         try: | ||||||
|             pk_sender, T = result[0]  # result[0] = (pk, threshold) |             pkx, pky = result[0]  # result[0] = (pkx, pky) | ||||||
|  |             pk_sender = (int(pkx), int(pky)) | ||||||
|         except: |         except: | ||||||
|             pk_sender, T = 0, -1 |             pk_sender, T = 0, -1 | ||||||
|  |  | ||||||
|     if len(cfrag_cts) <= T: |     T = 2 | ||||||
|         print(T) |     if len(cfrag_cts) >= T: | ||||||
|         # Deserialization |         # Deserialization | ||||||
|         temp_cfrag_cts = [] |         temp_cfrag_cts = [] | ||||||
|         for i in cfrag_cts: |         for i in cfrag_cts: | ||||||
|             capsule = pickle.loads(i[0]) |             capsule = pickle.loads(i[0]) | ||||||
|             temp_cfrag_cts.append((capsule, int(i[1]))) |             temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(32))) | ||||||
|  |  | ||||||
|         cfrags = mergecfrag(temp_cfrag_cts) |         cfrags = mergecfrag(temp_cfrag_cts) | ||||||
|         message = DecryptFrags(sk, pk, pk_sender, cfrags)  # type: ignore |  | ||||||
|  |         print("sk:", type(sk)) | ||||||
|  |         print("pk:", type(pk)) | ||||||
|  |         print("pk_sender:", type(pk_sender)) | ||||||
|  |         print("cfrags:", type(cfrags)) | ||||||
|  |         message = DecryptFrags(sk, pk, pk_sender, cfrags) | ||||||
|  |  | ||||||
|  |         print("merge success", message) | ||||||
|         node_response = True |         node_response = True | ||||||
|          |          | ||||||
|  |         print("merge:", node_response) | ||||||
|  |  | ||||||
|  |  | ||||||
| # send message to node | # send message to node | ||||||
| async def send_messages( | async def send_messages( | ||||||
| @@ -276,34 +290,9 @@ async def request_message(i_m: Request_Message): | |||||||
|         # content = {"message": "post timeout", "error": str(e)} |         # content = {"message": "post timeout", "error": str(e)} | ||||||
|         # return JSONResponse(content, status_code=400) |         # return JSONResponse(content, status_code=400) | ||||||
|  |  | ||||||
|     try: |  | ||||||
|         url = "http://" + dest_ip + ":" + dest_port + "/get_pk" |  | ||||||
|         print(url) |  | ||||||
|         response = requests.get(url,timeout=4) |  | ||||||
|         print(response.text) |  | ||||||
|         if response.status_code == 200: |  | ||||||
|             data = response.json() |  | ||||||
|             pkx = int(data["pkx"]) |  | ||||||
|             pky = int(data["pky"]) |  | ||||||
|             public_key = (pkx, pky) |  | ||||||
|             threshold = 2 |  | ||||||
|             with sqlite3.connect("client.db") as db: |  | ||||||
|                 db.execute( |  | ||||||
|                     """ |  | ||||||
|             INSERT INTO senderinfo |  | ||||||
|             (ip, public_key, threshold) |  | ||||||
|             VALUES |  | ||||||
|             (?, ?, ?) |  | ||||||
|             """, |  | ||||||
|                     (str(dest_ip), public_key, threshold), |  | ||||||
|                 ) |  | ||||||
|     except Exception as e: |  | ||||||
|         print("Database error") |  | ||||||
|         content = {"message": "Database Error","error": str(e)} |  | ||||||
|         return JSONResponse(content, status_code=400) |  | ||||||
|  |  | ||||||
|     # wait 3s to receive message from nodes |     # wait 3s to receive message from nodes | ||||||
|     for _ in range(3): |     for _ in range(10): | ||||||
|  |         print("wait:", node_response) | ||||||
|         if node_response: |         if node_response: | ||||||
|             data = message |             data = message | ||||||
|              |              | ||||||
| @@ -312,8 +301,8 @@ async def request_message(i_m: Request_Message): | |||||||
|             node_response = False |             node_response = False | ||||||
|  |  | ||||||
|             # return message to frontend |             # return message to frontend | ||||||
|             return {"message": data} |             return {"message": str(data)} | ||||||
|         time.sleep(1) |         await asyncio.sleep(0.2) | ||||||
|     content = {"message": "receive timeout"} |     content = {"message": "receive timeout"} | ||||||
|     return JSONResponse(content, status_code=400) |     return JSONResponse(content, status_code=400) | ||||||
|  |  | ||||||
| @@ -382,14 +371,49 @@ def get_node_list(count: int, server_addr: str): | |||||||
|     else: |     else: | ||||||
|         print("Failed:", response.status_code, response.text) |         print("Failed:", response.status_code, response.text) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # send pk to others | ||||||
| @app.get("/get_pk") | @app.get("/get_pk") | ||||||
| async def get_pk(): | async def get_pk(): | ||||||
|     global pk |     global pk, sk | ||||||
|  |     print(sk) | ||||||
|     return {"pkx": str(pk[0]), "pky": str(pk[1])} |     return {"pkx": str(pk[0]), "pky": str(pk[1])} | ||||||
|  |  | ||||||
|  |  | ||||||
| pk = point | class pk_model(BaseModel): | ||||||
| sk = int |     pkx: str | ||||||
|  |     pky: str | ||||||
|  |     ip: str | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # recieve pk from frontend | ||||||
|  | @app.post("/recieve_pk") | ||||||
|  | async def recieve_pk(pk: pk_model): | ||||||
|  |     pkx = pk.pkx | ||||||
|  |     pky = pk.pky | ||||||
|  |     dest_ip = pk.ip | ||||||
|  |     try: | ||||||
|  |         threshold = 2 | ||||||
|  |         with sqlite3.connect("client.db") as db: | ||||||
|  |             db.execute( | ||||||
|  |                 """ | ||||||
|  |         INSERT INTO senderinfo | ||||||
|  |         (ip, pkx, pky, threshold) | ||||||
|  |         VALUES | ||||||
|  |         (?, ?, ?, ?) | ||||||
|  |         """, | ||||||
|  |                 (str(dest_ip), pkx, pky, threshold), | ||||||
|  |             ) | ||||||
|  |     except Exception as e: | ||||||
|  |         # raise error | ||||||
|  |         print("Database error") | ||||||
|  |         content = {"message": "Database Error", "error": str(e)} | ||||||
|  |         return JSONResponse(content, status_code=400) | ||||||
|  |     return {"message": "save pk in database"} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | pk = (0, 0) | ||||||
|  | sk = 0 | ||||||
| server_address = str | server_address = str | ||||||
| node_response = False | node_response = False | ||||||
| message = bytes | message = bytes | ||||||
| @@ -398,4 +422,4 @@ local_ip = get_own_ip() | |||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import uvicorn  # pylint: disable=e0401 |     import uvicorn  # pylint: disable=e0401 | ||||||
|  |  | ||||||
|     uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True,log_level="debug") |     uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True, log_level="debug") | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| import argparse | import argparse | ||||||
| import requests | import requests | ||||||
|  | import json | ||||||
|  |  | ||||||
|  |  | ||||||
| def send_post_request(ip_addr, message_name): | def send_post_request(ip_addr, message_name): | ||||||
| @@ -9,13 +10,30 @@ def send_post_request(ip_addr, message_name): | |||||||
|     return response.text |     return response.text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_pk(ip_addr): | ||||||
|  |     url = f"http://" + ip_addr + ":8002/get_pk" | ||||||
|  |     response = requests.get(url, timeout=1) | ||||||
|  |     print(response.text) | ||||||
|  |     json_pk = json.loads(response.text) | ||||||
|  |     payload = {"pkx": json_pk["pkx"], "pky": json_pk["pky"], "ip": ip_addr} | ||||||
|  |     response = requests.post("http://localhost:8002/recieve_pk", json=payload) | ||||||
|  |  | ||||||
|  |     return response.text | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|     parser = argparse.ArgumentParser(description="Send POST request to a specified IP.") |     parser = argparse.ArgumentParser(description="Send POST request to a specified IP.") | ||||||
|     parser.add_argument("ip_addr", help="IP address to send request to.") |     parser.add_argument("ip_addr", help="IP address to send request to.") | ||||||
|     parser.add_argument("message_name", help="Message name to send.") |     parser.add_argument("message_name", help="Message name to send.") | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |      | ||||||
|  |     response = get_pk(args.ip_addr) | ||||||
|  |     print(response) | ||||||
|  |      | ||||||
|     response = send_post_request(args.ip_addr, args.message_name) |     response = send_post_request(args.ip_addr, args.message_name) | ||||||
|  |      | ||||||
|     print(response) |     print(response) | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								src/client_demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								src/client_demo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | from tpre import * | ||||||
|  |  | ||||||
|  | # local {"pkx":"110913495319893280527511520027612816833094668640322629943553195742251267532611","pky":"42442813417048462506373786007682778510807282038950736216326706485290996455738"} | ||||||
|  | # pkb (110913495319893280527511520027612816833094668640322629943553195742251267532611,42442813417048462506373786007682778510807282038950736216326706485290996455738 | ||||||
| @@ -272,7 +272,9 @@ def f(x: int, f_modulus: list, T: int) -> int: | |||||||
|     return res |     return res | ||||||
|  |  | ||||||
|  |  | ||||||
| def GenerateReKey(sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int,...]) -> list: | def GenerateReKey( | ||||||
|  |     sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...] | ||||||
|  | ) -> list: | ||||||
|     """ |     """ | ||||||
|     param: |     param: | ||||||
|     skA, pkB, N(节点总数), T(阈值) |     skA, pkB, N(节点总数), T(阈值) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user