diff --git a/api/observability/tracing/ssh/client.go b/api/observability/tracing/ssh/client.go index af11c3e0312fc..5e6906cdf92f0 100644 --- a/api/observability/tracing/ssh/client.go +++ b/api/observability/tracing/ssh/client.go @@ -103,8 +103,7 @@ func isTracingSupported(clt *ssh.Client) (tracingCapability, error) { // The resulting connection has a zero LocalAddr() and RemoteAddr(). func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) - - _, span := tracer.Start( + ctx, span := tracer.Start( ctx, "ssh.DialContext", oteltrace.WithSpanKind(oteltrace.SpanKindClient), @@ -121,12 +120,19 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err ) defer span.End() - conn, err := c.Client.Dial(n, addr) - if err != nil { - return nil, trace.Wrap(err) + c.mu.RLock() + // create the wrapper while the lock is held + wrapper := &clientWrapper{ + capability: c.capability, + Conn: c.Client.Conn, + opts: c.opts, + ctx: ctx, + contexts: make(map[string][]context.Context), } + c.mu.RUnlock() - return conn, nil + conn, err := wrapper.Dial(n, addr) + return conn, trace.Wrap(err) } // SendRequest sends a global request, and returns the @@ -204,10 +210,10 @@ func (c *Client) OpenChannel(ctx context.Context, name string, data []byte) (*Ch // NewSession creates a new SSH session that is passed tracing context // so that spans may be correlated properly over the ssh connection. -func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) { +func (c *Client) NewSession(ctx context.Context) (*Session, error) { tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName) - _, span := tracer.Start( + ctx, span := tracer.Start( ctx, "ssh.NewSession", oteltrace.WithSpanKind(oteltrace.SpanKindClient), @@ -249,8 +255,172 @@ func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) { c.capability = capability } + // create the wrapper while the lock is still held + wrapper := &clientWrapper{ + capability: c.capability, + Conn: c.Client.Conn, + opts: c.opts, + ctx: ctx, + contexts: make(map[string][]context.Context), + } + c.mu.Unlock() - session, err := c.Client.NewSession() + // get a session from the wrapper + session, err := wrapper.NewSession() return session, trace.Wrap(err) } + +// clientWrapper wraps the ssh.Conn for individual ssh.Client +// operations to intercept internal calls by the ssh.Client to +// OpenChannel. This allows for internal operations within the +// ssh.Client to have their payload wrapped in an Envelope to +// forward tracing context when tracing is enabled. +type clientWrapper struct { + // Conn is the ssh.Conn that requests will be forwarded to + ssh.Conn + // capability the tracingCapability of the ssh server + capability tracingCapability + // ctx the context which should be used to create spans from + ctx context.Context + // opts the tracing options to use for creating spans with + opts []tracing.Option + + // lock protects the context queue + lock sync.Mutex + // contexts a LIFO queue of context.Context per channel name. + contexts map[string][]context.Context +} + +// NewSession opens a new Session for this client. +func (c *clientWrapper) NewSession() (*Session, error) { + // create a client that will defer to us when + // opening the "session" channel so that we + // can add an Envelope to the request + client := &ssh.Client{ + Conn: c, + } + + session, err := client.NewSession() + if err != nil { + return nil, trace.Wrap(err) + } + + // wrap the session so all session requests on the channel + // can be traced + return &Session{ + Session: session, + wrapper: c, + }, nil +} + +// Dial initiates a connection to the addr from the remote host. +func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) { + // create a client that will defer to us when + // opening the "direct-tcpip" channel so that we + // can add an Envelope to the request + client := &ssh.Client{ + Conn: c, + } + + return client.Dial(n, addr) +} + +// addContext adds the provided context.Context to the end of +// the list for the provided channel name +func (c *clientWrapper) addContext(ctx context.Context, name string) { + c.lock.Lock() + defer c.lock.Unlock() + + c.contexts[name] = append(c.contexts[name], ctx) +} + +// nextContext returns the first context.Context for the provided +// channel name +func (c *clientWrapper) nextContext(name string) context.Context { + c.lock.Lock() + defer c.lock.Unlock() + + contexts, ok := c.contexts[name] + switch { + case !ok, len(contexts) <= 0: + return context.Background() + case len(contexts) == 1: + delete(c.contexts, name) + return contexts[0] + default: + c.contexts[name] = contexts[1:] + return contexts[0] + } +} + +// OpenChannel tries to open a channel. If tracing is enabled, +// the provided payload is wrapped in an Envelope to forward +// any tracing context. +func (c *clientWrapper) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { + config := tracing.NewConfig(c.opts) + tracer := config.TracerProvider.Tracer(instrumentationName) + ctx, span := tracer.Start( + c.ctx, + fmt.Sprintf("ssh.OpenChannel/%s", name), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + append( + peerAttr(c.Conn.RemoteAddr()), + semconv.RPCServiceKey.String("ssh.Client"), + semconv.RPCMethodKey.String("OpenChannel"), + semconv.RPCSystemKey.String("ssh"), + )..., + ), + ) + defer span.End() + + ch, reqs, err := c.Conn.OpenChannel(name, wrapPayload(ctx, c.capability, config.TextMapPropagator, data)) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.RecordError(err) + } + + return channelWrapper{ + Channel: ch, + manager: c, + }, reqs, err +} + +// channelWrapper wraps an ssh.Channel to allow for requests to +// contain tracing context. +type channelWrapper struct { + ssh.Channel + manager *clientWrapper +} + +// SendRequest sends a channel request. If tracing is enabled, +// the provided payload is wrapped in an Envelope to forward +// any tracing context. +// +// It is the callers' responsibility to ensure that addContext is +// called with the appropriate context.Context prior to any +// requests being sent along the channel. +func (c channelWrapper) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + config := tracing.NewConfig(c.manager.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + c.manager.nextContext(name), + fmt.Sprintf("ssh.ChannelRequest/%s", name), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.Bool("want_reply", wantReply), + semconv.RPCServiceKey.String("ssh.Channel"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + ok, err := c.Channel.SendRequest(name, wantReply, wrapPayload(ctx, c.manager.capability, config.TextMapPropagator, payload)) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.RecordError(err) + } + + return ok, err +} diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index 86bdb1a5634b5..ece5506480186 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -25,7 +25,6 @@ import ( ) func TestIsTracingSupported(t *testing.T) { - rejected := &ssh.OpenChannelError{ Reason: ssh.Prohibited, Message: "rejected!", @@ -186,11 +185,11 @@ func TestNewSession(t *testing.T) { cases := []struct { name string - assertionFunc func(t *testing.T, clt *Client, session *ssh.Session, err error) + assertionFunc func(t *testing.T, clt *Client, session *Session, err error) }{ { name: "session prohibited", - assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) { // creating a new session should return any errors captured when creating the client // and not actually probe the server require.Error(t, err) @@ -202,7 +201,7 @@ func TestNewSession(t *testing.T) { }, { name: "other failure to open tracing channel", - assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) { // this time through we should probe the server without getting a prohibited error, // but things still failed, so we shouldn't know the capability require.NoError(t, err) @@ -214,7 +213,7 @@ func TestNewSession(t *testing.T) { }, { name: "active session", - assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) { + assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) { // all is good now, we should have an active session require.NoError(t, err) require.NotNil(t, sess) diff --git a/api/observability/tracing/ssh/session.go b/api/observability/tracing/ssh/session.go new file mode 100644 index 0000000000000..0a0252366ee08 --- /dev/null +++ b/api/observability/tracing/ssh/session.go @@ -0,0 +1,277 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/attribute" + semconv "go.opentelemetry.io/otel/semconv/v1.10.0" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/observability/tracing" +) + +// Session is a wrapper around ssh.Session that adds tracing support +type Session struct { + *ssh.Session + wrapper *clientWrapper +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, error) { + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.SessionRequest/%s", name), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.Bool("want_reply", wantReply), + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + return s.Session.SendRequest(name, wantReply, wrapPayload(ctx, s.wrapper.capability, config.TextMapPropagator, payload)) +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(ctx context.Context, name, value string) error { + const request = "env" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.Setenv/%s", name), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.Setenv(name, value) +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(ctx context.Context, term string, h, w int, termmodes ssh.TerminalModes) error { + const request = "pty-req" + config := tracing.NewConfig(s.wrapper.opts) + tracer := config.TracerProvider.Tracer(instrumentationName) + ctx, span := tracer.Start( + ctx, + fmt.Sprintf("ssh.RequestPty/%s", term), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + attribute.Int("width", w), + attribute.Int("height", h), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.RequestPty(term, h, w, termmodes) +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated. +func (s *Session) RequestSubsystem(ctx context.Context, subsystem string) error { + const request = "subsystem" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.RequestSubsystem/%s", subsystem), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.RequestSubsystem(subsystem) +} + +// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns. +func (s *Session) WindowChange(ctx context.Context, h, w int) error { + const request = "window-change" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + "ssh.WindowChange", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + attribute.Int("height", h), + attribute.Int("width", w), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.WindowChange(h, w) +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(ctx context.Context, sig ssh.Signal) error { + const request = "signal" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.Signal/%s", sig), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.Signal(sig) +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(ctx context.Context, cmd string) error { + const request = "exec" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.Start/%s", cmd), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.Start(cmd) +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell(ctx context.Context) error { + const request = "shell" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + "ssh.Shell", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.Shell() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Run(ctx context.Context, cmd string) error { + const request = "exec" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.Run/%s", cmd), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.Run(cmd) +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(ctx context.Context, cmd string) ([]byte, error) { + const request = "exec" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.Output/%s", cmd), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.Output(cmd) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(ctx context.Context, cmd string) ([]byte, error) { + const request = "exec" + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + fmt.Sprintf("ssh.CombinedOutput/%s", cmd), + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + s.wrapper.addContext(ctx, request) + return s.Session.CombinedOutput(cmd) +} diff --git a/api/observability/tracing/ssh/ssh.go b/api/observability/tracing/ssh/ssh.go index 56590120e09fe..3c2135d6cb990 100644 --- a/api/observability/tracing/ssh/ssh.go +++ b/api/observability/tracing/ssh/ssh.go @@ -82,13 +82,28 @@ func ContextFromNewChannel(nch ssh.NewChannel, opts ...tracing.Option) (context. // initiates the SSH handshake, and then sets up a Client. For access // to incoming channels and requests, use net.Dial with NewClientConn // instead. -func Dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) { +func Dial(ctx context.Context, network, addr string, config *ssh.ClientConfig, opts ...tracing.Option) (*Client, error) { + tracer := tracing.NewConfig(opts).TracerProvider.Tracer(instrumentationName) + ctx, span := tracer.Start( + ctx, + "ssh/Dial", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + attribute.String("network", network), + attribute.String("address", addr), + semconv.RPCServiceKey.String("ssh"), + semconv.RPCMethodKey.String("Dial"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + dialer := net.Dialer{Timeout: config.Timeout} conn, err := dialer.DialContext(ctx, network, addr) if err != nil { return nil, err } - c, chans, reqs, err := NewClientConn(ctx, conn, addr, config) + c, chans, reqs, err := NewClientConn(ctx, conn, addr, config, opts...) if err != nil { return nil, err } @@ -97,7 +112,24 @@ func Dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) ( // NewClientConn creates a new SSH client connection that is passed tracing context so that spans may be correlated // properly over the ssh connection. -func NewClientConn(ctx context.Context, conn net.Conn, addr string, config *ssh.ClientConfig) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { +func NewClientConn(ctx context.Context, conn net.Conn, addr string, config *ssh.ClientConfig, opts ...tracing.Option) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { + tracer := tracing.NewConfig(opts).TracerProvider.Tracer(instrumentationName) + ctx, span := tracer.Start( + ctx, + "ssh/NewClientConn", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + append( + peerAttr(conn.RemoteAddr()), + attribute.String("address", addr), + semconv.RPCServiceKey.String("ssh"), + semconv.RPCMethodKey.String("NewClientConn"), + semconv.RPCSystemKey.String("ssh"), + )..., + ), + ) + defer span.End() + hp := &sshutils.HandshakePayload{ TracingContext: tracing.PropagationContextFromContext(ctx), } @@ -122,13 +154,13 @@ func NewClientConn(ctx context.Context, conn net.Conn, addr string, config *ssh. } // NewClientConnWithDeadline establishes new client connection with specified deadline -func NewClientConnWithDeadline(ctx context.Context, conn net.Conn, addr string, config *ssh.ClientConfig) (*Client, error) { +func NewClientConnWithDeadline(ctx context.Context, conn net.Conn, addr string, config *ssh.ClientConfig, opts ...tracing.Option) (*Client, error) { if config.Timeout > 0 { if err := conn.SetReadDeadline(time.Now().Add(config.Timeout)); err != nil { return nil, trace.Wrap(err) } } - c, chans, reqs, err := NewClientConn(ctx, conn, addr, config) + c, chans, reqs, err := NewClientConn(ctx, conn, addr, config, opts...) if err != nil { return nil, err } @@ -137,7 +169,7 @@ func NewClientConnWithDeadline(ctx context.Context, conn net.Conn, addr string, return nil, trace.Wrap(err) } } - return NewClient(c, chans, reqs), nil + return NewClient(c, chans, reqs, opts...), nil } // peerAttr returns attributes about the peer address. diff --git a/api/observability/tracing/ssh/ssh_test.go b/api/observability/tracing/ssh/ssh_test.go index cbc7d2c370c77..dbe545faab9c0 100644 --- a/api/observability/tracing/ssh/ssh_test.go +++ b/api/observability/tracing/ssh/ssh_test.go @@ -209,9 +209,28 @@ func (h handler) channelHandler(ch ssh.NewChannel) { return } case "session": - // TODO(tross): add tracing checks when session tracing lands - if subtle.ConstantTimeCompare(ch.ExtraData(), []byte(testPayload)) == 1 { - h.errChan <- errors.New("payload mismatch") + switch h.tracingSupported { + case tracingUnsupported: + if subtle.ConstantTimeCompare(ch.ExtraData(), []byte(testPayload)) == 1 { + h.errChan <- errors.New("payload mismatch") + } + case tracingSupported: + var envelope Envelope + if err := json.Unmarshal(ch.ExtraData(), &envelope); err != nil { + h.errChan <- trace.Wrap(err, "failed to unmarshal envelope") + ch.Accept() + return + } + if len(envelope.PropagationContext) <= 0 { + h.errChan <- errors.New("empty propagation context") + ch.Accept() + return + } + if len(envelope.Payload) > 0 { + h.errChan <- errors.New("payload mismatch") + ch.Accept() + return + } } _, chReqs, err := ch.Accept() @@ -253,16 +272,41 @@ func (h handler) subsystemHandler(req *ssh.Request) { } }() - // TODO(tross): add tracing checks when session tracing lands + switch h.tracingSupported { + case tracingUnsupported: + var msg subsystemRequestMsg + if err := ssh.Unmarshal(req.Payload, &msg); err != nil { + h.errChan <- trace.Wrap(err, "failed to unmarshal payload") + return + } - var msg subsystemRequestMsg - if err := ssh.Unmarshal(req.Payload, &msg); err != nil { - h.errChan <- trace.Wrap(err, "failed to unmarshal payload") - return - } + if msg.Subsystem != "test" { + h.errChan <- errors.New("received wrong subsystem") + } + case tracingSupported: + var envelope Envelope + if err := json.Unmarshal(req.Payload, &envelope); err != nil { + h.errChan <- trace.Wrap(err, "failed to unmarshal envelope") + return + } + if len(envelope.PropagationContext) <= 0 { + h.errChan <- errors.New("empty propagation context") + return + } - if msg.Subsystem != "test" { - h.errChan <- errors.New("received wrong subsystem") + var msg subsystemRequestMsg + if err := ssh.Unmarshal(envelope.Payload, &msg); err != nil { + h.errChan <- trace.Wrap(err, "failed to unmarshal payload") + return + } + if msg.Subsystem != "test" { + h.errChan <- errors.New("received wrong subsystem") + return + } + default: + if err := req.Reply(false, nil); err != nil { + h.errChan <- err + } } } @@ -331,7 +375,7 @@ func TestClient(t *testing.T) { default: } - require.NoError(t, session.RequestSubsystem("test")) + require.NoError(t, session.RequestSubsystem(ctx, "test")) select { case err := <-errChan: diff --git a/integration/integration_test.go b/integration/integration_test.go index a38727a8ef80a..f1bcfbf6ebd10 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -4616,6 +4616,7 @@ func runAndMatch(tc *client.TeleportClient, attempts int, command []string, patt func testWindowChange(t *testing.T, suite *integrationTestSuite) { tr := utils.NewTracer(utils.ThisFunction()).Start() defer tr.Stop() + ctx := context.Background() teleport := suite.newTeleport(t, nil, true) defer teleport.StopAll() @@ -4639,7 +4640,7 @@ func testWindowChange(t *testing.T, suite *integrationTestSuite) { cl.Stdout = personA cl.Stdin = personA - err = cl.SSH(context.TODO(), []string{}, false) + err = cl.SSH(ctx, []string{}, false) if !isSSHError(err) { require.NoError(t, err) } @@ -4666,8 +4667,8 @@ func testWindowChange(t *testing.T, suite *integrationTestSuite) { cl.Stdin = personB // Change the size of the window immediately after it is created. - cl.OnShellCreated = func(s *ssh.Session, c *tracessh.Client, terminal io.ReadWriteCloser) (exit bool, err error) { - err = s.WindowChange(48, 160) + cl.OnShellCreated = func(s *tracessh.Session, c *tracessh.Client, terminal io.ReadWriteCloser) (exit bool, err error) { + err = s.WindowChange(ctx, 48, 160) if err != nil { return true, trace.Wrap(err) } @@ -4675,7 +4676,7 @@ func testWindowChange(t *testing.T, suite *integrationTestSuite) { } for i := 0; i < 10; i++ { - err = cl.Join(context.TODO(), types.SessionPeerMode, apidefaults.Namespace, session.ID(sessionID), personB) + err = cl.Join(ctx, types.SessionPeerMode, apidefaults.Namespace, session.ID(sessionID), personB) if err == nil || isSSHError(err) { err = nil break diff --git a/lib/client/api.go b/lib/client/api.go index bb5dbb4c58cd5..86b84ea7a3dc6 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1390,7 +1390,7 @@ type TeleportClient struct { // hasn't begun yet. // // It allows clients to cancel SSH action -type ShellCreatedCallback func(s *ssh.Session, c *tracessh.Client, terminal io.ReadWriteCloser) (exit bool, err error) +type ShellCreatedCallback func(s *tracessh.Session, c *tracessh.Client, terminal io.ReadWriteCloser) (exit bool, err error) // NewClient creates a TeleportClient object and fully configures it func NewClient(c *Config) (tc *TeleportClient, err error) { diff --git a/lib/client/client.go b/lib/client/client.go index f2b85dc0b1ed4..430fd1da0dfa9 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -117,7 +117,7 @@ func (proxy *ProxyClient) GetSites(ctx context.Context) ([]types.Site, error) { // the session request, and then RequestSubsystem call could hang // forever go func() { - if err := proxySession.RequestSubsystem("proxysites"); err != nil { + if err := proxySession.RequestSubsystem(ctx, "proxysites"); err != nil { log.Warningf("Failed to request subsystem: %v", err) } }() @@ -1425,7 +1425,7 @@ func (proxy *ProxyClient) dialAuthServer(ctx context.Context, clusterName string return nil, trace.Wrap(err) } - err = proxySession.RequestSubsystem("proxy:" + address) + err = proxySession.RequestSubsystem(ctx, "proxy:"+address) if err != nil { // read the stderr output from the failed SSH session and append // it to the end of our own message: @@ -1476,11 +1476,11 @@ func (n *NodeAddr) ProxyFormat() string { // requestSubsystem sends a subsystem request on the session. If the passed // in context is canceled first, unblocks. -func requestSubsystem(ctx context.Context, session *ssh.Session, name string) error { +func requestSubsystem(ctx context.Context, session *tracessh.Session, name string) error { errCh := make(chan error) go func() { - er := session.RequestSubsystem(name) + er := session.RequestSubsystem(ctx, name) errCh <- er }() @@ -1558,7 +1558,7 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeAdd // pass the true client IP (if specified) to the proxy so it could pass it into the // SSH session for proper audit if len(proxy.clientAddr) > 0 { - if err = proxySession.Setenv(sshutils.TrueClientAddrVar, proxy.clientAddr); err != nil { + if err = proxySession.Setenv(ctx, sshutils.TrueClientAddrVar, proxy.clientAddr); err != nil { log.Error(err) } } @@ -1576,7 +1576,7 @@ func (proxy *ProxyClient) ConnectToNode(ctx context.Context, nodeAddress NodeAdd return nil, trace.Wrap(err) } - err = agent.RequestAgentForwarding(proxySession) + err = agent.RequestAgentForwarding(proxySession.Session) if err != nil { return nil, trace.Wrap(err) } @@ -1868,7 +1868,7 @@ func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error { runC := make(chan error, 1) go func() { - err := s.Run(shellCmd) + err := s.Run(ctx, shellCmd) if err != nil && errors.Is(err, &ssh.ExitMissingError{}) { // TODO(dmitri): currently, if the session is aborted with (*session).Close, // the remote side cannot send exit-status and this error results. diff --git a/lib/client/session.go b/lib/client/session.go index 08c302e037d64..31831ceeed58d 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -30,6 +30,7 @@ import ( "time" "github.com/gravitational/teleport" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/client/escape" "github.com/gravitational/teleport/lib/client/terminal" @@ -193,7 +194,7 @@ func (ns *NodeSession) NodeClient() *NodeClient { return ns.nodeClient } -func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *ssh.Session) error) error { +func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *tracessh.Session) error) error { session, err := ns.createServerSession(ctx) if err != nil { return trace.Wrap(err) @@ -204,9 +205,9 @@ func (ns *NodeSession) regularSession(ctx context.Context, callback func(s *ssh. return trace.Wrap(callback(session)) } -type interactiveCallback func(serverSession *ssh.Session, shell io.ReadWriteCloser) error +type interactiveCallback func(serverSession *tracessh.Session, shell io.ReadWriteCloser) error -func (ns *NodeSession) createServerSession(ctx context.Context) (*ssh.Session, error) { +func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Session, error) { sess, err := ns.nodeClient.Client.NewSession(ctx) if err != nil { return nil, trace.Wrap(err) @@ -223,7 +224,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*ssh.Session, e evarsToPass := []string{"LANG", "LANGUAGE"} for _, evar := range evarsToPass { if value := os.Getenv(evar); value != "" { - err = sess.Setenv(evar, value) + err = sess.Setenv(ctx, evar, value) if err != nil { log.Warn(err) } @@ -231,7 +232,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*ssh.Session, e } // pass environment variables set by client for key, val := range ns.env { - err = sess.Setenv(key, val) + err = sess.Setenv(ctx, key, val) if err != nil { log.Warn(err) } @@ -248,7 +249,7 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*ssh.Session, e if err != nil { return nil, trace.Wrap(err) } - err = agent.RequestAgentForwarding(sess) + err = agent.RequestAgentForwarding(sess.Session) if err != nil { return nil, trace.Wrap(err) } @@ -287,7 +288,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio return trace.Wrap(err) } // allocate terminal on the server: - remoteTerm, err := ns.allocateTerminal(termType, sess) + remoteTerm, err := ns.allocateTerminal(ctx, termType, sess) if err != nil { return trace.Wrap(err) } @@ -310,7 +311,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio // start piping input into the remote shell and pipe the output from // the remote shell into stdout: - ns.pipeInOut(remoteTerm, mode, sess) + ns.pipeInOut(ctx, remoteTerm, mode, sess) // wait for the session to end <-ns.closer.C @@ -326,7 +327,7 @@ func (ns *NodeSession) interactiveSession(ctx context.Context, mode types.Sessio } // allocateTerminal creates (allocates) a server-side terminal for this session. -func (ns *NodeSession) allocateTerminal(termType string, s *ssh.Session) (io.ReadWriteCloser, error) { +func (ns *NodeSession) allocateTerminal(ctx context.Context, termType string, s *tracessh.Session) (io.ReadWriteCloser, error) { var err error // read the size of the terminal window: @@ -343,10 +344,13 @@ func (ns *NodeSession) allocateTerminal(termType string, s *ssh.Session) (io.Rea } // ... and request a server-side terminal of the same size: - err = s.RequestPty(termType, + err = s.RequestPty( + ctx, + termType, height, width, - ssh.TerminalModes{}) + ssh.TerminalModes{}, + ) if err != nil { return nil, trace.Wrap(err) } @@ -363,7 +367,7 @@ func (ns *NodeSession) allocateTerminal(termType string, s *ssh.Session) (io.Rea return nil, trace.Wrap(err) } if ns.terminal.IsAttached() { - go ns.updateTerminalSize(s) + go ns.updateTerminalSize(ctx, s) } go func() { if _, err := io.Copy(os.Stderr, stderr); err != nil { @@ -379,7 +383,7 @@ func (ns *NodeSession) allocateTerminal(termType string, s *ssh.Session) (io.Rea ), nil } -func (ns *NodeSession) updateTerminalSize(s *ssh.Session) { +func (ns *NodeSession) updateTerminalSize(ctx context.Context, s *tracessh.Session) { terminalEvents := ns.terminal.Subscribe() lastWidth, lastHeight, err := ns.terminal.Size() @@ -420,6 +424,7 @@ func (ns *NodeSession) updateTerminalSize(s *ssh.Session) { // Send the "window-change" request over the channel. _, err = s.SendRequest( + ctx, sshutils.WindowChangeRequest, false, ssh.Marshal(sshutils.WinChangeReqParams{ @@ -485,13 +490,13 @@ func (ns *NodeSession) updateTerminalSize(s *ssh.Session) { // runShell executes user's shell on the remote node under an interactive session func (ns *NodeSession) runShell(ctx context.Context, mode types.SessionParticipantMode, beforeStart func(io.Writer), callback ShellCreatedCallback) error { - return ns.interactiveSession(ctx, mode, func(s *ssh.Session, shell io.ReadWriteCloser) error { + return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, shell io.ReadWriteCloser) error { if beforeStart != nil { beforeStart(s.Stdout) } // start the shell on the server: - if err := s.Shell(); err != nil { + if err := s.Shell(ctx); err != nil { return trace.Wrap(err) } // call the client-supplied callback @@ -521,8 +526,8 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // keyboard based signals will be propogated to the TTY on the server which is // where all signal handling will occur. if interactive { - return ns.interactiveSession(ctx, mode, func(s *ssh.Session, term io.ReadWriteCloser) error { - err := s.Start(strings.Join(cmd, " ")) + return ns.interactiveSession(ctx, mode, func(s *tracessh.Session, term io.ReadWriteCloser) error { + err := s.Start(ctx, strings.Join(cmd, " ")) if err != nil { return trace.Wrap(err) } @@ -548,13 +553,13 @@ func (ns *NodeSession) runCommand(ctx context.Context, mode types.SessionPartici // Unfortunately at the moment the Go SSH library Teleport uses does not // support sending SSH_MSG_DISCONNECT. Instead we close the SSH channel and // SSH client, and try and exit as gracefully as possible. - return ns.regularSession(ctx, func(s *ssh.Session) error { + return ns.regularSession(ctx, func(s *tracessh.Session) error { var err error runContext, cancel := context.WithCancel(ctx) go func() { defer cancel() - err = s.Run(strings.Join(cmd, " ")) + err = s.Run(ctx, strings.Join(cmd, " ")) }() select { @@ -688,7 +693,7 @@ func handlePeerControls(term *terminal.Terminal, enableEscapeSequences bool, rem // pipeInOut launches two goroutines: one to pipe the local input into the remote shell, // and another to pipe the output of the remote shell into the local output -func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser, mode types.SessionParticipantMode, sess *ssh.Session) { +func (ns *NodeSession) pipeInOut(ctx context.Context, shell io.ReadWriteCloser, mode types.SessionParticipantMode, sess *tracessh.Session) { // copy from the remote shell to the local output go func() { defer ns.closer.Close() @@ -704,7 +709,7 @@ func (ns *NodeSession) pipeInOut(shell io.ReadWriteCloser, mode types.SessionPar defer ns.closer.Close() handleNonPeerControls(mode, ns.terminal, func() { - _, err := sess.SendRequest(teleport.ForceTerminateRequest, true, nil) + _, err := sess.SendRequest(ctx, teleport.ForceTerminateRequest, true, nil) if err != nil { fmt.Printf("\n\rError while sending force termination request: %v\n\r", err.Error()) } diff --git a/lib/client/x11_session.go b/lib/client/x11_session.go index 452b772b524bb..6f395ff0301fa 100644 --- a/lib/client/x11_session.go +++ b/lib/client/x11_session.go @@ -21,16 +21,18 @@ import ( "os" "time" - "github.com/gravitational/teleport/lib/sshutils" - "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" + + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/sshutils/x11" ) // handleX11Forwarding handles X11 channel requests for the given server session. // If X11 forwarding is not requested by the client, or it is rejected by the server, // then X11 channel requests will be rejected. -func (ns *NodeSession) handleX11Forwarding(ctx context.Context, sess *ssh.Session) error { +func (ns *NodeSession) handleX11Forwarding(ctx context.Context, sess *tracessh.Session) error { if !ns.nodeClient.TC.EnableX11Forwarding { return ns.rejectX11Channels(ctx) } @@ -54,7 +56,7 @@ func (ns *NodeSession) handleX11Forwarding(ctx context.Context, sess *ssh.Sessio return trace.Wrap(err) } - if err := x11.RequestForwarding(sess, ns.spoofedXAuthEntry); err != nil { + if err := x11.RequestForwarding(sess.Session, ns.spoofedXAuthEntry); err != nil { // Notify the user that x11 forwarding request failed regardless of debug level log.Print("X11 forwarding request failed") log.WithError(err).Debug("X11 forwarding request error") @@ -134,7 +136,7 @@ func (ns *NodeSession) setXAuthData(ctx context.Context, display x11.Display) er } // serveX11Channels serves incoming X11 channels by starting X11 forwarding with the session. -func (ns *NodeSession) serveX11Channels(ctx context.Context, sess *ssh.Session) error { +func (ns *NodeSession) serveX11Channels(ctx context.Context, sess *tracessh.Session) error { err := x11.ServeChannelRequests(ctx, ns.nodeClient.Client.Client, func(ctx context.Context, nch ssh.NewChannel) { if !ns.x11RefuseTime.IsZero() && time.Now().After(ns.x11RefuseTime) { nch.Reject(ssh.Prohibited, "rejected X11 channel request after ForwardX11Timeout") diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 3d0f44d10886f..60ce57e57aebe 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -293,7 +293,7 @@ type ServerContext struct { // RemoteSession holds a SSH session to a remote server. Only used by the // recording proxy. - RemoteSession *ssh.Session + RemoteSession *tracessh.Session // clientLastActive records the last time there was activity from the client clientLastActive time.Time diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 119e8dcbaae76..7dd39b327fb85 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -18,6 +18,7 @@ package srv import ( "bufio" + "context" "fmt" "io" "os" @@ -31,6 +32,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/events" @@ -68,7 +70,7 @@ type Exec interface { SetCommand(string) // Start will start the execution of the command. - Start(channel ssh.Channel) (*ExecResult, error) + Start(ctx context.Context, channel ssh.Channel) (*ExecResult, error) // Wait will block while the command executes. Wait() *ExecResult @@ -135,7 +137,7 @@ func (e *localExec) SetCommand(command string) { // Start launches the given command returns (nil, nil) if successful. // ExecResult is only used to communicate an error while launching. -func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) { +func (e *localExec) Start(ctx context.Context, channel ssh.Channel) (*ExecResult, error) { // Parse the command to see if it is scp. err := e.transformSecureCopy() if err != nil { @@ -286,7 +288,7 @@ func waitForContinue(contfd *os.File) error { // remoteExec is used to run an "exec" SSH request and return the result. type remoteExec struct { command string - session *ssh.Session + session *tracessh.Session ctx *ServerContext } @@ -307,7 +309,7 @@ func (e *remoteExec) SetCommand(command string) { // Start launches the given command returns (nil, nil) if successful. // ExecResult is only used to communicate an error while launching. -func (e *remoteExec) Start(ch ssh.Channel) (*ExecResult, error) { +func (e *remoteExec) Start(ctx context.Context, ch ssh.Channel) (*ExecResult, error) { // hook up stdout/err the channel so the user can interact with the command e.session.Stdout = ch e.session.Stderr = ch.Stderr() @@ -324,7 +326,7 @@ func (e *remoteExec) Start(ch ssh.Channel) (*ExecResult, error) { inputWriter.Close() }() - err = e.session.Start(e.command) + err = e.session.Start(ctx, e.command) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 316cdc84cc2ec..efc100f476e63 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -1008,13 +1008,13 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, if scx.JoinOnly { switch req.Type { case tracessh.TracingRequest: - return s.handleTracingRequest(req, scx) + return s.handleTracingRequest(ctx, req, scx) case sshutils.PTYRequest: - return s.termHandlers.HandlePTYReq(ch, req, scx) + return s.termHandlers.HandlePTYReq(ctx, ch, req, scx) case sshutils.ShellRequest: return s.termHandlers.HandleShell(ctx, ch, req, scx) case sshutils.WindowChangeRequest: - return s.termHandlers.HandleWinChange(ch, req, scx) + return s.termHandlers.HandleWinChange(ctx, ch, req, scx) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, scx) case sshutils.EnvRequest: @@ -1038,19 +1038,19 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, switch req.Type { case tracessh.TracingRequest: - return s.handleTracingRequest(req, scx) + return s.handleTracingRequest(ctx, req, scx) case sshutils.ExecRequest: return s.termHandlers.HandleExec(ctx, ch, req, scx) case sshutils.PTYRequest: - return s.termHandlers.HandlePTYReq(ch, req, scx) + return s.termHandlers.HandlePTYReq(ctx, ch, req, scx) case sshutils.ShellRequest: return s.termHandlers.HandleShell(ctx, ch, req, scx) case sshutils.WindowChangeRequest: - return s.termHandlers.HandleWinChange(ch, req, scx) + return s.termHandlers.HandleWinChange(ctx, ch, req, scx) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, scx) case sshutils.EnvRequest: - return s.handleEnv(ch, req, scx) + return s.handleEnv(ctx, ch, req, scx) case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, scx) case sshutils.X11ForwardRequest: @@ -1084,7 +1084,7 @@ func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.S } // Make an "auth-agent-req@openssh.com" request on the remote host. - err = agent.RequestAgentForwarding(ctx.RemoteSession) + err = agent.RequestAgentForwarding(ctx.RemoteSession.Session) if err != nil { return trace.Wrap(err) } @@ -1117,14 +1117,14 @@ func (s *Server) handleX11ChannelRequest(ctx context.Context, nch ssh.NewChannel defer cancel() go func() { - err := sshutils.ForwardRequests(ctx, cin, sch) + err := sshutils.ForwardRequests(ctx, cin, tracessh.NewTraceChannel(sch, tracing.WithTracerProvider(s.tracerProvider))) if err != nil { s.log.WithError(err).Debug("Failed to forward ssh request from client during X11 forwarding") } }() go func() { - err := sshutils.ForwardRequests(ctx, sin, cch) + err := sshutils.ForwardRequests(ctx, sin, tracessh.NewTraceChannel(cch, tracing.WithTracerProvider(s.tracerProvider))) if err != nil { s.log.WithError(err).Debug("Failed to forward ssh request from server during X11 forwarding") } @@ -1175,7 +1175,7 @@ func (s *Server) handleX11Forward(ctx context.Context, ch ssh.Channel, req *ssh. } // send X11 forwarding request to remote - ok, err := sshutils.ForwardRequest(scx.RemoteSession, req) + ok, err := sshutils.ForwardRequest(ctx, scx.RemoteSession, req) if err != nil { return trace.Wrap(err) } else if !ok { @@ -1218,22 +1218,22 @@ func (s *Server) handleSubsystem(ctx context.Context, ch ssh.Channel, req *ssh.R return nil } -func (s *Server) handleTracingRequest(req *ssh.Request, ctx *srv.ServerContext) error { - if _, err := ctx.RemoteSession.SendRequest(req.Type, false, req.Payload); err != nil { +func (s *Server) handleTracingRequest(ctx context.Context, req *ssh.Request, scx *srv.ServerContext) error { + if _, err := scx.RemoteSession.SendRequest(ctx, req.Type, false, req.Payload); err != nil { s.log.WithError(err).Debugf("Unable to set forward tracing context") } return nil } -func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error { +func (s *Server) handleEnv(ctx context.Context, ch ssh.Channel, req *ssh.Request, scx *srv.ServerContext) error { var e sshutils.EnvReqParams if err := ssh.Unmarshal(req.Payload, &e); err != nil { - ctx.Error(err) + scx.Error(err) return trace.Wrap(err, "failed to parse env request") } - err := ctx.RemoteSession.Setenv(e.Name, e.Value) + err := scx.RemoteSession.Setenv(ctx, e.Name, e.Value) if err != nil { s.log.Debugf("Unable to set environment variable: %v: %v", e.Name, e.Value) } diff --git a/lib/srv/forward/subsystem.go b/lib/srv/forward/subsystem.go index d88130701252a..782a6624f1ce7 100644 --- a/lib/srv/forward/subsystem.go +++ b/lib/srv/forward/subsystem.go @@ -22,11 +22,12 @@ import ( "golang.org/x/crypto/ssh" + "github.com/gravitational/trace" + "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv" - "github.com/gravitational/trace" log "github.com/sirupsen/logrus" ) @@ -77,7 +78,7 @@ func (r *remoteSubsystem) Start(ctx context.Context, channel ssh.Channel) error // request the subsystem from the remote node. if successful, the user can // interact with the remote subsystem with stdin, stdout, and stderr. - err = session.RequestSubsystem(r.subsytemName) + err = session.RequestSubsystem(ctx, r.subsytemName) if err != nil { // emit an event to the audit log with the reason remote execution failed r.emitAuditEvent(err) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 82b6b63a6fa7b..a72092d82f64d 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1637,11 +1637,11 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, case tracessh.TracingRequest: return nil case sshutils.PTYRequest: - return s.termHandlers.HandlePTYReq(ch, req, serverContext) + return s.termHandlers.HandlePTYReq(ctx, ch, req, serverContext) case sshutils.ShellRequest: return s.termHandlers.HandleShell(ctx, ch, req, serverContext) case sshutils.WindowChangeRequest: - return s.termHandlers.HandleWinChange(ch, req, serverContext) + return s.termHandlers.HandleWinChange(ctx, ch, req, serverContext) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, serverContext) case sshutils.EnvRequest: @@ -1676,11 +1676,11 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, case sshutils.ExecRequest: return s.termHandlers.HandleExec(ctx, ch, req, serverContext) case sshutils.PTYRequest: - return s.termHandlers.HandlePTYReq(ch, req, serverContext) + return s.termHandlers.HandlePTYReq(ctx, ch, req, serverContext) case sshutils.ShellRequest: return s.termHandlers.HandleShell(ctx, ch, req, serverContext) case sshutils.WindowChangeRequest: - return s.termHandlers.HandleWinChange(ch, req, serverContext) + return s.termHandlers.HandleWinChange(ctx, ch, req, serverContext) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, serverContext) case sshutils.EnvRequest: diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index deec5b9aca497..b60a895299dda 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -529,10 +529,10 @@ func TestAgentForwardPermission(t *testing.T) { // to interoperate with OpenSSH, requests for agent forwarding always succeed. // however that does not mean the users agent will actually be forwarded. - require.NoError(t, agent.RequestAgentForwarding(se)) + require.NoError(t, agent.RequestAgentForwarding(se.Session)) // the output of env, we should not see SSH_AUTH_SOCK in the output - output, err := se.Output("env") + output, err := se.Output(ctx, "env") require.NoError(t, err) require.NotContains(t, string(output), "SSH_AUTH_SOCK") } @@ -580,17 +580,18 @@ func TestMaxSessions(t *testing.T) { func TestExecLongCommand(t *testing.T) { t.Parallel() f := newFixture(t) + ctx := context.Background() // Get the path to where the "echo" command is on disk. echoPath, err := exec.LookPath("echo") require.NoError(t, err) - se, err := f.ssh.clt.NewSession(context.Background()) + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) defer se.Close() // Write a message that larger than the maximum pipe size. - _, err = se.Output(fmt.Sprintf("%v %v", echoPath, strings.Repeat("a", maxPipeSize))) + _, err = se.Output(ctx, fmt.Sprintf("%v %v", echoPath, strings.Repeat("a", maxPipeSize))) require.NoError(t, err) } @@ -599,14 +600,15 @@ func TestExecLongCommand(t *testing.T) { func TestOpenExecSessionSetsSession(t *testing.T) { t.Parallel() f := newFixtureWithoutDiskBasedLogging(t) + ctx := context.Background() - se, err := f.ssh.clt.NewSession(context.Background()) + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) defer se.Close() // This will trigger an exec request, which will start a non-interactive session, // which then triggers setting env for SSH_SESSION_ID. - output, err := se.Output("env") + output, err := se.Output(ctx, "env") require.NoError(t, err) require.Contains(t, string(output), teleport.SSHSessionID) } @@ -630,7 +632,7 @@ func TestAgentForward(t *testing.T) { require.NoError(t, err) t.Cleanup(func() { se.Close() }) - err = agent.RequestAgentForwarding(se) + err = agent.RequestAgentForwarding(se.Session) require.NoError(t, err) // prepare to send virtual "keyboard input" into the shell: @@ -639,7 +641,7 @@ func TestAgentForward(t *testing.T) { t.Cleanup(func() { keyboard.Close() }) // start interactive SSH session (new shell): - err = se.Shell() + err = se.Shell(ctx) require.NoError(t, err) // create a temp file to collect the shell output into: @@ -819,7 +821,7 @@ func x11EchoSession(ctx context.Context, t *testing.T, clt *tracessh.Client) x11 // Client requests x11 forwarding for the server session. clientXAuthEntry, err := x11.NewFakeXAuthEntry(x11.Display{}) require.NoError(t, err) - err = x11.RequestForwarding(se, clientXAuthEntry) + err = x11.RequestForwarding(se.Session, clientXAuthEntry) require.NoError(t, err) // prepare to send virtual "keyboard input" into the shell: @@ -827,7 +829,7 @@ func x11EchoSession(ctx context.Context, t *testing.T, clt *tracessh.Client) x11 require.NoError(t, err) // start interactive SSH session with x11 forwarding enabled (new shell): - err = se.Shell() + err = se.Shell(ctx) require.NoError(t, err) // create a temp file to collect the shell output into: @@ -987,19 +989,21 @@ func TestKeyAlgorithms(t *testing.T) { func TestInvalidSessionID(t *testing.T) { t.Parallel() f := newFixture(t) + ctx := context.Background() - session, err := f.ssh.clt.NewSession(context.Background()) + session, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) - err = session.Setenv(sshutils.SessionEnvVar, "foo") + err = session.Setenv(ctx, sshutils.SessionEnvVar, "foo") require.NoError(t, err) - err = session.Shell() + err = session.Shell(ctx) require.Error(t, err) } func TestSessionHijack(t *testing.T) { t.Parallel() + ctx := context.Background() _, err := user.Lookup(teleportTestUser) if err != nil { t.Skipf("user %v is not found, skipping test", teleportTestUser) @@ -1018,22 +1022,22 @@ func TestSessionHijack(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - client, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig) + client, err := tracessh.Dial(ctx, "tcp", f.ssh.srv.Addr(), sshConfig) require.NoError(t, err) defer func() { err := client.Close() require.NoError(t, err) }() - se, err := client.NewSession(context.Background()) + se, err := client.NewSession(ctx) require.NoError(t, err) defer se.Close() firstSessionID := string(sess.NewID()) - err = se.Setenv(sshutils.SessionEnvVar, firstSessionID) + err = se.Setenv(ctx, sshutils.SessionEnvVar, firstSessionID) require.NoError(t, err) - err = se.Shell() + err = se.Shell(ctx) require.NoError(t, err) // user 2 does not have s.user as a listed principal @@ -1046,33 +1050,34 @@ func TestSessionHijack(t *testing.T) { HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - client2, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), sshConfig2) + client2, err := tracessh.Dial(ctx, "tcp", f.ssh.srv.Addr(), sshConfig2) require.NoError(t, err) defer func() { err := client2.Close() require.NoError(t, err) }() - se2, err := client2.NewSession(context.Background()) + se2, err := client2.NewSession(ctx) require.NoError(t, err) defer se2.Close() - err = se2.Setenv(sshutils.SessionEnvVar, firstSessionID) + err = se2.Setenv(ctx, sshutils.SessionEnvVar, firstSessionID) require.NoError(t, err) // attempt to hijack, should return error - err = se2.Shell() + err = se2.Shell(ctx) require.Error(t, err) } // testClient dials targetAddr via proxyAddr and executes 2+3 command func testClient(t *testing.T, f *sshTestFixture, proxyAddr, targetAddr, remoteAddr string, sshConfig *ssh.ClientConfig) { + ctx := context.Background() // Connect to node using registered address - client, err := tracessh.Dial(context.Background(), "tcp", proxyAddr, sshConfig) + client, err := tracessh.Dial(ctx, "tcp", proxyAddr, sshConfig) require.NoError(t, err) defer client.Close() - se, err := client.NewSession(context.Background()) + se, err := client.NewSession(ctx) require.NoError(t, err) defer se.Close() @@ -1083,7 +1088,7 @@ func testClient(t *testing.T, f *sshTestFixture, proxyAddr, targetAddr, remoteAd require.NoError(t, err) // Request opening TCP connection to the remote host - require.NoError(t, se.RequestSubsystem(fmt.Sprintf("proxy:%v", targetAddr))) + require.NoError(t, se.RequestSubsystem(ctx, fmt.Sprintf("proxy:%v", targetAddr))) local, err := utils.ParseAddr("tcp://" + proxyAddr) require.NoError(t, err) @@ -1102,7 +1107,7 @@ func testClient(t *testing.T, f *sshTestFixture, proxyAddr, targetAddr, remoteAd // Open SSH connection via TCP conn, chans, reqs, err := tracessh.NewClientConn( - context.Background(), + ctx, pipeNetConn, f.ssh.srv.Addr(), sshConfig, @@ -1115,11 +1120,11 @@ func testClient(t *testing.T, f *sshTestFixture, proxyAddr, targetAddr, remoteAd require.NoError(t, err) defer client2.Close() - se2, err := client2.NewSession(context.Background()) + se2, err := client2.NewSession(ctx) require.NoError(t, err) defer se2.Close() - out, err := se2.Output("echo hello") + out, err := se2.Output(ctx, "echo hello") require.NoError(t, err) require.Equal(t, "hello\n", string(out)) @@ -1340,31 +1345,33 @@ func TestProxyDirectAccess(t *testing.T) { // TestPTY requests PTY for an interactive session func TestPTY(t *testing.T) { t.Parallel() + ctx := context.Background() f := newFixture(t) - se, err := f.ssh.clt.NewSession(context.Background()) + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) defer se.Close() // request PTY with valid size - require.NoError(t, se.RequestPty("xterm", 30, 30, ssh.TerminalModes{})) + require.NoError(t, se.RequestPty(ctx, "xterm", 30, 30, ssh.TerminalModes{})) // request PTY with invalid size, should still work (selects defaults) - require.NoError(t, se.RequestPty("xterm", 0, 0, ssh.TerminalModes{})) + require.NoError(t, se.RequestPty(ctx, "xterm", 0, 0, ssh.TerminalModes{})) } // TestEnv requests setting environment variables. (We are currently ignoring these requests) func TestEnv(t *testing.T) { t.Parallel() + ctx := context.Background() f := newFixture(t) - se, err := f.ssh.clt.NewSession(context.Background()) + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) defer se.Close() - require.NoError(t, se.Setenv("HOME", "/")) + require.NoError(t, se.Setenv(ctx, "HOME", "/")) } // TestNoAuth tries to log in with no auth methods and should be rejected @@ -1390,19 +1397,20 @@ func TestPasswordAuth(t *testing.T) { func TestClientDisconnect(t *testing.T) { t.Parallel() f := newFixtureWithoutDiskBasedLogging(t) + ctx := context.Background() config := &ssh.ClientConfig{ User: f.user, Auth: []ssh.AuthMethod{ssh.PublicKeys(f.up.certSigner)}, HostKeyCallback: ssh.FixedHostKey(f.signer.PublicKey()), } - clt, err := tracessh.Dial(context.Background(), "tcp", f.ssh.srv.Addr(), config) + clt, err := tracessh.Dial(ctx, "tcp", f.ssh.srv.Addr(), config) require.NoError(t, err) require.NotNil(t, clt) - se, err := f.ssh.clt.NewSession(context.Background()) + se, err := f.ssh.clt.NewSession(ctx) require.NoError(t, err) - require.NoError(t, se.Shell()) + require.NoError(t, se.Shell(ctx)) require.NoError(t, clt.Close()) } @@ -1502,7 +1510,7 @@ func TestLimiter(t *testing.T) { se0, err := clt0.NewSession(ctx) require.NoError(t, err) - require.NoError(t, se0.Shell()) + require.NoError(t, se0.Shell(ctx)) // current connections = 1 clt, err := tracessh.Dial(ctx, "tcp", srv.Addr(), config) @@ -1510,7 +1518,7 @@ func TestLimiter(t *testing.T) { require.NotNil(t, clt) se, err := clt.NewSession(ctx) require.NoError(t, err) - require.NoError(t, se.Shell()) + require.NoError(t, se.Shell(ctx)) // current connections = 2 _, err = tracessh.Dial(ctx, "tcp", srv.Addr(), config) @@ -1529,7 +1537,7 @@ func TestLimiter(t *testing.T) { require.NotNil(t, clt) se, err = clt.NewSession(ctx) require.NoError(t, err) - require.NoError(t, se.Shell()) + require.NoError(t, se.Shell(ctx)) // current connections = 2 _, err = tracessh.Dial(ctx, "tcp", srv.Addr(), config) @@ -1880,7 +1888,7 @@ func TestX11ProxySupport(t *testing.T) { require.NotNil(t, xchs) // Send an X11 forwarding request to the server - ok, err = sess.SendRequest(sshutils.X11ForwardRequest, true, nil) + ok, err = sess.SendRequest(ctx, sshutils.X11ForwardRequest, true, nil) require.NoError(t, err) require.True(t, ok) @@ -2008,11 +2016,11 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { // Request the PuTTY-specific "simple@putty.projects.tartarus.org" channel type // This request should be ignored, and the connection should remain open. - _, err = se.SendRequest(sshutils.PuTTYSimpleRequest, false, []byte{}) + _, err = se.SendRequest(ctx, sshutils.PuTTYSimpleRequest, false, []byte{}) require.NoError(t, err) // Request proxy subsystem routing TCP connection to the remote host - require.NoError(t, se.RequestSubsystem(fmt.Sprintf("proxy:%v", f.ssh.srvAddress))) + require.NoError(t, se.RequestSubsystem(ctx, fmt.Sprintf("proxy:%v", f.ssh.srvAddress))) local, err := utils.ParseAddr("tcp://" + proxy.Addr()) require.NoError(t, err) @@ -2044,7 +2052,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { require.NoError(t, err) defer se2.Close() - out, err := se2.Output("echo hello again") + out, err := se2.Output(ctx, "echo hello again") require.NoError(t, err) require.Equal(t, "hello again\n", string(out)) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index d79642a07c6f7..0d00c81b25594 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -301,7 +301,7 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann return trace.Wrap(err) } - err = sess.startExec(channel, scx, tempUser) + err = sess.startExec(ctx, channel, scx, tempUser) if err != nil { sess.Close() return trace.Wrap(err) @@ -342,8 +342,8 @@ func (s *SessionRegistry) GetTerminalSize(sessionID string) (*term.Winsize, erro // NotifyWinChange is called to notify all members in the party that the PTY // size has changed. The notification is sent as a global SSH request and it // is the responsibility of the client to update it's window size upon receipt. -func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *ServerContext) error { - session := ctx.getSession() +func (s *SessionRegistry) NotifyWinChange(ctx context.Context, params rsession.TerminalParams, scx *ServerContext) error { + session := scx.getSession() if session == nil { s.log.Debug("Unable to update window size, no session found in context.") return nil @@ -355,13 +355,13 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S Metadata: apievents.Metadata{ Type: events.ResizeEvent, Code: events.TerminalResizeCode, - ClusterName: ctx.ClusterName, + ClusterName: scx.ClusterName, }, ServerMetadata: session.serverMeta, SessionMetadata: apievents.SessionMetadata{ SessionID: string(sid), }, - UserMetadata: ctx.Identity.GetUserMetadata(), + UserMetadata: scx.Identity.GetUserMetadata(), TerminalSize: params.Serialize(), } @@ -372,7 +372,7 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S } // Update the size of the server side PTY. - err := session.term.SetWinSize(params) + err := session.term.SetWinSize(ctx, params) if err != nil { return trace.Wrap(err) } @@ -380,7 +380,7 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S // If sessions are being recorded at the proxy, sessions can not be shared. // In that situation, PTY size information does not need to be propagated // back to all clients and we can return right away. - if services.IsRecordAtProxy(ctx.SessionRecordingConfig.GetMode()) { + if services.IsRecordAtProxy(scx.SessionRecordingConfig.GetMode()) { return nil } @@ -389,7 +389,7 @@ func (s *SessionRegistry) NotifyWinChange(params rsession.TerminalParams, ctx *S // OpenSSH clients will ignore this and not update their own local PTY. for _, p := range session.getParties() { // Don't send the window change notification back to the originator. - if p.ctx.ID() == ctx.ID() { + if p.ctx.ID() == scx.ID() { continue } @@ -655,7 +655,7 @@ func (s *session) Stop() { if err := s.term.Close(); err != nil { s.log.WithError(err).Debug("Failed to close the shell") } - if err := s.term.Kill(); err != nil { + if err := s.term.Kill(context.TODO()); err != nil { s.log.WithError(err).Debug("Failed to kill the shell") } } @@ -968,7 +968,7 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser s.BroadcastMessage("Creating session with ID: %v...", s.id) s.BroadcastMessage(SessionControlsInfoBroadcast) - if err := s.startTerminal(scx); err != nil { + if err := s.startTerminal(ctx, scx); err != nil { return trace.Wrap(err) } @@ -1070,22 +1070,22 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser return nil } -func (s *session) startTerminal(ctx *ServerContext) error { +func (s *session) startTerminal(ctx context.Context, scx *ServerContext) error { s.mu.Lock() defer s.mu.Unlock() // allocate a terminal or take the one previously allocated via a // separate "allocate TTY" SSH request var err error - if s.term = ctx.GetTerm(); s.term != nil { - ctx.SetTerm(nil) - } else if s.term, err = NewTerminal(ctx); err != nil { - ctx.Infof("Unable to allocate new terminal: %v", err) + if s.term = scx.GetTerm(); s.term != nil { + scx.SetTerm(nil) + } else if s.term, err = NewTerminal(scx); err != nil { + scx.Infof("Unable to allocate new terminal: %v", err) return trace.Wrap(err) } - if err := s.term.Run(); err != nil { - ctx.Errorf("Unable to run shell command: %v.", err) + if err := s.term.Run(ctx); err != nil { + scx.Errorf("Unable to run shell command: %v.", err) return trace.ConvertSystemError(err) } @@ -1164,49 +1164,49 @@ func newEventOnlyRecorder(s *session, ctx *ServerContext) (events.StreamWriter, return rec, nil } -func (s *session) startExec(channel ssh.Channel, ctx *ServerContext, tempUser *user.User) error { +func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *ServerContext, tempUser *user.User) error { // Emit a session.start event for the exec session. - s.emitSessionStartEvent(ctx) + s.emitSessionStartEvent(scx) // Start execution. If the program failed to start, send that result back. // Note this is a partial start. Teleport will have re-exec'ed itself and // wait until it's been placed in a cgroup and told to continue. - result, err := ctx.ExecRequest.Start(channel) + result, err := scx.ExecRequest.Start(ctx, channel) if err != nil { return trace.Wrap(err) } if result != nil { - ctx.Debugf("Exec request (%v) result: %v.", ctx.ExecRequest, result) - ctx.SendExecResult(*result) + scx.Debugf("Exec request (%v) result: %v.", scx.ExecRequest, result) + scx.SendExecResult(*result) } // Open a BPF recording session. If BPF was not configured, not available, // or running in a recording proxy, OpenSession is a NOP. sessionContext := &bpf.SessionContext{ - Context: ctx.srv.Context(), - PID: ctx.ExecRequest.PID(), + Context: scx.srv.Context(), + PID: scx.ExecRequest.PID(), Emitter: s.Recorder(), - Namespace: ctx.srv.GetNamespace(), + Namespace: scx.srv.GetNamespace(), SessionID: string(s.id), - ServerID: ctx.srv.HostUUID(), - Login: ctx.Identity.Login, - User: ctx.Identity.TeleportUser, - Events: ctx.Identity.AccessChecker.EnhancedRecordingSet(), + ServerID: scx.srv.HostUUID(), + Login: scx.Identity.Login, + User: scx.Identity.TeleportUser, + Events: scx.Identity.AccessChecker.EnhancedRecordingSet(), } - cgroupID, err := ctx.srv.GetBPF().OpenSession(sessionContext) + cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext) if err != nil { - ctx.Errorf("Failed to open enhanced recording (exec) session: %v: %v.", ctx.ExecRequest.GetCommand(), err) + scx.Errorf("Failed to open enhanced recording (exec) session: %v: %v.", scx.ExecRequest.GetCommand(), err) return trace.Wrap(err) } // If a cgroup ID was assigned then enhanced session recording was enabled. if cgroupID > 0 { s.setHasEnhancedRecording(true) - ctx.srv.GetRestrictedSessionManager().OpenSession(sessionContext, cgroupID) + scx.srv.GetRestrictedSessionManager().OpenSession(sessionContext, cgroupID) } if tempUser != nil { - sessionUser, err := user.Lookup(ctx.Identity.Login) + sessionUser, err := user.Lookup(scx.Identity.Login) if err != nil { return trace.Wrap(err) } @@ -1217,26 +1217,26 @@ func (s *session) startExec(channel ssh.Channel, ctx *ServerContext, tempUser *u } // Process has been placed in a cgroup, continue execution. - ctx.ExecRequest.Continue() + scx.ExecRequest.Continue() // Process is running, wait for it to stop. go func() { - result = ctx.ExecRequest.Wait() + result = scx.ExecRequest.Wait() if result != nil { - ctx.SendExecResult(*result) + scx.SendExecResult(*result) } // Wait a little bit to let all events filter through before closing the // BPF session so everything can be recorded. time.Sleep(2 * time.Second) - ctx.srv.GetRestrictedSessionManager().CloseSession(sessionContext, cgroupID) + scx.srv.GetRestrictedSessionManager().CloseSession(sessionContext, cgroupID) // Close the BPF recording session. If BPF was not configured, not available, // or running in a recording proxy, this is simply a NOP. - err = ctx.srv.GetBPF().CloseSession(sessionContext) + err = scx.srv.GetBPF().CloseSession(sessionContext) if err != nil { - ctx.Errorf("Failed to close enhanced recording (exec) session: %v: %v.", s.id, err) + scx.Errorf("Failed to close enhanced recording (exec) session: %v: %v.", s.id, err) } s.emitSessionEndEvent() diff --git a/lib/srv/term.go b/lib/srv/term.go index 0a13b4165cdd1..fe02936a86524 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -17,6 +17,7 @@ limitations under the License. package srv import ( + "context" "io" "os" "os/exec" @@ -28,6 +29,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/lib/services" rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" @@ -54,7 +56,7 @@ type Terminal interface { AddParty(delta int) // Run will run the terminal. - Run() error + Run(ctx context.Context) error // Wait will block until the terminal is complete. Wait() (*ExecResult, error) @@ -64,7 +66,7 @@ type Terminal interface { Continue() // Kill will force kill the terminal. - Kill() error + Kill(ctx context.Context) error // PTY returns the PTY backing the terminal. PTY() io.ReadWriter @@ -82,7 +84,7 @@ type Terminal interface { GetWinSize() (*term.Winsize, error) // SetWinSize sets the window size of the terminal. - SetWinSize(params rsession.TerminalParams) error + SetWinSize(ctx context.Context, params rsession.TerminalParams) error // GetTerminalParams is a fast call to get cached terminal parameters // and avoid extra system call. @@ -169,7 +171,7 @@ func (t *terminal) AddParty(delta int) { } // Run will run the terminal. -func (t *terminal) Run() error { +func (t *terminal) Run(ctx context.Context) error { var err error defer t.closeTTY() @@ -229,7 +231,7 @@ func (t *terminal) Continue() { } // Kill will force kill the terminal. -func (t *terminal) Kill() error { +func (t *terminal) Kill(ctx context.Context) error { if t.cmd != nil && t.cmd.Process != nil { if err := t.cmd.Process.Kill(); err != nil { if err.Error() != "os: process already finished" { @@ -314,7 +316,7 @@ func (t *terminal) GetWinSize() (*term.Winsize, error) { } // SetWinSize sets the window size of the terminal. -func (t *terminal) SetWinSize(params rsession.TerminalParams) error { +func (t *terminal) SetWinSize(ctx context.Context, params rsession.TerminalParams) error { t.mu.Lock() defer t.mu.Unlock() if t.pty == nil { @@ -417,7 +419,7 @@ type remoteTerminal struct { ctx *ServerContext - session *ssh.Session + session *tracessh.Session params rsession.TerminalParams termModes ssh.TerminalModes ptyBuffer *ptyBuffer @@ -458,9 +460,9 @@ func (b *ptyBuffer) Write(p []byte) (n int, err error) { return b.w.Write(p) } -func (t *remoteTerminal) Run() error { +func (t *remoteTerminal) Run(ctx context.Context) error { // prepare the remote remote session by setting environment variables - t.prepareRemoteSession(t.session, t.ctx) + t.prepareRemoteSession(ctx, t.session, t.ctx) // combine stdout and stderr stdout, err := t.session.StdoutPipe() @@ -484,7 +486,7 @@ func (t *remoteTerminal) Run() error { t.termType = defaultTerm } - if err := t.session.RequestPty(t.termType, t.params.H, t.params.W, t.termModes); err != nil { + if err := t.session.RequestPty(ctx, t.termType, t.params.H, t.params.W, t.termModes); err != nil { return trace.Wrap(err) } @@ -492,7 +494,7 @@ func (t *remoteTerminal) Run() error { if t.ctx.ExecRequest.GetCommand() != "" { t.log.Debugf("Running exec request within a PTY") - if err := t.session.Start(t.ctx.ExecRequest.GetCommand()); err != nil { + if err := t.session.Start(ctx, t.ctx.ExecRequest.GetCommand()); err != nil { return trace.Wrap(err) } @@ -501,7 +503,7 @@ func (t *remoteTerminal) Run() error { // we want an interactive shell t.log.Debugf("Requesting an interactive terminal of type %v", t.termType) - if err := t.session.Shell(); err != nil { + if err := t.session.Shell(ctx); err != nil { return trace.Wrap(err) } return nil @@ -532,8 +534,8 @@ func (t *remoteTerminal) Wait() (*ExecResult, error) { // Continue does nothing for remote command execution. func (t *remoteTerminal) Continue() {} -func (t *remoteTerminal) Kill() error { - err := t.session.Signal(ssh.SIGKILL) +func (t *remoteTerminal) Kill(ctx context.Context) error { + err := t.session.Signal(ctx, ssh.SIGKILL) if err != nil { return trace.Wrap(err) } @@ -575,11 +577,11 @@ func (t *remoteTerminal) GetWinSize() (*term.Winsize, error) { return t.params.Winsize(), nil } -func (t *remoteTerminal) SetWinSize(params rsession.TerminalParams) error { +func (t *remoteTerminal) SetWinSize(ctx context.Context, params rsession.TerminalParams) error { t.mu.Lock() defer t.mu.Unlock() - err := t.windowChange(params.W, params.H) + err := t.windowChange(ctx, params.W, params.H) if err != nil { return trace.Wrap(err) } @@ -611,7 +613,7 @@ func (t *remoteTerminal) SetTerminalModes(termModes ssh.TerminalModes) { t.termModes = termModes } -func (t *remoteTerminal) windowChange(w int, h int) error { +func (t *remoteTerminal) windowChange(ctx context.Context, w int, h int) error { type windowChangeRequest struct { W uint32 H uint32 @@ -624,22 +626,22 @@ func (t *remoteTerminal) windowChange(w int, h int) error { Wpx: uint32(w * 8), Hpx: uint32(h * 8), } - _, err := t.session.SendRequest(sshutils.WindowChangeRequest, false, ssh.Marshal(&req)) + _, err := t.session.SendRequest(ctx, sshutils.WindowChangeRequest, false, ssh.Marshal(&req)) return err } // prepareRemoteSession prepares the more session for execution. -func (t *remoteTerminal) prepareRemoteSession(session *ssh.Session, ctx *ServerContext) { +func (t *remoteTerminal) prepareRemoteSession(ctx context.Context, session *tracessh.Session, scx *ServerContext) { envs := map[string]string{ - teleport.SSHTeleportUser: ctx.Identity.TeleportUser, - teleport.SSHSessionWebproxyAddr: ctx.ProxyPublicAddress(), - teleport.SSHTeleportHostUUID: ctx.srv.ID(), - teleport.SSHTeleportClusterName: ctx.ClusterName, - teleport.SSHSessionID: string(ctx.SessionID()), + teleport.SSHTeleportUser: scx.Identity.TeleportUser, + teleport.SSHSessionWebproxyAddr: scx.ProxyPublicAddress(), + teleport.SSHTeleportHostUUID: scx.srv.ID(), + teleport.SSHTeleportClusterName: scx.ClusterName, + teleport.SSHSessionID: string(scx.SessionID()), } for k, v := range envs { - if err := session.Setenv(k, v); err != nil { + if err := session.Setenv(ctx, k, v); err != nil { t.log.Debugf("Unable to set environment variable: %v: %v", k, v) } } diff --git a/lib/srv/termhandlers.go b/lib/srv/termhandlers.go index 95193104389c1..0eebb453917a1 100644 --- a/lib/srv/termhandlers.go +++ b/lib/srv/termhandlers.go @@ -58,7 +58,7 @@ func (t *TermHandlers) HandleExec(ctx context.Context, ch ssh.Channel, req *ssh. // HandlePTYReq handles requests of type "pty-req" which allocate a TTY for // "exec" or "shell" requests. The "pty-req" includes the size of the TTY as // well as the terminal type requested. -func (t *TermHandlers) HandlePTYReq(ch ssh.Channel, req *ssh.Request, ctx *ServerContext) error { +func (t *TermHandlers) HandlePTYReq(ctx context.Context, ch ssh.Channel, req *ssh.Request, scx *ServerContext) error { // parse and extract the requested window size of the pty ptyRequest, err := parsePTYReq(req) if err != nil { @@ -74,28 +74,28 @@ func (t *TermHandlers) HandlePTYReq(ch ssh.Channel, req *ssh.Request, ctx *Serve if err != nil { return trace.Wrap(err) } - ctx.Debugf("Requested terminal %q of size %v", ptyRequest.Env, *params) + scx.Debugf("Requested terminal %q of size %v", ptyRequest.Env, *params) // get an existing terminal or create a new one - term := ctx.GetTerm() + term := scx.GetTerm() if term == nil { // a regular or forwarding terminal will be allocated - term, err = NewTerminal(ctx) + term, err = NewTerminal(scx) if err != nil { return trace.Wrap(err) } - ctx.SetTerm(term) - ctx.termAllocated = true + scx.SetTerm(term) + scx.termAllocated = true } - if err := term.SetWinSize(*params); err != nil { - ctx.Errorf("Failed setting window size: %v", err) + if err := term.SetWinSize(ctx, *params); err != nil { + scx.Errorf("Failed setting window size: %v", err) } term.SetTermType(ptyRequest.Env) term.SetTerminalModes(termModes) // update the session - if err := t.SessionRegistry.NotifyWinChange(*params, ctx); err != nil { - ctx.Errorf("Unable to update session: %v", err) + if err := t.SessionRegistry.NotifyWinChange(ctx, *params, scx); err != nil { + scx.Errorf("Unable to update session: %v", err) } return nil @@ -124,7 +124,7 @@ func (t *TermHandlers) HandleShell(ctx context.Context, ch ssh.Channel, req *ssh // HandleWinChange handles requests of type "window-change" which update the // size of the PTY running on the server and update any other members in the // party. -func (t *TermHandlers) HandleWinChange(ch ssh.Channel, req *ssh.Request, ctx *ServerContext) error { +func (t *TermHandlers) HandleWinChange(ctx context.Context, ch ssh.Channel, req *ssh.Request, scx *ServerContext) error { params, err := parseWinChange(req) if err != nil { return trace.Wrap(err) @@ -132,7 +132,7 @@ func (t *TermHandlers) HandleWinChange(ch ssh.Channel, req *ssh.Request, ctx *Se // Update any other members in the party that the window size has changed // and to update their terminal windows accordingly. - err = t.SessionRegistry.NotifyWinChange(*params, ctx) + err = t.SessionRegistry.NotifyWinChange(ctx, *params, scx) if err != nil { return trace.Wrap(err) } diff --git a/lib/sshutils/forward.go b/lib/sshutils/forward.go index e22ed973721ef..43a9e1367b576 100644 --- a/lib/sshutils/forward.go +++ b/lib/sshutils/forward.go @@ -24,12 +24,12 @@ import ( // RequestForwarder represents a resource capable of sending // an ssh request such as an ssh.Channel or ssh.Session. type RequestForwarder interface { - SendRequest(name string, wantReply bool, payload []byte) (bool, error) + SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, error) } // ForwardRequest is a helper for forwarding a request across a session or channel. -func ForwardRequest(sender RequestForwarder, req *ssh.Request) (bool, error) { - reply, err := sender.SendRequest(req.Type, req.WantReply, req.Payload) +func ForwardRequest(ctx context.Context, sender RequestForwarder, req *ssh.Request) (bool, error) { + reply, err := sender.SendRequest(ctx, req.Type, req.WantReply, req.Payload) if err != nil || !req.WantReply { return reply, trace.Wrap(err) } @@ -48,7 +48,7 @@ func ForwardRequests(ctx context.Context, sin <-chan *ssh.Request, sender Reques } switch sreq.Type { case WindowChangeRequest: - if _, err := ForwardRequest(sender, sreq); err != nil { + if _, err := ForwardRequest(ctx, sender, sreq); err != nil { return trace.Wrap(err) } default: diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 90319812dff67..7b738660aaab9 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -167,7 +167,7 @@ type TerminalHandler struct { hostUUID string // sshSession holds the "shell" SSH channel to the node. - sshSession *ssh.Session + sshSession *tracessh.Session // terminalContext is used to signal when the terminal sesson is closing. terminalContext context.Context @@ -349,9 +349,9 @@ func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*clie // Save the *ssh.Session after the shell has been created. The session is // used to update all other parties window size to that of the web client and // to allow future window changes. - tc.OnShellCreated = func(s *ssh.Session, c *tracessh.Client, _ io.ReadWriteCloser) (bool, error) { + tc.OnShellCreated = func(s *tracessh.Session, c *tracessh.Client, _ io.ReadWriteCloser) (bool, error) { t.sshSession = s - t.windowChange(&t.params.Term) + t.windowChange(r.Context(), &t.params.Term) return false, nil } @@ -547,12 +547,13 @@ func (t *TerminalHandler) streamEvents(ws *websocket.Conn, tc *client.TeleportCl // windowChange is called when the browser window is resized. It sends a // "window-change" channel request to the server. -func (t *TerminalHandler) windowChange(params *session.TerminalParams) { +func (t *TerminalHandler) windowChange(ctx context.Context, params *session.TerminalParams) { if t.sshSession == nil { return } _, err := t.sshSession.SendRequest( + ctx, sshutils.WindowChangeRequest, false, ssh.Marshal(sshutils.WinChangeReqParams{ @@ -701,7 +702,7 @@ func (t *TerminalHandler) read(out []byte, ws *websocket.Conn) (n int, err error // Send the window change request in a goroutine so reads are not blocked // by network connectivity issues. - go t.windowChange(params) + go t.windowChange(context.TODO(), params) return 0, nil default: diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index f8454a511e2da..765a51002d2df 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -210,13 +210,13 @@ func sshProxy(ctx context.Context, tc *libclient.TeleportClient, sp sshProxyPara if err != nil { return trace.Wrap(err) } - err = agent.RequestAgentForwarding(sess) + err = agent.RequestAgentForwarding(sess.Session) if err != nil { return trace.Wrap(err) } sshUserHost := fmt.Sprintf("%s:%s", sp.targetHost, sp.targetPort) - if err = sess.RequestSubsystem(proxySubsystemName(sshUserHost, sp.clusterName)); err != nil { + if err = sess.RequestSubsystem(ctx, proxySubsystemName(sshUserHost, sp.clusterName)); err != nil { return trace.Wrap(err) } if err := proxySession(ctx, sess); err != nil { @@ -271,7 +271,7 @@ func makeSSHClient(ctx context.Context, conn net.Conn, addr string, cfg *ssh.Cli return tracessh.NewClient(cc, chs, reqs), nil } -func proxySession(ctx context.Context, sess *ssh.Session) error { +func proxySession(ctx context.Context, sess *tracessh.Session) error { stdout, err := sess.StdoutPipe() if err != nil { return trace.Wrap(err)