mimajingsai/src/server.py

151 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
import sqlite3
import asyncio
import time
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
clean_env()
app = FastAPI(lifespan=lifespan)
# 连接到数据库(如果数据库不存在,则会自动创建)
conn = sqlite3.connect("server.db")
# 创建游标对象用于执行SQL语句
cursor = conn.cursor()
# 创建表: id: int; ip: TEXT
cursor.execute(
"""CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
last_heartbeat INTEGER
)"""
)
def init():
asyncio.create_task(receive_heartbeat_internal())
def clean_env():
clear_database()
# 关闭游标和连接
cursor.close()
conn.close()
@app.get("/server/show_nodes")
async def show_nodes() -> list:
nodes_list = []
# 查询数据
cursor.execute("SELECT * FROM nodes")
rows = cursor.fetchall()
for row in rows:
nodes_list.append(row)
return nodes_list
@app.get("/server/get_node")
async def get_node(ip: str) -> int:
"""
中心服务器与节点交互, 节点发送ip, 中心服务器接收ip存入数据库并将ip转换为int作为节点id返回给节点
params:
ip: node ip
return:
id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id
"""
ip_parts = ip.split(".")
ip_int = 0
for i in range(4):
ip_int += int(ip_parts[i]) << (24 - (8 * i))
# 获取当前时间
current_time = int(time.time())
# 插入数据
cursor.execute(
"INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)",
(ip_int, ip, current_time),
)
conn.commit()
return ip_int
@app.get("/server/delete_node")
async def delete_node(ip: str) -> None:
"""
param:
ip: 待删除节点的ip地址
return:
None
"""
# 查询要删除的节点
cursor.execute("SELECT * FROM nodes WHERE ip=?", (ip,))
row = cursor.fetchone()
if row is not None:
# 执行删除操作
cursor.execute("DELETE FROM nodes WHERE ip=?", (ip,))
conn.commit()
print(f"Node with IP {ip} deleted successfully.")
else:
print(f"Node with IP {ip} not found.")
# 接收节点心跳包
@app.get("/server/heartbeat")
async def receive_heartbeat(ip: str):
cursor.execute(
"UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)
)
return {"status": "received"}
async def receive_heartbeat_internal():
while 1:
timeout = 7
# 删除超时的节点
# cursor.execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,))
# conn.commit()
await asyncio.sleep(timeout)
@app.get("/server/send_nodes_list")
async def send_nodes_list(count: int) -> list:
"""
中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成list格式返回给客户端
params:
count: 所需节点个数
return:
nodes_list: list
"""
nodes_list = []
# 查询数据库中的节点数据
cursor.execute("SELECT * FROM nodes LIMIT ?", (count,))
rows = cursor.fetchall()
for row in rows:
id, ip, last_heartbeat = row
nodes_list.append(ip)
return nodes_list
# @app.get("/server/clear_database")
def clear_database() -> None:
cursor.execute("DELETE FROM nodes")
conn.commit()
if __name__ == "__main__":
import uvicorn # pylint: disable=e0401
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)