Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssl: Add an SslCtxCb hook to Ssl::HandshakerFactory. #15179

Merged
merged 13 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/envoy/ssl/context_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class ContextConfig {
* @return the set of capabilities for handshaker instances created by this context.
*/
virtual HandshakerCapabilities capabilities() const PURE;

/**
* @return a callback for configuring an SSL_CTX before use.
*/
virtual SslCtxCb sslctxCb() const PURE;
};

class ClientContextConfig : public virtual ContextConfig {
Expand Down
11 changes: 11 additions & 0 deletions include/envoy/ssl/handshaker.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ using HandshakerSharedPtr = std::shared_ptr<Handshaker>;
using HandshakerFactoryCb =
std::function<HandshakerSharedPtr(bssl::UniquePtr<SSL>, int, HandshakeCallbacks*)>;

// Callback for modifying an SSL_CTX.
using SslCtxCb = std::function<void(SSL_CTX*)>;

class HandshakerFactoryContext {
public:
virtual ~HandshakerFactoryContext() = default;
Expand Down Expand Up @@ -115,6 +118,14 @@ class HandshakerFactory : public Config::TypedFactory {
* capability.
*/
virtual HandshakerCapabilities capabilities() const PURE;

/**
* Implementations should return a callback for configuring an SSL_CTX context
* before it is used to create any SSL objects. Providing
* |handshaker_factory_context| as an argument allows callsites to access the
* API and other factory context methods.
*/
virtual SslCtxCb sslctxCb(HandshakerFactoryContext& handshaker_factory_context) const PURE;
};

} // namespace Ssl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ ContextConfigImpl::ContextConfigImpl(
factory_context.messageValidationVisitor());
}
capabilities_ = handshaker_factory->capabilities();
sslctx_cb_ = handshaker_factory->sslctxCb(handshaker_factory_context);
}

Ssl::CertificateValidationContextConfigPtr ContextConfigImpl::getCombinedValidationContextConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig {
void setSecretUpdateCallback(std::function<void()> callback) override;
Ssl::HandshakerFactoryCb createHandshaker() const override;
Ssl::HandshakerCapabilities capabilities() const override { return capabilities_; }
Ssl::SslCtxCb sslctxCb() const override { return sslctx_cb_; }

Ssl::CertificateValidationContextConfigPtr getCombinedValidationContextConfig(
const envoy::extensions::transport_sockets::tls::v3::CertificateValidationContext&
Expand Down Expand Up @@ -99,6 +100,7 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig {

Ssl::HandshakerFactoryCb handshaker_factory_cb_;
Ssl::HandshakerCapabilities capabilities_;
Ssl::SslCtxCb sslctx_cb_;
};

class ClientContextConfigImpl : public ContextConfigImpl, public Envoy::Ssl::ClientContextConfig {
Expand Down
8 changes: 8 additions & 0 deletions source/extensions/transport_sockets/tls/context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c
}
}

// As late as possible, run the custom SSL_CTX configuration callback on each
// SSL_CTX, if set.
if (auto sslctx_cb = config.sslctxCb(); sslctx_cb) {
for (TlsContext& ctx : tls_contexts_) {
sslctx_cb(ctx.ssl_ctx_.get());
}
}

// Add supported cipher suites from the TLS 1.3 spec:
// https://tools.ietf.org/html/rfc8446#appendix-B.4
// AES-CCM cipher suites are removed (no BoringSSL support).
Expand Down
5 changes: 5 additions & 0 deletions source/extensions/transport_sockets/tls/ssl_handshaker.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ class HandshakerFactoryImpl : public Ssl::HandshakerFactory {
return Ssl::HandshakerCapabilities{};
}

Ssl::SslCtxCb sslctxCb(Ssl::HandshakerFactoryContext&) const override {
// The default handshaker impl doesn't additionally modify SSL_CTX.
return nullptr;
}

static HandshakerFactory* getDefaultHandshakerFactory() {
static HandshakerFactoryImpl default_handshaker_factory;
return &default_handshaker_factory;
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/transport_sockets/tls/ssl_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class SslSocket : public Network::TransportSocket,
void onFailure() override;
Network::TransportSocketCallbacks* transportSocketCallbacks() override { return callbacks_; }

SSL* rawSslForTest() const { return rawSsl(); }

protected:
SSL* rawSsl() const { return info_->ssl_.get(); }

Expand Down
20 changes: 20 additions & 0 deletions test/extensions/transport_sockets/tls/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,26 @@ envoy_cc_test(
],
)

envoy_cc_test(
name = "handshaker_factory_test",
srcs = ["handshaker_factory_test.cc"],
deps = [
"//source/common/stream_info:stream_info_lib",
"//source/extensions/transport_sockets/tls:context_config_lib",
"//source/extensions/transport_sockets/tls:context_lib",
"//source/extensions/transport_sockets/tls:ssl_handshaker_lib",
"//source/extensions/transport_sockets/tls:ssl_socket_lib",
"//source/server:process_context_lib",
"//test/mocks/buffer:buffer_mocks",
"//test/mocks/network:network_mocks",
"//test/mocks/runtime:runtime_mocks",
"//test/mocks/server:server_mocks",
"//test/mocks/ssl:ssl_mocks",
"//test/mocks/stats:stats_mocks",
"//test/test_common:registry_lib",
],
)

envoy_cc_benchmark_binary(
name = "tls_throughput_benchmark",
srcs = ["tls_throughput_benchmark.cc"],
Expand Down
149 changes: 149 additions & 0 deletions test/extensions/transport_sockets/tls/handshaker_factory_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include <openssl/ssl3.h>

#include "envoy/network/transport_socket.h"
#include "envoy/ssl/handshaker.h"

#include "common/stream_info/stream_info_impl.h"

#include "server/process_context_impl.h"

#include "extensions/transport_sockets/tls/context_config_impl.h"
#include "extensions/transport_sockets/tls/context_manager_impl.h"
#include "extensions/transport_sockets/tls/ssl_handshaker.h"
#include "extensions/transport_sockets/tls/ssl_socket.h"

#include "test/mocks/network/connection.h"
#include "test/mocks/server/transport_socket_factory_context.h"
#include "test/test_common/registry.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "openssl/evp.h"
#include "openssl/hmac.h"
#include "openssl/ssl.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
namespace Tls {
namespace {

// Test-only custom process object which accepts an `SslCtxCb` for in-test SSL_CTX
// manipulation.
class CustomProcessObjectForTest : public ProcessObject {
public:
CustomProcessObjectForTest(Ssl::SslCtxCb cb) : cb_(cb) {}

Ssl::SslCtxCb getSslCtxCb() { return cb_; }

static CustomProcessObjectForTest* get(const ProcessContextOptRef& process_context_opt_ref) {
auto& process_context = process_context_opt_ref.value().get();
auto& process_object = dynamic_cast<CustomProcessObjectForTest&>(process_context.get());
return &process_object;
}

private:
Ssl::SslCtxCb cb_;
};

// Example SslHandshakerFactoryImpl demonstrating special-case behavior; in this
// case, using a process context to modify the SSL_CTX.
class HandshakerFactoryImplForTest
: public Extensions::TransportSockets::Tls::HandshakerFactoryImpl {
std::string name() const override { return "envoy.testonly_handshaker"; }

Ssl::SslCtxCb sslctxCb(Ssl::HandshakerFactoryContext& handshaker_factory_context) const override {
// Get process object, cast to custom process object, and return custom
// callback.
return CustomProcessObjectForTest::get(handshaker_factory_context.api().processContext())
->getSslCtxCb();
}
};

class HandshakerFactoryTest : public testing::Test {
protected:
HandshakerFactoryTest()
: context_manager_(
std::make_unique<Extensions::TransportSockets::Tls::ContextManagerImpl>(time_system_)),
registered_factory_(handshaker_factory_) {
// UpstreamTlsContext proto expects to use the newly-registered handshaker.
envoy::config::core::v3::TypedExtensionConfig* custom_handshaker_ =
tls_context_.mutable_common_tls_context()->mutable_custom_handshaker();
custom_handshaker_->set_name("envoy.testonly_handshaker");
}

// Helper for downcasting a socket to a test socket so we can examine its
// SSL_CTX.
SSL_CTX* extractSslCtx(Network::TransportSocket* socket) {
SslSocket* ssl_socket = dynamic_cast<SslSocket*>(socket);
SSL* ssl = ssl_socket->rawSslForTest();
return SSL_get_SSL_CTX(ssl);
}

Event::GlobalTimeSystem time_system_;
Stats::IsolatedStoreImpl stats_store_;
std::unique_ptr<Extensions::TransportSockets::Tls::ContextManagerImpl> context_manager_;
HandshakerFactoryImplForTest handshaker_factory_;
Registry::InjectFactory<Ssl::HandshakerFactory> registered_factory_;
envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context_;
};

TEST_F(HandshakerFactoryTest, SetMockFunctionCb) {
testing::MockFunction<void(SSL_CTX*)> cb;
EXPECT_CALL(cb, Call);

CustomProcessObjectForTest custom_process_object_for_test(cb.AsStdFunction());
auto process_context_impl = std::make_unique<Envoy::ProcessContextImpl>(
static_cast<Envoy::ProcessObject&>(custom_process_object_for_test));

NiceMock<Server::Configuration::MockTransportSocketFactoryContext> mock_factory_ctx;
EXPECT_CALL(mock_factory_ctx.api_, processContext())
.WillRepeatedly(
testing::Return(std::reference_wrapper<Envoy::ProcessContext>(*process_context_impl)));

Extensions::TransportSockets::Tls::ClientSslSocketFactory socket_factory(
/*config=*/
std::make_unique<Extensions::TransportSockets::Tls::ClientContextConfigImpl>(
tls_context_, "", mock_factory_ctx),
*context_manager_, stats_store_);

std::unique_ptr<Network::TransportSocket> socket = socket_factory.createTransportSocket(nullptr);

SSL_CTX* ssl_ctx = extractSslCtx(socket.get());

// Compare to the next test, where our custom `sslctxcb` reaches in and sets
// this option.
EXPECT_FALSE(SSL_CTX_get_options(ssl_ctx) & SSL_OP_NO_TLSv1);
}

TEST_F(HandshakerFactoryTest, SetSpecificSslCtxOption) {
CustomProcessObjectForTest custom_process_object_for_test(
/*cb=*/[](SSL_CTX* ssl_ctx) { SSL_CTX_set_options(ssl_ctx, SSL_OP_NO_TLSv1); });
auto process_context_impl = std::make_unique<Envoy::ProcessContextImpl>(
static_cast<Envoy::ProcessObject&>(custom_process_object_for_test));

NiceMock<Server::Configuration::MockTransportSocketFactoryContext> mock_factory_ctx;
EXPECT_CALL(mock_factory_ctx.api_, processContext())
.WillRepeatedly(
testing::Return(std::reference_wrapper<Envoy::ProcessContext>(*process_context_impl)));

Extensions::TransportSockets::Tls::ClientSslSocketFactory socket_factory(
/*config=*/
std::make_unique<Extensions::TransportSockets::Tls::ClientContextConfigImpl>(
tls_context_, "", mock_factory_ctx),
*context_manager_, stats_store_);

std::unique_ptr<Network::TransportSocket> socket = socket_factory.createTransportSocket(nullptr);

SSL_CTX* ssl_ctx = extractSslCtx(socket.get());

// Compare to the previous test, where our mock `sslctxcb` is called, but does
// not set this option.
EXPECT_TRUE(SSL_CTX_get_options(ssl_ctx) & SSL_OP_NO_TLSv1);
}

} // namespace
} // namespace Tls
} // namespace TransportSockets
} // namespace Extensions
} // namespace Envoy
2 changes: 2 additions & 0 deletions test/mocks/ssl/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class MockClientContextConfig : public ClientContextConfig {

MOCK_METHOD(Ssl::HandshakerFactoryCb, createHandshaker, (), (const, override));
MOCK_METHOD(Ssl::HandshakerCapabilities, capabilities, (), (const, override));
MOCK_METHOD(Ssl::SslCtxCb, sslctxCb, (), (const, override));

MOCK_METHOD(const std::string&, serverNameIndication, (), (const));
MOCK_METHOD(bool, allowRenegotiation, (), (const));
Expand All @@ -118,6 +119,7 @@ class MockServerContextConfig : public ServerContextConfig {

MOCK_METHOD(Ssl::HandshakerFactoryCb, createHandshaker, (), (const, override));
MOCK_METHOD(Ssl::HandshakerCapabilities, capabilities, (), (const, override));
MOCK_METHOD(Ssl::SslCtxCb, sslctxCb, (), (const, override));

MOCK_METHOD(bool, requireClientCertificate, (), (const));
MOCK_METHOD(OcspStaplePolicy, ocspStaplePolicy, (), (const));
Expand Down