diff --git a/problems/p7/Cargo.toml b/problems/p7/Cargo.toml index 36cb3f3..62de4f3 100644 --- a/problems/p7/Cargo.toml +++ b/problems/p7/Cargo.toml @@ -5,3 +5,5 @@ edition = "2024" [dependencies] anyhow = "1.0.98" +hex = "0.4.3" +base64 = "0.22.1" diff --git a/problems/p7/src/main.rs b/problems/p7/src/main.rs index cc18f9a..ce7cd24 100644 --- a/problems/p7/src/main.rs +++ b/problems/p7/src/main.rs @@ -1,4 +1,7 @@ +#![allow(dead_code)] use anyhow::{Result, anyhow}; +use base64::{Engine, engine::general_purpose::STANDARD}; +use std::fs::read_to_string; const SBOX: [u8; 256] = [ 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, @@ -10,12 +13,12 @@ const SBOX: [u8; 256] = [ 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, - 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x35, 0x5c, 0x87, 0x5f, 0x5b, 0x62, 0x99, 0xaa, 0xa1, - 0x08, 0xba, 0x7f, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, - 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, - 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, - 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, - 0x16, 0x63, 0x0c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, + 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, + 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, + 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, + 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, + 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, + 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16, ]; const INV_SBOX: [u8; 256] = [ @@ -37,6 +40,19 @@ const INV_SBOX: [u8; 256] = [ 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d, ]; +const RCON: [[u8; 4]; 10] = [ + [0x01, 0x00, 0x00, 0x00], + [0x02, 0x00, 0x00, 0x00], + [0x04, 0x00, 0x00, 0x00], + [0x08, 0x00, 0x00, 0x00], + [0x10, 0x00, 0x00, 0x00], + [0x20, 0x00, 0x00, 0x00], + [0x40, 0x00, 0x00, 0x00], + [0x80, 0x00, 0x00, 0x00], + [0x1b, 0x00, 0x00, 0x00], + [0x36, 0x00, 0x00, 0x00], +]; + fn add_round_key(block: &[u8; 16], round_key: &[u8; 16]) -> [u8; 16] { std::array::from_fn(|i| block[i] ^ round_key[i]) } @@ -45,6 +61,10 @@ fn sub_bytes(block: &[u8; 16]) -> [u8; 16] { std::array::from_fn(|i| SBOX[block[i] as usize]) } +fn inv_sub_bytes(block: &[u8; 16]) -> [u8; 16] { + std::array::from_fn(|i| INV_SBOX[block[i] as usize]) +} + fn shift_rows(block: &[u8; 16]) -> [u8; 16] { let mut result = [0u8; 16]; // Row 0 (no shift) @@ -95,48 +115,166 @@ fn inv_shift_rows(block: &[u8; 16]) -> [u8; 16] { result } -fn inv_sub_bytes(block: &[u8; 16]) -> [u8; 16] { - std::array::from_fn(|i| INV_SBOX[block[i] as usize]) +fn gf_mul_01(a: u8) -> u8 { + a +} + +fn gf_mul_02(a: u8) -> u8 { + if a & 0x80 != 0 { + (a << 1) ^ 0x1b + } else { + a << 1 + } +} + +fn gf_mul_03(a: u8) -> u8 { + gf_mul_02(a) ^ a +} + +fn gf_mul_0e(a: u8) -> u8 { + gf_mul_02(gf_mul_02(gf_mul_02(a))) ^ gf_mul_04(a) ^ gf_mul_02(a) +} + +fn gf_mul_0b(a: u8) -> u8 { + gf_mul_08(a) ^ gf_mul_02(a) ^ a +} + +fn gf_mul_0d(a: u8) -> u8 { + gf_mul_08(a) ^ gf_mul_04(a) ^ a +} + +fn gf_mul_09(a: u8) -> u8 { + gf_mul_08(a) ^ a +} + +fn gf_mul_04(a: u8) -> u8 { + gf_mul_02(gf_mul_02(a)) +} + +fn gf_mul_08(a: u8) -> u8 { + gf_mul_02(gf_mul_04(a)) } fn mix_columns(block: &[u8; 16]) -> [u8; 16] { - todo!() + // [02 03 01 01] + // [01 02 03 01] + // [01 01 02 03] + // [03 01 01 02] + let mut result = [0u8; 16]; + for i in 0..4 { + let col = [ + block[i * 4], + block[i * 4 + 1], + block[i * 4 + 2], + block[i * 4 + 3], + ]; + result[i * 4] = + gf_mul_02(col[0]) ^ gf_mul_03(col[1]) ^ gf_mul_01(col[2]) ^ gf_mul_01(col[3]); + result[i * 4 + 1] = + gf_mul_01(col[0]) ^ gf_mul_02(col[1]) ^ gf_mul_03(col[2]) ^ gf_mul_01(col[3]); + result[i * 4 + 2] = + gf_mul_01(col[0]) ^ gf_mul_01(col[1]) ^ gf_mul_02(col[2]) ^ gf_mul_03(col[3]); + result[i * 4 + 3] = + gf_mul_03(col[0]) ^ gf_mul_01(col[1]) ^ gf_mul_01(col[2]) ^ gf_mul_02(col[3]); + } + result } fn inv_mix_columns(block: &[u8; 16]) -> [u8; 16] { - todo!() + // [0E 0B 0D 09] + // [09 0E 0B 0D] + // [0D 09 0E 0B] + // [0B 0D 09 0E] + let mut result = [0u8; 16]; + for i in 0..4 { + let col = [ + block[i * 4], + block[i * 4 + 1], + block[i * 4 + 2], + block[i * 4 + 3], + ]; + result[i * 4] = + gf_mul_0e(col[0]) ^ gf_mul_0b(col[1]) ^ gf_mul_0d(col[2]) ^ gf_mul_09(col[3]); + result[i * 4 + 1] = + gf_mul_09(col[0]) ^ gf_mul_0e(col[1]) ^ gf_mul_0b(col[2]) ^ gf_mul_0d(col[3]); + result[i * 4 + 2] = + gf_mul_0d(col[0]) ^ gf_mul_09(col[1]) ^ gf_mul_0e(col[2]) ^ gf_mul_0b(col[3]); + result[i * 4 + 3] = + gf_mul_0b(col[0]) ^ gf_mul_0d(col[1]) ^ gf_mul_09(col[2]) ^ gf_mul_0e(col[3]); + } + result } fn expand_key(key: &[u8; 16]) -> [[u8; 16]; 11] { let mut round_key: [[u8; 16]; 11] = [[0u8; 16]; 11]; round_key[0] = *key; - for round in 1..round_key.len() {} + + for round in 1..round_key.len() { + let prev_key = round_key[round - 1]; + let mut new_key = [0u8; 16]; + + // 对前一轮密钥的最后4字节进行g变换 + let g_result = g_func(prev_key[12..16].try_into().unwrap(), round - 1); + + // 新密钥的每4字节都要与前面的4字节异或 + for i in 0..4 { + new_key[i] = prev_key[i] ^ g_result[i]; + } + + for i in 4..8 { + new_key[i] = prev_key[i] ^ new_key[i - 4]; + } + + for i in 8..12 { + new_key[i] = prev_key[i] ^ new_key[i - 4]; + } + + for i in 12..16 { + new_key[i] = prev_key[i] ^ new_key[i - 4]; + } + + round_key[round] = new_key; + } round_key } fn rot_word(word: [u8; 4]) -> [u8; 4] { - todo!() + [word[1], word[2], word[3], word[0]] +} + +fn xor_rcon(word: [u8; 4], round: usize) -> [u8; 4] { + let mut result = word; + result[0] ^= RCON[round][0]; // 只对第一个字节进行Rcon异或 + result } fn sub_word(word: [u8; 4]) -> [u8; 4] { - todo!() + [ + SBOX[word[0] as usize], + SBOX[word[1] as usize], + SBOX[word[2] as usize], + SBOX[word[3] as usize], + ] } -fn get_rcon(word: [u8; 4]) -> [u8; 4] { - todo!() +fn g_func(word: [u8; 4], round: usize) -> [u8; 4] { + let mut result = rot_word(word); + result = sub_word(result); + result = xor_rcon(result, round); + result } fn aes_ecb_enc(input: &[u8], key: &[u8; 16]) -> Result> { if input.len() % 16 != 0 { return Err(anyhow!("Invalid input length")); } - let cipher: Vec = Vec::new(); - let round_key = expand_key(key); + let mut cipher: Vec = Vec::new(); + let round_keys = expand_key(key); for i in 0..(input.len() / 16) { let mut block: [u8; 16] = input[(i * 16)..(i * 16 + 16)].try_into()?; - block = add_round_key(&block, key); - for _ in 0..9 { + block = add_round_key(&block, &round_keys[0]); + for round_key in round_keys.iter().take(10).skip(1) { block = sub_bytes(&block); block = shift_rows(&block); block = mix_columns(&block); @@ -144,7 +282,8 @@ fn aes_ecb_enc(input: &[u8], key: &[u8; 16]) -> Result> { } block = sub_bytes(&block); block = shift_rows(&block); - block = add_round_key(&block, round_key); + block = add_round_key(&block, &round_keys[10]); + cipher.extend(block); } Ok(cipher) @@ -155,12 +294,74 @@ fn aes_ecb_dec(input: &[u8], key: &[u8; 16]) -> Result> { return Err(anyhow!("Invalid input length")); } - let plaintext: Vec = Vec::new(); - let round_key = expand_key(key); + let mut plaintext: Vec = Vec::new(); + let round_keys = expand_key(key); + + for i in 0..(input.len() / 16) { + let mut block: [u8; 16] = input[(i * 16)..(i * 16 + 16)].try_into()?; + block = add_round_key(&block, &round_keys[10]); + block = inv_shift_rows(&block); + block = inv_sub_bytes(&block); + for j in 0..9 { + block = add_round_key(&block, &round_keys[9 - j]); + block = inv_mix_columns(&block); + block = inv_shift_rows(&block); + block = inv_sub_bytes(&block); + } + block = add_round_key(&block, &round_keys[0]); + plaintext.extend(block); + } Ok(plaintext) } fn main() { - println!("Hello, world!"); + let key = "YELLOW SUBMARINE".as_bytes(); + let file_path = "./problems/p7/7.txt"; + let b64_cipher = read_to_string(file_path).expect("Failed to read file"); + let cipher = STANDARD + .decode(b64_cipher.trim().replace('\n', "")) + .expect("Failed to decode base64"); + let key_array: [u8; 16] = key.try_into().expect("Key must be 16 bytes long"); + let decrypted = aes_ecb_dec(&cipher, &key_array).expect("Decryption failed"); + let decrypted_str = String::from_utf8(decrypted).expect("Failed to convert to string"); + println!("Decrypted text: {decrypted_str}"); +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_aes_ecb_enc_dec() { + let key: [u8; 16] = [0x00; 16]; + let plaintext: Vec = vec![0x01; 32]; // 2 blocks of 16 bytes + let ciphertext = aes_ecb_enc(&plaintext, &key).unwrap(); + dbg!(&ciphertext); + let decrypted = aes_ecb_dec(&ciphertext, &key).unwrap(); + assert_eq!(decrypted, plaintext); + } + #[test] + fn test_aes_ecb_enc_with_openssl_cmd() { + use std::io::Write; + use std::process::{Command, Stdio}; + let key: [u8; 16] = [0x00; 16]; + let plaintext: Vec = vec![0x01; 32]; // 2 blocks of 16 bytes + // Encrypt using our implementation + let ciphertext = aes_ecb_enc(&plaintext, &key).unwrap(); + + // Call openssl with stdin + let mut child = Command::new("openssl") + .args(["enc", "-aes-128-ecb", "-K", &hex::encode(key), "-nopad"]) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .expect("failed to execute openssl"); + + let mut stdin = child.stdin.take().expect("failed to get stdin"); + stdin.write_all(&plaintext).unwrap(); + drop(stdin); + + let output = child.wait_with_output().unwrap(); + assert_eq!(ciphertext, output.stdout); + } }