first commit
This commit is contained in:
254
src/lib.rs
Normal file
254
src/lib.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
#![allow(dead_code)]
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::{One, Zero};
|
||||
use pyo3::prelude::*;
|
||||
|
||||
/// Define curve
|
||||
#[derive(Debug, Clone)]
|
||||
struct CurveFp {
|
||||
name: &'static str,
|
||||
a: BigUint,
|
||||
b: BigUint,
|
||||
p: BigUint,
|
||||
n: BigUint,
|
||||
gx: BigUint,
|
||||
gy: BigUint,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Point {
|
||||
x: BigUint,
|
||||
y: BigUint,
|
||||
curve: CurveFp,
|
||||
}
|
||||
|
||||
// Initialize the SM2 Curve
|
||||
fn sm2p256v1() -> CurveFp {
|
||||
CurveFp {
|
||||
name: "sm2p256v1",
|
||||
a: BigUint::parse_bytes(
|
||||
b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC",
|
||||
16,
|
||||
)
|
||||
.unwrap(),
|
||||
b: BigUint::parse_bytes(
|
||||
b"28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93",
|
||||
16,
|
||||
)
|
||||
.unwrap(),
|
||||
p: BigUint::parse_bytes(
|
||||
b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF",
|
||||
16,
|
||||
)
|
||||
.unwrap(),
|
||||
n: BigUint::parse_bytes(
|
||||
b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
|
||||
16,
|
||||
)
|
||||
.unwrap(),
|
||||
gx: BigUint::parse_bytes(
|
||||
b"32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7",
|
||||
16,
|
||||
)
|
||||
.unwrap(),
|
||||
gy: BigUint::parse_bytes(
|
||||
b"BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0",
|
||||
16,
|
||||
)
|
||||
.unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates the greatest common divisor (GCD) of two `BigUint` numbers.
|
||||
///
|
||||
/// The function uses the Euclidean algorithm to compute the GCD. The algorithm is based on
|
||||
/// the principle that the greatest common divisor of two numbers does not change if the larger
|
||||
/// number is replaced by its difference with the smaller number.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - The first `BigUint` number.
|
||||
/// * `b` - The second `BigUint` number.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `BigUint` - The greatest common divisor of `a` and `b`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// let a = BigUint::from(60u32);
|
||||
/// let b = BigUint::from(48u32);
|
||||
/// let result = gcd(a, b);
|
||||
/// assert_eq!(result, BigUint::from(12u32));
|
||||
/// ```
|
||||
fn gcd(a: &BigUint, b: &BigUint) -> BigUint {
|
||||
let mut c = a.clone();
|
||||
let mut d = b.clone();
|
||||
while !a.is_zero() {
|
||||
let temp = c.clone();
|
||||
c = d % c;
|
||||
d = temp;
|
||||
}
|
||||
d
|
||||
}
|
||||
|
||||
/// Computes the modular inverse of `a` under modulo `m`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - A reference to a BigUint representing the number to find the modular inverse of.
|
||||
/// * `p` - A reference to a BigUint representing the modulus.
|
||||
///
|
||||
/// # Return1
|
||||
///
|
||||
/// * A BigUint representing the modular inverse of `a` modulo `p`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if the modular inverse does not exist (i.e., if `a` and `p` are not coprime).
|
||||
fn mod_inverse(a: &BigUint, m: &BigUint) -> BigUint {
|
||||
let (mut t, mut new_t) = (BigUint::zero(), BigUint::one());
|
||||
let (mut r, mut new_r) = (m.clone(), a.clone());
|
||||
|
||||
while !new_r.is_zero() {
|
||||
let quotient = &r / &new_r;
|
||||
|
||||
let temp_t = new_t.clone();
|
||||
new_t = if t < (quotient.clone() * &temp_t) {
|
||||
// BigUint can't be negative,
|
||||
// so we use mod to handle the case where t < quotient * temp_t
|
||||
(&t + m - &(quotient.clone() * &temp_t) % m) % m
|
||||
} else {
|
||||
&t - &(quotient.clone() * &temp_t)
|
||||
};
|
||||
|
||||
let temp_r = new_r.clone();
|
||||
new_r = &r - &(quotient.clone() * &temp_r);
|
||||
r = temp_r;
|
||||
t = temp_t;
|
||||
}
|
||||
|
||||
if r > BigUint::one() {
|
||||
panic!("Modular inverse does not exist");
|
||||
}
|
||||
|
||||
if t < BigUint::zero() {
|
||||
t += m;
|
||||
}
|
||||
|
||||
t
|
||||
}
|
||||
fn point_addition(p1: &Point, p2: &Point) -> Point {
|
||||
let curve = &p1.curve;
|
||||
let p = &curve.p;
|
||||
|
||||
if p1.x.is_zero() && p1.y.is_zero() {
|
||||
return p2.clone();
|
||||
}
|
||||
if p2.x.is_zero() && p2.y.is_zero() {
|
||||
return p1.clone();
|
||||
}
|
||||
|
||||
let lambda = if p1.x == p2.x && p1.y == p2.y {
|
||||
let num = (BigUint::from(3u32) * &p1.x * &p1.x + &curve.a) % p;
|
||||
let denom = (BigUint::from(2u32) * &p1.y) % p;
|
||||
(num * mod_inverse(&denom, p)) % p
|
||||
} else {
|
||||
let num = ((&p2.y + p) - &p1.y) % p;
|
||||
let denom = ((&p2.x + p) - &p1.x) % p;
|
||||
|
||||
(num * mod_inverse(&denom, p)) % p
|
||||
};
|
||||
|
||||
println!("{lambda}");
|
||||
|
||||
let x3 = (lambda.clone() * &lambda - &p1.x - &p2.x) % p;
|
||||
let y3 = (lambda * (&p1.x + p - &x3) - &p1.y) % p;
|
||||
|
||||
Point {
|
||||
x: x3,
|
||||
y: y3,
|
||||
curve: curve.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn point_multiplication(p: &Point, n: &BigUint) -> Point {
|
||||
let mut result = Point {
|
||||
x: BigUint::zero(),
|
||||
y: BigUint::zero(),
|
||||
curve: p.curve.clone(),
|
||||
};
|
||||
|
||||
let mut addend = p.clone();
|
||||
let mut k = n.clone();
|
||||
|
||||
while !k.is_zero() {
|
||||
if &k % 2u32 == BigUint::one() {
|
||||
result = point_addition(&result, &addend);
|
||||
}
|
||||
addend = point_addition(&addend, &addend);
|
||||
k >>= 1;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// SM2 addition
|
||||
#[pyfunction]
|
||||
fn add(p1: (String, String), p2: (String, String)) -> (String, String) {
|
||||
let curve = sm2p256v1();
|
||||
|
||||
let x1 = BigUint::parse_bytes(p1.0.as_bytes(), 10).unwrap();
|
||||
let y1 = BigUint::parse_bytes(p1.1.as_bytes(), 10).unwrap();
|
||||
let x2 = BigUint::parse_bytes(p2.0.as_bytes(), 10).unwrap();
|
||||
let y2 = BigUint::parse_bytes(p2.1.as_bytes(), 10).unwrap();
|
||||
|
||||
// 检查 x 和 y 是否小于曲线参数 p
|
||||
if x1 >= curve.p || y1 >= curve.p {
|
||||
panic!("Point p1 coordinates are out of range");
|
||||
}
|
||||
if x2 >= curve.p || y2 >= curve.p {
|
||||
panic!("Point p2 coordinates are out of range");
|
||||
}
|
||||
|
||||
let point1 = Point {
|
||||
x: x1,
|
||||
y: y1,
|
||||
curve: curve.clone(),
|
||||
};
|
||||
|
||||
let point2 = Point {
|
||||
x: x2,
|
||||
y: y2,
|
||||
curve: curve.clone(),
|
||||
};
|
||||
|
||||
let result = point_addition(&point1, &point2);
|
||||
|
||||
(result.x.to_str_radix(10), result.y.to_str_radix(10))
|
||||
}
|
||||
|
||||
/// SM2 multiply
|
||||
#[pyfunction]
|
||||
fn multiply(point: (String, String), n: String) -> (String, String) {
|
||||
let curve = sm2p256v1();
|
||||
let point = Point {
|
||||
x: BigUint::parse_bytes(point.0.as_bytes(), 10).unwrap(),
|
||||
y: BigUint::parse_bytes(point.1.as_bytes(), 10).unwrap(),
|
||||
curve: curve.clone(),
|
||||
};
|
||||
|
||||
let scalar_bn = BigUint::parse_bytes(n.as_bytes(), 10).unwrap();
|
||||
let result = point_multiplication(&point, &scalar_bn);
|
||||
|
||||
(result.x.to_str_radix(10), result.y.to_str_radix(10))
|
||||
}
|
||||
|
||||
/// A Python module implemented in Rust.
|
||||
#[pymodule]
|
||||
fn ecc_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction_bound!(multiply, m)?)?;
|
||||
m.add_function(wrap_pyfunction_bound!(add, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user