Skip to content

Commit

Permalink
Merge pull request #3411 from randombit/feature/remove_pqc_der_encodings
Browse files Browse the repository at this point in the history
Kyber Encoding Improvements
  • Loading branch information
reneme authored Mar 24, 2023
2 parents a735162 + d41524a commit ab8ee4c
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 413 deletions.
39 changes: 0 additions & 39 deletions src/lib/pubkey/dilithium/dilithium_common/dilithium.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,6 @@
namespace Botan {
namespace {

/**
* Helper class to ease unmarshalling of concatenated fixed-length values
*/
class BufferSlicer final
{
public:
BufferSlicer(std::span<const uint8_t> buffer) : m_remaining(buffer)
{}

template <typename ContainerT>
auto take_as(const size_t count)
{
const auto result = take(count);
return ContainerT(result.begin(), result.end());
}

auto take_vector(const size_t count) { return take_as<std::vector<uint8_t>>(count); }
auto take_secure_vector(const size_t count) { return take_as<secure_vector<uint8_t>>(count); }

std::span<const uint8_t> take(const size_t count)
{
BOTAN_STATE_CHECK(remaining() >= count);
auto result = m_remaining.first(count);
m_remaining = m_remaining.subspan(count);
return result;
}

void copy_into(std::span<uint8_t> sink)
{
const auto data = take(sink.size());
std::copy(data.begin(), data.end(), sink.begin());
}

size_t remaining() const { return m_remaining.size(); }

private:
std::span<const uint8_t> m_remaining;
};

std::pair<Dilithium::PolynomialVector, Dilithium::PolynomialVector>
calculate_t0_and_t1(const DilithiumModeConstants& mode,
const std::vector<uint8_t>& rho,
Expand Down
227 changes: 49 additions & 178 deletions src/lib/pubkey/kyber/kyber_common/kyber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,15 @@ class Polynomial
: Polynomial::cbd3(mode.PRF(seed, nonce, outlen));
}

template <typename Alloc>
static Polynomial from_bytes(const std::vector<uint8_t, Alloc>& a, const size_t offset = 0)
static Polynomial from_bytes(std::span<const uint8_t> a)
{
Polynomial r;
for(size_t i = 0; i < r.m_coeffs.size() / 2; ++i)
{
r.m_coeffs[2 * i] =
((a[3 * i + 0 + offset] >> 0) | (static_cast<uint16_t>(a[3 * i + 1 + offset]) << 8)) & 0xFFF;
((a[3 * i + 0] >> 0) | (static_cast<uint16_t>(a[3 * i + 1]) << 8)) & 0xFFF;
r.m_coeffs[2 * i + 1] =
((a[3 * i + 1 + offset] >> 4) | (static_cast<uint16_t>(a[3 * i + 2 + offset]) << 4)) & 0xFFF;
((a[3 * i + 1] >> 4) | (static_cast<uint16_t>(a[3 * i + 2]) << 4)) & 0xFFF;
}
return r;
}
Expand Down Expand Up @@ -679,14 +678,16 @@ class PolynomialVector
{
}

template <typename Alloc>
static PolynomialVector from_bytes(const std::vector<uint8_t, Alloc>& a, const KyberConstants& mode)
static PolynomialVector from_bytes(std::span<const uint8_t> a, const KyberConstants& mode)
{
BOTAN_ASSERT(a.size() == mode.polynomial_vector_byte_length(), "wrong byte length for frombytes");

PolynomialVector r(mode.k());
for(size_t i = 0; i < mode.k(); ++i)
{ r.m_vec[i] = Polynomial::from_bytes(a, i * KyberConstants::kSerializedPolynomialByteLength); }
{
r.m_vec[i] = Polynomial::from_bytes(a.subspan(0, KyberConstants::kSerializedPolynomialByteLength));
a = a.subspan(KyberConstants::kSerializedPolynomialByteLength);
}
return r;
}

Expand Down Expand Up @@ -803,7 +804,7 @@ class PolynomialMatrix
public:
PolynomialMatrix() = delete;

static PolynomialMatrix generate(const std::vector<uint8_t>& seed, const bool transposed,
static PolynomialMatrix generate(std::span<const uint8_t> seed, const bool transposed,
const KyberConstants& mode)
{
BOTAN_ASSERT(seed.size() == KyberConstants::kSymBytes, "unexpected seed size");
Expand Down Expand Up @@ -1110,7 +1111,7 @@ class Kyber_PublicKeyInternal
{
public:
Kyber_PublicKeyInternal(KyberConstants mode,
const std::vector<uint8_t>& polynomials,
std::span<const uint8_t> polynomials,
std::vector<uint8_t> seed)
: m_mode(std::move(mode)),
m_polynomials(PolynomialVector::from_bytes(polynomials, m_mode)),
Expand Down Expand Up @@ -1400,77 +1401,41 @@ size_t Kyber_PublicKey::estimated_strength() const
return m_public->mode().estimated_strength();
}

void Kyber_PublicKey::initialize_from_encoding(const std::vector<uint8_t>& pub_key,
KyberMode m,
KyberKeyEncoding encoding)
std::shared_ptr<Kyber_PublicKeyInternal>
Kyber_PublicKey::initialize_from_encoding(std::span<const uint8_t> pub_key, KyberMode m)
{
KyberConstants mode(m);

std::vector<uint8_t> poly_vec, seed;

switch(encoding)
if(pub_key.size() != mode.public_key_byte_length())
{
case KyberKeyEncoding::Full:
BER_Decoder(pub_key)
.start_sequence()
.decode(poly_vec, ASN1_Type::OctetString)
.decode(seed, ASN1_Type::OctetString)
.end_cons();
break;
case KyberKeyEncoding::Raw:
if(pub_key.size() != mode.public_key_byte_length())
{
throw Invalid_Argument("kyber public key does not have the correct byte count");
}
poly_vec = std::vector<uint8_t>(pub_key.begin(), pub_key.end() - KyberConstants::kSeedLength);
seed = std::vector<uint8_t>(pub_key.end() - KyberConstants::kSeedLength, pub_key.end());
break;
throw Invalid_Argument("kyber public key does not have the correct byte count");
}

if(poly_vec.size() != mode.polynomial_vector_byte_length())
{
throw Invalid_Argument("kyber public key t-param does not have the correct byte count");
}
BufferSlicer s(pub_key);

if(seed.size() != KyberConstants::kSeedLength)
{
throw Invalid_Argument("kyber public key rho-param does not have the correct byte count");
}
auto poly_vec = s.take(mode.polynomial_vector_byte_length());
auto seed = s.take_vector(KyberConstants::kSeedLength);
BOTAN_ASSERT_NOMSG(s.empty());

m_public = std::make_shared<Kyber_PublicKeyInternal>(std::move(mode), std::move(poly_vec), std::move(seed));
return std::make_shared<Kyber_PublicKeyInternal>(std::move(mode), poly_vec, std::move(seed));
}

Kyber_PublicKey::Kyber_PublicKey(const AlgorithmIdentifier& alg_id,
const std::vector<uint8_t>& key_bits) :
Kyber_PublicKey(key_bits,
KyberMode(alg_id.oid()),
KyberKeyEncoding::Full)
std::span<const uint8_t> key_bits)
: Kyber_PublicKey(key_bits, KyberMode(alg_id.oid()))
{}

Kyber_PublicKey::Kyber_PublicKey(const std::vector<uint8_t>& pub_key,
KyberMode m,
KyberKeyEncoding encoding)
: Kyber_PublicKey()
{
initialize_from_encoding(pub_key, m, encoding);
}
Kyber_PublicKey::Kyber_PublicKey(std::span<const uint8_t> pub_key, KyberMode m)
: m_public(initialize_from_encoding(pub_key, m))
{}

Kyber_PublicKey::Kyber_PublicKey(const Kyber_PublicKey& other)
: m_public(std::make_shared<Kyber_PublicKeyInternal>(*other.m_public)), m_key_encoding(other.m_key_encoding)
{
}
: m_public(std::make_shared<Kyber_PublicKeyInternal>(*other.m_public))
{}

std::vector<uint8_t> Kyber_PublicKey::public_key_bits() const
{
switch(m_key_encoding)
{
case KyberKeyEncoding::Full:
return public_key_bits_der();
case KyberKeyEncoding::Raw:
return public_key_bits_raw();
}

unreachable();
return public_key_bits_raw();
}

const std::vector<uint8_t>& Kyber_PublicKey::public_key_bits_raw() const
Expand All @@ -1483,19 +1448,6 @@ const std::vector<uint8_t>& Kyber_PublicKey::H_public_key_bits_raw() const
return m_public->H_public_key_bits_raw();
}

std::vector<uint8_t> Kyber_PublicKey::public_key_bits_der() const
{
std::vector<uint8_t> output;
DER_Encoder der(output);

der.start_sequence()
.encode(m_public->polynomials().to_bytes<std::vector<uint8_t>>(), ASN1_Type::OctetString)
.encode(m_public->seed(), ASN1_Type::OctetString)
.end_cons();

return output;
}

size_t Kyber_PublicKey::key_length() const
{
return m_public->mode().public_key_byte_length();
Expand All @@ -1514,8 +1466,11 @@ Kyber_PrivateKey::Kyber_PrivateKey(RandomNumberGenerator& rng, KyberMode m)
auto seed = G->process(rng.random_vec(KyberConstants::kSymBytes));

const auto middle = G->output_length() / 2;
std::vector<uint8_t> seed1(seed.begin(), seed.begin() + middle);
secure_vector<uint8_t> seed2(seed.begin() + middle, seed.end());

BufferSlicer s(seed);
auto seed1 = s.take_vector(middle);
auto seed2 = s.take(middle);
BOTAN_ASSERT_NOMSG(s.empty());

auto a = PolynomialMatrix::generate(seed1, false, mode);
auto skpv = PolynomialVector::getnoise_eta1(seed2, 0, mode);
Expand All @@ -1535,130 +1490,46 @@ Kyber_PrivateKey::Kyber_PrivateKey(RandomNumberGenerator& rng, KyberMode m)
}

Kyber_PrivateKey::Kyber_PrivateKey(const AlgorithmIdentifier& alg_id,
const secure_vector<uint8_t>& key_bits) :
Kyber_PrivateKey(key_bits,
KyberMode(alg_id.oid()),
KyberKeyEncoding::Full)
{}
std::span<const uint8_t> key_bits) :
Kyber_PrivateKey(key_bits, KyberMode(alg_id.oid())) {}

Kyber_PrivateKey::Kyber_PrivateKey(const secure_vector<uint8_t>& sk,
KyberMode m,
KyberKeyEncoding encoding)
Kyber_PrivateKey::Kyber_PrivateKey(std::span<const uint8_t> sk, KyberMode m)
{
KyberConstants mode(m);

if(encoding == KyberKeyEncoding::Full)
if(mode.private_key_byte_length() != sk.size())
{
secure_vector<uint8_t> z, skpv;
BER_Object pub_key;

std::vector<uint8_t> pkpv, seed;

auto dec = BER_Decoder(sk)
.start_sequence()
.decode_and_check<size_t>(0, "kyber private key does have a version other than 0")
.decode(z, ASN1_Type::OctetString)
.decode(skpv, ASN1_Type::OctetString);

try
{
dec.start_sequence().decode(pkpv, ASN1_Type::OctetString).decode(seed, ASN1_Type::OctetString).end_cons();
}
catch(const BER_Decoding_Error&)
{
throw Invalid_Argument("reading private key without an embedded public key is not supported");
}

// skipping the public key hash
dec.discard_remaining().end_cons();

if(skpv.size() != mode.polynomial_vector_byte_length())
{
throw Invalid_Argument("kyber private key sample-param does not have the correct byte count");
}

if(z.size() != KyberConstants::kZLength)
{
throw Invalid_Argument("kyber private key z-param does not have the correct byte count");
}

m_public = std::make_shared<Kyber_PublicKeyInternal>(m, std::move(pkpv), std::move(seed));
m_private = std::make_shared<Kyber_PrivateKeyInternal>(std::move(mode),
PolynomialVector::from_bytes(skpv, mode), std::move(z));
throw Invalid_Argument("kyber private key does not have the correct byte count");
}
else if(encoding == KyberKeyEncoding::Raw)
{
if(mode.private_key_byte_length() != sk.size())
{
throw Invalid_Argument("kyber private key does not have the correct byte count");
}

const auto off_pub_key = mode.polynomial_vector_byte_length();
const auto pub_key_len = mode.public_key_byte_length();
BufferSlicer s(sk);

auto skpv = secure_vector<uint8_t>(sk.begin(), sk.begin() + off_pub_key);
auto pub_key = std::vector<uint8_t>(sk.begin() + off_pub_key, sk.begin() + off_pub_key + pub_key_len);
// skipping the public key hash
auto z = secure_vector<uint8_t>(sk.end() - KyberConstants::kZLength, sk.end());
auto skpv = PolynomialVector::from_bytes(s.take(mode.polynomial_vector_byte_length()), mode);
auto pub_key = s.take(mode.public_key_byte_length());
s.skip(KyberConstants::kPublicKeyHashLength);
auto z = s.take_secure_vector(KyberConstants::kZLength);

initialize_from_encoding(pub_key, m, encoding);
m_private = std::make_shared<Kyber_PrivateKeyInternal>(std::move(mode),
PolynomialVector::from_bytes(skpv, mode), std::move(z));
}
BOTAN_ASSERT_NOMSG(s.empty());

m_public = initialize_from_encoding(pub_key, m);
m_private = std::make_shared<Kyber_PrivateKeyInternal>(std::move(mode), std::move(skpv), std::move(z));

BOTAN_ASSERT(m_private && m_public, "reading private key encoding");
}

std::unique_ptr<Public_Key> Kyber_PrivateKey::public_key() const
{
auto public_key = std::make_unique<Kyber_PublicKey>(*this);
public_key->set_binary_encoding(binary_encoding());
return public_key;
return std::make_unique<Kyber_PublicKey>(*this);
}

secure_vector<uint8_t> Kyber_PrivateKey::private_key_bits() const
{
switch(m_key_encoding)
{
case KyberKeyEncoding::Full:
return private_key_bits_der();
case KyberKeyEncoding::Raw:
return private_key_bits_raw();
}

unreachable();
}

secure_vector<uint8_t> Kyber_PrivateKey::private_key_bits_raw() const
{
const auto pub_key = public_key_bits_raw();
const auto pub_key_sv = secure_vector<uint8_t>(pub_key.begin(), pub_key.end());
const auto pub_key_hash = H_public_key_bits_raw();

return concat(m_private->polynomials().to_bytes<secure_vector<uint8_t>>(),
pub_key_sv, pub_key_hash,
public_key_bits_raw(),
H_public_key_bits_raw(),
m_private->z());
}

secure_vector<uint8_t> Kyber_PrivateKey::private_key_bits_der() const
{
secure_vector<uint8_t> output;
DER_Encoder der(output);

const auto pub_key = public_key_bits_der();
const auto pub_key_hash = m_private->mode().H()->process(pub_key);

der.start_sequence()
.encode(size_t(0), ASN1_Type::Integer, ASN1_Class::Universal)
.encode(m_private->z(), ASN1_Type::OctetString)
.encode(m_private->polynomials().to_bytes<secure_vector<uint8_t>>(), ASN1_Type::OctetString)
.raw_bytes(pub_key)
.encode(pub_key_hash, ASN1_Type::OctetString)
.end_cons();

return output;
}

std::unique_ptr<PK_Ops::KEM_Encryption>
Kyber_PublicKey::create_kem_encryption_op(
const std::string& params,
Expand Down
Loading

0 comments on commit ab8ee4c

Please sign in to comment.