From 533b502e2e8eed9661c42f7dee051cb67b4f73f8 Mon Sep 17 00:00:00 2001 From: Alex Weibel Date: Thu, 18 Jul 2024 06:08:25 -0700 Subject: [PATCH] =?UTF-8?q?Update=20s2n=5Fconnection=5Fget=5Fkem=5Fgroup?= =?UTF-8?q?=5Fname()=20to=20work=20with=20ClientHelloRe=E2=80=A6=20(#4652)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/s2n_tls13_pq_handshake_test.c | 12 ++++++++++++ tls/s2n_connection.c | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/unit/s2n_tls13_pq_handshake_test.c b/tests/unit/s2n_tls13_pq_handshake_test.c index d8301c761de..31450fcf28e 100644 --- a/tests/unit/s2n_tls13_pq_handshake_test.c +++ b/tests/unit/s2n_tls13_pq_handshake_test.c @@ -237,6 +237,12 @@ int s2n_test_tls13_pq_handshake(const struct s2n_security_policy *client_sec_pol POSIX_ENSURE_EQ(expected_kem_group->kem, server_conn->kex_params.server_kem_group_params.kem_params.kem); POSIX_ENSURE_EQ(expected_kem_group->curve, server_conn->kex_params.server_kem_group_params.ecc_params.negotiated_curve); POSIX_ENSURE_EQ(NULL, server_conn->kex_params.server_ecc_evp_params.negotiated_curve); + + /* Ensure s2n_connection_get_kem_group_name() gives the correct answer for both client and server */ + POSIX_ENSURE_EQ(strlen(expected_kem_group->name), strlen(s2n_connection_get_kem_group_name(server_conn))); + POSIX_ENSURE_EQ(memcmp(expected_kem_group->name, s2n_connection_get_kem_group_name(server_conn), strlen(expected_kem_group->name)), 0); + POSIX_ENSURE_EQ(strlen(expected_kem_group->name), strlen(s2n_connection_get_kem_group_name(client_conn))); + POSIX_ENSURE_EQ(memcmp(expected_kem_group->name, s2n_connection_get_kem_group_name(client_conn), strlen(expected_kem_group->name)), 0); } else { POSIX_ENSURE_EQ(NULL, client_conn->kex_params.server_kem_group_params.kem_group); POSIX_ENSURE_EQ(NULL, client_conn->kex_params.server_kem_group_params.kem_params.kem); @@ -247,6 +253,12 @@ int s2n_test_tls13_pq_handshake(const struct s2n_security_policy *client_sec_pol POSIX_ENSURE_EQ(NULL, server_conn->kex_params.server_kem_group_params.kem_params.kem); POSIX_ENSURE_EQ(NULL, server_conn->kex_params.server_kem_group_params.ecc_params.negotiated_curve); POSIX_ENSURE_EQ(expected_curve, server_conn->kex_params.server_ecc_evp_params.negotiated_curve); + + /* Ensure s2n_connection_get_curve() gives the correct answer for both client and server */ + POSIX_ENSURE_EQ(strlen(expected_curve->name), strlen(s2n_connection_get_curve(server_conn))); + POSIX_ENSURE_EQ(memcmp(expected_curve->name, s2n_connection_get_curve(server_conn), strlen(expected_curve->name)), 0); + POSIX_ENSURE_EQ(strlen(expected_curve->name), strlen(s2n_connection_get_curve(client_conn))); + POSIX_ENSURE_EQ(memcmp(expected_curve->name, s2n_connection_get_curve(client_conn), strlen(expected_curve->name)), 0); } /* Verify basic properties of secrets */ diff --git a/tls/s2n_connection.c b/tls/s2n_connection.c index 86323d681dd..fa218f9ba7e 100644 --- a/tls/s2n_connection.c +++ b/tls/s2n_connection.c @@ -970,11 +970,11 @@ const char *s2n_connection_get_kem_group_name(struct s2n_connection *conn) { PTR_ENSURE_REF(conn); - if (conn->actual_protocol_version < S2N_TLS13 || !conn->kex_params.client_kem_group_params.kem_group) { + if (conn->actual_protocol_version < S2N_TLS13 || !conn->kex_params.server_kem_group_params.kem_group) { return "NONE"; } - return conn->kex_params.client_kem_group_params.kem_group->name; + return conn->kex_params.server_kem_group_params.kem_group->name; } static S2N_RESULT s2n_connection_get_client_supported_version(struct s2n_connection *conn,