From e42d659c1ff769de0f06b32e87ffaf11382e4138 Mon Sep 17 00:00:00 2001 From: Jack Lloyd Date: Sun, 4 Aug 2024 19:03:01 -0400 Subject: [PATCH] Add support for AVX2-VAES On an AMD Zen3 system, results in 50% performance improvement for bulk AES. --- src/lib/block/aes/aes.cpp | 66 +++ src/lib/block/aes/aes.h | 15 + src/lib/block/aes/aes_vaes/aes_vaes.cpp | 629 +++++++++++++++++++++++ src/lib/block/aes/aes_vaes/info.txt | 16 + src/lib/utils/simd/simd_avx2/simd_avx2.h | 25 +- src/tests/data/block/aes.vec | 2 +- 6 files changed, 747 insertions(+), 6 deletions(-) create mode 100644 src/lib/block/aes/aes_vaes/aes_vaes.cpp create mode 100644 src/lib/block/aes/aes_vaes/info.txt diff --git a/src/lib/block/aes/aes.cpp b/src/lib/block/aes/aes.cpp index 3c09d2a8829..1f81a667ef3 100644 --- a/src/lib/block/aes/aes.cpp +++ b/src/lib/block/aes/aes.cpp @@ -740,6 +740,12 @@ void aes_key_schedule(const uint8_t key[], } size_t aes_parallelism() { +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return 8; // pipelined + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return 4; // pipelined @@ -757,6 +763,12 @@ size_t aes_parallelism() { } const char* aes_provider() { +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return "vaes"; + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return "cpu"; @@ -813,6 +825,12 @@ bool AES_256::has_keying_material() const { void AES_128::encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { assert_key_material_set(); +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return x86_vaes_encrypt_n(in, out, blocks); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return hw_aes_encrypt_n(in, out, blocks); @@ -831,6 +849,12 @@ void AES_128::encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const void AES_128::decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { assert_key_material_set(); +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return x86_vaes_decrypt_n(in, out, blocks); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return hw_aes_decrypt_n(in, out, blocks); @@ -853,6 +877,12 @@ void AES_128::key_schedule(std::span key) { } #endif +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return aes_key_schedule(key.data(), key.size(), m_EK, m_DK, CPUID::is_little_endian()); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return aes_key_schedule(key.data(), key.size(), m_EK, m_DK, CPUID::is_little_endian()); @@ -876,6 +906,12 @@ void AES_128::clear() { void AES_192::encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { assert_key_material_set(); +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return x86_vaes_encrypt_n(in, out, blocks); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return hw_aes_encrypt_n(in, out, blocks); @@ -894,6 +930,12 @@ void AES_192::encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const void AES_192::decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { assert_key_material_set(); +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return x86_vaes_decrypt_n(in, out, blocks); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return hw_aes_decrypt_n(in, out, blocks); @@ -916,6 +958,12 @@ void AES_192::key_schedule(std::span key) { } #endif +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return aes_key_schedule(key.data(), key.size(), m_EK, m_DK, CPUID::is_little_endian()); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return aes_key_schedule(key.data(), key.size(), m_EK, m_DK, CPUID::is_little_endian()); @@ -939,6 +987,12 @@ void AES_192::clear() { void AES_256::encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { assert_key_material_set(); +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return x86_vaes_encrypt_n(in, out, blocks); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return hw_aes_encrypt_n(in, out, blocks); @@ -957,6 +1011,12 @@ void AES_256::encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const void AES_256::decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { assert_key_material_set(); +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return x86_vaes_decrypt_n(in, out, blocks); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return hw_aes_decrypt_n(in, out, blocks); @@ -979,6 +1039,12 @@ void AES_256::key_schedule(std::span key) { } #endif +#if defined(BOTAN_HAS_AES_VAES) + if(CPUID::has_avx2_vaes()) { + return aes_key_schedule(key.data(), key.size(), m_EK, m_DK, CPUID::is_little_endian()); + } +#endif + #if defined(BOTAN_HAS_HW_AES_SUPPORT) if(CPUID::has_hw_aes()) { return aes_key_schedule(key.data(), key.size(), m_EK, m_DK, CPUID::is_little_endian()); diff --git a/src/lib/block/aes/aes.h b/src/lib/block/aes/aes.h index 0071e0edbaa..6676aedf574 100644 --- a/src/lib/block/aes/aes.h +++ b/src/lib/block/aes/aes.h @@ -50,6 +50,11 @@ class AES_128 final : public Block_Cipher_Fixed_Params<16, 16> { void hw_aes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; #endif +#if defined(BOTAN_HAS_AES_VAES) + void x86_vaes_encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; + void x86_vaes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; +#endif + secure_vector m_EK, m_DK; }; @@ -88,6 +93,11 @@ class AES_192 final : public Block_Cipher_Fixed_Params<16, 24> { void hw_aes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; #endif +#if defined(BOTAN_HAS_AES_VAES) + void x86_vaes_encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; + void x86_vaes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; +#endif + void key_schedule(std::span key) override; secure_vector m_EK, m_DK; @@ -128,6 +138,11 @@ class AES_256 final : public Block_Cipher_Fixed_Params<16, 32> { void hw_aes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; #endif +#if defined(BOTAN_HAS_AES_VAES) + void x86_vaes_encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; + void x86_vaes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const; +#endif + void key_schedule(std::span key) override; secure_vector m_EK, m_DK; diff --git a/src/lib/block/aes/aes_vaes/aes_vaes.cpp b/src/lib/block/aes/aes_vaes/aes_vaes.cpp new file mode 100644 index 00000000000..40161af707d --- /dev/null +++ b/src/lib/block/aes/aes_vaes/aes_vaes.cpp @@ -0,0 +1,629 @@ +/* +* (C) 2024 Jack Lloyd +* +* Botan is released under the Simplified BSD License (see license.txt) +*/ + +#include + +#include +#include +#include + +namespace Botan { + +namespace { + +BOTAN_FORCE_INLINE void keyxor(SIMD_8x32 K, SIMD_8x32& B0, SIMD_8x32& B1, SIMD_8x32& B2, SIMD_8x32& B3) { + B0 ^= K; + B1 ^= K; + B2 ^= K; + B3 ^= K; +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") void aesenc(SIMD_8x32 K, SIMD_8x32& B) { + B = SIMD_8x32(_mm256_aesenc_epi128(B.raw(), K.raw())); +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") +void aesenc(SIMD_8x32 K, SIMD_8x32& B0, SIMD_8x32& B1, SIMD_8x32& B2, SIMD_8x32& B3) { + B0 = SIMD_8x32(_mm256_aesenc_epi128(B0.raw(), K.raw())); + B1 = SIMD_8x32(_mm256_aesenc_epi128(B1.raw(), K.raw())); + B2 = SIMD_8x32(_mm256_aesenc_epi128(B2.raw(), K.raw())); + B3 = SIMD_8x32(_mm256_aesenc_epi128(B3.raw(), K.raw())); +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") void aesenclast(SIMD_8x32 K, SIMD_8x32& B) { + B = SIMD_8x32(_mm256_aesenclast_epi128(B.raw(), K.raw())); +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") +void aesenclast(SIMD_8x32 K, SIMD_8x32& B0, SIMD_8x32& B1, SIMD_8x32& B2, SIMD_8x32& B3) { + B0 = SIMD_8x32(_mm256_aesenclast_epi128(B0.raw(), K.raw())); + B1 = SIMD_8x32(_mm256_aesenclast_epi128(B1.raw(), K.raw())); + B2 = SIMD_8x32(_mm256_aesenclast_epi128(B2.raw(), K.raw())); + B3 = SIMD_8x32(_mm256_aesenclast_epi128(B3.raw(), K.raw())); +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") void aesdec(SIMD_8x32 K, SIMD_8x32& B) { + B = SIMD_8x32(_mm256_aesdec_epi128(B.raw(), K.raw())); +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") +void aesdec(SIMD_8x32 K, SIMD_8x32& B0, SIMD_8x32& B1, SIMD_8x32& B2, SIMD_8x32& B3) { + B0 = SIMD_8x32(_mm256_aesdec_epi128(B0.raw(), K.raw())); + B1 = SIMD_8x32(_mm256_aesdec_epi128(B1.raw(), K.raw())); + B2 = SIMD_8x32(_mm256_aesdec_epi128(B2.raw(), K.raw())); + B3 = SIMD_8x32(_mm256_aesdec_epi128(B3.raw(), K.raw())); +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") void aesdeclast(SIMD_8x32 K, SIMD_8x32& B) { + B = SIMD_8x32(_mm256_aesdeclast_epi128(B.raw(), K.raw())); +} + +BOTAN_FUNC_ISA_INLINE("vaes,avx2") +void aesdeclast(SIMD_8x32 K, SIMD_8x32& B0, SIMD_8x32& B1, SIMD_8x32& B2, SIMD_8x32& B3) { + B0 = SIMD_8x32(_mm256_aesdeclast_epi128(B0.raw(), K.raw())); + B1 = SIMD_8x32(_mm256_aesdeclast_epi128(B1.raw(), K.raw())); + B2 = SIMD_8x32(_mm256_aesdeclast_epi128(B2.raw(), K.raw())); + B3 = SIMD_8x32(_mm256_aesdeclast_epi128(B3.raw(), K.raw())); +} + +} // namespace + +/* +* AES-128 Encryption +*/ +BOTAN_FUNC_ISA("vaes,avx2") void AES_128::x86_vaes_encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { + const SIMD_8x32 K0 = SIMD_8x32::load_le128(&m_EK[4 * 0]); + const SIMD_8x32 K1 = SIMD_8x32::load_le128(&m_EK[4 * 1]); + const SIMD_8x32 K2 = SIMD_8x32::load_le128(&m_EK[4 * 2]); + const SIMD_8x32 K3 = SIMD_8x32::load_le128(&m_EK[4 * 3]); + const SIMD_8x32 K4 = SIMD_8x32::load_le128(&m_EK[4 * 4]); + const SIMD_8x32 K5 = SIMD_8x32::load_le128(&m_EK[4 * 5]); + const SIMD_8x32 K6 = SIMD_8x32::load_le128(&m_EK[4 * 6]); + const SIMD_8x32 K7 = SIMD_8x32::load_le128(&m_EK[4 * 7]); + const SIMD_8x32 K8 = SIMD_8x32::load_le128(&m_EK[4 * 8]); + const SIMD_8x32 K9 = SIMD_8x32::load_le128(&m_EK[4 * 9]); + const SIMD_8x32 K10 = SIMD_8x32::load_le128(&m_EK[4 * 10]); + + while(blocks >= 8) { + SIMD_8x32 B0 = SIMD_8x32::load_le(in); + SIMD_8x32 B1 = SIMD_8x32::load_le(in + 16 * 2); + SIMD_8x32 B2 = SIMD_8x32::load_le(in + 16 * 4); + SIMD_8x32 B3 = SIMD_8x32::load_le(in + 16 * 6); + + keyxor(K0, B0, B1, B2, B3); + aesenc(K1, B0, B1, B2, B3); + aesenc(K2, B0, B1, B2, B3); + aesenc(K3, B0, B1, B2, B3); + aesenc(K4, B0, B1, B2, B3); + aesenc(K5, B0, B1, B2, B3); + aesenc(K6, B0, B1, B2, B3); + aesenc(K7, B0, B1, B2, B3); + aesenc(K8, B0, B1, B2, B3); + aesenc(K9, B0, B1, B2, B3); + aesenclast(K10, B0, B1, B2, B3); + + B0.store_le(out); + B1.store_le(out + 16 * 2); + B2.store_le(out + 16 * 4); + B3.store_le(out + 16 * 6); + + blocks -= 8; + in += 8 * 16; + out += 8 * 16; + } + + while(blocks >= 2) { + SIMD_8x32 B = SIMD_8x32::load_le(in); + + B ^= K0; + aesenc(K1, B); + aesenc(K2, B); + aesenc(K3, B); + aesenc(K4, B); + aesenc(K5, B); + aesenc(K6, B); + aesenc(K7, B); + aesenc(K8, B); + aesenc(K9, B); + aesenclast(K10, B); + + B.store_le(out); + + in += 2 * 16; + out += 2 * 16; + blocks -= 2; + } + + if(blocks > 0) { + SIMD_8x32 B = SIMD_8x32::load_le128(in); + + B ^= K0; + aesenc(K1, B); + aesenc(K2, B); + aesenc(K3, B); + aesenc(K4, B); + aesenc(K5, B); + aesenc(K6, B); + aesenc(K7, B); + aesenc(K8, B); + aesenc(K9, B); + aesenclast(K10, B); + + B.store_le128(out); + } +} + +/* +* AES-128 Decryption +*/ +BOTAN_FUNC_ISA("vaes,avx2") void AES_128::x86_vaes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { + const SIMD_8x32 K0 = SIMD_8x32::load_le128(&m_DK[4 * 0]); + const SIMD_8x32 K1 = SIMD_8x32::load_le128(&m_DK[4 * 1]); + const SIMD_8x32 K2 = SIMD_8x32::load_le128(&m_DK[4 * 2]); + const SIMD_8x32 K3 = SIMD_8x32::load_le128(&m_DK[4 * 3]); + const SIMD_8x32 K4 = SIMD_8x32::load_le128(&m_DK[4 * 4]); + const SIMD_8x32 K5 = SIMD_8x32::load_le128(&m_DK[4 * 5]); + const SIMD_8x32 K6 = SIMD_8x32::load_le128(&m_DK[4 * 6]); + const SIMD_8x32 K7 = SIMD_8x32::load_le128(&m_DK[4 * 7]); + const SIMD_8x32 K8 = SIMD_8x32::load_le128(&m_DK[4 * 8]); + const SIMD_8x32 K9 = SIMD_8x32::load_le128(&m_DK[4 * 9]); + const SIMD_8x32 K10 = SIMD_8x32::load_le128(&m_DK[4 * 10]); + + while(blocks >= 8) { + SIMD_8x32 B0 = SIMD_8x32::load_le(in + 16 * 0); + SIMD_8x32 B1 = SIMD_8x32::load_le(in + 16 * 2); + SIMD_8x32 B2 = SIMD_8x32::load_le(in + 16 * 4); + SIMD_8x32 B3 = SIMD_8x32::load_le(in + 16 * 6); + + keyxor(K0, B0, B1, B2, B3); + aesdec(K1, B0, B1, B2, B3); + aesdec(K2, B0, B1, B2, B3); + aesdec(K3, B0, B1, B2, B3); + aesdec(K4, B0, B1, B2, B3); + aesdec(K5, B0, B1, B2, B3); + aesdec(K6, B0, B1, B2, B3); + aesdec(K7, B0, B1, B2, B3); + aesdec(K8, B0, B1, B2, B3); + aesdec(K9, B0, B1, B2, B3); + aesdeclast(K10, B0, B1, B2, B3); + + B0.store_le(out + 16 * 0); + B1.store_le(out + 16 * 2); + B2.store_le(out + 16 * 4); + B3.store_le(out + 16 * 6); + + blocks -= 8; + in += 8 * 16; + out += 8 * 16; + } + + while(blocks >= 2) { + SIMD_8x32 B = SIMD_8x32::load_le(in); + + B ^= K0; + aesdec(K1, B); + aesdec(K2, B); + aesdec(K3, B); + aesdec(K4, B); + aesdec(K5, B); + aesdec(K6, B); + aesdec(K7, B); + aesdec(K8, B); + aesdec(K9, B); + aesdeclast(K10, B); + + B.store_le(out); + + in += 2 * 16; + out += 2 * 16; + blocks -= 2; + } + + if(blocks > 0) { + SIMD_8x32 B = SIMD_8x32::load_le128(in); + + B ^= K0; + aesdec(K1, B); + aesdec(K2, B); + aesdec(K3, B); + aesdec(K4, B); + aesdec(K5, B); + aesdec(K6, B); + aesdec(K7, B); + aesdec(K8, B); + aesdec(K9, B); + aesdeclast(K10, B); + + B.store_le128(out); + } +} + +/* +* AES-192 Encryption +*/ +BOTAN_FUNC_ISA("vaes,avx2") void AES_192::x86_vaes_encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { + const SIMD_8x32 K0 = SIMD_8x32::load_le128(&m_EK[4 * 0]); + const SIMD_8x32 K1 = SIMD_8x32::load_le128(&m_EK[4 * 1]); + const SIMD_8x32 K2 = SIMD_8x32::load_le128(&m_EK[4 * 2]); + const SIMD_8x32 K3 = SIMD_8x32::load_le128(&m_EK[4 * 3]); + const SIMD_8x32 K4 = SIMD_8x32::load_le128(&m_EK[4 * 4]); + const SIMD_8x32 K5 = SIMD_8x32::load_le128(&m_EK[4 * 5]); + const SIMD_8x32 K6 = SIMD_8x32::load_le128(&m_EK[4 * 6]); + const SIMD_8x32 K7 = SIMD_8x32::load_le128(&m_EK[4 * 7]); + const SIMD_8x32 K8 = SIMD_8x32::load_le128(&m_EK[4 * 8]); + const SIMD_8x32 K9 = SIMD_8x32::load_le128(&m_EK[4 * 9]); + const SIMD_8x32 K10 = SIMD_8x32::load_le128(&m_EK[4 * 10]); + const SIMD_8x32 K11 = SIMD_8x32::load_le128(&m_EK[4 * 11]); + const SIMD_8x32 K12 = SIMD_8x32::load_le128(&m_EK[4 * 12]); + + while(blocks >= 8) { + SIMD_8x32 B0 = SIMD_8x32::load_le(in + 16 * 0); + SIMD_8x32 B1 = SIMD_8x32::load_le(in + 16 * 2); + SIMD_8x32 B2 = SIMD_8x32::load_le(in + 16 * 4); + SIMD_8x32 B3 = SIMD_8x32::load_le(in + 16 * 6); + + keyxor(K0, B0, B1, B2, B3); + aesenc(K1, B0, B1, B2, B3); + aesenc(K2, B0, B1, B2, B3); + aesenc(K3, B0, B1, B2, B3); + aesenc(K4, B0, B1, B2, B3); + aesenc(K5, B0, B1, B2, B3); + aesenc(K6, B0, B1, B2, B3); + aesenc(K7, B0, B1, B2, B3); + aesenc(K8, B0, B1, B2, B3); + aesenc(K9, B0, B1, B2, B3); + aesenc(K10, B0, B1, B2, B3); + aesenc(K11, B0, B1, B2, B3); + aesenclast(K12, B0, B1, B2, B3); + + B0.store_le(out + 16 * 0); + B1.store_le(out + 16 * 2); + B2.store_le(out + 16 * 4); + B3.store_le(out + 16 * 6); + + blocks -= 8; + in += 8 * 16; + out += 8 * 16; + } + + while(blocks >= 2) { + SIMD_8x32 B = SIMD_8x32::load_le(in); + + B ^= K0; + aesenc(K1, B); + aesenc(K2, B); + aesenc(K3, B); + aesenc(K4, B); + aesenc(K5, B); + aesenc(K6, B); + aesenc(K7, B); + aesenc(K8, B); + aesenc(K9, B); + aesenc(K10, B); + aesenc(K11, B); + aesenclast(K12, B); + + B.store_le(out); + + in += 2 * 16; + out += 2 * 16; + blocks -= 2; + } + + if(blocks > 0) { + SIMD_8x32 B = SIMD_8x32::load_le128(in); + + B ^= K0; + aesenc(K1, B); + aesenc(K2, B); + aesenc(K3, B); + aesenc(K4, B); + aesenc(K5, B); + aesenc(K6, B); + aesenc(K7, B); + aesenc(K8, B); + aesenc(K9, B); + aesenc(K10, B); + aesenc(K11, B); + aesenclast(K12, B); + + B.store_le128(out); + } +} + +/* +* AES-192 Decryption +*/ +BOTAN_FUNC_ISA("vaes,avx2") void AES_192::x86_vaes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { + const SIMD_8x32 K0 = SIMD_8x32::load_le128(&m_DK[4 * 0]); + const SIMD_8x32 K1 = SIMD_8x32::load_le128(&m_DK[4 * 1]); + const SIMD_8x32 K2 = SIMD_8x32::load_le128(&m_DK[4 * 2]); + const SIMD_8x32 K3 = SIMD_8x32::load_le128(&m_DK[4 * 3]); + const SIMD_8x32 K4 = SIMD_8x32::load_le128(&m_DK[4 * 4]); + const SIMD_8x32 K5 = SIMD_8x32::load_le128(&m_DK[4 * 5]); + const SIMD_8x32 K6 = SIMD_8x32::load_le128(&m_DK[4 * 6]); + const SIMD_8x32 K7 = SIMD_8x32::load_le128(&m_DK[4 * 7]); + const SIMD_8x32 K8 = SIMD_8x32::load_le128(&m_DK[4 * 8]); + const SIMD_8x32 K9 = SIMD_8x32::load_le128(&m_DK[4 * 9]); + const SIMD_8x32 K10 = SIMD_8x32::load_le128(&m_DK[4 * 10]); + const SIMD_8x32 K11 = SIMD_8x32::load_le128(&m_DK[4 * 11]); + const SIMD_8x32 K12 = SIMD_8x32::load_le128(&m_DK[4 * 12]); + + while(blocks >= 8) { + SIMD_8x32 B0 = SIMD_8x32::load_le(in + 16 * 0); + SIMD_8x32 B1 = SIMD_8x32::load_le(in + 16 * 2); + SIMD_8x32 B2 = SIMD_8x32::load_le(in + 16 * 4); + SIMD_8x32 B3 = SIMD_8x32::load_le(in + 16 * 6); + + keyxor(K0, B0, B1, B2, B3); + aesdec(K1, B0, B1, B2, B3); + aesdec(K2, B0, B1, B2, B3); + aesdec(K3, B0, B1, B2, B3); + aesdec(K4, B0, B1, B2, B3); + aesdec(K5, B0, B1, B2, B3); + aesdec(K6, B0, B1, B2, B3); + aesdec(K7, B0, B1, B2, B3); + aesdec(K8, B0, B1, B2, B3); + aesdec(K9, B0, B1, B2, B3); + aesdec(K10, B0, B1, B2, B3); + aesdec(K11, B0, B1, B2, B3); + aesdeclast(K12, B0, B1, B2, B3); + + B0.store_le(out + 16 * 0); + B1.store_le(out + 16 * 2); + B2.store_le(out + 16 * 4); + B3.store_le(out + 16 * 6); + + blocks -= 8; + in += 8 * 16; + out += 8 * 16; + } + + while(blocks >= 2) { + SIMD_8x32 B = SIMD_8x32::load_le(in); + + B ^= K0; + aesdec(K1, B); + aesdec(K2, B); + aesdec(K3, B); + aesdec(K4, B); + aesdec(K5, B); + aesdec(K6, B); + aesdec(K7, B); + aesdec(K8, B); + aesdec(K9, B); + aesdec(K10, B); + aesdec(K11, B); + aesdeclast(K12, B); + + B.store_le(out); + + in += 2 * 16; + out += 2 * 16; + blocks -= 2; + } + + if(blocks > 0) { + SIMD_8x32 B = SIMD_8x32::load_le128(in); + + B ^= K0; + aesdec(K1, B); + aesdec(K2, B); + aesdec(K3, B); + aesdec(K4, B); + aesdec(K5, B); + aesdec(K6, B); + aesdec(K7, B); + aesdec(K8, B); + aesdec(K9, B); + aesdec(K10, B); + aesdec(K11, B); + aesdeclast(K12, B); + + B.store_le128(out); + } +} + +BOTAN_FUNC_ISA("vaes,avx2") void AES_256::x86_vaes_encrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { + const SIMD_8x32 K0 = SIMD_8x32::load_le128(&m_EK[4 * 0]); + const SIMD_8x32 K1 = SIMD_8x32::load_le128(&m_EK[4 * 1]); + const SIMD_8x32 K2 = SIMD_8x32::load_le128(&m_EK[4 * 2]); + const SIMD_8x32 K3 = SIMD_8x32::load_le128(&m_EK[4 * 3]); + const SIMD_8x32 K4 = SIMD_8x32::load_le128(&m_EK[4 * 4]); + const SIMD_8x32 K5 = SIMD_8x32::load_le128(&m_EK[4 * 5]); + const SIMD_8x32 K6 = SIMD_8x32::load_le128(&m_EK[4 * 6]); + const SIMD_8x32 K7 = SIMD_8x32::load_le128(&m_EK[4 * 7]); + const SIMD_8x32 K8 = SIMD_8x32::load_le128(&m_EK[4 * 8]); + const SIMD_8x32 K9 = SIMD_8x32::load_le128(&m_EK[4 * 9]); + const SIMD_8x32 K10 = SIMD_8x32::load_le128(&m_EK[4 * 10]); + const SIMD_8x32 K11 = SIMD_8x32::load_le128(&m_EK[4 * 11]); + const SIMD_8x32 K12 = SIMD_8x32::load_le128(&m_EK[4 * 12]); + const SIMD_8x32 K13 = SIMD_8x32::load_le128(&m_EK[4 * 13]); + const SIMD_8x32 K14 = SIMD_8x32::load_le128(&m_EK[4 * 14]); + + while(blocks >= 8) { + SIMD_8x32 B0 = SIMD_8x32::load_le(in + 16 * 0); + SIMD_8x32 B1 = SIMD_8x32::load_le(in + 16 * 2); + SIMD_8x32 B2 = SIMD_8x32::load_le(in + 16 * 4); + SIMD_8x32 B3 = SIMD_8x32::load_le(in + 16 * 6); + + keyxor(K0, B0, B1, B2, B3); + aesenc(K1, B0, B1, B2, B3); + aesenc(K2, B0, B1, B2, B3); + aesenc(K3, B0, B1, B2, B3); + aesenc(K4, B0, B1, B2, B3); + aesenc(K5, B0, B1, B2, B3); + aesenc(K6, B0, B1, B2, B3); + aesenc(K7, B0, B1, B2, B3); + aesenc(K8, B0, B1, B2, B3); + aesenc(K9, B0, B1, B2, B3); + aesenc(K10, B0, B1, B2, B3); + aesenc(K11, B0, B1, B2, B3); + aesenc(K12, B0, B1, B2, B3); + aesenc(K13, B0, B1, B2, B3); + aesenclast(K14, B0, B1, B2, B3); + + B0.store_le(out + 16 * 0); + B1.store_le(out + 16 * 2); + B2.store_le(out + 16 * 4); + B3.store_le(out + 16 * 6); + + blocks -= 8; + in += 8 * 16; + out += 8 * 16; + } + + while(blocks >= 2) { + SIMD_8x32 B = SIMD_8x32::load_le(in); + + B ^= K0; + aesenc(K1, B); + aesenc(K2, B); + aesenc(K3, B); + aesenc(K4, B); + aesenc(K5, B); + aesenc(K6, B); + aesenc(K7, B); + aesenc(K8, B); + aesenc(K9, B); + aesenc(K10, B); + aesenc(K11, B); + aesenc(K12, B); + aesenc(K13, B); + aesenclast(K14, B); + + B.store_le(out); + + in += 2 * 16; + out += 2 * 16; + blocks -= 2; + } + + if(blocks > 0) { + SIMD_8x32 B = SIMD_8x32::load_le128(in); + + B ^= K0; + aesenc(K1, B); + aesenc(K2, B); + aesenc(K3, B); + aesenc(K4, B); + aesenc(K5, B); + aesenc(K6, B); + aesenc(K7, B); + aesenc(K8, B); + aesenc(K9, B); + aesenc(K10, B); + aesenc(K11, B); + aesenc(K12, B); + aesenc(K13, B); + aesenclast(K14, B); + + B.store_le128(out); + } +} + +/* +* AES-256 Decryption +*/ +BOTAN_FUNC_ISA("vaes,avx2") void AES_256::x86_vaes_decrypt_n(const uint8_t in[], uint8_t out[], size_t blocks) const { + const SIMD_8x32 K0 = SIMD_8x32::load_le128(&m_DK[4 * 0]); + const SIMD_8x32 K1 = SIMD_8x32::load_le128(&m_DK[4 * 1]); + const SIMD_8x32 K2 = SIMD_8x32::load_le128(&m_DK[4 * 2]); + const SIMD_8x32 K3 = SIMD_8x32::load_le128(&m_DK[4 * 3]); + const SIMD_8x32 K4 = SIMD_8x32::load_le128(&m_DK[4 * 4]); + const SIMD_8x32 K5 = SIMD_8x32::load_le128(&m_DK[4 * 5]); + const SIMD_8x32 K6 = SIMD_8x32::load_le128(&m_DK[4 * 6]); + const SIMD_8x32 K7 = SIMD_8x32::load_le128(&m_DK[4 * 7]); + const SIMD_8x32 K8 = SIMD_8x32::load_le128(&m_DK[4 * 8]); + const SIMD_8x32 K9 = SIMD_8x32::load_le128(&m_DK[4 * 9]); + const SIMD_8x32 K10 = SIMD_8x32::load_le128(&m_DK[4 * 10]); + const SIMD_8x32 K11 = SIMD_8x32::load_le128(&m_DK[4 * 11]); + const SIMD_8x32 K12 = SIMD_8x32::load_le128(&m_DK[4 * 12]); + const SIMD_8x32 K13 = SIMD_8x32::load_le128(&m_DK[4 * 13]); + const SIMD_8x32 K14 = SIMD_8x32::load_le128(&m_DK[4 * 14]); + + while(blocks >= 8) { + SIMD_8x32 B0 = SIMD_8x32::load_le(in + 16 * 0); + SIMD_8x32 B1 = SIMD_8x32::load_le(in + 16 * 2); + SIMD_8x32 B2 = SIMD_8x32::load_le(in + 16 * 4); + SIMD_8x32 B3 = SIMD_8x32::load_le(in + 16 * 6); + + keyxor(K0, B0, B1, B2, B3); + aesdec(K1, B0, B1, B2, B3); + aesdec(K2, B0, B1, B2, B3); + aesdec(K3, B0, B1, B2, B3); + aesdec(K4, B0, B1, B2, B3); + aesdec(K5, B0, B1, B2, B3); + aesdec(K6, B0, B1, B2, B3); + aesdec(K7, B0, B1, B2, B3); + aesdec(K8, B0, B1, B2, B3); + aesdec(K9, B0, B1, B2, B3); + aesdec(K10, B0, B1, B2, B3); + aesdec(K11, B0, B1, B2, B3); + aesdec(K12, B0, B1, B2, B3); + aesdec(K13, B0, B1, B2, B3); + aesdeclast(K14, B0, B1, B2, B3); + + B0.store_le(out + 16 * 0); + B1.store_le(out + 16 * 2); + B2.store_le(out + 16 * 4); + B3.store_le(out + 16 * 6); + + blocks -= 8; + in += 8 * 16; + out += 8 * 16; + } + + while(blocks >= 2) { + SIMD_8x32 B = SIMD_8x32::load_le(in); + + B ^= K0; + aesdec(K1, B); + aesdec(K2, B); + aesdec(K3, B); + aesdec(K4, B); + aesdec(K5, B); + aesdec(K6, B); + aesdec(K7, B); + aesdec(K8, B); + aesdec(K9, B); + aesdec(K10, B); + aesdec(K11, B); + aesdec(K12, B); + aesdec(K13, B); + aesdeclast(K14, B); + + B.store_le(out); + + in += 2 * 16; + out += 2 * 16; + blocks -= 2; + } + + if(blocks > 0) { + SIMD_8x32 B = SIMD_8x32::load_le128(in); + + B ^= K0; + aesdec(K1, B); + aesdec(K2, B); + aesdec(K3, B); + aesdec(K4, B); + aesdec(K5, B); + aesdec(K6, B); + aesdec(K7, B); + aesdec(K8, B); + aesdec(K9, B); + aesdec(K10, B); + aesdec(K11, B); + aesdec(K12, B); + aesdec(K13, B); + aesdeclast(K14, B); + + B.store_le128(out); + } +} + +} // namespace Botan diff --git a/src/lib/block/aes/aes_vaes/info.txt b/src/lib/block/aes/aes_vaes/info.txt new file mode 100644 index 00000000000..0cfad790a20 --- /dev/null +++ b/src/lib/block/aes/aes_vaes/info.txt @@ -0,0 +1,16 @@ + +AES_VAES -> 20240803 + + + +name -> "AES-VAES" +brief -> "AES using VAES" + + + +simd_avx2 + + + +vaes + diff --git a/src/lib/utils/simd/simd_avx2/simd_avx2.h b/src/lib/utils/simd/simd_avx2/simd_avx2.h index 963e8cdf1cc..7001526a2ca 100644 --- a/src/lib/utils/simd/simd_avx2/simd_avx2.h +++ b/src/lib/utils/simd/simd_avx2/simd_avx2.h @@ -52,12 +52,27 @@ class SIMD_8x32 final { return SIMD_8x32(_mm256_loadu_si256(reinterpret_cast(in))); } + BOTAN_AVX2_FN + static SIMD_8x32 load_le128(const uint8_t* in) noexcept { + return SIMD_8x32(_mm256_broadcastsi128_si256(_mm_loadu_si128(reinterpret_cast(in)))); + } + + BOTAN_AVX2_FN + static SIMD_8x32 load_le128(const uint32_t* in) noexcept { + return SIMD_8x32(_mm256_broadcastsi128_si256(_mm_loadu_si128(reinterpret_cast(in)))); + } + BOTAN_AVX2_FN static SIMD_8x32 load_be(const uint8_t* in) noexcept { return load_le(in).bswap(); } BOTAN_AVX2_FN void store_le(uint8_t out[]) const noexcept { _mm256_storeu_si256(reinterpret_cast<__m256i*>(out), m_avx2); } + BOTAN_AVX2_FN + void store_le128(uint8_t out[]) const noexcept { + _mm_storeu_si128(reinterpret_cast<__m128i*>(out), _mm256_extracti128_si256(raw(), 0)); + } + BOTAN_AVX2_FN void store_be(uint8_t out[]) const noexcept { bswap().store_le(out); } @@ -224,7 +239,7 @@ class SIMD_8x32 final { BOTAN_AVX2_FN static SIMD_8x32 choose(const SIMD_8x32& mask, const SIMD_8x32& a, const SIMD_8x32& b) noexcept { #if defined(__AVX512VL__) - return _mm256_ternarylogic_epi32(mask.handle(), a.handle(), b.handle(), 0xca); + return _mm256_ternarylogic_epi32(mask.raw(), a.raw(), b.raw(), 0xca); #else return (mask & a) ^ mask.andc(b); #endif @@ -233,7 +248,7 @@ class SIMD_8x32 final { BOTAN_AVX2_FN static SIMD_8x32 majority(const SIMD_8x32& x, const SIMD_8x32& y, const SIMD_8x32& z) noexcept { #if defined(__AVX512VL__) - return _mm256_ternarylogic_epi32(x.handle(), y.handle(), z.handle(), 0xe8); + return _mm256_ternarylogic_epi32(x.raw(), y.raw(), z.raw(), 0xe8); #else return SIMD_8x32::choose(x ^ y, z, y); #endif @@ -245,7 +260,7 @@ class SIMD_8x32 final { BOTAN_AVX2_FN static void zero_registers() noexcept { _mm256_zeroall(); } - __m256i BOTAN_AVX2_FN handle() const noexcept { return m_avx2; } + __m256i BOTAN_AVX2_FN raw() const noexcept { return m_avx2; } BOTAN_AVX2_FN SIMD_8x32(__m256i x) noexcept : m_avx2(x) {} @@ -253,8 +268,8 @@ class SIMD_8x32 final { private: BOTAN_AVX2_FN static void swap_tops(SIMD_8x32& A, SIMD_8x32& B) { - SIMD_8x32 T0 = _mm256_permute2x128_si256(A.handle(), B.handle(), 0 + (2 << 4)); - SIMD_8x32 T1 = _mm256_permute2x128_si256(A.handle(), B.handle(), 1 + (3 << 4)); + SIMD_8x32 T0 = _mm256_permute2x128_si256(A.raw(), B.raw(), 0 + (2 << 4)); + SIMD_8x32 T1 = _mm256_permute2x128_si256(A.raw(), B.raw(), 1 + (3 << 4)); A = T0; B = T1; } diff --git a/src/tests/data/block/aes.vec b/src/tests/data/block/aes.vec index e08df20d420..b9f3df2fc8b 100644 --- a/src/tests/data/block/aes.vec +++ b/src/tests/data/block/aes.vec @@ -1,7 +1,7 @@ # Test vectors from NIST CAVP AESAVS # http://csrc.nist.gov/groups/STM/cavp/documents/aes/AESAVS.pdf -#test cpuid aesni armv8aes power_crypto ssse3 neon altivec +#test cpuid aesni avx2_vaes armv8aes power_crypto ssse3 neon altivec [AES-128] Key = 000102030405060708090A0B0C0D0E0F