Skip to content

Commit

Permalink
Trace ssh sessions (#14966)
Browse files Browse the repository at this point in the history
Adds a wrapper around `ssh.Session` which injects tracing context
in a similar manner to the `ssh.Client` wrapper. All usages of
`ssh.Session` have now been replaced and have the appropriate
`context.Context` passed along

Part of #12241
  • Loading branch information
rosstimothy authored Aug 4, 2022
1 parent 361ea8e commit 0cb248d
Show file tree
Hide file tree
Showing 22 changed files with 772 additions and 228 deletions.
188 changes: 179 additions & 9 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ func isTracingSupported(clt *ssh.Client) (tracingCapability, error) {
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

_, span := tracer.Start(
ctx, span := tracer.Start(
ctx,
"ssh.DialContext",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
Expand All @@ -121,12 +120,19 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err
)
defer span.End()

conn, err := c.Client.Dial(n, addr)
if err != nil {
return nil, trace.Wrap(err)
c.mu.RLock()
// create the wrapper while the lock is held
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
opts: c.opts,
ctx: ctx,
contexts: make(map[string][]context.Context),
}
c.mu.RUnlock()

return conn, nil
conn, err := wrapper.Dial(n, addr)
return conn, trace.Wrap(err)
}

// SendRequest sends a global request, and returns the
Expand Down Expand Up @@ -204,10 +210,10 @@ func (c *Client) OpenChannel(ctx context.Context, name string, data []byte) (*Ch

// NewSession creates a new SSH session that is passed tracing context
// so that spans may be correlated properly over the ssh connection.
func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) {
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

_, span := tracer.Start(
ctx, span := tracer.Start(
ctx,
"ssh.NewSession",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
Expand Down Expand Up @@ -249,8 +255,172 @@ func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) {
c.capability = capability
}

// create the wrapper while the lock is still held
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
opts: c.opts,
ctx: ctx,
contexts: make(map[string][]context.Context),
}

c.mu.Unlock()

session, err := c.Client.NewSession()
// get a session from the wrapper
session, err := wrapper.NewSession()
return session, trace.Wrap(err)
}

// clientWrapper wraps the ssh.Conn for individual ssh.Client
// operations to intercept internal calls by the ssh.Client to
// OpenChannel. This allows for internal operations within the
// ssh.Client to have their payload wrapped in an Envelope to
// forward tracing context when tracing is enabled.
type clientWrapper struct {
// Conn is the ssh.Conn that requests will be forwarded to
ssh.Conn
// capability the tracingCapability of the ssh server
capability tracingCapability
// ctx the context which should be used to create spans from
ctx context.Context
// opts the tracing options to use for creating spans with
opts []tracing.Option

// lock protects the context queue
lock sync.Mutex
// contexts a LIFO queue of context.Context per channel name.
contexts map[string][]context.Context
}

// NewSession opens a new Session for this client.
func (c *clientWrapper) NewSession() (*Session, error) {
// create a client that will defer to us when
// opening the "session" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

session, err := client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}

// wrap the session so all session requests on the channel
// can be traced
return &Session{
Session: session,
wrapper: c,
}, nil
}

// Dial initiates a connection to the addr from the remote host.
func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) {
// create a client that will defer to us when
// opening the "direct-tcpip" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

return client.Dial(n, addr)
}

// addContext adds the provided context.Context to the end of
// the list for the provided channel name
func (c *clientWrapper) addContext(ctx context.Context, name string) {
c.lock.Lock()
defer c.lock.Unlock()

c.contexts[name] = append(c.contexts[name], ctx)
}

// nextContext returns the first context.Context for the provided
// channel name
func (c *clientWrapper) nextContext(name string) context.Context {
c.lock.Lock()
defer c.lock.Unlock()

contexts, ok := c.contexts[name]
switch {
case !ok, len(contexts) <= 0:
return context.Background()
case len(contexts) == 1:
delete(c.contexts, name)
return contexts[0]
default:
c.contexts[name] = contexts[1:]
return contexts[0]
}
}

// OpenChannel tries to open a channel. If tracing is enabled,
// the provided payload is wrapped in an Envelope to forward
// any tracing context.
func (c *clientWrapper) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
config := tracing.NewConfig(c.opts)
tracer := config.TracerProvider.Tracer(instrumentationName)
ctx, span := tracer.Start(
c.ctx,
fmt.Sprintf("ssh.OpenChannel/%s", name),
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
append(
peerAttr(c.Conn.RemoteAddr()),
semconv.RPCServiceKey.String("ssh.Client"),
semconv.RPCMethodKey.String("OpenChannel"),
semconv.RPCSystemKey.String("ssh"),
)...,
),
)
defer span.End()

ch, reqs, err := c.Conn.OpenChannel(name, wrapPayload(ctx, c.capability, config.TextMapPropagator, data))
if err != nil {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
}

return channelWrapper{
Channel: ch,
manager: c,
}, reqs, err
}

// channelWrapper wraps an ssh.Channel to allow for requests to
// contain tracing context.
type channelWrapper struct {
ssh.Channel
manager *clientWrapper
}

// SendRequest sends a channel request. If tracing is enabled,
// the provided payload is wrapped in an Envelope to forward
// any tracing context.
//
// It is the callers' responsibility to ensure that addContext is
// called with the appropriate context.Context prior to any
// requests being sent along the channel.
func (c channelWrapper) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
config := tracing.NewConfig(c.manager.opts)
ctx, span := config.TracerProvider.Tracer(instrumentationName).Start(
c.manager.nextContext(name),
fmt.Sprintf("ssh.ChannelRequest/%s", name),
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
attribute.Bool("want_reply", wantReply),
semconv.RPCServiceKey.String("ssh.Channel"),
semconv.RPCMethodKey.String("SendRequest"),
semconv.RPCSystemKey.String("ssh"),
),
)
defer span.End()

ok, err := c.Channel.SendRequest(name, wantReply, wrapPayload(ctx, c.manager.capability, config.TextMapPropagator, payload))
if err != nil {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
}

return ok, err
}
9 changes: 4 additions & 5 deletions api/observability/tracing/ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
)

func TestIsTracingSupported(t *testing.T) {

rejected := &ssh.OpenChannelError{
Reason: ssh.Prohibited,
Message: "rejected!",
Expand Down Expand Up @@ -186,11 +185,11 @@ func TestNewSession(t *testing.T) {

cases := []struct {
name string
assertionFunc func(t *testing.T, clt *Client, session *ssh.Session, err error)
assertionFunc func(t *testing.T, clt *Client, session *Session, err error)
}{
{
name: "session prohibited",
assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) {
assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) {
// creating a new session should return any errors captured when creating the client
// and not actually probe the server
require.Error(t, err)
Expand All @@ -202,7 +201,7 @@ func TestNewSession(t *testing.T) {
},
{
name: "other failure to open tracing channel",
assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) {
assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) {
// this time through we should probe the server without getting a prohibited error,
// but things still failed, so we shouldn't know the capability
require.NoError(t, err)
Expand All @@ -214,7 +213,7 @@ func TestNewSession(t *testing.T) {
},
{
name: "active session",
assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) {
assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) {
// all is good now, we should have an active session
require.NoError(t, err)
require.NotNil(t, sess)
Expand Down
Loading

0 comments on commit 0cb248d

Please sign in to comment.