diff --git a/Cargo.lock b/Cargo.lock index fbfd4a6..93f95a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,6 +109,7 @@ dependencies = [ "indoc", "libc", "memoffset", + "num-bigint", "once_cell", "portable-atomic", "pyo3-build-config", diff --git a/Cargo.toml b/Cargo.toml index 1d96762..fff2875 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,4 @@ crate-type = ["cdylib"] [dependencies] num-bigint = "0.4.6" num-traits = "0.2.19" -pyo3 = { version = "0.22.2", features = ["extension-module"] } +pyo3 = { version = "0.22.2", features = ["extension-module", "num-bigint"] } diff --git a/ecc_rs.pyi b/ecc_rs.pyi index 1f0a7cf..9bc63b8 100644 --- a/ecc_rs.pyi +++ b/ecc_rs.pyi @@ -1,4 +1,4 @@ # ecc_rs.pyi -def add(p1: tuple[str, str], p2: tuple[str, str]) -> tuple[str, str]: ... -def multiply(point: tuple[str, str], n: str) -> tuple[str, str]: ... +def add(p1: tuple[int, int], p2: tuple[int, int]) -> tuple[int, int]: ... +def multiply(point: tuple[int, int], n: int) -> tuple[int, int]: ... diff --git a/src/lib.rs b/src/lib.rs index 929bcc7..f751bcf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -194,13 +194,13 @@ fn point_multiplication(p: &Point, n: &BigUint) -> Point { /// SM2 addition #[pyfunction] -fn add(p1: (String, String), p2: (String, String)) -> (String, String) { +fn add(p1: (BigUint, BigUint), p2: (BigUint, BigUint)) -> (BigUint, BigUint) { let curve = sm2p256v1(); - let x1 = BigUint::parse_bytes(p1.0.as_bytes(), 10).unwrap(); - let y1 = BigUint::parse_bytes(p1.1.as_bytes(), 10).unwrap(); - let x2 = BigUint::parse_bytes(p2.0.as_bytes(), 10).unwrap(); - let y2 = BigUint::parse_bytes(p2.1.as_bytes(), 10).unwrap(); + let x1 = p1.0; + let y1 = p1.1; + let x2 = p2.0; + let y2 = p2.1; // 检查 x 和 y 是否小于曲线参数 p if x1 >= curve.p || y1 >= curve.p { @@ -223,30 +223,28 @@ fn add(p1: (String, String), p2: (String, String)) -> (String, String) { }; let result = point_addition(&point1, &point2); - - (result.x.to_str_radix(10), result.y.to_str_radix(10)) + (result.x, result.y) } /// SM2 multiply #[pyfunction] -fn multiply(point: (String, String), n: String) -> (String, String) { +fn multiply(point: (BigUint, BigUint), n: BigUint) -> (BigUint, BigUint) { let curve = sm2p256v1(); + // Construct the point with BigUint values let point = Point { - x: BigUint::parse_bytes(point.0.as_bytes(), 10).unwrap(), - y: BigUint::parse_bytes(point.1.as_bytes(), 10).unwrap(), + x: point.0.clone(), + y: point.1.clone(), curve: curve.clone(), }; - - let scalar_bn = BigUint::parse_bytes(n.as_bytes(), 10).unwrap(); - let result = point_multiplication(&point, &scalar_bn); - - (result.x.to_str_radix(10), result.y.to_str_radix(10)) + // Perform point multiplication + let result = point_multiplication(&point, &n); + (result.x, result.y) } /// A Python module implemented in Rust. #[pymodule] fn ecc_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction_bound!(multiply, m)?)?; - m.add_function(wrap_pyfunction_bound!(add, m)?)?; + m.add_function(wrap_pyfunction!(multiply, m)?)?; + m.add_function(wrap_pyfunction!(add, m)?)?; Ok(()) } diff --git a/test.py b/test.py index 7f26eda..5985b59 100644 --- a/test.py +++ b/test.py @@ -2,23 +2,18 @@ import ecc_rs # Example point coordinates for P1 and P2 as tuples (x1, y1) and (x2, y2) p1 = ( - "1234567890123456789012345678901234567890123456789012345678901234", - "9876543210987654321098765432109876543210987654321098765432109876", + 1234567890123456789012345678901234567890123456789012345678901234, + 9876543210987654321098765432109876543210987654321098765432109876, ) p2 = ( - "2234567890123456789012345678901234567890123456789012345678901234", - "2876543210987654321098765432109876543210987654321098765432109876", + 2234567890123456789012345678901234567890123456789012345678901234, + 2876543210987654321098765432109876543210987654321098765432109876, ) -print(ecc_rs.__all__) # Add the two points result_x, result_y = ecc_rs.add(p1, p2) print(f"Resulting Point: x = {result_x}, y = {result_y}") -# Convert the result to integers if needed -result_x_int = int(result_x) -result_y_int = int(result_y) -print(f"Resulting Point as integers: x = {result_x_int}, y = {result_y_int}") -result = ecc_rs.multiply(p1, "2") +result = ecc_rs.multiply(p1, 2) print(f"Resulting Point: x = {result[0]}, y = {result[1]}")