feat: update ecc operation call
This commit is contained in:
parent
dd86253162
commit
5144334558
@ -19,3 +19,6 @@ dependencies = ["gmssl-python>=2.2.2,<3.0.0", "fastapi", "uvicorn", "requests"]
|
||||
[project.optional-dependencies]
|
||||
test = ["httpx", "pytest"]
|
||||
dev = ["httpx", "pytest", "pyright", "ruff"]
|
||||
|
||||
[tool.ruff]
|
||||
select = ["E", "F", "N", "B", "I", "C4", "UP", "SIM"]
|
||||
|
223
src/tpre.py
223
src/tpre.py
@ -1,62 +1,144 @@
|
||||
from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT
|
||||
from typing import Tuple
|
||||
"""这个模块实现了椭圆曲线加密(ECC)的基本操作.
|
||||
|
||||
它提供了点加法、点乘法和其他与SM2P256V1曲线相关的操作,
|
||||
并支持可选的Rust实现来提高性能。
|
||||
"""
|
||||
|
||||
import random
|
||||
import ecc_rs
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
from gmssl import DO_DECRYPT, DO_ENCRYPT, Sm2Key, Sm3, Sm4Cbc
|
||||
|
||||
point = Tuple[int, int]
|
||||
capsule = Tuple[point, point, int]
|
||||
|
||||
try:
|
||||
import ecc_rs
|
||||
|
||||
# 生成密钥对模块
|
||||
class CurveFp:
|
||||
def __init__(self, A, B, P, N, Gx, Gy, name):
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.P = P
|
||||
self.N = N
|
||||
self.Gx = Gx
|
||||
self.Gy = Gy
|
||||
self.name = name
|
||||
RUST_ECC = True
|
||||
|
||||
def add(a: point, b: point) -> point:
|
||||
"""Add two points on the curve.
|
||||
|
||||
sm2p256v1 = CurveFp(
|
||||
name="sm2p256v1",
|
||||
A=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
|
||||
B=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
|
||||
P=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
|
||||
N=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
|
||||
Gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
|
||||
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
|
||||
)
|
||||
Args:
|
||||
a: first point
|
||||
b: second point
|
||||
flag: if 1, use Rust implementation
|
||||
|
||||
"""
|
||||
if RUST_ECC is True and ecc_rs is not None:
|
||||
return ecc_rs.add(a, b)
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
|
||||
|
||||
# 生成元
|
||||
g = (sm2p256v1.Gx, sm2p256v1.Gy)
|
||||
def multiply(a: point, n: int) -> point:
|
||||
"""Multiply a point by a scalar.
|
||||
|
||||
Args:
|
||||
a: point
|
||||
n: scalar
|
||||
flag: if 1, use Rust implementation
|
||||
|
||||
def multiply(a: point, n: int, flag: int = 0) -> point:
|
||||
if flag == 1:
|
||||
"""
|
||||
if RUST_ECC is True and ecc_rs is not None:
|
||||
|
||||
result = ecc_rs.multiply(a, n)
|
||||
return result
|
||||
else:
|
||||
return ecc_rs.multiply(a, n)
|
||||
N = sm2p256v1.N
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
|
||||
|
||||
except ImportError:
|
||||
RUST_ECC = False
|
||||
ecc_rs = None
|
||||
|
||||
def add(a: point, b: point) -> point:
|
||||
"""Add two points on the curve.
|
||||
|
||||
Args:
|
||||
a: first point
|
||||
b: second point
|
||||
|
||||
"""
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
|
||||
|
||||
def multiply(a: point, n: int) -> point:
|
||||
"""Multiply a point by a scalar.
|
||||
|
||||
Args:
|
||||
a: point
|
||||
n: scalar
|
||||
|
||||
"""
|
||||
N = sm2p256v1.N
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
|
||||
|
||||
|
||||
def add(a: point, b: point, flag: int = 0) -> point:
|
||||
if flag == 1:
|
||||
result = ecc_rs.add(a, b)
|
||||
return result
|
||||
else:
|
||||
A = sm2p256v1.A
|
||||
P = sm2p256v1.P
|
||||
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
|
||||
@dataclass
|
||||
class CurveParams:
|
||||
"""Definition of SM2P256V1 curve parameters."""
|
||||
|
||||
a: int
|
||||
b: int
|
||||
p: int
|
||||
n: int
|
||||
gx: int
|
||||
gy: int
|
||||
name: str
|
||||
|
||||
|
||||
# 生成密钥对模块
|
||||
class CurveFp:
|
||||
"""Definition of SM2P256V1 curve class."""
|
||||
|
||||
def __init__(self, params: CurveParams) -> None:
|
||||
"""Initialize curve with parameters.
|
||||
|
||||
Args:
|
||||
params: Collection of curve parameters
|
||||
|
||||
"""
|
||||
self.A = params.a
|
||||
self.B = params.b
|
||||
self.P = params.p
|
||||
self.N = params.n
|
||||
self.Gx = params.gx
|
||||
self.Gy = params.gy
|
||||
self.name = params.name
|
||||
|
||||
|
||||
curve_params = CurveParams(
|
||||
name="sm2p256v1",
|
||||
a=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
|
||||
b=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
|
||||
p=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
|
||||
n=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
|
||||
gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
|
||||
gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
|
||||
)
|
||||
sm2p256v1 = CurveFp(params=curve_params)
|
||||
|
||||
|
||||
# 生成元
|
||||
g = (sm2p256v1.Gx, sm2p256v1.Gy)
|
||||
|
||||
|
||||
def inv(a: int, n: int) -> int:
|
||||
"""Return the modular inverse of a mod n.
|
||||
|
||||
Args:
|
||||
a: input integer
|
||||
n: modulus
|
||||
Returns:
|
||||
modular inverse of a mod n
|
||||
|
||||
"""
|
||||
if a == 0:
|
||||
return 0
|
||||
lm, hm = 1, 0
|
||||
@ -80,7 +162,9 @@ def fromJacobian(Xp_Yp_Zp: Tuple[int, int, int], P: int) -> point:
|
||||
|
||||
|
||||
def jacobianDouble(
|
||||
Xp_Yp_Zp: Tuple[int, int, int], A: int, P: int
|
||||
Xp_Yp_Zp: Tuple[int, int, int],
|
||||
A: int,
|
||||
P: int,
|
||||
) -> Tuple[int, int, int]:
|
||||
Xp, Yp, Zp = Xp_Yp_Zp
|
||||
if not Yp:
|
||||
@ -95,7 +179,10 @@ def jacobianDouble(
|
||||
|
||||
|
||||
def jacobianAdd(
|
||||
Xp_Yp_Zp: Tuple[int, int, int], Xq_Yq_Zq: Tuple[int, int, int], A: int, P: int
|
||||
Xp_Yp_Zp: Tuple[int, int, int],
|
||||
Xq_Yq_Zq: Tuple[int, int, int],
|
||||
A: int,
|
||||
P: int,
|
||||
) -> Tuple[int, int, int]:
|
||||
Xp, Yp, Zp = Xp_Yp_Zp
|
||||
Xq, Yq, Zq = Xq_Yq_Zq
|
||||
@ -123,7 +210,11 @@ def jacobianAdd(
|
||||
|
||||
|
||||
def jacobianMultiply(
|
||||
Xp_Yp_Zp: Tuple[int, int, int], n: int, N: int, A: int, P: int
|
||||
Xp_Yp_Zp: Tuple[int, int, int],
|
||||
n: int,
|
||||
N: int,
|
||||
A: int,
|
||||
P: int,
|
||||
) -> Tuple[int, int, int]:
|
||||
Xp, Yp, Zp = Xp_Yp_Zp
|
||||
if Yp == 0 or n == 0:
|
||||
@ -191,9 +282,9 @@ def KDF(G: point) -> int:
|
||||
|
||||
|
||||
def GenerateKeyPair() -> Tuple[point, int]:
|
||||
"""
|
||||
return:
|
||||
"""return:
|
||||
public_key, secret_key
|
||||
|
||||
"""
|
||||
sm2 = Sm2Key() # pylint: disable=e0602
|
||||
sm2.generate_key()
|
||||
@ -207,15 +298,15 @@ def GenerateKeyPair() -> Tuple[point, int]:
|
||||
return public_key, secret_key
|
||||
|
||||
|
||||
def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]:
|
||||
enca = Encapsulate(pk)
|
||||
def Encrypt(public_key: point, message: bytes) -> Tuple[capsule, bytes]:
|
||||
enca = Encapsulate(public_key)
|
||||
K = enca[0].to_bytes(16)
|
||||
capsule = enca[1]
|
||||
if len(K) != 16:
|
||||
raise ValueError("invalid key length")
|
||||
iv = b"tpretpretpretpre"
|
||||
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602
|
||||
enc_Data = sm4_enc.update(m)
|
||||
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT)
|
||||
enc_Data = sm4_enc.update(message)
|
||||
enc_Data += sm4_enc.finish()
|
||||
enc_message = (capsule, bytes(enc_Data))
|
||||
return enc_message
|
||||
@ -231,15 +322,14 @@ def Decapsulate(ska: int, capsule: capsule) -> int:
|
||||
|
||||
|
||||
def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
|
||||
"""
|
||||
params:
|
||||
"""params:
|
||||
sk_A: secret key
|
||||
C: (capsule, enc_data)
|
||||
"""
|
||||
capsule, enc_Data = C
|
||||
K = Decapsulate(sk_A, capsule)
|
||||
K = Decapsulate(sk_A, capsule).to_bytes(16)
|
||||
iv = b"tpretpretpretpre"
|
||||
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
|
||||
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT)
|
||||
dec_Data = sm4_dec.update(enc_Data)
|
||||
dec_Data += sm4_dec.finish()
|
||||
return bytes(dec_Data)
|
||||
@ -247,7 +337,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
|
||||
|
||||
# GenerateRekey
|
||||
def hash5(id: int, D: int) -> int:
|
||||
sm3 = Sm3() # pylint: disable=e0602
|
||||
sm3 = Sm3()
|
||||
sm3.update(id.to_bytes(32))
|
||||
sm3.update(D.to_bytes(32))
|
||||
hash = sm3.digest()
|
||||
@ -266,13 +356,14 @@ def hash6(triple_G: Tuple[point, point, point]) -> int:
|
||||
|
||||
|
||||
def f(x: int, f_modulus: list, T: int) -> int:
|
||||
"""
|
||||
功能: 通过多项式插值来实现信息的分散和重构
|
||||
"""功能: 通过多项式插值来实现信息的分散和重构
|
||||
例如: 随机生成一个多项式f(x)=4x+5,质数P=11,其中f(0)=5,将多项式的系数分别分配给两个人,例如第一个人得到(1, 9),第二个人得到(2, 2).如果两个人都收集到了这两个点,那么可以使用拉格朗日插值法恢复原始的多项式,进而得到秘密信息"5"
|
||||
param:
|
||||
x, f_modulus(多项式系数列表), T(门限)
|
||||
return:
|
||||
|
||||
Return:
|
||||
res
|
||||
|
||||
"""
|
||||
res = 0
|
||||
for i in range(T):
|
||||
@ -282,13 +373,18 @@ def f(x: int, f_modulus: list, T: int) -> int:
|
||||
|
||||
|
||||
def GenerateReKey(
|
||||
sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...]
|
||||
sk_A: int,
|
||||
pk_B: point,
|
||||
N: int,
|
||||
T: int,
|
||||
id_tuple: Tuple[int, ...],
|
||||
) -> list:
|
||||
"""
|
||||
param:
|
||||
"""param:
|
||||
skA, pkB, N(节点总数), T(阈值)
|
||||
return:
|
||||
|
||||
Return:
|
||||
rki(0 <= i <= N-1)
|
||||
|
||||
"""
|
||||
# 计算临时密钥对(x_A, X_A)
|
||||
x_A = random.randint(0, sm2p256v1.N - 1)
|
||||
@ -369,7 +465,8 @@ def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, po
|
||||
|
||||
|
||||
def ReEncrypt(
|
||||
kFrag: tuple, C: Tuple[capsule, bytes]
|
||||
kFrag: tuple,
|
||||
C: Tuple[capsule, bytes],
|
||||
) -> Tuple[Tuple[point, point, int, point], bytes]:
|
||||
capsule, enc_Data = C
|
||||
|
||||
@ -394,11 +491,10 @@ def MergeCFrag(cfrag_cts: list) -> list:
|
||||
|
||||
|
||||
def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
|
||||
"""
|
||||
return:
|
||||
"""return:
|
||||
K: sm4 key
|
||||
"""
|
||||
|
||||
"""
|
||||
Elist = []
|
||||
Vlist = []
|
||||
idlist = []
|
||||
@ -434,7 +530,8 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
|
||||
E2 = add(Ek, E2)
|
||||
V2 = add(Vk, V2)
|
||||
X_Ab = multiply(
|
||||
X_Alist[0], sk_B
|
||||
X_Alist[0],
|
||||
sk_B,
|
||||
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
|
||||
d = hash3((X_Alist[0], pk_B, X_Ab))
|
||||
EV = add(E2, V2) # E2 + V2
|
||||
|
@ -15,14 +15,14 @@ def test_rust_vs_python_multiply():
|
||||
# Rust实现
|
||||
start_time = time.time()
|
||||
for _ in range(10):
|
||||
_ = multiply(g, mul_times, 1)
|
||||
_ = multiply(g, mul_times)
|
||||
rust_time = time.time() - start_time
|
||||
print(f"\nRust multiply 执行时间: {rust_time:.6f} 秒")
|
||||
|
||||
# Python实现
|
||||
start_time = time.time()
|
||||
for _ in range(10):
|
||||
_ = multiply(g, mul_times, 0)
|
||||
_ = multiply(g, mul_times)
|
||||
python_time = time.time() - start_time
|
||||
print(f"Python multiply 执行时间: {python_time:.6f} 秒")
|
||||
|
||||
@ -33,14 +33,14 @@ def test_rust_vs_python_add():
|
||||
# Rust实现
|
||||
start_time = time.time()
|
||||
for _ in range(10):
|
||||
_ = add(g, g, 1)
|
||||
_ = add(g, g)
|
||||
rust_time = time.time() - start_time
|
||||
print(f"\nRust add 执行时间: {rust_time:.6f} 秒")
|
||||
|
||||
# Python实现
|
||||
start_time = time.time()
|
||||
for _ in range(10):
|
||||
_ = add(g, g, 0)
|
||||
_ = add(g, g)
|
||||
python_time = time.time() - start_time
|
||||
print(f"Python add 执行时间: {python_time:.6f} 秒")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user