diff --git a/src/server.py b/src/server.py index c4b8fa3..0d0b1a6 100644 --- a/src/server.py +++ b/src/server.py @@ -6,33 +6,40 @@ import sqlite3 import asyncio import time + @asynccontextmanager async def lifespan(app: FastAPI): init() yield clean_env() -app = FastAPI(lifespan = lifespan) + +app = FastAPI(lifespan=lifespan) # 连接到数据库(如果数据库不存在,则会自动创建) -conn = sqlite3.connect('server.db') +conn = sqlite3.connect("server.db") # 创建游标对象,用于执行SQL语句 cursor = conn.cursor() # 创建表: id: int; ip: TEXT -cursor.execute('''CREATE TABLE IF NOT EXISTS nodes ( +cursor.execute( + """CREATE TABLE IF NOT EXISTS nodes ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, last_heartbeat INTEGER - )''') + )""" +) + def init(): asyncio.create_task(receive_heartbeat_internal()) + def clean_env(): # 关闭游标和连接 cursor.close() conn.close() + @app.get("/server/show_nodes") async def show_nodes() -> list: nodes_list = [] @@ -43,37 +50,42 @@ async def show_nodes() -> list: nodes_list.append(row) return nodes_list + @app.get("/server/get_node") async def get_node(ip: str) -> int: - ''' - 中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点 - params: - ip: node ip - return: - id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id - ''' + """ + 中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点 + params: + ip: node ip + return: + id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id + """ ip_parts = ip.split(".") ip_int = 0 for i in range(4): ip_int += int(ip_parts[i]) << (24 - (8 * i)) - + # 获取当前时间 current_time = int(time.time()) # 插入数据 - cursor.execute("INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)", (ip_int, ip, current_time)) + cursor.execute( + "INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)", + (ip_int, ip, current_time), + ) conn.commit() return ip_int + @app.get("/server/delete_node") async def delete_node(ip: str) -> None: - ''' + """ param: ip: 待删除节点的ip地址 return: None - ''' + """ # 查询要删除的节点 cursor.execute("SELECT * FROM nodes WHERE ip=?", (ip,)) row = cursor.fetchone() @@ -85,12 +97,16 @@ async def delete_node(ip: str) -> None: else: print(f"Node with IP {ip} not found.") + # 接收节点心跳包 @app.post("/server/heartbeat") async def receive_heartbeat(ip: str): - cursor.execute("UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)) - return {"status": "received"} - + cursor.execute( + "UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip) + ) + return {"status": "received"} + + async def receive_heartbeat_internal(): while 1: timeout = 70 @@ -99,33 +115,35 @@ async def receive_heartbeat_internal(): # conn.commit() await asyncio.sleep(timeout) + @app.get("/server/send_nodes_list") async def send_nodes_list(count: int) -> list: - ''' + """ 中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成list格式返回给客户端 - params: - count: 所需节点个数 - return: + params: + count: 所需节点个数 + return: nodes_list: list - ''' + """ nodes_list = [] # 查询数据库中的节点数据 cursor.execute("SELECT * FROM nodes LIMIT ?", (count,)) rows = cursor.fetchall() - for row in rows: id, ip, last_heartbeat = row nodes_list.append(ip) return nodes_list + @app.get("/server/clear_database") async def clear_database() -> None: cursor.execute("DELETE FROM nodes") conn.commit() + if __name__ == "__main__": import uvicorn # pylint: disable=e0401 - uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file + uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)