forked from sangge/tpre-python
		
	Merge pull request 'main' (#32) from sangge/mimajingsai:main into main
Reviewed-on: dqy/mimajingsai#32
This commit is contained in:
		| @@ -52,6 +52,7 @@ pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple | |||||||
|  |  | ||||||
| ### Use base image and build yourself | ### Use base image and build yourself | ||||||
| ```bash | ```bash | ||||||
|  | docker build . -f basedockerfile -t git.mamahaha.work/sangge/tpre:base | ||||||
| docker pull git.mamahaha.work/sangge/tpre:base   | docker pull git.mamahaha.work/sangge/tpre:base   | ||||||
| docker build . -t your_image_name | docker build . -t your_image_name | ||||||
| docker run your_image_name | docker run your_image_name | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ pr | |||||||
| docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=110.41.130.197 git.mamahaha.work/sangge/tpre:base bash   | docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=110.41.130.197 git.mamahaha.work/sangge/tpre:base bash   | ||||||
|  |  | ||||||
|  |  | ||||||
| tpre3: docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=110.41.21.35 git.mamahaha.work/sangge/tpre:base bash | tpre3: docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/app -e HOST_IP=60.204.233.103 git.mamahaha.work/sangge/tpre:base bash | ||||||
|  |  | ||||||
|  |  | ||||||
| 110.41.155.96 tpre1   | 110.41.155.96 tpre1   | ||||||
| @@ -18,3 +18,10 @@ tpre3: docker run -it -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ~/mimajingsai:/a | |||||||
| 110.41.21.35 tpre3 | 110.41.21.35 tpre3 | ||||||
|  |  | ||||||
| python client_cli.py 110.41.21.35 aaa | python client_cli.py 110.41.21.35 aaa | ||||||
|  |  | ||||||
|  |  | ||||||
|  | apt update && apt install docker.io mosh -y | ||||||
|  |  | ||||||
|  | 60.204.236.38 tpre1 | ||||||
|  | 1.94.42.18 tpre2 | ||||||
|  | 60.204.233.103 tpre3 | ||||||
|   | |||||||
| @@ -1,3 +1,3 @@ | |||||||
| [settings] | [settings] | ||||||
| server_address = 110.41.155.96:8000 | server_address = 60.204.236.38:8000 | ||||||
| version = 1.0 | version = 1.0 | ||||||
|   | |||||||
| @@ -12,6 +12,8 @@ import time | |||||||
| import base64 | import base64 | ||||||
| import json | import json | ||||||
| import pickle | import pickle | ||||||
|  | from fastapi.responses import JSONResponse | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @asynccontextmanager | @asynccontextmanager | ||||||
| @@ -64,7 +66,8 @@ def init_db(): | |||||||
|             CREATE TABLE IF NOT EXISTS senderinfo ( |             CREATE TABLE IF NOT EXISTS senderinfo ( | ||||||
|                 id INTEGER PRIMARY KEY, |                 id INTEGER PRIMARY KEY, | ||||||
|                 ip TEXT, |                 ip TEXT, | ||||||
|                 publickey TEXT, |                 pkx TEXT, | ||||||
|  |                 pky TEXT, | ||||||
|                 threshold INTEGER |                 threshold INTEGER | ||||||
|             ) |             ) | ||||||
|             """ |             """ | ||||||
| @@ -88,6 +91,7 @@ def init_config(): | |||||||
| def clean_env(): | def clean_env(): | ||||||
|     with sqlite3.connect("client.db") as db: |     with sqlite3.connect("client.db") as db: | ||||||
|         db.execute("DELETE FROM node") |         db.execute("DELETE FROM node") | ||||||
|  |         db.execute("DELETE FROM message") | ||||||
|         db.commit() |         db.commit() | ||||||
|     print("Exit app") |     print("Exit app") | ||||||
|  |  | ||||||
| @@ -127,9 +131,6 @@ async def receive_messages(message: C): | |||||||
|     # insert record into database |     # insert record into database | ||||||
|     with sqlite3.connect("client.db") as db: |     with sqlite3.connect("client.db") as db: | ||||||
|         try: |         try: | ||||||
|             print("bin:", bin_C_capsule) |  | ||||||
|             print("ct:", C_ct) |  | ||||||
|             print("ip:", ip) |  | ||||||
|             db.execute( |             db.execute( | ||||||
|                 """ |                 """ | ||||||
|                 INSERT INTO message  |                 INSERT INTO message  | ||||||
| @@ -151,9 +152,15 @@ async def receive_messages(message: C): | |||||||
| # check record count | # check record count | ||||||
| async def check_merge(ct: int, ip: str): | async def check_merge(ct: int, ip: str): | ||||||
|     global sk, pk, node_response, message |     global sk, pk, node_response, message | ||||||
|  |     """ | ||||||
|  |     CREATE TABLE IF NOT EXISTS senderinfo ( | ||||||
|  |         id INTEGER PRIMARY KEY, | ||||||
|  |         ip TEXT, | ||||||
|  |         publickey TEXT, | ||||||
|  |         threshold INTEGER | ||||||
|  |     ) | ||||||
|  |     """ | ||||||
|     with sqlite3.connect("client.db") as db: |     with sqlite3.connect("client.db") as db: | ||||||
|         print("str(ct):", str(ct)) |  | ||||||
|         print("ip:", ip) |  | ||||||
|         # Check if the combination of ct_column and ip_column appears more than once. |         # Check if the combination of ct_column and ip_column appears more than once. | ||||||
|         cursor = db.execute( |         cursor = db.execute( | ||||||
|             """ |             """ | ||||||
| @@ -166,16 +173,6 @@ async def check_merge(ct: int, ip: str): | |||||||
|         # [(capsule, ct), ...] |         # [(capsule, ct), ...] | ||||||
|         cfrag_cts = cursor.fetchall() |         cfrag_cts = cursor.fetchall() | ||||||
|  |  | ||||||
|         cursor = db.execute( |  | ||||||
|             """ |  | ||||||
|         SELECT publickey, threshold  |  | ||||||
|         FROM senderinfo |  | ||||||
|         WHERE ip = ? |  | ||||||
|         """, |  | ||||||
|             ('127.1.1'), |  | ||||||
|         ) |  | ||||||
|         result = cursor.fetchall() |  | ||||||
|          |  | ||||||
|         # get T |         # get T | ||||||
|         cursor = db.execute( |         cursor = db.execute( | ||||||
|             """ |             """ | ||||||
| @@ -183,13 +180,16 @@ async def check_merge(ct: int, ip: str): | |||||||
|         FROM senderinfo |         FROM senderinfo | ||||||
|         WHERE ip = ? |         WHERE ip = ? | ||||||
|         """, |         """, | ||||||
|             (ip), |             (ip,), | ||||||
|         ) |         ) | ||||||
|         result = cursor.fetchall() |         result = cursor.fetchall() | ||||||
|         print("maybe error here?") |         try: | ||||||
|         pk_sender, T = result[0]  # result[0] = (pk, threshold) |             pk_sender, T = result[0]  # result[0] = (pk, threshold) | ||||||
|  |         except: | ||||||
|  |             pk_sender, T = 0, -1 | ||||||
|  |  | ||||||
|     if len(cfrag_cts) >= T: |     if len(cfrag_cts) <= T: | ||||||
|  |         print(T) | ||||||
|         # Deserialization |         # Deserialization | ||||||
|         temp_cfrag_cts = [] |         temp_cfrag_cts = [] | ||||||
|         for i in cfrag_cts: |         for i in cfrag_cts: | ||||||
| @@ -260,7 +260,7 @@ async def request_message(i_m: Request_Message): | |||||||
|     message_name = i_m.message_name |     message_name = i_m.message_name | ||||||
|     source_ip = get_own_ip() |     source_ip = get_own_ip() | ||||||
|     dest_port = "8002" |     dest_port = "8002" | ||||||
|     url = "http://" + dest_ip + ":" + dest_port + "/recieve_request" |     url = "http://" + dest_ip + ":" + dest_port + "/receive_request" | ||||||
|     payload = { |     payload = { | ||||||
|         "dest_ip": dest_ip, |         "dest_ip": dest_ip, | ||||||
|         "message_name": message_name, |         "message_name": message_name, | ||||||
| @@ -268,19 +268,25 @@ async def request_message(i_m: Request_Message): | |||||||
|         "pk": pk, |         "pk": pk, | ||||||
|     } |     } | ||||||
|     try: |     try: | ||||||
|         response = requests.post(url, json=payload, timeout=3) |         response = requests.post(url, json=payload, timeout=1) | ||||||
|         print(response.text) |         # print("menxian and pk", response.text) | ||||||
|  |  | ||||||
|     except: |     except requests.Timeout: | ||||||
|         print("can't post") |         print("can't post") | ||||||
|         return {"message": "can't post"} |         # content = {"message": "post timeout", "error": str(e)} | ||||||
|  |         # return JSONResponse(content, status_code=400) | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|  |         url = "http://" + dest_ip + ":" + dest_port + "/get_pk" | ||||||
|  |         print(url) | ||||||
|  |         response = requests.get(url,timeout=4) | ||||||
|  |         print(response.text) | ||||||
|         if response.status_code == 200: |         if response.status_code == 200: | ||||||
|             data = response.json() |             data = response.json() | ||||||
|             public_key = int(data["public_key"]) |             pkx = int(data["pkx"]) | ||||||
|             threshold = int(data["threshold"]) |             pky = int(data["pky"]) | ||||||
|             print(data) |             public_key = (pkx, pky) | ||||||
|  |             threshold = 2 | ||||||
|             with sqlite3.connect("client.db") as db: |             with sqlite3.connect("client.db") as db: | ||||||
|                 db.execute( |                 db.execute( | ||||||
|                     """ |                     """ | ||||||
| @@ -291,11 +297,12 @@ async def request_message(i_m: Request_Message): | |||||||
|             """, |             """, | ||||||
|                     (str(dest_ip), public_key, threshold), |                     (str(dest_ip), public_key, threshold), | ||||||
|                 ) |                 ) | ||||||
|     except: |     except Exception as e: | ||||||
|         print("Database error") |         print("Database error") | ||||||
|         return {"message": "Database Error"} |         content = {"message": "Database Error","error": str(e)} | ||||||
|  |         return JSONResponse(content, status_code=400) | ||||||
|  |  | ||||||
|     # wait 10s to recieve message from nodes |     # wait 3s to receive message from nodes | ||||||
|     for _ in range(3): |     for _ in range(3): | ||||||
|         if node_response: |         if node_response: | ||||||
|             data = message |             data = message | ||||||
| @@ -307,12 +314,13 @@ async def request_message(i_m: Request_Message): | |||||||
|             # return message to frontend |             # return message to frontend | ||||||
|             return {"message": data} |             return {"message": data} | ||||||
|         time.sleep(1) |         time.sleep(1) | ||||||
|     return {"message": "recieve timeout"} |     content = {"message": "receive timeout"} | ||||||
|  |     return JSONResponse(content, status_code=400) | ||||||
|  |  | ||||||
|  |  | ||||||
| # recieve request from others | # receive request from others | ||||||
| @app.post("/recieve_request") | @app.post("/receive_request") | ||||||
| async def recieve_request(i_m: IP_Message): | async def receive_request(i_m: IP_Message): | ||||||
|     global pk |     global pk | ||||||
|     source_ip = get_own_ip() |     source_ip = get_own_ip() | ||||||
|     if source_ip != i_m.dest_ip: |     if source_ip != i_m.dest_ip: | ||||||
| @@ -374,6 +382,11 @@ def get_node_list(count: int, server_addr: str): | |||||||
|     else: |     else: | ||||||
|         print("Failed:", response.status_code, response.text) |         print("Failed:", response.status_code, response.text) | ||||||
|          |          | ||||||
|  | @app.get("/get_pk") | ||||||
|  | async def get_pk(): | ||||||
|  |     global pk | ||||||
|  |     return {"pkx": str(pk[0]), "pky": str(pk[1])} | ||||||
|  |  | ||||||
|  |  | ||||||
| pk = point | pk = point | ||||||
| sk = int | sk = int | ||||||
| @@ -385,4 +398,4 @@ local_ip = get_own_ip() | |||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import uvicorn  # pylint: disable=e0401 |     import uvicorn  # pylint: disable=e0401 | ||||||
|  |  | ||||||
|     uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True) |     uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True,log_level="debug") | ||||||
|   | |||||||
							
								
								
									
										91
									
								
								src/demo.py
									
									
									
									
									
								
							
							
						
						
									
										91
									
								
								src/demo.py
									
									
									
									
									
								
							| @@ -1,51 +1,60 @@ | |||||||
| from tpre import * | from tpre import * | ||||||
| import time | import time | ||||||
|  |  | ||||||
| # 1 | for N in range(4,21,4): | ||||||
| start_time = time.time() |     # N = 10 | ||||||
| pk_a, sk_a = GenerateKeyPair() |     # T = 5 | ||||||
| m = b"hello world" |     T = N // 2 | ||||||
| end_time = time.time() |     print(f"当前门限值: N = {N}, T = {T}") | ||||||
| elapsed_time = end_time - start_time |  | ||||||
| print(f"代码块1运行时间:{elapsed_time}秒") |  | ||||||
|      |      | ||||||
| # 2 |     start_total_time = time.time() | ||||||
| start_time = time.time() |     # 1 | ||||||
| capsule_ct = Encrypt(pk_a, m) |     start_time = time.time() | ||||||
| end_time = time.time() |     pk_a, sk_a = GenerateKeyPair() | ||||||
| elapsed_time = end_time - start_time |     m = b"hello world" | ||||||
| print(f"代码块2运行时间:{elapsed_time}秒") |     end_time = time.time() | ||||||
|  |     elapsed_time = end_time - start_time | ||||||
|  |     print(f"密钥生成运行时间:{elapsed_time}秒") | ||||||
|  |  | ||||||
| # 3 |     # 2 | ||||||
| pk_b, sk_b = GenerateKeyPair() |     start_time = time.time() | ||||||
|  |     capsule_ct = Encrypt(pk_a, m) | ||||||
|  |     end_time = time.time() | ||||||
|  |     elapsed_time = end_time - start_time | ||||||
|  |     print(f"加密算法运行时间:{elapsed_time}秒") | ||||||
|  |  | ||||||
| N = 10 |     # 3 | ||||||
| T = 5 |     pk_b, sk_b = GenerateKeyPair() | ||||||
|  |  | ||||||
| # 5 |  | ||||||
| start_time = time.time() |  | ||||||
| id_tuple = tuple(range(N)) |  | ||||||
| rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) |  | ||||||
| end_time = time.time() |  | ||||||
| elapsed_time = end_time - start_time |  | ||||||
| print(f"代码块5运行时间:{elapsed_time}秒") |  | ||||||
|      |      | ||||||
| # 7 |     # 5 | ||||||
| start_time = time.time() |     start_time = time.time() | ||||||
| cfrag_cts = [] |     id_tuple = tuple(range(N)) | ||||||
|  |     rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) | ||||||
|  |     end_time = time.time() | ||||||
|  |     elapsed_time = end_time - start_time | ||||||
|  |     print(f"重加密密钥生成算法运行时间:{elapsed_time}秒") | ||||||
|  |  | ||||||
| for rekey in rekeys: |     # 7 | ||||||
|     cfrag_ct = ReEncrypt(rekey, capsule_ct) |     start_time = time.time() | ||||||
|     cfrag_cts.append(cfrag_ct) |     cfrag_cts = [] | ||||||
| end_time = time.time() |  | ||||||
| elapsed_time = end_time - start_time |  | ||||||
| print(f"代码块7运行时间:{elapsed_time}秒") |  | ||||||
|  |  | ||||||
| # 9 |     for rekey in rekeys: | ||||||
| start_time = time.time() |         cfrag_ct = ReEncrypt(rekey, capsule_ct) | ||||||
| cfrags = mergecfrag(cfrag_cts) |         cfrag_cts.append(cfrag_ct) | ||||||
| m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) |     end_time = time.time() | ||||||
| end_time = time.time() |     elapsed_time = (end_time - start_time) / len(rekeys) | ||||||
| elapsed_time = end_time - start_time |     print(f"重加密算法运行时间:{elapsed_time}秒") | ||||||
| print(f"代码块9运行时间:{elapsed_time}秒") |  | ||||||
| print(m) |     # 9 | ||||||
|  |     start_time = time.time() | ||||||
|  |     cfrags = mergecfrag(cfrag_cts) | ||||||
|  |     m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) | ||||||
|  |     end_time = time.time() | ||||||
|  |     elapsed_time = end_time - start_time | ||||||
|  |     end_total_time = time.time() | ||||||
|  |     total_time = end_total_time - start_total_time | ||||||
|  |     print(f"解密算法运行时间:{elapsed_time}秒") | ||||||
|  |     print("成功解密:", m) | ||||||
|  |     print(f"算法总运行时间:{total_time}秒") | ||||||
|  |     print() | ||||||
|   | |||||||
| @@ -18,7 +18,7 @@ async def lifespan(app: FastAPI): | |||||||
|  |  | ||||||
|  |  | ||||||
| app = FastAPI(lifespan=lifespan) | app = FastAPI(lifespan=lifespan) | ||||||
| server_address = "http://110.41.155.96:8000/server" | server_address = "http://60.204.236.38:8000/server" | ||||||
| id = 0 | id = 0 | ||||||
| ip = "" | ip = "" | ||||||
| client_ip_src = ""  # 发送信息用户的ip | client_ip_src = ""  # 发送信息用户的ip | ||||||
| @@ -59,7 +59,7 @@ def clear(): | |||||||
|  |  | ||||||
|  |  | ||||||
| async def send_heartbeat_internal() -> None: | async def send_heartbeat_internal() -> None: | ||||||
|     timeout = 3 |     timeout = 30 | ||||||
|     global ip |     global ip | ||||||
|     url = server_address + "/heartbeat?ip=" + ip |     url = server_address + "/heartbeat?ip=" + ip | ||||||
|     while True: |     while True: | ||||||
| @@ -120,4 +120,4 @@ async def send_user_des_message(source_ip: str, dest_ip: str, re_message):  #  | |||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import uvicorn  # pylint: disable=e0401 |     import uvicorn  # pylint: disable=e0401 | ||||||
|  |  | ||||||
|     uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True) |     uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True,log_level="debug") | ||||||
|   | |||||||
| @@ -110,7 +110,7 @@ async def receive_heartbeat(ip: str): | |||||||
|  |  | ||||||
| async def receive_heartbeat_internal(): | async def receive_heartbeat_internal(): | ||||||
|     while 1: |     while 1: | ||||||
|         timeout = 7 |         timeout = 70 | ||||||
|         # 删除超时的节点 |         # 删除超时的节点 | ||||||
|         cursor.execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,)) |         cursor.execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,)) | ||||||
|         conn.commit() |         conn.commit() | ||||||
|   | |||||||
| @@ -1 +0,0 @@ | |||||||
| 123565432 |  | ||||||
		Reference in New Issue
	
	Block a user