Skip to content

Commit

Permalink
Improve TeleportClient host resolution
Browse files Browse the repository at this point in the history
Update tsh to prefer using ResolveSSHTarget to resolve the target
if labels, search, or predicate expressions were supplied. In the
event that the RPC is not implemented, tsh falls back to the
previous behavior of calling ListUnifiedResources, with one slight
addition. In the event that resolution is ambiguous tsh now
consults the cluster networking config to determine the
ROUTE_TO_MOST_RECENT strategy is in place. If so, the returned
server with the expiry the farthest in the future will be used.
  • Loading branch information
rosstimothy committed Dec 13, 2024
1 parent e2c8b10 commit 41254a8
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 42 deletions.
6 changes: 6 additions & 0 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4261,6 +4261,12 @@ func (c *Client) GetSSHTargets(ctx context.Context, req *proto.GetSSHTargetsRequ
return rsp, trace.Wrap(err)
}

// ResolveSSHTarget gets a server that would match an equivalent ssh dial request.
func (c *Client) ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) {
rsp, err := c.grpc.ResolveSSHTarget(ctx, req)
return rsp, trace.Wrap(err)
}

// CreateSessionTracker creates a tracker resource for an active session.
func (c *Client) CreateSessionTracker(ctx context.Context, st types.SessionTracker) (types.SessionTracker, error) {
v1, ok := st.(*types.SessionTrackerV1)
Expand Down
120 changes: 109 additions & 11 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,7 @@ type TargetNode struct {
func (tc *TeleportClient) GetTargetNodes(ctx context.Context, clt client.ListUnifiedResourcesClient, options SSHOptions) ([]TargetNode, error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/getTargetNodes",
"teleportClient/GetTargetNodes",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()
Expand Down Expand Up @@ -1553,6 +1553,112 @@ func (tc *TeleportClient) GetTargetNodes(ctx context.Context, clt client.ListUni
}, nil
}

// GetTargetNode returns a single host matching the target host provided by users. Host resolution
// honors an explicit host, i.e. tsh ssh user@hostname, label based hosts, i.e. tsh ssh user@foo=bar,
// as well as respecting any proxy templates that are specified.
func (tc *TeleportClient) GetTargetNode(ctx context.Context, clt authclient.ClientI, options *SSHOptions) (*TargetNode, error) {
ctx, span := tc.Tracer.Start(
ctx,
"teleportClient/GetTargetNode",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
)
defer span.End()

if options != nil && options.HostAddress != "" {
return &TargetNode{
Hostname: options.HostAddress,
Addr: options.HostAddress,
}, nil
}

if len(tc.Labels) == 0 && len(tc.SearchKeywords) == 0 && tc.PredicateExpression == "" {
log.Debugf("Using provided host %s", tc.Host)

// detect the common error when users use host:port address format
_, port, err := net.SplitHostPort(tc.Host)
// client has used host:port notation
if err == nil {
return nil, trace.BadParameter("please use ssh subcommand with '--port=%v' flag instead of semicolon", port)
}

addr := net.JoinHostPort(tc.Host, strconv.Itoa(tc.HostPort))
return &TargetNode{
Hostname: tc.Host,
Addr: addr,
}, nil
}

// Query for nodes if labels, fuzzy search, or predicate expressions were provided.
log.Debugf("Attempting to resolve matching host from labels=%v|search=%v|predicate=%v", tc.Labels, tc.SearchKeywords, tc.PredicateExpression)
resp, err := clt.ResolveSSHTarget(ctx, &proto.ResolveSSHTargetRequest{
PredicateExpression: tc.PredicateExpression,
SearchKeywords: tc.SearchKeywords,
Labels: tc.Labels,
})
switch {
//TODO(tross): DELETE IN v20.0.0
case trace.IsNotImplemented(err):
resources, err := client.GetAllUnifiedResources(ctx, clt, &proto.ListUnifiedResourcesRequest{
Kinds: []string{types.KindNode},
SortBy: types.SortBy{Field: types.ResourceMetadataName},
Labels: tc.Labels,
SearchKeywords: tc.SearchKeywords,
PredicateExpression: tc.PredicateExpression,
})
if err != nil {
return nil, trace.Wrap(err)
}

if len(resources) == 0 {
return nil, trace.NotFound("no matching SSH hosts found for search terms or query expression")
}

if len(resources) > 1 {
// If routing does not allow choosing the most recent host, then abort with
// an ambiguous host error.
cnc, err := clt.GetClusterNetworkingConfig(ctx)
if err != nil || cnc.GetRoutingStrategy() != types.RoutingStrategy_MOST_RECENT {
return nil, trace.BadParameter("found multiple matching SSH hosts %v", resources[:2])
}

// Sort the resource by expiry so we can identify the most "recent".
slices.SortFunc(resources, func(a, b *types.EnrichedResource) int {
return a.Expiry().Compare(b.Expiry())
})

}

// Sorting above is oldest expiry to newest expiry, so proceed
// with the last item server in the slice.
server, ok := resources[len(resources)-1].ResourceWithLabels.(types.Server)
if !ok {
return nil, trace.BadParameter("recevied unexpected resource type %T", resources[0].ResourceWithLabels)
}

// Dialing is happening by UUID but a port is still required by
// the Proxy dial request. Zero is an indicator to the Proxy that
// it may chose the appropriate port based on the target server.
return &TargetNode{
Hostname: server.GetHostname(),
Addr: server.GetName() + ":0",
}, nil
case err == nil:
if resp.GetServer() == nil {
return nil, trace.NotFound("no matching SSH hosts found")
}

// Dialing is happening by UUID but a port is still required by
// the Proxy dial request. Zero is an indicator to the Proxy that
// it may chose the appropriate port based on the target server.
return &TargetNode{
Hostname: resp.GetServer().GetHostname(),
Addr: resp.GetServer().GetName() + ":0",
}, nil
default:
return nil, trace.Wrap(err)
}
}

// ReissueUserCerts issues new user certs based on params and stores them in
// the local key agent (usually on disk in ~/.tsh).
func (tc *TeleportClient) ReissueUserCerts(ctx context.Context, cachePolicy CertCachePolicy, params ReissueParams) error {
Expand Down Expand Up @@ -2434,19 +2540,11 @@ func (tc *TeleportClient) SFTP(ctx context.Context, source []string, destination
defer clt.Close()

// Respect any proxy templates and attempt host resolution.
resolvedNodes, err := tc.GetTargetNodes(ctx, clt.AuthClient, SSHOptions{})
target, err := tc.GetTargetNode(ctx, clt.AuthClient, nil)
if err != nil {
return trace.Wrap(err)
}

switch len(resolvedNodes) {
case 1:
case 0:
return trace.NotFound("no matching hosts found")
default:
return trace.BadParameter("multiple matching hosts found")
}

var cfg *sftp.Config
switch {
case isDownload:
Expand All @@ -2469,7 +2567,7 @@ func (tc *TeleportClient) SFTP(ctx context.Context, source []string, destination
}
}

return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, resolvedNodes[0].Addr, cfg))
return trace.Wrap(tc.TransferFiles(ctx, clt, tc.HostLogin, target.Addr, cfg))
}

// TransferFiles copies files between the current machine and the
Expand Down
187 changes: 187 additions & 0 deletions lib/client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/grpc/interceptors"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/cryptosuites"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -1380,6 +1381,192 @@ func TestGetTargetNodes(t *testing.T) {
}
}

type fakeGetTargetNodeClient struct {
authclient.ClientI

nodes []*types.ServerV2
resolved *types.ServerV2
resolveErr error
routeToMostRecent bool
}

func (f fakeGetTargetNodeClient) ListUnifiedResources(ctx context.Context, req *proto.ListUnifiedResourcesRequest) (*proto.ListUnifiedResourcesResponse, error) {
out := make([]*proto.PaginatedResource, 0, len(f.nodes))
for _, n := range f.nodes {
out = append(out, &proto.PaginatedResource{Resource: &proto.PaginatedResource_Node{Node: n}})
}

return &proto.ListUnifiedResourcesResponse{Resources: out}, nil
}

func (f fakeGetTargetNodeClient) ResolveSSHTarget(ctx context.Context, req *proto.ResolveSSHTargetRequest) (*proto.ResolveSSHTargetResponse, error) {
if f.resolveErr != nil {
return nil, f.resolveErr
}

return &proto.ResolveSSHTargetResponse{Server: f.resolved}, nil
}

func (f fakeGetTargetNodeClient) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) {
cfg := types.DefaultClusterNetworkingConfig()
if f.routeToMostRecent {
cfg.SetRoutingStrategy(types.RoutingStrategy_MOST_RECENT)
}

return cfg, nil
}

func TestGetTargetNode(t *testing.T) {
now := time.Now()
then := now.Add(-5 * time.Hour)

tests := []struct {
name string
options *SSHOptions
labels map[string]string
search []string
predicate string
host string
port int
clt fakeGetTargetNodeClient
errAssertion require.ErrorAssertionFunc
expected TargetNode
}{
{
name: "options override",
options: &SSHOptions{
HostAddress: "test:1234",
},
host: "llama",
port: 56789,
errAssertion: require.NoError,
expected: TargetNode{Hostname: "test:1234", Addr: "test:1234"},
},
{
name: "explicit target",
host: "test",
port: 1234,
errAssertion: require.NoError,
expected: TargetNode{Hostname: "test", Addr: "test:1234"},
},
{
name: "resolved labels",
labels: map[string]string{"foo": "bar"},
errAssertion: require.NoError,
expected: TargetNode{Hostname: "resolved-labels", Addr: "abcd:0"},
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "labels"}}},
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-labels"}},
},
},
{
name: "fallback labels",
labels: map[string]string{"foo": "bar"},
errAssertion: require.NoError,
expected: TargetNode{Hostname: "labels", Addr: "abcd:0"},
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "labels"}}},
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-labels"}},
resolveErr: trace.NotImplemented(""),
},
},
{
name: "resolved search",
search: []string{"foo", "bar"},
errAssertion: require.NoError,
expected: TargetNode{Hostname: "resolved-search", Addr: "abcd:0"},
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "search"}}},
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-search"}},
},
},

{
name: "fallback search",
search: []string{"foo", "bar"},
errAssertion: require.NoError,
expected: TargetNode{Hostname: "search", Addr: "abcd:0"},
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "search"}}},
resolveErr: trace.NotImplemented(""),
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-search"}},
},
},
{
name: "resolved predicate",
predicate: `resource.spec.hostname == "test"`,
errAssertion: require.NoError,
expected: TargetNode{Hostname: "resolved-predicate", Addr: "abcd:0"},
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}},
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}},
},
},
{
name: "fallback predicate",
predicate: `resource.spec.hostname == "test"`,
errAssertion: require.NoError,
expected: TargetNode{Hostname: "predicate", Addr: "abcd:0"},
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "predicate"}}},
resolveErr: trace.NotImplemented(""),
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}},
},
},
{
name: "fallback ambiguous hosts",
predicate: `resource.spec.hostname == "test"`,
errAssertion: require.Error,
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{
{Metadata: types.Metadata{Name: "abcd-1"}, Spec: types.ServerSpecV2{Hostname: "predicate"}},
{Metadata: types.Metadata{Name: "abcd-2"}, Spec: types.ServerSpecV2{Hostname: "predicate"}},
},
resolveErr: trace.NotImplemented(""),
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}},
},
},
{
name: "fallback and route to recent",
predicate: `resource.spec.hostname == "test"`,
errAssertion: require.NoError,
expected: TargetNode{Hostname: "predicate-now", Addr: "abcd-1:0"},
clt: fakeGetTargetNodeClient{
nodes: []*types.ServerV2{
{Metadata: types.Metadata{Name: "abcd-0", Expires: &then}, Spec: types.ServerSpecV2{Hostname: "predicate-then"}},
{Metadata: types.Metadata{Name: "abcd-1", Expires: &now}, Spec: types.ServerSpecV2{Hostname: "predicate-now"}},
{Metadata: types.Metadata{Name: "abcd-2", Expires: &then}, Spec: types.ServerSpecV2{Hostname: "predicate-then-again"}},
},
resolveErr: trace.NotImplemented(""),
routeToMostRecent: true,
resolved: &types.ServerV2{Metadata: types.Metadata{Name: "abcd"}, Spec: types.ServerSpecV2{Hostname: "resolved-predicate"}},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
clt := TeleportClient{
Config: Config{
Tracer: tracing.NoopTracer(""),
Labels: test.labels,
SearchKeywords: test.search,
PredicateExpression: test.predicate,
Host: test.host,
HostPort: test.port,
},
}

match, err := clt.GetTargetNode(context.Background(), test.clt, test.options)
test.errAssertion(t, err)
if match == nil {
match = &TargetNode{}
}
require.EqualValues(t, test.expected, *match)
})
}
}

func TestNonRetryableError(t *testing.T) {
orgError := trace.AccessDenied("do not enter")
err := &NonRetryableError{
Expand Down
Loading

0 comments on commit 41254a8

Please sign in to comment.