diff --git a/source/extensions/upstreams/http/udp/upstream_request.cc b/source/extensions/upstreams/http/udp/upstream_request.cc index adb92f2e5763..1fd5de949251 100644 --- a/source/extensions/upstreams/http/udp/upstream_request.cc +++ b/source/extensions/upstreams/http/udp/upstream_request.cc @@ -26,15 +26,34 @@ namespace Http { namespace Udp { void UdpConnPool::newStream(Router::GenericConnectionPoolCallbacks* callbacks) { - Router::UpstreamToDownstream& upstream_to_downstream = callbacks->upstreamToDownstream(); - Network::SocketPtr socket = createSocket(host_); + Envoy::Network::SocketPtr socket = createSocket(host_); + auto source_address_selector = host_->cluster().getUpstreamLocalAddressSelector(); + auto upstream_local_address = source_address_selector->getUpstreamLocalAddress( + host_->address(), /*socket_options=*/nullptr); + if (!Envoy::Network::Socket::applyOptions(upstream_local_address.socket_options_, *socket, + envoy::config::core::v3::SocketOption::STATE_PREBIND)) { + callbacks->onPoolFailure(ConnectionPool::PoolFailureReason::LocalConnectionFailure, + "Failed to apply socket option for UDP upstream", host_); + return; + } + if (upstream_local_address.address_) { + Envoy::Api::SysCallIntResult bind_result = socket->bind(upstream_local_address.address_); + if (bind_result.return_value_ < 0) { + callbacks->onPoolFailure(ConnectionPool::PoolFailureReason::LocalConnectionFailure, + "Failed to bind for UDP upstream", host_); + return; + } + } + const Network::ConnectionInfoProvider& connection_info_provider = socket->connectionInfoProvider(); + Router::UpstreamToDownstream& upstream_to_downstream = callbacks->upstreamToDownstream(); ASSERT(upstream_to_downstream.connection().has_value()); Event::Dispatcher& dispatcher = upstream_to_downstream.connection()->dispatcher(); auto upstream = std::make_unique(&upstream_to_downstream, std::move(socket), host_, dispatcher); StreamInfo::StreamInfoImpl stream_info(dispatcher.timeSource(), nullptr); + callbacks->onPoolReady(std::move(upstream), host_, connection_info_provider, stream_info, {}); } diff --git a/test/extensions/upstreams/http/udp/BUILD b/test/extensions/upstreams/http/udp/BUILD index e94d6b501009..2d757dd91c2f 100644 --- a/test/extensions/upstreams/http/udp/BUILD +++ b/test/extensions/upstreams/http/udp/BUILD @@ -28,6 +28,7 @@ envoy_cc_test( "//test/mocks/upstream:upstream_mocks", "//test/test_common:environment_lib", "//test/test_common:simulated_time_system_lib", + "//test/test_common:threadsafe_singleton_injector_lib", "//test/test_common:utility_lib", ], ) diff --git a/test/extensions/upstreams/http/udp/upstream_request_test.cc b/test/extensions/upstreams/http/udp/upstream_request_test.cc index 3f89801321e2..423b21589589 100644 --- a/test/extensions/upstreams/http/udp/upstream_request_test.cc +++ b/test/extensions/upstreams/http/udp/upstream_request_test.cc @@ -2,6 +2,7 @@ #include "source/common/buffer/buffer_impl.h" #include "source/common/network/address_impl.h" +#include "source/common/network/socket_option_factory.h" #include "source/common/network/utility.h" #include "source/common/router/config_impl.h" #include "source/common/router/router.h" @@ -16,6 +17,7 @@ #include "test/mocks/router/router_filter_interface.h" #include "test/mocks/server/factory_context.h" #include "test/mocks/server/instance.h" +#include "test/test_common/threadsafe_singleton_injector.h" #include "test/test_common/utility.h" #include "gmock/gmock.h" @@ -185,6 +187,68 @@ TEST_F(UdpUpstreamTest, SocketConnectError) { EXPECT_FALSE(udp_upstream_->encodeHeaders(connect_udp_headers_, false).ok()); } +class UdpConnPoolTest : public ::testing::Test { +public: + UdpConnPoolTest() { + ON_CALL(*mock_thread_local_cluster_.lb_.host_, address) + .WillByDefault( + Return(Network::Utility::parseInternetAddressAndPortNoThrow("127.0.0.1:80", false))); + udp_conn_pool_ = std::make_unique(mock_thread_local_cluster_, nullptr); + EXPECT_CALL(*mock_thread_local_cluster_.lb_.host_, address).Times(2); + EXPECT_CALL(*mock_thread_local_cluster_.lb_.host_, cluster); + mock_thread_local_cluster_.lb_.host_->cluster_.source_address_ = + Network::Utility::parseInternetAddressAndPortNoThrow("127.0.0.1:10001", false); + } + +protected: + NiceMock mock_thread_local_cluster_; + std::unique_ptr udp_conn_pool_; + Router::MockGenericConnectionPoolCallbacks mock_callback_; +}; + +TEST_F(UdpConnPoolTest, BindToUpstreamLocalAddress) { + EXPECT_CALL(mock_callback_, upstreamToDownstream); + NiceMock downstream_connection_; + EXPECT_CALL(mock_callback_.upstream_to_downstream_, connection) + .WillRepeatedly( + Return(Envoy::OptRef(downstream_connection_))); + EXPECT_CALL(mock_callback_, onPoolReady); + // Mock syscall to make the bind call succeed. + NiceMock mock_os_sys_calls; + Envoy::TestThreadsafeSingletonInjector os_sys_calls( + &mock_os_sys_calls); + EXPECT_CALL(mock_os_sys_calls, bind).WillOnce(Return(Api::SysCallIntResult{0, 0})); + udp_conn_pool_->newStream(&mock_callback_); +} + +TEST_F(UdpConnPoolTest, ApplySocketOptionsFailure) { + Upstream::UpstreamLocalAddress upstream_local_address = { + mock_thread_local_cluster_.lb_.host_->cluster_.source_address_, + Network::SocketOptionFactory::buildIpFreebindOptions()}; + // Return a socket option to make the setsockopt syscall is called. + EXPECT_CALL(*mock_thread_local_cluster_.lb_.host_->cluster_.upstream_local_address_selector_, + getUpstreamLocalAddressImpl) + .WillOnce(Return(upstream_local_address)); + EXPECT_CALL(mock_callback_, onPoolFailure); + // Mock syscall to make the setsockopt call fail. + NiceMock mock_os_sys_calls; + Envoy::TestThreadsafeSingletonInjector os_sys_calls( + &mock_os_sys_calls); + // Use ON_CALL since the applyOptions call fail without calling the setsockopt_ in Windows. + ON_CALL(mock_os_sys_calls, setsockopt_).WillByDefault(Return(-1)); + udp_conn_pool_->newStream(&mock_callback_); +} + +TEST_F(UdpConnPoolTest, BindFailure) { + EXPECT_CALL(mock_callback_, onPoolFailure); + // Mock syscall to make the bind call fail. + NiceMock mock_os_sys_calls; + Envoy::TestThreadsafeSingletonInjector os_sys_calls( + &mock_os_sys_calls); + EXPECT_CALL(mock_os_sys_calls, bind).WillOnce(Return(Api::SysCallIntResult{-1, 0})); + udp_conn_pool_->newStream(&mock_callback_); +} + } // namespace Udp } // namespace Http } // namespace Upstreams