forked from sangge/tpre-python
fix: fix all bug
This commit is contained in:
parent
6df22e2072
commit
bfb5c27bcb
116
src/client.py
116
src/client.py
@ -13,7 +13,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -89,6 +89,9 @@ def init_config():
|
|||||||
|
|
||||||
# execute on exit
|
# execute on exit
|
||||||
def clean_env():
|
def clean_env():
|
||||||
|
global message, node_response
|
||||||
|
message = b""
|
||||||
|
node_response = False
|
||||||
with sqlite3.connect("client.db") as db:
|
with sqlite3.connect("client.db") as db:
|
||||||
db.execute("DELETE FROM node")
|
db.execute("DELETE FROM node")
|
||||||
db.execute("DELETE FROM message")
|
db.execute("DELETE FROM message")
|
||||||
@ -141,7 +144,7 @@ async def receive_messages(message: C):
|
|||||||
(bin_C_capsule, str(C_ct), ip),
|
(bin_C_capsule, str(C_ct), ip),
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
await check_merge(C_ct, ip)
|
check_merge(C_ct, ip)
|
||||||
return HTTPException(status_code=200, detail="Message received")
|
return HTTPException(status_code=200, detail="Message received")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error occurred: {e}")
|
print(f"Error occurred: {e}")
|
||||||
@ -150,13 +153,14 @@ async def receive_messages(message: C):
|
|||||||
|
|
||||||
|
|
||||||
# check record count
|
# check record count
|
||||||
async def check_merge(ct: int, ip: str):
|
def check_merge(ct: int, ip: str):
|
||||||
global sk, pk, node_response, message
|
global sk, pk, node_response, message
|
||||||
"""
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS senderinfo (
|
CREATE TABLE IF NOT EXISTS senderinfo (
|
||||||
id INTEGER PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
ip TEXT,
|
ip TEXT,
|
||||||
publickey TEXT,
|
pkx TEXT,
|
||||||
|
pky TEXT,
|
||||||
threshold INTEGER
|
threshold INTEGER
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
@ -173,10 +177,10 @@ async def check_merge(ct: int, ip: str):
|
|||||||
# [(capsule, ct), ...]
|
# [(capsule, ct), ...]
|
||||||
cfrag_cts = cursor.fetchall()
|
cfrag_cts = cursor.fetchall()
|
||||||
|
|
||||||
# get T
|
# get _sender_pk
|
||||||
cursor = db.execute(
|
cursor = db.execute(
|
||||||
"""
|
"""
|
||||||
SELECT publickey, threshold
|
SELECT pkx, pky
|
||||||
FROM senderinfo
|
FROM senderinfo
|
||||||
WHERE ip = ?
|
WHERE ip = ?
|
||||||
""",
|
""",
|
||||||
@ -184,21 +188,31 @@ async def check_merge(ct: int, ip: str):
|
|||||||
)
|
)
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
try:
|
try:
|
||||||
pk_sender, T = result[0] # result[0] = (pk, threshold)
|
pkx, pky = result[0] # result[0] = (pkx, pky)
|
||||||
|
pk_sender = (int(pkx), int(pky))
|
||||||
except:
|
except:
|
||||||
pk_sender, T = 0, -1
|
pk_sender, T = 0, -1
|
||||||
|
|
||||||
if len(cfrag_cts) <= T:
|
T = 2
|
||||||
print(T)
|
if len(cfrag_cts) >= T:
|
||||||
# Deserialization
|
# Deserialization
|
||||||
temp_cfrag_cts = []
|
temp_cfrag_cts = []
|
||||||
for i in cfrag_cts:
|
for i in cfrag_cts:
|
||||||
capsule = pickle.loads(i[0])
|
capsule = pickle.loads(i[0])
|
||||||
temp_cfrag_cts.append((capsule, int(i[1])))
|
temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(32)))
|
||||||
|
|
||||||
cfrags = mergecfrag(temp_cfrag_cts)
|
cfrags = mergecfrag(temp_cfrag_cts)
|
||||||
message = DecryptFrags(sk, pk, pk_sender, cfrags) # type: ignore
|
|
||||||
|
print("sk:", type(sk))
|
||||||
|
print("pk:", type(pk))
|
||||||
|
print("pk_sender:", type(pk_sender))
|
||||||
|
print("cfrags:", type(cfrags))
|
||||||
|
message = DecryptFrags(sk, pk, pk_sender, cfrags)
|
||||||
|
|
||||||
|
print("merge success", message)
|
||||||
node_response = True
|
node_response = True
|
||||||
|
|
||||||
|
print("merge:", node_response)
|
||||||
|
|
||||||
|
|
||||||
# send message to node
|
# send message to node
|
||||||
@ -276,44 +290,19 @@ async def request_message(i_m: Request_Message):
|
|||||||
# content = {"message": "post timeout", "error": str(e)}
|
# content = {"message": "post timeout", "error": str(e)}
|
||||||
# return JSONResponse(content, status_code=400)
|
# return JSONResponse(content, status_code=400)
|
||||||
|
|
||||||
try:
|
|
||||||
url = "http://" + dest_ip + ":" + dest_port + "/get_pk"
|
|
||||||
print(url)
|
|
||||||
response = requests.get(url,timeout=4)
|
|
||||||
print(response.text)
|
|
||||||
if response.status_code == 200:
|
|
||||||
data = response.json()
|
|
||||||
pkx = int(data["pkx"])
|
|
||||||
pky = int(data["pky"])
|
|
||||||
public_key = (pkx, pky)
|
|
||||||
threshold = 2
|
|
||||||
with sqlite3.connect("client.db") as db:
|
|
||||||
db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO senderinfo
|
|
||||||
(ip, public_key, threshold)
|
|
||||||
VALUES
|
|
||||||
(?, ?, ?)
|
|
||||||
""",
|
|
||||||
(str(dest_ip), public_key, threshold),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print("Database error")
|
|
||||||
content = {"message": "Database Error","error": str(e)}
|
|
||||||
return JSONResponse(content, status_code=400)
|
|
||||||
|
|
||||||
# wait 3s to receive message from nodes
|
# wait 3s to receive message from nodes
|
||||||
for _ in range(3):
|
for _ in range(10):
|
||||||
|
print("wait:", node_response)
|
||||||
if node_response:
|
if node_response:
|
||||||
data = message
|
data = message
|
||||||
|
|
||||||
# reset message and node_response
|
# reset message and node_response
|
||||||
message = b""
|
message = b""
|
||||||
node_response = False
|
node_response = False
|
||||||
|
|
||||||
# return message to frontend
|
# return message to frontend
|
||||||
return {"message": data}
|
return {"message": str(data)}
|
||||||
time.sleep(1)
|
await asyncio.sleep(0.2)
|
||||||
content = {"message": "receive timeout"}
|
content = {"message": "receive timeout"}
|
||||||
return JSONResponse(content, status_code=400)
|
return JSONResponse(content, status_code=400)
|
||||||
|
|
||||||
@ -381,15 +370,50 @@ def get_node_list(count: int, server_addr: str):
|
|||||||
print("Success add node ip")
|
print("Success add node ip")
|
||||||
else:
|
else:
|
||||||
print("Failed:", response.status_code, response.text)
|
print("Failed:", response.status_code, response.text)
|
||||||
|
|
||||||
|
|
||||||
|
# send pk to others
|
||||||
@app.get("/get_pk")
|
@app.get("/get_pk")
|
||||||
async def get_pk():
|
async def get_pk():
|
||||||
global pk
|
global pk, sk
|
||||||
|
print(sk)
|
||||||
return {"pkx": str(pk[0]), "pky": str(pk[1])}
|
return {"pkx": str(pk[0]), "pky": str(pk[1])}
|
||||||
|
|
||||||
|
|
||||||
pk = point
|
class pk_model(BaseModel):
|
||||||
sk = int
|
pkx: str
|
||||||
|
pky: str
|
||||||
|
ip: str
|
||||||
|
|
||||||
|
|
||||||
|
# recieve pk from frontend
|
||||||
|
@app.post("/recieve_pk")
|
||||||
|
async def recieve_pk(pk: pk_model):
|
||||||
|
pkx = pk.pkx
|
||||||
|
pky = pk.pky
|
||||||
|
dest_ip = pk.ip
|
||||||
|
try:
|
||||||
|
threshold = 2
|
||||||
|
with sqlite3.connect("client.db") as db:
|
||||||
|
db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO senderinfo
|
||||||
|
(ip, pkx, pky, threshold)
|
||||||
|
VALUES
|
||||||
|
(?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(str(dest_ip), pkx, pky, threshold),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# raise error
|
||||||
|
print("Database error")
|
||||||
|
content = {"message": "Database Error", "error": str(e)}
|
||||||
|
return JSONResponse(content, status_code=400)
|
||||||
|
return {"message": "save pk in database"}
|
||||||
|
|
||||||
|
|
||||||
|
pk = (0, 0)
|
||||||
|
sk = 0
|
||||||
server_address = str
|
server_address = str
|
||||||
node_response = False
|
node_response = False
|
||||||
message = bytes
|
message = bytes
|
||||||
@ -398,4 +422,4 @@ local_ip = get_own_ip()
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn # pylint: disable=e0401
|
import uvicorn # pylint: disable=e0401
|
||||||
|
|
||||||
uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True,log_level="debug")
|
uvicorn.run("client:app", host="0.0.0.0", port=8002, reload=True, log_level="debug")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import requests
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
def send_post_request(ip_addr, message_name):
|
def send_post_request(ip_addr, message_name):
|
||||||
@ -9,13 +10,30 @@ def send_post_request(ip_addr, message_name):
|
|||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
|
|
||||||
|
def get_pk(ip_addr):
|
||||||
|
url = f"http://" + ip_addr + ":8002/get_pk"
|
||||||
|
response = requests.get(url, timeout=1)
|
||||||
|
print(response.text)
|
||||||
|
json_pk = json.loads(response.text)
|
||||||
|
payload = {"pkx": json_pk["pkx"], "pky": json_pk["pky"], "ip": ip_addr}
|
||||||
|
response = requests.post("http://localhost:8002/recieve_pk", json=payload)
|
||||||
|
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Send POST request to a specified IP.")
|
parser = argparse.ArgumentParser(description="Send POST request to a specified IP.")
|
||||||
parser.add_argument("ip_addr", help="IP address to send request to.")
|
parser.add_argument("ip_addr", help="IP address to send request to.")
|
||||||
parser.add_argument("message_name", help="Message name to send.")
|
parser.add_argument("message_name", help="Message name to send.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
response = get_pk(args.ip_addr)
|
||||||
|
print(response)
|
||||||
|
|
||||||
response = send_post_request(args.ip_addr, args.message_name)
|
response = send_post_request(args.ip_addr, args.message_name)
|
||||||
|
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
4
src/client_demo.py
Normal file
4
src/client_demo.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from tpre import *
|
||||||
|
|
||||||
|
# local {"pkx":"110913495319893280527511520027612816833094668640322629943553195742251267532611","pky":"42442813417048462506373786007682778510807282038950736216326706485290996455738"}
|
||||||
|
# pkb (110913495319893280527511520027612816833094668640322629943553195742251267532611,42442813417048462506373786007682778510807282038950736216326706485290996455738
|
@ -272,7 +272,9 @@ def f(x: int, f_modulus: list, T: int) -> int:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def GenerateReKey(sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int,...]) -> list:
|
def GenerateReKey(
|
||||||
|
sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...]
|
||||||
|
) -> list:
|
||||||
"""
|
"""
|
||||||
param:
|
param:
|
||||||
skA, pkB, N(节点总数), T(阈值)
|
skA, pkB, N(节点总数), T(阈值)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user