Skip to content

Commit

Permalink
Merge pull request #3539 from randombit/fix/eph_key_callbacks
Browse files Browse the repository at this point in the history
FIX: allow custom KEX logic for TLS 1.2 server
  • Loading branch information
reneme authored May 3, 2023
2 parents 1881cd8 + 717b510 commit b37705f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 37 deletions.
36 changes: 14 additions & 22 deletions src/lib/tls/tls12/msg_client_kex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,33 +289,25 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<uint8_t>& contents,
kex_algo == Kex_Algo::ECDH ||
kex_algo == Kex_Algo::ECDHE_PSK)
{
const Private_Key& private_key = state.server_kex()->server_kex_key();
const PK_Key_Agreement_Key& ka_key = state.server_kex()->server_kex_key();

const PK_Key_Agreement_Key* ka_key =
dynamic_cast<const PK_Key_Agreement_Key*>(&private_key);
const std::vector<uint8_t> client_pubkey =
(ka_key.algo_name() == "DH") ? reader.get_range<uint8_t>(2, 0, 65535)
: reader.get_range<uint8_t>(1, 1, 255);

if(!ka_key)
throw Internal_Error("Expected key agreement key type but got " +
private_key.algo_name());

std::vector<uint8_t> client_pubkey;

if(ka_key->algo_name() == "DH")
{
client_pubkey = reader.get_range<uint8_t>(2, 0, 65535);
}
else
{
client_pubkey = reader.get_range<uint8_t>(1, 1, 255);
}
const auto shared_group = state.server_kex()->shared_group();
BOTAN_STATE_CHECK(shared_group && shared_group.value() != Group_Params::NONE);

try
{
PK_Key_Agreement ka(*ka_key, rng, "Raw");

secure_vector<uint8_t> shared_secret = ka.derive_key(0, client_pubkey).bits_of();
auto shared_secret =
state.callbacks().tls_ephemeral_key_agreement(shared_group.value(),
ka_key,
client_pubkey,
rng,
policy);

if(ka_key->algo_name() == "DH")
if(ka_key.algo_name() == "DH")
shared_secret = CT::strip_leading_zeros(shared_secret);

if(kex_algo == Kex_Algo::ECDHE_PSK)
Expand All @@ -338,7 +330,7 @@ Client_Key_Exchange::Client_Key_Exchange(const std::vector<uint8_t>& contents,
* failure condition, randomize the pre-master output and carry on,
* allowing the protocol to fail later in the finished checks.
*/
rng.random_vec(m_pre_master, ka_key->public_value().size());
rng.random_vec(m_pre_master, ka_key.public_value().size());
}

reader.assert_done();
Expand Down
36 changes: 22 additions & 14 deletions src/lib/tls/tls12/msg_server_kex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
{
const std::vector<Group_Params> dh_groups = state.client_hello()->supported_dh_groups();

Group_Params shared_group = Group_Params::NONE;
m_shared_group = Group_Params::NONE;

/*
If the client does not send any DH groups in the supported groups
Expand All @@ -60,20 +60,28 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,

if(dh_groups.empty())
{
shared_group = policy.default_dh_group();
m_shared_group = policy.default_dh_group();
}
else
{
shared_group = policy.choose_key_exchange_group(dh_groups, {});
m_shared_group = policy.choose_key_exchange_group(dh_groups, {});
}

if(shared_group == Group_Params::NONE)
if(m_shared_group.value() == Group_Params::NONE)
throw TLS_Exception(Alert::HandshakeFailure,
"Could not agree on a DH group with the client");

BOTAN_ASSERT(group_param_is_dh(shared_group), "DH groups for the DH ciphersuites god");

m_kex_key = state.callbacks().tls_generate_ephemeral_key(shared_group, rng);
BOTAN_ASSERT(group_param_is_dh(m_shared_group.value()), "DH ciphersuite is using a finite field group");

// Note: TLS 1.2 allows defining and using arbitrary DH groups (additional
// to the named and standardized ones). This API doesn't allow the
// server to make use of that at the moment. TLS 1.3 does not
// provide this flexibility!
//
// A possible implementation strategy in case one would ever need that:
// `Policy::default_dh_group()` could return a `std::variant<Group_Params,
// DL_Group>`, allowing it to define arbitrary groups.
m_kex_key = state.callbacks().tls_generate_ephemeral_key(m_shared_group.value(), rng);
auto dh = dynamic_cast<DH_PrivateKey*>(m_kex_key.get());
if(!dh)
{
Expand All @@ -91,16 +99,16 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
if(ec_groups.empty())
throw Internal_Error("Client sent no ECC extension but we negotiated ECDH");

Group_Params shared_group = policy.choose_key_exchange_group(ec_groups, {});
m_shared_group = policy.choose_key_exchange_group(ec_groups, {});

if(shared_group == Group_Params::NONE)
if(m_shared_group.value() == Group_Params::NONE)
throw TLS_Exception(Alert::HandshakeFailure, "No shared ECC group with client");

std::vector<uint8_t> ecdh_public_val;

if(shared_group == Group_Params::X25519)
if(m_shared_group.value() == Group_Params::X25519)
{
m_kex_key = state.callbacks().tls_generate_ephemeral_key(shared_group, rng);
m_kex_key = state.callbacks().tls_generate_ephemeral_key(m_shared_group.value(), rng);
if(!m_kex_key)
{
throw TLS_Exception(Alert::InternalError,
Expand All @@ -110,7 +118,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
}
else
{
m_kex_key = state.callbacks().tls_generate_ephemeral_key(shared_group, rng);
m_kex_key = state.callbacks().tls_generate_ephemeral_key(m_shared_group.value(), rng);
auto ecdh = dynamic_cast<ECDH_PrivateKey*>(m_kex_key.get());
if(!ecdh)
{
Expand All @@ -123,7 +131,7 @@ Server_Key_Exchange::Server_Key_Exchange(Handshake_IO& io,
EC_Point_Format::Compressed : EC_Point_Format::Uncompressed);
}

const uint16_t named_curve_id = static_cast<uint16_t>(shared_group);
const uint16_t named_curve_id = static_cast<uint16_t>(m_shared_group.value());
m_params.push_back(3); // named curve
m_params.push_back(get_byte<0>(named_curve_id));
m_params.push_back(get_byte<1>(named_curve_id));
Expand Down Expand Up @@ -258,7 +266,7 @@ bool Server_Key_Exchange::verify(const Public_Key& server_key,
#endif
}

const Private_Key& Server_Key_Exchange::server_kex_key() const
const PK_Key_Agreement_Key& Server_Key_Exchange::server_kex_key() const
{
BOTAN_ASSERT_NONNULL(m_kex_key);
return *m_kex_key;
Expand Down
9 changes: 8 additions & 1 deletion src/lib/tls/tls_messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,13 @@ class BOTAN_UNSTABLE_API Server_Key_Exchange final : public Handshake_Message
const Policy& policy) const;

// Only valid for certain kex types
const Private_Key& server_kex_key() const;
const PK_Key_Agreement_Key& server_kex_key() const;

/**
* @returns the agreed upon KEX group or std::nullopt if the KEX type does
* not depend on a group
*/
const std::optional<Group_Params>& shared_group() const { return m_shared_group; }

Server_Key_Exchange(Handshake_IO& io,
Handshake_State& state,
Expand All @@ -889,6 +895,7 @@ class BOTAN_UNSTABLE_API Server_Key_Exchange final : public Handshake_Message
std::vector<uint8_t> serialize() const override;

std::unique_ptr<PK_Key_Agreement_Key> m_kex_key;
std::optional<Group_Params> m_shared_group;

std::vector<uint8_t> m_params;

Expand Down

0 comments on commit b37705f

Please sign in to comment.