From 794ca5182160955f2d163a99b6e9cbfbb07d199f Mon Sep 17 00:00:00 2001 From: Przemko Robakowski Date: Thu, 31 Mar 2022 03:52:16 +0200 Subject: [PATCH] Use first available auth server (#11229) (#11598) Currently we use random auth server from the list but if it's unavailable (for example it was restarted but there's still entry in cache, dynamodb backend etc) we return error. This change tries all servers (in random order) and uses first that is available. Closes #10019 (cherry picked from commit 35a9bbc887e877bbcc388658364d71deff0d9794) --- lib/srv/alpnproxy/auth/auth_proxy.go | 23 +++++-- lib/srv/alpnproxy/auth/auth_proxy_test.go | 76 +++++++++++++++++++++++ 2 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 lib/srv/alpnproxy/auth/auth_proxy_test.go diff --git a/lib/srv/alpnproxy/auth/auth_proxy.go b/lib/srv/alpnproxy/auth/auth_proxy.go index ec8740fae261e..f7b5aae4115b3 100644 --- a/lib/srv/alpnproxy/auth/auth_proxy.go +++ b/lib/srv/alpnproxy/auth/auth_proxy.go @@ -18,6 +18,7 @@ package alpnproxyauth import ( "context" + "fmt" "io" "math/rand" "net" @@ -113,13 +114,23 @@ func (s *AuthProxyDialerService) dialLocalAuthServer(ctx context.Context) (net.C if len(authServers) == 0 { return nil, trace.NotFound("empty auth servers list") } - //TODO(smallinksy) Better support for HA. Add dial retry on auth network errors. - authServerIndex := rand.Intn(len(authServers)) - conn, err := net.Dial("tcp", authServers[authServerIndex].GetAddr()) - if err != nil { - return nil, trace.Wrap(err) + var errors []error + + // iterate over the addresses in random order + for len(authServers) > 0 { + l := len(authServers) + authServerIndex := rand.Intn(l) + addr := authServers[authServerIndex].GetAddr() + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err == nil { + return conn, nil + } + errors = append(errors, fmt.Errorf("%s: %w", addr, err)) + authServers[authServerIndex] = authServers[l-1] + authServers = authServers[:l-1] } - return conn, nil + return nil, trace.NewAggregate(errors...) } func (s *AuthProxyDialerService) dialRemoteAuthServer(ctx context.Context, clusterName string) (net.Conn, error) { diff --git a/lib/srv/alpnproxy/auth/auth_proxy_test.go b/lib/srv/alpnproxy/auth/auth_proxy_test.go new file mode 100644 index 0000000000000..ff25f26412124 --- /dev/null +++ b/lib/srv/alpnproxy/auth/auth_proxy_test.go @@ -0,0 +1,76 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package alpnproxyauth + +import ( + "context" + "net" + "testing" + "time" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" + "github.com/stretchr/testify/require" +) + +type mockAuthGetter struct { + servers []types.Server +} + +func (m mockAuthGetter) GetClusterName(...services.MarshalOption) (types.ClusterName, error) { + return nil, nil +} + +func (m mockAuthGetter) GetAuthServers() ([]types.Server, error) { + return m.servers, nil +} + +func TestDialLocalAuthServerNoServers(t *testing.T) { + s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: []types.Server{}}) + _, err := s.dialLocalAuthServer(context.Background()) + require.Error(t, err, "dialLocalAuthServer expected to fail") + require.Equal(t, "empty auth servers list", err.Error()) +} + +func TestDialLocalAuthServerNoAvailableServers(t *testing.T) { + server1, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: "invalid:8000"}) + require.NoError(t, err) + s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: []types.Server{server1}}) + _, err = s.dialLocalAuthServer(context.Background()) + require.Error(t, err, "dialLocalAuthServer expected to fail") + require.Contains(t, err.Error(), "invalid:8000:") +} + +func TestDialLocalAuthServerAvailableServers(t *testing.T) { + socket, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer socket.Close() + server, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: socket.Addr().String()}) + require.NoError(t, err) + servers := []types.Server{server} + // multiple invalid servers to minimize chance that we select good one first try + for i := 0; i < 20; i++ { + server, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: "invalid2:8000"}) + require.NoError(t, err) + servers = append(servers, server) + } + s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: servers}) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err = s.dialLocalAuthServer(ctx) + require.NoError(t, err) +}