feat: update ecc operation call

This commit is contained in:
sangge-redmi 2025-04-04 23:19:27 +08:00
parent dd86253162
commit 5144334558
3 changed files with 167 additions and 67 deletions

View File

@ -19,3 +19,6 @@ dependencies = ["gmssl-python>=2.2.2,<3.0.0", "fastapi", "uvicorn", "requests"]
[project.optional-dependencies] [project.optional-dependencies]
test = ["httpx", "pytest"] test = ["httpx", "pytest"]
dev = ["httpx", "pytest", "pyright", "ruff"] dev = ["httpx", "pytest", "pyright", "ruff"]
[tool.ruff]
select = ["E", "F", "N", "B", "I", "C4", "UP", "SIM"]

View File

@ -1,62 +1,144 @@
from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT """这个模块实现了椭圆曲线加密(ECC)的基本操作.
from typing import Tuple
它提供了点加法点乘法和其他与SM2P256V1曲线相关的操作,
并支持可选的Rust实现来提高性能
"""
import random import random
import ecc_rs from dataclasses import dataclass
from typing import Tuple
from gmssl import DO_DECRYPT, DO_ENCRYPT, Sm2Key, Sm3, Sm4Cbc
point = Tuple[int, int] point = Tuple[int, int]
capsule = Tuple[point, point, int] capsule = Tuple[point, point, int]
try:
import ecc_rs
# 生成密钥对模块 RUST_ECC = True
class CurveFp:
def __init__(self, A, B, P, N, Gx, Gy, name):
self.A = A
self.B = B
self.P = P
self.N = N
self.Gx = Gx
self.Gy = Gy
self.name = name
def add(a: point, b: point) -> point:
"""Add two points on the curve.
sm2p256v1 = CurveFp( Args:
name="sm2p256v1", a: first point
A=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC, b: second point
B=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93, flag: if 1, use Rust implementation
P=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
N=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
Gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
"""
if RUST_ECC is True and ecc_rs is not None:
return ecc_rs.add(a, b)
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
# 生成元 def multiply(a: point, n: int) -> point:
g = (sm2p256v1.Gx, sm2p256v1.Gy) """Multiply a point by a scalar.
Args:
a: point
n: scalar
flag: if 1, use Rust implementation
def multiply(a: point, n: int, flag: int = 0) -> point: """
if flag == 1: if RUST_ECC is True and ecc_rs is not None:
result = ecc_rs.multiply(a, n) return ecc_rs.multiply(a, n)
return result N = sm2p256v1.N
else: A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
except ImportError:
RUST_ECC = False
ecc_rs = None
def add(a: point, b: point) -> point:
"""Add two points on the curve.
Args:
a: first point
b: second point
"""
A = sm2p256v1.A
P = sm2p256v1.P
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P)
def multiply(a: point, n: int) -> point:
"""Multiply a point by a scalar.
Args:
a: point
n: scalar
"""
N = sm2p256v1.N N = sm2p256v1.N
A = sm2p256v1.A A = sm2p256v1.A
P = sm2p256v1.P P = sm2p256v1.P
return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P) return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P)
def add(a: point, b: point, flag: int = 0) -> point: @dataclass
if flag == 1: class CurveParams:
result = ecc_rs.add(a, b) """Definition of SM2P256V1 curve parameters."""
return result
else: a: int
A = sm2p256v1.A b: int
P = sm2p256v1.P p: int
return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P) n: int
gx: int
gy: int
name: str
# 生成密钥对模块
class CurveFp:
"""Definition of SM2P256V1 curve class."""
def __init__(self, params: CurveParams) -> None:
"""Initialize curve with parameters.
Args:
params: Collection of curve parameters
"""
self.A = params.a
self.B = params.b
self.P = params.p
self.N = params.n
self.Gx = params.gx
self.Gy = params.gy
self.name = params.name
curve_params = CurveParams(
name="sm2p256v1",
a=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
b=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
p=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
n=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
sm2p256v1 = CurveFp(params=curve_params)
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
def inv(a: int, n: int) -> int: def inv(a: int, n: int) -> int:
"""Return the modular inverse of a mod n.
Args:
a: input integer
n: modulus
Returns:
modular inverse of a mod n
"""
if a == 0: if a == 0:
return 0 return 0
lm, hm = 1, 0 lm, hm = 1, 0
@ -80,7 +162,9 @@ def fromJacobian(Xp_Yp_Zp: Tuple[int, int, int], P: int) -> point:
def jacobianDouble( def jacobianDouble(
Xp_Yp_Zp: Tuple[int, int, int], A: int, P: int Xp_Yp_Zp: Tuple[int, int, int],
A: int,
P: int,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp Xp, Yp, Zp = Xp_Yp_Zp
if not Yp: if not Yp:
@ -95,7 +179,10 @@ def jacobianDouble(
def jacobianAdd( def jacobianAdd(
Xp_Yp_Zp: Tuple[int, int, int], Xq_Yq_Zq: Tuple[int, int, int], A: int, P: int Xp_Yp_Zp: Tuple[int, int, int],
Xq_Yq_Zq: Tuple[int, int, int],
A: int,
P: int,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp Xp, Yp, Zp = Xp_Yp_Zp
Xq, Yq, Zq = Xq_Yq_Zq Xq, Yq, Zq = Xq_Yq_Zq
@ -123,7 +210,11 @@ def jacobianAdd(
def jacobianMultiply( def jacobianMultiply(
Xp_Yp_Zp: Tuple[int, int, int], n: int, N: int, A: int, P: int Xp_Yp_Zp: Tuple[int, int, int],
n: int,
N: int,
A: int,
P: int,
) -> Tuple[int, int, int]: ) -> Tuple[int, int, int]:
Xp, Yp, Zp = Xp_Yp_Zp Xp, Yp, Zp = Xp_Yp_Zp
if Yp == 0 or n == 0: if Yp == 0 or n == 0:
@ -191,9 +282,9 @@ def KDF(G: point) -> int:
def GenerateKeyPair() -> Tuple[point, int]: def GenerateKeyPair() -> Tuple[point, int]:
""" """return:
return:
public_key, secret_key public_key, secret_key
""" """
sm2 = Sm2Key() # pylint: disable=e0602 sm2 = Sm2Key() # pylint: disable=e0602
sm2.generate_key() sm2.generate_key()
@ -207,15 +298,15 @@ def GenerateKeyPair() -> Tuple[point, int]:
return public_key, secret_key return public_key, secret_key
def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]: def Encrypt(public_key: point, message: bytes) -> Tuple[capsule, bytes]:
enca = Encapsulate(pk) enca = Encapsulate(public_key)
K = enca[0].to_bytes(16) K = enca[0].to_bytes(16)
capsule = enca[1] capsule = enca[1]
if len(K) != 16: if len(K) != 16:
raise ValueError("invalid key length") raise ValueError("invalid key length")
iv = b"tpretpretpretpre" iv = b"tpretpretpretpre"
sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602 sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT)
enc_Data = sm4_enc.update(m) enc_Data = sm4_enc.update(message)
enc_Data += sm4_enc.finish() enc_Data += sm4_enc.finish()
enc_message = (capsule, bytes(enc_Data)) enc_message = (capsule, bytes(enc_Data))
return enc_message return enc_message
@ -231,15 +322,14 @@ def Decapsulate(ska: int, capsule: capsule) -> int:
def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes: def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
""" """params:
params:
sk_A: secret key sk_A: secret key
C: (capsule, enc_data) C: (capsule, enc_data)
""" """
capsule, enc_Data = C capsule, enc_Data = C
K = Decapsulate(sk_A, capsule) K = Decapsulate(sk_A, capsule).to_bytes(16)
iv = b"tpretpretpretpre" iv = b"tpretpretpretpre"
sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602 sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT)
dec_Data = sm4_dec.update(enc_Data) dec_Data = sm4_dec.update(enc_Data)
dec_Data += sm4_dec.finish() dec_Data += sm4_dec.finish()
return bytes(dec_Data) return bytes(dec_Data)
@ -247,7 +337,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes:
# GenerateRekey # GenerateRekey
def hash5(id: int, D: int) -> int: def hash5(id: int, D: int) -> int:
sm3 = Sm3() # pylint: disable=e0602 sm3 = Sm3()
sm3.update(id.to_bytes(32)) sm3.update(id.to_bytes(32))
sm3.update(D.to_bytes(32)) sm3.update(D.to_bytes(32))
hash = sm3.digest() hash = sm3.digest()
@ -266,13 +356,14 @@ def hash6(triple_G: Tuple[point, point, point]) -> int:
def f(x: int, f_modulus: list, T: int) -> int: def f(x: int, f_modulus: list, T: int) -> int:
""" """功能: 通过多项式插值来实现信息的分散和重构
功能: 通过多项式插值来实现信息的分散和重构
例如: 随机生成一个多项式f(x)=4x+5,质数P=11,其中f(0)=5,将多项式的系数分别分配给两个人,例如第一个人得到(1, 9),第二个人得到(2, 2).如果两个人都收集到了这两个点,那么可以使用拉格朗日插值法恢复原始的多项式,进而得到秘密信息"5" 例如: 随机生成一个多项式f(x)=4x+5,质数P=11,其中f(0)=5,将多项式的系数分别分配给两个人,例如第一个人得到(1, 9),第二个人得到(2, 2).如果两个人都收集到了这两个点,那么可以使用拉格朗日插值法恢复原始的多项式,进而得到秘密信息"5"
param: param:
x, f_modulus(多项式系数列表), T(门限) x, f_modulus(多项式系数列表), T(门限)
return:
Return:
res res
""" """
res = 0 res = 0
for i in range(T): for i in range(T):
@ -282,13 +373,18 @@ def f(x: int, f_modulus: list, T: int) -> int:
def GenerateReKey( def GenerateReKey(
sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...] sk_A: int,
pk_B: point,
N: int,
T: int,
id_tuple: Tuple[int, ...],
) -> list: ) -> list:
""" """param:
param:
skA, pkB, N(节点总数), T(阈值) skA, pkB, N(节点总数), T(阈值)
return:
Return:
rki(0 <= i <= N-1) rki(0 <= i <= N-1)
""" """
# 计算临时密钥对(x_A, X_A) # 计算临时密钥对(x_A, X_A)
x_A = random.randint(0, sm2p256v1.N - 1) x_A = random.randint(0, sm2p256v1.N - 1)
@ -369,7 +465,8 @@ def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, po
def ReEncrypt( def ReEncrypt(
kFrag: tuple, C: Tuple[capsule, bytes] kFrag: tuple,
C: Tuple[capsule, bytes],
) -> Tuple[Tuple[point, point, int, point], bytes]: ) -> Tuple[Tuple[point, point, int, point], bytes]:
capsule, enc_Data = C capsule, enc_Data = C
@ -394,11 +491,10 @@ def MergeCFrag(cfrag_cts: list) -> list:
def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int: def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
""" """return:
return:
K: sm4 key K: sm4 key
"""
"""
Elist = [] Elist = []
Vlist = [] Vlist = []
idlist = [] idlist = []
@ -434,7 +530,8 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int:
E2 = add(Ek, E2) E2 = add(Ek, E2)
V2 = add(Vk, V2) V2 = add(Vk, V2)
X_Ab = multiply( X_Ab = multiply(
X_Alist[0], sk_B X_Alist[0],
sk_B,
) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值 ) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值
d = hash3((X_Alist[0], pk_B, X_Ab)) d = hash3((X_Alist[0], pk_B, X_Ab))
EV = add(E2, V2) # E2 + V2 EV = add(E2, V2) # E2 + V2

View File

@ -15,14 +15,14 @@ def test_rust_vs_python_multiply():
# Rust实现 # Rust实现
start_time = time.time() start_time = time.time()
for _ in range(10): for _ in range(10):
_ = multiply(g, mul_times, 1) _ = multiply(g, mul_times)
rust_time = time.time() - start_time rust_time = time.time() - start_time
print(f"\nRust multiply 执行时间: {rust_time:.6f}") print(f"\nRust multiply 执行时间: {rust_time:.6f}")
# Python实现 # Python实现
start_time = time.time() start_time = time.time()
for _ in range(10): for _ in range(10):
_ = multiply(g, mul_times, 0) _ = multiply(g, mul_times)
python_time = time.time() - start_time python_time = time.time() - start_time
print(f"Python multiply 执行时间: {python_time:.6f}") print(f"Python multiply 执行时间: {python_time:.6f}")
@ -33,14 +33,14 @@ def test_rust_vs_python_add():
# Rust实现 # Rust实现
start_time = time.time() start_time = time.time()
for _ in range(10): for _ in range(10):
_ = add(g, g, 1) _ = add(g, g)
rust_time = time.time() - start_time rust_time = time.time() - start_time
print(f"\nRust add 执行时间: {rust_time:.6f}") print(f"\nRust add 执行时间: {rust_time:.6f}")
# Python实现 # Python实现
start_time = time.time() start_time = time.time()
for _ in range(10): for _ in range(10):
_ = add(g, g, 0) _ = add(g, g)
python_time = time.time() - start_time python_time = time.time() - start_time
print(f"Python add 执行时间: {python_time:.6f}") print(f"Python add 执行时间: {python_time:.6f}")