diff --git a/doc/README_app_en.md b/doc/README_app_en.md index d24eab1..4a69b6d 100644 --- a/doc/README_app_en.md +++ b/doc/README_app_en.md @@ -4,7 +4,7 @@ /request_node get method -pr + docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=110.41.130.197 git.mamahaha.work/sangge/tpre:base bash diff --git a/src/client.py b/src/client.py index 012d1dd..85a5634 100644 --- a/src/client.py +++ b/src/client.py @@ -95,6 +95,7 @@ def clean_env(): with sqlite3.connect("client.db") as db: db.execute("DELETE FROM node") db.execute("DELETE FROM message") + db.execute("DELETE FROM senderinfo") db.commit() print("Exit app") @@ -121,8 +122,10 @@ async def receive_messages(message: C): return: status_code """ + print(f"Received message: {message}") if not message.Tuple or not message.ip: + print("Invalid input data received.") raise HTTPException(status_code=400, detail="Invalid input data") C_capsule, C_ct = message.Tuple @@ -144,6 +147,7 @@ async def receive_messages(message: C): (bin_C_capsule, str(C_ct), ip), ) db.commit() + print("Data inserted successfully into database.") check_merge(C_ct, ip) return HTTPException(status_code=200, detail="Message received") except Exception as e: @@ -211,7 +215,7 @@ def check_merge(ct: int, ip: str): print("merge success", message) node_response = True - + print("merge:", node_response) @@ -230,11 +234,13 @@ async def send_messages( for i in range(4): id += int(ip_parts[i]) << (24 - (8 * i)) id_list.append(id) + print(f"Calculated IDs: {id_list}") # generate rk rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore - + print(f"Generated ReKey list: {rk_list}") capsule, ct = Encrypt(pk, message) # type: ignore # capsule_ct = (capsule, int.from_bytes(ct)) + print(f"Encrypted message to capsule={capsule}, ct={ct}") for i in range(len(node_ips)): url = "http://" + node_ips[i][0] + ":8001" + "/user_src" @@ -245,11 +251,15 @@ async def send_messages( "ct": int.from_bytes(ct), "rk": rk_list[i], } - print(json.dumps(payload)) + print(f"Sending payload to {url}: {json.dumps(payload)}") response = requests.post(url, json=payload) if response.status_code == 200: print(f"send to {node_ips[i]} successful") + else: + print( + f"Failed to send to {node_ips[i]}. Response code: {response.status_code}, Response text: {response.text}" + ) return 0 @@ -269,6 +279,9 @@ class Request_Message(BaseModel): @app.post("/request_message") async def request_message(i_m: Request_Message): global message, node_response, pk + print( + f"Function 'request_message' called with: dest_ip={i_m.dest_ip}, message_name={i_m.message_name}" + ) dest_ip = i_m.dest_ip # dest_ip = dest_ip.split(":")[0] message_name = i_m.message_name @@ -281,21 +294,25 @@ async def request_message(i_m: Request_Message): "source_ip": source_ip, "pk": pk, } + print(f"Sending request to {url} with payload: {payload}") try: response = requests.post(url, json=payload, timeout=1) + print(f"Response received from {url}: {response.text}") # print("menxian and pk", response.text) except requests.Timeout: - print("can't post") + print("Timeout error: can't post to the destination.") + # print("can't post") # content = {"message": "post timeout", "error": str(e)} # return JSONResponse(content, status_code=400) # wait 3s to receive message from nodes for _ in range(10): - print("wait:", node_response) + print(f"Waiting for node_response... Current value: {node_response}") + # print("wait:", node_response) if node_response: data = message - + print(f"Node response received with message: {data}") # reset message and node_response message = b"" node_response = False @@ -303,6 +320,7 @@ async def request_message(i_m: Request_Message): # return message to frontend return {"message": str(data)} await asyncio.sleep(0.2) + print("Timeout while waiting for node_response.") content = {"message": "receive timeout"} return JSONResponse(content, status_code=400) @@ -311,14 +329,20 @@ async def request_message(i_m: Request_Message): @app.post("/receive_request") async def receive_request(i_m: IP_Message): global pk + print( + f"Function 'receive_request' called with: dest_ip={i_m.dest_ip}, source_ip={i_m.source_ip}, pk={i_m.pk}" + ) source_ip = get_own_ip() + print(f"Own IP: {source_ip}") if source_ip != i_m.dest_ip: + print("Mismatch in destination IP.") return HTTPException(status_code=400, detail="Wrong ip") dest_ip = i_m.source_ip # threshold = random.randrange(1, 2) threshold = 2 own_public_key = pk pk_B = i_m.pk + print(f"Using own public key: {own_public_key} and received public key: {pk_B}") with sqlite3.connect("client.db") as db: cursor = db.execute( @@ -330,16 +354,18 @@ async def receive_request(i_m: IP_Message): (threshold,), ) node_ips = cursor.fetchall() + print(f"Selected node IPs from database: {node_ips}") # message name # message_name = i_m.message_name # message = xxxxx message = b"hello world" + random.randbytes(8) + print(f"Generated message: {message}") # send message to nodes await send_messages(tuple(node_ips), message, dest_ip, pk_B, threshold) response = {"threshold": threshold, "public_key": own_public_key} - print("###############RESPONSE = ", response) + print(f"Sending response: {response}") return response diff --git a/src/demo.py b/src/demo.py index 0f0646a..a5d00ed 100644 --- a/src/demo.py +++ b/src/demo.py @@ -1,60 +1,71 @@ from tpre import * import time -for N in range(4,21,4): - # N = 10 - # T = 5 - T = N // 2 - print(f"当前门限值: N = {N}, T = {T}") - - start_total_time = time.time() - # 1 - start_time = time.time() - pk_a, sk_a = GenerateKeyPair() - m = b"hello world" - end_time = time.time() - elapsed_time = end_time - start_time - print(f"密钥生成运行时间:{elapsed_time}秒") +# for T in range(2, 20, 2): +N = 10 +T = N // 2 +# print(f"当前门限值: N = {N}, T = {T}") - # 2 - start_time = time.time() - capsule_ct = Encrypt(pk_a, m) - end_time = time.time() - elapsed_time = end_time - start_time - print(f"加密算法运行时间:{elapsed_time}秒") +start_total_time = time.time() +# 1 +start_time = time.time() +pk_a, sk_a = GenerateKeyPair() +# print("pk_a: ", pk_a) +# print("sk_a: ", sk_a) +end_time = time.time() +elapsed_time = end_time - start_time +# print(f"密钥生成运行时间:{elapsed_time}秒") - # 3 - pk_b, sk_b = GenerateKeyPair() +# 2 +start_time = time.time() +m = b"hello world" +capsule_ct = Encrypt(pk_a, m) +capsule = capsule_ct[0] +print("check capsule: ", Checkcapsule(capsule)) +capsule = (capsule[0], capsule[1], -1) +print("check capsule: ", Checkcapsule(capsule)) +# print("capsule_ct: ", capsule_ct) +end_time = time.time() +elapsed_time = end_time - start_time +# print(f"加密算法运行时间:{elapsed_time}秒") - - # 5 - start_time = time.time() - id_tuple = tuple(range(N)) - rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) - end_time = time.time() - elapsed_time = end_time - start_time - print(f"重加密密钥生成算法运行时间:{elapsed_time}秒") +# 3 +pk_b, sk_b = GenerateKeyPair() - # 7 - start_time = time.time() - cfrag_cts = [] +# 5 +start_time = time.time() +id_tuple = tuple(range(N)) +rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) +# print("rekeys: ", rekeys) +end_time = time.time() +elapsed_time = end_time - start_time +# print(f"重加密密钥生成算法运行时间:{elapsed_time}秒") - for rekey in rekeys: - cfrag_ct = ReEncrypt(rekey, capsule_ct) - cfrag_cts.append(cfrag_ct) - end_time = time.time() - elapsed_time = (end_time - start_time) / len(rekeys) - print(f"重加密算法运行时间:{elapsed_time}秒") +# 7 +start_time = time.time() +cfrag_cts = [] - # 9 - start_time = time.time() - cfrags = mergecfrag(cfrag_cts) - m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) - end_time = time.time() - elapsed_time = end_time - start_time - end_total_time = time.time() - total_time = end_total_time - start_total_time - print(f"解密算法运行时间:{elapsed_time}秒") - print("成功解密:", m) - print(f"算法总运行时间:{total_time}秒") - print() +for rekey in rekeys: + cfrag_ct = ReEncrypt(rekey, capsule_ct) + # cfrag_ct = ReEncrypt(rekeys[0], capsule_ct) + cfrag_cts.append(cfrag_ct) +# print("cfrag_cts: ", cfrag_cts) +end_time = time.time() +re_elapsed_time = (end_time - start_time) / len(rekeys) +# print(f"重加密算法运行时间:{re_elapsed_time}秒") + +# 9 +start_time = time.time() +cfrags = mergecfrag(cfrag_cts) +# print("cfrags: ", cfrags) +# m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) +m = DecryptFrags(sk_a, pk_b, pk_a, cfrags) +# print("m = ", m) +end_time = time.time() +elapsed_time = end_time - start_time +end_total_time = time.time() +total_time = end_total_time - start_total_time - re_elapsed_time * len(rekeys) +# print(f"解密算法运行时间:{elapsed_time}秒") +# print("成功解密:", m) +# print(f"算法总运行时间:{total_time}秒") +# print() diff --git a/src/demo2.py b/src/demo2.py new file mode 100644 index 0000000..6f466f5 --- /dev/null +++ b/src/demo2.py @@ -0,0 +1,93 @@ +from tpre import * +import time +import openpyxl + +# 初始化Excel工作簿和工作表 +wb = openpyxl.Workbook() +ws = wb.active +ws.title = "算法性能结果" +headers = [ + "门限值 N", + "门限值 T", + "密钥生成运行时间", + "加密算法运行时间", + "重加密密钥生成算法运行时间", + "重加密算法运行时间", + "解密算法运行时间", + "算法总运行时间", +] +ws.append(headers) + + +for N in range(4, 21, 2): + T = N // 2 + print(f"当前门限值: N = {N}, T = {T}") + + start_total_time = time.time() + # 1 + start_time = time.time() + pk_a, sk_a = GenerateKeyPair() + m = b"hello world" + end_time = time.time() + elapsed_time_key_gen = end_time - start_time + print(f"密钥生成运行时间:{elapsed_time_key_gen}秒") + + # ... [中间代码不变] + # 2 + start_time = time.time() + capsule_ct = Encrypt(pk_a, m) + end_time = time.time() + elapsed_time_enc = end_time - start_time + print(f"加密算法运行时间:{elapsed_time_enc}秒") + + # 3 + pk_b, sk_b = GenerateKeyPair() + + # 5 + start_time = time.time() + id_tuple = tuple(range(N)) + rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) + end_time = time.time() + elapsed_time_rekey_gen = end_time - start_time + print(f"重加密密钥生成算法运行时间:{elapsed_time_rekey_gen}秒") + + # 7 + start_time = time.time() + cfrag_cts = [] + + for rekey in rekeys: + cfrag_ct = ReEncrypt(rekey, capsule_ct) + cfrag_cts.append(cfrag_ct) + end_time = time.time() + re_elapsed_time = (end_time - start_time) / len(rekeys) + print(f"重加密算法运行时间:{re_elapsed_time}秒") + + # 9 + start_time = time.time() + cfrags = mergecfrag(cfrag_cts) + m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) + end_time = time.time() + elapsed_time_dec = end_time - start_time + end_total_time = time.time() + total_time = end_total_time - start_total_time - re_elapsed_time * len(rekeys) + print(f"解密算法运行时间:{elapsed_time_dec}秒") + print("成功解密:", m) + print(f"算法总运行时间:{total_time}秒") + print() + + # 将结果保存到Excel + ws.append( + [ + N, + T, + elapsed_time_key_gen, + elapsed_time_enc, + elapsed_time_rekey_gen, + re_elapsed_time, + elapsed_time_dec, + total_time, + ] + ) + +# 保存Excel文件 +wb.save("结果.xlsx") diff --git a/src/node.py b/src/node.py index 0388e57..7d4e061 100644 --- a/src/node.py +++ b/src/node.py @@ -36,6 +36,7 @@ def send_ip(): # ip = get_local_ip() # type: ignore global id id = requests.get(url, timeout=3) + print("中心服务器返回节点ID为: ", id) # 用环境变量获取本机ip @@ -84,7 +85,10 @@ class Req(BaseModel): @app.post("/user_src") # 接收用户1发送的信息 async def user_src(message: Req): global client_ip_src, client_ip_des - # kfrag , capsule_ct ,client_ip_src , client_ip_des = json_data[] # 看梁俊勇 + print( + f"Function 'user_src' called with: source_ip={message.source_ip}, dest_ip={message.dest_ip}, capsule={message.capsule}, ct={message.ct}, rk={message.rk}" + ) + # kfrag , capsule_ct ,client_ip_src , client_ip_des = json_data[] """ payload = { "source_ip": local_ip, @@ -100,10 +104,12 @@ async def user_src(message: Req): ct = message.ct capsule_ct = (capsule, ct.to_bytes(32)) rk = message.rk - + print(f"Computed capsule_ct: {capsule_ct}") a, b = ReEncrypt(rk, capsule_ct) processed_message = (a, int.from_bytes(b)) + print(f"Re-encrypted message: {processed_message}") await send_user_des_message(source_ip, dest_ip, processed_message) + print("Message sent to destination user.") return HTTPException(status_code=200, detail="message recieved") @@ -114,10 +120,10 @@ async def send_user_des_message(source_ip: str, dest_ip: str, re_message): # response = requests.post( "http://" + dest_ip + ":8002" + "/receive_messages", json=data ) - print("send stauts:" ,response.text) + print("send stauts:", response.text) if __name__ == "__main__": import uvicorn # pylint: disable=e0401 - uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True,log_level="debug") + uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True, log_level="debug") diff --git a/src/server.py b/src/server.py index f0530cf..3e66739 100644 --- a/src/server.py +++ b/src/server.py @@ -65,9 +65,11 @@ async def get_node(ip: str) -> int: ip_int = 0 for i in range(4): ip_int += int(ip_parts[i]) << (24 - (8 * i)) + print("IP", ip, "对应的ID为", ip_int) # 获取当前时间 current_time = int(time.time()) + print("当前时间: ", current_time) # 插入数据 cursor.execute( @@ -102,6 +104,7 @@ async def delete_node(ip: str) -> None: # 接收节点心跳包 @app.get("/server/heartbeat") async def receive_heartbeat(ip: str): + print("收到来自", ip, "的心跳包") cursor.execute( "UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip) ) @@ -112,7 +115,9 @@ async def receive_heartbeat_internal(): while 1: timeout = 70 # 删除超时的节点 - cursor.execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,)) + cursor.execute( + "DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,) + ) conn.commit() await asyncio.sleep(timeout) @@ -135,6 +140,8 @@ async def send_nodes_list(count: int) -> list: id, ip, last_heartbeat = row nodes_list.append(ip) + print("收到来自客户端的节点列表请求...") + print(nodes_list) return nodes_list