Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v9] Trace ssh sessions #15615

Merged
merged 2 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 179 additions & 9 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ func isTracingSupported(clt *ssh.Client) (tracingCapability, error) {
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

_, span := tracer.Start(
ctx, span := tracer.Start(
ctx,
"ssh.DialContext",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
Expand All @@ -121,12 +120,19 @@ func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, err
)
defer span.End()

conn, err := c.Client.Dial(n, addr)
if err != nil {
return nil, trace.Wrap(err)
c.mu.RLock()
// create the wrapper while the lock is held
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
opts: c.opts,
ctx: ctx,
contexts: make(map[string][]context.Context),
}
c.mu.RUnlock()

return conn, nil
conn, err := wrapper.Dial(n, addr)
return conn, trace.Wrap(err)
}

// SendRequest sends a global request, and returns the
Expand Down Expand Up @@ -204,10 +210,10 @@ func (c *Client) OpenChannel(ctx context.Context, name string, data []byte) (*Ch

// NewSession creates a new SSH session that is passed tracing context
// so that spans may be correlated properly over the ssh connection.
func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) {
func (c *Client) NewSession(ctx context.Context) (*Session, error) {
tracer := tracing.NewConfig(c.opts).TracerProvider.Tracer(instrumentationName)

_, span := tracer.Start(
ctx, span := tracer.Start(
ctx,
"ssh.NewSession",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
Expand Down Expand Up @@ -249,8 +255,172 @@ func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) {
c.capability = capability
}

// create the wrapper while the lock is still held
wrapper := &clientWrapper{
capability: c.capability,
Conn: c.Client.Conn,
opts: c.opts,
ctx: ctx,
contexts: make(map[string][]context.Context),
}

c.mu.Unlock()

session, err := c.Client.NewSession()
// get a session from the wrapper
session, err := wrapper.NewSession()
return session, trace.Wrap(err)
}

// clientWrapper wraps the ssh.Conn for individual ssh.Client
// operations to intercept internal calls by the ssh.Client to
// OpenChannel. This allows for internal operations within the
// ssh.Client to have their payload wrapped in an Envelope to
// forward tracing context when tracing is enabled.
type clientWrapper struct {
// Conn is the ssh.Conn that requests will be forwarded to
ssh.Conn
// capability the tracingCapability of the ssh server
capability tracingCapability
// ctx the context which should be used to create spans from
ctx context.Context
// opts the tracing options to use for creating spans with
opts []tracing.Option

// lock protects the context queue
lock sync.Mutex
// contexts a LIFO queue of context.Context per channel name.
contexts map[string][]context.Context
}

// NewSession opens a new Session for this client.
func (c *clientWrapper) NewSession() (*Session, error) {
// create a client that will defer to us when
// opening the "session" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

session, err := client.NewSession()
if err != nil {
return nil, trace.Wrap(err)
}

// wrap the session so all session requests on the channel
// can be traced
return &Session{
Session: session,
wrapper: c,
}, nil
}

// Dial initiates a connection to the addr from the remote host.
func (c *clientWrapper) Dial(n, addr string) (net.Conn, error) {
// create a client that will defer to us when
// opening the "direct-tcpip" channel so that we
// can add an Envelope to the request
client := &ssh.Client{
Conn: c,
}

return client.Dial(n, addr)
}

// addContext adds the provided context.Context to the end of
// the list for the provided channel name
func (c *clientWrapper) addContext(ctx context.Context, name string) {
c.lock.Lock()
defer c.lock.Unlock()

c.contexts[name] = append(c.contexts[name], ctx)
}

// nextContext returns the first context.Context for the provided
// channel name
func (c *clientWrapper) nextContext(name string) context.Context {
c.lock.Lock()
defer c.lock.Unlock()

contexts, ok := c.contexts[name]
switch {
case !ok, len(contexts) <= 0:
return context.Background()
case len(contexts) == 1:
delete(c.contexts, name)
return contexts[0]
default:
c.contexts[name] = contexts[1:]
return contexts[0]
}
}

// OpenChannel tries to open a channel. If tracing is enabled,
// the provided payload is wrapped in an Envelope to forward
// any tracing context.
func (c *clientWrapper) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
config := tracing.NewConfig(c.opts)
tracer := config.TracerProvider.Tracer(instrumentationName)
ctx, span := tracer.Start(
c.ctx,
fmt.Sprintf("ssh.OpenChannel/%s", name),
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
append(
peerAttr(c.Conn.RemoteAddr()),
semconv.RPCServiceKey.String("ssh.Client"),
semconv.RPCMethodKey.String("OpenChannel"),
semconv.RPCSystemKey.String("ssh"),
)...,
),
)
defer span.End()

ch, reqs, err := c.Conn.OpenChannel(name, wrapPayload(ctx, c.capability, config.TextMapPropagator, data))
if err != nil {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
}

return channelWrapper{
Channel: ch,
manager: c,
}, reqs, err
}

// channelWrapper wraps an ssh.Channel to allow for requests to
// contain tracing context.
type channelWrapper struct {
ssh.Channel
manager *clientWrapper
}

// SendRequest sends a channel request. If tracing is enabled,
// the provided payload is wrapped in an Envelope to forward
// any tracing context.
//
// It is the callers' responsibility to ensure that addContext is
// called with the appropriate context.Context prior to any
// requests being sent along the channel.
func (c channelWrapper) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
config := tracing.NewConfig(c.manager.opts)
ctx, span := config.TracerProvider.Tracer(instrumentationName).Start(
c.manager.nextContext(name),
fmt.Sprintf("ssh.ChannelRequest/%s", name),
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
attribute.Bool("want_reply", wantReply),
semconv.RPCServiceKey.String("ssh.Channel"),
semconv.RPCMethodKey.String("SendRequest"),
semconv.RPCSystemKey.String("ssh"),
),
)
defer span.End()

ok, err := c.Channel.SendRequest(name, wantReply, wrapPayload(ctx, c.manager.capability, config.TextMapPropagator, payload))
if err != nil {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
}

return ok, err
}
9 changes: 4 additions & 5 deletions api/observability/tracing/ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
)

func TestIsTracingSupported(t *testing.T) {

rejected := &ssh.OpenChannelError{
Reason: ssh.Prohibited,
Message: "rejected!",
Expand Down Expand Up @@ -186,11 +185,11 @@ func TestNewSession(t *testing.T) {

cases := []struct {
name string
assertionFunc func(t *testing.T, clt *Client, session *ssh.Session, err error)
assertionFunc func(t *testing.T, clt *Client, session *Session, err error)
}{
{
name: "session prohibited",
assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) {
assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) {
// creating a new session should return any errors captured when creating the client
// and not actually probe the server
require.Error(t, err)
Expand All @@ -202,7 +201,7 @@ func TestNewSession(t *testing.T) {
},
{
name: "other failure to open tracing channel",
assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) {
assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) {
// this time through we should probe the server without getting a prohibited error,
// but things still failed, so we shouldn't know the capability
require.NoError(t, err)
Expand All @@ -214,7 +213,7 @@ func TestNewSession(t *testing.T) {
},
{
name: "active session",
assertionFunc: func(t *testing.T, clt *Client, sess *ssh.Session, err error) {
assertionFunc: func(t *testing.T, clt *Client, sess *Session, err error) {
// all is good now, we should have an active session
require.NoError(t, err)
require.NotNil(t, sess)
Expand Down
Loading