merge into main #36
@@ -1,10 +1,13 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -39,6 +42,9 @@ def clean_env():
|
|||||||
clear_database()
|
clear_database()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def home():
|
async def home():
|
||||||
return {"message": "Hello, World!"}
|
return {"message": "Hello, World!"}
|
||||||
@@ -54,10 +60,23 @@ async def show_nodes() -> list:
|
|||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
nodes_list.append(row)
|
nodes_list.append(row)
|
||||||
|
# TODO: use JSONResponse
|
||||||
return nodes_list
|
return nodes_list
|
||||||
|
|
||||||
|
|
||||||
def validate_ip(ip: str) -> bool:
|
def validate_ip(ip: str) -> bool:
|
||||||
|
"""
|
||||||
|
Validate an IP address.
|
||||||
|
|
||||||
|
This function checks if the provided string is a valid IP address.
|
||||||
|
Both IPv4 and IPv6 are considered valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip (str): The IP address to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the IP address is valid, False otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
ipaddress.ip_address(ip)
|
ipaddress.ip_address(ip)
|
||||||
return True
|
return True
|
||||||
@@ -82,10 +101,13 @@ async def get_node(ip: str) -> int:
|
|||||||
ip_int = 0
|
ip_int = 0
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
ip_int += int(ip_parts[i]) << (24 - (8 * i))
|
ip_int += int(ip_parts[i]) << (24 - (8 * i))
|
||||||
|
|
||||||
|
# TODO: replace print with logger
|
||||||
print("IP", ip, "对应的ID为", ip_int)
|
print("IP", ip, "对应的ID为", ip_int)
|
||||||
|
|
||||||
# 获取当前时间
|
# 获取当前时间
|
||||||
current_time = int(time.time())
|
current_time = int(time.time())
|
||||||
|
# TODO: replace print with logger
|
||||||
print("当前时间: ", current_time)
|
print("当前时间: ", current_time)
|
||||||
|
|
||||||
with sqlite3.connect("server.db") as db:
|
with sqlite3.connect("server.db") as db:
|
||||||
@@ -96,16 +118,19 @@ async def get_node(ip: str) -> int:
|
|||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
# TODO: use JSONResponse
|
||||||
return ip_int
|
return ip_int
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: try to use @app.delete("/node")
|
||||||
@app.get("/server/delete_node")
|
@app.get("/server/delete_node")
|
||||||
async def delete_node(ip: str) -> None:
|
async def delete_node(ip: str):
|
||||||
"""
|
"""
|
||||||
param:
|
Delete a node by ip.
|
||||||
ip: 待删除节点的ip地址
|
|
||||||
return:
|
Args:
|
||||||
None
|
ip (str): The ip of the node to be deleted.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with sqlite3.connect("server.db") as db:
|
with sqlite3.connect("server.db") as db:
|
||||||
@@ -117,28 +142,44 @@ async def delete_node(ip: str) -> None:
|
|||||||
# 执行删除操作
|
# 执行删除操作
|
||||||
db.execute("DELETE FROM nodes WHERE ip=?", (ip,))
|
db.execute("DELETE FROM nodes WHERE ip=?", (ip,))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
# TODO: replace print with logger
|
||||||
print(f"Node with IP {ip} deleted successfully.")
|
print(f"Node with IP {ip} deleted successfully.")
|
||||||
|
return {"message", f"Node with IP {ip} deleted successfully."}
|
||||||
else:
|
else:
|
||||||
print(f"Node with IP {ip} not found.")
|
print(f"Node with IP {ip} not found.")
|
||||||
|
raise HTTPException(status_code=404, detail=f"Node with IP {ip} not found.")
|
||||||
|
|
||||||
|
|
||||||
# 接收节点心跳包
|
# 接收节点心跳包
|
||||||
@app.get("/server/heartbeat")
|
@app.get("/server/heartbeat")
|
||||||
async def receive_heartbeat(ip: str):
|
async def receive_heartbeat(ip: str):
|
||||||
|
"""
|
||||||
|
Receive a heartbeat from a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip (str): The IP address of the node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSONResponse: A message indicating the result of the operation.
|
||||||
|
"""
|
||||||
if not validate_ip(ip):
|
if not validate_ip(ip):
|
||||||
content = {"message": "invalid ip "}
|
content = {"message": "invalid ip format"}
|
||||||
return JSONResponse(content, status_code=400)
|
return JSONResponse(content, status_code=400)
|
||||||
print("收到来自", ip, "的心跳包")
|
print("收到来自", ip, "的心跳包")
|
||||||
|
logger.info("收到来自", ip, "的心跳包")
|
||||||
|
|
||||||
with sqlite3.connect("server.db") as db:
|
with sqlite3.connect("server.db") as db:
|
||||||
db.execute(
|
db.execute(
|
||||||
"UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)
|
"UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip)
|
||||||
)
|
)
|
||||||
return {"status": "received"}
|
content = {"status": "received"}
|
||||||
|
return JSONResponse(content, status_code=200)
|
||||||
|
|
||||||
|
|
||||||
async def receive_heartbeat_internal():
|
async def receive_heartbeat_internal():
|
||||||
while 1:
|
|
||||||
timeout = 70
|
timeout = 70
|
||||||
|
while 1:
|
||||||
with sqlite3.connect("server.db") as db:
|
with sqlite3.connect("server.db") as db:
|
||||||
# 删除超时的节点
|
# 删除超时的节点
|
||||||
db.execute(
|
db.execute(
|
||||||
@@ -171,10 +212,10 @@ async def send_nodes_list(count: int) -> list:
|
|||||||
|
|
||||||
print("收到来自客户端的节点列表请求...")
|
print("收到来自客户端的节点列表请求...")
|
||||||
print(nodes_list)
|
print(nodes_list)
|
||||||
|
# TODO: use JSONResponse
|
||||||
return nodes_list
|
return nodes_list
|
||||||
|
|
||||||
|
|
||||||
# @app.get("/server/clear_database")
|
|
||||||
def clear_database() -> None:
|
def clear_database() -> None:
|
||||||
with sqlite3.connect("server.db") as db:
|
with sqlite3.connect("server.db") as db:
|
||||||
db.execute("DELETE FROM nodes")
|
db.execute("DELETE FROM nodes")
|
||||||
@@ -182,6 +223,6 @@ def clear_database() -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn # pylint: disable=e0401
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
|
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)
|
||||||
|
Reference in New Issue
Block a user