Compare commits

..

5 Commits

Author SHA1 Message Date
880c34ce03 use less clone and builtin modinv 2024-09-05 15:18:33 +08:00
0b9ae82c17 add docstring 2024-09-05 10:16:46 +08:00
a8e3dd7f1e add new test 2024-09-05 09:36:57 +08:00
9c1b3996f3 redefine type lint 2024-09-05 09:36:44 +08:00
41038ebdf4 add jacobian algo 2024-09-05 09:36:27 +08:00
4 changed files with 289 additions and 122 deletions

View File

@@ -1,4 +1,51 @@
# ecc_rs.pyi # ecc_rs.pyi
def add(p1: tuple[int, int], p2: tuple[int, int]) -> tuple[int, int]: ... point = tuple[int, int]
def multiply(point: tuple[int, int], n: int) -> tuple[int, int]: ...
def add(p1: point, p2: point) -> point:
"""
Adds two points on the SM2 elliptic curve.
Performs point addition on the SM2 curve defined by `sm2p256v1()`.
Args:
p1 (point): A tuple representing the x and y coordinates of the first point.
p2 (point): A tuple representing the x and y coordinates of the second point.
Raises:
Panic: If the x or y coordinates of either point are not valid on the curve.
Returns:
point: A tuple representing the x and y coordinates of the resulting point after addition.
Example:
>>> p1 = (x1, y1)
>>> p2 = (x2, y2)
>>> result = add(p1, p2)
"""
...
def multiply(point: point, n: int) -> point:
"""
multiply(point, n)
Performs scalar multiplication of a point on the SM2 curve.
performs the multiplication operation on the SM2 curve defined by `sm2p256v1()`.
Args:
point (point): representing the x and y coordinates of the point.
n (int): representing the scalar to multiply the point by.
Raises:
Panic: If the x or y coordinates of the point are not less than the curve parameter `p`.
Returns:
point: representing the x and y coordinates of the result point.
Example:
>>> point = (g, g)
>>> n = 10
>>> result = multiply(point, n)
"""
...

4
readme.md Normal file
View File

@@ -0,0 +1,4 @@
# ecc_rs
a simple rust implementation of SM2 binding for python.
powered by pyo3.

View File

@@ -16,10 +16,19 @@ struct CurveFp {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct Point { struct Point<'a> {
x: BigUint, x: BigUint,
y: BigUint, y: BigUint,
curve: CurveFp, curve: &'a CurveFp,
}
/// 椭圆曲线上的点(雅可比坐标)
#[derive(Clone)]
struct JacobianPoint<'a> {
x: BigUint, // X 坐标
y: BigUint, // Y 坐标
z: BigUint, // Z 坐标
curve: &'a CurveFp, // 椭圆曲线的参数
} }
// Initialize the SM2 Curve // Initialize the SM2 Curve
@@ -59,137 +68,180 @@ fn sm2p256v1() -> CurveFp {
} }
} }
/// Calculates the greatest common divisor (GCD) of two `BigUint` numbers. fn point_addition<'a>(p1: &'a Point<'a>, p2: &'a Point<'a>) -> Point<'a> {
/// // 如果 p1 是零点,返回一个新构造的 p2 点,而不是克隆
/// The function uses the Euclidean algorithm to compute the GCD. The algorithm is based on
/// the principle that the greatest common divisor of two numbers does not change if the larger
/// number is replaced by its difference with the smaller number.
///
/// # Arguments
///
/// * `a` - The first `BigUint` number.
/// * `b` - The second `BigUint` number.
///
/// # Returns
///
/// * `BigUint` - The greatest common divisor of `a` and `b`.
///
/// # Example
///
/// ```
/// let a = BigUint::from(60u32);
/// let b = BigUint::from(48u32);
/// let result = gcd(a, b);
/// assert_eq!(result, BigUint::from(12u32));
/// ```
fn gcd(a: &BigUint, b: &BigUint) -> BigUint {
let mut c = a.clone();
let mut d = b.clone();
while !a.is_zero() {
let temp = c.clone();
c = d % c;
d = temp;
}
d
}
/// Computes the modular inverse of `a` under modulo `m`.
///
/// # Arguments
///
/// * `a` - A reference to a BigUint representing the number to find the modular inverse of.
/// * `p` - A reference to a BigUint representing the modulus.
///
/// # Return1
///
/// * A BigUint representing the modular inverse of `a` modulo `p`.
///
/// # Panics
///
/// This function will panic if the modular inverse does not exist (i.e., if `a` and `p` are not coprime).
fn mod_inverse(a: &BigUint, m: &BigUint) -> BigUint {
let (mut t, mut new_t) = (BigUint::zero(), BigUint::one());
let (mut r, mut new_r) = (m.clone(), a.clone());
while !new_r.is_zero() {
let quotient = &r / &new_r;
let temp_t = new_t.clone();
new_t = if t < (quotient.clone() * &temp_t) {
// BigUint can't be negative,
// so we use mod to handle the case where t < quotient * temp_t
(&t + m - &(quotient.clone() * &temp_t) % m) % m
} else {
&t - &(quotient.clone() * &temp_t)
};
let temp_r = new_r.clone();
new_r = &r - &(quotient.clone() * &temp_r);
r = temp_r;
t = temp_t;
}
if r > BigUint::one() {
panic!("Modular inverse does not exist");
}
if t < BigUint::zero() {
t += m;
}
t
}
fn point_addition(p1: &Point, p2: &Point) -> Point {
let curve = &p1.curve;
let p = &curve.p;
if p1.x.is_zero() && p1.y.is_zero() { if p1.x.is_zero() && p1.y.is_zero() {
return Point {
x: p2.x.clone(),
y: p2.y.clone(),
curve: p2.curve, // 如果 curve 是引用,保持引用即可
};
}
// 如果 p2 是零点,返回一个新构造的 p1 点,而不是克隆
if p2.x.is_zero() && p2.y.is_zero() {
return Point {
x: p1.x.clone(),
y: p1.y.clone(),
curve: p1.curve, // 如果 curve 是引用,保持引用即可
};
}
from_jacobian(jacobian_add(to_jacobian(p1), to_jacobian(p2)))
}
/// 将仿射坐标转换为雅可比坐标 (X, Y, Z)
fn to_jacobian<'a>(p: &'a Point) -> JacobianPoint<'a> {
JacobianPoint {
x: p.x.clone(),
y: p.y.clone(),
z: BigUint::one(), // Z = 1 表示仿射坐标
curve: p.curve,
}
}
/// 将雅可比坐标转换为仿射坐标
fn from_jacobian(p: JacobianPoint) -> Point {
if p.z.is_zero() {
return Point {
x: BigUint::zero(),
y: BigUint::zero(),
curve: p.curve,
};
}
let p_mod = &p.curve.p;
// 计算 Z 的模反
let z_inv = p.z.modinv(p_mod).expect("modinv failed");
let z_inv2 = (&z_inv * &z_inv) % p_mod; // Z_inv^2
let z_inv3 = (&z_inv2 * &z_inv) % p_mod; // Z_inv^3
// 计算 x = X * Z_inv^2, y = Y * Z_inv^3
let x_affine = (&p.x * &z_inv2) % p_mod;
let y_affine = (&p.y * &z_inv3) % p_mod;
Point {
x: x_affine,
y: y_affine,
curve: p.curve,
}
}
/// 雅可比坐标下的点加法
fn jacobian_add<'a>(p1: JacobianPoint<'a>, p2: JacobianPoint<'a>) -> JacobianPoint<'a> {
if p1.z.is_zero() {
return p2.clone(); return p2.clone();
} }
if p2.x.is_zero() && p2.y.is_zero() { if p2.z.is_zero() {
return p1.clone(); return p1.clone();
} }
let lambda = if p1.x == p2.x && p1.y == p2.y { let p_mod = &p1.curve.p;
let num = (BigUint::from(3u32) * &p1.x * &p1.x + &curve.a) % p;
let denom = (BigUint::from(2u32) * &p1.y) % p;
(num * mod_inverse(&denom, p)) % p
} else {
let num = ((&p2.y + p) - &p1.y) % p;
let denom = ((&p2.x + p) - &p1.x) % p;
(num * mod_inverse(&denom, p)) % p // U1 = X1 * Z2^2, U2 = X2 * Z1^2
}; let z1z1 = (&p1.z * &p1.z) % p_mod;
let z2z2 = (&p2.z * &p2.z) % p_mod;
let u1 = (&p1.x * &z2z2) % p_mod;
let u2 = (&p2.x * &z1z1) % p_mod;
let x3 = (lambda.clone() * &lambda - &p1.x - &p2.x) % p; // S1 = Y1 * Z2^3, S2 = Y2 * Z1^3
let y3 = (lambda * (&p1.x + p - &x3) - &p1.y) % p; let z1z1z1 = (&z1z1 * &p1.z) % p_mod;
let z2z2z2 = (&z2z2 * &p2.z) % p_mod;
let s1 = (&p1.y * &z2z2z2) % p_mod;
let s2 = (&p2.y * &z1z1z1) % p_mod;
Point { if u1 == u2 && s1 == s2 {
// 点倍运算 (p1 == p2)
return jacobian_double(p1);
}
// H = U2 - U1, R = S2 - S1
let h = (u2 + p_mod - &u1) % p_mod;
let r = (s2 + p_mod - &s1) % p_mod;
// X3 = R^2 - H^3 - 2 * U1 * H^2
let h2 = (&h * &h) % p_mod;
let h3 = (&h2 * &h) % p_mod;
let x3 = (&r * &r + p_mod - &h3 - (&BigUint::from(2u32) * &u1 * &h2) % p_mod) % p_mod;
// Y3 = R * (U1 * H^2 - X3) - S1 * H^3
let y3 =
(&r * ((&u1 * &h2 + p_mod - (&x3 % p_mod)) % p_mod) + p_mod - (&s1 * &h3) % p_mod) % p_mod;
// Z3 = Z1 * Z2 * H
let z3 = (&p1.z * &p2.z * &h) % p_mod;
JacobianPoint {
x: x3, x: x3,
y: y3, y: y3,
curve: curve.clone(), z: z3,
curve: p1.curve,
} }
} }
fn point_multiplication(p: &Point, n: &BigUint) -> Point { /// 雅可比坐标下的点倍运算
let mut result = Point { fn jacobian_double(p: JacobianPoint) -> JacobianPoint {
let p_mod = &p.curve.p;
if p.y.is_zero() {
return JacobianPoint {
x: BigUint::zero(),
y: BigUint::zero(),
z: BigUint::zero(),
curve: p.curve,
};
}
// S = 4 * X * Y^2
let y2 = (&p.y * &p.y) % p_mod;
let s = (&p.x * &y2 * BigUint::from(4u32)) % p_mod;
// M = 3 * X^2 + a * Z^4
let z2 = (&p.z * &p.z) % p_mod;
let z4 = (&z2 * &z2) % p_mod;
let m = ((&p.x * &p.x * BigUint::from(3u32)) + &p.curve.a * &z4) % p_mod;
// X3 = M^2 - 2 * S
let x3 = (&m * &m + p_mod - &s * BigUint::from(2u32)) % p_mod;
// Y3 = M * (S - X3) - 8 * Y^4
let y4 = (&y2 * &y2) % p_mod;
let y3 = (&m * (&s + p_mod - &x3) + p_mod - BigUint::from(8u32) * &y4) % p_mod;
// Z3 = 2 * Y * Z
let z3 = (&p.y * &p.z * BigUint::from(2u32)) % p_mod;
JacobianPoint {
x: x3,
y: y3,
z: z3,
curve: p.curve,
}
}
fn point_multiplication<'a>(p: &'a Point<'a>, n: &BigUint) -> Point<'a> {
let mut result = JacobianPoint {
x: BigUint::zero(), x: BigUint::zero(),
y: BigUint::zero(), y: BigUint::one(), // 无穷远点的雅可比坐标表示
curve: p.curve.clone(), z: BigUint::zero(),
curve: p.curve,
}; };
let mut addend = p.clone(); // 将输入点从仿射坐标转换为雅可比坐标
let mut addend = to_jacobian(p);
let mut k = n.clone(); let mut k = n.clone();
// 使用二进制展开法进行点乘运算
while !k.is_zero() { while !k.is_zero() {
if &k % 2u32 == BigUint::one() { if &k % 2u32 == BigUint::one() {
result = point_addition(&result, &addend); result = jacobian_add(result, addend.clone());
} }
addend = point_addition(&addend, &addend); addend = jacobian_double(addend); // 倍点运算
k >>= 1; k >>= 1;
} }
result // 将结果从雅可比坐标转换为仿射坐标
from_jacobian(result)
} }
/// SM2 addition /// SM2 addition
@@ -202,24 +254,24 @@ fn add(p1: (BigUint, BigUint), p2: (BigUint, BigUint)) -> (BigUint, BigUint) {
let x2 = p2.0; let x2 = p2.0;
let y2 = p2.1; let y2 = p2.1;
// 检查 x 和 y 是否小于曲线参数 p // Check if x and y are less than the curve parameter p
if x1 >= curve.p || y1 >= curve.p { match (x1 >= curve.p, y1 >= curve.p, x2 >= curve.p, y2 >= curve.p) {
panic!("Point p1 coordinates are out of range"); (true, _, _, _) | (_, true, _, _) | (_, _, true, _) | (_, _, _, true) => {
} panic!("Point coordinates are out of range");
if x2 >= curve.p || y2 >= curve.p { }
panic!("Point p2 coordinates are out of range"); _ => (),
} }
let point1 = Point { let point1 = Point {
x: x1, x: x1,
y: y1, y: y1,
curve: curve.clone(), curve: &curve,
}; };
let point2 = Point { let point2 = Point {
x: x2, x: x2,
y: y2, y: y2,
curve: curve.clone(), curve: &curve,
}; };
let result = point_addition(&point1, &point2); let result = point_addition(&point1, &point2);
@@ -228,13 +280,23 @@ fn add(p1: (BigUint, BigUint), p2: (BigUint, BigUint)) -> (BigUint, BigUint) {
/// SM2 multiply /// SM2 multiply
#[pyfunction] #[pyfunction]
fn multiply(point: (BigUint, BigUint), n: BigUint) -> (BigUint, BigUint) { fn multiply(point: (BigUint, BigUint), mut n: BigUint) -> (BigUint, BigUint) {
let curve = sm2p256v1(); let curve = sm2p256v1();
if n >= curve.n {
n %= &curve.n;
}
match (point.0 < curve.p, point.1 < curve.p) {
(false, _) | (_, false) => panic!("Point coordinates are out of range"),
_ => {}
}
// Construct the point with BigUint values // Construct the point with BigUint values
let point = Point { let point = Point {
x: point.0.clone(), x: point.0,
y: point.1.clone(), y: point.1,
curve: curve.clone(), curve: &curve,
}; };
// Perform point multiplication // Perform point multiplication
let result = point_multiplication(&point, &n); let result = point_multiplication(&point, &n);

54
test.py
View File

@@ -1,5 +1,7 @@
import ecc_rs import ecc_rs
import time
point = tuple[int, int]
# Example point coordinates for P1 and P2 as tuples (x1, y1) and (x2, y2) # Example point coordinates for P1 and P2 as tuples (x1, y1) and (x2, y2)
p1 = ( p1 = (
1234567890123456789012345678901234567890123456789012345678901234, 1234567890123456789012345678901234567890123456789012345678901234,
@@ -17,3 +19,55 @@ print(f"Resulting Point: x = {result_x}, y = {result_y}")
result = ecc_rs.multiply(p1, 2) result = ecc_rs.multiply(p1, 2)
print(f"Resulting Point: x = {result[0]}, y = {result[1]}") print(f"Resulting Point: x = {result[0]}, y = {result[1]}")
# 生成密钥对模块
class CurveFp:
def __init__(self, A, B, P, N, Gx, Gy, name):
self.A = A
self.B = B
self.P = P
self.N = N
self.Gx = Gx
self.Gy = Gy
self.name = name
sm2p256v1 = CurveFp(
name="sm2p256v1",
A=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC,
B=0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93,
P=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF,
N=0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123,
Gx=0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7,
Gy=0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0,
)
# 生成元
g = (sm2p256v1.Gx, sm2p256v1.Gy)
def multiply(a: point, n: int) -> point:
result = ecc_rs.multiply(a, n)
return result
def add(a: point, b: point) -> point:
result = ecc_rs.add(a, b)
return result
start_time = time.time() # 获取开始时间
for i in range(10):
result = multiply(g, 10000) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust multiply 执行时间: {elapsed_time:.6f}")
start_time = time.time() # 获取开始时间
for i in range(10):
result = add(g, g) # 执行函数
end_time = time.time() # 获取结束时间
elapsed_time = end_time - start_time # 计算执行时间
print(f"rust add 执行时间: {elapsed_time:.6f}")