feat: update ecc operation call

This commit is contained in:
sangge-redmi 2025-04-04 23:19:27 +08:00
parent dd86253162
commit 5144334558
3 changed files with 167 additions and 67 deletions

View File

@ -19,3 +19,6 @@ dependencies = ["gmssl-python>=2.2.2,<3.0.0", "fastapi", "uvicorn", "requests"]
[project.optional-dependencies]
test = ["httpx", "pytest"]
dev = ["httpx", "pytest", "pyright", "ruff"]
[tool.ruff]
select = ["E", "F", "N", "B", "I", "C4", "UP", "SIM"]

View File

@ -1,62 +1,144 @@
from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT
from typing import Tuple
"""这个模块实现了椭圆曲线加密(ECC)的基本操作.
它提供了点加法点乘法和其他与SM2P256V1曲线相关的操作,
并支持可选的Rust实现来提高性能
"""
import random
import ecc_rs
from dataclasses import dataclass
from typing import Tuple
from gmssl import DO_DECRYPT, DO_ENCRYPT, Sm2Key, Sm3, Sm4Cbc
point = Tuple[int, int]
capsule = Tuple[point, point, int]
try:
import ecc_rs
# 生成密钥对模块
class CurveFp:
def __init__(self, A, B, P, N, Gx, Gy, name):
self.A = A
self.B = B
self.P = P
self.N = N
self.Gx = Gx
self.Gy = Gy
self.name = name
RUST_ECC = True
def add(a: point, b: point) -> point:
"""Add two points on the curve.
sm2p256v1 = CurveFp(
name="sm2p256v1",
A=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
B=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
P=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
N=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
Gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
Args:
a: first point
b: second point
flag: if 1, use Rust implementation
"""
if RUST_ECC is True and ecc_rs is not None:
return ecc_rs.add(a, b)
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
def multiply(a: point, n: int) -> point:
"""Multiply a point by a scalar.
Args:
a: point
n: scalar
flag: if 1, use Rust implementation
def multiply(a: point, n: int, flag: int = 0) -> point:
if flag == 1:
"""
if RUST_ECC is True and ecc_rs is not None:
result = ecc_rs.multiply(a, n)
return result
else:
return ecc_rs.multiply(a, n)
N = sm2p256v1.N
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
except ImportError:
RUST_ECC = False
ecc_rs = None
def add(a: point, b: point) -> point:
"""Add two points on the curve.
Args:
a: first point
b: second point
"""
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
def multiply(a: point, n: int) -> point:
"""Multiply a point by a scalar.
Args:
a: point
n: scalar
"""
N = sm2p256v1.N
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
def add(a: point, b: point, flag: int = 0) -> point:
if flag == 1:
result = ecc_rs.add(a, b)
return result
else:
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
@dataclass
class CurveParams:
"""Definition of SM2P256V1 curve parameters."""
a: int
b: int
p: int
n: int
gx: int
gy: int
name: str
# 生成密钥对模块
class CurveFp:
"""Definition of SM2P256V1 curve class."""
def __init__(self, params: CurveParams) -> None:
"""Initialize curve with parameters.
Args:
params: Collection of curve parameters
"""
self.A = params.a
self.B = params.b
self.P = params.p
self.N = params.n
self.Gx = params.gx
self.Gy = params.gy
self.name = params.name
curve_params = CurveParams(
name="sm2p256v1",
a=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
b=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
p=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
n=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
sm2p256v1 = CurveFp(params=curve_params)
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
def inv(a: int, n: int) -> int:
"""Return the modular inverse of a mod n.
Args:
a: input integer
n: modulus
Returns:
modular inverse of a mod n
"""
if a == 0:
return 0
lm, hm = 1, 0
@ -80,7 +162,9 @@ def fromJacobian(Xp_Yp_Zp: Tuple[int, int, int], P: int) -> point:
def jacobianDouble(
Xp_Yp_Zp: Tuple[int, int, int], A: int, P: int
Xp_Yp_Zp: Tuple[int, int, int],
A: int,
P: int,
) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp
if not Yp:
@ -95,7 +179,10 @@ def jacobianDouble(
def jacobianAdd(
Xp_Yp_Zp: Tuple[int, int, int], Xq_Yq_Zq: Tuple[int, int, int], A: int, P: int
Xp_Yp_Zp: Tuple[int, int, int],
Xq_Yq_Zq: Tuple[int, int, int],
A: int,
P: int,
) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp
Xq, Yq, Zq = Xq_Yq_Zq
@ -123,7 +210,11 @@ def jacobianAdd(
def jacobianMultiply(
Xp_Yp_Zp: Tuple[int, int, int], n: int, N: int, A: int, P: int
Xp_Yp_Zp: Tuple[int, int, int],
n: int,
N: int,
A: int,
P: int,
) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp
if Yp == 0 or n == 0:
@ -191,9 +282,9 @@ def KDF(G: point) -> int:
def GenerateKeyPair() -> Tuple[point, int]:
"""
return:
"""return:
public_key, secret_key
"""
sm2 = Sm2Key() # pylint: disable=e0602
sm2.generate_key()
@ -207,15 +298,15 @@ def GenerateKeyPair() -> Tuple[point, int]:
return public_key, secret_key
def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]:
enca = Encapsulate(pk)
def Encrypt(public_key: point, message: bytes) -> Tuple[capsule, bytes]:
enca = Encapsulate(public_key)
K = enca[0].to_bytes(16)
capsule = enca[1]
if len(K) != 16:
raise ValueError("invalid key length")
iv = b"tpretpretpretpre"
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602
enc_Data = sm4_enc.update(m)
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT)
enc_Data = sm4_enc.update(message)
enc_Data += sm4_enc.finish()
enc_message = (capsule, bytes(enc_Data))
return enc_message
@ -231,15 +322,14 @@ def Decapsulate(ska: int, capsule: capsule) -> int:
def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
"""
params:
"""params:
sk_A: secret key
C: (capsule, enc_data)
"""
capsule, enc_Data = C
K = Decapsulate(sk_A, capsule)
K = Decapsulate(sk_A, capsule).to_bytes(16)
iv = b"tpretpretpretpre"
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT)
dec_Data = sm4_dec.update(enc_Data)
dec_Data += sm4_dec.finish()
return bytes(dec_Data)
@ -247,7 +337,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
# GenerateRekey
def hash5(id: int, D: int) -> int:
sm3 = Sm3() # pylint: disable=e0602
sm3 = Sm3()
sm3.update(id.to_bytes(32))
sm3.update(D.to_bytes(32))
hash = sm3.digest()
@ -266,13 +356,14 @@ def hash6(triple_G: Tuple[point, point, point]) -> int:
def f(x: int, f_modulus: list, T: int) -> int:
"""
功能: 通过多项式插值来实现信息的分散和重构
"""功能: 通过多项式插值来实现信息的分散和重构
例如: 随机生成一个多项式f(x)=4x+5,质数P=11,其中f(0)=5,将多项式的系数分别分配给两个人,例如第一个人得到(1, 9),第二个人得到(2, 2).如果两个人都收集到了这两个点,那么可以使用拉格朗日插值法恢复原始的多项式,进而得到秘密信息"5"
param:
x, f_modulus(多项式系数列表), T(门限)
return:
Return:
res
"""
res = 0
for i in range(T):
@ -282,13 +373,18 @@ def f(x: int, f_modulus: list, T: int) -> int:
def GenerateReKey(
sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...]
sk_A: int,
pk_B: point,
N: int,
T: int,
id_tuple: Tuple[int, ...],
) -> list:
"""
param:
"""param:
skA, pkB, N(节点总数), T(阈值)
return:
Return:
rki(0 <= i <= N-1)
"""
# 计算临时密钥对(x_A, X_A)
x_A = random.randint(0, sm2p256v1.N - 1)
@ -369,7 +465,8 @@ def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, po
def ReEncrypt(
kFrag: tuple, C: Tuple[capsule, bytes]
kFrag: tuple,
C: Tuple[capsule, bytes],
) -> Tuple[Tuple[point, point, int, point], bytes]:
capsule, enc_Data = C
@ -394,11 +491,10 @@ def MergeCFrag(cfrag_cts: list) -> list:
def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
"""
return:
"""return:
K: sm4 key
"""
"""
Elist = []
Vlist = []
idlist = []
@ -434,7 +530,8 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
E2 = add(Ek, E2)
V2 = add(Vk, V2)
X_Ab = multiply(
X_Alist[0], sk_B
X_Alist[0],
sk_B,
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
d = hash3((X_Alist[0], pk_B, X_Ab))
EV = add(E2, V2) # E2 + V2

View File

@ -15,14 +15,14 @@ def test_rust_vs_python_multiply():
# Rust实现
start_time = time.time()
for _ in range(10):
_ = multiply(g, mul_times, 1)
_ = multiply(g, mul_times)
rust_time = time.time() - start_time
print(f"\nRust multiply 执行时间: {rust_time:.6f}")
# Python实现
start_time = time.time()
for _ in range(10):
_ = multiply(g, mul_times, 0)
_ = multiply(g, mul_times)
python_time = time.time() - start_time
print(f"Python multiply 执行时间: {python_time:.6f}")
@ -33,14 +33,14 @@ def test_rust_vs_python_add():
# Rust实现
start_time = time.time()
for _ in range(10):
_ = add(g, g, 1)
_ = add(g, g)
rust_time = time.time() - start_time
print(f"\nRust add 执行时间: {rust_time:.6f}")
# Python实现
start_time = time.time()
for _ in range(10):
_ = add(g, g, 0)
_ = add(g, g)
python_time = time.time() - start_time
print(f"Python add 执行时间: {python_time:.6f}")