From 91d122834835f3c4fbcf8812bd4204ab7496d0ad Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Mon, 18 Jul 2022 17:09:35 -0400 Subject: [PATCH] allow tracing support to be probed multiple times --- api/observability/tracing/ssh/channel.go | 2 +- api/observability/tracing/ssh/client.go | 92 +++++-- api/observability/tracing/ssh/client_test.go | 253 +++++++++++++++++++ api/observability/tracing/ssh/ssh.go | 4 +- api/observability/tracing/ssh/ssh_test.go | 29 ++- 5 files changed, 340 insertions(+), 40 deletions(-) create mode 100644 api/observability/tracing/ssh/client_test.go diff --git a/api/observability/tracing/ssh/channel.go b/api/observability/tracing/ssh/channel.go index a7fa3118ce9c7..22691bb05dc51 100644 --- a/api/observability/tracing/ssh/channel.go +++ b/api/observability/tracing/ssh/channel.go @@ -30,7 +30,7 @@ import ( // Channel is a wrapper around ssh.Channel that adds tracing support. type Channel struct { ssh.Channel - tracingSupported bool + tracingSupported tracingCapability opts []tracing.Option } diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index 43e6e512b557a..af11c3e0312fc 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "net" + "sync" "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" @@ -33,11 +34,24 @@ import ( // Client is a wrapper around ssh.Client that adds tracing support. type Client struct { *ssh.Client - opts []tracing.Option - tracingSupported bool - rejectedError error + opts []tracing.Option + + // mu protects capability and rejectedError which may change based + // on the outcome probing the server for tracing capabilities that + // may occur trying to establish a session + mu sync.RWMutex + capability tracingCapability + rejectedError error } +type tracingCapability int + +const ( + tracingUnknown tracingCapability = iota + tracingUnsupported + tracingSupported +) + // NewClient creates a new Client. // // The server being connected to is probed to determine if it supports @@ -46,39 +60,43 @@ type Client struct { // server will be wrapped in an Envelope with tracing context. All Session // and Channel created from the returned Client will honor the clients view // of whether they should provide tracing context. -// -// Note: a channel is used instead of a global request in order prevent blocking -// forever in the event that the connection is rejected. In that case, the server -// doesn't service any global requests and writes the error to the first opened -// channel. func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, opts ...tracing.Option) *Client { clt := &Client{ Client: ssh.NewClient(c, chans, reqs), opts: opts, } - // Check if the server supports tracing - ch, _, err := clt.Client.OpenChannel(TracingChannel, nil) + clt.capability, clt.rejectedError = isTracingSupported(clt.Client) + + return clt +} + +// isTracingSupported determines whether the ssh server supports +// tracing payloads by trying to open a TracingChannel. +// +// Note: a channel is used instead of a global request in order prevent blocking +// forever in the event that the connection is rejected. In that case, the server +// doesn't service any global requests and writes the error to the first opened +// channel. +func isTracingSupported(clt *ssh.Client) (tracingCapability, error) { + ch, _, err := clt.OpenChannel(TracingChannel, nil) if err != nil { var openError *ssh.OpenChannelError + // prohibited errors due to locks and session control are expected by callers of NewSession if errors.As(err, &openError) { switch openError.Reason { case ssh.Prohibited: - // prohibited errors due to locks and session control are expected by callers of NewSession - clt.rejectedError = err - default: + return tracingUnknown, err + case ssh.UnknownChannelType: + return tracingUnsupported, nil } - - return clt } - return clt + return tracingUnknown, nil } _ = ch.Close() - clt.tracingSupported = true - - return clt + return tracingSupported, nil } // DialContext initiates a connection to the addr from the remote host. @@ -134,7 +152,11 @@ func (c *Client) SendRequest(ctx context.Context, name string, wantReply bool, p ) defer span.End() - ok, resp, err := c.Client.SendRequest(name, wantReply, wrapPayload(ctx, c.tracingSupported, config.TextMapPropagator, payload)) + c.mu.RLock() + capability := c.capability + c.mu.RUnlock() + + ok, resp, err := c.Client.SendRequest(name, wantReply, wrapPayload(ctx, capability, config.TextMapPropagator, payload)) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) @@ -164,7 +186,11 @@ func (c *Client) OpenChannel(ctx context.Context, name string, data []byte) (*Ch ) defer span.End() - ch, reqs, err := c.Client.OpenChannel(name, wrapPayload(ctx, c.tracingSupported, config.TextMapPropagator, data)) + c.mu.RLock() + capability := c.capability + c.mu.RUnlock() + + ch, reqs, err := c.Client.OpenChannel(name, wrapPayload(ctx, capability, config.TextMapPropagator, data)) if err != nil { span.SetStatus(codes.Error, err.Error()) span.RecordError(err) @@ -196,15 +222,35 @@ func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) { ) defer span.End() + c.mu.Lock() + // If the TracingChannel was rejected when the client was created, // the connection was prohibited due to a lock or session control. // Callers to NewSession are expecting to receive the reason the session // was rejected, so we need to propagate the rejectedError here. if c.rejectedError != nil { - return nil, trace.Wrap(c.rejectedError) + err := c.rejectedError + c.rejectedError = nil + c.capability = tracingUnknown + c.mu.Unlock() + return nil, trace.Wrap(err) } - session, err := c.Client.NewSession() + // If the tracing capabilities of the server are unknown due to + // prohibited errors from previous attempts to check, we need to + // do another check to see if our connection will be permitted + // this time. + if c.capability == tracingUnknown { + capability, err := isTracingSupported(c.Client) + if err != nil { + c.mu.Unlock() + return nil, trace.Wrap(err) + } + c.capability = capability + } + + c.mu.Unlock() + session, err := c.Client.NewSession() return session, trace.Wrap(err) } diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go new file mode 100644 index 0000000000000..86bdb1a5634b5 --- /dev/null +++ b/api/observability/tracing/ssh/client_test.go @@ -0,0 +1,253 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "context" + "fmt" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestIsTracingSupported(t *testing.T) { + + rejected := &ssh.OpenChannelError{ + Reason: ssh.Prohibited, + Message: "rejected!", + } + + unknown := &ssh.OpenChannelError{ + Reason: ssh.UnknownChannelType, + Message: "unknown!", + } + + cases := []struct { + name string + channelErr *ssh.OpenChannelError + expectedCapability tracingCapability + errAssertion require.ErrorAssertionFunc + }{ + { + name: "rejected", + channelErr: rejected, + expectedCapability: tracingUnknown, + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.Equal(t, rejected.Error(), err.Error()) + }, + }, + { + name: "unknown", + channelErr: unknown, + expectedCapability: tracingUnsupported, + errAssertion: require.NoError, + }, + { + name: "supported", + channelErr: nil, + expectedCapability: tracingSupported, + errAssertion: require.NoError, + }, + { + name: "other error", + channelErr: &ssh.OpenChannelError{ + Reason: ssh.ConnectionFailed, + Message: "", + }, + expectedCapability: tracingUnknown, + errAssertion: require.NoError, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + errChan := make(chan error, 5) + + srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + for { + select { + case <-ctx.Done(): + return + + case ch := <-channels: + if ch == nil { + return + } + + if tt.channelErr != nil { + if err := ch.Reject(tt.channelErr.Reason, tt.channelErr.Message); err != nil { + errChan <- trace.Wrap(err, "failed to reject channel") + } + return + } + + _, _, err := ch.Accept() + if err != nil { + errChan <- trace.Wrap(err, "failed to accept channel") + return + } + } + } + }) + + go srv.Run(errChan) + + conn, chans, reqs := srv.GetClient(t) + client := ssh.NewClient(conn, chans, reqs) + + capabaility, err := isTracingSupported(client) + require.Equal(t, tt.expectedCapability, capabaility) + tt.errAssertion(t, err) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + }) + } +} + +func TestNewSession(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + errChan := make(chan error, 5) + + first := ssh.OpenChannelError{ + Reason: ssh.Prohibited, + Message: "first attempt", + } + + second := ssh.OpenChannelError{ + Reason: ssh.ConnectionFailed, + Message: "second attempt", + } + + srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + for i := 0; ; i++ { + select { + case <-ctx.Done(): + return + + case ch := <-channels: + switch { + case ch == nil: + return + case ch.ChannelType() == "session": + _, _, err := ch.Accept() + if err != nil { + errChan <- trace.Wrap(err, "failed to accept session channel") + return + } + case i == 0: + if err := ch.Reject(first.Reason, first.Message); err != nil { + errChan <- err + return + } + case i == 1: + if err := ch.Reject(second.Reason, second.Message); err != nil { + errChan <- err + return + } + case i > 2: + if _, _, err := ch.Accept(); err != nil { + errChan <- err + return + } + default: + if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %d", i)); err != nil { + errChan <- err + return + } + } + } + } + }) + + go srv.Run(errChan) + + cases := []struct { + name string + assertionFunc func(t *testing.T, clt *Client, session *ssh.Session, err error) + }{ + { + name: "session prohibited", + assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + // creating a new session should return any errors captured when creating the client + // and not actually probe the server + require.Error(t, err) + require.Equal(t, trace.Unwrap(err).Error(), first.Error()) + require.Nil(t, sess) + require.Nil(t, clt.rejectedError) + require.Equal(t, clt.capability, tracingUnknown) + }, + }, + { + name: "other failure to open tracing channel", + assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + // this time through we should probe the server without getting a prohibited error, + // but things still failed, so we shouldn't know the capability + require.NoError(t, err) + require.NotNil(t, sess) + require.NoError(t, clt.rejectedError) + require.Equal(t, clt.capability, tracingUnknown) + require.NoError(t, sess.Close()) + }, + }, + { + name: "active session", + assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + // all is good now, we should have an active session + require.NoError(t, err) + require.NotNil(t, sess) + require.NoError(t, clt.rejectedError) + require.Equal(t, clt.capability, tracingSupported) + require.NoError(t, sess.Close()) + }, + }, + } + + // check tracing status after first capability probe from creating the client + conn, chans, reqs := srv.GetClient(t) + client := NewClient(conn, chans, reqs) + require.Error(t, client.rejectedError) + require.Equal(t, client.rejectedError.Error(), first.Error()) + require.Equal(t, client.capability, tracingUnknown) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + sess, err := client.NewSession(ctx) + tt.assertionFunc(t, client, sess, err) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + }) + } +} diff --git a/api/observability/tracing/ssh/ssh.go b/api/observability/tracing/ssh/ssh.go index a076dbb709342..56590120e09fe 100644 --- a/api/observability/tracing/ssh/ssh.go +++ b/api/observability/tracing/ssh/ssh.go @@ -191,8 +191,8 @@ func createEnvelope(ctx context.Context, propagator propagation.TextMapPropagato // wrapPayload wraps the provided payload within an envelope if tracing is // enabled and there is any tracing information to propagate. Otherwise, the // original payload is returned -func wrapPayload(ctx context.Context, supported bool, propagator propagation.TextMapPropagator, payload []byte) []byte { - if !supported { +func wrapPayload(ctx context.Context, supported tracingCapability, propagator propagation.TextMapPropagator, payload []byte) []byte { + if supported != tracingSupported { return payload } diff --git a/api/observability/tracing/ssh/ssh_test.go b/api/observability/tracing/ssh/ssh_test.go index cb3298a59e1e0..e7ace705dfb31 100644 --- a/api/observability/tracing/ssh/ssh_test.go +++ b/api/observability/tracing/ssh/ssh_test.go @@ -125,7 +125,7 @@ func newServer(t *testing.T, handler func(*ssh.ServerConn, <-chan ssh.NewChannel } type handler struct { - tracingSupported bool + tracingSupported tracingCapability errChan chan error ctx context.Context } @@ -154,7 +154,7 @@ func (h handler) handle(sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs func (h handler) requestHandler(req *ssh.Request) { switch { - case req.Type == TracingChannel && h.tracingSupported: + case req.Type == TracingChannel && h.tracingSupported == tracingSupported: if err := req.Reply(true, nil); err != nil { h.errChan <- err } @@ -168,11 +168,11 @@ func (h handler) requestHandler(req *ssh.Request) { }() switch h.tracingSupported { - case false: + case tracingUnsupported: if subtle.ConstantTimeCompare(req.Payload, []byte(testPayload)) != 1 { h.errChan <- errors.New("payload mismatch") } - case true: + case tracingSupported: var envelope Envelope if err := json.Unmarshal(req.Payload, &envelope); err != nil { h.errChan <- trace.Wrap(err, "failed to unmarshal envelope") @@ -198,11 +198,11 @@ func (h handler) channelHandler(ch ssh.NewChannel) { switch ch.ChannelType() { case TracingChannel: switch h.tracingSupported { - case false: + case tracingUnsupported: if err := ch.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil { h.errChan <- trace.Wrap(err, "failed to reject channel") } - case true: + case tracingSupported: ch.Accept() return } @@ -267,15 +267,15 @@ func (h handler) subsystemHandler(req *ssh.Request) { func TestClient(t *testing.T) { cases := []struct { name string - tracingSupported bool + tracingSupported tracingCapability }{ { name: "server supports tracing", - tracingSupported: true, + tracingSupported: tracingSupported, }, { name: "server does not support tracing", - tracingSupported: false, + tracingSupported: tracingSupported, }, } @@ -304,7 +304,7 @@ func TestClient(t *testing.T) { tracing.WithTracerProvider(tp), tracing.WithTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})), ) - require.Equal(t, handler.tracingSupported, client.tracingSupported) + require.Equal(t, handler.tracingSupported, client.capability) ctx, span := tp.Tracer("test").Start(context.Background(), "test") t.Cleanup(func() { span.End() }) @@ -356,31 +356,32 @@ func TestWrapPayload(t *testing.T) { cases := []struct { name string ctx context.Context - supported bool + supported tracingCapability propagator propagation.TextMapPropagator payloadAssertion require.ComparisonAssertionFunc }{ { name: "unsupported returns provided payload", ctx: recordingCtx, + supported: tracingUnsupported, payloadAssertion: require.Equal, }, { name: "non-recording spans aren't propagated", - supported: true, + supported: tracingSupported, ctx: nonRecordingCtx, payloadAssertion: require.Equal, }, { name: "empty trace context is not propagated", - supported: true, + supported: tracingSupported, ctx: emptyCtx, payloadAssertion: require.Equal, }, { name: "recording spans are propagated", - supported: true, + supported: tracingSupported, ctx: recordingCtx, propagator: propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}), payloadAssertion: func(t require.TestingT, i interface{}, i2 interface{}, i3 ...interface{}) {