main #22

Merged
ccyj merged 68 commits from sangge/tpre-python:main into main 2023-10-26 14:45:47 +08:00
2 changed files with 23 additions and 17 deletions
Showing only changes of commit 087e429ef7 - Show all commits

View File

@ -11,6 +11,7 @@ import random
import time
import base64
import json
import pickle
@asynccontextmanager
@ -40,7 +41,7 @@ def init_db():
"""
CREATE TABLE IF NOT EXISTS message (
id INTEGER PRIMARY KEY,
capsule TEXT,
capsule BLOB,
ct TEXT,
senderip TEXT
);
@ -113,29 +114,27 @@ async def receive_messages(message: C):
return:
status_code
"""
a, b = message.Tuple
C_tuple = (a, b)
ip = message.ip
if not C_tuple or not ip:
if not message.Tuple or not message.ip:
raise HTTPException(status_code=400, detail="Invalid input data")
C_capsule = C_tuple[0]
C_ct = C_tuple[1]
if not Checkcapsule(C_capsule):
raise HTTPException(status_code=400, detail="Invalid capsule")
C_capsule, C_ct = message.Tuple
ip = message.ip
# Serialization
bin_C_capsule = pickle.dumps(C_capsule)
# insert record into database
with sqlite3.connect("message.db") as db:
with sqlite3.connect("client.db") as db:
try:
db.execute(
"""
INSERT INTO message
(capsule_column, ct_column, ip_column)
(capsule, ct, senderip)
VALUES
(?, ?, ?)
""",
(C_capsule, C_ct, ip),
(bin_C_capsule, C_ct, ip),
)
db.commit()
await check_merge(C_ct, ip)
@ -175,7 +174,13 @@ async def check_merge(ct: int, ip: str):
pk_sender, T = result[0] # result[0] = (pk, threshold)
if len(cfrag_cts) >= T:
cfrags = mergecfrag(cfrag_cts)
# Deserialization
temp_cfrag_cts = []
for i in cfrag_cts:
capsule = pickle.loads(i[0])
temp_cfrag_cts.append((capsule, i[1]))
cfrags = mergecfrag(temp_cfrag_cts)
message = DecryptFrags(sk, pk, pk_sender, cfrags) # type: ignore
node_response = True
@ -199,7 +204,7 @@ async def send_messages(
rk_list = GenerateReKey(sk, pk_B, len(node_ips), shreshold, tuple(id_list)) # type: ignore
capsule, ct = Encrypt(pk, message) # type: ignore
#capsule_ct = (capsule, int.from_bytes(ct))
# capsule_ct = (capsule, int.from_bytes(ct))
for i in range(len(node_ips)):
url = "http://" + node_ips[i][0] + ":8001" + "/user_src"

View File

@ -78,7 +78,7 @@ class Req(BaseModel):
dest_ip: str
capsule: capsule
ct: int
rk: Any
rk: list
@app.post("/user_src") # 接收用户1发送的信息
@ -93,6 +93,7 @@ async def user_src(message: Req):
"rk": rk_list[i],
}
"""
print("node: ", message)
source_ip = message.source_ip
dest_ip = message.dest_ip
capsule = message.capsule
@ -113,7 +114,7 @@ async def send_user_des_message(source_ip: str, dest_ip: str, re_message): #
response = requests.post(
"http://" + dest_ip + ":8002" + "/receive_messages", json=data
)
print(response.text)
print("send stauts:" ,response.text)
if __name__ == "__main__":