refactor: refactor some code

This commit is contained in:
2024-09-02 11:21:39 +08:00
parent 8b89e1b722
commit 82309c5505
13 changed files with 73 additions and 349 deletions

View File

@@ -1,7 +1,6 @@
from gmssl import * # pylint: disable = e0401
from typing import Tuple, Callable
from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT
from typing import Tuple
import random
import traceback
point = Tuple[int, int]
capsule = Tuple[point, point, int]
@@ -29,7 +28,6 @@ sm2p256v1 = CurveFp(
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
point = Tuple[int, int]
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
@@ -209,12 +207,13 @@ def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]:
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602
enc_Data = sm4_enc.update(m)
enc_Data += sm4_enc.finish()
enc_message = (capsule, enc_Data)
enc_message = (capsule, bytes(enc_Data))
return enc_message
def Decapsulate(ska: int, capsule: capsule) -> int:
E, V, s = capsule
# E, V, s = capsule
E, V, _ = capsule
EVa = multiply(add(E, V), ska) # (E*V)^ska
K = KDF(EVa)
@@ -233,7 +232,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
dec_Data = sm4_dec.update(enc_Data)
dec_Data += sm4_dec.finish()
return dec_Data
return bytes(dec_Data)
# GenerateRekey
@@ -305,8 +304,9 @@ def GenerateReKey(
# 计算KF
KF = []
for i in range(N):
y = random.randint(0, sm2p256v1.N - 1)
Y = multiply(g, y)
# seems unused?
# y = random.randint(0, sm2p256v1.N - 1)
# Y = multiply(g, y)
s_x = hash5(id_tuple[i], D) # id需要设置
r_k = f(s_x, f_modulus, T)
U1 = multiply(U, r_k)
@@ -344,8 +344,10 @@ def Checkcapsule(capsule: capsule) -> bool: # 验证胶囊的有效性
def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, point]:
id, rk, Xa, U1 = kFrag
E, V, s = capsule
# id, rk, Xa, U1 = kFrag
id, rk, Xa, _ = kFrag
# E, V, s = capsule
E, V, _ = capsule
if not Checkcapsule(capsule):
raise ValueError("Invalid capsule")
E1 = multiply(E, rk)
@@ -369,7 +371,7 @@ def ReEncrypt(
# 将加密节点加密后产生的t个capsule,ct合并在一起,产生cfrags = {{capsule1,capsule2,...},ct}
def mergecfrag(cfrag_cts: list) -> list:
def MergeCFrag(cfrag_cts: list) -> list:
ct_list = []
cfrags_list = []
cfrags = []
@@ -421,7 +423,9 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
Vk = multiply(Vlist[k], bis[k])
E2 = add(Ek, E2)
V2 = add(Vk, V2)
X_Ab = multiply(X_Alist[0], sk_B) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
X_Ab = multiply(
X_Alist[0], sk_B
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
d = hash3((X_Alist[0], pk_B, X_Ab))
EV = add(E2, V2) # E2 + V2
EVd = multiply(EV, d) # (E2 + V2)^d
@@ -447,4 +451,4 @@ def DecryptFrags(sk_B: int, pk_B: point, pk_A: point, cfrags: list) -> bytes:
print(e)
print("key error")
dec_Data = b""
return dec_Data
return bytes(dec_Data)