Skip to content

Commit

Permalink
feat(ssh): add logs on ssh connection
Browse files Browse the repository at this point in the history
  • Loading branch information
henrybarreto authored and gustavosbarreto committed Sep 17, 2024
1 parent 07ac3b3 commit f41c78d
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 18 deletions.
8 changes: 4 additions & 4 deletions pkg/connman/connman.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

"github.com/shellhub-io/shellhub/pkg/revdial"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

var ErrNoConnection = errors.New("no connection")
Expand All @@ -27,12 +27,12 @@ func New() *ConnectionManager {
}

func (m *ConnectionManager) Set(key string, conn *wsconnadapter.Adapter, connPath string) {
dialer := revdial.NewDialer(conn, connPath)
dialer := revdial.NewDialer(conn.Logger, conn, connPath)

m.dialers.Store(key, dialer)

if size := m.dialers.Size(key); size > 1 {
logrus.WithFields(logrus.Fields{
log.WithFields(log.Fields{
"key": key,
"size": size,
}).Warning("Multiple connections stored for the same identifier.")
Expand Down Expand Up @@ -67,7 +67,7 @@ func (m *ConnectionManager) Dial(ctx context.Context, key string) (net.Conn, err
}

if size := m.dialers.Size(key); size > 1 {
logrus.WithFields(logrus.Fields{
log.WithFields(log.Fields{
"key": key,
"size": size,
}).Warning("Multiple connections found for the same identifier during reverse tunnel dialing.")
Expand Down
17 changes: 15 additions & 2 deletions pkg/httptunnel/httptunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net"
"net/http"
"strings"

"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -75,14 +76,26 @@ func (t *Tunnel) Router() http.Handler {
return c.String(http.StatusInternalServerError, err.Error())
}

id, err := t.ConnectionHandler(c.Request())
key, err := t.ConnectionHandler(c.Request())
if err != nil {
conn.Close()

return c.String(http.StatusBadRequest, err.Error())
}

t.connman.Set(id, wsconnadapter.New(conn), t.DialerPath)
requestID := c.Request().Header.Get("X-Request-ID")
parts := strings.Split(key, ":")
tenant := parts[0]
device := parts[1]

t.connman.Set(
key,
wsconnadapter.
New(conn).
WithID(requestID).
WithDevice(tenant, device),
t.DialerPath,
)

return nil
})
Expand Down
52 changes: 42 additions & 10 deletions pkg/revdial/revdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
Expand All @@ -34,11 +33,13 @@ import (
"github.com/gorilla/websocket"
"github.com/shellhub-io/shellhub/pkg/clock"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

var ErrDialerClosed = errors.New("revdial.Dialer closed")
var ErrDialerTimedout = errors.New("revdial.Dialer timedout")
var (
ErrDialerClosed = errors.New("revdial.Dialer closed")
ErrDialerTimedout = errors.New("revdial.Dialer timedout")
)

// dialerUniqParam is the parameter name of the GET URL form value
// containing the Dialer's random unique ID.
Expand All @@ -59,19 +60,18 @@ type Dialer struct {
connReady chan bool
donec chan struct{}
closeOnce sync.Once
logger *log.Entry
}

var (
dialers = sync.Map{}
)
var dialers = sync.Map{}

// NewDialer returns the side of the connection which will initiate
// new connections. This will typically be the side which did the HTTP
// Hijack. The connection is (typically) the hijacked HTTP client
// connection. The connPath is the HTTP path and optional query (but
// without scheme or host) on the dialer where the ConnHandler is
// mounted.
func NewDialer(c net.Conn, connPath string) *Dialer {
func NewDialer(logger *log.Entry, c net.Conn, connPath string) *Dialer {
d := &Dialer{
path: connPath,
uniqID: newUniqID(),
Expand All @@ -80,6 +80,7 @@ func NewDialer(c net.Conn, connPath string) *Dialer {
connReady: make(chan bool),
incomingConn: make(chan net.Conn),
pickupFailed: make(chan error),
logger: logger,
}

join := "?"
Expand All @@ -90,6 +91,8 @@ func NewDialer(c net.Conn, connPath string) *Dialer {
d.register()
go d.serve() // nolint:errcheck

d.logger.Debug("new dialer connection")

return d
}

Expand Down Expand Up @@ -121,6 +124,8 @@ func (d *Dialer) Close() error {
}

func (d *Dialer) close() {
d.logger.Debug("dialer connection closed")

d.unregister()
d.conn.Close()
d.donec <- struct{}{}
Expand All @@ -132,21 +137,34 @@ func (d *Dialer) Dial(ctx context.Context) (net.Conn, error) {
// First, tell serve that we want a connection:
select {
case d.connReady <- true:
d.logger.Debug("message true to conn ready channel")
case <-d.donec:
d.logger.Debug("dial done")

return nil, ErrDialerClosed
case <-ctx.Done():
d.logger.Debug("dial done due context cancellation")

return nil, ctx.Err()
}

// Then pick it up:
select {
case c := <-d.incomingConn:
d.logger.Debug("new incoming connection")

return c, nil
case err := <-d.pickupFailed:
d.logger.Debug("failed to pick-up connection")

return nil, err
case <-d.donec:
d.logger.Debug("dial done on pick-up")

return nil, ErrDialerClosed
case <-ctx.Done():
d.logger.Debug("dial done on pick-up due context cancellation")

return nil, ctx.Err()
}
}
Expand All @@ -165,21 +183,24 @@ func (d *Dialer) serve() error {

go func() {
defer d.Close()
defer d.logger.Debug("dialer serve done")

br := bufio.NewReader(d.conn)
for {
line, err := br.ReadSlice('\n')
if err != nil {
d.logger.WithError(err).Trace("failed to read the agent's command")

unexpectedError := websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure)
if !errors.Is(err, net.ErrClosed) && unexpectedError {
logrus.WithError(err).Error("revdial.Dialer failed to read")
d.logger.WithError(err).Error("revdial.Dialer failed to read")
}

return
}
var msg controlMsg
if err := json.Unmarshal(line, &msg); err != nil {
log.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err)
d.logger.WithError(err).WithField("line", line).Printf("revdial.Dialer read invalid JSON")

return
}
Expand All @@ -190,16 +211,21 @@ func (d *Dialer) serve() error {
select {
case d.pickupFailed <- err:
case <-d.donec:
d.logger.WithError(err).Debug("failed to pick-up connection")

return
}
case "keep-alive":
default:
// Ignore unknown messages
log.WithField("message", msg.Command).Debug("unknown message received")
}
}
}()
for {
if err := d.sendMessage(controlMsg{Command: "keep-alive"}); err != nil {
d.logger.WithError(err).Debug("failed to send keep-alive message to device")

return err
}

Expand All @@ -213,6 +239,8 @@ func (d *Dialer) serve() error {
Command: "conn-ready",
ConnPath: d.pickupPath,
}); err != nil {
d.logger.WithError(err).Debug("failed to send conn-ready message to device")

return err
}
case <-d.donec:
Expand All @@ -225,13 +253,17 @@ func (d *Dialer) serve() error {

func (d *Dialer) sendMessage(m controlMsg) error {
if err := d.conn.SetWriteDeadline(clock.Now().Add(10 * time.Second)); err != nil {
d.logger.WithError(err).Debug("failed to set the write dead line to device")

return err
}

j, _ := json.Marshal(m)
j = append(j, '\n')

if _, err := d.conn.Write(j); err != nil {
d.logger.WithError(err).Debug("failed to write on the connection")

return err
}

Expand Down
52 changes: 50 additions & 2 deletions pkg/wsconnadapter/wsconnadapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"errors"
"io"
"net"
"os"
"sync"
"time"

"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

// an adapter for representing WebSocket connection as a net.Conn
Expand All @@ -28,34 +29,64 @@ type Adapter struct {
reader io.Reader
stopPingCh chan struct{}
pongCh chan bool
Logger *log.Entry
}

func (a *Adapter) WithID(requestID string) *Adapter {
a.Logger = a.Logger.WithFields(log.Fields{
"request-id": requestID,
})

return a
}

func (a *Adapter) WithDevice(tenant string, device string) *Adapter {
a.Logger = a.Logger.WithFields(log.Fields{
"tenant": tenant,
"device": device,
})

return a
}

func New(conn *websocket.Conn) *Adapter {
adapter := &Adapter{
conn: conn,
Logger: log.NewEntry(&log.Logger{
Out: os.Stderr,
Formatter: log.StandardLogger().Formatter,
Hooks: log.StandardLogger().Hooks,
Level: log.StandardLogger().Level,
}),
}

return adapter
}

func (a *Adapter) Ping() chan bool {
if a.pongCh != nil {
a.Logger.Debug("pong channel is not null")

return a.pongCh
}

a.stopPingCh = make(chan struct{})
a.pongCh = make(chan bool)

timeout := time.AfterFunc(pongTimeout, func() {
a.Logger.Debug("close connection due pong timeout")

_ = a.Close()
})

a.conn.SetPongHandler(func(data string) error {
timeout.Reset(pongTimeout)
a.Logger.Trace("pong timeout")

// non-blocking channel write
select {
case a.pongCh <- true:
a.Logger.Trace("write true to pong channel")
default:
}

Expand All @@ -71,9 +102,11 @@ func (a *Adapter) Ping() chan bool {
select {
case <-ticker.C:
if err := a.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
logrus.WithError(err).Error("Failed to write ping message")
a.Logger.WithError(err).Error("failed to write ping message")
}
case <-a.stopPingCh:
a.Logger.Debug("stop ping message received")

return
}
}
Expand Down Expand Up @@ -112,6 +145,10 @@ func (a *Adapter) Read(b []byte) (int, error) {
}
}

a.Logger.WithError(err).
WithField("bytes", bytesRead).
Trace("bytes read from wsconnadapter")

return bytesRead, err
}

Expand All @@ -121,22 +158,31 @@ func (a *Adapter) Write(b []byte) (int, error) {

nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
a.Logger.WithError(err).Trace("failed to get the next writer")

return 0, err
}

bytesWritten, err := nextWriter.Write(b)
nextWriter.Close()

a.Logger.WithError(err).
WithField("bytes", bytesWritten).
Trace("bytes written from wsconnadapter")

return bytesWritten, err
}

func (a *Adapter) Close() error {
select {
case <-a.stopPingCh:
a.Logger.Debug("stop ping message received")
default:
if a.stopPingCh != nil {
a.stopPingCh <- struct{}{}
close(a.stopPingCh)

a.Logger.Debug("stop ping channel closed")
}
}

Expand All @@ -153,6 +199,8 @@ func (a *Adapter) RemoteAddr() net.Addr {

func (a *Adapter) SetDeadline(t time.Time) error {
if err := a.SetReadDeadline(t); err != nil {
a.Logger.WithError(err).Trace("failed to set the deadline")

return err
}

Expand Down

0 comments on commit f41c78d

Please sign in to comment.