forked from sangge/tpre-python
		
	main #22
| @@ -28,7 +28,6 @@ def init(): | |||||||
|  |  | ||||||
|     # load config from config file |     # load config from config file | ||||||
|     init_config() |     init_config() | ||||||
|  |  | ||||||
|     get_node_list(2, server_address)  # type: ignore |     get_node_list(2, server_address)  # type: ignore | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -100,6 +99,7 @@ class C(BaseModel): | |||||||
|     Tuple: Tuple[capsule, int] |     Tuple: Tuple[capsule, int] | ||||||
|     ip: str |     ip: str | ||||||
|  |  | ||||||
|  |  | ||||||
| # receive messages from nodes | # receive messages from nodes | ||||||
| @app.post("/receive_messages") | @app.post("/receive_messages") | ||||||
| async def receive_messages(message: C): | async def receive_messages(message: C): | ||||||
| @@ -184,6 +184,7 @@ async def send_messages( | |||||||
| ): | ): | ||||||
|     global pk, sk |     global pk, sk | ||||||
|     id_list = [] |     id_list = [] | ||||||
|  |     print(node_ips) | ||||||
|     # calculate id of nodes |     # calculate id of nodes | ||||||
|     for node_ip in node_ips: |     for node_ip in node_ips: | ||||||
|         ip_parts = node_ip.split(".") |         ip_parts = node_ip.split(".") | ||||||
| @@ -246,6 +247,20 @@ async def request_message(i_m: Request_Message): | |||||||
|     except: |     except: | ||||||
|         print("can't post") |         print("can't post") | ||||||
|         return {"message": "can't post"} |         return {"message": "can't post"} | ||||||
|  |     if response.status_code == 200: | ||||||
|  |         data = response.json() | ||||||
|  |         public_key = int(data["public_key"]) | ||||||
|  |         threshold = int(data["threshold"]) | ||||||
|  |         with sqlite3.connect("client.db") as db: | ||||||
|  |             db.execute( | ||||||
|  |                 """ | ||||||
|  |         INSERT INTO senderinfo | ||||||
|  |         (public_key, threshold) | ||||||
|  |         VALUES | ||||||
|  |         (?, ?) | ||||||
|  |         """, | ||||||
|  |                 (public_key, threshold), | ||||||
|  |             ) | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         if response.status_code == 200: |         if response.status_code == 200: | ||||||
| @@ -290,7 +305,8 @@ async def recieve_request(i_m: IP_Message): | |||||||
|     if source_ip != i_m.dest_ip: |     if source_ip != i_m.dest_ip: | ||||||
|         return HTTPException(status_code=400, detail="Wrong ip") |         return HTTPException(status_code=400, detail="Wrong ip") | ||||||
|     dest_ip = i_m.source_ip |     dest_ip = i_m.source_ip | ||||||
|     threshold = random.randrange(1, 2) |     # threshold = random.randrange(1, 2) | ||||||
|  |     threshold = 2 | ||||||
|     own_public_key = pk |     own_public_key = pk | ||||||
|     pk_B = i_m.pk |     pk_B = i_m.pk | ||||||
|  |  | ||||||
| @@ -315,7 +331,6 @@ async def recieve_request(i_m: IP_Message): | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_own_ip() -> str: | def get_own_ip() -> str: | ||||||
|      |  | ||||||
|     ip = os.environ.get("HOST_IP", "IP not set") |     ip = os.environ.get("HOST_IP", "IP not set") | ||||||
|     return ip |     return ip | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,34 +6,41 @@ import sqlite3 | |||||||
| import asyncio | import asyncio | ||||||
| import time | import time | ||||||
|  |  | ||||||
|  |  | ||||||
| @asynccontextmanager | @asynccontextmanager | ||||||
| async def lifespan(app: FastAPI): | async def lifespan(app: FastAPI): | ||||||
|     init() |     init() | ||||||
|     yield |     yield | ||||||
|     clean_env() |     clean_env() | ||||||
|  |  | ||||||
|  |  | ||||||
| app = FastAPI(lifespan=lifespan) | app = FastAPI(lifespan=lifespan) | ||||||
|  |  | ||||||
| # 连接到数据库(如果数据库不存在,则会自动创建) | # 连接到数据库(如果数据库不存在,则会自动创建) | ||||||
| conn = sqlite3.connect('server.db') | conn = sqlite3.connect("server.db") | ||||||
| # 创建游标对象,用于执行SQL语句 | # 创建游标对象,用于执行SQL语句 | ||||||
| cursor = conn.cursor() | cursor = conn.cursor() | ||||||
| # 创建表: id: int; ip: TEXT | # 创建表: id: int; ip: TEXT | ||||||
| cursor.execute('''CREATE TABLE IF NOT EXISTS nodes ( | cursor.execute( | ||||||
|  |     """CREATE TABLE IF NOT EXISTS nodes ( | ||||||
|                    id INTEGER PRIMARY KEY AUTOINCREMENT, |                    id INTEGER PRIMARY KEY AUTOINCREMENT, | ||||||
|                    ip TEXT NOT NULL, |                    ip TEXT NOT NULL, | ||||||
|                    last_heartbeat INTEGER |                    last_heartbeat INTEGER | ||||||
|                )''') |                )""" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def init(): | def init(): | ||||||
|     asyncio.create_task(receive_heartbeat_internal()) |     asyncio.create_task(receive_heartbeat_internal()) | ||||||
|  |  | ||||||
|  |  | ||||||
| def clean_env(): | def clean_env(): | ||||||
|     clear_database() |     clear_database() | ||||||
|     # 关闭游标和连接 |     # 关闭游标和连接 | ||||||
|     cursor.close() |     cursor.close() | ||||||
|     conn.close() |     conn.close() | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.get("/server/show_nodes") | @app.get("/server/show_nodes") | ||||||
| async def show_nodes() -> list: | async def show_nodes() -> list: | ||||||
|     nodes_list = [] |     nodes_list = [] | ||||||
| @@ -44,15 +51,16 @@ async def show_nodes() -> list: | |||||||
|         nodes_list.append(row) |         nodes_list.append(row) | ||||||
|     return nodes_list |     return nodes_list | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.get("/server/get_node") | @app.get("/server/get_node") | ||||||
| async def get_node(ip: str) -> int: | async def get_node(ip: str) -> int: | ||||||
|     ''' |     """ | ||||||
|     中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点 |     中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点 | ||||||
|     params: |     params: | ||||||
|     ip: node ip |     ip: node ip | ||||||
|     return: |     return: | ||||||
|     id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id |     id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id | ||||||
|     ''' |     """ | ||||||
|     ip_parts = ip.split(".") |     ip_parts = ip.split(".") | ||||||
|     ip_int = 0 |     ip_int = 0 | ||||||
|     for i in range(4): |     for i in range(4): | ||||||
| @@ -62,19 +70,23 @@ async def get_node(ip: str) -> int: | |||||||
|     current_time = int(time.time()) |     current_time = int(time.time()) | ||||||
|  |  | ||||||
|     # 插入数据 |     # 插入数据 | ||||||
|     cursor.execute("INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)", (ip_int, ip, current_time)) |     cursor.execute( | ||||||
|  |         "INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)", | ||||||
|  |         (ip_int, ip, current_time), | ||||||
|  |     ) | ||||||
|     conn.commit() |     conn.commit() | ||||||
|  |  | ||||||
|     return ip_int |     return ip_int | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.get("/server/delete_node") | @app.get("/server/delete_node") | ||||||
| async def delete_node(ip: str) -> None: | async def delete_node(ip: str) -> None: | ||||||
|     ''' |     """ | ||||||
|     param: |     param: | ||||||
|     ip: 待删除节点的ip地址 |     ip: 待删除节点的ip地址 | ||||||
|     return: |     return: | ||||||
|     None |     None | ||||||
|     ''' |     """ | ||||||
|     # 查询要删除的节点 |     # 查询要删除的节点 | ||||||
|     cursor.execute("SELECT * FROM nodes WHERE ip=?", (ip,)) |     cursor.execute("SELECT * FROM nodes WHERE ip=?", (ip,)) | ||||||
|     row = cursor.fetchone() |     row = cursor.fetchone() | ||||||
| @@ -86,12 +98,16 @@ async def delete_node(ip: str) -> None: | |||||||
|     else: |     else: | ||||||
|         print(f"Node with IP {ip} not found.") |         print(f"Node with IP {ip} not found.") | ||||||
|  |  | ||||||
|  |  | ||||||
| # 接收节点心跳包 | # 接收节点心跳包 | ||||||
| @app.get("/server/heartbeat") | @app.get("/server/heartbeat") | ||||||
| async def receive_heartbeat(ip: str): | async def receive_heartbeat(ip: str): | ||||||
|         cursor.execute("UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)) |     cursor.execute( | ||||||
|  |         "UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip) | ||||||
|  |     ) | ||||||
|     return {"status": "received"} |     return {"status": "received"} | ||||||
|  |  | ||||||
|  |  | ||||||
| async def receive_heartbeat_internal(): | async def receive_heartbeat_internal(): | ||||||
|     while 1: |     while 1: | ||||||
|         timeout = 7 |         timeout = 7 | ||||||
| @@ -100,32 +116,34 @@ async def receive_heartbeat_internal(): | |||||||
|         conn.commit() |         conn.commit() | ||||||
|         await asyncio.sleep(timeout) |         await asyncio.sleep(timeout) | ||||||
|  |  | ||||||
|  |  | ||||||
| @app.get("/server/send_nodes_list") | @app.get("/server/send_nodes_list") | ||||||
| async def send_nodes_list(count: int) -> list: | async def send_nodes_list(count: int) -> list: | ||||||
|     ''' |     """ | ||||||
|     中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成list格式返回给客户端 |     中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成list格式返回给客户端 | ||||||
|     params: |     params: | ||||||
|     count: 所需节点个数 |     count: 所需节点个数 | ||||||
|     return: |     return: | ||||||
|     nodes_list: list |     nodes_list: list | ||||||
|     ''' |     """ | ||||||
|     nodes_list = [] |     nodes_list = [] | ||||||
|  |  | ||||||
|     # 查询数据库中的节点数据 |     # 查询数据库中的节点数据 | ||||||
|     cursor.execute("SELECT * FROM nodes LIMIT ?", (count,)) |     cursor.execute("SELECT * FROM nodes LIMIT ?", (count,)) | ||||||
|     rows = cursor.fetchall() |     rows = cursor.fetchall() | ||||||
|  |  | ||||||
|     for row in rows: |     for row in rows: | ||||||
|         id, ip, last_heartbeat = row |         id, ip, last_heartbeat = row | ||||||
|         nodes_list.append(ip) |         nodes_list.append(ip) | ||||||
|  |  | ||||||
|     return nodes_list |     return nodes_list | ||||||
|  |  | ||||||
|  |  | ||||||
| # @app.get("/server/clear_database") | # @app.get("/server/clear_database") | ||||||
| def clear_database() -> None: | def clear_database() -> None: | ||||||
|     cursor.execute("DELETE FROM nodes") |     cursor.execute("DELETE FROM nodes") | ||||||
|     conn.commit() |     conn.commit() | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     import uvicorn  # pylint: disable=e0401 |     import uvicorn  # pylint: disable=e0401 | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user