Skip to content

Commit

Permalink
Use span in PK operations and avoid needless secure_vector
Browse files Browse the repository at this point in the history
Fixes #4014
  • Loading branch information
randombit committed Jul 21, 2024
1 parent 5abbd7b commit c125c52
Show file tree
Hide file tree
Showing 31 changed files with 292 additions and 305 deletions.
14 changes: 6 additions & 8 deletions src/lib/prov/pkcs11/p11_ecdh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,18 @@ class PKCS11_ECDH_KA_Operation final : public PK_Ops::Key_Agreement {
/// The encoding in V2.20 was not specified and resulted in different implementations choosing different encodings.
/// Applications relying only on a V2.20 encoding (e.g. the DER variant) other than the one specified now (raw) may not work with all V2.30 compliant tokens.
secure_vector<uint8_t> agree(size_t key_len,
const uint8_t other_key[],
size_t other_key_len,
const uint8_t salt[],
size_t salt_len) override {
std::span<const uint8_t> other_key,
std::span<const uint8_t> salt) override {
std::vector<uint8_t> der_encoded_other_key;
if(m_key.point_encoding() == PublicPointEncoding::Der) {
DER_Encoder(der_encoded_other_key).encode(other_key, other_key_len, ASN1_Type::OctetString);
DER_Encoder(der_encoded_other_key).encode(other_key.data(), other_key.size(), ASN1_Type::OctetString);
m_mechanism.set_ecdh_other_key(der_encoded_other_key.data(), der_encoded_other_key.size());
} else {
m_mechanism.set_ecdh_other_key(other_key, other_key_len);
m_mechanism.set_ecdh_other_key(other_key.data(), other_key.size());
}

if(salt != nullptr && salt_len > 0) {
m_mechanism.set_ecdh_salt(salt, salt_len);
if(!salt.empty()) {
m_mechanism.set_ecdh_salt(salt.data(), salt.size());
}

ObjectHandle secret_handle = 0;
Expand Down
25 changes: 13 additions & 12 deletions src/lib/prov/pkcs11/p11_ecdsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ class PKCS11_ECDSA_Signature_Operation final : public PK_Ops::Signature {
m_mechanism(MechanismWrapper::create_ecdsa_mechanism(hash)),
m_hash(hash) {}

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_SignInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -75,11 +75,11 @@ class PKCS11_ECDSA_Signature_Operation final : public PK_Ops::Signature {
m_first_message.clear();
}

m_key.module()->C_SignUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_SignUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

secure_vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
secure_vector<uint8_t> signature;
std::vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> signature;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Sign(m_key.session().handle(), m_first_message, signature);
Expand Down Expand Up @@ -121,12 +121,12 @@ class PKCS11_ECDSA_Verification_Operation final : public PK_Ops::Verification {
m_mechanism(MechanismWrapper::create_ecdsa_mechanism(hash)),
m_hash(hash) {}

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_VerifyInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -136,23 +136,24 @@ class PKCS11_ECDSA_Verification_Operation final : public PK_Ops::Verification {
m_first_message.clear();
}

m_key.module()->C_VerifyUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_VerifyUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

bool is_valid_signature(const uint8_t sig[], size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
ReturnValue return_value = ReturnValue::SignatureInvalid;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Verify(m_key.session().handle(),
m_first_message.data(),
static_cast<Ulong>(m_first_message.size()),
sig,
static_cast<Ulong>(sig_len),
sig.data(),
static_cast<Ulong>(sig.size()),
&return_value);
m_first_message.clear();
} else {
// multiple calls to update (or none): finish multiple-part operation
m_key.module()->C_VerifyFinal(m_key.session().handle(), sig, static_cast<Ulong>(sig_len), &return_value);
m_key.module()->C_VerifyFinal(
m_key.session().handle(), sig.data(), static_cast<Ulong>(sig.size()), &return_value);
}
m_initialized = false;
if(return_value != ReturnValue::OK && return_value != ReturnValue::SignatureInvalid) {
Expand Down
41 changes: 21 additions & 20 deletions src/lib/prov/pkcs11/p11_rsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ class PKCS11_RSA_Decryption_Operation final : public PK_Ops::Decryption {

size_t plaintext_length(size_t /*ctext_len*/) const override { return m_key.get_n().bytes(); }

secure_vector<uint8_t> decrypt(uint8_t& valid_mask, const uint8_t ciphertext[], size_t ciphertext_len) override {
secure_vector<uint8_t> decrypt(uint8_t& valid_mask, std::span<const uint8_t> ctext) override {
valid_mask = 0;
m_key.module()->C_DecryptInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());

std::vector<uint8_t> encrypted_data(ciphertext, ciphertext + ciphertext_len);
std::vector<uint8_t> encrypted_data(ctext.begin(), ctext.end());

const size_t modulus_bytes = (m_key.get_n().bits() + 7) / 8;

Expand Down Expand Up @@ -173,8 +173,8 @@ class PKCS11_RSA_Decryption_Operation_Software_EME final : public PK_Ops::Decryp

size_t plaintext_length(size_t ctext_len) const override { return m_raw_decryptor.plaintext_length(ctext_len); }

secure_vector<uint8_t> raw_decrypt(const uint8_t input[], size_t input_len) override {
return m_raw_decryptor.decrypt(input, input_len);
secure_vector<uint8_t> raw_decrypt(std::span<const uint8_t> input) override {
return m_raw_decryptor.decrypt(input.data(), input.size());
}

private:
Expand All @@ -194,13 +194,13 @@ class PKCS11_RSA_Encryption_Operation final : public PK_Ops::Encryption {

size_t max_input_bits() const override { return m_bits; }

secure_vector<uint8_t> encrypt(const uint8_t msg[], size_t msg_len, RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> encrypt(std::span<const uint8_t> input, RandomNumberGenerator& /*rng*/) override {
m_key.module()->C_EncryptInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());

secure_vector<uint8_t> encrytped_data;
std::vector<uint8_t> encrypted_data;
m_key.module()->C_Encrypt(
m_key.session().handle(), secure_vector<uint8_t>(msg, msg + msg_len), encrytped_data);
return encrytped_data;
m_key.session().handle(), secure_vector<uint8_t>(input.begin(), input.end()), encrypted_data);
return encrypted_data;
}

private:
Expand All @@ -216,12 +216,12 @@ class PKCS11_RSA_Signature_Operation final : public PK_Ops::Signature {

size_t signature_length() const override { return m_key.get_n().bytes(); }

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_SignInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -231,11 +231,11 @@ class PKCS11_RSA_Signature_Operation final : public PK_Ops::Signature {
m_first_message.clear();
}

m_key.module()->C_SignUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_SignUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

secure_vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
secure_vector<uint8_t> signature;
std::vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> signature;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Sign(m_key.session().handle(), m_first_message, signature);
Expand Down Expand Up @@ -331,12 +331,12 @@ class PKCS11_RSA_Verification_Operation final : public PK_Ops::Verification {
PKCS11_RSA_Verification_Operation(const PKCS11_RSA_PublicKey& key, std::string_view padding) :
m_key(key), m_mechanism(MechanismWrapper::create_rsa_sign_mechanism(padding)) {}

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_VerifyInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -346,23 +346,24 @@ class PKCS11_RSA_Verification_Operation final : public PK_Ops::Verification {
m_first_message.clear();
}

m_key.module()->C_VerifyUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_VerifyUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

bool is_valid_signature(const uint8_t sig[], size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
ReturnValue return_value = ReturnValue::SignatureInvalid;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Verify(m_key.session().handle(),
m_first_message.data(),
static_cast<Ulong>(m_first_message.size()),
sig,
static_cast<Ulong>(sig_len),
sig.data(),
static_cast<Ulong>(sig.size()),
&return_value);
m_first_message.clear();
} else {
// multiple calls to update (or none): finish multiple-part operation
m_key.module()->C_VerifyFinal(m_key.session().handle(), sig, static_cast<Ulong>(sig_len), &return_value);
m_key.module()->C_VerifyFinal(
m_key.session().handle(), sig.data(), static_cast<Ulong>(sig.size()), &return_value);
}
m_initialized = false;
if(return_value != ReturnValue::OK && return_value != ReturnValue::SignatureInvalid) {
Expand Down
6 changes: 3 additions & 3 deletions src/lib/prov/tpm/tpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ class TPM_Signing_Operation final : public PK_Ops::Signature {

size_t signature_length() const override { return m_key.get_n().bytes(); }

void update(const uint8_t msg[], size_t msg_len) override { m_hash->update(msg, msg_len); }
void update(std::span<const uint8_t> msg) override { m_hash->update(msg); }

AlgorithmIdentifier algorithm_identifier() const override {
const std::string full_name = "RSA/EMSA3(" + m_hash->name() + ")";
const OID oid = OID::from_string(full_name);
return AlgorithmIdentifier(oid, AlgorithmIdentifier::USE_EMPTY_PARAM);
}

secure_vector<uint8_t> sign(RandomNumberGenerator&) override {
std::vector<uint8_t> sign(RandomNumberGenerator&) override {
/*
* v1.2 TPMs will only sign with PKCS #1 v1.5 padding. SHA-1 is built
* in, all other hash inputs (TSS_HASH_OTHER) are treated as the
Expand All @@ -352,7 +352,7 @@ class TPM_Signing_Operation final : public PK_Ops::Signature {
BYTE* sig_bytes = nullptr;
UINT32 sig_len = 0;
TSPI_CHECK_SUCCESS(::Tspi_Hash_Sign(tpm_hash, m_key.handle(), &sig_len, &sig_bytes));
secure_vector<uint8_t> sig(sig_bytes, sig_bytes + sig_len);
std::vector<uint8_t> sig(sig_bytes, sig_bytes + sig_len);

// TODO: RAII for Context_FreeMemory
TSPI_CHECK_SUCCESS(::Tspi_Context_FreeMemory(ctx, sig_bytes));
Expand Down
10 changes: 5 additions & 5 deletions src/lib/pubkey/curve448/ed448/ed448.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ class Ed448_Verify_Operation final : public PK_Ops::Verification {
}
}

void update(const uint8_t msg[], size_t msg_len) override { m_message->update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_message->update(input); }

bool is_valid_signature(const uint8_t sig[], size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
const auto msg = m_message->get_and_clear();
try {
return verify_signature(m_pk, m_prehash_function.has_value(), {}, {sig, sig_len}, msg);
return verify_signature(m_pk, m_prehash_function.has_value(), {}, sig, msg);
} catch(Decoding_Error&) {
return false;
}
Expand Down Expand Up @@ -185,9 +185,9 @@ class Ed448_Sign_Operation final : public PK_Ops::Signature {
}
}

void update(const uint8_t msg[], size_t msg_len) override { m_message->update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_message->update(input); }

secure_vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
BOTAN_ASSERT_NOMSG(m_sk.size() == ED448_LEN);
auto scope = CT::scoped_poison(m_sk);
const auto sig = sign_message(
Expand Down
10 changes: 5 additions & 5 deletions src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class Dilithium_Signature_Operation final : public PK_Ops::Signature {
m_t0(ntt(m_priv_key->t0().clone())),
m_A(Dilithium_Algos::expand_A(m_priv_key->rho(), m_priv_key->mode())) {}

void update(const uint8_t msg[], size_t msg_len) override { m_h.update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_h.update(input); }

/**
* NIST FIPS 204 IPD, Algorithm 2 (ML-DSA.Sign)
Expand All @@ -203,7 +203,7 @@ class Dilithium_Signature_Operation final : public PK_Ops::Signature {
* s2 and t0 are done in the constructor of this class, as a 'signature
* operation' may be used to sign multiple messages.
*/
secure_vector<uint8_t> sign(RandomNumberGenerator& rng) override {
std::vector<uint8_t> sign(RandomNumberGenerator& rng) override {
auto scope = CT::scoped_poison(*m_priv_key);

const auto mu = m_h.final();
Expand Down Expand Up @@ -300,7 +300,7 @@ class Dilithium_Verification_Operation final : public PK_Ops::Verification {
m_t1_ntt_shifted(ntt(m_pub_key->t1() << DilithiumConstants::D)),
m_h(m_pub_key->mode().symmetric_primitives().get_message_hash(m_pub_key->tr())) {}

void update(const uint8_t msg[], size_t msg_len) override { m_h.update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_h.update(input); }

/**
* NIST FIPS 204 IPD, Algorithm 3 (ML-DSA.Verify)
Expand All @@ -309,10 +309,10 @@ class Dilithium_Verification_Operation final : public PK_Ops::Verification {
* matrix A is expanded from 'rho' in the constructor of this class, as
* a 'verification operation' may be used to verify multiple signatures.
*/
bool is_valid_signature(const uint8_t* sig, size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
const auto& mode = m_pub_key->mode();
const auto& sympri = mode.symmetric_primitives();
StrongSpan<const DilithiumSerializedSignature> sig_bytes({sig, sig_len});
StrongSpan<const DilithiumSerializedSignature> sig_bytes(sig);

if(sig_bytes.size() != mode.signature_bytes()) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using DilithiumHashedPublicKey = Strong<std::vector<uint8_t>, struct DilithiumHa
using DilithiumMessageRepresentative = Strong<std::vector<uint8_t>, struct DilithiumMessageRepresentative_>;

/// Serialized signature data
using DilithiumSerializedSignature = Strong<secure_vector<uint8_t>, struct DilithiumSerializedSignature_>;
using DilithiumSerializedSignature = Strong<std::vector<uint8_t>, struct DilithiumSerializedSignature_>;

/// Serialized representation of a commitment w1
using DilithiumSerializedCommitment = Strong<std::vector<uint8_t>, struct DilithiumSerializedCommitment_>;
Expand Down
Loading

0 comments on commit c125c52

Please sign in to comment.