diff --git a/api/observability/tracing/option.go b/api/observability/tracing/option.go new file mode 100644 index 0000000000000..5b48637f5ab21 --- /dev/null +++ b/api/observability/tracing/option.go @@ -0,0 +1,73 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tracing + +import ( + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + oteltrace "go.opentelemetry.io/otel/trace" +) + +// Option applies an option value for a Config. +type Option interface { + apply(*Config) +} + +// Config stores tracing related properties to customize +// creating Tracers and extracting TraceContext +type Config struct { + TracerProvider oteltrace.TracerProvider + TextMapPropagator propagation.TextMapPropagator +} + +// NewConfig returns a Config configured with all the passed Option. +func NewConfig(opts []Option) *Config { + c := &Config{ + TracerProvider: otel.GetTracerProvider(), + TextMapPropagator: otel.GetTextMapPropagator(), + } + for _, o := range opts { + o.apply(c) + } + return c +} + +type tracerProviderOption struct{ tp oteltrace.TracerProvider } + +func (o tracerProviderOption) apply(c *Config) { + if o.tp != nil { + c.TracerProvider = o.tp + } +} + +// WithTracerProvider returns an Option to use the trace.TracerProvider when +// creating a trace.Tracer. +func WithTracerProvider(tp oteltrace.TracerProvider) Option { + return tracerProviderOption{tp: tp} +} + +type propagatorOption struct{ p propagation.TextMapPropagator } + +func (o propagatorOption) apply(c *Config) { + if o.p != nil { + c.TextMapPropagator = o.p + } +} + +// WithTextMapPropagator returns an Option to use the propagation.TextMapPropagator when extracting +// and injecting trace context. +func WithTextMapPropagator(p propagation.TextMapPropagator) Option { + return propagatorOption{p: p} +} diff --git a/api/observability/tracing/ssh/channel.go b/api/observability/tracing/ssh/channel.go new file mode 100644 index 0000000000000..22691bb05dc51 --- /dev/null +++ b/api/observability/tracing/ssh/channel.go @@ -0,0 +1,106 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "context" + "encoding/json" + "fmt" + + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/observability/tracing" +) + +// Channel is a wrapper around ssh.Channel that adds tracing support. +type Channel struct { + ssh.Channel + tracingSupported tracingCapability + opts []tracing.Option +} + +// NewTraceChannel creates a new Channel. +func NewTraceChannel(ch ssh.Channel, opts ...tracing.Option) *Channel { + return &Channel{ + Channel: ch, + opts: opts, + } +} + +// SendRequest sends a global request, and returns the +// reply. If tracing is enabled, the provided payload +// is wrapped in an Envelope to forward any tracing context. +func (c *Channel) SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, error) { + config := tracing.NewConfig(c.opts) + tracer := config.TracerProvider.Tracer(instrumentationName) + + ctx, span := tracer.Start( + ctx, + fmt.Sprintf("ssh.ChannelRequest/%s", name), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Channel"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + ok, err := c.Channel.SendRequest(name, wantReply, wrapPayload(ctx, c.tracingSupported, config.TextMapPropagator, payload)) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.RecordError(err) + } + + return ok, err +} + +// NewChannel is a wrapper around ssh.NewChannel that allows an +// Envelope to be provided to new channels. +type NewChannel struct { + ssh.NewChannel + Envelope Envelope +} + +// NewTraceNewChannel wraps the ssh.NewChannel in a new NewChannel +// +// The provided ssh.NewChannel will have any Envelope provided +// via ExtraData extracted so that the original payload can be +// provided to callers of NewCh.ExtraData. +func NewTraceNewChannel(nch ssh.NewChannel) *NewChannel { + ch := &NewChannel{ + NewChannel: nch, + } + + data := nch.ExtraData() + + var envelope Envelope + if err := json.Unmarshal(data, &envelope); err == nil { + ch.Envelope = envelope + } else { + ch.Envelope.Payload = data + } + + return ch +} + +// ExtraData returns the arbitrary payload for this channel, as supplied +// by the client. This data is specific to the channel type. +func (n NewChannel) ExtraData() []byte { + return n.Envelope.Payload +} diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go new file mode 100644 index 0000000000000..af11c3e0312fc --- /dev/null +++ b/api/observability/tracing/ssh/client.go @@ -0,0 +1,256 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + + "github.com/gravitational/trace" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/observability/tracing" +) + +// Client is a wrapper around ssh.Client that adds tracing support. +type Client struct { + *ssh.Client + opts []tracing.Option + + // mu protects capability and rejectedError which may change based + // on the outcome probing the server for tracing capabilities that + // may occur trying to establish a session + mu sync.RWMutex + capability tracingCapability + rejectedError error +} + +type tracingCapability int + +const ( + tracingUnknown tracingCapability = iota + tracingUnsupported + tracingSupported +) + +// NewClient creates a new Client. +// +// The server being connected to is probed to determine if it supports +// ssh tracing. This is done by attempting to open a TracingChannel channel. +// If the channel is successfully opened then all payloads delivered to the +// server will be wrapped in an Envelope with tracing context. All Session +// and Channel created from the returned Client will honor the clients view +// of whether they should provide tracing context. +func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, opts ...tracing.Option) *Client { + clt := &Client{ + Client: ssh.NewClient(c, chans, reqs), + opts: opts, + } + + clt.capability, clt.rejectedError = isTracingSupported(clt.Client) + + return clt +} + +// isTracingSupported determines whether the ssh server supports +// tracing payloads by trying to open a TracingChannel. +// +// Note: a channel is used instead of a global request in order prevent blocking +// forever in the event that the connection is rejected. In that case, the server +// doesn't service any global requests and writes the error to the first opened +// channel. +func isTracingSupported(clt *ssh.Client) (tracingCapability, error) { + ch, _, err := clt.OpenChannel(TracingChannel, nil) + if err != nil { + var openError *ssh.OpenChannelError + // prohibited errors due to locks and session control are expected by callers of NewSession + if errors.As(err, &openError) { + switch openError.Reason { + case ssh.Prohibited: + return tracingUnknown, err + case ssh.UnknownChannelType: + return tracingUnsupported, nil + } + } + + return tracingUnknown, nil + } + + _ = ch.Close() + return tracingSupported, nil +} + +// DialContext initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) + + _, span := tracer.Start( + ctx, + "ssh.DialContext", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + append( + peerAttr(c.Conn.RemoteAddr()), + attribute.String("network", n), + attribute.String("address", addr), + semconv.RPCServiceKey.String("ssh.Client"), + semconv.RPCMethodKey.String("Dial"), + semconv.RPCSystemKey.String("ssh"), + )..., + ), + ) + defer span.End() + + conn, err := c.Client.Dial(n, addr) + if err != nil { + return nil, trace.Wrap(err) + } + + return conn, nil +} + +// SendRequest sends a global request, and returns the +// reply. If tracing is enabled, the provided payload +// is wrapped in an Envelope to forward any tracing context. +func (c *Client) SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, []byte, error) { + config := tracing.NewConfig(c.opts) + tracer := config.TracerProvider.Tracer(instrumentationName) + + ctx, span := tracer.Start( + ctx, + fmt.Sprintf("ssh.GlobalRequest/%s", name), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + append( + peerAttr(c.Conn.RemoteAddr()), + attribute.Bool("want_reply", wantReply), + semconv.RPCServiceKey.String("ssh.Client"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + )..., + ), + ) + defer span.End() + + c.mu.RLock() + capability := c.capability + c.mu.RUnlock() + + ok, resp, err := c.Client.SendRequest(name, wantReply, wrapPayload(ctx, capability, config.TextMapPropagator, payload)) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.RecordError(err) + } + + return ok, resp, err +} + +// OpenChannel tries to open a channel. If tracing is enabled, +// the provided payload is wrapped in an Envelope to forward +// any tracing context. +func (c *Client) OpenChannel(ctx context.Context, name string, data []byte) (*Channel, <-chan *ssh.Request, error) { + config := tracing.NewConfig(c.opts) + tracer := config.TracerProvider.Tracer(instrumentationName) + ctx, span := tracer.Start( + ctx, + fmt.Sprintf("ssh.OpenChannel/%s", name), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + append( + peerAttr(c.Conn.RemoteAddr()), + semconv.RPCServiceKey.String("ssh.Client"), + semconv.RPCMethodKey.String("OpenChannel"), + semconv.RPCSystemKey.String("ssh"), + )..., + ), + ) + defer span.End() + + c.mu.RLock() + capability := c.capability + c.mu.RUnlock() + + ch, reqs, err := c.Client.OpenChannel(name, wrapPayload(ctx, capability, config.TextMapPropagator, data)) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.RecordError(err) + } + + return &Channel{ + Channel: ch, + opts: c.opts, + }, reqs, err +} + +// NewSession creates a new SSH session that is passed tracing context +// so that spans may be correlated properly over the ssh connection. +func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) { + tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) + + _, span := tracer.Start( + ctx, + "ssh.NewSession", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + append( + peerAttr(c.Conn.RemoteAddr()), + semconv.RPCServiceKey.String("ssh.Client"), + semconv.RPCMethodKey.String("NewSession"), + semconv.RPCSystemKey.String("ssh"), + )..., + ), + ) + defer span.End() + + c.mu.Lock() + + // If the TracingChannel was rejected when the client was created, + // the connection was prohibited due to a lock or session control. + // Callers to NewSession are expecting to receive the reason the session + // was rejected, so we need to propagate the rejectedError here. + if c.rejectedError != nil { + err := c.rejectedError + c.rejectedError = nil + c.capability = tracingUnknown + c.mu.Unlock() + return nil, trace.Wrap(err) + } + + // If the tracing capabilities of the server are unknown due to + // prohibited errors from previous attempts to check, we need to + // do another check to see if our connection will be permitted + // this time. + if c.capability == tracingUnknown { + capability, err := isTracingSupported(c.Client) + if err != nil { + c.mu.Unlock() + return nil, trace.Wrap(err) + } + c.capability = capability + } + + c.mu.Unlock() + + session, err := c.Client.NewSession() + return session, trace.Wrap(err) +} diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go new file mode 100644 index 0000000000000..86bdb1a5634b5 --- /dev/null +++ b/api/observability/tracing/ssh/client_test.go @@ -0,0 +1,253 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "context" + "fmt" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestIsTracingSupported(t *testing.T) { + + rejected := &ssh.OpenChannelError{ + Reason: ssh.Prohibited, + Message: "rejected!", + } + + unknown := &ssh.OpenChannelError{ + Reason: ssh.UnknownChannelType, + Message: "unknown!", + } + + cases := []struct { + name string + channelErr *ssh.OpenChannelError + expectedCapability tracingCapability + errAssertion require.ErrorAssertionFunc + }{ + { + name: "rejected", + channelErr: rejected, + expectedCapability: tracingUnknown, + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.Equal(t, rejected.Error(), err.Error()) + }, + }, + { + name: "unknown", + channelErr: unknown, + expectedCapability: tracingUnsupported, + errAssertion: require.NoError, + }, + { + name: "supported", + channelErr: nil, + expectedCapability: tracingSupported, + errAssertion: require.NoError, + }, + { + name: "other error", + channelErr: &ssh.OpenChannelError{ + Reason: ssh.ConnectionFailed, + Message: "", + }, + expectedCapability: tracingUnknown, + errAssertion: require.NoError, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + errChan := make(chan error, 5) + + srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + for { + select { + case <-ctx.Done(): + return + + case ch := <-channels: + if ch == nil { + return + } + + if tt.channelErr != nil { + if err := ch.Reject(tt.channelErr.Reason, tt.channelErr.Message); err != nil { + errChan <- trace.Wrap(err, "failed to reject channel") + } + return + } + + _, _, err := ch.Accept() + if err != nil { + errChan <- trace.Wrap(err, "failed to accept channel") + return + } + } + } + }) + + go srv.Run(errChan) + + conn, chans, reqs := srv.GetClient(t) + client := ssh.NewClient(conn, chans, reqs) + + capabaility, err := isTracingSupported(client) + require.Equal(t, tt.expectedCapability, capabaility) + tt.errAssertion(t, err) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + }) + } +} + +func TestNewSession(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + errChan := make(chan error, 5) + + first := ssh.OpenChannelError{ + Reason: ssh.Prohibited, + Message: "first attempt", + } + + second := ssh.OpenChannelError{ + Reason: ssh.ConnectionFailed, + Message: "second attempt", + } + + srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + for i := 0; ; i++ { + select { + case <-ctx.Done(): + return + + case ch := <-channels: + switch { + case ch == nil: + return + case ch.ChannelType() == "session": + _, _, err := ch.Accept() + if err != nil { + errChan <- trace.Wrap(err, "failed to accept session channel") + return + } + case i == 0: + if err := ch.Reject(first.Reason, first.Message); err != nil { + errChan <- err + return + } + case i == 1: + if err := ch.Reject(second.Reason, second.Message); err != nil { + errChan <- err + return + } + case i > 2: + if _, _, err := ch.Accept(); err != nil { + errChan <- err + return + } + default: + if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %d", i)); err != nil { + errChan <- err + return + } + } + } + } + }) + + go srv.Run(errChan) + + cases := []struct { + name string + assertionFunc func(t *testing.T, clt *Client, session *ssh.Session, err error) + }{ + { + name: "session prohibited", + assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + // creating a new session should return any errors captured when creating the client + // and not actually probe the server + require.Error(t, err) + require.Equal(t, trace.Unwrap(err).Error(), first.Error()) + require.Nil(t, sess) + require.Nil(t, clt.rejectedError) + require.Equal(t, clt.capability, tracingUnknown) + }, + }, + { + name: "other failure to open tracing channel", + assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + // this time through we should probe the server without getting a prohibited error, + // but things still failed, so we shouldn't know the capability + require.NoError(t, err) + require.NotNil(t, sess) + require.NoError(t, clt.rejectedError) + require.Equal(t, clt.capability, tracingUnknown) + require.NoError(t, sess.Close()) + }, + }, + { + name: "active session", + assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + // all is good now, we should have an active session + require.NoError(t, err) + require.NotNil(t, sess) + require.NoError(t, clt.rejectedError) + require.Equal(t, clt.capability, tracingSupported) + require.NoError(t, sess.Close()) + }, + }, + } + + // check tracing status after first capability probe from creating the client + conn, chans, reqs := srv.GetClient(t) + client := NewClient(conn, chans, reqs) + require.Error(t, client.rejectedError) + require.Equal(t, client.rejectedError.Error(), first.Error()) + require.Equal(t, client.capability, tracingUnknown) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + sess, err := client.NewSession(ctx) + tt.assertionFunc(t, client, sess, err) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + }) + } +} diff --git a/api/observability/tracing/ssh/ssh.go b/api/observability/tracing/ssh/ssh.go index f8d7a1bee0157..56590120e09fe 100644 --- a/api/observability/tracing/ssh/ssh.go +++ b/api/observability/tracing/ssh/ssh.go @@ -21,58 +21,78 @@ import ( "net" "time" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/propagation" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" + "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/trace" log "github.com/sirupsen/logrus" - oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" ) const ( // TracingRequest is sent by clients to server to pass along tracing context. TracingRequest = "tracing@goteleport.com" -) -// Client is a wrapper around ssh.Client that adds tracing support. -type Client struct { - *ssh.Client -} + // TracingChannel is a SSH channel used to indicate that servers support tracing. + TracingChannel = "tracing" -// NewClient creates a new Client. -func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) *Client { - return &Client{Client: ssh.NewClient(c, chans, reqs)} -} + // instrumentationName is the name of this instrumentation package. + instrumentationName = "otelssh" +) -// NewSession creates a new SSH session that is passed tracing context so that spans may be correlated -// properly over the ssh connection. -func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) { - session, err := c.Client.NewSession() - if err != nil { - return nil, trace.Wrap(err) +// ContextFromRequest extracts any tracing data provided via an Envelope +// in the ssh.Request payload. If the payload contains an Envelope, then +// the context returned will have tracing data populated from the remote +// tracing context and the ssh.Request payload will be replaced with the +// original payload from the client. +func ContextFromRequest(req *ssh.Request, opts ...tracing.Option) context.Context { + ctx := context.Background() + + var envelope Envelope + if err := json.Unmarshal(req.Payload, &envelope); err != nil { + return ctx } - span := oteltrace.SpanFromContext(ctx) - if !span.IsRecording() { - return session, nil - } + ctx = tracing.WithPropagationContext(ctx, envelope.PropagationContext, opts...) + req.Payload = envelope.Payload - traceCtx := tracing.PropagationContextFromContext(ctx) - if len(traceCtx) == 0 { - return session, nil - } + return ctx +} + +// ContextFromNewChannel extracts any tracing data provided via an Envelope +// in the ssh.NewChannel ExtraData. If the ExtraData contains an Envelope, then +// the context returned will have tracing data populated from the remote +// tracing context and the ssh.NewChannel wrapped in a TraceCh so that the +// original ExtraData from the client is exposed instead of the Envelope +// payload. +func ContextFromNewChannel(nch ssh.NewChannel, opts ...tracing.Option) (context.Context, ssh.NewChannel) { + ch := NewTraceNewChannel(nch) + ctx := tracing.WithPropagationContext(context.Background(), ch.Envelope.PropagationContext, opts...) + + return ctx, ch +} - payload, err := json.Marshal(traceCtx) +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) { + dialer := net.Dialer{Timeout: config.Timeout} + conn, err := dialer.DialContext(ctx, network, addr) if err != nil { - return nil, trace.Wrap(err) + return nil, err } - - if _, err := session.SendRequest(TracingRequest, false, payload); err != nil { - return nil, trace.Wrap(err) + c, chans, reqs, err := NewClientConn(ctx, conn, addr, config) + if err != nil { + return nil, err } - - return session, nil + return NewClient(c, chans, reqs), nil } // NewClientConn creates a new SSH client connection that is passed tracing context so that spans may be correlated @@ -119,3 +139,72 @@ func NewClientConnWithDeadline(ctx context.Context, conn net.Conn, addr string, } return NewClient(c, chans, reqs), nil } + +// peerAttr returns attributes about the peer address. +func peerAttr(addr net.Addr) []attribute.KeyValue { + host, port, err := net.SplitHostPort(addr.String()) + if err != nil { + return nil + } + + if host == "" { + host = "127.0.0.1" + } + + return []attribute.KeyValue{ + semconv.NetPeerIPKey.String(host), + semconv.NetPeerPortKey.String(port), + } +} + +// Envelope wraps the payload of all ssh messages with +// tracing context. Any servers that reply to a TracingChannel +// will attempt to parse the Envelope for all received requests and +// ensure that the original payload is provided to the handlers. +type Envelope struct { + PropagationContext tracing.PropagationContext + Payload []byte +} + +// createEnvelope wraps the provided payload with a tracing envelope +// that is used to propagate trace context . +func createEnvelope(ctx context.Context, propagator propagation.TextMapPropagator, payload []byte) Envelope { + envelope := Envelope{ + Payload: payload, + } + + span := oteltrace.SpanFromContext(ctx) + if !span.IsRecording() { + return envelope + } + + traceCtx := tracing.PropagationContextFromContext(ctx, tracing.WithTextMapPropagator(propagator)) + if len(traceCtx) == 0 { + return envelope + } + + envelope.PropagationContext = traceCtx + + return envelope +} + +// wrapPayload wraps the provided payload within an envelope if tracing is +// enabled and there is any tracing information to propagate. Otherwise, the +// original payload is returned +func wrapPayload(ctx context.Context, supported tracingCapability, propagator propagation.TextMapPropagator, payload []byte) []byte { + if supported != tracingSupported { + return payload + } + + envelope := createEnvelope(ctx, propagator, payload) + if len(envelope.PropagationContext) == 0 { + return payload + } + + wrappedPayload, err := json.Marshal(envelope) + if err == nil { + return wrappedPayload + } + + return payload +} diff --git a/api/observability/tracing/ssh/ssh_test.go b/api/observability/tracing/ssh/ssh_test.go new file mode 100644 index 0000000000000..cbc7d2c370c77 --- /dev/null +++ b/api/observability/tracing/ssh/ssh_test.go @@ -0,0 +1,410 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "net" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/propagation" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/observability/tracing" +) + +const testPayload = "test" + +type server struct { + listener net.Listener + config *ssh.ServerConfig + handler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request) + + cSigner ssh.Signer + hSigner ssh.Signer +} + +func (s *server) Run(errC chan error) { + for { + conn, err := s.listener.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + errC <- err + } + return + } + + go func() { + sconn, chans, reqs, err := ssh.NewServerConn(conn, s.config) + if err != nil { + errC <- err + return + } + s.handler(sconn, chans, reqs) + }() + } +} + +func (s *server) Stop() error { + return s.listener.Close() +} + +func generateSigner(t *testing.T) ssh.Signer { + private, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(private), + } + + privatePEM := pem.EncodeToMemory(block) + signer, err := ssh.ParsePrivateKey(privatePEM) + require.NoError(t, err) + + return signer +} + +func (s *server) GetClient(t *testing.T) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request) { + conn, err := net.Dial("tcp", s.listener.Addr().String()) + require.NoError(t, err) + + sconn, nc, r, err := ssh.NewClientConn(conn, "", &ssh.ClientConfig{ + Auth: []ssh.AuthMethod{ssh.PublicKeys(s.cSigner)}, + HostKeyCallback: ssh.FixedHostKey(s.hSigner.PublicKey()), + }) + require.NoError(t, err) + + return sconn, nc, r +} + +func newServer(t *testing.T, handler func(*ssh.ServerConn, <-chan ssh.NewChannel, <-chan *ssh.Request)) *server { + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + + cSigner := generateSigner(t) + hSigner := generateSigner(t) + + config := &ssh.ServerConfig{ + NoClientAuth: true, + } + config.AddHostKey(hSigner) + + srv := &server{ + listener: listener, + config: config, + handler: handler, + cSigner: cSigner, + hSigner: hSigner, + } + + t.Cleanup(func() { require.NoError(t, srv.Stop()) }) + + return srv +} + +type handler struct { + tracingSupported tracingCapability + errChan chan error + ctx context.Context +} + +func (h handler) handle(sconn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { + for { + select { + case <-h.ctx.Done(): + return + case req := <-reqs: + if req == nil { + return + } + + h.requestHandler(req) + + case ch := <-chans: + if ch == nil { + return + } + + h.channelHandler(ch) + } + } +} + +func (h handler) requestHandler(req *ssh.Request) { + switch { + case req.Type == TracingChannel && h.tracingSupported == tracingSupported: + if err := req.Reply(true, nil); err != nil { + h.errChan <- err + } + case req.Type == "test": + defer func() { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + h.errChan <- err + } + } + }() + + switch h.tracingSupported { + case tracingUnsupported: + if subtle.ConstantTimeCompare(req.Payload, []byte(testPayload)) != 1 { + h.errChan <- errors.New("payload mismatch") + } + case tracingSupported: + var envelope Envelope + if err := json.Unmarshal(req.Payload, &envelope); err != nil { + h.errChan <- trace.Wrap(err, "failed to unmarshal envelope") + return + } + if len(envelope.PropagationContext) <= 0 { + h.errChan <- errors.New("empty propagation context") + return + } + if subtle.ConstantTimeCompare(envelope.Payload, []byte(testPayload)) != 1 { + h.errChan <- errors.New("payload mismatch") + return + } + } + default: + if err := req.Reply(false, nil); err != nil { + h.errChan <- err + } + } +} + +func (h handler) channelHandler(ch ssh.NewChannel) { + switch ch.ChannelType() { + case TracingChannel: + switch h.tracingSupported { + case tracingUnsupported: + if err := ch.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil { + h.errChan <- trace.Wrap(err, "failed to reject channel") + } + case tracingSupported: + ch.Accept() + return + } + case "session": + // TODO(tross): add tracing checks when session tracing lands + if subtle.ConstantTimeCompare(ch.ExtraData(), []byte(testPayload)) == 1 { + h.errChan <- errors.New("payload mismatch") + } + + _, chReqs, err := ch.Accept() + if err != nil { + h.errChan <- trace.Wrap(err, "failed to accept channel") + return + } + + go func() { + for { + select { + case <-h.ctx.Done(): + return + case req := <-chReqs: + switch req.Type { + case "subsystem": + h.subsystemHandler(req) + } + } + } + }() + default: + if err := ch.Reject(ssh.UnknownChannelType, "unknown channel type"); err != nil { + h.errChan <- trace.Wrap(err, "failed to reject channel") + } + } +} + +type subsystemRequestMsg struct { + Subsystem string +} + +func (h handler) subsystemHandler(req *ssh.Request) { + defer func() { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + h.errChan <- err + } + } + }() + + // TODO(tross): add tracing checks when session tracing lands + + var msg subsystemRequestMsg + if err := ssh.Unmarshal(req.Payload, &msg); err != nil { + h.errChan <- trace.Wrap(err, "failed to unmarshal payload") + return + } + + if msg.Subsystem != "test" { + h.errChan <- errors.New("received wrong subsystem") + } +} + +func TestClient(t *testing.T) { + cases := []struct { + name string + tracingSupported tracingCapability + }{ + { + name: "server supports tracing", + tracingSupported: tracingSupported, + }, + { + name: "server does not support tracing", + tracingSupported: tracingSupported, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + errChan := make(chan error, 5) + + handler := handler{ + tracingSupported: tt.tracingSupported, + errChan: errChan, + ctx: ctx, + } + + srv := newServer(t, handler.handle) + go srv.Run(errChan) + + tp := sdktrace.NewTracerProvider() + conn, chans, reqs := srv.GetClient(t) + client := NewClient( + conn, + chans, + reqs, + tracing.WithTracerProvider(tp), + tracing.WithTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})), + ) + require.Equal(t, handler.tracingSupported, client.capability) + + ctx, span := tp.Tracer("test").Start(context.Background(), "test") + ok, resp, err := client.SendRequest(ctx, "test", true, []byte("test")) + span.End() + require.True(t, ok) + require.Empty(t, resp) + require.NoError(t, err) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + + session, err := client.NewSession(ctx) + require.NoError(t, err) + require.NotNil(t, session) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + + require.NoError(t, session.RequestSubsystem("test")) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } + }) + } +} + +func TestWrapPayload(t *testing.T) { + testPayload := []byte("test") + + nonRecordingCtx, nonRecordingSpan := otel.GetTracerProvider().Tracer("non-recording").Start(context.Background(), "test") + nonRecordingSpan.End() + + emptyCtx, emptySpan := sdktrace.NewTracerProvider().Tracer("empty-trace-context").Start(context.Background(), "test") + t.Cleanup(func() { emptySpan.End() }) + + recordingCtx, recordingSpan := sdktrace.NewTracerProvider().Tracer("recording").Start(context.Background(), "test") + t.Cleanup(func() { recordingSpan.End() }) + cases := []struct { + name string + ctx context.Context + supported tracingCapability + propagator propagation.TextMapPropagator + payloadAssertion require.ComparisonAssertionFunc + }{ + { + name: "unsupported returns provided payload", + ctx: recordingCtx, + supported: tracingUnsupported, + payloadAssertion: require.Equal, + }, + { + + name: "non-recording spans aren't propagated", + supported: tracingSupported, + ctx: nonRecordingCtx, + payloadAssertion: require.Equal, + }, + { + name: "empty trace context is not propagated", + supported: tracingSupported, + ctx: emptyCtx, + payloadAssertion: require.Equal, + }, + { + name: "recording spans are propagated", + supported: tracingSupported, + ctx: recordingCtx, + propagator: propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}), + payloadAssertion: func(t require.TestingT, i interface{}, i2 interface{}, i3 ...interface{}) { + payload, ok := i2.([]byte) + require.True(t, ok) + + require.NotEqual(t, testPayload, payload) + + var envelope Envelope + require.NoError(t, json.Unmarshal(payload, &envelope)) + require.Equal(t, testPayload, envelope.Payload) + require.NotEmpty(t, envelope.PropagationContext) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + if tt.propagator == nil { + tt.propagator = otel.GetTextMapPropagator() + } + payload := wrapPayload(tt.ctx, tt.supported, tt.propagator, testPayload) + tt.payloadAssertion(t, testPayload, payload) + }) + } +} diff --git a/api/observability/tracing/tracing.go b/api/observability/tracing/tracing.go index e9ba4a674cf93..ceacbb0e9e5bc 100644 --- a/api/observability/tracing/tracing.go +++ b/api/observability/tracing/tracing.go @@ -19,6 +19,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" + oteltrace "go.opentelemetry.io/otel/trace" ) // PropagationContext contains tracing information to be passed across service boundaries @@ -26,14 +27,19 @@ type PropagationContext map[string]string // PropagationContextFromContext creates a PropagationContext from the given context.Context. If the context // does not contain any tracing information, the PropagationContext will be empty. -func PropagationContextFromContext(ctx context.Context) PropagationContext { +func PropagationContextFromContext(ctx context.Context, opts ...Option) PropagationContext { carrier := propagation.MapCarrier{} - otel.GetTextMapPropagator().Inject(ctx, &carrier) + NewConfig(opts).TextMapPropagator.Inject(ctx, &carrier) return PropagationContext(carrier) } // WithPropagationContext injects any tracing information from the given PropagationContext into the // given context.Context. -func WithPropagationContext(ctx context.Context, pc PropagationContext) context.Context { - return otel.GetTextMapPropagator().Extract(ctx, propagation.MapCarrier(pc)) +func WithPropagationContext(ctx context.Context, pc PropagationContext, opts ...Option) context.Context { + return NewConfig(opts).TextMapPropagator.Extract(ctx, propagation.MapCarrier(pc)) +} + +// DefaultProvider returns the global default TracerProvider. +func DefaultProvider() oteltrace.TracerProvider { + return otel.GetTracerProvider() } diff --git a/api/utils/sshutils/chconn.go b/api/utils/sshutils/chconn.go index ecd647b63ae6d..0de68696a5f0b 100644 --- a/api/utils/sshutils/chconn.go +++ b/api/utils/sshutils/chconn.go @@ -28,19 +28,27 @@ import ( "golang.org/x/crypto/ssh" ) +type Conn interface { + io.Closer + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + // NewChConn returns a new net.Conn implemented over // SSH channel -func NewChConn(conn ssh.Conn, ch ssh.Channel) *ChConn { +func NewChConn(conn Conn, ch ssh.Channel) *ChConn { return newChConn(conn, ch, false) } // NewExclusiveChConn returns a new net.Conn implemented over // SSH channel, whenever this connection closes -func NewExclusiveChConn(conn ssh.Conn, ch ssh.Channel) *ChConn { +func NewExclusiveChConn(conn Conn, ch ssh.Channel) *ChConn { return newChConn(conn, ch, true) } -func newChConn(conn ssh.Conn, ch ssh.Channel, exclusive bool) *ChConn { +func newChConn(conn Conn, ch ssh.Channel, exclusive bool) *ChConn { reader, writer := net.Pipe() c := &ChConn{ Channel: ch, @@ -68,7 +76,7 @@ type ChConn struct { mu sync.Mutex ssh.Channel - conn ssh.Conn + conn Conn // exclusive indicates that whenever this channel connection // is getting closed, the underlying connection is closed as well exclusive bool diff --git a/lib/client/client.go b/lib/client/client.go index c1758000112a9..f2b85dc0b1ed4 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1341,7 +1341,7 @@ type proxyResponse struct { // isRecordingProxy returns true if the proxy is in recording mode. Note, this // function can only be called after authentication has occurred and should be // called before the first session is created. -func (proxy *ProxyClient) isRecordingProxy() (bool, error) { +func (proxy *ProxyClient) isRecordingProxy(ctx context.Context) (bool, error) { responseCh := make(chan proxyResponse) // we have to run this in a goroutine because older version of Teleport handled @@ -1350,7 +1350,7 @@ func (proxy *ProxyClient) isRecordingProxy() (bool, error) { // don't hear anything back, most likley we are trying to connect to an older // version of Teleport and we should not try and forward our agent. go func() { - ok, responseBytes, err := proxy.Client.SendRequest(teleport.RecordingProxyReqType, true, nil) + ok, responseBytes, err := proxy.Client.SendRequest(ctx, teleport.RecordingProxyReqType, true, nil) if err != nil { responseCh <- proxyResponse{isRecord: false, err: trace.Wrap(err)} return @@ -1533,7 +1533,7 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeAdd // after auth but before we create the first session, find out if the proxy // is in recording mode or not - recordingProxy, err := proxy.isRecordingProxy() + recordingProxy, err := proxy.isRecordingProxy(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -1663,7 +1663,7 @@ func (proxy *ProxyClient) PortForwardToNode(ctx context.Context, nodeAddress Nod // after auth but before we create the first session, find out if the proxy // is in recording mode or not - recordingProxy, err := proxy.isRecordingProxy() + recordingProxy, err := proxy.isRecordingProxy(ctx) if err != nil { return nil, trace.Wrap(err) } @@ -2085,7 +2085,7 @@ func (c *NodeClient) GetRemoteTerminalSize(ctx context.Context, sessionID string ) defer span.End() - ok, payload, err := c.Client.SendRequest(teleport.TerminalSizeRequest, true, []byte(sessionID)) + ok, payload, err := c.Client.SendRequest(ctx, teleport.TerminalSizeRequest, true, []byte(sessionID)) if err != nil { return nil, trace.Wrap(err) } else if !ok { diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 38c55c1836c8a..628de502d65be 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -23,11 +23,14 @@ package reversetunnel import ( "context" "fmt" + "io" "strings" "sync" "time" "github.com/gravitational/teleport/api/constants" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/utils" @@ -58,7 +61,7 @@ type AgentStateCallback func(AgentState) // transporter handles the creation of new transports over ssh. type transporter interface { // Transport creates a new transport. - transport(context.Context, ssh.Channel, <-chan *ssh.Request, ssh.Conn) *transport + transport(context.Context, ssh.Channel, <-chan *ssh.Request, sshutils.Conn) *transport } // sshDialer is an ssh dialer that returns an SSHClient @@ -74,7 +77,11 @@ type versionGetter interface { // SSHClient is a client for an ssh connection. type SSHClient interface { - ssh.Conn + ssh.ConnMetadata + io.Closer + Wait() error + OpenChannel(ctx context.Context, name string, data []byte) (*tracessh.Channel, <-chan *ssh.Request, error) + SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, []byte, error) Principals() []string GlobalRequests() <-chan *ssh.Request HandleChannelOpen(channelType string) <-chan ssh.NewChannel @@ -156,7 +163,7 @@ type agent struct { // an agent is connecting and protects wait groups from being waited on early. doneConnecting chan struct{} // hbChannel is the channel heartbeats are sent over. - hbChannel ssh.Channel + hbChannel *tracessh.Channel // hbRequests are requests going over the heartbeat channel. hbRequests <-chan *ssh.Request // discoveryC receives new discovery channels. @@ -388,7 +395,7 @@ func (a *agent) connect() error { // sendFirstHeartbeat opens the heartbeat channel and sends the first // heartbeat. func (a *agent) sendFirstHeartbeat(ctx context.Context) error { - channel, requests, err := a.client.OpenChannel(chanHeartbeat, nil) + channel, requests, err := a.client.OpenChannel(ctx, chanHeartbeat, nil) if err != nil { return trace.Wrap(err) } @@ -397,7 +404,7 @@ func (a *agent) sendFirstHeartbeat(ctx context.Context) error { a.hbRequests = requests // Send the first ping right away. - if _, err := a.hbChannel.SendRequest("ping", false, nil); err != nil { + if _, err := a.hbChannel.SendRequest(ctx, "ping", false, nil); err != nil { return trace.Wrap(err) } @@ -540,7 +547,7 @@ func (a *agent) handleDrainChannels() error { continue } bytes, _ := a.clock.Now().UTC().MarshalText() - _, err := a.hbChannel.SendRequest("ping", false, bytes) + _, err := a.hbChannel.SendRequest(a.ctx, "ping", false, bytes) if err != nil { a.log.Error(err) return trace.Wrap(err) diff --git a/lib/reversetunnel/agent_test.go b/lib/reversetunnel/agent_test.go index 37916a963d7f5..4ec54ef890852 100644 --- a/lib/reversetunnel/agent_test.go +++ b/lib/reversetunnel/agent_test.go @@ -26,6 +26,8 @@ import ( "time" "github.com/gravitational/teleport" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/reversetunnel/track" "github.com/gravitational/teleport/lib/utils" @@ -36,8 +38,8 @@ import ( ) type mockSSHClient struct { - MockSendRequest func(name string, wantReply bool, payload []byte) (bool, []byte, error) - MockOpenChannel func(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) + MockSendRequest func(ctx context.Context, name string, wantReply bool, payload []byte) (bool, []byte, error) + MockOpenChannel func(ctx context.Context, name string, data []byte) (*tracessh.Channel, <-chan *ssh.Request, error) MockClose func() error MockReply func(*ssh.Request, bool, []byte) error MockPrincipals []string @@ -57,16 +59,16 @@ func (m *mockSSHClient) RemoteAddr() net.Addr { return nil } func (m *mockSSHClient) LocalAddr() net.Addr { return nil } -func (m *mockSSHClient) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { +func (m *mockSSHClient) SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, []byte, error) { if m.MockSendRequest != nil { - return m.MockSendRequest(name, wantReply, payload) + return m.MockSendRequest(ctx, name, wantReply, payload) } return false, nil, trace.NotImplemented("") } -func (m *mockSSHClient) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { +func (m *mockSSHClient) OpenChannel(ctx context.Context, name string, data []byte) (*tracessh.Channel, <-chan *ssh.Request, error) { if m.MockOpenChannel != nil { - return m.MockOpenChannel(name, data) + return m.MockOpenChannel(ctx, name, data) } return nil, nil, trace.NotImplemented("") } @@ -135,7 +137,7 @@ type mockAgentInjection struct { client SSHClient } -func (m *mockAgentInjection) transport(context.Context, ssh.Channel, <-chan *ssh.Request, ssh.Conn) *transport { +func (m *mockAgentInjection) transport(context.Context, ssh.Channel, <-chan *ssh.Request, sshutils.Conn) *transport { return &transport{} } @@ -272,20 +274,26 @@ func TestAgentStart(t *testing.T) { } }() - client.MockOpenChannel = func(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { + client.MockOpenChannel = func(ctx context.Context, name string, data []byte) (*tracessh.Channel, <-chan *ssh.Request, error) { // Block until the version request is handled to ensure we handle // global requests during startup. <-waitForVersion atomic.AddInt32(openChannels, 1) assert.Equal(t, name, chanHeartbeat, "Unexpected channel opened during startup.") - return &mockSSHChannel{MockSendRequest: func(name string, wantReply bool, payload []byte) (bool, error) { - atomic.AddInt32(sentPings, 1) - - assert.Equal(t, name, "ping", "Unexpected request name.") - assert.False(t, wantReply, "Expected no reply wanted.") - return true, nil - }}, make(<-chan *ssh.Request), nil + return tracessh.NewTraceChannel( + &mockSSHChannel{ + MockSendRequest: func(name string, wantReply bool, payload []byte) (bool, error) { + atomic.AddInt32(sentPings, 1) + + assert.Equal(t, name, "ping", "Unexpected request name.") + assert.False(t, wantReply, "Expected no reply wanted.") + return true, nil + }, + }, + ), + make(<-chan *ssh.Request), + nil } client.MockReply = func(r *ssh.Request, b1 bool, b2 []byte) error { diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 381aea95d2466..3aaf0308a925c 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/reversetunnel/track" @@ -543,7 +544,7 @@ func (p *AgentPool) getVersion(ctx context.Context) (string, error) { } // transport creates a new transport instance. -func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request, conn ssh.Conn) *transport { +func (p *AgentPool) transport(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request, conn sshutils.Conn) *transport { return &transport{ closeContext: ctx, component: p.Component, diff --git a/lib/reversetunnel/transport.go b/lib/reversetunnel/transport.go index 0768e49a5a436..5fbd6ab43c024 100644 --- a/lib/reversetunnel/transport.go +++ b/lib/reversetunnel/transport.go @@ -155,7 +155,7 @@ type transport struct { kubeDialAddr utils.NetAddr // sconn is a SSH connection to the remote host. Used for dial back nodes. - sconn ssh.Conn + sconn sshutils.Conn // reverseTunnelServer holds all reverse tunnel connections. reverseTunnelServer Server diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 7a25df55bf295..be1dcb95eecf7 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -26,10 +26,9 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" - "github.com/gravitational/trace" - "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" @@ -47,8 +46,11 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/google/uuid" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" ) // Server is a forwarding server. Server is used to create a single in-memory @@ -157,6 +159,10 @@ type Server struct { // lockWatcher is the server's lock watcher. lockWatcher *services.LockWatcher + + // tracerProvider is used to create tracers capable + // of starting spans. + tracerProvider oteltrace.TracerProvider } // ServerConfig is the configuration needed to create an instance of a Server. @@ -209,6 +215,10 @@ type ServerConfig struct { // LockWatcher is a lock watcher. LockWatcher *services.LockWatcher + + // TracerProvider is used to create tracers capable + // of starting spans. + TracerProvider oteltrace.TracerProvider } // CheckDefaults makes sure all required parameters are passed in. @@ -246,6 +256,9 @@ func (s *ServerConfig) CheckDefaults() error { if s.LockWatcher == nil { return trace.BadParameter("missing parameter LockWatcher") } + if s.TracerProvider == nil { + s.TracerProvider = tracing.DefaultProvider() + } return nil } @@ -290,6 +303,7 @@ func New(c ServerConfig) (*Server, error) { StreamEmitter: c.Emitter, parentContext: c.ParentContext, lockWatcher: c.LockWatcher, + tracerProvider: c.TracerProvider, } // Set the ciphers, KEX, and MACs that the in-memory server will send to the @@ -538,7 +552,7 @@ func (s *Server) Serve() { go srv.StartKeepAliveLoop(srv.KeepAliveParams{ Conns: []srv.RequestSender{ s.sconn, - s.remoteClient, + s.remoteClient.Client, }, Interval: netConfig.GetKeepAliveInterval(), MaxCount: netConfig.GetKeepAliveCountMax(), @@ -624,17 +638,47 @@ func (s *Server) handleConnection(ctx context.Context, chans <-chan ssh.NewChann for { select { // Global out-of-band requests. - case newRequest := <-reqs: - if newRequest == nil { + case req := <-reqs: + if req == nil { return } - go s.handleGlobalRequest(newRequest) + reqCtx := tracessh.ContextFromRequest(req) + ctx, span := s.tracerProvider.Tracer("ssh").Start( + oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(reqCtx)), + fmt.Sprintf("ssh.Forward.GlobalRequest/%s", req.Type), + oteltrace.WithSpanKind(oteltrace.SpanKindServer), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.ForwardServer"), + semconv.RPCMethodKey.String("GlobalRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + + go func(span oteltrace.Span) { + defer span.End() + s.handleGlobalRequest(ctx, req) + }(span) // Channel requests. - case newChannel := <-chans: - if newChannel == nil { + case nch := <-chans: + if nch == nil { return } - go s.handleChannel(ctx, newChannel) + chanCtx, nch := tracessh.ContextFromNewChannel(nch) + ctx, span := s.tracerProvider.Tracer("ssh").Start( + oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(chanCtx)), + fmt.Sprintf("ssh.Forward.OpenChannel/%s", nch.ChannelType()), + oteltrace.WithSpanKind(oteltrace.SpanKindServer), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.ForwardServer"), + semconv.RPCMethodKey.String("OpenChannel"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + + go func(span oteltrace.Span) { + defer span.End() + s.handleChannel(ctx, nch) + }(span) // If the server is closing (either the heartbeat failed or Close() was // called, exit out of the connection handler loop. case <-ctx.Done(): @@ -653,7 +697,7 @@ func (s *Server) rejectChannel(chans <-chan ssh.NewChannel, errMessage string) { } } -func (s *Server) handleGlobalRequest(req *ssh.Request) { +func (s *Server) handleGlobalRequest(ctx context.Context, req *ssh.Request) { // Version requests are internal Teleport requests, they should not be // forwarded to the remote server. if req.Type == teleport.VersionRequest { @@ -664,7 +708,7 @@ func (s *Server) handleGlobalRequest(req *ssh.Request) { return } - ok, payload, err := s.remoteClient.SendRequest(req.Type, req.WantReply, req.Payload) + ok, payload, err := s.remoteClient.SendRequest(ctx, req.Type, req.WantReply, req.Payload) if err != nil { s.log.Warnf("Failed to forward global request %v: %v", req.Type, err) return @@ -682,6 +726,23 @@ func (s *Server) handleChannel(ctx context.Context, nch ssh.NewChannel) { channelType := nch.ChannelType() switch channelType { + // Channels of type "tracing-request" are sent to determine if ssh tracing envelopes + // are supported. Accepting the channel indicates to clients that they may wrap their + // ssh payload with tracing context. + case tracessh.TracingChannel: + ch, _, err := nch.Accept() + if err != nil { + s.log.Warnf("Unable to accept channel: %v", err) + if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)); err != nil { + s.log.Warnf("Failed to reject channel: %v", err) + } + return + } + + if err := ch.Close(); err != nil { + s.log.Warnf("Unable to close %q channel: %v", nch.ChannelType(), err) + } + return // Channels of type "session" handle requests that are involved in running // commands on a server, subsystem requests, and agent forwarding. case teleport.ChanSession: @@ -742,7 +803,7 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r defer s.log.Debugf("Completing direct-tcpip request from %v to %v in context %v.", scx.SrcAddr, scx.DstAddr, scx.ID()) // Create "direct-tcpip" channel from the remote host to the target host. - conn, err := s.remoteClient.Dial("tcp", scx.DstAddr) + conn, err := s.remoteClient.DialContext(ctx, "tcp", scx.DstAddr) if err != nil { scx.Infof("Failed to connect to: %v: %v", scx.DstAddr, err) return @@ -880,8 +941,22 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { scx.Debugf("Client %v disconnected", s.sconn.RemoteAddr()) return } + + reqCtx := tracessh.ContextFromRequest(req) + ctx, span := s.tracerProvider.Tracer("ssh").Start( + oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(reqCtx)), + fmt.Sprintf("ssh.Forward.SessionRequest/%s", req.Type), + oteltrace.WithSpanKind(oteltrace.SpanKindServer), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.ForwardServer"), + semconv.RPCMethodKey.String("SessionRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + if err := s.dispatch(ctx, ch, req, scx); err != nil { s.replyError(ch, req, err) + span.End() return } if req.WantReply { @@ -889,6 +964,7 @@ func (s *Server) handleSessionChannel(ctx context.Context, nch ssh.NewChannel) { scx.Errorf("failed sending OK response on %q request: %v", req.Type, err) } } + span.End() case result := <-scx.ExecResultCh: scx.Debugf("Exec request (%q) complete: %v", result.Command, result.Code) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 5b8f215ef98ee..a2af576fdba84 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -20,7 +20,6 @@ package regular import ( "context" - "encoding/json" "fmt" "io" "net" @@ -61,6 +60,8 @@ import ( "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" ) @@ -187,7 +188,7 @@ type Server struct { // utmpPath is the path to the user accounting database. utmpPath string - // wtmpPath is the path to the user accounting log. + // wtmpPath is the path to the user accounting s.Logger. wtmpPath string // allowTCPForwarding indicates whether the ssh server is allowed to offer @@ -214,6 +215,10 @@ type Server struct { // users is used to start the automatic user deletion loop users srv.HostUsers + + // tracerProvider is used to create tracers capable + // of starting spans. + tracerProvider oteltrace.TracerProvider } // GetClock returns server clock implementation @@ -650,6 +655,14 @@ func SetInventoryControlHandle(handle inventory.DownstreamHandle) ServerOption { } } +// SetTracerProvider sets the tracer provider. +func SetTracerProvider(provider oteltrace.TracerProvider) ServerOption { + return func(s *Server) error { + s.tracerProvider = provider + return nil + } +} + // New returns an unstarted server func New(addr utils.NetAddr, hostname string, @@ -719,6 +732,10 @@ func New(addr utils.NetAddr, s.connectedProxyGetter = reversetunnel.NewConnectedProxyGetter() } + if s.tracerProvider == nil { + s.tracerProvider = tracing.DefaultProvider() + } + var component string if s.proxyMode { component = teleport.ComponentProxy @@ -793,7 +810,7 @@ func New(addr utils.NetAddr, var heartbeat srv.HeartbeatI if heartbeatMode == srv.HeartbeatModeNode && s.inventoryHandle != nil { - log.Info("debug -> starting control-stream based heartbeat.") + s.Logger.Info("debug -> starting control-stream based heartbeat.") heartbeat, err = srv.NewSSHServerHeartbeat(srv.SSHServerHeartbeatConfig{ InventoryHandle: s.inventoryHandle, GetServer: s.getServerInfo, @@ -801,7 +818,7 @@ func New(addr utils.NetAddr, OnHeartbeat: s.onHeartbeat, }) } else { - log.Info("debug -> starting legacy heartbeat.") + s.Logger.Info("debug -> starting legacy heartbeat.") heartbeat, err = srv.NewHeartbeat(srv.HeartbeatConfig{ Mode: heartbeatMode, Context: ctx, @@ -892,7 +909,7 @@ func (s *Server) AdvertiseAddr() string { _, port, _ := net.SplitHostPort(listenAddr) ahost, aport, err := utils.ParseAdvertiseAddr(advertiseAddr.String()) if err != nil { - log.Warningf("Failed to parse advertise address %q, %v, using default value %q.", advertiseAddr, err, listenAddr) + s.Logger.Warningf("Failed to parse advertise address %q, %v, using default value %q.", advertiseAddr, err, listenAddr) return listenAddr } if aport == "" { @@ -968,7 +985,7 @@ func (s *Server) getServerInfo() *types.ServerV2 { rotation, err := s.getRotation(s.getRole()) if err != nil { if !trace.IsNotFound(err) { - log.Warningf("Failed to get rotation state: %v", err) + s.Logger.Warningf("Failed to get rotation state: %v", err) } } else { server.SetRotation(*rotation) @@ -1040,10 +1057,10 @@ func (s *Server) HandleRequest(ctx context.Context, r *ssh.Request) { default: if r.WantReply { if err := r.Reply(false, nil); err != nil { - log.Warnf("Failed to reply to %q request: %v", r.Type, err) + s.Logger.Warnf("Failed to reply to %q request: %v", r.Type, err) } } - log.Debugf("Discarding %q global request: %+v", r.Type, r) + s.Logger.Debugf("Discarding %q global request: %+v", r.Type, r) } } @@ -1085,7 +1102,7 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont if lockErr := s.lockWatcher.CheckLockInForce(lockingMode, lockTargets...); lockErr != nil { event.Reason = lockErr.Error() if err := s.EmitAuditEvent(s.ctx, event); err != nil { - log.WithError(err).Warn("Failed to emit session reject event.") + s.Logger.WithError(err).Warn("Failed to emit session reject event.") } return ctx, trace.Wrap(lockErr) } @@ -1124,7 +1141,7 @@ func (s *Server) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionCont event.Reason = events.SessionRejectedEvent event.Maximum = maxConnections if err := s.EmitAuditEvent(s.ctx, event); err != nil { - log.WithError(err).Warn("Failed to emit session reject event.") + s.Logger.WithError(err).Warn("Failed to emit session reject event.") } err = trace.AccessDenied("too many concurrent ssh connections for user %q (max=%d)", identityContext.TeleportUser, @@ -1158,18 +1175,34 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont channelType := nch.ChannelType() if s.proxyMode { switch channelType { + // Channels of type "tracing-request" are sent to determine if ssh tracing envelopes + // are supported. Accepting the channel indicates to clients that they may wrap their + // ssh payload with tracing context. + case tracessh.TracingChannel: + ch, _, err := nch.Accept() + if err != nil { + s.Logger.Warnf("Unable to accept channel: %v", err) + if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)); err != nil { + s.Logger.Warnf("Failed to reject channel: %v", err) + } + return + } + if err := ch.Close(); err != nil { + s.Logger.Warnf("Unable to close %q channel: %v", nch.ChannelType(), err) + } + return // Channels of type "direct-tcpip", for proxies, it's equivalent // of teleport proxy: subsystem case teleport.ChanDirectTCPIP: req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData()) if err != nil { - log.Errorf("Failed to parse request data: %v, err: %v.", string(nch.ExtraData()), err) + s.Logger.Errorf("Failed to parse request data: %v, err: %v.", string(nch.ExtraData()), err) rejectChannel(nch, ssh.UnknownChannelType, "failed to parse direct-tcpip request") return } ch, _, err := nch.Accept() if err != nil { - log.Warnf("Unable to accept channel: %v.", err) + s.Logger.Warnf("Unable to accept channel: %v.", err) rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } @@ -1181,7 +1214,7 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont case teleport.ChanSession: ch, requests, err := nch.Accept() if err != nil { - log.Warnf("Unable to accept channel: %v.", err) + s.Logger.Warnf("Unable to accept channel: %v.", err) rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } @@ -1194,6 +1227,22 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont } switch channelType { + // Channels of type "tracing-request" are sent to determine if ssh tracing envelopes + // are supported. Accepting the channel indicates to clients that they may wrap their + // ssh payload with tracing context. + case tracessh.TracingChannel: + ch, _, err := nch.Accept() + if err != nil { + if err := nch.Reject(ssh.ConnectionFailed, err.Error()); err != nil { + s.Logger.Warnf("Unable to reject %q channel: %v", nch.ChannelType(), err) + } + return + } + + if err := ch.Close(); err != nil { + s.Logger.Warnf("Unable to close %q channel: %v", nch.ChannelType(), err) + } + return // Channels of type "session" handle requests that are involved in running // commands on a server, subsystem requests, and agent forwarding. case teleport.ChanSession: @@ -1220,7 +1269,7 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont Reason: events.SessionRejectedReasonMaxSessions, Maximum: max, }); err != nil { - log.WithError(err).Warn("Failed to emit session reject event.") + s.Logger.WithError(err).Warn("Failed to emit session reject event.") } rejectChannel(nch, ssh.Prohibited, fmt.Sprintf("too many session channels for user %q (max=%d)", identityContext.TeleportUser, max)) return @@ -1229,7 +1278,7 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont } ch, requests, err := nch.Accept() if err != nil { - log.Warnf("Unable to accept channel: %v.", err) + s.Logger.Warnf("Unable to accept channel: %v.", err) rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) if decr != nil { decr() @@ -1246,13 +1295,13 @@ func (s *Server) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionCont case teleport.ChanDirectTCPIP: req, err := sshutils.ParseDirectTCPIPReq(nch.ExtraData()) if err != nil { - log.Errorf("Failed to parse request data: %v, err: %v.", string(nch.ExtraData()), err) + s.Logger.Errorf("Failed to parse request data: %v, err: %v.", string(nch.ExtraData()), err) rejectChannel(nch, ssh.UnknownChannelType, "failed to parse direct-tcpip request") return } ch, _, err := nch.Accept() if err != nil { - log.Warnf("Unable to accept channel: %v.", err) + s.Logger.Warnf("Unable to accept channel: %v.", err) rejectChannel(nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) return } @@ -1297,10 +1346,10 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con // forwarding is complete. ctx, scx, err := srv.NewServerContext(ctx, ccx, s, identityContext) if err != nil { - log.WithError(err).Error("Unable to create connection context.") + s.Logger.WithError(err).Error("Unable to create connection context.") writeStderr(channel, "Unable to create connection context.") if err := channel.Close(); err != nil { - log.WithError(err).Warn("Failed to close channel.") + s.Logger.WithError(err).Warn("Failed to close channel.") } return } @@ -1338,13 +1387,13 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con // parent and child. pr, err := cmd.StdoutPipe() if err != nil { - log.Errorf("Failed to setup stdout pipe: %v", err) + s.Logger.Errorf("Failed to setup stdout pipe: %v", err) writeStderr(channel, err.Error()) return } pw, err := cmd.StdinPipe() if err != nil { - log.Errorf("Failed to setup stdin pipe: %v", err) + s.Logger.Errorf("Failed to setup stdin pipe: %v", err) writeStderr(channel, err.Error()) return } @@ -1383,7 +1432,7 @@ Loop: select { case err := <-errorCh: if err != nil && err != io.EOF { - log.Warnf("Connection problem in \"direct-tcpip\" channel: %v %T.", trace.DebugReport(err), err) + s.Logger.Warnf("Connection problem in \"direct-tcpip\" channel: %v %T.", trace.DebugReport(err), err) } case <-ctx.Done(): break Loop @@ -1413,7 +1462,7 @@ Loop: Success: true, }, }); err != nil { - log.WithError(err).Warn("Failed to emit port forward event.") + s.Logger.WithError(err).Warn("Failed to emit port forward event.") } } @@ -1423,7 +1472,7 @@ Loop: func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, ch ssh.Channel, in <-chan *ssh.Request) { netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(ctx) if err != nil { - log.Errorf("Unable to fetch cluster networking config: %v.", err) + s.Logger.Errorf("Unable to fetch cluster networking config: %v.", err) writeStderr(ch, "Unable to fetch cluster networking configuration.") return } @@ -1435,10 +1484,10 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec cfg.MessageWriter = &stderrWriter{channel: ch} }) if err != nil { - log.WithError(err).Error("Unable to create connection context.") + s.Logger.WithError(err).Error("Unable to create connection context.") writeStderr(ch, "Unable to create connection context.") if err := ch.Close(); err != nil { - log.WithError(err).Warn("Failed to close channel.") + s.Logger.WithError(err).Warn("Failed to close channel.") } return } @@ -1492,28 +1541,29 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec return } - // handle the tracing request inline here to update the context for this session. - // this should only be requested once, right after the session has been created. - if req.Type == tracessh.TracingRequest { - var traceCtx tracing.PropagationContext - if err := json.Unmarshal(req.Payload, &traceCtx); err != nil { - scx.WithError(err).Error("Failed to unmarshal tracing request.") - continue - } - - ctx = tracing.WithPropagationContext(ctx, traceCtx) - continue - } + reqCtx := tracessh.ContextFromRequest(req) + ctx, span := s.tracerProvider.Tracer("ssh").Start( + oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(reqCtx)), + fmt.Sprintf("ssh.Regular.SessionRequest/%s", req.Type), + oteltrace.WithSpanKind(oteltrace.SpanKindServer), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.RegularServer"), + semconv.RPCMethodKey.String("SessionRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) if err := s.dispatch(ctx, ch, req, scx); err != nil { s.replyError(ch, req, err) + span.End() return } if req.WantReply { if err := req.Reply(true, nil); err != nil { - log.Warnf("Failed to reply to %q request: %v", req.Type, err) + s.Logger.Warnf("Failed to reply to %q request: %v", req.Type, err) } } + span.End() case result := <-scx.ExecResultCh: scx.Debugf("Exec request (%q) complete: %v", result.Command, result.Code) @@ -1526,7 +1576,7 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec return case <-ctx.Done(): - log.Debugf("Closing session due to cancellation.") + s.Logger.Debugf("Closing session due to cancellation.") return } } @@ -1541,6 +1591,8 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // other than our own custom "subsystems" and environment manipulation. if s.proxyMode { switch req.Type { + case tracessh.TracingRequest: + return nil case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, serverContext) case sshutils.EnvRequest: @@ -1551,7 +1603,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // recording proxy mode. err := s.handleAgentForwardProxy(req, serverContext) if err != nil { - log.Warn(err) + s.Logger.Warn(err) } return nil case sshutils.PuTTYSimpleRequest: @@ -1559,7 +1611,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // as a proxy to indicate that it's in "simple" node and won't be requesting any other channels. // As we don't support this request, we ignore it. // https://the.earth.li/~sgtatham/putty/0.76/htmldoc/AppendixG.html#sshnames-channel - log.Debugf("%v: deliberately ignoring request for '%v' channel", s.Component(), sshutils.PuTTYSimpleRequest) + s.Logger.Debugf("%v: deliberately ignoring request for '%v' channel", s.Component(), sshutils.PuTTYSimpleRequest) return nil default: return trace.BadParameter( @@ -1571,6 +1623,8 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // subset of all the possible request types. if serverContext.JoinOnly { switch req.Type { + case tracessh.TracingRequest: + return nil case sshutils.PTYRequest: return s.termHandlers.HandlePTYReq(ch, req, serverContext) case sshutils.ShellRequest: @@ -1597,7 +1651,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // processing requests. err := s.handleAgentForwardNode(req, serverContext) if err != nil { - log.Warn(err) + s.Logger.Warn(err) } return nil default: @@ -1606,6 +1660,8 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, } switch req.Type { + case tracessh.TracingRequest: + return nil case sshutils.ExecRequest: return s.termHandlers.HandleExec(ctx, ch, req, serverContext) case sshutils.PTYRequest: @@ -1637,7 +1693,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, // processing requests. err := s.handleAgentForwardNode(req, serverContext) if err != nil { - log.Warn(err) + s.Logger.Warn(err) } return nil default: @@ -1722,12 +1778,12 @@ func (s *Server) handleX11Forward(ch ssh.Channel, req *ssh.Request, ctx *srv.Ser } if trace.IsAccessDenied(err) { // denied X11 requests are ok from a protocol perspective so we - // don't return them, just reply over ssh and emit the audit log. + // don't return them, just reply over ssh and emit the audit s.Logger. s.replyError(ch, req, err) err = nil } if err := s.EmitAuditEvent(s.ctx, event); err != nil { - log.WithError(err).Warn("Failed to emit x11-forward event.") + s.Logger.WithError(err).Warn("Failed to emit x11-forward event.") } }() @@ -1772,7 +1828,7 @@ func (s *Server) handleSubsystem(ctx context.Context, ch ssh.Channel, req *ssh.R } go func() { err := sb.Wait() - log.Debugf("Subsystem %v finished with result: %v.", sb, err) + s.Logger.Debugf("Subsystem %v finished with result: %v.", sb, err) serverContext.SendSubsystemResult(srv.SubsystemResult{Err: trace.Wrap(err)}) }() return nil @@ -1792,18 +1848,18 @@ func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerCont // handleKeepAlive accepts and replies to keepalive@openssh.com requests. func (s *Server) handleKeepAlive(req *ssh.Request) { - log.Debugf("Received %q: WantReply: %v", req.Type, req.WantReply) + s.Logger.Debugf("Received %q: WantReply: %v", req.Type, req.WantReply) // only reply if the sender actually wants a response if req.WantReply { err := req.Reply(true, nil) if err != nil { - log.Warnf("Unable to reply to %q request: %v", req.Type, err) + s.Logger.Warnf("Unable to reply to %q request: %v", req.Type, err) return } } - log.Debugf("Replied to %q", req.Type) + s.Logger.Debugf("Replied to %q", req.Type) } // handleRecordingProxy responds to global out-of-band with a bool which @@ -1811,7 +1867,7 @@ func (s *Server) handleKeepAlive(req *ssh.Request) { func (s *Server) handleRecordingProxy(req *ssh.Request) { var recordingProxy bool - log.Debugf("Global request (%v, %v) received", req.Type, req.WantReply) + s.Logger.Debugf("Global request (%v, %v) received", req.Type, req.WantReply) if req.WantReply { // get the cluster config, if we can't get it, reply false @@ -1819,7 +1875,7 @@ func (s *Server) handleRecordingProxy(req *ssh.Request) { if err != nil { err := req.Reply(false, nil) if err != nil { - log.Warnf("Unable to respond to global request (%v, %v): %v", req.Type, req.WantReply, err) + s.Logger.Warnf("Unable to respond to global request (%v, %v): %v", req.Type, req.WantReply, err) } return } @@ -1829,19 +1885,19 @@ func (s *Server) handleRecordingProxy(req *ssh.Request) { recordingProxy = services.IsRecordAtProxy(recConfig.GetMode()) err = req.Reply(true, []byte(strconv.FormatBool(recordingProxy))) if err != nil { - log.Warnf("Unable to respond to global request (%v, %v): %v: %v", req.Type, req.WantReply, recordingProxy, err) + s.Logger.Warnf("Unable to respond to global request (%v, %v): %v: %v", req.Type, req.WantReply, recordingProxy, err) return } } - log.Debugf("Replied to global request (%v, %v): %v", req.Type, req.WantReply, recordingProxy) + s.Logger.Debugf("Replied to global request (%v, %v): %v", req.Type, req.WantReply, recordingProxy) } // handleVersionRequest replies with the Teleport version of the server. func (s *Server) handleVersionRequest(req *ssh.Request) { err := req.Reply(true, []byte(teleport.Version)) if err != nil { - log.Debugf("Failed to reply to version request: %v.", err) + s.Logger.Debugf("Failed to reply to version request: %v.", err) } } @@ -1851,10 +1907,10 @@ func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionCo // session request is complete. ctx, scx, err := srv.NewServerContext(ctx, ccx, s, identityContext) if err != nil { - log.WithError(err).Error("Unable to create connection context.") + s.Logger.WithError(err).Error("Unable to create connection context.") writeStderr(ch, "Unable to create connection context.") if err := ch.Close(); err != nil { - log.WithError(err).Warn("Failed to close channel.") + s.Logger.WithError(err).Warn("Failed to close channel.") } return } @@ -1866,7 +1922,7 @@ func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionCo recConfig, err := s.GetAccessPoint().GetSessionRecordingConfig(ctx) if err != nil { - log.Errorf("Unable to fetch session recording config: %v.", err) + s.Logger.Errorf("Unable to fetch session recording config: %v.", err) writeStderr(ch, "Unable to fetch session recording configuration.") return } @@ -1899,7 +1955,7 @@ func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionCo if services.IsRecordAtProxy(recConfig.GetMode()) { err = s.handleAgentForwardProxy(&ssh.Request{}, scx) if err != nil { - log.Warningf("Failed to request agent in recording mode: %v", err) + s.Logger.Warningf("Failed to request agent in recording mode: %v", err) writeStderr(ch, "Failed to request agent") return } @@ -1907,7 +1963,7 @@ func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionCo netConfig, err := s.GetAccessPoint().GetClusterNetworkingConfig(ctx) if err != nil { - log.Errorf("Unable to fetch cluster networking config: %v.", err) + s.Logger.Errorf("Unable to fetch cluster networking config: %v.", err) writeStderr(ch, "Unable to fetch cluster networking configuration.") return } @@ -1930,13 +1986,13 @@ func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionCo port: fmt.Sprintf("%v", req.Port), }) if err != nil { - log.Errorf("Unable instantiate proxy subsystem: %v.", err) + s.Logger.Errorf("Unable instantiate proxy subsystem: %v.", err) writeStderr(ch, "Unable to instantiate proxy subsystem.") return } if err := subsys.Start(ctx, scx.ServerConn, ch, &ssh.Request{}, scx); err != nil { - log.Errorf("Unable to start proxy subsystem: %v.", err) + s.Logger.Errorf("Unable to start proxy subsystem: %v.", err) writeStderr(ch, "Unable to start proxy subsystem.") return } @@ -1945,7 +2001,7 @@ func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionCo go func() { defer close(wch) if err := subsys.Wait(); err != nil { - log.Errorf("Proxy subsystem failed: %v.", err) + s.Logger.Errorf("Proxy subsystem failed: %v.", err) writeStderr(ch, "Proxy subsystem failed.") } }() @@ -1956,7 +2012,7 @@ func (s *Server) handleProxyJump(ctx context.Context, ccx *sshutils.ConnectionCo } func (s *Server) replyError(ch ssh.Channel, req *ssh.Request, err error) { - log.Error(err) + s.Logger.Error(err) // Terminate the error with a newline when writing to remote channel's // stderr so the output does not mix with the rest of the output if the remote // side is not doing additional formatting for extended data. @@ -1965,7 +2021,7 @@ func (s *Server) replyError(ch ssh.Channel, req *ssh.Request, err error) { writeStderr(ch, message) if req.WantReply { if err := req.Reply(false, []byte(message)); err != nil { - log.Warnf("Failed to reply to %q request: %v", req.Type, err) + s.Logger.Warnf("Failed to reply to %q request: %v", req.Type, err) } } } diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 903cfe1bb2ee9..deec5b9aca497 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -33,15 +33,12 @@ import ( "testing" "time" - "github.com/gravitational/teleport/api/breaker" - "github.com/mailgun/timetools" - "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/agent" - "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/native" @@ -60,10 +57,12 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" + "github.com/mailgun/timetools" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" ) // teleportTestUser is additional user used for tests @@ -92,7 +91,7 @@ type sshInfo struct { srvAddress string srvPort string srvHostPort string - clt *ssh.Client + clt *tracessh.Client cltConfig *ssh.ClientConfig assertCltClose require.ErrorAssertionFunc } @@ -247,7 +246,7 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()), } - client, err := ssh.Dial("tcp", sshSrv.Addr(), cltConfig) + client, err := tracessh.Dial(ctx, "tcp", sshSrv.Addr(), cltConfig) require.NoError(t, err) f := &sshTestFixture{ @@ -268,7 +267,7 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO } t.Cleanup(func() { f.ssh.assertCltClose(t, client.Close()) }) - require.NoError(t, agent.ForwardToAgent(client, keyring)) + require.NoError(t, agent.ForwardToAgent(client.Client, keyring)) return f } @@ -342,7 +341,7 @@ func TestInactivityTimeout(t *testing.T) { // so change the assertion on closing the client to expect it to fail f.ssh.assertCltClose = require.Error - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(context.Background()) require.NoError(t, err) defer se.Close() @@ -382,7 +381,7 @@ func TestLockInForce(t *testing.T) { // so change the assertion on closing the client to expect it to fail. f.ssh.assertCltClose = require.Error - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) stderr, err := se.StderrPipe() @@ -418,23 +417,23 @@ func TestLockInForce(t *testing.T) { require.Equal(t, lockInForceMsg, string(text)) // As long as the lock is in force, new sessions cannot be opened. - newClient, err := ssh.Dial("tcp", f.ssh.srvAddress, f.ssh.cltConfig) + newClient, err := tracessh.Dial(ctx, "tcp", f.ssh.srvAddress, f.ssh.cltConfig) require.NoError(t, err) t.Cleanup(func() { // The client is expected to be closed by the lock monitor therefore expect // an error on this second attempt. require.Error(t, newClient.Close()) }) - _, err = newClient.NewSession() + _, err = newClient.NewSession(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), lockInForceMsg) // Once the lock is lifted, new sessions should go through without error. require.NoError(t, f.testSrv.Auth().DeleteLock(ctx, "test-lock")) - newClient2, err := ssh.Dial("tcp", f.ssh.srvAddress, f.ssh.cltConfig) + newClient2, err := tracessh.Dial(ctx, "tcp", f.ssh.srvAddress, f.ssh.cltConfig) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, newClient2.Close()) }) - _, err = newClient2.NewSession() + _, err = newClient2.NewSession(context.Background()) require.NoError(t, err) } @@ -461,7 +460,7 @@ func TestDirectTCPIP(t *testing.T) { httpClient := http.Client{ Transport: &http.Transport{ Dial: func(network string, addr string) (net.Conn, error) { - return f.ssh.clt.Dial("tcp", u.Host) + return f.ssh.clt.DialContext(context.Background(), "tcp", u.Host) }, }, } @@ -524,7 +523,7 @@ func TestAgentForwardPermission(t *testing.T) { role.SetOptions(roleOptions) require.NoError(t, f.testSrv.Auth().UpsertRole(ctx, role)) - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) t.Cleanup(func() { se.Close() }) @@ -557,18 +556,18 @@ func TestMaxSessions(t *testing.T) { require.NoError(t, err) for i := int64(0); i < maxSessions; i++ { - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) defer se.Close() } - _, err = f.ssh.clt.NewSession() + _, err = f.ssh.clt.NewSession(ctx) require.Error(t, err) require.Contains(t, err.Error(), "too many session channels") // verify that max sessions does not affect max connections. for i := int64(0); i <= maxSessions; i++ { - clt, err := ssh.Dial("tcp", f.ssh.srv.Addr(), f.ssh.cltConfig) + clt, err := tracessh.Dial(ctx, "tcp", f.ssh.srv.Addr(), f.ssh.cltConfig) require.NoError(t, err) require.NoError(t, clt.Close()) } @@ -586,7 +585,7 @@ func TestExecLongCommand(t *testing.T) { echoPath, err := exec.LookPath("echo") require.NoError(t, err) - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(context.Background()) require.NoError(t, err) defer se.Close() @@ -601,7 +600,7 @@ func TestOpenExecSessionSetsSession(t *testing.T) { t.Parallel() f := newFixtureWithoutDiskBasedLogging(t) - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(context.Background()) require.NoError(t, err) defer se.Close() @@ -627,7 +626,7 @@ func TestAgentForward(t *testing.T) { err = f.testSrv.Auth().UpsertRole(ctx, role) require.NoError(t, err) - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) t.Cleanup(func() { se.Close() }) @@ -677,7 +676,7 @@ func TestAgentForward(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - client, err := ssh.Dial("tcp", f.ssh.srv.Addr(), sshConfig) + client, err := tracessh.Dial(ctx, "tcp", f.ssh.srv.Addr(), sshConfig) require.NoError(t, err) err = client.Close() require.NoError(t, err) @@ -737,7 +736,7 @@ func TestX11Forward(t *testing.T) { // concurrent X11 sessions. serverDisplay := x11EchoSession(ctx, t, f.ssh.clt) - client2, err := ssh.Dial("tcp", f.ssh.srv.Addr(), f.ssh.cltConfig) + client2, err := tracessh.Dial(ctx, "tcp", f.ssh.srv.Addr(), f.ssh.cltConfig) require.NoError(t, err) serverDisplay2 := x11EchoSession(ctx, t, client2) @@ -770,8 +769,8 @@ func TestX11Forward(t *testing.T) { // x11EchoSession creates a new ssh session and handles x11 forwarding for the session, // echoing XServer requests received back to the client. Returns the Display opened on the // session, which is set in $DISPLAY. -func x11EchoSession(ctx context.Context, t *testing.T, clt *ssh.Client) x11.Display { - se, err := clt.NewSession() +func x11EchoSession(ctx context.Context, t *testing.T, clt *tracessh.Client) x11.Display { + se, err := clt.NewSession(context.Background()) require.NoError(t, err) t.Cleanup(func() { se.Close() }) @@ -794,7 +793,7 @@ func x11EchoSession(ctx context.Context, t *testing.T, clt *ssh.Client) x11.Disp // Handle any x11 channel requests received from the server // and start x11 forwarding to the client display. - err = x11.ServeChannelRequests(ctx, clt, func(ctx context.Context, nch ssh.NewChannel) { + err = x11.ServeChannelRequests(ctx, clt.Client, func(ctx context.Context, nch ssh.NewChannel) { sch, sin, err := nch.Accept() require.NoError(t, err) defer sch.Close() @@ -892,7 +891,7 @@ func TestAllowedUsers(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - client, err := ssh.Dial("tcp", f.ssh.srv.Addr(), sshConfig) + client, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig) require.NoError(t, err) require.NoError(t, client.Close()) @@ -906,7 +905,7 @@ func TestAllowedUsers(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - _, err = ssh.Dial("tcp", f.ssh.srv.Addr(), sshConfig) + _, err = tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig) require.Error(t, err) } @@ -956,7 +955,7 @@ func TestAllowedLabels(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - _, err = ssh.Dial("tcp", f.ssh.srv.Addr(), sshConfig) + _, err = tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig) if tt.outError { require.Error(t, err) } else { @@ -981,7 +980,7 @@ func TestKeyAlgorithms(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - _, err = ssh.Dial("tcp", f.ssh.srv.Addr(), sshConfig) + _, err = tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig) require.Error(t, err) } @@ -989,7 +988,7 @@ func TestInvalidSessionID(t *testing.T) { t.Parallel() f := newFixture(t) - session, err := f.ssh.clt.NewSession() + session, err := f.ssh.clt.NewSession(context.Background()) require.NoError(t, err) err = session.Setenv(sshutils.SessionEnvVar, "foo") @@ -1019,14 +1018,14 @@ func TestSessionHijack(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - client, err := ssh.Dial("tcp", f.ssh.srv.Addr(), sshConfig) + client, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig) require.NoError(t, err) defer func() { err := client.Close() require.NoError(t, err) }() - se, err := client.NewSession() + se, err := client.NewSession(context.Background()) require.NoError(t, err) defer se.Close() @@ -1047,14 +1046,14 @@ func TestSessionHijack(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - client2, err := ssh.Dial("tcp", f.ssh.srv.Addr(), sshConfig2) + client2, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig2) require.NoError(t, err) defer func() { err := client2.Close() require.NoError(t, err) }() - se2, err := client2.NewSession() + se2, err := client2.NewSession(context.Background()) require.NoError(t, err) defer se2.Close() @@ -1069,11 +1068,11 @@ func TestSessionHijack(t *testing.T) { // testClient dials targetAddr via proxyAddr and executes 2+3 command func testClient(t *testing.T, f *sshTestFixture, proxyAddr, targetAddr, remoteAddr string, sshConfig *ssh.ClientConfig) { // Connect to node using registered address - client, err := ssh.Dial("tcp", proxyAddr, sshConfig) + client, err := tracessh.Dial(context.Background(), "tcp", proxyAddr, sshConfig) require.NoError(t, err) defer client.Close() - se, err := client.NewSession() + se, err := client.NewSession(context.Background()) require.NoError(t, err) defer se.Close() @@ -1102,17 +1101,21 @@ func testClient(t *testing.T, f *sshTestFixture, proxyAddr, targetAddr, remoteAd defer pipeNetConn.Close() // Open SSH connection via TCP - conn, chans, reqs, err := ssh.NewClientConn(pipeNetConn, - f.ssh.srv.Addr(), sshConfig) + conn, chans, reqs, err := tracessh.NewClientConn( + context.Background(), + pipeNetConn, + f.ssh.srv.Addr(), + sshConfig, + ) require.NoError(t, err) defer conn.Close() // using this connection as regular SSH - client2 := ssh.NewClient(conn, chans, reqs) + client2 := tracessh.NewClient(conn, chans, reqs) require.NoError(t, err) defer client2.Close() - se2, err := client2.NewSession() + se2, err := client2.NewSession(context.Background()) require.NoError(t, err) defer se2.Close() @@ -1339,7 +1342,8 @@ func TestPTY(t *testing.T) { t.Parallel() f := newFixture(t) - se, err := f.ssh.clt.NewSession() + + se, err := f.ssh.clt.NewSession(context.Background()) require.NoError(t, err) defer se.Close() @@ -1356,7 +1360,7 @@ func TestEnv(t *testing.T) { f := newFixture(t) - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(context.Background()) require.NoError(t, err) defer se.Close() @@ -1367,7 +1371,7 @@ func TestEnv(t *testing.T) { func TestNoAuth(t *testing.T) { t.Parallel() f := newFixture(t) - _, err := ssh.Dial("tcp", f.ssh.srv.Addr(), &ssh.ClientConfig{}) + _, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), &ssh.ClientConfig{}) require.Error(t, err) } @@ -1379,7 +1383,7 @@ func TestPasswordAuth(t *testing.T) { Auth: []ssh.AuthMethod{ssh.Password("")}, HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - _, err := ssh.Dial("tcp", f.ssh.srv.Addr(), config) + _, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), config) require.Error(t, err) } @@ -1392,11 +1396,11 @@ func TestClientDisconnect(t *testing.T) { Auth: []ssh.AuthMethod{ssh.PublicKeys(f.up.certSigner)}, HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - clt, err := ssh.Dial("tcp", f.ssh.srv.Addr(), config) + clt, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), config) require.NoError(t, err) require.NotNil(t, clt) - se, err := f.ssh.clt.NewSession() + se, err := f.ssh.clt.NewSession(context.Background()) require.NoError(t, err) require.NoError(t, se.Shell()) require.NoError(t, clt.Close()) @@ -1492,24 +1496,24 @@ func TestLimiter(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - clt0, err := ssh.Dial("tcp", srv.Addr(), config) + clt0, err := tracessh.Dial(ctx, "tcp", srv.Addr(), config) require.NoError(t, err) require.NotNil(t, clt0) - se0, err := clt0.NewSession() + se0, err := clt0.NewSession(ctx) require.NoError(t, err) require.NoError(t, se0.Shell()) // current connections = 1 - clt, err := ssh.Dial("tcp", srv.Addr(), config) + clt, err := tracessh.Dial(ctx, "tcp", srv.Addr(), config) require.NoError(t, err) require.NotNil(t, clt) - se, err := clt.NewSession() + se, err := clt.NewSession(ctx) require.NoError(t, err) require.NoError(t, se.Shell()) // current connections = 2 - _, err = ssh.Dial("tcp", srv.Addr(), config) + _, err = tracessh.Dial(ctx, "tcp", srv.Addr(), config) require.Error(t, err) require.NoError(t, se.Close()) @@ -1520,15 +1524,15 @@ func TestLimiter(t *testing.T) { require.Eventually(t, getWaitForNumberOfConnsFunc(t, limiter, "127.0.0.1", 1), time.Second*10, time.Millisecond*100) // current connections = 1 - clt, err = ssh.Dial("tcp", srv.Addr(), config) + clt, err = tracessh.Dial(ctx, "tcp", srv.Addr(), config) require.NoError(t, err) require.NotNil(t, clt) - se, err = clt.NewSession() + se, err = clt.NewSession(ctx) require.NoError(t, err) require.NoError(t, se.Shell()) // current connections = 2 - _, err = ssh.Dial("tcp", srv.Addr(), config) + _, err = tracessh.Dial(ctx, "tcp", srv.Addr(), config) require.Error(t, err) require.NoError(t, se.Close()) @@ -1540,10 +1544,10 @@ func TestLimiter(t *testing.T) { // current connections = 1 // requests rate should exceed now - clt, err = ssh.Dial("tcp", srv.Addr(), config) + clt, err = tracessh.Dial(ctx, "tcp", srv.Addr(), config) require.NoError(t, err) require.NotNil(t, clt) - _, err = clt.NewSession() + _, err = clt.NewSession(ctx) require.Error(t, err) clt.Close() @@ -1555,7 +1559,7 @@ func TestLimiter(t *testing.T) { func TestServerAliveInterval(t *testing.T) { t.Parallel() f := newFixture(t) - ok, _, err := f.ssh.clt.SendRequest(teleport.KeepAliveReqType, true, nil) + ok, _, err := f.ssh.clt.SendRequest(context.Background(), teleport.KeepAliveReqType, true, nil) require.NoError(t, err) require.True(t, ok) } @@ -1577,7 +1581,7 @@ func TestGlobalRequestRecordingProxy(t *testing.T) { // send the request again, we have cluster config and when we parse the // response, it should be false because recording is occurring at the node. - ok, responseBytes, err := f.ssh.clt.SendRequest(teleport.RecordingProxyReqType, true, nil) + ok, responseBytes, err := f.ssh.clt.SendRequest(ctx, teleport.RecordingProxyReqType, true, nil) require.NoError(t, err) require.True(t, ok) response, err := strconv.ParseBool(string(responseBytes)) @@ -1595,7 +1599,7 @@ func TestGlobalRequestRecordingProxy(t *testing.T) { // send request again, now that we have cluster config and it's set to record // at the proxy, we should return true and when we parse the payload it should // also be true - ok, responseBytes, err = f.ssh.clt.SendRequest(teleport.RecordingProxyReqType, true, nil) + ok, responseBytes, err = f.ssh.clt.SendRequest(ctx, teleport.RecordingProxyReqType, true, nil) require.NoError(t, err) require.True(t, ok) response, err = strconv.ParseBool(string(responseBytes)) @@ -1690,70 +1694,77 @@ func newRawNode(t *testing.T, authSrv *auth.Server) *rawNode { // X11 forwarding request and then dials a single X11 channel which echoes all bytes written // to it. Used to verify the behavior of X11 forwarding in recording proxies. Returns a // node and an error channel that can be monitored for asynchronous failures. -func startX11EchoServer(ctx context.Context, t *testing.T, authSrv *auth.Server) (*rawNode, <-chan error) { - node := newRawNode(t, authSrv) +// x11Handler handles requests received by the x11 echo server +func x11Handler(ctx context.Context, conn *ssh.ServerConn, chs <-chan ssh.NewChannel) error { + // expect client to open a session channel + var nch ssh.NewChannel + select { + case nch = <-chs: + case <-time.After(time.Second * 3): + return trace.LimitExceeded("Timeout waiting for session channel") + case <-ctx.Done(): + return nil + } - sessionMain := func(ctx context.Context, conn *ssh.ServerConn, chs <-chan ssh.NewChannel) error { - defer conn.Close() + if nch.ChannelType() == tracessh.TracingChannel { + nch.Reject(ssh.UnknownChannelType, "") + return x11Handler(ctx, conn, chs) + } - // expect client to open a session channel - var nch ssh.NewChannel - select { - case nch = <-chs: - case <-time.After(time.Second * 3): - return trace.LimitExceeded("Timeout waiting for session channel") - case <-ctx.Done(): - return nil - } - if nch.ChannelType() != teleport.ChanSession { - return trace.BadParameter("Unexpected channel type: %q", nch.ChannelType()) - } + if nch.ChannelType() != teleport.ChanSession { + return trace.BadParameter("Unexpected channel type: %q", nch.ChannelType()) + } - sch, creqs, err := nch.Accept() - if err != nil { - return trace.Wrap(err) - } - defer sch.Close() + sch, creqs, err := nch.Accept() + if err != nil { + return trace.Wrap(err) + } + defer sch.Close() - // expect client to send an X11 forwarding request - var req *ssh.Request - select { - case req = <-creqs: - case <-time.After(time.Second * 3): - return trace.LimitExceeded("Timeout waiting for X11 forwarding request") - case <-ctx.Done(): - return nil - } + // expect client to send an X11 forwarding request + var req *ssh.Request + select { + case req = <-creqs: + case <-time.After(time.Second * 3): + return trace.LimitExceeded("Timeout waiting for X11 forwarding request") + case <-ctx.Done(): + return nil + } - if req.Type != sshutils.X11ForwardRequest { - return trace.BadParameter("Unexpected request type %q", req.Type) - } + if req.Type != sshutils.X11ForwardRequest { + return trace.BadParameter("Unexpected request type %q", req.Type) + } - if err = req.Reply(true, nil); err != nil { - return trace.Wrap(err) - } + if err = req.Reply(true, nil); err != nil { + return trace.Wrap(err) + } - // start a fake X11 channel - xch, _, err := conn.OpenChannel(sshutils.X11ChannelRequest, nil) - if err != nil { - return trace.Wrap(err) - } - defer xch.Close() - - // echo all bytes back across the X11 channel - _, err = io.Copy(xch, xch) - if err == nil { - xch.CloseWrite() - } else { - log.Errorf("X11 channel error: %v", err) - } + // start a fake X11 channel + xch, _, err := conn.OpenChannel(sshutils.X11ChannelRequest, nil) + if err != nil { + return trace.Wrap(err) + } + defer xch.Close() - return nil + // echo all bytes back across the X11 channel + _, err = io.Copy(xch, xch) + if err == nil { + xch.CloseWrite() + } else { + log.Errorf("X11 channel error: %v", err) } - errorCh := make(chan error, 1) + return nil +} - nodeMain := func() { +// startX11EchoServer starts a fake node which, for each incoming SSH connection, accepts an +// X11 forwarding request and then dials a single X11 channel which echoes all bytes written +// to it. Used to verify the behavior of X11 forwarding in recording proxies. Returns a +// node and an error channel that can be monitored for asynchronous failures. +func startX11EchoServer(ctx context.Context, t *testing.T, authSrv *auth.Server) (*rawNode, <-chan error) { + node := newRawNode(t, authSrv) + errorCh := make(chan error, 1) + go func() { for { conn, chs, _, err := node.accept() if err != nil { @@ -1761,14 +1772,13 @@ func startX11EchoServer(ctx context.Context, t *testing.T, authSrv *auth.Server) return } go func() { - if err := sessionMain(ctx, conn, chs); err != nil { + if err := x11Handler(ctx, conn, chs); err != nil { errorCh <- err } + conn.Close() }() } - } - - go nodeMain() + }() return node, errorCh } @@ -1827,7 +1837,7 @@ func TestX11ProxySupport(t *testing.T) { require.NoError(t, err) // verify that the proxy is in recording mode - ok, responseBytes, err := f.ssh.clt.SendRequest(teleport.RecordingProxyReqType, true, nil) + ok, responseBytes, err := f.ssh.clt.SendRequest(ctx, teleport.RecordingProxyReqType, true, nil) require.NoError(t, err) require.True(t, ok) response, err := strconv.ParseBool(string(responseBytes)) @@ -1848,7 +1858,7 @@ func TestX11ProxySupport(t *testing.T) { defer x11Cancel() // Create a direct TCP/IP connection from proxy to our X11 test server. - netConn, err := f.ssh.clt.Dial("tcp", node.addr) + netConn, err := f.ssh.clt.DialContext(x11Ctx, "tcp", node.addr) require.NoError(t, err) defer netConn.Close() @@ -1858,11 +1868,11 @@ func TestX11ProxySupport(t *testing.T) { cltConfig.HostKeyCallback = ssh.InsecureIgnoreHostKey() // Perform ssh handshake and setup client for X11 test server. - cltConn, chs, reqs, err := ssh.NewClientConn(netConn, node.addr, &cltConfig) + cltConn, chs, reqs, err := tracessh.NewClientConn(ctx, netConn, node.addr, &cltConfig) require.NoError(t, err) - clt := ssh.NewClient(cltConn, chs, reqs) + clt := tracessh.NewClient(cltConn, chs, reqs) - sess, err := clt.NewSession() + sess, err := clt.NewSession(ctx) require.NoError(t, err) // register X11 channel handler before requesting forwarding to avoid races @@ -1981,12 +1991,12 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { require.NoError(t, err) // Connect SSH client to proxy - client, err := ssh.Dial("tcp", proxy.Addr(), sshConfig) + client, err := tracessh.Dial(ctx, "tcp", proxy.Addr(), sshConfig) require.NoError(t, err) defer client.Close() - se, err := client.NewSession() + se, err := client.NewSession(ctx) require.NoError(t, err) defer se.Close() @@ -2020,17 +2030,17 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { defer pipeNetConn.Close() // Open SSH connection via proxy subsystem's TCP tunnel - conn, chans, reqs, err := ssh.NewClientConn(pipeNetConn, + conn, chans, reqs, err := tracessh.NewClientConn(ctx, pipeNetConn, f.ssh.srv.Addr(), sshConfig) require.NoError(t, err) defer conn.Close() // Run commands over this connection like regular SSH - client2 := ssh.NewClient(conn, chans, reqs) + client2 := tracessh.NewClient(conn, chans, reqs) require.NoError(t, err) defer client2.Close() - se2, err := client2.NewSession() + se2, err := client2.NewSession(ctx) require.NoError(t, err) defer se2.Close() diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index bd1042fd21d6c..6a897b3d8842b 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net" "sync" @@ -31,10 +32,13 @@ import ( "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/observability/tracing" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" @@ -86,6 +90,10 @@ type Server struct { // fips means Teleport started in a FedRAMP/FIPS 140-2 compliant // configuration. fips bool + + // tracerProvider is used to create tracers capable + // of starting spans. + tracerProvider oteltrace.TracerProvider } const ( @@ -139,6 +147,14 @@ func SetInsecureSkipHostValidation() ServerOption { } } +// SetTracerProvider sets the tracer provider for the server. +func SetTracerProvider(provider oteltrace.TracerProvider) ServerOption { + return func(s *Server) error { + s.tracerProvider = provider + return nil + } +} + func NewServer( component string, a utils.NetAddr, @@ -176,6 +192,11 @@ func NewServer( if s.shutdownPollPeriod == 0 { s.shutdownPollPeriod = defaults.ShutdownPollPeriod } + + if s.tracerProvider == nil { + s.tracerProvider = tracing.DefaultProvider() + } + err = s.checkArguments(a, h, hostSigners, ah) if err != nil { return nil, err @@ -497,8 +518,26 @@ func (s *Server) HandleConnection(conn net.Conn) { return } s.log.Debugf("Received out-of-band request: %+v.", req) + + reqCtx := tracessh.ContextFromRequest(req) + ctx, span := s.tracerProvider.Tracer("ssh").Start( + oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(reqCtx)), + fmt.Sprintf("ssh.GlobalRequest/%s", req.Type), + oteltrace.WithSpanKind(oteltrace.SpanKindServer), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Server"), + semconv.RPCMethodKey.String("GlobalRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + if s.reqHandler != nil { - go s.reqHandler.HandleRequest(ctx, req) + go func(span oteltrace.Span) { + defer span.End() + s.reqHandler.HandleRequest(ctx, req) + }(span) + } else { + span.End() } // handle channels: case nch := <-chans: @@ -506,7 +545,41 @@ func (s *Server) HandleConnection(conn net.Conn) { connClosed() return } - go s.newChanHandler.HandleNewChan(ctx, ccx, nch) + + // This is a request from clients to determine if tracing is enabled. + // Handle here so that we always alert clients that we can handle tracing envelopes. + if nch.ChannelType() == tracessh.TracingChannel { + ch, _, err := nch.Accept() + if err != nil { + s.log.Warnf("Unable to accept channel: %v", err) + if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)); err != nil { + s.log.Warnf("Failed to reject channel: %v", err) + } + continue + } + + if err := ch.Close(); err != nil { + s.log.Warnf("Unable to close %q channel: %v", nch.ChannelType(), err) + } + continue + } + + chanCtx, nch := tracessh.ContextFromNewChannel(nch) + ctx, span := s.tracerProvider.Tracer("ssh").Start( + oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(chanCtx)), + fmt.Sprintf("ssh.OpenChannel/%s", nch.ChannelType()), + oteltrace.WithSpanKind(oteltrace.SpanKindServer), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Server"), + semconv.RPCMethodKey.String("OpenChannel"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + + go func(span oteltrace.Span) { + defer span.End() + s.newChanHandler.HandleNewChan(ctx, ccx, nch) + }(span) // send keepalive pings to the clients case <-keepAliveTick.C: const wantReply = true