first commit
This commit is contained in:
		
							
								
								
									
										254
									
								
								src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								src/lib.rs
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,254 @@ | ||||
| #![allow(dead_code)] | ||||
| use num_bigint::BigUint; | ||||
| use num_traits::{One, Zero}; | ||||
| use pyo3::prelude::*; | ||||
|  | ||||
| /// Define curve | ||||
| #[derive(Debug, Clone)] | ||||
| struct CurveFp { | ||||
|     name: &'static str, | ||||
|     a: BigUint, | ||||
|     b: BigUint, | ||||
|     p: BigUint, | ||||
|     n: BigUint, | ||||
|     gx: BigUint, | ||||
|     gy: BigUint, | ||||
| } | ||||
|  | ||||
| #[derive(Debug, Clone)] | ||||
| struct Point { | ||||
|     x: BigUint, | ||||
|     y: BigUint, | ||||
|     curve: CurveFp, | ||||
| } | ||||
|  | ||||
| // Initialize the SM2 Curve | ||||
| fn sm2p256v1() -> CurveFp { | ||||
|     CurveFp { | ||||
|         name: "sm2p256v1", | ||||
|         a: BigUint::parse_bytes( | ||||
|             b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC", | ||||
|             16, | ||||
|         ) | ||||
|         .unwrap(), | ||||
|         b: BigUint::parse_bytes( | ||||
|             b"28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", | ||||
|             16, | ||||
|         ) | ||||
|         .unwrap(), | ||||
|         p: BigUint::parse_bytes( | ||||
|             b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", | ||||
|             16, | ||||
|         ) | ||||
|         .unwrap(), | ||||
|         n: BigUint::parse_bytes( | ||||
|             b"FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", | ||||
|             16, | ||||
|         ) | ||||
|         .unwrap(), | ||||
|         gx: BigUint::parse_bytes( | ||||
|             b"32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", | ||||
|             16, | ||||
|         ) | ||||
|         .unwrap(), | ||||
|         gy: BigUint::parse_bytes( | ||||
|             b"BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", | ||||
|             16, | ||||
|         ) | ||||
|         .unwrap(), | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Calculates the greatest common divisor (GCD) of two `BigUint` numbers. | ||||
| /// | ||||
| /// The function uses the Euclidean algorithm to compute the GCD. The algorithm is based on | ||||
| /// the principle that the greatest common divisor of two numbers does not change if the larger | ||||
| /// number is replaced by its difference with the smaller number. | ||||
| /// | ||||
| /// # Arguments | ||||
| /// | ||||
| /// * `a` - The first `BigUint` number. | ||||
| /// * `b` - The second `BigUint` number. | ||||
| /// | ||||
| /// # Returns | ||||
| /// | ||||
| /// * `BigUint` - The greatest common divisor of `a` and `b`. | ||||
| /// | ||||
| /// # Example | ||||
| /// | ||||
| /// ``` | ||||
| /// let a = BigUint::from(60u32); | ||||
| /// let b = BigUint::from(48u32); | ||||
| /// let result = gcd(a, b); | ||||
| /// assert_eq!(result, BigUint::from(12u32)); | ||||
| /// ``` | ||||
| fn gcd(a: &BigUint, b: &BigUint) -> BigUint { | ||||
|     let mut c = a.clone(); | ||||
|     let mut d = b.clone(); | ||||
|     while !a.is_zero() { | ||||
|         let temp = c.clone(); | ||||
|         c = d % c; | ||||
|         d = temp; | ||||
|     } | ||||
|     d | ||||
| } | ||||
|  | ||||
| /// Computes the modular inverse of `a` under modulo `m`. | ||||
| /// | ||||
| /// # Arguments | ||||
| /// | ||||
| /// * `a` - A reference to a BigUint representing the number to find the modular inverse of. | ||||
| /// * `p` - A reference to a BigUint representing the modulus. | ||||
| /// | ||||
| /// # Return1 | ||||
| /// | ||||
| /// * A BigUint representing the modular inverse of `a` modulo `p`. | ||||
| /// | ||||
| /// # Panics | ||||
| /// | ||||
| /// This function will panic if the modular inverse does not exist (i.e., if `a` and `p` are not coprime). | ||||
| fn mod_inverse(a: &BigUint, m: &BigUint) -> BigUint { | ||||
|     let (mut t, mut new_t) = (BigUint::zero(), BigUint::one()); | ||||
|     let (mut r, mut new_r) = (m.clone(), a.clone()); | ||||
|  | ||||
|     while !new_r.is_zero() { | ||||
|         let quotient = &r / &new_r; | ||||
|  | ||||
|         let temp_t = new_t.clone(); | ||||
|         new_t = if t < (quotient.clone() * &temp_t) { | ||||
|             // BigUint can't be negative, | ||||
|             // so we use mod to handle the case where t < quotient * temp_t | ||||
|             (&t + m - &(quotient.clone() * &temp_t) % m) % m | ||||
|         } else { | ||||
|             &t - &(quotient.clone() * &temp_t) | ||||
|         }; | ||||
|  | ||||
|         let temp_r = new_r.clone(); | ||||
|         new_r = &r - &(quotient.clone() * &temp_r); | ||||
|         r = temp_r; | ||||
|         t = temp_t; | ||||
|     } | ||||
|  | ||||
|     if r > BigUint::one() { | ||||
|         panic!("Modular inverse does not exist"); | ||||
|     } | ||||
|  | ||||
|     if t < BigUint::zero() { | ||||
|         t += m; | ||||
|     } | ||||
|  | ||||
|     t | ||||
| } | ||||
| fn point_addition(p1: &Point, p2: &Point) -> Point { | ||||
|     let curve = &p1.curve; | ||||
|     let p = &curve.p; | ||||
|  | ||||
|     if p1.x.is_zero() && p1.y.is_zero() { | ||||
|         return p2.clone(); | ||||
|     } | ||||
|     if p2.x.is_zero() && p2.y.is_zero() { | ||||
|         return p1.clone(); | ||||
|     } | ||||
|  | ||||
|     let lambda = if p1.x == p2.x && p1.y == p2.y { | ||||
|         let num = (BigUint::from(3u32) * &p1.x * &p1.x + &curve.a) % p; | ||||
|         let denom = (BigUint::from(2u32) * &p1.y) % p; | ||||
|         (num * mod_inverse(&denom, p)) % p | ||||
|     } else { | ||||
|         let num = ((&p2.y + p) - &p1.y) % p; | ||||
|         let denom = ((&p2.x + p) - &p1.x) % p; | ||||
|  | ||||
|         (num * mod_inverse(&denom, p)) % p | ||||
|     }; | ||||
|  | ||||
|     println!("{lambda}"); | ||||
|  | ||||
|     let x3 = (lambda.clone() * &lambda - &p1.x - &p2.x) % p; | ||||
|     let y3 = (lambda * (&p1.x + p - &x3) - &p1.y) % p; | ||||
|  | ||||
|     Point { | ||||
|         x: x3, | ||||
|         y: y3, | ||||
|         curve: curve.clone(), | ||||
|     } | ||||
| } | ||||
|  | ||||
| fn point_multiplication(p: &Point, n: &BigUint) -> Point { | ||||
|     let mut result = Point { | ||||
|         x: BigUint::zero(), | ||||
|         y: BigUint::zero(), | ||||
|         curve: p.curve.clone(), | ||||
|     }; | ||||
|  | ||||
|     let mut addend = p.clone(); | ||||
|     let mut k = n.clone(); | ||||
|  | ||||
|     while !k.is_zero() { | ||||
|         if &k % 2u32 == BigUint::one() { | ||||
|             result = point_addition(&result, &addend); | ||||
|         } | ||||
|         addend = point_addition(&addend, &addend); | ||||
|         k >>= 1; | ||||
|     } | ||||
|  | ||||
|     result | ||||
| } | ||||
|  | ||||
| /// SM2 addition | ||||
| #[pyfunction] | ||||
| fn add(p1: (String, String), p2: (String, String)) -> (String, String) { | ||||
|     let curve = sm2p256v1(); | ||||
|  | ||||
|     let x1 = BigUint::parse_bytes(p1.0.as_bytes(), 10).unwrap(); | ||||
|     let y1 = BigUint::parse_bytes(p1.1.as_bytes(), 10).unwrap(); | ||||
|     let x2 = BigUint::parse_bytes(p2.0.as_bytes(), 10).unwrap(); | ||||
|     let y2 = BigUint::parse_bytes(p2.1.as_bytes(), 10).unwrap(); | ||||
|  | ||||
|     // 检查 x 和 y 是否小于曲线参数 p | ||||
|     if x1 >= curve.p || y1 >= curve.p { | ||||
|         panic!("Point p1 coordinates are out of range"); | ||||
|     } | ||||
|     if x2 >= curve.p || y2 >= curve.p { | ||||
|         panic!("Point p2 coordinates are out of range"); | ||||
|     } | ||||
|  | ||||
|     let point1 = Point { | ||||
|         x: x1, | ||||
|         y: y1, | ||||
|         curve: curve.clone(), | ||||
|     }; | ||||
|  | ||||
|     let point2 = Point { | ||||
|         x: x2, | ||||
|         y: y2, | ||||
|         curve: curve.clone(), | ||||
|     }; | ||||
|  | ||||
|     let result = point_addition(&point1, &point2); | ||||
|  | ||||
|     (result.x.to_str_radix(10), result.y.to_str_radix(10)) | ||||
| } | ||||
|  | ||||
| /// SM2 multiply | ||||
| #[pyfunction] | ||||
| fn multiply(point: (String, String), n: String) -> (String, String) { | ||||
|     let curve = sm2p256v1(); | ||||
|     let point = Point { | ||||
|         x: BigUint::parse_bytes(point.0.as_bytes(), 10).unwrap(), | ||||
|         y: BigUint::parse_bytes(point.1.as_bytes(), 10).unwrap(), | ||||
|         curve: curve.clone(), | ||||
|     }; | ||||
|  | ||||
|     let scalar_bn = BigUint::parse_bytes(n.as_bytes(), 10).unwrap(); | ||||
|     let result = point_multiplication(&point, &scalar_bn); | ||||
|  | ||||
|     (result.x.to_str_radix(10), result.y.to_str_radix(10)) | ||||
| } | ||||
|  | ||||
| /// A Python module implemented in Rust. | ||||
| #[pymodule] | ||||
| fn ecc_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { | ||||
|     m.add_function(wrap_pyfunction_bound!(multiply, m)?)?; | ||||
|     m.add_function(wrap_pyfunction_bound!(add, m)?)?; | ||||
|     Ok(()) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user