diff --git a/pkg/connman/connman.go b/pkg/connman/connman.go index fbc6c8a799d..22c17ef3b8d 100644 --- a/pkg/connman/connman.go +++ b/pkg/connman/connman.go @@ -7,7 +7,7 @@ import ( "github.com/shellhub-io/shellhub/pkg/revdial" "github.com/shellhub-io/shellhub/pkg/wsconnadapter" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) var ErrNoConnection = errors.New("no connection") @@ -27,12 +27,12 @@ func New() *ConnectionManager { } func (m *ConnectionManager) Set(key string, conn *wsconnadapter.Adapter, connPath string) { - dialer := revdial.NewDialer(conn, connPath) + dialer := revdial.NewDialer(conn.Logger, conn, connPath) m.dialers.Store(key, dialer) if size := m.dialers.Size(key); size > 1 { - logrus.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "key": key, "size": size, }).Warning("Multiple connections stored for the same identifier.") @@ -67,7 +67,7 @@ func (m *ConnectionManager) Dial(ctx context.Context, key string) (net.Conn, err } if size := m.dialers.Size(key); size > 1 { - logrus.WithFields(logrus.Fields{ + log.WithFields(log.Fields{ "key": key, "size": size, }).Warning("Multiple connections found for the same identifier during reverse tunnel dialing.") diff --git a/pkg/httptunnel/httptunnel.go b/pkg/httptunnel/httptunnel.go index a597bfaafa5..2230d24de5a 100644 --- a/pkg/httptunnel/httptunnel.go +++ b/pkg/httptunnel/httptunnel.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/http" + "strings" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" @@ -75,14 +76,26 @@ func (t *Tunnel) Router() http.Handler { return c.String(http.StatusInternalServerError, err.Error()) } - id, err := t.ConnectionHandler(c.Request()) + key, err := t.ConnectionHandler(c.Request()) if err != nil { conn.Close() return c.String(http.StatusBadRequest, err.Error()) } - t.connman.Set(id, wsconnadapter.New(conn), t.DialerPath) + requestID := c.Request().Header.Get("X-Request-ID") + parts := strings.Split(key, ":") + tenant := parts[0] + device := parts[1] + + t.connman.Set( + key, + wsconnadapter. + New(conn). + WithID(requestID). + WithDevice(tenant, device), + t.DialerPath, + ) return nil }) diff --git a/pkg/revdial/revdial.go b/pkg/revdial/revdial.go index 39fe0911194..715afffe9ec 100644 --- a/pkg/revdial/revdial.go +++ b/pkg/revdial/revdial.go @@ -24,7 +24,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "net" "net/http" "strings" @@ -34,11 +33,13 @@ import ( "github.com/gorilla/websocket" "github.com/shellhub-io/shellhub/pkg/clock" "github.com/shellhub-io/shellhub/pkg/wsconnadapter" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) -var ErrDialerClosed = errors.New("revdial.Dialer closed") -var ErrDialerTimedout = errors.New("revdial.Dialer timedout") +var ( + ErrDialerClosed = errors.New("revdial.Dialer closed") + ErrDialerTimedout = errors.New("revdial.Dialer timedout") +) // dialerUniqParam is the parameter name of the GET URL form value // containing the Dialer's random unique ID. @@ -59,11 +60,10 @@ type Dialer struct { connReady chan bool donec chan struct{} closeOnce sync.Once + logger *log.Entry } -var ( - dialers = sync.Map{} -) +var dialers = sync.Map{} // NewDialer returns the side of the connection which will initiate // new connections. This will typically be the side which did the HTTP @@ -71,7 +71,7 @@ var ( // connection. The connPath is the HTTP path and optional query (but // without scheme or host) on the dialer where the ConnHandler is // mounted. -func NewDialer(c net.Conn, connPath string) *Dialer { +func NewDialer(logger *log.Entry, c net.Conn, connPath string) *Dialer { d := &Dialer{ path: connPath, uniqID: newUniqID(), @@ -80,6 +80,7 @@ func NewDialer(c net.Conn, connPath string) *Dialer { connReady: make(chan bool), incomingConn: make(chan net.Conn), pickupFailed: make(chan error), + logger: logger, } join := "?" @@ -90,6 +91,8 @@ func NewDialer(c net.Conn, connPath string) *Dialer { d.register() go d.serve() // nolint:errcheck + d.logger.Debug("new dialer connection") + return d } @@ -121,6 +124,8 @@ func (d *Dialer) Close() error { } func (d *Dialer) close() { + d.logger.Debug("dialer connection closed") + d.unregister() d.conn.Close() d.donec <- struct{}{} @@ -132,21 +137,34 @@ func (d *Dialer) Dial(ctx context.Context) (net.Conn, error) { // First, tell serve that we want a connection: select { case d.connReady <- true: + d.logger.Debug("message true to conn ready channel") case <-d.donec: + d.logger.Debug("dial done") + return nil, ErrDialerClosed case <-ctx.Done(): + d.logger.Debug("dial done due context cancellation") + return nil, ctx.Err() } // Then pick it up: select { case c := <-d.incomingConn: + d.logger.Debug("new incoming connection") + return c, nil case err := <-d.pickupFailed: + d.logger.Debug("failed to pick-up connection") + return nil, err case <-d.donec: + d.logger.Debug("dial done on pick-up") + return nil, ErrDialerClosed case <-ctx.Done(): + d.logger.Debug("dial done on pick-up due context cancellation") + return nil, ctx.Err() } } @@ -165,21 +183,24 @@ func (d *Dialer) serve() error { go func() { defer d.Close() + defer d.logger.Debug("dialer serve done") br := bufio.NewReader(d.conn) for { line, err := br.ReadSlice('\n') if err != nil { + d.logger.WithError(err).Trace("failed to read the agent's command") + unexpectedError := websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) if !errors.Is(err, net.ErrClosed) && unexpectedError { - logrus.WithError(err).Error("revdial.Dialer failed to read") + d.logger.WithError(err).Error("revdial.Dialer failed to read") } return } var msg controlMsg if err := json.Unmarshal(line, &msg); err != nil { - log.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err) + d.logger.WithError(err).WithField("line", line).Printf("revdial.Dialer read invalid JSON") return } @@ -190,16 +211,21 @@ func (d *Dialer) serve() error { select { case d.pickupFailed <- err: case <-d.donec: + d.logger.WithError(err).Debug("failed to pick-up connection") + return } case "keep-alive": default: // Ignore unknown messages + log.WithField("message", msg.Command).Debug("unknown message received") } } }() for { if err := d.sendMessage(controlMsg{Command: "keep-alive"}); err != nil { + d.logger.WithError(err).Debug("failed to send keep-alive message to device") + return err } @@ -213,6 +239,8 @@ func (d *Dialer) serve() error { Command: "conn-ready", ConnPath: d.pickupPath, }); err != nil { + d.logger.WithError(err).Debug("failed to send conn-ready message to device") + return err } case <-d.donec: @@ -225,6 +253,8 @@ func (d *Dialer) serve() error { func (d *Dialer) sendMessage(m controlMsg) error { if err := d.conn.SetWriteDeadline(clock.Now().Add(10 * time.Second)); err != nil { + d.logger.WithError(err).Debug("failed to set the write dead line to device") + return err } @@ -232,6 +262,8 @@ func (d *Dialer) sendMessage(m controlMsg) error { j = append(j, '\n') if _, err := d.conn.Write(j); err != nil { + d.logger.WithError(err).Debug("failed to write on the connection") + return err } diff --git a/pkg/wsconnadapter/wsconnadapter.go b/pkg/wsconnadapter/wsconnadapter.go index 19ab06b57ca..662d14af374 100644 --- a/pkg/wsconnadapter/wsconnadapter.go +++ b/pkg/wsconnadapter/wsconnadapter.go @@ -4,11 +4,12 @@ import ( "errors" "io" "net" + "os" "sync" "time" "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) // an adapter for representing WebSocket connection as a net.Conn @@ -28,11 +29,35 @@ type Adapter struct { reader io.Reader stopPingCh chan struct{} pongCh chan bool + Logger *log.Entry +} + +func (a *Adapter) WithID(requestID string) *Adapter { + a.Logger = a.Logger.WithFields(log.Fields{ + "request-id": requestID, + }) + + return a +} + +func (a *Adapter) WithDevice(tenant string, device string) *Adapter { + a.Logger = a.Logger.WithFields(log.Fields{ + "tenant": tenant, + "device": device, + }) + + return a } func New(conn *websocket.Conn) *Adapter { adapter := &Adapter{ conn: conn, + Logger: log.NewEntry(&log.Logger{ + Out: os.Stderr, + Formatter: log.StandardLogger().Formatter, + Hooks: log.StandardLogger().Hooks, + Level: log.StandardLogger().Level, + }), } return adapter @@ -40,6 +65,8 @@ func New(conn *websocket.Conn) *Adapter { func (a *Adapter) Ping() chan bool { if a.pongCh != nil { + a.Logger.Debug("pong channel is not null") + return a.pongCh } @@ -47,15 +74,19 @@ func (a *Adapter) Ping() chan bool { a.pongCh = make(chan bool) timeout := time.AfterFunc(pongTimeout, func() { + a.Logger.Debug("close connection due pong timeout") + _ = a.Close() }) a.conn.SetPongHandler(func(data string) error { timeout.Reset(pongTimeout) + a.Logger.Trace("pong timeout") // non-blocking channel write select { case a.pongCh <- true: + a.Logger.Trace("write true to pong channel") default: } @@ -71,9 +102,11 @@ func (a *Adapter) Ping() chan bool { select { case <-ticker.C: if err := a.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil { - logrus.WithError(err).Error("Failed to write ping message") + a.Logger.WithError(err).Error("failed to write ping message") } case <-a.stopPingCh: + a.Logger.Debug("stop ping message received") + return } } @@ -112,6 +145,10 @@ func (a *Adapter) Read(b []byte) (int, error) { } } + a.Logger.WithError(err). + WithField("bytes", bytesRead). + Trace("bytes read from wsconnadapter") + return bytesRead, err } @@ -121,22 +158,31 @@ func (a *Adapter) Write(b []byte) (int, error) { nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage) if err != nil { + a.Logger.WithError(err).Trace("failed to get the next writer") + return 0, err } bytesWritten, err := nextWriter.Write(b) nextWriter.Close() + a.Logger.WithError(err). + WithField("bytes", bytesWritten). + Trace("bytes written from wsconnadapter") + return bytesWritten, err } func (a *Adapter) Close() error { select { case <-a.stopPingCh: + a.Logger.Debug("stop ping message received") default: if a.stopPingCh != nil { a.stopPingCh <- struct{}{} close(a.stopPingCh) + + a.Logger.Debug("stop ping channel closed") } } @@ -153,6 +199,8 @@ func (a *Adapter) RemoteAddr() net.Addr { func (a *Adapter) SetDeadline(t time.Time) error { if err := a.SetReadDeadline(t); err != nil { + a.Logger.WithError(err).Trace("failed to set the deadline") + return err }