Merge pull request 'main' (#19) from sangge/mimajingsai:main into main

Reviewed-on: dqy/mimajingsai#19
This commit is contained in:
dqy 2023-10-23 16:03:14 +08:00
commit 9efc8e2c7b
2 changed files with 114 additions and 92 deletions

View File

@ -17,15 +17,8 @@ async def lifespan(app: FastAPI):
yield yield
clean_env() clean_env()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
pk = point
sk = int
server_address = str
node_response = False
message = bytes
def init(): def init():
global pk, sk, server_address global pk, sk, server_address
@ -100,7 +93,6 @@ class C(BaseModel):
Tuple: Tuple[capsule, int] Tuple: Tuple[capsule, int]
ip: str ip: str
# receive messages from node # receive messages from node
@app.post("/receive_messages") @app.post("/receive_messages")
async def receive_messages(message: C): async def receive_messages(message: C):
@ -178,7 +170,30 @@ async def check_merge(db, ct: int, ip: str):
# send message to node # send message to node
def send_message(ip: tuple[str, ...]): async def send_messages(node_ips: tuple[str, ...],
message: bytes,
dest_ip: str,
pk_B: point,
shreshold: int
):
global pk, sk
id_list = []
for node_ip in node_ips:
ip_parts = node_ip.split(".")
id = 0
for i in range(4):
id += int(ip_parts[i]) << (24 - (8 * i))
id_list.append(id)
rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore
for i in range(len(node_ips)):
url = "http://" + node_ips[i] + ":8001" + "/recieve_message"
payload = {
"source_ip": local_ip,
"dest_ip": dest_ip,
"message": message,
"rk": rk_list[i]
}
response = requests.post(url, json=payload)
return 0 return 0
@ -186,18 +201,22 @@ class IP_Message(BaseModel):
dest_ip: str dest_ip: str
message_name: str message_name: str
source_ip: str source_ip: str
pk: int
# request message from others # request message from others
@app.post("/request_message") @app.post("/request_message")
async def request_message(i_m: IP_Message): async def request_message(i_m: IP_Message):
global message, node_response global message, node_response, pk
dest_ip = i_m.dest_ip dest_ip = i_m.dest_ip
message_name = i_m.message_name message_name = i_m.message_name
source_ip = get_own_ip() source_ip = get_own_ip()
dest_port = "8003" dest_port = "8003"
url = "http://" + dest_ip + dest_port + "/recieve_request" url = "http://" + dest_ip + dest_port + "/recieve_request"
payload = {"dest_ip": dest_ip, "message_name": message_name, "source_ip": source_ip} payload = {"dest_ip": dest_ip,
"message_name": message_name,
"source_ip": source_ip,
"pk": pk
}
response = requests.post(url, json=payload) response = requests.post(url, json=payload)
if response.status_code == 200: if response.status_code == 200:
data = response.json() data = response.json()
@ -222,6 +241,7 @@ async def request_message(i_m: IP_Message):
# return message to frontend # return message to frontend
return {"message": data} return {"message": data}
time.sleep(1) time.sleep(1)
return {"message": "recieve timeout"}
# recieve request from others # recieve request from others
@ -233,8 +253,19 @@ async def recieve_request(i_m: IP_Message):
return HTTPException(status_code=400, detail="Wrong ip") return HTTPException(status_code=400, detail="Wrong ip")
dest_ip = i_m.source_ip dest_ip = i_m.source_ip
threshold = random.randrange(1, 6) threshold = random.randrange(1, 6)
public_key = pk own_public_key = pk
response = {"threshold": threshold,"public_key": public_key} pk_B = i_m.pk
with sqlite3.connect("client.db") as db:
cursor = db.execute("""
SELECT nodeip
FROM node
LIMIT ?
""",(threshold,))
node_ips = cursor.fetchall()
message = b"hello world" + random.randbytes(8)
await send_messages(node_ips, message, dest_ip, pk_B, threshold) # type: ignore
response = {"threshold": threshold,"public_key": own_public_key}
return response return response
@ -269,6 +300,14 @@ def get_node_list(count: int, server_addr: str):
print("Failed:", response.status_code, response.text) print("Failed:", response.status_code, response.text)
pk = point
sk = int
server_address = str
node_response = False
message = bytes
local_ip = get_own_ip()
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn # pylint: disable=e0401 import uvicorn # pylint: disable=e0401

View File

@ -1,7 +1,10 @@
from fastapi import FastAPI from fastapi import FastAPI,Request
import requests import requests
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import socket import socket
import asyncio
from pydantic import BaseModel
from tpre import *
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@ -12,18 +15,23 @@ async def lifespan(app: FastAPI):
clear() clear()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
server_address ="http://中心服务器IP地址:端口号/ip" server_address ="http://中心服务器IP地址/server"
id = 0 id = 0
ip = ''
client_ip_src = '' # 发送信息用户的ip
client_ip_des = '' # 接收信息用户的ip
processed_message = () # 重加密后的数据
# class C(BaseModel):
# Tuple: Tuple[capsule, int]
# ip_src: str
# 向中心服务器发送自己的IP地址,并获取自己的id # 向中心服务器发送自己的IP地址,并获取自己的id
def send_ip(ip: str): def send_ip():
url = server_address url = server_address + '/get_node?ip = ' + ip
# ip = get_local_ip # type: ignore # ip = get_local_ip # type: ignore
data = {"ip": ip} global id
response = requests.post(url, data=data) id = requests.get(url)
data = response.json()
id = data['id']
return id
# 用socket获取本机ip # 用socket获取本机ip
def get_local_ip(): def get_local_ip():
@ -34,86 +42,61 @@ def get_local_ip():
# 获取本地IP地址 # 获取本地IP地址
local_ip = s.getsockname()[0] local_ip = s.getsockname()[0]
s.close() s.close()
return local_ip global ip
ip = local_ip
id = int
def init(): def init():
ip = get_local_ip() get_local_ip()
global id global id
id = send_ip(ip) send_ip()
task = asyncio.create_task(send_heartbeat_internal())
def clear(): def clear():
pass pass
@app.post("/heartbeat/")
async def receive_heartbeat():
return {"status": "received"}
# 接收用户发来的消息,经过处理之后,再将消息发送给其他用户 # 接收用户发来的消息,经过处理之后,再将消息发送给其他用户
@app.post("/send_message")
async def send_message(message: str):
# 处理消息
processed_message = message.upper()
# 发送消息给其他用户
url = "http://其他用户IP地址:端口号/receive_message"
data = {"message": processed_message}
response = requests.post(url, data=data)
return response.json()
async def send_heartbeat_internal() -> None:
import requests
def send_heartbeat(url: str) -> bool:
try:
response = requests.get(url, timeout=5) # 使用 GET 方法作为心跳请求
response.raise_for_status() # 检查响应是否为 200 OK
# 可选:根据响应内容进行进一步验证
# if response.json() != expected_response:
# return False
return True
except requests.RequestException:
return False
# 使用方式
url = "https://your-service-url.com/heartbeat"
if send_heartbeat(url):
print("Service is alive!")
else:
print("Service might be down or unreachable.")
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI
async def receive_heartbeat_internal() -> int:
while True: while True:
print('successful delete1') # print('successful send my_heart')
timeout = 10 global ip
url = server_address + '/get_node?ip = ' + ip
folderol = requests.get(url)
timeout = 30
# 删除超时的节点(假设你有一个异步的数据库操作函数) # 删除超时的节点(假设你有一个异步的数据库操作函数)
await async_cursor_execute("DELETE FROM nodes WHERE last_heartbeat < ?", (time.time() - timeout,))
await async_conn_commit()
print('successful delete')
await asyncio.sleep(timeout) await asyncio.sleep(timeout)
return 1
@asynccontextmanager
async def lifespan(app: FastAPI):
task = asyncio.create_task(receive_heartbeat_internal())
yield
task.cancel() # 取消我们之前创建的任务
await clean_env() # 假设这是一个异步函数
# 其他FastAPI应用的代码... @app.post("/user_src") # 接收用户1发送的信息
async def receive_user_src_message(message: Request):
json_data = await message.json()
global client_ip_src,client_ip_des
# kfrag , capsule_ct ,client_ip_src , client_ip_des = json_data[] # 看梁俊勇
global processed_message
processed_message = ReEncrypt(kfrag, capsule_ct)
def send_user_des_message(): # 发送消息给用户2
global processed_message,client_ip_src,client_ip_des
data = {
"Tuple": processed_message, # 类型不匹配
"ip": client_ip_src
}
# 发送 HTTP POST 请求
response = requests.post("http://"+ client_ip_des + "/receive_messages", json=data)
print(response)
if __name__ == "__main__":
import uvicorn # pylint: disable=e0401
uvicorn.run("node:app", host="0.0.0.0", port=8000, reload=True)