Compare commits

..

3 Commits

Author SHA1 Message Date
31d5d63897 accept int as input, output 2024-09-04 14:54:42 +08:00
0f289a728b add pyi 2024-09-04 11:32:54 +08:00
c629cd00c3 remove print 2024-09-04 10:22:10 +08:00
6 changed files with 29 additions and 33 deletions

1
Cargo.lock generated
View File

@@ -109,6 +109,7 @@ dependencies = [
"indoc",
"libc",
"memoffset",
"num-bigint",
"once_cell",
"portable-atomic",
"pyo3-build-config",

View File

@@ -11,4 +11,4 @@ crate-type = ["cdylib"]
[dependencies]
num-bigint = "0.4.6"
num-traits = "0.2.19"
pyo3 = "0.22.0"
pyo3 = { version = "0.22.2", features = ["extension-module", "num-bigint"] }

4
ecc_rs.pyi Normal file
View File

@@ -0,0 +1,4 @@
# ecc_rs.pyi
def add(p1: tuple[int, int], p2: tuple[int, int]) -> tuple[int, int]: ...
def multiply(point: tuple[int, int], n: int) -> tuple[int, int]: ...

View File

@@ -6,10 +6,10 @@ build-backend = "maturin"
name = "ecc_rs"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
]
dynamic = ["version"]
[tool.maturin]
features = ["pyo3/extension-module"]
bindings = "pyo3"

View File

@@ -161,8 +161,6 @@ fn point_addition(p1: &Point, p2: &Point) -> Point {
(num * mod_inverse(&denom, p)) % p
};
println!("{lambda}");
let x3 = (lambda.clone() * &lambda - &p1.x - &p2.x) % p;
let y3 = (lambda * (&p1.x + p - &x3) - &p1.y) % p;
@@ -196,13 +194,13 @@ fn point_multiplication(p: &Point, n: &BigUint) -> Point {
/// SM2 addition
#[pyfunction]
fn add(p1: (String, String), p2: (String, String)) -> (String, String) {
fn add(p1: (BigUint, BigUint), p2: (BigUint, BigUint)) -> (BigUint, BigUint) {
let curve = sm2p256v1();
let x1 = BigUint::parse_bytes(p1.0.as_bytes(), 10).unwrap();
let y1 = BigUint::parse_bytes(p1.1.as_bytes(), 10).unwrap();
let x2 = BigUint::parse_bytes(p2.0.as_bytes(), 10).unwrap();
let y2 = BigUint::parse_bytes(p2.1.as_bytes(), 10).unwrap();
let x1 = p1.0;
let y1 = p1.1;
let x2 = p2.0;
let y2 = p2.1;
// 检查 x 和 y 是否小于曲线参数 p
if x1 >= curve.p || y1 >= curve.p {
@@ -225,30 +223,28 @@ fn add(p1: (String, String), p2: (String, String)) -> (String, String) {
};
let result = point_addition(&point1, &point2);
(result.x.to_str_radix(10), result.y.to_str_radix(10))
(result.x, result.y)
}
/// SM2 multiply
#[pyfunction]
fn multiply(point: (String, String), n: String) -> (String, String) {
fn multiply(point: (BigUint, BigUint), n: BigUint) -> (BigUint, BigUint) {
let curve = sm2p256v1();
// Construct the point with BigUint values
let point = Point {
x: BigUint::parse_bytes(point.0.as_bytes(), 10).unwrap(),
y: BigUint::parse_bytes(point.1.as_bytes(), 10).unwrap(),
x: point.0.clone(),
y: point.1.clone(),
curve: curve.clone(),
};
let scalar_bn = BigUint::parse_bytes(n.as_bytes(), 10).unwrap();
let result = point_multiplication(&point, &scalar_bn);
(result.x.to_str_radix(10), result.y.to_str_radix(10))
// Perform point multiplication
let result = point_multiplication(&point, &n);
(result.x, result.y)
}
/// A Python module implemented in Rust.
#[pymodule]
fn ecc_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction_bound!(multiply, m)?)?;
m.add_function(wrap_pyfunction_bound!(add, m)?)?;
m.add_function(wrap_pyfunction!(multiply, m)?)?;
m.add_function(wrap_pyfunction!(add, m)?)?;
Ok(())
}

15
test.py
View File

@@ -2,23 +2,18 @@ import ecc_rs
# Example point coordinates for P1 and P2 as tuples (x1, y1) and (x2, y2)
p1 = (
"1234567890123456789012345678901234567890123456789012345678901234",
"9876543210987654321098765432109876543210987654321098765432109876",
1234567890123456789012345678901234567890123456789012345678901234,
9876543210987654321098765432109876543210987654321098765432109876,
)
p2 = (
"2234567890123456789012345678901234567890123456789012345678901234",
"2876543210987654321098765432109876543210987654321098765432109876",
2234567890123456789012345678901234567890123456789012345678901234,
2876543210987654321098765432109876543210987654321098765432109876,
)
print(ecc_rs.__all__)
# Add the two points
result_x, result_y = ecc_rs.add(p1, p2)
print(f"Resulting Point: x = {result_x}, y = {result_y}")
# Convert the result to integers if needed
result_x_int = int(result_x)
result_y_int = int(result_y)
print(f"Resulting Point as integers: x = {result_x_int}, y = {result_y_int}")
result = ecc_rs.multiply(p1, "2")
result = ecc_rs.multiply(p1, 2)
print(f"Resulting Point: x = {result[0]}, y = {result[1]}")