From 190f9e0cfe16e779f622c16dce8e833600e5fb45 Mon Sep 17 00:00:00 2001
From: Marc Barry <4965634+marc-barry@users.noreply.github.com>
Date: Wed, 15 May 2024 17:04:32 -0400
Subject: [PATCH] external authorization: set the SNI value from server name if
 it isn't available on the connection/socket (#34100)

Signed-off-by: Marc Barry <4965634+marc-barry@users.noreply.github.com>
---
 changelogs/current.yaml                       |  5 +++
 .../common/ext_authz/check_request_utils.cc   | 18 +++++---
 .../common/ext_authz/check_request_utils.h    |  2 +-
 .../ext_authz/check_request_utils_test.cc     | 44 ++++++++++++++++---
 4 files changed, 55 insertions(+), 14 deletions(-)

diff --git a/changelogs/current.yaml b/changelogs/current.yaml
index dcd787a0581f..5300deb60298 100644
--- a/changelogs/current.yaml
+++ b/changelogs/current.yaml
@@ -91,6 +91,11 @@ bug_fixes:
   change: |
     Fix BalsaParser resetting state too early, guarded by default-true
     ``envoy.reloadable_features.http1_balsa_delay_reset``.
+- area: ext_authz
+  change: |
+    Set the SNI value from the requested server name if it isn't available on the connection/socket. This applies when
+    ``include_tls_session`` is true. The requested server name is set on a connection when filters such as the TLS
+    inspector are used.
 
 removed_config_or_runtime:
 # *Normally occurs at the end of the* :ref:`deprecation period <deprecated>`
diff --git a/source/extensions/filters/common/ext_authz/check_request_utils.cc b/source/extensions/filters/common/ext_authz/check_request_utils.cc
index 9cb73acd3ed3..cddf4dc9c07d 100644
--- a/source/extensions/filters/common/ext_authz/check_request_utils.cc
+++ b/source/extensions/filters/common/ext_authz/check_request_utils.cc
@@ -223,10 +223,14 @@ void CheckRequestUtils::setAttrContextRequest(
 
 void CheckRequestUtils::setTLSSession(
     envoy::service::auth::v3::AttributeContext::TLSSession& session,
-    const Ssl::ConnectionInfoConstSharedPtr ssl_info) {
-  if (!ssl_info->sni().empty()) {
+    const Envoy::Network::Connection& connection) {
+  const Ssl::ConnectionInfoConstSharedPtr ssl_info = connection.ssl();
+  if (ssl_info != nullptr && !ssl_info->sni().empty()) {
     const std::string server_name(ssl_info->sni());
     session.set_sni(server_name);
+  } else if (!connection.requestedServerName().empty()) {
+    const std::string server_name(connection.requestedServerName());
+    session.set_sni(server_name);
   }
 }
 
@@ -248,6 +252,7 @@ void CheckRequestUtils::createHttpCheck(
   // *cb->connection(), callbacks->streamInfo() and callbacks->decodingBuffer() are not qualified as
   // const.
   auto* cb = const_cast<Envoy::Http::StreamDecoderFilterCallbacks*>(callbacks);
+
   setAttrContextPeer(*attrs->mutable_source(), *cb->connection(), service, false,
                      include_peer_certificate);
   setAttrContextPeer(*attrs->mutable_destination(), *cb->connection(), EMPTY_STRING, true,
@@ -256,8 +261,8 @@ void CheckRequestUtils::createHttpCheck(
                         cb->decodingBuffer(), headers, max_request_bytes, pack_as_bytes,
                         encode_raw_headers, allowed_headers_matcher, disallowed_headers_matcher);
 
-  if (include_tls_session && cb->connection()->ssl() != nullptr) {
-    setTLSSession(*attrs->mutable_tls_session(), cb->connection()->ssl());
+  if (include_tls_session) {
+    setTLSSession(*attrs->mutable_tls_session(), *cb->connection());
   }
   (*attrs->mutable_destination()->mutable_labels()) = destination_labels;
   // Fill in the context extensions and metadata context.
@@ -280,8 +285,9 @@ void CheckRequestUtils::createTcpCheck(
                      include_peer_certificate);
   setAttrContextPeer(*attrs->mutable_destination(), cb->connection(), server_name, true,
                      include_peer_certificate);
-  if (include_tls_session && cb->connection().ssl() != nullptr) {
-    setTLSSession(*attrs->mutable_tls_session(), cb->connection().ssl());
+
+  if (include_tls_session) {
+    setTLSSession(*attrs->mutable_tls_session(), cb->connection());
   }
   (*attrs->mutable_destination()->mutable_labels()) = destination_labels;
 }
diff --git a/source/extensions/filters/common/ext_authz/check_request_utils.h b/source/extensions/filters/common/ext_authz/check_request_utils.h
index 73563363cce5..0dd2f00802e6 100644
--- a/source/extensions/filters/common/ext_authz/check_request_utils.h
+++ b/source/extensions/filters/common/ext_authz/check_request_utils.h
@@ -143,7 +143,7 @@ class CheckRequestUtils {
       bool encode_raw_headers, const MatcherSharedPtr& allowed_headers_matcher,
       const MatcherSharedPtr& disallowed_headers_matcher);
   static void setTLSSession(envoy::service::auth::v3::AttributeContext::TLSSession& session,
-                            const Ssl::ConnectionInfoConstSharedPtr ssl_info);
+                            const Envoy::Network::Connection& connection);
   static std::string getHeaderStr(const Envoy::Http::HeaderEntry* entry);
   static Envoy::Http::HeaderMap::Iterate fillHttpHeaders(const Envoy::Http::HeaderEntry&, void*);
 };
diff --git a/test/extensions/filters/common/ext_authz/check_request_utils_test.cc b/test/extensions/filters/common/ext_authz/check_request_utils_test.cc
index d20861e5856d..7026d9b0a727 100644
--- a/test/extensions/filters/common/ext_authz/check_request_utils_test.cc
+++ b/test/extensions/filters/common/ext_authz/check_request_utils_test.cc
@@ -98,7 +98,11 @@ class CheckRequestUtilsTest : public testing::Test {
 
     EXPECT_EQ(want_tls_session != nullptr, request.attributes().has_tls_session());
     if (want_tls_session != nullptr) {
-      EXPECT_EQ(want_tls_session->sni(), request.attributes().tls_session().sni());
+      if (!want_tls_session->sni().empty()) {
+        EXPECT_EQ(want_tls_session->sni(), request.attributes().tls_session().sni());
+      } else {
+        EXPECT_EQ(requested_server_name_, request.attributes().tls_session().sni());
+      }
     }
   }
 
@@ -223,10 +227,10 @@ TEST_F(CheckRequestUtilsTest, TcpPeerCertificate) {
 // Verify that createTcpCheck populates the tls session details correctly.
 TEST_F(CheckRequestUtilsTest, TcpTlsSession) {
   envoy::service::auth::v3::CheckRequest request;
-  EXPECT_CALL(net_callbacks_, connection()).Times(5).WillRepeatedly(ReturnRef(connection_));
+  EXPECT_CALL(net_callbacks_, connection()).Times(4).WillRepeatedly(ReturnRef(connection_));
   connection_.stream_info_.downstream_connection_info_provider_->setRemoteAddress(addr_);
   connection_.stream_info_.downstream_connection_info_provider_->setLocalAddress(addr_);
-  EXPECT_CALL(Const(connection_), ssl()).Times(4).WillRepeatedly(Return(ssl_));
+  EXPECT_CALL(Const(connection_), ssl()).Times(3).WillRepeatedly(Return(ssl_));
   EXPECT_CALL(*ssl_, uriSanPeerCertificate()).WillOnce(Return(std::vector<std::string>{"source"}));
   EXPECT_CALL(*ssl_, uriSanLocalCertificate())
       .WillOnce(Return(std::vector<std::string>{"destination"}));
@@ -240,6 +244,29 @@ TEST_F(CheckRequestUtilsTest, TcpTlsSession) {
   EXPECT_EQ(want_tls_session.sni(), request.attributes().tls_session().sni());
 }
 
+// Verify that createTcpCheck populates the tls session details correctly from the connection when
+// TLS session information isn't present.
+TEST_F(CheckRequestUtilsTest, TcpTlsSessionNoSessionSni) {
+  envoy::service::auth::v3::CheckRequest request;
+  EXPECT_CALL(net_callbacks_, connection()).Times(4).WillRepeatedly(ReturnRef(connection_));
+  connection_.stream_info_.downstream_connection_info_provider_->setRemoteAddress(addr_);
+  connection_.stream_info_.downstream_connection_info_provider_->setLocalAddress(addr_);
+  EXPECT_CALL(connection_, requestedServerName())
+      .Times(3)
+      .WillRepeatedly(Return(requested_server_name_));
+  EXPECT_CALL(Const(connection_), ssl()).Times(3).WillRepeatedly(Return(ssl_));
+  EXPECT_CALL(*ssl_, uriSanPeerCertificate()).WillOnce(Return(std::vector<std::string>{"source"}));
+  EXPECT_CALL(*ssl_, uriSanLocalCertificate())
+      .WillOnce(Return(std::vector<std::string>{"destination"}));
+  envoy::service::auth::v3::AttributeContext_TLSSession want_tls_session;
+  EXPECT_CALL(*ssl_, sni()).WillOnce(ReturnRef(want_tls_session.sni()));
+
+  CheckRequestUtils::createTcpCheck(&net_callbacks_, request, false, true,
+                                    Protobuf::Map<std::string, std::string>());
+  EXPECT_TRUE(request.attributes().has_tls_session());
+  EXPECT_EQ(requested_server_name_, request.attributes().tls_session().sni());
+}
+
 // Verify that createHttpCheck's dependencies are invoked when it's called.
 // Verify that check request object has no request data.
 // Verify that a client supplied EnvoyAuthPartialBody will not affect the
@@ -691,11 +718,11 @@ TEST_F(CheckRequestUtilsTest, CheckAttrContextPeerCertificate) {
 // Verify that the SNI is populated correctly.
 TEST_F(CheckRequestUtilsTest, CheckAttrContextPeerTLSSession) {
   EXPECT_CALL(callbacks_, connection())
-      .Times(4)
+      .Times(3)
       .WillRepeatedly(Return(OptRef<const Network::Connection>{connection_}));
   connection_.stream_info_.downstream_connection_info_provider_->setRemoteAddress(addr_);
   connection_.stream_info_.downstream_connection_info_provider_->setLocalAddress(addr_);
-  EXPECT_CALL(Const(connection_), ssl()).Times(4).WillRepeatedly(Return(ssl_));
+  EXPECT_CALL(Const(connection_), ssl()).Times(3).WillRepeatedly(Return(ssl_));
   EXPECT_CALL(callbacks_, streamId()).WillOnce(Return(0));
   EXPECT_CALL(callbacks_, decodingBuffer()).WillOnce(Return(buffer_.get()));
   EXPECT_CALL(callbacks_, streamInfo()).WillOnce(ReturnRef(req_info_));
@@ -715,11 +742,14 @@ TEST_F(CheckRequestUtilsTest, CheckAttrContextPeerTLSSession) {
 // Verify that the SNI is populated correctly.
 TEST_F(CheckRequestUtilsTest, CheckAttrContextPeerTLSSessionWithoutSNI) {
   EXPECT_CALL(callbacks_, connection())
-      .Times(4)
+      .Times(3)
       .WillRepeatedly(Return(OptRef<const Network::Connection>{connection_}));
   connection_.stream_info_.downstream_connection_info_provider_->setRemoteAddress(addr_);
   connection_.stream_info_.downstream_connection_info_provider_->setLocalAddress(addr_);
-  EXPECT_CALL(Const(connection_), ssl()).Times(4).WillRepeatedly(Return(ssl_));
+  EXPECT_CALL(Const(connection_), ssl()).Times(3).WillRepeatedly(Return(ssl_));
+  EXPECT_CALL(connection_, requestedServerName())
+      .Times(2)
+      .WillRepeatedly(Return(requested_server_name_));
   EXPECT_CALL(callbacks_, streamId()).WillOnce(Return(0));
   EXPECT_CALL(callbacks_, decodingBuffer()).WillOnce(Return(buffer_.get()));
   EXPECT_CALL(callbacks_, streamInfo()).WillOnce(ReturnRef(req_info_));