Skip to content

Commit

Permalink
allow tracing support to be probed multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
rosstimothy committed Jul 21, 2022
1 parent 1cbf7da commit 91d1228
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 40 deletions.
2 changes: 1 addition & 1 deletion api/observability/tracing/ssh/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
// Channel is a wrapper around ssh.Channel that adds tracing support.
type Channel struct {
ssh.Channel
tracingSupported bool
tracingSupported tracingCapability
opts []tracing.Option
}

Expand Down
92 changes: 69 additions & 23 deletions api/observability/tracing/ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net"
"sync"

"github.com/gravitational/trace"
"go.opentelemetry.io/otel/attribute"
Expand All @@ -33,11 +34,24 @@ import (
// Client is a wrapper around ssh.Client that adds tracing support.
type Client struct {
*ssh.Client
opts []tracing.Option
tracingSupported bool
rejectedError error
opts []tracing.Option

// mu protects capability and rejectedError which may change based
// on the outcome probing the server for tracing capabilities that
// may occur trying to establish a session
mu sync.RWMutex
capability tracingCapability
rejectedError error
}

type tracingCapability int

const (
tracingUnknown tracingCapability = iota
tracingUnsupported
tracingSupported
)

// NewClient creates a new Client.
//
// The server being connected to is probed to determine if it supports
Expand All @@ -46,39 +60,43 @@ type Client struct {
// server will be wrapped in an Envelope with tracing context. All Session
// and Channel created from the returned Client will honor the clients view
// of whether they should provide tracing context.
//
// Note: a channel is used instead of a global request in order prevent blocking
// forever in the event that the connection is rejected. In that case, the server
// doesn't service any global requests and writes the error to the first opened
// channel.
func NewClient(c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, opts ...tracing.Option) *Client {
clt := &Client{
Client: ssh.NewClient(c, chans, reqs),
opts: opts,
}

// Check if the server supports tracing
ch, _, err := clt.Client.OpenChannel(TracingChannel, nil)
clt.capability, clt.rejectedError = isTracingSupported(clt.Client)

return clt
}

// isTracingSupported determines whether the ssh server supports
// tracing payloads by trying to open a TracingChannel.
//
// Note: a channel is used instead of a global request in order prevent blocking
// forever in the event that the connection is rejected. In that case, the server
// doesn't service any global requests and writes the error to the first opened
// channel.
func isTracingSupported(clt *ssh.Client) (tracingCapability, error) {
ch, _, err := clt.OpenChannel(TracingChannel, nil)
if err != nil {
var openError *ssh.OpenChannelError
// prohibited errors due to locks and session control are expected by callers of NewSession
if errors.As(err, &openError) {
switch openError.Reason {
case ssh.Prohibited:
// prohibited errors due to locks and session control are expected by callers of NewSession
clt.rejectedError = err
default:
return tracingUnknown, err
case ssh.UnknownChannelType:
return tracingUnsupported, nil
}

return clt
}

return clt
return tracingUnknown, nil
}

_ = ch.Close()
clt.tracingSupported = true

return clt
return tracingSupported, nil
}

// DialContext initiates a connection to the addr from the remote host.
Expand Down Expand Up @@ -134,7 +152,11 @@ func (c *Client) SendRequest(ctx context.Context, name string, wantReply bool, p
)
defer span.End()

ok, resp, err := c.Client.SendRequest(name, wantReply, wrapPayload(ctx, c.tracingSupported, config.TextMapPropagator, payload))
c.mu.RLock()
capability := c.capability
c.mu.RUnlock()

ok, resp, err := c.Client.SendRequest(name, wantReply, wrapPayload(ctx, capability, config.TextMapPropagator, payload))
if err != nil {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
Expand Down Expand Up @@ -164,7 +186,11 @@ func (c *Client) OpenChannel(ctx context.Context, name string, data []byte) (*Ch
)
defer span.End()

ch, reqs, err := c.Client.OpenChannel(name, wrapPayload(ctx, c.tracingSupported, config.TextMapPropagator, data))
c.mu.RLock()
capability := c.capability
c.mu.RUnlock()

ch, reqs, err := c.Client.OpenChannel(name, wrapPayload(ctx, capability, config.TextMapPropagator, data))
if err != nil {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
Expand Down Expand Up @@ -196,15 +222,35 @@ func (c *Client) NewSession(ctx context.Context) (*ssh.Session, error) {
)
defer span.End()

c.mu.Lock()

// If the TracingChannel was rejected when the client was created,
// the connection was prohibited due to a lock or session control.
// Callers to NewSession are expecting to receive the reason the session
// was rejected, so we need to propagate the rejectedError here.
if c.rejectedError != nil {
return nil, trace.Wrap(c.rejectedError)
err := c.rejectedError
c.rejectedError = nil
c.capability = tracingUnknown
c.mu.Unlock()
return nil, trace.Wrap(err)
}

session, err := c.Client.NewSession()
// If the tracing capabilities of the server are unknown due to
// prohibited errors from previous attempts to check, we need to
// do another check to see if our connection will be permitted
// this time.
if c.capability == tracingUnknown {
capability, err := isTracingSupported(c.Client)
if err != nil {
c.mu.Unlock()
return nil, trace.Wrap(err)
}
c.capability = capability
}

c.mu.Unlock()

session, err := c.Client.NewSession()
return session, trace.Wrap(err)
}
Loading

0 comments on commit 91d1228

Please sign in to comment.