Skip to content

Commit

Permalink
refactor: Avoid unnecessary s2n_hmac calls in s2n_record_write (#4539)
Browse files Browse the repository at this point in the history
  • Loading branch information
goatgoose authored May 10, 2024
1 parent 4862e7f commit f28ed07
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 31 deletions.
109 changes: 109 additions & 0 deletions tests/unit/s2n_cbc_test.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

#include "s2n_test.h"
#include "testlib/s2n_testlib.h"

int main(int argc, char **argv)
{
BEGIN_TEST();

char dhparams_pem[S2N_MAX_TEST_PEM_SIZE] = { 0 };
EXPECT_SUCCESS(s2n_read_test_pem(S2N_DEFAULT_TEST_DHPARAMS, dhparams_pem, S2N_MAX_TEST_PEM_SIZE));

DEFER_CLEANUP(struct s2n_cert_chain_and_key *rsa_chain_and_key = NULL, s2n_cert_chain_and_key_ptr_free);
EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&rsa_chain_and_key,
S2N_DEFAULT_TEST_CERT_CHAIN, S2N_DEFAULT_TEST_PRIVATE_KEY));

DEFER_CLEANUP(struct s2n_cert_chain_and_key *ecdsa_chain_and_key = NULL, s2n_cert_chain_and_key_ptr_free);
EXPECT_SUCCESS(s2n_test_cert_chain_and_key_new(&ecdsa_chain_and_key,
S2N_DEFAULT_ECDSA_TEST_CERT_CHAIN, S2N_DEFAULT_ECDSA_TEST_PRIVATE_KEY));

/* Self-talk test */
{
DEFER_CLEANUP(struct s2n_config *client_config = s2n_config_new_minimal(), s2n_config_ptr_free);
EXPECT_NOT_NULL(client_config);
EXPECT_SUCCESS(s2n_config_set_unsafe_for_testing(client_config));

DEFER_CLEANUP(struct s2n_config *server_config = s2n_config_new_minimal(), s2n_config_ptr_free);
EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(server_config, rsa_chain_and_key));
EXPECT_SUCCESS(s2n_config_add_cert_chain_and_key_to_store(server_config, ecdsa_chain_and_key));
EXPECT_SUCCESS(s2n_config_add_dhparams(server_config, dhparams_pem));

/* Test both composite and non-composite CBC ciphers for all CBC cipher suites. */
size_t record_algs_tested = 0;
for (size_t cipher_suite_idx = 0; cipher_suite_idx < cipher_preferences_test_all.count; cipher_suite_idx++) {
uint8_t record_algs = cipher_preferences_test_all.suites[cipher_suite_idx]->num_record_algs;
for (size_t record_alg_idx = 0; record_alg_idx < record_algs; record_alg_idx++) {
struct s2n_cipher_suite test_cipher_suite = *cipher_preferences_test_all.suites[cipher_suite_idx];
test_cipher_suite.record_alg = test_cipher_suite.all_record_algs[record_alg_idx];

/* Skip non-CBC ciphers. */
uint8_t cipher = test_cipher_suite.record_alg->cipher->type;
if (cipher != S2N_CBC && cipher != S2N_COMPOSITE) {
continue;
}

/* Skip unsupported ciphers. */
if (!test_cipher_suite.record_alg->cipher->is_available()) {
continue;
}

struct s2n_cipher_suite *test_cipher_suite_ptr = &test_cipher_suite;
struct s2n_cipher_preferences test_cipher_preferences = {
.count = 1,
.suites = &test_cipher_suite_ptr,
};

struct s2n_security_policy test_security_policy = security_policy_test_all;
test_security_policy.cipher_preferences = &test_cipher_preferences;

DEFER_CLEANUP(struct s2n_connection *client = s2n_connection_new(S2N_CLIENT),
s2n_connection_ptr_free);
EXPECT_NOT_NULL(client);
EXPECT_SUCCESS(s2n_connection_set_config(client, client_config));
client->security_policy_override = &test_security_policy;

DEFER_CLEANUP(struct s2n_connection *server = s2n_connection_new(S2N_SERVER),
s2n_connection_ptr_free);
EXPECT_NOT_NULL(server);
EXPECT_SUCCESS(s2n_connection_set_config(server, server_config));
server->security_policy_override = &test_security_policy;

DEFER_CLEANUP(struct s2n_test_io_pair io_pair = { 0 }, s2n_io_pair_close);
EXPECT_SUCCESS(s2n_io_pair_init_non_blocking(&io_pair));
EXPECT_SUCCESS(s2n_connection_set_io_pair(client, &io_pair));
EXPECT_SUCCESS(s2n_connection_set_io_pair(server, &io_pair));

EXPECT_SUCCESS(s2n_negotiate_test_server_and_client(server, client));

uint8_t negotiated_cipher_suite[S2N_TLS_CIPHER_SUITE_LEN] = { 0 };
EXPECT_SUCCESS(s2n_connection_get_cipher_iana_value(server, negotiated_cipher_suite,
negotiated_cipher_suite + 1));
EXPECT_BYTEARRAY_EQUAL(negotiated_cipher_suite, test_cipher_suite.iana_value,
S2N_TLS_CIPHER_SUITE_LEN);

EXPECT_OK(s2n_send_and_recv_test(client, server));
EXPECT_OK(s2n_send_and_recv_test(server, client));

record_algs_tested += 1;
}
}

EXPECT_TRUE(record_algs_tested > 0);
}

END_TEST();
}
142 changes: 111 additions & 31 deletions tls/s2n_record_write.c
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,106 @@ static inline int s2n_record_encrypt(
return 0;
}

static S2N_RESULT s2n_record_write_mac(struct s2n_connection *conn, struct s2n_blob *record_header,
struct s2n_blob *plaintext, struct s2n_stuffer *out, uint32_t *bytes_written)
{
RESULT_ENSURE_REF(conn);
RESULT_ENSURE_REF(conn->server);
RESULT_ENSURE_REF(conn->client);
RESULT_ENSURE_REF(record_header);
RESULT_ENSURE_REF(plaintext);
RESULT_ENSURE_REF(out);
RESULT_ENSURE_REF(bytes_written);
*bytes_written = 0;

struct s2n_hmac_state *mac = &conn->server->server_record_mac;
const struct s2n_cipher_suite *cipher_suite = conn->server->cipher_suite;
uint8_t *sequence_number = conn->server->server_sequence_number;

if (conn->mode == S2N_CLIENT) {
mac = &conn->client->client_record_mac;
cipher_suite = conn->client->cipher_suite;
sequence_number = conn->client->client_sequence_number;
}

RESULT_ENSURE_REF(cipher_suite);
RESULT_ENSURE_REF(cipher_suite->record_alg);

if (cipher_suite->record_alg->hmac_alg == S2N_HMAC_NONE) {
/* If the S2N_HMAC_NONE algorithm is specified, a MAC should not be explicitly written.
* This is the case for AEAD and Composite cipher types, where the MAC is written as part
* of encryption. This is also the case for plaintext handshake records, where the null
* stream cipher is used.
*/
return S2N_RESULT_OK;
}

/**
*= https://www.rfc-editor.org/rfc/rfc5246#section-6.2.3.1
*# The MAC is generated as:
*#
*# MAC(MAC_write_key, seq_num +
*/
RESULT_GUARD_POSIX(s2n_hmac_update(mac, sequence_number, S2N_TLS_SEQUENCE_NUM_LEN));

struct s2n_stuffer header_stuffer = { 0 };
RESULT_GUARD_POSIX(s2n_stuffer_init_written(&header_stuffer, record_header));

/**
*= https://www.rfc-editor.org/rfc/rfc5246#section-6.2.3.1
*# TLSCompressed.type +
*/
void *record_type_byte = s2n_stuffer_raw_read(&header_stuffer, sizeof(uint8_t));
RESULT_ENSURE_REF(record_type_byte);
RESULT_GUARD_POSIX(s2n_hmac_update(mac, record_type_byte, sizeof(uint8_t)));

/**
*= https://www.rfc-editor.org/rfc/rfc5246#section-6.2.3.1
*# TLSCompressed.version +
*/
void *protocol_version_bytes = s2n_stuffer_raw_read(&header_stuffer, S2N_TLS_PROTOCOL_VERSION_LEN);
RESULT_ENSURE_REF(protocol_version_bytes);
if (conn->actual_protocol_version > S2N_SSLv3) {
/* SSLv3 doesn't include the protocol version in the MAC. */
RESULT_GUARD_POSIX(s2n_hmac_update(mac, protocol_version_bytes, S2N_TLS_PROTOCOL_VERSION_LEN));
}

/**
*= https://www.rfc-editor.org/rfc/rfc5246#section-6.2.3.1
*# TLSCompressed.length +
*
* Note that the length field refers to the length of the plaintext content, not the length of
* TLSCiphertext fragment written to the record header, which accounts for additional fields
* such as the padding and MAC.
*/
uint8_t content_length_bytes[sizeof(uint16_t)] = { 0 };
struct s2n_blob content_length_blob = { 0 };
RESULT_GUARD_POSIX(s2n_blob_init(&content_length_blob, content_length_bytes, sizeof(content_length_bytes)));
struct s2n_stuffer content_length_stuffer = { 0 };
RESULT_GUARD_POSIX(s2n_stuffer_init(&content_length_stuffer, &content_length_blob));
RESULT_GUARD_POSIX(s2n_stuffer_write_uint16(&content_length_stuffer, plaintext->size));
RESULT_GUARD_POSIX(s2n_hmac_update(mac, content_length_bytes, sizeof(content_length_bytes)));

/**
*= https://www.rfc-editor.org/rfc/rfc5246#section-6.2.3.1
*# TLSCompressed.fragment);
*#
*# where "+" denotes concatenation.
*/
RESULT_GUARD_POSIX(s2n_hmac_update(mac, plaintext->data, plaintext->size));

uint8_t mac_digest_size = 0;
RESULT_GUARD_POSIX(s2n_hmac_digest_size(mac->alg, &mac_digest_size));
uint8_t *digest = s2n_stuffer_raw_write(out, mac_digest_size);
RESULT_ENSURE_REF(digest);
RESULT_GUARD_POSIX(s2n_hmac_digest(mac, digest, mac_digest_size));
*bytes_written = mac_digest_size;

RESULT_GUARD_POSIX(s2n_hmac_reset(mac));

return S2N_RESULT_OK;
}

int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const struct iovec *in, int in_count, size_t offs, size_t to_write)
{
if (conn->ktls_send_enabled) {
Expand All @@ -274,14 +374,12 @@ int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const s
}

uint8_t *sequence_number = conn->server->server_sequence_number;
struct s2n_hmac_state *mac = &conn->server->server_record_mac;
struct s2n_session_key *session_key = &conn->server->server_key;
const struct s2n_cipher_suite *cipher_suite = conn->server->cipher_suite;
uint8_t *implicit_iv = conn->server->server_implicit_iv;

if (conn->mode == S2N_CLIENT) {
sequence_number = conn->client->client_sequence_number;
mac = &conn->client->client_record_mac;
session_key = &conn->client->client_key;
cipher_suite = conn->client->cipher_suite;
implicit_iv = conn->client->client_implicit_iv;
Expand All @@ -301,9 +399,6 @@ int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const s
POSIX_ENSURE(s2n_stuffer_data_available(&conn->out) == 0, S2N_ERR_RECORD_STUFFER_NEEDS_DRAINING);
}

uint8_t mac_digest_size = 0;
POSIX_GUARD(s2n_hmac_digest_size(mac->alg, &mac_digest_size));

/* Before we do anything, we need to figure out what the length of the
* fragment is going to be.
*/
Expand All @@ -324,9 +419,6 @@ int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const s
block_size = cipher_suite->record_alg->cipher->io.comp.block_size;
}

/* Start the MAC with the sequence number */
POSIX_GUARD(s2n_hmac_update(mac, sequence_number, S2N_TLS_SEQUENCE_NUM_LEN));

if (s2n_stuffer_is_freed(&conn->out)) {
/* If the output buffer has not been allocated yet, allocate
* at least enough memory to hold a record with the local maximum fragment length.
Expand Down Expand Up @@ -363,17 +455,6 @@ int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const s
POSIX_GUARD(s2n_stuffer_write_uint8(&record_stuffer, record_type));
POSIX_GUARD(s2n_record_write_protocol_version(conn, record_type, &record_stuffer));

/* First write a header that has the payload length, this is for the MAC */
POSIX_GUARD(s2n_stuffer_write_uint16(&record_stuffer, data_bytes_to_take));

if (conn->actual_protocol_version > S2N_SSLv3) {
POSIX_GUARD(s2n_hmac_update(mac, record_stuffer.blob.data, S2N_TLS_RECORD_HEADER_LENGTH));
} else {
/* SSLv3 doesn't include the protocol version in the MAC */
POSIX_GUARD(s2n_hmac_update(mac, record_stuffer.blob.data, 1));
POSIX_GUARD(s2n_hmac_update(mac, record_stuffer.blob.data + 3, 2));
}

/* Compute non-payload parts of the MAC(seq num, type, proto vers, fragment length) for composite ciphers.
* Composite "encrypt" will MAC the payload data and fill in padding.
*/
Expand Down Expand Up @@ -401,7 +482,6 @@ int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const s
/* ensure actual_fragment_length + S2N_TLS_RECORD_HEADER_LENGTH <= max record length */
const uint16_t max_record_length = is_tls13_record ? S2N_TLS13_MAXIMUM_RECORD_LENGTH : S2N_TLS_MAXIMUM_RECORD_LENGTH;
S2N_ERROR_IF(actual_fragment_length + S2N_TLS_RECORD_HEADER_LENGTH > max_record_length, S2N_ERR_RECORD_LENGTH_TOO_LARGE);
POSIX_GUARD(s2n_stuffer_wipe_n(&record_stuffer, 2));
POSIX_GUARD(s2n_stuffer_write_uint16(&record_stuffer, actual_fragment_length));

/* If we're AEAD, write the sequence number as an IV, and generate the AAD */
Expand Down Expand Up @@ -468,22 +548,22 @@ int s2n_record_writev(struct s2n_connection *conn, uint8_t content_type, const s
}
}

/* We are done with this sequence number, so we can increment it */
struct s2n_blob seq = { 0 };
POSIX_GUARD(s2n_blob_init(&seq, sequence_number, S2N_TLS_SEQUENCE_NUM_LEN));
POSIX_GUARD(s2n_increment_sequence_number(&seq));

/* Write the plaintext data */
POSIX_GUARD(s2n_stuffer_writev_bytes(&record_stuffer, in, in_count, offs, data_bytes_to_take));
void *orig_write_ptr = record_stuffer.blob.data + record_stuffer.write_cursor - data_bytes_to_take;
POSIX_GUARD(s2n_hmac_update(mac, orig_write_ptr, data_bytes_to_take));

/* Write the digest */
uint8_t *digest = s2n_stuffer_raw_write(&record_stuffer, mac_digest_size);
POSIX_ENSURE_REF(digest);
/* Write the MAC */
struct s2n_blob header_blob = { 0 };
POSIX_GUARD(s2n_blob_slice(&record_blob, &header_blob, 0, S2N_TLS_RECORD_HEADER_LENGTH));
struct s2n_blob plaintext_blob = { 0 };
POSIX_GUARD(s2n_blob_init(&plaintext_blob, orig_write_ptr, data_bytes_to_take));
uint32_t mac_digest_size = 0;
POSIX_GUARD_RESULT(s2n_record_write_mac(conn, &header_blob, &plaintext_blob, &record_stuffer, &mac_digest_size));

POSIX_GUARD(s2n_hmac_digest(mac, digest, mac_digest_size));
POSIX_GUARD(s2n_hmac_reset(mac));
/* We are done with this sequence number, so we can increment it */
struct s2n_blob seq = { 0 };
POSIX_GUARD(s2n_blob_init(&seq, sequence_number, S2N_TLS_SEQUENCE_NUM_LEN));
POSIX_GUARD(s2n_increment_sequence_number(&seq));

/* Write content type for TLS 1.3 record (RFC 8446 Section 5.2) */
if (is_tls13_record) {
Expand Down

0 comments on commit f28ed07

Please sign in to comment.