refactor: refactor some code

This commit is contained in:
sangge-redmi 2024-09-02 11:21:39 +08:00
parent 8b89e1b722
commit 82309c5505
13 changed files with 73 additions and 349 deletions

View File

@ -2,14 +2,19 @@ from fastapi import FastAPI, HTTPException
import requests import requests
import os import os
from typing import Tuple from typing import Tuple
from tpre import * from tpre import (
GenerateKeyPair,
Encrypt,
DecryptFrags,
GenerateReKey,
MergeCFrag,
point,
)
import sqlite3 import sqlite3
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pydantic import BaseModel from pydantic import BaseModel
import socket import socket
import random import random
import time
import base64
import json import json
import pickle import pickle
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -18,12 +23,12 @@ import asyncio
# 测试文本 # 测试文本
test_msessgaes = { test_msessgaes = {
"name": b"proxy re-encryption", "name": b"proxy re-encryption",
"environment": b"distributed environment" "environment": b"distributed environment",
} }
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
init() init()
yield yield
clean_env() clean_env()
@ -200,7 +205,7 @@ def check_merge(ct: int, ip: str):
try: try:
pkx, pky = result[0] # result[0] = (pkx, pky) pkx, pky = result[0] # result[0] = (pkx, pky)
pk_sender = (int(pkx), int(pky)) pk_sender = (int(pkx), int(pky))
except: except IndexError:
pk_sender, T = 0, -1 pk_sender, T = 0, -1
T = 2 T = 2
@ -212,7 +217,7 @@ def check_merge(ct: int, ip: str):
byte_length = (ct.bit_length() + 7) // 8 byte_length = (ct.bit_length() + 7) // 8
temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(byte_length))) temp_cfrag_cts.append((capsule, int(i[1]).to_bytes(byte_length)))
cfrags = mergecfrag(temp_cfrag_cts) cfrags = MergeCFrag(temp_cfrag_cts)
print("sk:", type(sk)) print("sk:", type(sk))
print("pk:", type(pk)) print("pk:", type(pk))
@ -371,7 +376,7 @@ async def receive_request(i_m: IP_Message):
try: try:
message = test_msessgaes[i_m.message_name] message = test_msessgaes[i_m.message_name]
except: except IndexError:
message = b"hello world" + random.randbytes(8) message = b"hello world" + random.randbytes(8)
print(f"Message to be send: {message}") print(f"Message to be send: {message}")
@ -391,7 +396,7 @@ def get_own_ip() -> str:
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
ip = s.getsockname()[0] ip = s.getsockname()[0]
s.close() s.close()
except: except IndexError:
raise ValueError("Unable to get IP") raise ValueError("Unable to get IP")
return str(ip) return str(ip)

View File

@ -4,14 +4,14 @@ import json
def send_post_request(ip_addr, message_name): def send_post_request(ip_addr, message_name):
url = f"http://localhost:8002/request_message" url = "http://localhost:8002/request_message"
data = {"dest_ip": ip_addr, "message_name": message_name} data = {"dest_ip": ip_addr, "message_name": message_name}
response = requests.post(url, json=data) response = requests.post(url, json=data)
return response.text return response.text
def get_pk(ip_addr): def get_pk(ip_addr):
url = f"http://" + ip_addr + ":8002/get_pk" url = "http://" + ip_addr + ":8002/get_pk"
response = requests.get(url, timeout=1) response = requests.get(url, timeout=1)
print(response.text) print(response.text)
json_pk = json.loads(response.text) json_pk = json.loads(response.text)
@ -21,19 +21,18 @@ def get_pk(ip_addr):
return response.text return response.text
def main(): def main():
parser = argparse.ArgumentParser(description="Send POST request to a specified IP.") parser = argparse.ArgumentParser(description="Send POST request to a specified IP.")
parser.add_argument("ip_addr", help="IP address to send request to.") parser.add_argument("ip_addr", help="IP address to send request to.")
parser.add_argument("message_name", help="Message name to send.") parser.add_argument("message_name", help="Message name to send.")
args = parser.parse_args() args = parser.parse_args()
response = get_pk(args.ip_addr) response = get_pk(args.ip_addr)
print(response) print(response)
response = send_post_request(args.ip_addr, args.message_name) response = send_post_request(args.ip_addr, args.message_name)
print(response) print(response)

View File

@ -64,7 +64,7 @@ for N in range(4, 21, 2):
# 9 # 9
start_time = time.time() start_time = time.time()
cfrags = mergecfrag(cfrag_cts) cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time() end_time = time.time()
elapsed_time_dec = end_time - start_time elapsed_time_dec = end_time - start_time

View File

@ -1,4 +1,11 @@
from tpre import * from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time import time
N = 20 N = 20
@ -52,7 +59,7 @@ for i in range(1, 10):
# 9 # 9
start_time = time.time() start_time = time.time()
cfrags = mergecfrag(cfrag_cts) cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time

View File

@ -28,7 +28,7 @@ while total_time < 1:
total_time += elapsed_time total_time += elapsed_time
# print(f"加密算法运行时间:{elapsed_time}秒") # print(f"加密算法运行时间:{elapsed_time}秒")
# 3 # 3
pk_b, sk_b = GenerateKeyPair() pk_b, sk_b = GenerateKeyPair()
# 5 # 5
@ -54,7 +54,7 @@ while total_time < 1:
# 9 # 9
start_time = time.time() start_time = time.time()
cfrags = mergecfrag(cfrag_cts) cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time

View File

@ -1,18 +1,16 @@
from fastapi import FastAPI, Request, HTTPException from fastapi import FastAPI, HTTPException
import requests import requests
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import socket import socket
import asyncio import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from tpre import * from tpre import capsule, ReEncrypt
import os import os
from typing import Any, Tuple
import base64
import logging import logging
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
init() init()
yield yield
clear() clear()
@ -55,7 +53,7 @@ def get_local_ip():
s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP s.connect(("8.8.8.8", 80)) # 通过连接Google DNS获取IP
ip = str(s.getsockname()[0]) ip = str(s.getsockname()[0])
s.close() s.close()
except: except IndexError:
raise ValueError("Unable to get IP") raise ValueError("Unable to get IP")
@ -81,7 +79,7 @@ async def send_heartbeat_internal() -> None:
# print('successful send my_heart') # print('successful send my_heart')
try: try:
requests.get(url, timeout=3) requests.get(url, timeout=3)
except: except requests.exceptions.RequestException:
logger.error("Central server error") logger.error("Central server error")
print("Central server error") print("Central server error")
@ -139,7 +137,9 @@ async def user_src(message: Req):
return HTTPException(status_code=200, detail="message recieved") return HTTPException(status_code=200, detail="message recieved")
async def send_user_des_message(source_ip: str, dest_ip: str, re_message): # 发送消息给用户2 async def send_user_des_message(
source_ip: str, dest_ip: str, re_message
): # 发送消息给用户2
data = {"Tuple": re_message, "ip": source_ip} # 类型不匹配 data = {"Tuple": re_message, "ip": source_ip} # 类型不匹配
# 发送 HTTP POST 请求 # 发送 HTTP POST 请求

View File

@ -1,5 +1,4 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import sqlite3 import sqlite3
@ -9,7 +8,7 @@ import ipaddress
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_: FastAPI):
init() init()
yield yield
clean_env() clean_env()
@ -22,6 +21,7 @@ def init():
asyncio.create_task(receive_heartbeat_internal()) asyncio.create_task(receive_heartbeat_internal())
init_db() init_db()
def init_db(): def init_db():
conn = sqlite3.connect("server.db") conn = sqlite3.connect("server.db")
cursor = conn.cursor() cursor = conn.cursor()
@ -165,7 +165,8 @@ async def send_nodes_list(count: int) -> list:
rows = cursor.fetchall() rows = cursor.fetchall()
for row in rows: for row in rows:
id, ip, last_heartbeat = row # id, ip, last_heartbeat = row
_, ip, _ = row
nodes_list.append(ip) nodes_list.append(ip)
print("收到来自客户端的节点列表请求...") print("收到来自客户端的节点列表请求...")

View File

@ -1,7 +1,7 @@
import unittest import unittest
import sqlite3 import sqlite3
import os import os
from server import * from server import init_db
class TestServer(unittest.TestCase): class TestServer(unittest.TestCase):
@ -21,8 +21,6 @@ class TestServer(unittest.TestCase):
# 关闭数据库连接 # 关闭数据库连接
conn.close() conn.close()
os.remove("server.db") os.remove("server.db")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,14 +0,0 @@
from setuptools import setup, Extension
# 定义您的扩展
ext = Extension(
"tpreECC",
sources=["tpreECC.c"],
)
setup(
name="tpreECC",
version="1.0",
description="basic ECC written in C",
ext_modules=[ext],
)

View File

@ -1,4 +1,11 @@
from tpre import * from tpre import (
GenerateKeyPair,
Encrypt,
GenerateReKey,
ReEncrypt,
MergeCFrag,
DecryptFrags,
)
import time import time
N = 20 N = 20
@ -50,7 +57,7 @@ print(f"重加密算法运行时间:{elapsed_time}秒")
# 9 # 9
start_time = time.time() start_time = time.time()
cfrags = mergecfrag(cfrag_cts) cfrags = MergeCFrag(cfrag_cts)
m = DecryptFrags(sk_b, pk_b, pk_a, cfrags) m = DecryptFrags(sk_b, pk_b, pk_a, cfrags)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time

View File

@ -1,7 +1,6 @@
from gmssl import * # pylint: disable = e0401 from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT
from typing import Tuple, Callable from typing import Tuple
import random import random
import traceback
point = Tuple[int, int] point = Tuple[int, int]
capsule = Tuple[point, point, int] capsule = Tuple[point, point, int]
@ -29,7 +28,6 @@ sm2p256v1 = CurveFp(
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0, Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
) )
point = Tuple[int, int]
# 生成元 # 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy) g = (sm2p256v1.Gx, sm2p256v1.Gy)
@ -209,12 +207,13 @@ def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]:
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602 sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602
enc_Data = sm4_enc.update(m) enc_Data = sm4_enc.update(m)
enc_Data += sm4_enc.finish() enc_Data += sm4_enc.finish()
enc_message = (capsule, enc_Data) enc_message = (capsule, bytes(enc_Data))
return enc_message return enc_message
def Decapsulate(ska: int, capsule: capsule) -> int: def Decapsulate(ska: int, capsule: capsule) -> int:
E, V, s = capsule # E, V, s = capsule
E, V, _ = capsule
EVa = multiply(add(E, V), ska) # (E*V)^ska EVa = multiply(add(E, V), ska) # (E*V)^ska
K = KDF(EVa) K = KDF(EVa)
@ -233,7 +232,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602 sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
dec_Data = sm4_dec.update(enc_Data) dec_Data = sm4_dec.update(enc_Data)
dec_Data += sm4_dec.finish() dec_Data += sm4_dec.finish()
return dec_Data return bytes(dec_Data)
# GenerateRekey # GenerateRekey
@ -305,8 +304,9 @@ def GenerateReKey(
# 计算KF # 计算KF
KF = [] KF = []
for i in range(N): for i in range(N):
y = random.randint(0, sm2p256v1.N - 1) # seems unused?
Y = multiply(g, y) # y = random.randint(0, sm2p256v1.N - 1)
# Y = multiply(g, y)
s_x = hash5(id_tuple[i], D) # id需要设置 s_x = hash5(id_tuple[i], D) # id需要设置
r_k = f(s_x, f_modulus, T) r_k = f(s_x, f_modulus, T)
U1 = multiply(U, r_k) U1 = multiply(U, r_k)
@ -344,8 +344,10 @@ def Checkcapsule(capsule: capsule) -> bool: # 验证胶囊的有效性
def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, point]: def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, point]:
id, rk, Xa, U1 = kFrag # id, rk, Xa, U1 = kFrag
E, V, s = capsule id, rk, Xa, _ = kFrag
# E, V, s = capsule
E, V, _ = capsule
if not Checkcapsule(capsule): if not Checkcapsule(capsule):
raise ValueError("Invalid capsule") raise ValueError("Invalid capsule")
E1 = multiply(E, rk) E1 = multiply(E, rk)
@ -369,7 +371,7 @@ def ReEncrypt(
# 将加密节点加密后产生的t个capsule,ct合并在一起,产生cfrags = {{capsule1,capsule2,...},ct} # 将加密节点加密后产生的t个capsule,ct合并在一起,产生cfrags = {{capsule1,capsule2,...},ct}
def mergecfrag(cfrag_cts: list) -> list: def MergeCFrag(cfrag_cts: list) -> list:
ct_list = [] ct_list = []
cfrags_list = [] cfrags_list = []
cfrags = [] cfrags = []
@ -421,7 +423,9 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
Vk = multiply(Vlist[k], bis[k]) Vk = multiply(Vlist[k], bis[k])
E2 = add(Ek, E2) E2 = add(Ek, E2)
V2 = add(Vk, V2) V2 = add(Vk, V2)
X_Ab = multiply(X_Alist[0], sk_B) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值 X_Ab = multiply(
X_Alist[0], sk_B
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
d = hash3((X_Alist[0], pk_B, X_Ab)) d = hash3((X_Alist[0], pk_B, X_Ab))
EV = add(E2, V2) # E2 + V2 EV = add(E2, V2) # E2 + V2
EVd = multiply(EV, d) # (E2 + V2)^d EVd = multiply(EV, d) # (E2 + V2)^d
@ -447,4 +451,4 @@ def DecryptFrags(sk_B: int, pk_B: point, pk_A: point, cfrags: list) -> bytes:
print(e) print(e)
print("key error") print("key error")
dec_Data = b"" dec_Data = b""
return dec_Data return bytes(dec_Data)

View File

@ -1,284 +0,0 @@
#include <Python.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
// define TPRE Big Number
typedef uint64_t TPRE_BN[8]
// GF(p)
typedef TPRE_BN SM2_Fp;
// GF(n)
typedef TPRE_BN SM2_Fn;
// 定义一个结构体来表示雅各比坐标系的点
typedef struct
{
TPRE_BN X;
TPRE_BN Y;
TPRE_BN Z;
} JACOBIAN_POINT;
// 定义一个结构体来表示点
typedef struct
{
uint8_t x[32];
uint8_t y[32];
} TPRE_POINT;
const TPRE_BN SM2_P = {
0xffffffff,
0xffffffff,
0x00000000,
0xffffffff,
0xffffffff,
0xffffffff,
0xffffffff,
0xfffffffe,
};
const TPRE_BN SM2_A = {
0xfffffffc,
0xffffffff,
0x00000000,
0xffffffff,
0xffffffff,
0xffffffff,
0xffffffff,
0xfffffffe,
};
const TPRE_BN SM2_B = {
0x4d940e93,
0xddbcbd41,
0x15ab8f92,
0xf39789f5,
0xcf6509a7,
0x4d5a9e4b,
0x9d9f5e34,
0x28e9fa9e,
};
// 生成元GX, GY
const SM2_JACOBIAN_POINT _SM2_G = {
{
0x334c74c7,
0x715a4589,
0xf2660be1,
0x8fe30bbf,
0x6a39c994,
0x5f990446,
0x1f198119,
0x32c4ae2c,
},
{
0x2139f0a0,
0x02df32e5,
0xc62a4740,
0xd0a9877c,
0x6b692153,
0x59bdcee3,
0xf4f6779c,
0xbc3736a2,
},
{
1,
0,
0,
0,
0,
0,
0,
0,
},
};
const SM2_JACOBIAN_POINT *SM2_G = &_SM2_G;
const TPRE_BN SM2_N = {
0x39d54123,
0x53bbf409,
0x21c6052b,
0x7203df6b,
0xffffffff,
0xffffffff,
0xffffffff,
0xfffffffe,
};
// u = (p - 1)/4, u + 1 = (p + 1)/4
const TPRE_BN SM2_U_PLUS_ONE = {
0x00000000,
0x40000000,
0xc0000000,
0xffffffff,
0xffffffff,
0xffffffff,
0xbfffffff,
0x3fffffff,
};
const TPRE_BN SM2_ONE = {1, 0, 0, 0, 0, 0, 0, 0};
const TPRE_BN SM2_TWO = {2, 0, 0, 0, 0, 0, 0, 0};
const TPRE_BN SM2_THREE = {3, 0, 0, 0, 0, 0, 0, 0};
#define GETU32(p) \
((uint32_t)(p)[0] << 24 | \
(uint32_t)(p)[1] << 16 | \
(uint32_t)(p)[2] << 8 | \
(uint32_t)(p)[3])
// 点乘
static void multiply(TPRE_POINT r, const TPRE_POINT a, int64_t n)
{
Point result;
// ...实现乘法逻辑...
return result;
}
// 点加
static void add(TPRE_POINT *R, TPRE_POINT *P, TPRE_POINT *Q)
{
JACOBIAN_POINT P_;
JACOBIAN_POINT Q_;
jacobianPoint_from_bytes(&P_, (uint8_t *)P)
jacobianPoint_from_bytes(&Q_, (uint8_t *)Q)
jacobianPoint_add(&P_, &P_, &Q_);
jacobianPoint_to_bytes(&P_, (uint8_t *)R);
}
// 求逆
static void inv()
{
}
// jacobianPoint点加
static void jacobianPoint_add(JACOBIAN_POINT *R, const JACOBIAN_POINT *P, const JACOBIAN_POINT *Q)
{
const uint64_t *X1 = P->X;
const uint64_t *Y1 = P->Y;
const uint64_t *Z1 = P->Z;
const uint64_t *x2 = Q->X;
const uint64_t *y2 = Q->Y;
SM2_BN T1;
SM2_BN T2;
SM2_BN T3;
SM2_BN T4;
SM2_BN X3;
SM2_BN Y3;
SM2_BN Z3;
if (sm2_jacobian_point_is_at_infinity(Q))
{
sm2_jacobian_point_copy(R, P);
return;
}
if (sm2_jacobian_point_is_at_infinity(P))
{
sm2_jacobian_point_copy(R, Q);
return;
}
assert(sm2_bn_is_one(Q->Z));
sm2_fp_sqr(T1, Z1);
sm2_fp_mul(T2, T1, Z1);
sm2_fp_mul(T1, T1, x2);
sm2_fp_mul(T2, T2, y2);
sm2_fp_sub(T1, T1, X1);
sm2_fp_sub(T2, T2, Y1);
if (sm2_bn_is_zero(T1))
{
if (sm2_bn_is_zero(T2))
{
SM2_JACOBIAN_POINT _Q, *Q = &_Q;
sm2_jacobian_point_set_xy(Q, x2, y2);
sm2_jacobian_point_dbl(R, Q);
return;
}
else
{
sm2_jacobian_point_set_infinity(R);
return;
}
}
sm2_fp_mul(Z3, Z1, T1);
sm2_fp_sqr(T3, T1);
sm2_fp_mul(T4, T3, T1);
sm2_fp_mul(T3, T3, X1);
sm2_fp_dbl(T1, T3);
sm2_fp_sqr(X3, T2);
sm2_fp_sub(X3, X3, T1);
sm2_fp_sub(X3, X3, T4);
sm2_fp_sub(T3, T3, X3);
sm2_fp_mul(T3, T3, T2);
sm2_fp_mul(T4, T4, Y1);
sm2_fp_sub(Y3, T3, T4);
sm2_bn_copy(R->X, X3);
sm2_bn_copy(R->Y, Y3);
sm2_bn_copy(R->Z, Z3);
}
// bytes转jacobianPoint
static void jacobianPoint_from_bytes(JACOBIAN_POINT *P, const uint8_t in[64])
{
}
// jacobianPoint转bytes
static void jacobianPoint_to_bytes(JACOBIAN_POINT *P, const uint8_t in[64])
{
}
static void BN_from_bytes(TPRE_BN *r, const uint8_t in[32])
{
int i;
for (i = 7; i >= 0; i--)
{
r[i] = GETU32(in);
in += sizeof(uint32_t);
}
}
// 点乘的Python接口函数
static PyObject *py_multiply(PyObject *self, PyObject *args)
{
return
}
// 点加的Python接口函数
static PyObject *py_add(PyObject *self, PyObject *args)
{
return
}
// 求逆的Python接口函数
static PyObject *py_inv(PyObject *self, PyObject *args)
{
return
}
// 模块方法定义
static PyMethodDef MyMethods[] = {
{"multiply", py_multiply, METH_VARARGS, "Multiply a point on the sm2p256v1 curve"},
{"add", py_add, METH_VARARGS, "Add a point on thesm2p256v1 curve"},
{"inv", py_inv, METH_VARARGS, "Calculate an inv of a number"},
{NULL, NULL, 0, NULL} // 哨兵
};
// 模块定义
static struct PyModuleDef tpreECC = {
PyModuleDef_HEAD_INIT,
"tpreECC",
NULL, // 模块文档
-1,
MyMethods};
// 初始化模块
PyMODINIT_FUNC PyInit_tpre(void)
{
return PyModule_Create(&tpreECC);
}

View File

@ -1,4 +1,5 @@
from tpre import * from tpre import hash2, hash3, hash4, multiply, g, sm2p256v1
import random
import unittest import unittest