Skip to content

Commit

Permalink
Apply review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
FAlbertDev authored and reneme committed Feb 16, 2024
1 parent 2e0c28b commit b6c3bc8
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 74 deletions.
2 changes: 1 addition & 1 deletion src/cli/speed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,7 @@ class Speed final : public Command {
"SHA-256,HW(10,1),HW(10,1)",
"SHA-256,HW(10,1),HW(10,1),HW(10,1)"};

for(auto params : hss_lms_instances) {
for(const auto& params : hss_lms_instances) {
auto keygen_timer = make_timer(params, provider, "keygen");

std::unique_ptr<Botan::Private_Key> key(
Expand Down
14 changes: 7 additions & 7 deletions src/lib/pubkey/hss_lms/hss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ secure_vector<uint8_t> HSS_LMS_PrivateKeyInternal::to_bytes() const {
stuffer.append(store_be(params.lms_params().algorithm_type()));
stuffer.append(store_be(params.lmots_params().algorithm_type()));
}
stuffer.append(seed());
stuffer.append(identifier());
stuffer.append(m_hss_seed);
stuffer.append(m_identifier);
BOTAN_ASSERT_NOMSG(stuffer.full());

return sk_bytes;
Expand All @@ -201,7 +201,7 @@ size_t HSS_LMS_PrivateKeyInternal::size() const {
size_t sk_size = sizeof(HSS_Level) + sizeof(HSS_Sig_Idx);
// The concatenated algorithm types for all layers
sk_size += hss_params().L().get() * (sizeof(LMS_Algorithm_Type) + sizeof(LMOTS_Algorithm_Type));
sk_size += seed().size() + identifier().size();
sk_size += m_hss_seed.size() + m_identifier.size();
return sk_size;
}

Expand Down Expand Up @@ -266,7 +266,7 @@ secure_vector<uint8_t> HSS_LMS_PrivateKeyInternal::sign(std::span<const uint8_t>

LMS_PrivateKey HSS_LMS_PrivateKeyInternal::hss_derive_root_lms_private_key() const {
auto& top_params = hss_params().params_at_level(HSS_Level(0));
return LMS_PrivateKey(top_params.lms_params(), top_params.lmots_params(), identifier(), seed());
return LMS_PrivateKey(top_params.lms_params(), top_params.lmots_params(), m_identifier, m_hss_seed);
}

LMS_PrivateKey HSS_LMS_PrivateKeyInternal::hss_derive_child_lms_private_key(
Expand Down Expand Up @@ -347,7 +347,7 @@ size_t HSS_LMS_PublicKeyInternal::size() const {
}

bool HSS_LMS_PublicKeyInternal::verify_signature(std::span<const uint8_t> msg, const HSS_Signature& sig) const {
if(checked_cast_to<HSS_Level>(sig.Nspk()) + 1 != L()) {
if(checked_cast_to<HSS_Level>(sig.Nspk()) + 1 != m_L) {
// HSS levels in the public key does not match with the signature's
return false;
}
Expand All @@ -356,7 +356,7 @@ bool HSS_LMS_PublicKeyInternal::verify_signature(std::span<const uint8_t> msg, c
const auto hash_name = lms_pk->lms_params().hash_name();

// Verify the signature by the above layer over the LMS public keys for layer 1 to Nspk.
for(uint8_t layer = 0; layer < sig.Nspk(); ++layer) {
for(HSS_Level layer(0); layer < sig.Nspk(); ++layer) {
const HSS_Signature::Signed_Pub_Key& signed_pub_key = sig.signed_pub_key(layer);
if(signed_pub_key.public_key().lms_params().hash_name() != hash_name ||
signed_pub_key.public_key().lmots_params().hash_name() != hash_name) {
Expand Down Expand Up @@ -399,7 +399,7 @@ HSS_Signature HSS_Signature::from_bytes_or_throw(std::span<const uint8_t> sig_by
if(!slicer.empty()) {
throw Decoding_Error("HSS-LMS signature contains more bytes than expected.");
}
return HSS_Signature(checked_cast_to<uint8_t>(Nspk), std::move(signed_pub_keys), std::move(sig));
return HSS_Signature(std::move(signed_pub_keys), std::move(sig));
}

size_t HSS_Signature::size(const HSS_LMS_Params& params) {
Expand Down
26 changes: 5 additions & 21 deletions src/lib/pubkey/hss_lms/hss.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace Botan {
using HSS_Sig_Idx = Strong<uint64_t, struct HSS_Sig_Idx_, EnableArithmeticWithPlainNumber>;

/**
* @brief The index of a node within a specific LMS tree layer
* @brief The HSS layer in the HSS multi tree starting at 0 from the root
*/
using HSS_Level = Strong<uint32_t, struct HSS_Level_, EnableArithmeticWithPlainNumber>;

Expand Down Expand Up @@ -191,16 +191,6 @@ class BOTAN_TEST_API HSS_LMS_PrivateKeyInternal final {
private:
HSS_LMS_PrivateKeyInternal(HSS_LMS_Params hss_params, LMS_Seed hss_seed, LMS_Identifier identifier);

/**
* @brief Returns the seed contained in the private key used for key derivation
*/
const LMS_Seed& seed() const { return m_hss_seed; }

/**
* @brief The identifier of the top level LMS tree
*/
const LMS_Identifier& identifier() const { return m_identifier; }

/**
* @brief Get the index of the next signature to generate and
* increase the counter by one.
Expand Down Expand Up @@ -305,11 +295,6 @@ class BOTAN_TEST_API HSS_LMS_PublicKeyInternal final {
bool verify_signature(std::span<const uint8_t> msg, const HSS_Signature& sig) const;

private:
/**
* @brief Returns the number of layers of LMS trees in the HSS tree.
*/
HSS_Level L() const { return m_L; }

HSS_Level m_L;
LMS_PublicKey m_top_lms_pub_key;
};
Expand Down Expand Up @@ -370,15 +355,15 @@ class BOTAN_TEST_API HSS_Signature final {
/**
* @brief Returns the number of signed public keys (Nspk = L-1).
*/
uint8_t Nspk() const { return m_Nspk; }
HSS_Level Nspk() const { return HSS_Level(static_cast<uint32_t>(m_signed_pub_keys.size())); }

/**
* @brief Returns the signed LMS key signed by a specific layer.
*
* @param layer The layer by which the LMS key is signed.
* @return The LMS key and the signature by its parent layer.
*/
const Signed_Pub_Key& signed_pub_key(uint8_t layer) const { return m_signed_pub_keys.at(layer); }
const Signed_Pub_Key& signed_pub_key(HSS_Level layer) const { return m_signed_pub_keys.at(layer.get()); }

/**
* @brief Returns the LMS signature by the bottom layer of the signed message.
Expand All @@ -389,10 +374,9 @@ class BOTAN_TEST_API HSS_Signature final {
/**
* @brief Private constructor using the individual signature fields.
*/
HSS_Signature(uint8_t Nspk, std::vector<Signed_Pub_Key> signed_pub_keys, LMS_Signature sig) :
m_Nspk(Nspk), m_signed_pub_keys(std::move(signed_pub_keys)), m_sig(std::move(sig)) {}
HSS_Signature(std::vector<Signed_Pub_Key> signed_pub_keys, LMS_Signature sig) :
m_signed_pub_keys(std::move(signed_pub_keys)), m_sig(std::move(sig)) {}

uint8_t m_Nspk;
std::vector<Signed_Pub_Key> m_signed_pub_keys;
LMS_Signature m_sig;
};
Expand Down
9 changes: 3 additions & 6 deletions src/lib/pubkey/hss_lms/hss_lms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ class HSS_LMS_Verification_Operation final : public PK_Ops::Verification {
}

bool is_valid_signature(const uint8_t* sig, size_t sig_len) override {
std::vector<uint8_t> message_to_verify = std::move(m_msg_buffer);
m_msg_buffer.clear();
std::vector<uint8_t> message_to_verify = std::exchange(m_msg_buffer, {});
try {
const auto signature = HSS_Signature::from_bytes_or_throw({sig, sig_len});
bool sig_valid = m_public->verify_signature(message_to_verify, signature);
return sig_valid;
return m_public->verify_signature(message_to_verify, signature);
} catch(const Decoding_Error&) {
// Signature could not be decoded
return false;
Expand Down Expand Up @@ -167,8 +165,7 @@ class HSS_LMS_Signature_Operation final : public PK_Ops::Signature {
}

secure_vector<uint8_t> sign(RandomNumberGenerator&) override {
std::vector<uint8_t> message_to_sign = std::move(m_msg_buffer);
m_msg_buffer.clear();
std::vector<uint8_t> message_to_sign = std::exchange(m_msg_buffer, {});
return m_private->sign(message_to_sign);
}

Expand Down
12 changes: 8 additions & 4 deletions src/lib/pubkey/hss_lms/hss_lms_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
#include <botan/internal/stl_util.h>

namespace Botan {
PseudorandomKeyGeneration::PseudorandomKeyGeneration(std::span<const uint8_t> identifier) {
m_input_buffer.resize(identifier.size() + sizeof(uint32_t) + sizeof(uint16_t) + sizeof(uint8_t));
BufferStuffer input_stuffer(m_input_buffer);
input_stuffer.append(identifier);
PseudorandomKeyGeneration::PseudorandomKeyGeneration(std::span<const uint8_t> identifier) :
m_input_buffer(identifier.size() + sizeof(uint32_t) + sizeof(uint16_t) + sizeof(uint8_t)),
m_q(m_input_buffer.data() + identifier.size(), sizeof(uint32_t)),
m_i(m_input_buffer.data() + identifier.size() + sizeof(uint32_t), sizeof(uint16_t)),
m_j(m_input_buffer.data() + identifier.size() + sizeof(uint32_t) + sizeof(uint16_t), sizeof(uint8_t))

{
copy_mem(m_input_buffer.data(), identifier.data(), identifier.size());
}

void PseudorandomKeyGeneration::gen(std::span<uint8_t> out, HashFunction& hash, std::span<const uint8_t> seed) const {
Expand Down
10 changes: 7 additions & 3 deletions src/lib/pubkey/hss_lms/hss_lms_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ class PseudorandomKeyGeneration {
/**
* @brief Specify the value for the u32str(q) hash input field
*/
void set_q(uint32_t q) { store_be(q, std::span(m_input_buffer).last<7>().first<4>().data()); }
void set_q(uint32_t q) { store_be(q, m_q.data()); }

/**
* @brief Specify the value for the u16str(i) hash input field
*/
void set_i(uint16_t i) { store_be(i, std::span(m_input_buffer).last<3>().first<2>().data()); }
void set_i(uint16_t i) { store_be(i, m_i.data()); }

/**
* @brief Specify the value for the u8str(j) hash input field
*/
void set_j(uint8_t j) { m_input_buffer.back() = j; }
void set_j(uint8_t j) { m_j[0] = j; }

/**
* @brief Create a hash value using the preconfigured prefix and a @p seed
Expand All @@ -66,6 +66,10 @@ class PseudorandomKeyGeneration {
private:
/// Input buffer containing the prefix: 'identifier || u32str(q) || u16str(i) || u8str(j)'
std::vector<uint8_t> m_input_buffer;

std::span<uint8_t, sizeof(uint32_t)> m_q;
std::span<uint8_t, sizeof(uint16_t)> m_i;
std::span<uint8_t, sizeof(uint8_t)> m_j;
};

} // namespace Botan
Expand Down
2 changes: 1 addition & 1 deletion src/lib/pubkey/hss_lms/info.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ HSS_LMS -> 20230925
</defines>

<module_info>
name -> "HSS_LMS"
name -> "HSS-LMS"
</module_info>

<header:public>
Expand Down
21 changes: 7 additions & 14 deletions src/lib/pubkey/hss_lms/lm_ots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,9 @@ class Chain_Generator {
std::span<uint8_t> out) {
BOTAN_ARG_CHECK(start <= end, "Start value is bigger than end value");

if(start == end) {
copy_into(out, in);
return;
}
copy_into(out, in);
m_gen.set_i(chain_idx);

// Unroll first iteration of the loop
m_gen.set_j(start++);
m_gen.gen(out, hash, in);

for(uint8_t j = start; j < end; ++j) {
m_gen.set_j(j);
m_gen.gen(out, hash, out);
Expand Down Expand Up @@ -86,8 +79,8 @@ std::vector<uint8_t> gen_Q_with_cksm(const LMOTS_Params& params,
BufferStuffer qwc_stuffer(Q_with_cksm);
const auto hash = HashFunction::create_or_throw(params.hash_name());
hash->update(identifier);
hash->update_be(q.get());
hash->update_be(D_MESG);
hash->update(store_be(q));
hash->update(store_be(D_MESG));
hash->update(C);
hash->update(msg);
auto Q_span = qwc_stuffer.next(params.n());
Expand Down Expand Up @@ -312,8 +305,8 @@ void LMOTS_Private_Key::derive_random_C(std::span<uint8_t> out, HashFunction& ha
LMOTS_Public_Key::LMOTS_Public_Key(const LMOTS_Private_Key& lmots_sk) : OTS_Instance(lmots_sk) {
const auto pk_hash = HashFunction::create_or_throw(lmots_sk.params().hash_name());
pk_hash->update(lmots_sk.identifier());
pk_hash->update_be(lmots_sk.q().get());
pk_hash->update_be(D_PBLC);
pk_hash->update(store_be(lmots_sk.q()));
pk_hash->update(store_be(D_PBLC));

Chain_Generator chain_gen(lmots_sk.identifier(), lmots_sk.q());
const auto hash = HashFunction::create_or_throw(lmots_sk.params().hash_name());
Expand All @@ -339,8 +332,8 @@ LMOTS_K lmots_compute_pubkey_from_sig(const LMOTS_Signature& sig,
// Prefill the final hash object
const auto pk_hash = HashFunction::create_or_throw(params.hash_name());
pk_hash->update(identifier);
pk_hash->update_be(q.get());
pk_hash->update_be(D_PBLC);
pk_hash->update(store_be(q));
pk_hash->update(store_be(D_PBLC));

Chain_Generator chain_gen(identifier, q);
const auto hash = HashFunction::create_or_throw(params.hash_name());
Expand Down
8 changes: 4 additions & 4 deletions src/lib/pubkey/hss_lms/lms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ auto get_hash_pair_func_for_identifier(const LMS_Params& lms_params, LMS_Identif
auto lms_address = dynamic_cast<const TreeAddress&>(address);

hash->update(I);
hash->update_be(lms_address.r());
hash->update_be(D_INTR);
hash->update(store_be(lms_address.r()));
hash->update(store_be(D_INTR));
hash->update(left);
hash->update(right);
hash->final(out);
Expand All @@ -76,8 +76,8 @@ void lms_gen_leaf(StrongSpan<LMS_Tree_Node> out,
const TreeAddress& tree_address,
HashFunction& hash) {
hash.update(lmots_pk.identifier());
hash.update_be(tree_address.r());
hash.update_be(D_LEAF);
hash.update(store_be(tree_address.r()));
hash.update(store_be(D_LEAF));
hash.update(lmots_pk.K());
hash.final(out);
}
Expand Down
19 changes: 7 additions & 12 deletions src/lib/utils/tree_hash/tree_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,13 @@ inline void treehash(
BOTAN_ASSERT_NOMSG(out_root.size() == node_size);
BOTAN_ASSERT(out_auth_path.has_value() == leaf_idx.has_value(),
"Both leaf index and auth path buffer is given or neither.");
bool is_signing = leaf_idx.has_value();
if(is_signing) {
BOTAN_ASSERT_NOMSG(out_auth_path.value().size() == node_size * total_tree_height.get());
}
const bool is_signing = leaf_idx.has_value();
BOTAN_ASSERT_NOMSG(!is_signing || out_auth_path.value().size() == node_size * total_tree_height.get());

const TreeNodeIndex max_idx(uint32_t((1 << total_tree_height.get()) - 1));

std::vector<uint8_t> stack(total_tree_height.get() * node_size);
std::vector<TreeNode> last_visited_left_child_at_layer(total_tree_height.get(), TreeNode(node_size));

TreeNode current_node(node_size); // Current logical node

// Traverse the tree from the left-most leaf, matching siblings and up until
Expand Down Expand Up @@ -162,18 +161,14 @@ inline void treehash(
copy_into(auth_path_location, current_node);
}

// At this point we know that we'll need to use the stack. Get a
// reference to the correct location.
auto stack_location = StrongSpan<TreeNode>(std::span(stack).subspan(h.get() * node_size, node_size));

// Check if we're at a left child; if so, stop going up the stack
// Check if we're at a left child; if so, stop going up the tree
// Exception: if we've reached the end of the tree, keep on going (so
// we combine the last 4 nodes into the one root node in two more
// iterations)
if((internal_idx & 1) == 0U && idx < max_idx) {
// We've hit a left child; save the current for when we get the
// corresponding right child.
copy_into(stack_location, current_node);
copy_into(last_visited_left_child_at_layer.at(h.get()), current_node);
break;
}

Expand All @@ -184,7 +179,7 @@ inline void treehash(
internal_idx_offset /= 2;
tree_address.set_address(h + 1, internal_idx / 2 + internal_idx_offset);

node_pair_hash(current_node, tree_address, stack_location, current_node);
node_pair_hash(current_node, tree_address, last_visited_left_child_at_layer.at(h.get()), current_node);

internal_idx /= 2;
if(internal_leaf.has_value()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_lmots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ BOTAN_REGISTER_TEST("pubkey", "lmots", LMOTS_Test);
} // namespace
} // namespace Botan_Tests

#endif // BOTAN_HAS_HSS_LMS
#endif // BOTAN_HAS_HSS_LMS

0 comments on commit b6c3bc8

Please sign in to comment.