Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use span in PK operations and avoid needless secure_vector #4239

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/lib/prov/pkcs11/p11_ecdh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,18 @@ class PKCS11_ECDH_KA_Operation final : public PK_Ops::Key_Agreement {
/// The encoding in V2.20 was not specified and resulted in different implementations choosing different encodings.
/// Applications relying only on a V2.20 encoding (e.g. the DER variant) other than the one specified now (raw) may not work with all V2.30 compliant tokens.
secure_vector<uint8_t> agree(size_t key_len,
const uint8_t other_key[],
size_t other_key_len,
const uint8_t salt[],
size_t salt_len) override {
std::span<const uint8_t> other_key,
std::span<const uint8_t> salt) override {
std::vector<uint8_t> der_encoded_other_key;
if(m_key.point_encoding() == PublicPointEncoding::Der) {
DER_Encoder(der_encoded_other_key).encode(other_key, other_key_len, ASN1_Type::OctetString);
DER_Encoder(der_encoded_other_key).encode(other_key.data(), other_key.size(), ASN1_Type::OctetString);
m_mechanism.set_ecdh_other_key(der_encoded_other_key.data(), der_encoded_other_key.size());
} else {
m_mechanism.set_ecdh_other_key(other_key, other_key_len);
m_mechanism.set_ecdh_other_key(other_key.data(), other_key.size());
}

if(salt != nullptr && salt_len > 0) {
m_mechanism.set_ecdh_salt(salt, salt_len);
if(!salt.empty()) {
m_mechanism.set_ecdh_salt(salt.data(), salt.size());
}

ObjectHandle secret_handle = 0;
Expand Down
25 changes: 13 additions & 12 deletions src/lib/prov/pkcs11/p11_ecdsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ class PKCS11_ECDSA_Signature_Operation final : public PK_Ops::Signature {
m_mechanism(MechanismWrapper::create_ecdsa_mechanism(hash)),
m_hash(hash) {}

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_SignInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -75,11 +75,11 @@ class PKCS11_ECDSA_Signature_Operation final : public PK_Ops::Signature {
m_first_message.clear();
}

m_key.module()->C_SignUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_SignUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

secure_vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
secure_vector<uint8_t> signature;
std::vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> signature;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Sign(m_key.session().handle(), m_first_message, signature);
Expand Down Expand Up @@ -121,12 +121,12 @@ class PKCS11_ECDSA_Verification_Operation final : public PK_Ops::Verification {
m_mechanism(MechanismWrapper::create_ecdsa_mechanism(hash)),
m_hash(hash) {}

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_VerifyInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -136,23 +136,24 @@ class PKCS11_ECDSA_Verification_Operation final : public PK_Ops::Verification {
m_first_message.clear();
}

m_key.module()->C_VerifyUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_VerifyUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

bool is_valid_signature(const uint8_t sig[], size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
ReturnValue return_value = ReturnValue::SignatureInvalid;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Verify(m_key.session().handle(),
m_first_message.data(),
static_cast<Ulong>(m_first_message.size()),
sig,
static_cast<Ulong>(sig_len),
sig.data(),
static_cast<Ulong>(sig.size()),
&return_value);
m_first_message.clear();
} else {
// multiple calls to update (or none): finish multiple-part operation
m_key.module()->C_VerifyFinal(m_key.session().handle(), sig, static_cast<Ulong>(sig_len), &return_value);
m_key.module()->C_VerifyFinal(
m_key.session().handle(), sig.data(), static_cast<Ulong>(sig.size()), &return_value);
}
m_initialized = false;
if(return_value != ReturnValue::OK && return_value != ReturnValue::SignatureInvalid) {
Expand Down
41 changes: 21 additions & 20 deletions src/lib/prov/pkcs11/p11_rsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ class PKCS11_RSA_Decryption_Operation final : public PK_Ops::Decryption {

size_t plaintext_length(size_t /*ctext_len*/) const override { return m_key.get_n().bytes(); }

secure_vector<uint8_t> decrypt(uint8_t& valid_mask, const uint8_t ciphertext[], size_t ciphertext_len) override {
secure_vector<uint8_t> decrypt(uint8_t& valid_mask, std::span<const uint8_t> ctext) override {
valid_mask = 0;
m_key.module()->C_DecryptInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());

std::vector<uint8_t> encrypted_data(ciphertext, ciphertext + ciphertext_len);
std::vector<uint8_t> encrypted_data(ctext.begin(), ctext.end());

const size_t modulus_bytes = (m_key.get_n().bits() + 7) / 8;

Expand Down Expand Up @@ -173,8 +173,8 @@ class PKCS11_RSA_Decryption_Operation_Software_EME final : public PK_Ops::Decryp

size_t plaintext_length(size_t ctext_len) const override { return m_raw_decryptor.plaintext_length(ctext_len); }

secure_vector<uint8_t> raw_decrypt(const uint8_t input[], size_t input_len) override {
return m_raw_decryptor.decrypt(input, input_len);
secure_vector<uint8_t> raw_decrypt(std::span<const uint8_t> input) override {
return m_raw_decryptor.decrypt(input);
}

private:
Expand All @@ -194,13 +194,13 @@ class PKCS11_RSA_Encryption_Operation final : public PK_Ops::Encryption {

size_t max_input_bits() const override { return m_bits; }

secure_vector<uint8_t> encrypt(const uint8_t msg[], size_t msg_len, RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> encrypt(std::span<const uint8_t> input, RandomNumberGenerator& /*rng*/) override {
m_key.module()->C_EncryptInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());

secure_vector<uint8_t> encrytped_data;
std::vector<uint8_t> encrypted_data;
m_key.module()->C_Encrypt(
m_key.session().handle(), secure_vector<uint8_t>(msg, msg + msg_len), encrytped_data);
return encrytped_data;
m_key.session().handle(), secure_vector<uint8_t>(input.begin(), input.end()), encrypted_data);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to future us: PKCS11 is also in need of a std::span treatment. This secure_vector is allocated and then just used for its C-style pointer/length.

return encrypted_data;
}

private:
Expand All @@ -216,12 +216,12 @@ class PKCS11_RSA_Signature_Operation final : public PK_Ops::Signature {

size_t signature_length() const override { return m_key.get_n().bytes(); }

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_SignInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -231,11 +231,11 @@ class PKCS11_RSA_Signature_Operation final : public PK_Ops::Signature {
m_first_message.clear();
}

m_key.module()->C_SignUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_SignUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

secure_vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
secure_vector<uint8_t> signature;
std::vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> signature;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Sign(m_key.session().handle(), m_first_message, signature);
Expand Down Expand Up @@ -331,12 +331,12 @@ class PKCS11_RSA_Verification_Operation final : public PK_Ops::Verification {
PKCS11_RSA_Verification_Operation(const PKCS11_RSA_PublicKey& key, std::string_view padding) :
m_key(key), m_mechanism(MechanismWrapper::create_rsa_sign_mechanism(padding)) {}

void update(const uint8_t msg[], size_t msg_len) override {
void update(std::span<const uint8_t> input) override {
if(!m_initialized) {
// first call to update: initialize and cache message because we can not determine yet whether a single- or multiple-part operation will be performed
m_key.module()->C_VerifyInit(m_key.session().handle(), m_mechanism.data(), m_key.handle());
m_initialized = true;
m_first_message = secure_vector<uint8_t>(msg, msg + msg_len);
m_first_message.assign(input.begin(), input.end());
return;
}

Expand All @@ -346,23 +346,24 @@ class PKCS11_RSA_Verification_Operation final : public PK_Ops::Verification {
m_first_message.clear();
}

m_key.module()->C_VerifyUpdate(m_key.session().handle(), msg, static_cast<Ulong>(msg_len));
m_key.module()->C_VerifyUpdate(m_key.session().handle(), input.data(), static_cast<Ulong>(input.size()));
}

bool is_valid_signature(const uint8_t sig[], size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
ReturnValue return_value = ReturnValue::SignatureInvalid;
if(!m_first_message.empty()) {
// single call to update: perform single-part operation
m_key.module()->C_Verify(m_key.session().handle(),
m_first_message.data(),
static_cast<Ulong>(m_first_message.size()),
sig,
static_cast<Ulong>(sig_len),
sig.data(),
static_cast<Ulong>(sig.size()),
&return_value);
m_first_message.clear();
} else {
// multiple calls to update (or none): finish multiple-part operation
m_key.module()->C_VerifyFinal(m_key.session().handle(), sig, static_cast<Ulong>(sig_len), &return_value);
m_key.module()->C_VerifyFinal(
m_key.session().handle(), sig.data(), static_cast<Ulong>(sig.size()), &return_value);
}
m_initialized = false;
if(return_value != ReturnValue::OK && return_value != ReturnValue::SignatureInvalid) {
Expand Down
6 changes: 3 additions & 3 deletions src/lib/prov/tpm/tpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ class TPM_Signing_Operation final : public PK_Ops::Signature {

size_t signature_length() const override { return m_key.get_n().bytes(); }

void update(const uint8_t msg[], size_t msg_len) override { m_hash->update(msg, msg_len); }
void update(std::span<const uint8_t> msg) override { m_hash->update(msg); }

AlgorithmIdentifier algorithm_identifier() const override {
const std::string full_name = "RSA/EMSA3(" + m_hash->name() + ")";
const OID oid = OID::from_string(full_name);
return AlgorithmIdentifier(oid, AlgorithmIdentifier::USE_EMPTY_PARAM);
}

secure_vector<uint8_t> sign(RandomNumberGenerator&) override {
std::vector<uint8_t> sign(RandomNumberGenerator&) override {
/*
* v1.2 TPMs will only sign with PKCS #1 v1.5 padding. SHA-1 is built
* in, all other hash inputs (TSS_HASH_OTHER) are treated as the
Expand All @@ -352,7 +352,7 @@ class TPM_Signing_Operation final : public PK_Ops::Signature {
BYTE* sig_bytes = nullptr;
UINT32 sig_len = 0;
TSPI_CHECK_SUCCESS(::Tspi_Hash_Sign(tpm_hash, m_key.handle(), &sig_len, &sig_bytes));
secure_vector<uint8_t> sig(sig_bytes, sig_bytes + sig_len);
std::vector<uint8_t> sig(sig_bytes, sig_bytes + sig_len);

// TODO: RAII for Context_FreeMemory
TSPI_CHECK_SUCCESS(::Tspi_Context_FreeMemory(ctx, sig_bytes));
Expand Down
10 changes: 5 additions & 5 deletions src/lib/pubkey/curve448/ed448/ed448.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ class Ed448_Verify_Operation final : public PK_Ops::Verification {
}
}

void update(const uint8_t msg[], size_t msg_len) override { m_message->update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_message->update(input); }

bool is_valid_signature(const uint8_t sig[], size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
const auto msg = m_message->get_and_clear();
try {
return verify_signature(m_pk, m_prehash_function.has_value(), {}, {sig, sig_len}, msg);
return verify_signature(m_pk, m_prehash_function.has_value(), {}, sig, msg);
} catch(Decoding_Error&) {
return false;
}
Expand Down Expand Up @@ -185,9 +185,9 @@ class Ed448_Sign_Operation final : public PK_Ops::Signature {
}
}

void update(const uint8_t msg[], size_t msg_len) override { m_message->update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_message->update(input); }

secure_vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
std::vector<uint8_t> sign(RandomNumberGenerator& /*rng*/) override {
BOTAN_ASSERT_NOMSG(m_sk.size() == ED448_LEN);
auto scope = CT::scoped_poison(m_sk);
const auto sig = sign_message(
Expand Down
10 changes: 5 additions & 5 deletions src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class Dilithium_Signature_Operation final : public PK_Ops::Signature {
m_t0(ntt(m_priv_key->t0().clone())),
m_A(Dilithium_Algos::expand_A(m_priv_key->rho(), m_priv_key->mode())) {}

void update(const uint8_t msg[], size_t msg_len) override { m_h.update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_h.update(input); }

/**
* NIST FIPS 204 IPD, Algorithm 2 (ML-DSA.Sign)
Expand All @@ -203,7 +203,7 @@ class Dilithium_Signature_Operation final : public PK_Ops::Signature {
* s2 and t0 are done in the constructor of this class, as a 'signature
* operation' may be used to sign multiple messages.
*/
secure_vector<uint8_t> sign(RandomNumberGenerator& rng) override {
std::vector<uint8_t> sign(RandomNumberGenerator& rng) override {
auto scope = CT::scoped_poison(*m_priv_key);

const auto mu = m_h.final();
Expand Down Expand Up @@ -300,7 +300,7 @@ class Dilithium_Verification_Operation final : public PK_Ops::Verification {
m_t1_ntt_shifted(ntt(m_pub_key->t1() << DilithiumConstants::D)),
m_h(m_pub_key->mode().symmetric_primitives().get_message_hash(m_pub_key->tr())) {}

void update(const uint8_t msg[], size_t msg_len) override { m_h.update({msg, msg_len}); }
void update(std::span<const uint8_t> input) override { m_h.update(input); }

/**
* NIST FIPS 204 IPD, Algorithm 3 (ML-DSA.Verify)
Expand All @@ -309,10 +309,10 @@ class Dilithium_Verification_Operation final : public PK_Ops::Verification {
* matrix A is expanded from 'rho' in the constructor of this class, as
* a 'verification operation' may be used to verify multiple signatures.
*/
bool is_valid_signature(const uint8_t* sig, size_t sig_len) override {
bool is_valid_signature(std::span<const uint8_t> sig) override {
const auto& mode = m_pub_key->mode();
const auto& sympri = mode.symmetric_primitives();
StrongSpan<const DilithiumSerializedSignature> sig_bytes({sig, sig_len});
StrongSpan<const DilithiumSerializedSignature> sig_bytes(sig);

if(sig_bytes.size() != mode.signature_bytes()) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using DilithiumHashedPublicKey = Strong<std::vector<uint8_t>, struct DilithiumHa
using DilithiumMessageRepresentative = Strong<std::vector<uint8_t>, struct DilithiumMessageRepresentative_>;

/// Serialized signature data
using DilithiumSerializedSignature = Strong<secure_vector<uint8_t>, struct DilithiumSerializedSignature_>;
using DilithiumSerializedSignature = Strong<std::vector<uint8_t>, struct DilithiumSerializedSignature_>;

/// Serialized representation of a commitment w1
using DilithiumSerializedCommitment = Strong<std::vector<uint8_t>, struct DilithiumSerializedCommitment_>;
Expand Down
Loading
Loading