From 5144334558edc42b54266ea2f5a4e538610a99f6 Mon Sep 17 00:00:00 2001 From: sangge-redmi <2251250136@qq.com> Date: Fri, 4 Apr 2025 23:19:27 +0800 Subject: [PATCH] feat: update ecc operation call --- pyproject.toml | 3 + src/tpre.py | 223 +++++++++++++++++++++-------- tests/ecc_speed_comparison_test.py | 8 +- 3 files changed, 167 insertions(+), 67 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5bee456..44e12b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,3 +19,6 @@ dependencies = ["gmssl-python>=2.2.2,<3.0.0", "fastapi", "uvicorn", "requests"] [project.optional-dependencies] test = ["httpx", "pytest"] dev = ["httpx", "pytest", "pyright", "ruff"] + +[tool.ruff] +select = ["E", "F", "N", "B", "I", "C4", "UP", "SIM"] diff --git a/src/tpre.py b/src/tpre.py index 5b0fcf4..979111e 100644 --- a/src/tpre.py +++ b/src/tpre.py @@ -1,62 +1,144 @@ -from gmssl import Sm3, Sm2Key, Sm4Cbc, DO_ENCRYPT, DO_DECRYPT -from typing import Tuple +"""这个模块实现了椭圆曲线加密(ECC)的基本操作. + +它提供了点加法、点乘法和其他与SM2P256V1曲线相关的操作, +并支持可选的Rust实现来提高性能。 +""" + import random -import ecc_rs +from dataclasses import dataclass +from typing import Tuple + +from gmssl import DO_DECRYPT, DO_ENCRYPT, Sm2Key, Sm3, Sm4Cbc point = Tuple[int, int] capsule = Tuple[point, point, int] +try: + import ecc_rs -# 生成密钥对模块 -class CurveFp: - def __init__(self, A, B, P, N, Gx, Gy, name): - self.A = A - self.B = B - self.P = P - self.N = N - self.Gx = Gx - self.Gy = Gy - self.name = name + RUST_ECC = True + def add(a: point, b: point) -> point: + """Add two points on the curve. -sm2p256v1 = CurveFp( - name="sm2p256v1", - A=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC, - B=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93, - P=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF, - N=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123, - Gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7, - Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0, -) + Args: + a: first point + b: second point + flag: if 1, use Rust implementation + """ + if RUST_ECC is True and ecc_rs is not None: + return ecc_rs.add(a, b) + A = sm2p256v1.A + P = sm2p256v1.P + return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P) -# 生成元 -g = (sm2p256v1.Gx, sm2p256v1.Gy) + def multiply(a: point, n: int) -> point: + """Multiply a point by a scalar. + Args: + a: point + n: scalar + flag: if 1, use Rust implementation -def multiply(a: point, n: int, flag: int = 0) -> point: - if flag == 1: + """ + if RUST_ECC is True and ecc_rs is not None: - result = ecc_rs.multiply(a, n) - return result - else: + return ecc_rs.multiply(a, n) + N = sm2p256v1.N + A = sm2p256v1.A + P = sm2p256v1.P + return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P) + +except ImportError: + RUST_ECC = False + ecc_rs = None + + def add(a: point, b: point) -> point: + """Add two points on the curve. + + Args: + a: first point + b: second point + + """ + A = sm2p256v1.A + P = sm2p256v1.P + return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P) + + def multiply(a: point, n: int) -> point: + """Multiply a point by a scalar. + + Args: + a: point + n: scalar + + """ N = sm2p256v1.N A = sm2p256v1.A P = sm2p256v1.P return fromJacobian(jacobianMultiply(toJacobian(a), n, N, A, P), P) -def add(a: point, b: point, flag: int = 0) -> point: - if flag == 1: - result = ecc_rs.add(a, b) - return result - else: - A = sm2p256v1.A - P = sm2p256v1.P - return fromJacobian(jacobianAdd(toJacobian(a), toJacobian(b), A, P), P) +@dataclass +class CurveParams: + """Definition of SM2P256V1 curve parameters.""" + + a: int + b: int + p: int + n: int + gx: int + gy: int + name: str + + +# 生成密钥对模块 +class CurveFp: + """Definition of SM2P256V1 curve class.""" + + def __init__(self, params: CurveParams) -> None: + """Initialize curve with parameters. + + Args: + params: Collection of curve parameters + + """ + self.A = params.a + self.B = params.b + self.P = params.p + self.N = params.n + self.Gx = params.gx + self.Gy = params.gy + self.name = params.name + + +curve_params = CurveParams( + name="sm2p256v1", + a=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC, + b=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93, + p=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF, + n=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123, + gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7, + gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0, +) +sm2p256v1 = CurveFp(params=curve_params) + + +# 生成元 +g = (sm2p256v1.Gx, sm2p256v1.Gy) def inv(a: int, n: int) -> int: + """Return the modular inverse of a mod n. + + Args: + a: input integer + n: modulus + Returns: + modular inverse of a mod n + + """ if a == 0: return 0 lm, hm = 1, 0 @@ -80,7 +162,9 @@ def fromJacobian(Xp_Yp_Zp: Tuple[int, int, int], P: int) -> point: def jacobianDouble( - Xp_Yp_Zp: Tuple[int, int, int], A: int, P: int + Xp_Yp_Zp: Tuple[int, int, int], + A: int, + P: int, ) -> Tuple[int, int, int]: Xp, Yp, Zp = Xp_Yp_Zp if not Yp: @@ -95,7 +179,10 @@ def jacobianDouble( def jacobianAdd( - Xp_Yp_Zp: Tuple[int, int, int], Xq_Yq_Zq: Tuple[int, int, int], A: int, P: int + Xp_Yp_Zp: Tuple[int, int, int], + Xq_Yq_Zq: Tuple[int, int, int], + A: int, + P: int, ) -> Tuple[int, int, int]: Xp, Yp, Zp = Xp_Yp_Zp Xq, Yq, Zq = Xq_Yq_Zq @@ -123,7 +210,11 @@ def jacobianAdd( def jacobianMultiply( - Xp_Yp_Zp: Tuple[int, int, int], n: int, N: int, A: int, P: int + Xp_Yp_Zp: Tuple[int, int, int], + n: int, + N: int, + A: int, + P: int, ) -> Tuple[int, int, int]: Xp, Yp, Zp = Xp_Yp_Zp if Yp == 0 or n == 0: @@ -191,9 +282,9 @@ def KDF(G: point) -> int: def GenerateKeyPair() -> Tuple[point, int]: - """ - return: + """return: public_key, secret_key + """ sm2 = Sm2Key() # pylint: disable=e0602 sm2.generate_key() @@ -207,15 +298,15 @@ def GenerateKeyPair() -> Tuple[point, int]: return public_key, secret_key -def Encrypt(pk: point, m: bytes) -> Tuple[capsule, bytes]: - enca = Encapsulate(pk) +def Encrypt(public_key: point, message: bytes) -> Tuple[capsule, bytes]: + enca = Encapsulate(public_key) K = enca[0].to_bytes(16) capsule = enca[1] if len(K) != 16: raise ValueError("invalid key length") iv = b"tpretpretpretpre" - sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) # pylint: disable=e0602 - enc_Data = sm4_enc.update(m) + sm4_enc = Sm4Cbc(K, iv, DO_ENCRYPT) + enc_Data = sm4_enc.update(message) enc_Data += sm4_enc.finish() enc_message = (capsule, bytes(enc_Data)) return enc_message @@ -231,15 +322,14 @@ def Decapsulate(ska: int, capsule: capsule) -> int: def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes: - """ - params: + """params: sk_A: secret key C: (capsule, enc_data) """ capsule, enc_Data = C - K = Decapsulate(sk_A, capsule) + K = Decapsulate(sk_A, capsule).to_bytes(16) iv = b"tpretpretpretpre" - sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) # pylint: disable= e0602 + sm4_dec = Sm4Cbc(K, iv, DO_DECRYPT) dec_Data = sm4_dec.update(enc_Data) dec_Data += sm4_dec.finish() return bytes(dec_Data) @@ -247,7 +337,7 @@ def Decrypt(sk_A: int, C: Tuple[capsule, bytes]) -> bytes: # GenerateRekey def hash5(id: int, D: int) -> int: - sm3 = Sm3() # pylint: disable=e0602 + sm3 = Sm3() sm3.update(id.to_bytes(32)) sm3.update(D.to_bytes(32)) hash = sm3.digest() @@ -266,13 +356,14 @@ def hash6(triple_G: Tuple[point, point, point]) -> int: def f(x: int, f_modulus: list, T: int) -> int: - """ - 功能: 通过多项式插值来实现信息的分散和重构 + """功能: 通过多项式插值来实现信息的分散和重构 例如: 随机生成一个多项式f(x)=4x+5,质数P=11,其中f(0)=5,将多项式的系数分别分配给两个人,例如第一个人得到(1, 9),第二个人得到(2, 2).如果两个人都收集到了这两个点,那么可以使用拉格朗日插值法恢复原始的多项式,进而得到秘密信息"5" param: x, f_modulus(多项式系数列表), T(门限) - return: + + Return: res + """ res = 0 for i in range(T): @@ -282,13 +373,18 @@ def f(x: int, f_modulus: list, T: int) -> int: def GenerateReKey( - sk_A: int, pk_B: point, N: int, T: int, id_tuple: Tuple[int, ...] + sk_A: int, + pk_B: point, + N: int, + T: int, + id_tuple: Tuple[int, ...], ) -> list: - """ - param: + """param: skA, pkB, N(节点总数), T(阈值) - return: + + Return: rki(0 <= i <= N-1) + """ # 计算临时密钥对(x_A, X_A) x_A = random.randint(0, sm2p256v1.N - 1) @@ -369,7 +465,8 @@ def ReEncapsulate(kFrag: tuple, capsule: capsule) -> Tuple[point, point, int, po def ReEncrypt( - kFrag: tuple, C: Tuple[capsule, bytes] + kFrag: tuple, + C: Tuple[capsule, bytes], ) -> Tuple[Tuple[point, point, int, point], bytes]: capsule, enc_Data = C @@ -394,11 +491,10 @@ def MergeCFrag(cfrag_cts: list) -> list: def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int: - """ - return: + """return: K: sm4 key - """ + """ Elist = [] Vlist = [] idlist = [] @@ -434,7 +530,8 @@ def DecapsulateFrags(sk_B: int, pk_B: point, pk_A: point, cFrags: list) -> int: E2 = add(Ek, E2) V2 = add(Vk, V2) X_Ab = multiply( - X_Alist[0], sk_B + X_Alist[0], + sk_B, ) # X_A^b X_A 的值是随机生成的xa,通过椭圆曲线上的倍点运算生成的固定的值 d = hash3((X_Alist[0], pk_B, X_Ab)) EV = add(E2, V2) # E2 + V2 diff --git a/tests/ecc_speed_comparison_test.py b/tests/ecc_speed_comparison_test.py index 9486d5a..903662e 100644 --- a/tests/ecc_speed_comparison_test.py +++ b/tests/ecc_speed_comparison_test.py @@ -15,14 +15,14 @@ def test_rust_vs_python_multiply(): # Rust实现 start_time = time.time() for _ in range(10): - _ = multiply(g, mul_times, 1) + _ = multiply(g, mul_times) rust_time = time.time() - start_time print(f"\nRust multiply 执行时间: {rust_time:.6f} 秒") # Python实现 start_time = time.time() for _ in range(10): - _ = multiply(g, mul_times, 0) + _ = multiply(g, mul_times) python_time = time.time() - start_time print(f"Python multiply 执行时间: {python_time:.6f} 秒") @@ -33,14 +33,14 @@ def test_rust_vs_python_add(): # Rust实现 start_time = time.time() for _ in range(10): - _ = add(g, g, 1) + _ = add(g, g) rust_time = time.time() - start_time print(f"\nRust add 执行时间: {rust_time:.6f} 秒") # Python实现 start_time = time.time() for _ in range(10): - _ = add(g, g, 0) + _ = add(g, g) python_time = time.time() - start_time print(f"Python add 执行时间: {python_time:.6f} 秒")