diff --git a/src/lib/pubkey/hss_lms/hss.cpp b/src/lib/pubkey/hss_lms/hss.cpp index 5d555600ab5..6938918b6ca 100644 --- a/src/lib/pubkey/hss_lms/hss.cpp +++ b/src/lib/pubkey/hss_lms/hss.cpp @@ -120,7 +120,7 @@ HSS_LMS_PrivateKeyInternal::HSS_LMS_PrivateKeyInternal(const HSS_LMS_Params& hss std::shared_ptr HSS_LMS_PrivateKeyInternal::from_bytes_or_throw( std::span key_bytes) { if(key_bytes.size() < sizeof(HSS_Level) + sizeof(HSS_Sig_Idx)) { - throw Decoding_Error("To few private key bytes."); + throw Decoding_Error("Too few private key bytes."); } BufferSlicer slicer(key_bytes); @@ -166,9 +166,7 @@ std::shared_ptr HSS_LMS_PrivateKeyInternal::from_byt } secure_vector HSS_LMS_PrivateKeyInternal::to_bytes() const { - secure_vector sk_bytes(sizeof(HSS_Level) + sizeof(HSS_Sig_Idx) + - hss_params().L().get() * - (sizeof(LMS_Algorithm_Type) + sizeof(LMOTS_Algorithm_Type))); + secure_vector sk_bytes(size()); BufferStuffer stuffer(sk_bytes); stuffer.append_be(hss_params().L()); @@ -179,9 +177,9 @@ secure_vector HSS_LMS_PrivateKeyInternal::to_bytes() const { stuffer.append_be(params.lms_params().algorithm_type()); stuffer.append_be(params.lmots_params().algorithm_type()); } - - sk_bytes.insert(sk_bytes.end(), seed().begin(), seed().end()); - sk_bytes.insert(sk_bytes.end(), identifier().begin(), identifier().end()); + stuffer.append(seed()); + stuffer.append(identifier()); + BOTAN_ASSERT_NOMSG(stuffer.full()); return sk_bytes; } @@ -199,6 +197,14 @@ HSS_Sig_Idx HSS_LMS_PrivateKeyInternal::reserve_next_idx() { return next_idx; } +size_t HSS_LMS_PrivateKeyInternal::size() const { + size_t sk_size = sizeof(HSS_Level) + sizeof(HSS_Sig_Idx); + // The concatenated algorithm types for all layers + sk_size += hss_params().L().get() * (sizeof(LMS_Algorithm_Type) + sizeof(LMOTS_Algorithm_Type)); + sk_size += seed().size() + identifier().size(); + return sk_size; +} + HSS_LMS_PrivateKeyInternal::HSS_LMS_PrivateKeyInternal(HSS_LMS_Params hss_params, LMS_Seed hss_seed, LMS_Identifier identifier) : @@ -299,7 +305,7 @@ HSS_LMS_PublicKeyInternal HSS_LMS_PublicKeyInternal::create(const HSS_LMS_Privat std::shared_ptr HSS_LMS_PublicKeyInternal::from_bytes_or_throw( std::span key_bytes) { if(key_bytes.size() < sizeof(HSS_Level)) { - throw Decoding_Error("To few public key bytes."); + throw Decoding_Error("Too few public key bytes."); } BufferSlicer slicer(key_bytes); @@ -308,7 +314,7 @@ std::shared_ptr HSS_LMS_PublicKeyInternal::from_bytes throw Decoding_Error("Invalid number of HSS layers in public HSS-LMS key."); } - LMS_PublicKey lms_pub_key = LMS_PublicKey::from_bytes_of_throw(slicer); + LMS_PublicKey lms_pub_key = LMS_PublicKey::from_bytes_or_throw(slicer); if(!slicer.empty()) { throw Decoding_Error("Public HSS-LMS key contains more bytes than expected."); @@ -372,7 +378,7 @@ HSS_Signature::Signed_Pub_Key::Signed_Pub_Key(LMS_Signature sig, LMS_PublicKey p HSS_Signature HSS_Signature::from_bytes_or_throw(std::span sig_bytes) { if(sig_bytes.size() < sizeof(uint32_t)) { - throw Decoding_Error("To few HSS signature bytes."); + throw Decoding_Error("Too few HSS signature bytes."); } BufferSlicer slicer(sig_bytes); @@ -384,7 +390,7 @@ HSS_Signature HSS_Signature::from_bytes_or_throw(std::span sig_by std::vector signed_pub_keys; for(size_t i = 0; i < Nspk; ++i) { LMS_Signature sig = LMS_Signature::from_bytes_or_throw(slicer); - LMS_PublicKey pub_key = LMS_PublicKey::from_bytes_of_throw(slicer); + LMS_PublicKey pub_key = LMS_PublicKey::from_bytes_or_throw(slicer); signed_pub_keys.push_back(Signed_Pub_Key(std::move(sig), std::move(pub_key))); } diff --git a/src/lib/pubkey/hss_lms/hss.h b/src/lib/pubkey/hss_lms/hss.h index 9d32b2ebfa4..4772216de3f 100644 --- a/src/lib/pubkey/hss_lms/hss.h +++ b/src/lib/pubkey/hss_lms/hss.h @@ -207,6 +207,11 @@ class BOTAN_TEST_API HSS_LMS_PrivateKeyInternal final { */ HSS_Sig_Idx reserve_next_idx(); + /** + * @brief Returns the size in bytes the key would have in its encoded format. + */ + size_t size() const; + /** * @brief Derive the seed and identifier of an LMS tree from its parent LMS tree. * diff --git a/src/lib/pubkey/hss_lms/hss_lms_utils.cpp b/src/lib/pubkey/hss_lms/hss_lms_utils.cpp index b38b34efb46..ffd8a965c7b 100644 --- a/src/lib/pubkey/hss_lms/hss_lms_utils.cpp +++ b/src/lib/pubkey/hss_lms/hss_lms_utils.cpp @@ -15,11 +15,6 @@ PseudorandomKeyGeneration::PseudorandomKeyGeneration(std::span id m_input_buffer.resize(identifier.size() + sizeof(uint32_t) + sizeof(uint16_t) + sizeof(uint8_t)); BufferStuffer input_stuffer(m_input_buffer); input_stuffer.append(identifier); - - m_q = input_stuffer.next(sizeof(uint32_t)); - m_i = input_stuffer.next(sizeof(uint16_t)); - m_j = input_stuffer.next(sizeof(uint8_t)).data(); - BOTAN_ASSERT_NOMSG(input_stuffer.full()); } void PseudorandomKeyGeneration::gen(std::span out, HashFunction& hash, std::span seed) const { diff --git a/src/lib/pubkey/hss_lms/hss_lms_utils.h b/src/lib/pubkey/hss_lms/hss_lms_utils.h index e89d3d9e090..a580c308f0b 100644 --- a/src/lib/pubkey/hss_lms/hss_lms_utils.h +++ b/src/lib/pubkey/hss_lms/hss_lms_utils.h @@ -35,17 +35,17 @@ class PseudorandomKeyGeneration { /** * @brief Specify the value for the u32str(q) hash input field */ - void set_q(uint32_t q) { store_be(q, m_q.data()); } + void set_q(uint32_t q) { store_be(q, std::span(m_input_buffer).last<7>().first<4>().data()); } /** * @brief Specify the value for the u16str(i) hash input field */ - void set_i(uint16_t i) { store_be(i, m_i.data()); } + void set_i(uint16_t i) { store_be(i, std::span(m_input_buffer).last<3>().first<2>().data()); } /** * @brief Specify the value for the u8str(j) hash input field */ - void set_j(uint8_t j) { *m_j = j; } + void set_j(uint8_t j) { m_input_buffer.back() = j; } /** * @brief Create a hash value using the preconfigured prefix and a @p seed @@ -66,12 +66,6 @@ class PseudorandomKeyGeneration { private: /// Input buffer containing the prefix: 'identifier || u32str(q) || u16str(i) || u8str(j)' std::vector m_input_buffer; - /// Subspan of m_input_buffer representing 'u32str(q)' - std::span m_q; - /// Subspan of m_input_buffer representing 'u26str(i)' - std::span m_i; - /// Pointer to m_input_buffer at 'u8str(j)' - uint8_t* m_j; }; } // namespace Botan diff --git a/src/lib/pubkey/hss_lms/lm_ots.cpp b/src/lib/pubkey/hss_lms/lm_ots.cpp index caff64e1f1e..4c2240f05fc 100644 --- a/src/lib/pubkey/hss_lms/lm_ots.cpp +++ b/src/lib/pubkey/hss_lms/lm_ots.cpp @@ -101,45 +101,113 @@ std::vector gen_Q_with_cksm(const LMOTS_Params& params, } // namespace LMOTS_Params LMOTS_Params::create_or_throw(LMOTS_Algorithm_Type type) { - uint8_t type_value = checked_cast_to_or_throw(type, "Unsupported LM-OTS algorithm type"); - - if(type >= LMOTS_Algorithm_Type::SHA256_N32_W1 && type <= LMOTS_Algorithm_Type::SHA256_N32_W8) { - uint8_t w = 1 << (type_value - checked_cast_to(LMOTS_Algorithm_Type::SHA256_N32_W1)); - return LMOTS_Params(type, "SHA-256", w); - } - if(type >= LMOTS_Algorithm_Type::SHA256_N24_W1 && type <= LMOTS_Algorithm_Type::SHA256_N24_W8) { - uint8_t w = 1 << (type_value - checked_cast_to(LMOTS_Algorithm_Type::SHA256_N24_W1)); - return LMOTS_Params(type, "Truncated(SHA-256,192)", w); - } - if(type >= LMOTS_Algorithm_Type::SHAKE_N32_W1 && type <= LMOTS_Algorithm_Type::SHAKE_N32_W8) { - uint8_t w = 1 << (type_value - checked_cast_to(LMOTS_Algorithm_Type::SHAKE_N32_W1)); - return LMOTS_Params(type, "SHAKE-256(256)", w); - } - if(type >= LMOTS_Algorithm_Type::SHAKE_N24_W1 && type <= LMOTS_Algorithm_Type::SHAKE_N24_W8) { - uint8_t w = 1 << (type_value - checked_cast_to(LMOTS_Algorithm_Type::SHAKE_N24_W1)); - return LMOTS_Params(type, "SHAKE-256(192)", w); - } + auto [hash_name, w] = [](const LMOTS_Algorithm_Type& lmots_type) -> std::pair { + switch(lmots_type) { + case LMOTS_Algorithm_Type::SHA256_N32_W1: + return {"SHA-256", 1}; + case LMOTS_Algorithm_Type::SHA256_N32_W2: + return {"SHA-256", 2}; + case LMOTS_Algorithm_Type::SHA256_N32_W4: + return {"SHA-256", 4}; + case LMOTS_Algorithm_Type::SHA256_N32_W8: + return {"SHA-256", 8}; + case LMOTS_Algorithm_Type::SHA256_N24_W1: + return {"Truncated(SHA-256,192)", 1}; + case LMOTS_Algorithm_Type::SHA256_N24_W2: + return {"Truncated(SHA-256,192)", 2}; + case LMOTS_Algorithm_Type::SHA256_N24_W4: + return {"Truncated(SHA-256,192)", 4}; + case LMOTS_Algorithm_Type::SHA256_N24_W8: + return {"Truncated(SHA-256,192)", 8}; + case LMOTS_Algorithm_Type::SHAKE_N32_W1: + return {"SHAKE-256(256)", 1}; + case LMOTS_Algorithm_Type::SHAKE_N32_W2: + return {"SHAKE-256(256)", 2}; + case LMOTS_Algorithm_Type::SHAKE_N32_W4: + return {"SHAKE-256(256)", 4}; + case LMOTS_Algorithm_Type::SHAKE_N32_W8: + return {"SHAKE-256(256)", 8}; + case LMOTS_Algorithm_Type::SHAKE_N24_W1: + return {"SHAKE-256(192)", 1}; + case LMOTS_Algorithm_Type::SHAKE_N24_W2: + return {"SHAKE-256(192)", 2}; + case LMOTS_Algorithm_Type::SHAKE_N24_W4: + return {"SHAKE-256(192)", 4}; + case LMOTS_Algorithm_Type::SHAKE_N24_W8: + return {"SHAKE-256(192)", 8}; + case LMOTS_Algorithm_Type::RESERVED: + throw Decoding_Error("Unsupported LMS algorithm type"); + } + throw Decoding_Error("Unsupported LMS algorithm type"); + }(type); - throw Decoding_Error("Unsupported LM-OTS algorithm type"); + return LMOTS_Params(type, hash_name, w); } LMOTS_Params LMOTS_Params::create_or_throw(std::string_view hash_name, uint8_t w) { - BOTAN_ARG_CHECK(w == 1 || w == 2 || w == 4 || w == 8, "Invalid w value"); - auto type_offset = high_bit(w) - 1; - LMOTS_Algorithm_Type base_type; - - if(hash_name == "SHA-256") { - base_type = LMOTS_Algorithm_Type::SHA256_N32_W1; - } else if(hash_name == "Truncated(SHA-256,192)") { - base_type = LMOTS_Algorithm_Type::SHA256_N24_W1; - } else if(hash_name == "SHAKE-256(256)") { - base_type = LMOTS_Algorithm_Type::SHAKE_N32_W1; - } else if(hash_name == "SHAKE-256(192)") { - base_type = LMOTS_Algorithm_Type::SHAKE_N24_W1; - } else { - throw Decoding_Error("Unsupported hash function"); + if(w != 1 && w != 2 && w != 4 && w != 8) { + throw Decoding_Error("Invalid Winternitz parameter"); } - auto type = checked_cast_to(checked_cast_to(base_type) + type_offset); + LMOTS_Algorithm_Type type = [](std::string_view hash, uint8_t w_p) -> LMOTS_Algorithm_Type { + if(hash == "SHA-256") { + switch(w_p) { + case 1: + return LMOTS_Algorithm_Type::SHA256_N32_W1; + case 2: + return LMOTS_Algorithm_Type::SHA256_N32_W2; + case 4: + return LMOTS_Algorithm_Type::SHA256_N32_W4; + case 8: + return LMOTS_Algorithm_Type::SHA256_N32_W8; + default: + throw Decoding_Error("Unsupported Winternitz parameter"); + } + } + if(hash == "Truncated(SHA-256,192)") { + switch(w_p) { + case 1: + return LMOTS_Algorithm_Type::SHA256_N24_W1; + case 2: + return LMOTS_Algorithm_Type::SHA256_N24_W2; + case 4: + return LMOTS_Algorithm_Type::SHA256_N24_W4; + case 8: + return LMOTS_Algorithm_Type::SHA256_N24_W8; + default: + throw Decoding_Error("Unsupported Winternitz parameter"); + } + } + if(hash == "SHAKE-256(256)") { + switch(w_p) { + case 1: + return LMOTS_Algorithm_Type::SHAKE_N32_W1; + case 2: + return LMOTS_Algorithm_Type::SHAKE_N32_W2; + case 4: + return LMOTS_Algorithm_Type::SHAKE_N32_W4; + case 8: + return LMOTS_Algorithm_Type::SHAKE_N32_W8; + default: + throw Decoding_Error("Unsupported Winternitz parameter"); + } + } + if(hash == "SHAKE-256(192)") { + switch(w_p) { + case 1: + return LMOTS_Algorithm_Type::SHAKE_N24_W1; + case 2: + return LMOTS_Algorithm_Type::SHAKE_N24_W2; + case 4: + return LMOTS_Algorithm_Type::SHAKE_N24_W4; + case 8: + return LMOTS_Algorithm_Type::SHAKE_N24_W8; + default: + throw Decoding_Error("Unsupported Winternitz parameter"); + } + } + throw Decoding_Error("Unsupported hash function"); + }(hash_name, w); + return LMOTS_Params(type, hash_name, w); } @@ -171,7 +239,7 @@ LMOTS_Signature LMOTS_Signature::from_bytes_or_throw(BufferSlicer& slicer) { size_t total_remaining_bytes = slicer.remaining(); // Alg. 6a. 1. (last 4 bytes) / Alg. 4b. 1. if(total_remaining_bytes < sizeof(LMOTS_Algorithm_Type)) { - throw Decoding_Error("To few signature bytes while parsing LMOTS signature."); + throw Decoding_Error("Too few signature bytes while parsing LMOTS signature."); } // Alg. 6a. 2.b. / Alg. 4b. 2.a. auto algorithm_type = slicer.copy_be(); @@ -180,7 +248,7 @@ LMOTS_Signature LMOTS_Signature::from_bytes_or_throw(BufferSlicer& slicer) { LMOTS_Params params = LMOTS_Params::create_or_throw(algorithm_type); if(total_remaining_bytes < size(params)) { - throw Decoding_Error("To few signature bytes while parsing LMOTS signature."); + throw Decoding_Error("Too few signature bytes while parsing LMOTS signature."); } // Alg. 4b. 2.d. diff --git a/src/lib/pubkey/hss_lms/lms.cpp b/src/lib/pubkey/hss_lms/lms.cpp index 8462e7ee325..5f50609e10b 100644 --- a/src/lib/pubkey/hss_lms/lms.cpp +++ b/src/lib/pubkey/hss_lms/lms.cpp @@ -54,15 +54,12 @@ class TreeAddress final { uint32_t m_r; }; -std::function, const TreeAddress&, StrongSpan, StrongSpan)> -get_hash_pair_func_for_identifier(const LMS_Params& lms_params, LMS_Identifier identifier) { - // hash object must be shared, otherwise std::function would not be copyable, which is not allowed - std::shared_ptr hash = HashFunction::create_or_throw(lms_params.hash_name()); - return [hash, I = std::move(identifier)](StrongSpan out, - const TreeAddress& address, - StrongSpan left, - StrongSpan right) { +auto get_hash_pair_func_for_identifier(const LMS_Params& lms_params, LMS_Identifier identifier) { + return [hash = HashFunction::create_or_throw(lms_params.hash_name()), I = std::move(identifier)]( + StrongSpan out, + const TreeAddress& address, + StrongSpan left, + StrongSpan right) { auto lms_address = dynamic_cast(address); hash->update(I); @@ -85,11 +82,9 @@ void lms_gen_leaf(StrongSpan out, hash.final(out); } -std::function out, const TreeAddress& address)> lms_gen_leaf_func( - const LMS_PrivateKey& lms_sk) { - // hash object must be shared, otherwise std::function would not be copyable, which is not allowed - std::shared_ptr hash = HashFunction::create_or_throw(lms_sk.lms_params().hash_name()); - return [lms_sk, hash](StrongSpan out, const TreeAddress& tree_address) { +auto lms_gen_leaf_func(const LMS_PrivateKey& lms_sk) { + return [hash = HashFunction::create_or_throw(lms_sk.lms_params().hash_name()), lms_sk]( + StrongSpan out, const TreeAddress& tree_address) { auto lmots_sk = LMOTS_Private_Key(lms_sk.lmots_params(), lms_sk.identifier(), tree_address.q(), lms_sk.seed()); auto lmots_pk = LMOTS_Public_Key(lmots_sk); lms_gen_leaf(out, lmots_pk, tree_address, *hash); @@ -110,54 +105,138 @@ void lms_treehash(StrongSpan out_root, lms_sk.lms_params().m(), LMS_TreeLayerIndex(lms_sk.lms_params().h()), 0, - hash_pair_func, - gen_leaf, + std::move(hash_pair_func), + std::move(gen_leaf), lms_tree_address); } } // namespace LMS_Params LMS_Params::create_or_throw(LMS_Algorithm_Type type) { - uint8_t type_value = checked_cast_to_or_throw(type, "Unsupported LMS algorithm type"); - - if(type >= LMS_Algorithm_Type::SHA256_M32_H5 && type <= LMS_Algorithm_Type::SHA256_M32_H25) { - uint8_t h = 5 * (type_value - checked_cast_to(LMS_Algorithm_Type::SHA256_M32_H5) + 1); - return LMS_Params(type, "SHA-256", h); - } - if(type >= LMS_Algorithm_Type::SHA256_M24_H5 && type <= LMS_Algorithm_Type::SHA256_M24_H25) { - uint8_t h = 5 * (type_value - checked_cast_to(LMS_Algorithm_Type::SHA256_M24_H5) + 1); - return LMS_Params(type, "Truncated(SHA-256,192)", h); - } - if(type >= LMS_Algorithm_Type::SHAKE_M32_H5 && type <= LMS_Algorithm_Type::SHAKE_M32_H25) { - uint8_t h = 5 * (type_value - checked_cast_to(LMS_Algorithm_Type::SHAKE_M32_H5) + 1); - return LMS_Params(type, "SHAKE-256(256)", h); - } - if(type >= LMS_Algorithm_Type::SHAKE_M24_H5 && type <= LMS_Algorithm_Type::SHAKE_M24_H25) { - uint8_t h = 5 * (type_value - checked_cast_to(LMS_Algorithm_Type::SHAKE_M24_H5) + 1); - return LMS_Params(type, "SHAKE-256(192)", h); - } + auto [hash_name, height] = [](const LMS_Algorithm_Type& lms_type) -> std::pair { + switch(lms_type) { + case LMS_Algorithm_Type::SHA256_M32_H5: + return {"SHA-256", 5}; + case LMS_Algorithm_Type::SHA256_M32_H10: + return {"SHA-256", 10}; + case LMS_Algorithm_Type::SHA256_M32_H15: + return {"SHA-256", 15}; + case LMS_Algorithm_Type::SHA256_M32_H20: + return {"SHA-256", 20}; + case LMS_Algorithm_Type::SHA256_M32_H25: + return {"SHA-256", 25}; + case LMS_Algorithm_Type::SHA256_M24_H5: + return {"Truncated(SHA-256,192)", 5}; + case LMS_Algorithm_Type::SHA256_M24_H10: + return {"Truncated(SHA-256,192)", 10}; + case LMS_Algorithm_Type::SHA256_M24_H15: + return {"Truncated(SHA-256,192)", 15}; + case LMS_Algorithm_Type::SHA256_M24_H20: + return {"Truncated(SHA-256,192)", 20}; + case LMS_Algorithm_Type::SHA256_M24_H25: + return {"Truncated(SHA-256,192)", 25}; + case LMS_Algorithm_Type::SHAKE_M32_H5: + return {"SHAKE-256(256)", 5}; + case LMS_Algorithm_Type::SHAKE_M32_H10: + return {"SHAKE-256(256)", 10}; + case LMS_Algorithm_Type::SHAKE_M32_H15: + return {"SHAKE-256(256)", 15}; + case LMS_Algorithm_Type::SHAKE_M32_H20: + return {"SHAKE-256(256)", 20}; + case LMS_Algorithm_Type::SHAKE_M32_H25: + return {"SHAKE-256(256)", 25}; + case LMS_Algorithm_Type::SHAKE_M24_H5: + return {"SHAKE-256(192)", 5}; + case LMS_Algorithm_Type::SHAKE_M24_H10: + return {"SHAKE-256(192)", 10}; + case LMS_Algorithm_Type::SHAKE_M24_H15: + return {"SHAKE-256(192)", 15}; + case LMS_Algorithm_Type::SHAKE_M24_H20: + return {"SHAKE-256(192)", 20}; + case LMS_Algorithm_Type::SHAKE_M24_H25: + return {"SHAKE-256(192)", 25}; + case LMS_Algorithm_Type::RESERVED: + throw Decoding_Error("Unsupported LMS algorithm type"); + } + throw Decoding_Error("Unsupported LMS algorithm type"); + }(type); - throw Decoding_Error("Unsupported LMS algorithm type"); + return LMS_Params(type, hash_name, height); } -LMS_Params LMS_Params::create_or_throw(std::string_view hash_name, uint8_t h) { - BOTAN_ARG_CHECK(h == 5 || h == 10 || h == 15 || h == 20 || h == 25, "Invalid h value"); - auto type_offset = h / 5 - 1; - LMS_Algorithm_Type base_type; - - if(hash_name == "SHA-256") { - base_type = LMS_Algorithm_Type::SHA256_M32_H5; - } else if(hash_name == "Truncated(SHA-256,192)") { - base_type = LMS_Algorithm_Type::SHA256_M24_H5; - } else if(hash_name == "SHAKE-256(256)") { - base_type = LMS_Algorithm_Type::SHAKE_M32_H5; - } else if(hash_name == "SHAKE-256(192)") { - base_type = LMS_Algorithm_Type::SHAKE_M24_H5; - } else { - throw Decoding_Error("Unsupported hash function"); +LMS_Params LMS_Params::create_or_throw(std::string_view hash_name, uint8_t height) { + if(height != 5 && height != 10 && height != 15 && height != 20 && height != 25) { + throw Decoding_Error("Invalid height"); } - auto type = checked_cast_to(checked_cast_to(base_type) + type_offset); - return LMS_Params(type, hash_name, h); + LMS_Algorithm_Type type = [](std::string_view hash, uint8_t h) -> LMS_Algorithm_Type { + if(hash == "SHA-256") { + switch(h) { + case 5: + return LMS_Algorithm_Type::SHA256_M32_H5; + case 10: + return LMS_Algorithm_Type::SHA256_M32_H10; + case 15: + return LMS_Algorithm_Type::SHA256_M32_H15; + case 20: + return LMS_Algorithm_Type::SHA256_M32_H20; + case 25: + return LMS_Algorithm_Type::SHA256_M32_H25; + default: + throw Decoding_Error("Unsupported height for hash function"); + } + } + if(hash == "Truncated(SHA-256,192)") { + switch(h) { + case 5: + return LMS_Algorithm_Type::SHA256_M24_H5; + case 10: + return LMS_Algorithm_Type::SHA256_M24_H10; + case 15: + return LMS_Algorithm_Type::SHA256_M24_H15; + case 20: + return LMS_Algorithm_Type::SHA256_M24_H20; + case 25: + return LMS_Algorithm_Type::SHA256_M24_H25; + default: + throw Decoding_Error("Unsupported height for hash function"); + } + } + if(hash == "SHAKE-256(256)") { + switch(h) { + case 5: + return LMS_Algorithm_Type::SHAKE_M32_H5; + case 10: + return LMS_Algorithm_Type::SHAKE_M32_H10; + case 15: + return LMS_Algorithm_Type::SHAKE_M32_H15; + case 20: + return LMS_Algorithm_Type::SHAKE_M32_H20; + case 25: + return LMS_Algorithm_Type::SHAKE_M32_H25; + default: + throw Decoding_Error("Unsupported height for hash function"); + } + } + if(hash == "SHAKE-256(192)") { + switch(h) { + case 5: + return LMS_Algorithm_Type::SHAKE_M24_H5; + case 10: + return LMS_Algorithm_Type::SHAKE_M24_H10; + case 15: + return LMS_Algorithm_Type::SHAKE_M24_H15; + case 20: + return LMS_Algorithm_Type::SHAKE_M24_H20; + case 25: + return LMS_Algorithm_Type::SHAKE_M24_H25; + default: + throw Decoding_Error("Unsupported height for hash function"); + } + } + throw Decoding_Error("Unsupported hash function"); + }(hash_name, height); + + return LMS_Params(type, hash_name, height); } LMS_Params::LMS_Params(LMS_Algorithm_Type algorithm_type, std::string_view hash_name, uint8_t h) : @@ -188,11 +267,11 @@ LMS_PublicKey LMS_PrivateKey::sign_and_get_pk(StrongSpan ou return LMS_PublicKey(lms_params(), lmots_params(), identifier(), std::move(pk_buffer)); } -LMS_PublicKey LMS_PublicKey::from_bytes_of_throw(BufferSlicer& slicer) { +LMS_PublicKey LMS_PublicKey::from_bytes_or_throw(BufferSlicer& slicer) { size_t total_remaining_bytes = slicer.remaining(); // Alg. 6. 1. (4 bytes are sufficient until the next check) if(total_remaining_bytes < sizeof(LMS_Algorithm_Type)) { - throw Decoding_Error("To few bytes while parsing LMS public key."); + throw Decoding_Error("Too few bytes while parsing LMS public key."); } // Alg. 6. 2.a. auto lms_type = slicer.copy_be(); @@ -200,7 +279,7 @@ LMS_PublicKey LMS_PublicKey::from_bytes_of_throw(BufferSlicer& slicer) { auto lms_params = LMS_Params::create_or_throw(lms_type); // Alg. 6. 2.d. if(total_remaining_bytes < size(lms_params)) { - throw Decoding_Error("To few bytes while parsing LMS public key."); + throw Decoding_Error("Too few bytes while parsing LMS public key."); } // Alg. 6. 2.b. auto lmots_type = slicer.copy_be(); @@ -248,7 +327,7 @@ LMS_Signature LMS_Signature::from_bytes_or_throw(BufferSlicer& slicer) { size_t total_remaining_bytes = slicer.remaining(); // Alg. 6a 1. (next 4 bytes are checked in LMOTS_Signature::from_bytes_or_throw) if(total_remaining_bytes < sizeof(LMS_Tree_Node_Idx)) { - throw Decoding_Error("To few signature bytes while parsing LMS signature."); + throw Decoding_Error("Too few signature bytes while parsing LMS signature."); } // Alg. 6a 2.a. auto q = slicer.copy_be(); @@ -258,7 +337,7 @@ LMS_Signature LMS_Signature::from_bytes_or_throw(BufferSlicer& slicer) { LMOTS_Params lmots_params = LMOTS_Params::create_or_throw(lmots_sig.algorithm_type()); if(slicer.remaining() < sizeof(LMS_Algorithm_Type)) { - throw Decoding_Error("To few signature bytes while parsing LMS signature."); + throw Decoding_Error("Too few signature bytes while parsing LMS signature."); } // Alg. 6a 2.f. auto lms_type = slicer.copy_be(); @@ -266,13 +345,13 @@ LMS_Signature LMS_Signature::from_bytes_or_throw(BufferSlicer& slicer) { LMS_Params lms_params = LMS_Params::create_or_throw(lms_type); // Alg. 6a 2.i. (signature is not exactly [...] bytes long) if(total_remaining_bytes < size(lms_params, lmots_params)) { - throw Decoding_Error("To few signature bytes while parsing LMS signature."); + throw Decoding_Error("Too few signature bytes while parsing LMS signature."); } // Alg. 6a 2.j. - auto auth_path = slicer.take(lms_params.m() * lms_params.h()); + auto auth_path = slicer.copy(lms_params.m() * lms_params.h()); - return LMS_Signature(q, std::move(lmots_sig), lms_type, LMS_AuthenticationPath(auth_path)); + return LMS_Signature(q, std::move(lmots_sig), lms_type, std::move(auth_path)); } LMS_PublicKey::LMS_PublicKey(const LMS_PrivateKey& sk) : LMS_Instance(sk), m_lms_root(sk.lms_params().m()) { @@ -313,35 +392,37 @@ std::optional LMS_PublicKey::lms_compute_root_from_sig(const LMS_ lmots_params().algorithm_type() != sig.lmots_sig().algorithm_type()) { return std::nullopt; } - - const LMS_Params lms_params = LMS_Params::create_or_throw(sig.lms_type()); - const LMOTS_Signature& lmots_sig = sig.lmots_sig(); - const LMOTS_Params lmots_params = LMOTS_Params::create_or_throw(lmots_sig.algorithm_type()); - const LMOTS_K Kc = lmots_compute_pubkey_from_sig(lmots_sig, msg, identifier(), sig.q()); - const auto hash = HashFunction::create_or_throw(lms_params.hash_name()); - - auto hash_pair_func = get_hash_pair_func_for_identifier(lms_params, identifier()); - - auto lms_address = TreeAddress(lms_params.h()); - lms_address.set_address(LMS_TreeLayerIndex(0), LMS_Tree_Node_Idx(sig.q().get())); - - LMOTS_Public_Key pk_candidate(lmots_params, identifier(), sig.q(), Kc); - LMS_Tree_Node tmp(lms_params.m()); - lms_gen_leaf(tmp, pk_candidate, lms_address, *hash); - - LMS_Tree_Node root(lms_params.m()); - - compute_root(StrongSpan(root), - sig.auth_path(), - sig.q(), - StrongSpan(tmp), - lms_params.m(), - LMS_TreeLayerIndex(lms_params.h()), - 0, - hash_pair_func, - lms_address); - - return LMS_Tree_Node(root); + try { + const LMS_Params lms_params = LMS_Params::create_or_throw(sig.lms_type()); + const LMOTS_Signature& lmots_sig = sig.lmots_sig(); + const LMOTS_Params lmots_params = LMOTS_Params::create_or_throw(lmots_sig.algorithm_type()); + const LMOTS_K Kc = lmots_compute_pubkey_from_sig(lmots_sig, msg, identifier(), sig.q()); + const auto hash = HashFunction::create_or_throw(lms_params.hash_name()); + + auto hash_pair_func = get_hash_pair_func_for_identifier(lms_params, identifier()); + + auto lms_address = TreeAddress(lms_params.h()); + lms_address.set_address(LMS_TreeLayerIndex(0), LMS_Tree_Node_Idx(sig.q().get())); + + LMOTS_Public_Key pk_candidate(lmots_params, identifier(), sig.q(), Kc); + LMS_Tree_Node tmp(lms_params.m()); + lms_gen_leaf(tmp, pk_candidate, lms_address, *hash); + + LMS_Tree_Node root(lms_params.m()); + + compute_root(StrongSpan(root), + sig.auth_path(), + sig.q(), + StrongSpan(tmp), + lms_params.m(), + LMS_TreeLayerIndex(lms_params.h()), + 0, + std::move(hash_pair_func), + lms_address); + return LMS_Tree_Node(root); + } catch(const Decoding_Error&) { + return std::nullopt; + } } size_t LMS_Signature::size(const LMS_Params& lms_params, const LMOTS_Params& lmots_params) { diff --git a/src/lib/pubkey/hss_lms/lms.h b/src/lib/pubkey/hss_lms/lms.h index 2016a8bd1ab..468bcace615 100644 --- a/src/lib/pubkey/hss_lms/lms.h +++ b/src/lib/pubkey/hss_lms/lms.h @@ -225,7 +225,7 @@ class BOTAN_TEST_API LMS_PublicKey : public LMS_Instance { * @return The LMS public key. * @throws Decoding_Error If parsing the public key fails. */ - static LMS_PublicKey from_bytes_of_throw(BufferSlicer& slicer); + static LMS_PublicKey from_bytes_or_throw(BufferSlicer& slicer); /** * @brief Construct a public key for given public key data diff --git a/src/lib/utils/bit_ops.h b/src/lib/utils/bit_ops.h index 345c451eba9..3cb25367c1d 100644 --- a/src/lib/utils/bit_ops.h +++ b/src/lib/utils/bit_ops.h @@ -145,10 +145,8 @@ constexpr uint8_t ceil_log2(T x) * * @returns ceil(a/b) */ -template -inline constexpr T ceil_division(T a, T b) - requires(std::is_integral_v && std::is_unsigned_v) -{ +template +inline constexpr T ceil_division(T a, T b) { return (a + b - 1) / b; } diff --git a/src/lib/utils/concepts.h b/src/lib/utils/concepts.h index fa586dd22fc..e440750ee55 100644 --- a/src/lib/utils/concepts.h +++ b/src/lib/utils/concepts.h @@ -130,6 +130,12 @@ concept integral = std::is_integral_v; template concept enum_type = std::is_enum_v; +template +concept unsigned_type = std::unsigned_integral || // Normal uint + (concepts::strong_type && + std::unsigned_integral) || // Strong type on uint + (std::is_enum_v && std::unsigned_integral>); // Enum on uint + } // namespace concepts template diff --git a/src/lib/utils/stl_util.h b/src/lib/utils/stl_util.h index 6ddebbddae1..2aa8d7c17a5 100644 --- a/src/lib/utils/stl_util.h +++ b/src/lib/utils/stl_util.h @@ -173,12 +173,12 @@ class BufferSlicer final { std::copy(data.begin(), data.end(), sink.begin()); } - template + template auto copy_be() { return load_be(take(sizeof(T)).data(), 0); } - template + template auto copy_le() { return load_le(take(sizeof(T)).data(), 0); } @@ -236,9 +236,14 @@ class BufferStuffer { std::fill(sink.begin(), sink.end(), b); } - template - auto append_be(T value) { - return store_be(value, next(sizeof(T)).data()); + template + void append_be(T value) { + store_be(value, next(sizeof(T)).data()); + } + + template + void append_le(T value) { + store_le(value, next(sizeof(T)).data()); } bool full() const { return m_buffer.empty(); } diff --git a/src/lib/utils/tree_hash/info.txt b/src/lib/utils/tree_hash/info.txt index ddbcded3f24..3b20bbe7699 100644 --- a/src/lib/utils/tree_hash/info.txt +++ b/src/lib/utils/tree_hash/info.txt @@ -4,6 +4,8 @@ TREE_HASH -> 20231006 name -> "Tree Hash" +brief -> "Generic implementation of Merkle Tree Hashing" +type -> "Internal" diff --git a/src/lib/x509/x509_obj.cpp b/src/lib/x509/x509_obj.cpp index fcd4604325b..36c6e8d0b9d 100644 --- a/src/lib/x509/x509_obj.cpp +++ b/src/lib/x509/x509_obj.cpp @@ -173,6 +173,9 @@ std::string x509_signature_padding_for(const std::string& algo_name, } else if(algo_name == "XMSS") { // XMSS does not take any padding, but if the user insists, we pass it along return std::string(user_specified_padding); + } else if(algo_name == "HSS-LMS") { + // HSS-LMS does not take any padding, but if the user insists, we pass it along + return std::string(user_specified_padding); } else { throw Invalid_Argument("Unknown X.509 signing key type: " + algo_name); } diff --git a/src/tests/test_lms.cpp b/src/tests/test_lms.cpp index a70edb033f3..d90af100a06 100644 --- a/src/tests/test_lms.cpp +++ b/src/tests/test_lms.cpp @@ -42,7 +42,7 @@ class LMS_Test final : public Text_Based_Test { auto hash = Botan::HashFunction::create("SHA-256"); auto lms_pk_ref_slicer = Botan::BufferSlicer(pk_ref); - Botan::LMS_PublicKey lms_pk_ref = Botan::LMS_PublicKey::from_bytes_of_throw(lms_pk_ref_slicer); + Botan::LMS_PublicKey lms_pk_ref = Botan::LMS_PublicKey::from_bytes_or_throw(lms_pk_ref_slicer); // Test public key creation auto lms_sk = diff --git a/src/tests/unit_x509.cpp b/src/tests/unit_x509.cpp index 6bc3d3fc1ac..6f52f01b3d4 100644 --- a/src/tests/unit_x509.cpp +++ b/src/tests/unit_x509.cpp @@ -109,6 +109,9 @@ std::unique_ptr make_a_private_key(const std::string& algo) if(algo == "ECKCDSA" || algo == "ECGDSA") { return "brainpool256r1"; } + if(algo == "HSS-LMS") { + return "SHA-256,HW(5,4),HW(5,4)"; + } return ""; // default "" means choose acceptable algo-specific params }(); @@ -1144,7 +1147,7 @@ Test::Result test_valid_constraints(const Botan::Private_Key& key, const std::st result.test_eq("crl sign not permitted", crl_sign.compatible_with(key), false); result.test_eq("sign", sign_everything.compatible_with(key), false); } else if(pk_algo == "DSA" || pk_algo == "ECDSA" || pk_algo == "ECGDSA" || pk_algo == "ECKCDSA" || - pk_algo == "GOST-34.10" || pk_algo == "Dilithium") { + pk_algo == "GOST-34.10" || pk_algo == "Dilithium" || pk_algo == "HSS-LMS") { // these are signature algorithms only result.test_eq("all constraints not permitted", all.compatible_with(key), false); @@ -1437,6 +1440,8 @@ std::vector get_sig_paddings(const std::string& sig_algo, const std return {"Pure"}; } else if(sig_algo == "Dilithium") { return {"Randomized"}; + } else if(sig_algo == "HSS-LMS") { + return {""}; } else { return {}; } @@ -1448,7 +1453,7 @@ class X509_Cert_Unit_Tests final : public Test { std::vector results; const std::string sig_algos[]{ - "RSA", "DSA", "ECDSA", "ECGDSA", "ECKCDSA", "GOST-34.10", "Ed25519", "Dilithium"}; + "RSA", "DSA", "ECDSA", "ECGDSA", "ECKCDSA", "GOST-34.10", "Ed25519", "Dilithium", "HSS-LMS"}; for(const std::string& algo : sig_algos) { #if !defined(BOTAN_HAS_EMSA_PKCS1)