forked from sangge/tpre-python
		
	Merge branch 'main' of https://git.mamahaha.work/dqy/mimajingsai
This commit is contained in:
		
							
								
								
									
										11
									
								
								README_en.md
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								README_en.md
									
									
									
									
									
								
							| @@ -21,6 +21,7 @@ The project uses the Chinese national standard cryptography algorithm to impleme | |||||||
|  |  | ||||||
| ## Environment Dependencies | ## Environment Dependencies | ||||||
|  |  | ||||||
|  | ### Bare mental version(UNTESTED) | ||||||
| System requirements:   | System requirements:   | ||||||
| - Linux | - Linux | ||||||
| - Windows(may need to complie and install gmssl yourself) | - Windows(may need to complie and install gmssl yourself) | ||||||
| @@ -30,22 +31,24 @@ The project relies on the following software: | |||||||
| - gmssl | - gmssl | ||||||
| - gmssl-python | - gmssl-python | ||||||
|  |  | ||||||
|  | ### Docker version | ||||||
|  | docker version:   | ||||||
|  | - Version:           24.0.5   | ||||||
|  | - API version:       1.43   | ||||||
|  | - Go version:        go1.20.6   | ||||||
| ## Installation Steps | ## Installation Steps | ||||||
|  |  | ||||||
| ### Pre-installation | ### Pre-installation | ||||||
| This project depends on gmssl, so you need to compile it from source first.   | This project depends on gmssl, so you need to compile it from source first.   | ||||||
| Visit [GmSSL](https://github.com/guanzhi/GmSSL) to learn how to install.   | Visit [GmSSL](https://github.com/guanzhi/GmSSL) to learn how to install.   | ||||||
|  |  | ||||||
|  | Then install essential python libs   | ||||||
| ```bash | ```bash | ||||||
| pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple | pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Docker Installation | ## Docker Installation | ||||||
| my docker version: |  | ||||||
| - Version:           24.0.5 |  | ||||||
| - API version:       1.43 |  | ||||||
| - Go version:        go1.20.6 |  | ||||||
|  |  | ||||||
| ### Use base image and build yourself | ### Use base image and build yourself | ||||||
| ```bash | ```bash | ||||||
|   | |||||||
| @@ -25,6 +25,8 @@ def init(): | |||||||
|     global pk, sk, server_address |     global pk, sk, server_address | ||||||
|     init_db() |     init_db() | ||||||
|     pk, sk = GenerateKeyPair() |     pk, sk = GenerateKeyPair() | ||||||
|  |  | ||||||
|  |     # load config from config file | ||||||
|     init_config() |     init_config() | ||||||
|     get_node_list(2, server_address)  # type: ignore |     get_node_list(2, server_address)  # type: ignore | ||||||
|  |  | ||||||
| @@ -94,14 +96,13 @@ class C(BaseModel): | |||||||
|     Tuple: Tuple[capsule, int] |     Tuple: Tuple[capsule, int] | ||||||
|     ip: str |     ip: str | ||||||
|  |  | ||||||
|  | # receive messages from nodes | ||||||
| # receive messages from node |  | ||||||
| @app.post("/receive_messages") | @app.post("/receive_messages") | ||||||
| async def receive_messages(message: C): | async def receive_messages(message: C): | ||||||
|     """ |     """ | ||||||
|     receive capsule and ip from nodes |     receive capsule and ip from nodes | ||||||
|     params: |     params: | ||||||
|     C: capsule and ct |     Tuple: capsule and ct | ||||||
|     ip: sender ip |     ip: sender ip | ||||||
|     return: |     return: | ||||||
|     status_code |     status_code | ||||||
| @@ -131,7 +132,7 @@ async def receive_messages(message: C): | |||||||
|                 (C_capsule, C_ct, ip), |                 (C_capsule, C_ct, ip), | ||||||
|             ) |             ) | ||||||
|             db.commit() |             db.commit() | ||||||
|             await check_merge(db, C_ct, ip) |             await check_merge(C_ct, ip) | ||||||
|             return HTTPException(status_code=200, detail="Message received") |             return HTTPException(status_code=200, detail="Message received") | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             print(f"Error occurred: {e}") |             print(f"Error occurred: {e}") | ||||||
| @@ -140,8 +141,9 @@ async def receive_messages(message: C): | |||||||
|  |  | ||||||
|  |  | ||||||
| # check record count | # check record count | ||||||
| async def check_merge(db, ct: int, ip: str): | async def check_merge(ct: int, ip: str): | ||||||
|     global sk, pk, node_response, message |     global sk, pk, node_response, message | ||||||
|  |     with sqlite3.connect("client.db") as db: | ||||||
|     # Check if the combination of ct_column and ip_column appears more than once. |     # Check if the combination of ct_column and ip_column appears more than once. | ||||||
|         cursor = db.execute( |         cursor = db.execute( | ||||||
|             """ |             """ | ||||||
| @@ -154,7 +156,7 @@ async def check_merge(db, ct: int, ip: str): | |||||||
|         # [(capsule, ct), ...] |         # [(capsule, ct), ...] | ||||||
|         cfrag_cts = cursor.fetchall() |         cfrag_cts = cursor.fetchall() | ||||||
|  |  | ||||||
|     # get N |         # get T | ||||||
|         cursor = db.execute( |         cursor = db.execute( | ||||||
|             """ |             """ | ||||||
|         SELECT publickey, threshold  |         SELECT publickey, threshold  | ||||||
| @@ -164,7 +166,8 @@ async def check_merge(db, ct: int, ip: str): | |||||||
|             (ip), |             (ip), | ||||||
|         ) |         ) | ||||||
|         result = cursor.fetchall() |         result = cursor.fetchall() | ||||||
|     pk_sender, T = result[0] |         pk_sender, T = result[0] # result[0] = (pk, threshold) | ||||||
|  |          | ||||||
|     if len(cfrag_cts) >= T: |     if len(cfrag_cts) >= T: | ||||||
|         cfrags = mergecfrag(cfrag_cts) |         cfrags = mergecfrag(cfrag_cts) | ||||||
|         message = DecryptFrags(sk, pk, pk_sender, cfrags)  # type: ignore |         message = DecryptFrags(sk, pk, pk_sender, cfrags)  # type: ignore | ||||||
| @@ -177,22 +180,31 @@ async def send_messages( | |||||||
| ): | ): | ||||||
|     global pk, sk |     global pk, sk | ||||||
|     id_list = [] |     id_list = [] | ||||||
|  |     # calculate id of nodes | ||||||
|     for node_ip in node_ips: |     for node_ip in node_ips: | ||||||
|         ip_parts = node_ip.split(".") |         ip_parts = node_ip.split(".") | ||||||
|         id = 0 |         id = 0 | ||||||
|         for i in range(4): |         for i in range(4): | ||||||
|             id += int(ip_parts[i]) << (24 - (8 * i)) |             id += int(ip_parts[i]) << (24 - (8 * i)) | ||||||
|         id_list.append(id) |         id_list.append(id) | ||||||
|  |      | ||||||
|  |     # generate rk | ||||||
|     rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list))  # type: ignore |     rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list))  # type: ignore | ||||||
|  |      | ||||||
|  |     capsule_ct = Encrypt(pk, message)  # type: ignore | ||||||
|  |  | ||||||
|     for i in range(len(node_ips)): |     for i in range(len(node_ips)): | ||||||
|         url = "http://" + node_ips[i] + ":8001" + "/recieve_message" |         url = "http://" + node_ips[i] + ":8001" + "/user_src?message" | ||||||
|  |  | ||||||
|         payload = { |         payload = { | ||||||
|             "source_ip": local_ip, |             "source_ip": local_ip, | ||||||
|             "dest_ip": dest_ip, |             "dest_ip": dest_ip, | ||||||
|             "message": message, |             "capsule_ct": capsule_ct, | ||||||
|             "rk": rk_list[i], |             "rk": rk_list[i], | ||||||
|         } |         } | ||||||
|         response = requests.post(url, json=payload) |         response = requests.post(url, json=payload) | ||||||
|  |         if response.status_code == 200: | ||||||
|  |             print(f"send to {node_ips[i]} successful") | ||||||
|     return 0 |     return 0 | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -203,22 +215,32 @@ class IP_Message(BaseModel): | |||||||
|     pk: int |     pk: int | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Request_Message(BaseModel): | ||||||
|  |     dest_ip: str | ||||||
|  |     message_name: str | ||||||
|  |  | ||||||
|  |  | ||||||
| # request message from others | # request message from others | ||||||
| @app.post("/request_message") | @app.post("/request_message") | ||||||
| async def request_message(i_m: IP_Message): | async def request_message(i_m: Request_Message): | ||||||
|     global message, node_response, pk |     global message, node_response, pk | ||||||
|     dest_ip = i_m.dest_ip |     dest_ip = i_m.dest_ip | ||||||
|  |     # dest_ip = dest_ip.split(":")[0] | ||||||
|     message_name = i_m.message_name |     message_name = i_m.message_name | ||||||
|     source_ip = get_own_ip() |     source_ip = get_own_ip() | ||||||
|     dest_port = "8003" |     dest_port = "8003" | ||||||
|     url = "http://" + dest_ip + dest_port + "/recieve_request" |     url = "http://" + dest_ip + ":" + dest_port + "/recieve_request?i_m" | ||||||
|     payload = { |     payload = { | ||||||
|         "dest_ip": dest_ip, |         "dest_ip": dest_ip, | ||||||
|         "message_name": message_name, |         "message_name": message_name, | ||||||
|         "source_ip": source_ip, |         "source_ip": source_ip, | ||||||
|         "pk": pk, |         "pk": pk, | ||||||
|     } |     } | ||||||
|  |     try: | ||||||
|         response = requests.post(url, json=payload) |         response = requests.post(url, json=payload) | ||||||
|  |     except: | ||||||
|  |         print("can't post") | ||||||
|  |         return {"message": "can't post"} | ||||||
|     if response.status_code == 200: |     if response.status_code == 200: | ||||||
|         data = response.json() |         data = response.json() | ||||||
|         public_key = int(data["public_key"]) |         public_key = int(data["public_key"]) | ||||||
| @@ -234,11 +256,37 @@ async def request_message(i_m: IP_Message): | |||||||
|                 (public_key, threshold), |                 (public_key, threshold), | ||||||
|             ) |             ) | ||||||
|      |      | ||||||
|     # wait to recieve message from nodes |  | ||||||
|  |      | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         if response.status_code == 200: | ||||||
|  |             data = response.json() | ||||||
|  |             public_key = int(data["public_key"]) | ||||||
|  |             threshold = int(data["threshold"]) | ||||||
|  |             with sqlite3.connect("client.db") as db: | ||||||
|  |                 db.execute( | ||||||
|  |                     """ | ||||||
|  |             INSERT INTO senderinfo | ||||||
|  |             (public_key, threshold) | ||||||
|  |             VALUES | ||||||
|  |             (?, ?) | ||||||
|  |             """, | ||||||
|  |                     (public_key, threshold), | ||||||
|  |                 ) | ||||||
|  |     except: | ||||||
|  |         print("Database error") | ||||||
|  |         return {"message": "Database Error"} | ||||||
|  |  | ||||||
|  |     # wait 10s to recieve message from nodes | ||||||
|     for _ in range(10): |     for _ in range(10): | ||||||
|         if node_response: |         if node_response: | ||||||
|             data = message |             data = message | ||||||
|  |              | ||||||
|  |             # reset message and node_response | ||||||
|             message = b"" |             message = b"" | ||||||
|  |             node_response = False | ||||||
|  |  | ||||||
|             # return message to frontend |             # return message to frontend | ||||||
|             return {"message": data} |             return {"message": data} | ||||||
|         time.sleep(1) |         time.sleep(1) | ||||||
| @@ -267,7 +315,11 @@ async def recieve_request(i_m: IP_Message): | |||||||
|             (threshold,), |             (threshold,), | ||||||
|         ) |         ) | ||||||
|         node_ips = cursor.fetchall() |         node_ips = cursor.fetchall() | ||||||
|  |          | ||||||
|  |     # message name | ||||||
|     message = b"hello world" + random.randbytes(8) |     message = b"hello world" + random.randbytes(8) | ||||||
|  |      | ||||||
|  |     # send message to nodes | ||||||
|     await send_messages(node_ips, message, dest_ip, pk_B, threshold)  # type: ignore |     await send_messages(node_ips, message, dest_ip, pk_B, threshold)  # type: ignore | ||||||
|     response = {"threshold": threshold, "public_key": own_public_key} |     response = {"threshold": threshold, "public_key": own_public_key} | ||||||
|     return response |     return response | ||||||
| @@ -282,8 +334,6 @@ def get_own_ip() -> str: | |||||||
| # get node list from central server | # get node list from central server | ||||||
| def get_node_list(count: int, server_addr: str): | def get_node_list(count: int, server_addr: str): | ||||||
|     url = "http://" + server_addr + "/server/send_nodes_list?count=" + str(count) |     url = "http://" + server_addr + "/server/send_nodes_list?count=" + str(count) | ||||||
|     # payload = {"count": count} |  | ||||||
|     # response = requests.post(url, json=payload) |  | ||||||
|     response = requests.get(url) |     response = requests.get(url) | ||||||
|     # Checking the response |     # Checking the response | ||||||
|     if response.status_code == 200: |     if response.status_code == 200: | ||||||
|   | |||||||
							
								
								
									
										23
									
								
								src/client_cli.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								src/client_cli.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | |||||||
|  | import argparse | ||||||
|  | import requests | ||||||
|  |  | ||||||
|  | def send_post_request(ip_addr, message_name): | ||||||
|  |     url = f"http://localhost:20234/request_message/?i_m" | ||||||
|  |     data = { | ||||||
|  |         "dest_ip": ip_addr, | ||||||
|  |         "message_name": message_name | ||||||
|  |     } | ||||||
|  |     response = requests.post(url, json=data) | ||||||
|  |     return response.text | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     parser = argparse.ArgumentParser(description="Send POST request to a specified IP.") | ||||||
|  |     parser.add_argument("ip_addr", help="IP address to send request to.") | ||||||
|  |     parser.add_argument("message_name", help="Message name to send.") | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     response = send_post_request(args.ip_addr, args.message_name) | ||||||
|  |     print(response) | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
| @@ -24,7 +24,8 @@ T = 5 | |||||||
|  |  | ||||||
| # 5 | # 5 | ||||||
| start_time = time.time() | start_time = time.time() | ||||||
| rekeys = GenerateReKey(sk_a, pk_b, N, T) | id_tuple = tuple(range(N)) | ||||||
|  | rekeys = GenerateReKey(sk_a, pk_b, N, T, id_tuple) | ||||||
| end_time = time.time() | end_time = time.time() | ||||||
| elapsed_time = end_time - start_time | elapsed_time = end_time - start_time | ||||||
| print(f"代码块5运行时间:{elapsed_time}秒") | print(f"代码块5运行时间:{elapsed_time}秒") | ||||||
|   | |||||||
							
								
								
									
										35
									
								
								src/node.py
									
									
									
									
									
								
							
							
						
						
									
										35
									
								
								src/node.py
									
									
									
									
									
								
							| @@ -1,4 +1,4 @@ | |||||||
| from fastapi import FastAPI, Request | from fastapi import FastAPI, Request, HTTPException | ||||||
| import requests | import requests | ||||||
| from contextlib import asynccontextmanager | from contextlib import asynccontextmanager | ||||||
| import socket | import socket | ||||||
| @@ -77,24 +77,39 @@ async def send_heartbeat_internal() -> None: | |||||||
|  |  | ||||||
| @app.post("/user_src")  # 接收用户1发送的信息 | @app.post("/user_src")  # 接收用户1发送的信息 | ||||||
| async def receive_user_src_message(message: Request): | async def receive_user_src_message(message: Request): | ||||||
|     json_data = await message.json() |  | ||||||
|     global client_ip_src, client_ip_des |     global client_ip_src, client_ip_des | ||||||
|     # kfrag , capsule_ct ,client_ip_src , client_ip_des   = json_data[]  # 看梁俊勇 |     # kfrag , capsule_ct ,client_ip_src , client_ip_des   = json_data[]  # 看梁俊勇 | ||||||
|     global processed_message |     """ | ||||||
|     processed_message = ReEncrypt(kfrag, capsule_ct) |     payload = { | ||||||
|  |             "source_ip": local_ip, | ||||||
|  |             "dest_ip": dest_ip, | ||||||
|  |             "capsule_ct": capsule_ct, | ||||||
|  |             "rk": rk_list[i], | ||||||
|  |         } | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     data = await message.json() | ||||||
|  |     source_ip = data.get("source_ip") | ||||||
|  |     dest_ip = data.get("dest_ip") | ||||||
|  |     capsule_ct = data.get("capsule_ct") | ||||||
|  |     rk = data.get("rk") | ||||||
|  |  | ||||||
|  |     processed_message = ReEncrypt(rk, capsule_ct) | ||||||
|  |     await send_user_des_message(source_ip, dest_ip, processed_message) | ||||||
|  |     return HTTPException(status_code=200, detail="message recieved") | ||||||
|  |  | ||||||
|  |  | ||||||
| def send_user_des_message():  # 发送消息给用户2 | async def send_user_des_message(source_ip: str, dest_ip: str, re_message):  # 发送消息给用户2 | ||||||
|     global processed_message, client_ip_src, client_ip_des |     data = {"Tuple": re_message, "ip": source_ip}  # 类型不匹配 | ||||||
|  |  | ||||||
|     data = {"Tuple": processed_message, "ip": client_ip_src}  # 类型不匹配 |  | ||||||
|  |  | ||||||
|     # 发送 HTTP POST 请求 |     # 发送 HTTP POST 请求 | ||||||
|     response = requests.post("http://" + client_ip_des + "/receive_messages", json=data) |     response = requests.post( | ||||||
|  |         "http://" + dest_ip + "/receive_messages?message", json=data | ||||||
|  |     ) | ||||||
|     print(response) |     print(response) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import uvicorn  # pylint: disable=e0401 |     import uvicorn  # pylint: disable=e0401 | ||||||
|  |  | ||||||
|     uvicorn.run("node:app", host="0.0.0.0", port=8001, reload=True) |     uvicorn.run("node:app", host="0.0.0.0", port=8000, reload=True) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user