refactor server

This commit is contained in:
sangge-redmi 2024-09-05 10:32:14 +08:00
parent 53928b7f9e
commit 9654d8504b

View File

@ -1,10 +1,13 @@
from fastapi import FastAPI from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import sqlite3 import sqlite3
import asyncio import asyncio
import time import time
import ipaddress import ipaddress
import logging
logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
@ -39,6 +42,9 @@ def clean_env():
clear_database() clear_database()
# -----------------------------------------------------------------------------------------------
@app.get("/") @app.get("/")
async def home(): async def home():
return {"message": "Hello, World!"} return {"message": "Hello, World!"}
@ -54,10 +60,23 @@ async def show_nodes() -> list:
for row in rows: for row in rows:
nodes_list.append(row) nodes_list.append(row)
# TODO: use JSONResponse
return nodes_list return nodes_list
def validate_ip(ip: str) -> bool: def validate_ip(ip: str) -> bool:
"""
Validate an IP address.
This function checks if the provided string is a valid IP address.
Both IPv4 and IPv6 are considered valid.
Args:
ip (str): The IP address to validate.
Returns:
bool: True if the IP address is valid, False otherwise.
"""
try: try:
ipaddress.ip_address(ip) ipaddress.ip_address(ip)
return True return True
@ -82,10 +101,13 @@ async def get_node(ip: str) -> int:
ip_int = 0 ip_int = 0
for i in range(4): for i in range(4):
ip_int += int(ip_parts[i]) << (24 - (8 * i)) ip_int += int(ip_parts[i]) << (24 - (8 * i))
# TODO: replace print with logger
print("IP", ip, "对应的ID为", ip_int) print("IP", ip, "对应的ID为", ip_int)
# 获取当前时间 # 获取当前时间
current_time = int(time.time()) current_time = int(time.time())
# TODO: replace print with logger
print("当前时间: ", current_time) print("当前时间: ", current_time)
with sqlite3.connect("server.db") as db: with sqlite3.connect("server.db") as db:
@ -96,16 +118,19 @@ async def get_node(ip: str) -> int:
) )
db.commit() db.commit()
# TODO: use JSONResponse
return ip_int return ip_int
# TODO: try to use @app.delete("/node")
@app.get("/server/delete_node") @app.get("/server/delete_node")
async def delete_node(ip: str) -> None: async def delete_node(ip: str):
""" """
param: Delete a node by ip.
ip: 待删除节点的ip地址
return: Args:
None ip (str): The ip of the node to be deleted.
""" """
with sqlite3.connect("server.db") as db: with sqlite3.connect("server.db") as db:
@ -117,28 +142,44 @@ async def delete_node(ip: str) -> None:
# 执行删除操作 # 执行删除操作
db.execute("DELETE FROM nodes WHERE ip=?", (ip,)) db.execute("DELETE FROM nodes WHERE ip=?", (ip,))
db.commit() db.commit()
# TODO: replace print with logger
print(f"Node with IP {ip} deleted successfully.") print(f"Node with IP {ip} deleted successfully.")
return {"message", f"Node with IP {ip} deleted successfully."}
else: else:
print(f"Node with IP {ip} not found.") print(f"Node with IP {ip} not found.")
raise HTTPException(status_code=404, detail=f"Node with IP {ip} not found.")
# 接收节点心跳包 # 接收节点心跳包
@app.get("/server/heartbeat") @app.get("/server/heartbeat")
async def receive_heartbeat(ip: str): async def receive_heartbeat(ip: str):
"""
Receive a heartbeat from a node.
Args:
ip (str): The IP address of the node.
Returns:
JSONResponse: A message indicating the result of the operation.
"""
if not validate_ip(ip): if not validate_ip(ip):
content = {"message": "invalid ip "} content = {"message": "invalid ip format"}
return JSONResponse(content, status_code=400) return JSONResponse(content, status_code=400)
print("收到来自", ip, "的心跳包") print("收到来自", ip, "的心跳包")
logger.info("收到来自", ip, "的心跳包")
with sqlite3.connect("server.db") as db: with sqlite3.connect("server.db") as db:
db.execute( db.execute(
"UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip) "UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)
) )
return {"status": "received"} content = {"status": "received"}
return JSONResponse(content, status_code=200)
async def receive_heartbeat_internal(): async def receive_heartbeat_internal():
timeout = 70
while 1: while 1:
timeout = 70
with sqlite3.connect("server.db") as db: with sqlite3.connect("server.db") as db:
# 删除超时的节点 # 删除超时的节点
db.execute( db.execute(
@ -171,10 +212,10 @@ async def send_nodes_list(count: int) -> list:
print("收到来自客户端的节点列表请求...") print("收到来自客户端的节点列表请求...")
print(nodes_list) print(nodes_list)
# TODO: use JSONResponse
return nodes_list return nodes_list
# @app.get("/server/clear_database")
def clear_database() -> None: def clear_database() -> None:
with sqlite3.connect("server.db") as db: with sqlite3.connect("server.db") as db:
db.execute("DELETE FROM nodes") db.execute("DELETE FROM nodes")
@ -182,6 +223,6 @@ def clear_database() -> None:
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn # pylint: disable=e0401 import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True) uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)