diff --git a/ecc_rs.pyi b/ecc_rs.pyi index b8bd9c2..a516e4d 100644 --- a/ecc_rs.pyi +++ b/ecc_rs.pyi @@ -2,5 +2,50 @@ point = tuple[int, int] -def add(p1: point, p2: point) -> point: ... -def multiply(point: point, n: int) -> point: ... +def add(p1: point, p2: point) -> point: + """ + Adds two points on the SM2 elliptic curve. + + Performs point addition on the SM2 curve defined by `sm2p256v1()`. + + Args: + p1 (point): A tuple representing the x and y coordinates of the first point. + p2 (point): A tuple representing the x and y coordinates of the second point. + + Raises: + Panic: If the x or y coordinates of either point are not valid on the curve. + + Returns: + point: A tuple representing the x and y coordinates of the resulting point after addition. + + Example: + >>> p1 = (x1, y1) + >>> p2 = (x2, y2) + >>> result = add(p1, p2) + """ + ... + +def multiply(point: point, n: int) -> point: + """ + multiply(point, n) + + Performs scalar multiplication of a point on the SM2 curve. + + performs the multiplication operation on the SM2 curve defined by `sm2p256v1()`. + + Args: + point (point): representing the x and y coordinates of the point. + n (int): representing the scalar to multiply the point by. + + Raises: + Panic: If the x or y coordinates of the point are not less than the curve parameter `p`. + + Returns: + point: representing the x and y coordinates of the result point. + + Example: + >>> point = (g, g) + >>> n = 10 + >>> result = multiply(point, n) + """ + ... diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..3871c33 --- /dev/null +++ b/readme.md @@ -0,0 +1,4 @@ +# ecc_rs + +a simple rust implementation of SM2 binding for python. +powered by pyo3. diff --git a/src/lib.rs b/src/lib.rs index 06688d0..f799ce0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -325,12 +325,12 @@ fn add(p1: (BigUint, BigUint), p2: (BigUint, BigUint)) -> (BigUint, BigUint) { let x2 = p2.0; let y2 = p2.1; - // 检查 x 和 y 是否小于曲线参数 p - if x1 >= curve.p || y1 >= curve.p { - panic!("Point p1 coordinates are out of range"); - } - if x2 >= curve.p || y2 >= curve.p { - panic!("Point p2 coordinates are out of range"); + // Check if x and y are less than the curve parameter p + match (x1 >= curve.p, y1 >= curve.p, x2 >= curve.p, y2 >= curve.p) { + (true, _, _, _) | (_, true, _, _) | (_, _, true, _) | (_, _, _, true) => { + panic!("Point coordinates are out of range"); + } + _ => (), } let point1 = Point { @@ -351,8 +351,18 @@ fn add(p1: (BigUint, BigUint), p2: (BigUint, BigUint)) -> (BigUint, BigUint) { /// SM2 multiply #[pyfunction] -fn multiply(point: (BigUint, BigUint), n: BigUint) -> (BigUint, BigUint) { +fn multiply(point: (BigUint, BigUint), mut n: BigUint) -> (BigUint, BigUint) { let curve = sm2p256v1(); + + if n >= curve.n { + n %= &curve.n; + } + + match (point.0 < curve.p, point.1 < curve.p) { + (false, _) | (_, false) => panic!("Point coordinates are out of range"), + _ => {} + } + // Construct the point with BigUint values let point = Point { x: point.0,