first commit

This commit is contained in:
2024-09-04 09:58:38 +08:00
commit 3f8e0ff6c9
6 changed files with 580 additions and 0 deletions

254
src/lib.rs Normal file
View File

@@ -0,0 +1,254 @@
#![allow(dead_code)]
use num_bigint::BigUint;
use num_traits::{One, Zero};
use pyo3::prelude::*;
/// Define curve
#[derive(Debug, Clone)]
struct CurveFp {
name: &'static str,
a: BigUint,
b: BigUint,
p: BigUint,
n: BigUint,
gx: BigUint,
gy: BigUint,
}
#[derive(Debug, Clone)]
struct Point {
x: BigUint,
y: BigUint,
curve: CurveFp,
}
// Initialize the SM2 Curve
fn sm2p256v1() -> CurveFp {
CurveFp {
name: "sm2p256v1",
a: BigUint::parse_bytes(
b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC",
16,
)
.unwrap(),
b: BigUint::parse_bytes(
b"28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93",
16,
)
.unwrap(),
p: BigUint::parse_bytes(
b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF",
16,
)
.unwrap(),
n: BigUint::parse_bytes(
b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
16,
)
.unwrap(),
gx: BigUint::parse_bytes(
b"32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7",
16,
)
.unwrap(),
gy: BigUint::parse_bytes(
b"BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0",
16,
)
.unwrap(),
}
}
/// Calculates the greatest common divisor (GCD) of two `BigUint` numbers.
///
/// The function uses the Euclidean algorithm to compute the GCD. The algorithm is based on
/// the principle that the greatest common divisor of two numbers does not change if the larger
/// number is replaced by its difference with the smaller number.
///
/// # Arguments
///
/// * `a` - The first `BigUint` number.
/// * `b` - The second `BigUint` number.
///
/// # Returns
///
/// * `BigUint` - The greatest common divisor of `a` and `b`.
///
/// # Example
///
/// ```
/// let a = BigUint::from(60u32);
/// let b = BigUint::from(48u32);
/// let result = gcd(a, b);
/// assert_eq!(result, BigUint::from(12u32));
/// ```
fn gcd(a: &BigUint, b: &BigUint) -> BigUint {
let mut c = a.clone();
let mut d = b.clone();
while !a.is_zero() {
let temp = c.clone();
c = d % c;
d = temp;
}
d
}
/// Computes the modular inverse of `a` under modulo `m`.
///
/// # Arguments
///
/// * `a` - A reference to a BigUint representing the number to find the modular inverse of.
/// * `p` - A reference to a BigUint representing the modulus.
///
/// # Return1
///
/// * A BigUint representing the modular inverse of `a` modulo `p`.
///
/// # Panics
///
/// This function will panic if the modular inverse does not exist (i.e., if `a` and `p` are not coprime).
fn mod_inverse(a: &BigUint, m: &BigUint) -> BigUint {
let (mut t, mut new_t) = (BigUint::zero(), BigUint::one());
let (mut r, mut new_r) = (m.clone(), a.clone());
while !new_r.is_zero() {
let quotient = &r / &new_r;
let temp_t = new_t.clone();
new_t = if t < (quotient.clone() * &temp_t) {
// BigUint can't be negative,
// so we use mod to handle the case where t < quotient * temp_t
(&t + m - &(quotient.clone() * &temp_t) % m) % m
} else {
&t - &(quotient.clone() * &temp_t)
};
let temp_r = new_r.clone();
new_r = &r - &(quotient.clone() * &temp_r);
r = temp_r;
t = temp_t;
}
if r > BigUint::one() {
panic!("Modular inverse does not exist");
}
if t < BigUint::zero() {
t += m;
}
t
}
fn point_addition(p1: &Point, p2: &Point) -> Point {
let curve = &p1.curve;
let p = &curve.p;
if p1.x.is_zero() && p1.y.is_zero() {
return p2.clone();
}
if p2.x.is_zero() && p2.y.is_zero() {
return p1.clone();
}
let lambda = if p1.x == p2.x && p1.y == p2.y {
let num = (BigUint::from(3u32) * &p1.x * &p1.x + &curve.a) % p;
let denom = (BigUint::from(2u32) * &p1.y) % p;
(num * mod_inverse(&denom, p)) % p
} else {
let num = ((&p2.y + p) - &p1.y) % p;
let denom = ((&p2.x + p) - &p1.x) % p;
(num * mod_inverse(&denom, p)) % p
};
println!("{lambda}");
let x3 = (lambda.clone() * &lambda - &p1.x - &p2.x) % p;
let y3 = (lambda * (&p1.x + p - &x3) - &p1.y) % p;
Point {
x: x3,
y: y3,
curve: curve.clone(),
}
}
fn point_multiplication(p: &Point, n: &BigUint) -> Point {
let mut result = Point {
x: BigUint::zero(),
y: BigUint::zero(),
curve: p.curve.clone(),
};
let mut addend = p.clone();
let mut k = n.clone();
while !k.is_zero() {
if &k % 2u32 == BigUint::one() {
result = point_addition(&result, &addend);
}
addend = point_addition(&addend, &addend);
k >>= 1;
}
result
}
/// SM2 addition
#[pyfunction]
fn add(p1: (String, String), p2: (String, String)) -> (String, String) {
let curve = sm2p256v1();
let x1 = BigUint::parse_bytes(p1.0.as_bytes(), 10).unwrap();
let y1 = BigUint::parse_bytes(p1.1.as_bytes(), 10).unwrap();
let x2 = BigUint::parse_bytes(p2.0.as_bytes(), 10).unwrap();
let y2 = BigUint::parse_bytes(p2.1.as_bytes(), 10).unwrap();
// 检查 x 和 y 是否小于曲线参数 p
if x1 >= curve.p || y1 >= curve.p {
panic!("Point p1 coordinates are out of range");
}
if x2 >= curve.p || y2 >= curve.p {
panic!("Point p2 coordinates are out of range");
}
let point1 = Point {
x: x1,
y: y1,
curve: curve.clone(),
};
let point2 = Point {
x: x2,
y: y2,
curve: curve.clone(),
};
let result = point_addition(&point1, &point2);
(result.x.to_str_radix(10), result.y.to_str_radix(10))
}
/// SM2 multiply
#[pyfunction]
fn multiply(point: (String, String), n: String) -> (String, String) {
let curve = sm2p256v1();
let point = Point {
x: BigUint::parse_bytes(point.0.as_bytes(), 10).unwrap(),
y: BigUint::parse_bytes(point.1.as_bytes(), 10).unwrap(),
curve: curve.clone(),
};
let scalar_bn = BigUint::parse_bytes(n.as_bytes(), 10).unwrap();
let result = point_multiplication(&point, &scalar_bn);
(result.x.to_str_radix(10), result.y.to_str_radix(10))
}
/// A Python module implemented in Rust.
#[pymodule]
fn ecc_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction_bound!(multiply, m)?)?;
m.add_function(wrap_pyfunction_bound!(add, m)?)?;
Ok(())
}