Merge pull request 'main' (#19) from sangge/mimajingsai:main into main

Reviewed-on: ccyj/mimajingsai#19
This commit is contained in:
ccyj 2023-10-21 20:40:13 +08:00
commit 1aac74592f
6 changed files with 441 additions and 25 deletions

2
.gitignore vendored
View File

@ -8,3 +8,5 @@ example.py
ReEncrypt.py ReEncrypt.py
src/temp_message_file src/temp_message_file
src/temp_key_file src/temp_key_file
src/client.db
src/server.db

3
src/client.ini Normal file
View File

@ -0,0 +1,3 @@
[settings]
server_address = "127.0.0.1:8000"
version = 1.0

View File

@ -0,0 +1,275 @@
from fastapi import FastAPI, HTTPException
import requests
import os
from typing import Tuple
from tpre import *
import sqlite3
from contextlib import asynccontextmanager
from pydantic import BaseModel
import socket
import random
import time
@asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
clean_env()
app = FastAPI(lifespan=lifespan)
pk = point
sk = int
server_address = str
node_response = False
message = bytes
def init():
global pk, sk, server_address
init_db()
pk, sk = GenerateKeyPair()
init_config()
# get_node_list(6, server_address) # type: ignore
def init_db():
with sqlite3.connect("client.db") as db:
# message table
db.execute(
"""
CREATE TABLE IF NOT EXISTS message (
id INTEGER PRIMARY KEY,
capsule TEXT,
ct TEXT,
senderip TEXT
);
"""
)
# node ip table
db.execute(
"""
CREATE TABLE IF NOT EXISTS node (
id INTEGER PRIMARY KEY,
nodeip TEXT
);
"""
)
# sender info table
db.execute(
"""
CREATE TABLE IF NOT EXISTS senderinfo (
id INTEGER PRIMARY KEY,
ip TEXT,
publickey TEXT,
threshold INTEGER
)
"""
)
db.commit()
print("Init Database Successful")
# load config from config file
def init_config():
import configparser
global server_address
config = configparser.ConfigParser()
config.read("client.ini")
server_address = config["settings"]["server_address"]
# execute on exit
def clean_env():
print("Exit app")
# main page
@app.get("/")
async def read_root():
return {"message": "Hello, World!"}
class C(BaseModel):
Tuple: Tuple[capsule, int]
ip: str
# receive messages from node
@app.post("/receive_messages")
async def receive_messages(message: C):
"""
receive capsule and ip from nodes
params:
C: capsule and ct
ip: sender ip
return:
status_code
"""
C_tuple = message.Tuple
ip = message.ip
if not C_tuple or not ip:
raise HTTPException(status_code=400, detail="Invalid input data")
C_capsule = C_tuple[0]
C_ct = C_tuple[1]
if not Checkcapsule(C_capsule):
raise HTTPException(status_code=400, detail="Invalid capsule")
# insert record into database
with sqlite3.connect("message.db") as db:
try:
db.execute(
"""
INSERT INTO message
(capsule_column, ct_column, ip_column)
VALUES
(?, ?, ?)
""",
(C_capsule, C_ct, ip),
)
db.commit()
await check_merge(db, C_ct, ip)
return HTTPException(status_code=200, detail="Message received")
except Exception as e:
print(f"Error occurred: {e}")
db.rollback()
return HTTPException(status_code=400, detail="Database error")
# check record count
async def check_merge(db, ct: int, ip: str):
global sk, pk, node_response, message
# Check if the combination of ct_column and ip_column appears more than once.
cursor = db.execute(
"""
SELECT capsule, ct
FROM message
WHERE ct = ? AND senderip = ?
""",
(ct, ip),
)
# [(capsule, ct), ...]
cfrag_cts = cursor.fetchall()
# get N
cursor = db.execute(
"""
SELECT publickey, threshold
FROM senderinfo
WHERE senderip = ?
""",
(ip),
)
result = cursor.fetchall()
pk_sender, T = result[0]
if len(cfrag_cts) >= T:
cfrags = mergecfrag(cfrag_cts)
message = DecryptFrags(sk, pk, pk_sender, cfrags) # type: ignore
node_response = True
# send message to node
def send_message(ip: tuple[str, ...]):
return 0
class IP_Message(BaseModel):
dest_ip: str
message_name: str
source_ip: str
# request message from others
@app.post("/request_message")
async def request_message(i_m: IP_Message):
global message, node_response
dest_ip = i_m.dest_ip
message_name = i_m.message_name
source_ip = get_own_ip()
dest_port = "8003"
url = "http://" + dest_ip + dest_port + "/recieve_request"
payload = {"dest_ip": dest_ip, "message_name": message_name, "source_ip": source_ip}
response = requests.post(url, json=payload)
if response.status_code == 200:
data = response.json()
public_key = int(data["public_key"])
threshold = int(data["threshold"])
with sqlite3.connect("client.db") as db:
db.execute(
"""
INSERT INTO senderinfo
(public_key, threshold)
VALUES
(?, ?)
""",
(public_key, threshold),
)
# wait to recieve message from nodes
for _ in range(10):
if node_response:
data = message
message = b""
# return message to frontend
return {"message": data}
time.sleep(1)
# recieve request from others
@app.post("/recieve_request")
async def recieve_request(i_m: IP_Message):
global pk
source_ip = get_own_ip()
if source_ip != i_m.dest_ip:
return HTTPException(status_code=400, detail="Wrong ip")
dest_ip = i_m.source_ip
threshold = random.randrange(1, 6)
public_key = pk
response = {"threshold": threshold,"public_key": public_key}
return response
def get_own_ip() -> str:
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
return ip
# get node list from central server
def get_node_list(count: int, server_addr: str):
url = "http://" + server_addr + "/server/send_nodes_list"
payload = {"count": count}
response = requests.post(url, json=payload)
# Checking the response
if response.status_code == 200:
print("Success get node list")
node_ip = response.text
# insert node ip to database
with sqlite3.connect("client.db") as db:
db.executemany(
"""
INSERT INTO node
nodeip
VALUE (?)
""",
node_ip,
)
db.commit()
print("Success add node ip")
else:
print("Failed:", response.status_code, response.text)
if __name__ == "__main__":
import uvicorn # pylint: disable=e0401
uvicorn.run("client:app", host="0.0.0.0", port=8003, reload="True")

View File

@ -68,3 +68,52 @@ async def send_message(message: str):
data = {"message": processed_message} data = {"message": processed_message}
response = requests.post(url, data=data) response = requests.post(url, data=data)
return response.json() return response.json()
import requests
def send_heartbeat(url: str) -> bool:
try:
response = requests.get(url, timeout=5) # 使用 GET 方法作为心跳请求
response.raise_for_status() # 检查响应是否为 200 OK
# 可选:根据响应内容进行进一步验证
# if response.json() != expected_response:
# return False
return True
except requests.RequestException:
return False
# 使用方式
url = "https://your-service-url.com/heartbeat"
if send_heartbeat(url):
print("Service is alive!")
else:
print("Service might be down or unreachable.")
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
async def receive_heartbeat_internal() -> int:
while True:
print('successful delete1')
timeout = 10
# 删除超时的节点(假设你有一个异步的数据库操作函数)
await async_cursor_execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,))
await async_conn_commit()
print('successful delete')
await asyncio.sleep(timeout)
return 1
@asynccontextmanager
async def lifespan(app: FastAPI):
task = asyncio.create_task(receive_heartbeat_internal())
yield
task.cancel() # 取消我们之前创建的任务
await clean_env() # 假设这是一个异步函数
# 其他FastAPI应用的代码...

View File

@ -1,9 +1,49 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager
from typing import Tuple, Callable from typing import Tuple, Callable
import sqlite3
import asyncio
import time
import random
app = FastAPI() @asynccontextmanager
async def lifespan(app: FastAPI):
init()
yield
clean_env()
app = FastAPI(lifespan = lifespan)
# 连接到数据库(如果数据库不存在,则会自动创建)
conn = sqlite3.connect('server.db')
# 创建游标对象用于执行SQL语句
cursor = conn.cursor()
# 创建表: id: int; ip: TEXT
cursor.execute('''CREATE TABLE IF NOT EXISTS nodes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
last_heartbeat INTEGER
)''')
def init():
task = asyncio.create_task(receive_heartbeat_internal())
def clean_env():
# 关闭游标和连接
cursor.close()
conn.close()
@app.get("/server/show_nodes")
async def show_nodes() -> list:
nodes_list = []
# 查询数据
cursor.execute("SELECT * FROM nodes")
rows = cursor.fetchall()
for row in rows:
nodes_list.append(row)
return nodes_list
@app.get("/server/get_node") @app.get("/server/get_node")
async def get_node(ip: str) -> int: async def get_node(ip: str) -> int:
@ -14,19 +54,57 @@ async def get_node(ip: str) -> int:
return: return:
id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id id: ip按点分割成四部分, 每部分转二进制后拼接再转十进制作为节点id
''' '''
# ip存入数据库, id = hash(int(ip))
ip_parts = ip.split(".") ip_parts = ip.split(".")
ip_int = 0 ip_int = 0
for i in range(4): for i in range(4):
ip_int += int(ip_parts[i]) << (24 - (8 * i)) ip_int += int(ip_parts[i]) << (24 - (8 * i))
# 获取当前时间
current_time = int(time.time())
# 插入数据
cursor.execute("INSERT INTO nodes (id, ip, last_heartbeat) VALUES (?, ?, ?)", (ip_int, ip, current_time))
conn.commit()
return ip_int return ip_int
@app.get("/server/delete_node") @app.get("/server/delete_node")
async def delete_node(ip: str) -> None: async def delete_node(ip: str) -> None:
# 按照节点ip遍历数据库, 删除该行数据 '''
param:
ip: 待删除节点的ip地址
return:
None
'''
# 查询要删除的节点
cursor.execute("SELECT * FROM nodes WHERE ip=?", (ip,))
row = cursor.fetchone()
if row is not None:
# 执行删除操作
cursor.execute("DELETE FROM nodes WHERE ip=?", (ip,))
conn.commit()
print(f"Node with IP {ip} deleted successfully.")
else:
print(f"Node with IP {ip} not found.")
@app.post("/server/send_nodes_list") # 接收节点心跳包
@app.post("/server/heartbeat")
async def receive_heartbeat(ip: str):
cursor.execute("UPDATE nodes SET last_heartbeat = ? WHERE ip = ?", (time.time(), ip))
return {"status": "received"}
async def receive_heartbeat_internal() -> int:
while 1:
print('successful delete1')
timeout = 10
# 删除超时的节点
cursor.execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,))
conn.commit()
print('successful delete')
await asyncio.sleep(timeout)
return 1
@app.get("/server/send_nodes_list")
async def send_nodes_list(count: int) -> JSONResponse: async def send_nodes_list(count: int) -> JSONResponse:
''' '''
中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成json格式返回给客户端 中心服务器与客户端交互, 客户端发送所需节点个数, 中心服务器从数据库中顺序取出节点封装成json格式返回给客户端
@ -36,9 +114,24 @@ async def send_nodes_list(count: int) -> JSONResponse:
JSONResponse: {id: ip,...} JSONResponse: {id: ip,...}
''' '''
nodes_list = {} nodes_list = {}
for i in range(count):
# 访问数据库取出节点数据 # 查询数据库中的节点数据
node = (id, ip) cursor.execute("SELECT * FROM nodes LIMIT ?", (count,))
nodes_list[node[0]] = node[1] rows = cursor.fetchall()
for row in rows:
id, ip, last_heartbeat = row
nodes_list[id] = ip
json_result = jsonable_encoder(nodes_list) json_result = jsonable_encoder(nodes_list)
return JSONResponse(content=json_result) return JSONResponse(content=json_result)
@app.get("/server/clear_database")
async def clear_database() -> None:
cursor.execute("DELETE FROM nodes")
conn.commit()
if __name__ == "__main__":
import uvicorn # pylint: disable=e0401
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)

View File

@ -1,6 +0,0 @@
from gmssl import * #pylint: disable = e0401
sm3 = Sm3() #pylint: disable = e0602
sm3.update(b'abc')
dgst = sm3.digest()
print("sm3('abc') : " + dgst.hex())