Skip to content

Commit

Permalink
feat(ssh): add logs on ssh connection
Browse files Browse the repository at this point in the history
  • Loading branch information
henrybarreto committed Sep 16, 2024
1 parent 5d752db commit c416ad6
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 19 deletions.
1 change: 0 additions & 1 deletion gateway/nginx/conf.d/shellhub.conf
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ server {
{{ end -}}
proxy_set_header X-Device-UID $device_uid;
proxy_set_header X-Tenant-ID $tenant_id;
proxy_set_header X-Request-ID $request_id;
proxy_http_version 1.1;
proxy_cache_bypass $http_upgrade;
proxy_redirect off;
Expand Down
21 changes: 17 additions & 4 deletions pkg/connman/connman.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"
"errors"
"net"
"os"
"strings"

"github.com/shellhub-io/shellhub/pkg/revdial"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

var ErrNoConnection = errors.New("no connection")
Expand All @@ -27,12 +29,23 @@ func New() *ConnectionManager {
}

func (m *ConnectionManager) Set(key string, conn *wsconnadapter.Adapter, connPath string) {
dialer := revdial.NewDialer(conn, connPath)
parts := strings.Split(key, ":")
logger := (&log.Logger{
Out: os.Stderr,
Formatter: log.StandardLogger().Formatter,
Hooks: log.StandardLogger().Hooks,
Level: log.StandardLogger().Level,
}).WithFields(log.Fields{
"tenant": parts[0],
"device": parts[1],
})

dialer := revdial.NewDialer(logger, conn, connPath)

m.dialers.Store(key, dialer)

if size := m.dialers.Size(key); size > 1 {
logrus.WithFields(logrus.Fields{
log.WithFields(log.Fields{
"key": key,
"size": size,
}).Warning("Multiple connections stored for the same identifier.")
Expand Down Expand Up @@ -67,7 +80,7 @@ func (m *ConnectionManager) Dial(ctx context.Context, key string) (net.Conn, err
}

if size := m.dialers.Size(key); size > 1 {
logrus.WithFields(logrus.Fields{
log.WithFields(log.Fields{
"key": key,
"size": size,
}).Warning("Multiple connections found for the same identifier during reverse tunnel dialing.")
Expand Down
24 changes: 22 additions & 2 deletions pkg/httptunnel/httptunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import (
"io"
"net"
"net/http"
"strings"

"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/shellhub-io/shellhub/pkg/connman"
"github.com/shellhub-io/shellhub/pkg/revdial"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
log "github.com/sirupsen/logrus"
)

var upgrader = websocket.Upgrader{
Expand Down Expand Up @@ -75,14 +77,32 @@ func (t *Tunnel) Router() http.Handler {
return c.String(http.StatusInternalServerError, err.Error())
}

id, err := t.ConnectionHandler(c.Request())
key, err := t.ConnectionHandler(c.Request())
if err != nil {
conn.Close()

return c.String(http.StatusBadRequest, err.Error())
}

t.connman.Set(id, wsconnadapter.New(conn), t.DialerPath)
parts := strings.Split(key, ":")
requestID := c.Request().Header.Get("X-Request-ID")
tenant := parts[0]
device := parts[1]

log.WithFields(log.Fields{
"request-id": requestID,
"tenant": tenant,
"device": device,
}).Debug("new ssh connection")

t.connman.Set(
key,
wsconnadapter.
New(conn).
WithID(requestID).
WithDevice(tenant, device),
t.DialerPath,
)

return nil
})
Expand Down
37 changes: 27 additions & 10 deletions pkg/revdial/revdial.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
Expand All @@ -34,11 +33,13 @@ import (
"github.com/gorilla/websocket"
"github.com/shellhub-io/shellhub/pkg/clock"
"github.com/shellhub-io/shellhub/pkg/wsconnadapter"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

var ErrDialerClosed = errors.New("revdial.Dialer closed")
var ErrDialerTimedout = errors.New("revdial.Dialer timedout")
var (
ErrDialerClosed = errors.New("revdial.Dialer closed")
ErrDialerTimedout = errors.New("revdial.Dialer timedout")
)

// dialerUniqParam is the parameter name of the GET URL form value
// containing the Dialer's random unique ID.
Expand All @@ -59,19 +60,18 @@ type Dialer struct {
connReady chan bool
donec chan struct{}
closeOnce sync.Once
logger *log.Entry
}

var (
dialers = sync.Map{}
)
var dialers = sync.Map{}

// NewDialer returns the side of the connection which will initiate
// new connections. This will typically be the side which did the HTTP
// Hijack. The connection is (typically) the hijacked HTTP client
// connection. The connPath is the HTTP path and optional query (but
// without scheme or host) on the dialer where the ConnHandler is
// mounted.
func NewDialer(c net.Conn, connPath string) *Dialer {
func NewDialer(logger *log.Entry, c net.Conn, connPath string) *Dialer {
d := &Dialer{
path: connPath,
uniqID: newUniqID(),
Expand All @@ -80,6 +80,7 @@ func NewDialer(c net.Conn, connPath string) *Dialer {
connReady: make(chan bool),
incomingConn: make(chan net.Conn),
pickupFailed: make(chan error),
logger: logger,
}

join := "?"
Expand Down Expand Up @@ -121,6 +122,8 @@ func (d *Dialer) Close() error {
}

func (d *Dialer) close() {
d.logger.Debug("dialer connection closed")

d.unregister()
d.conn.Close()
d.donec <- struct{}{}
Expand Down Expand Up @@ -165,21 +168,24 @@ func (d *Dialer) serve() error {

go func() {
defer d.Close()
defer d.logger.Debug("dialer serve done due reader error")

br := bufio.NewReader(d.conn)
for {
line, err := br.ReadSlice('\n')
if err != nil {
d.logger.WithError(err).Debug("failed to read the agent's command")

unexpectedError := websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure)
if !errors.Is(err, net.ErrClosed) && unexpectedError {
logrus.WithError(err).Error("revdial.Dialer failed to read")
d.logger.WithError(err).Error("revdial.Dialer failed to read")
}

return
}
var msg controlMsg
if err := json.Unmarshal(line, &msg); err != nil {
log.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err)
d.logger.Printf("revdial.Dialer read invalid JSON: %q: %v", line, err)

return
}
Expand All @@ -190,16 +196,21 @@ func (d *Dialer) serve() error {
select {
case d.pickupFailed <- err:
case <-d.donec:
d.logger.WithError(err).Debug("failed to pick-up connection")

return
}
case "keep-alive":
default:
// Ignore unknown messages
log.WithField("message", msg.Command).Debug("unknown message received")
}
}
}()
for {
if err := d.sendMessage(controlMsg{Command: "keep-alive"}); err != nil {
d.logger.WithError(err).Debug("failed to send keep-alive message to device")

return err
}

Expand All @@ -213,6 +224,8 @@ func (d *Dialer) serve() error {
Command: "conn-ready",
ConnPath: d.pickupPath,
}); err != nil {
d.logger.WithError(err).Debug("failed to send conn-ready message to device")

return err
}
case <-d.donec:
Expand All @@ -225,13 +238,17 @@ func (d *Dialer) serve() error {

func (d *Dialer) sendMessage(m controlMsg) error {
if err := d.conn.SetWriteDeadline(clock.Now().Add(10 * time.Second)); err != nil {
d.logger.WithError(err).Debug("failed to set the write dead line to device")

return err
}

j, _ := json.Marshal(m)
j = append(j, '\n')

if _, err := d.conn.Write(j); err != nil {
d.logger.WithError(err).Debug("failed to write on the connection")

return err
}

Expand Down
49 changes: 47 additions & 2 deletions pkg/wsconnadapter/wsconnadapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"errors"
"io"
"net"
"os"
"sync"
"time"

"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)

// an adapter for representing WebSocket connection as a net.Conn
Expand All @@ -28,34 +29,64 @@ type Adapter struct {
reader io.Reader
stopPingCh chan struct{}
pongCh chan bool
logger *log.Entry
}

func (a *Adapter) WithID(requestID string) *Adapter {
a.logger = a.logger.WithFields(log.Fields{
"request-id": requestID,
})

return a
}

func (a *Adapter) WithDevice(tenant string, device string) *Adapter {
a.logger = a.logger.WithFields(log.Fields{
"tenant": tenant,
"device": device,
})

return a
}

func New(conn *websocket.Conn) *Adapter {
adapter := &Adapter{
conn: conn,
logger: log.NewEntry(&log.Logger{
Out: os.Stderr,
Formatter: log.StandardLogger().Formatter,
Hooks: log.StandardLogger().Hooks,
Level: log.StandardLogger().Level,
}),
}

return adapter
}

func (a *Adapter) Ping() chan bool {
if a.pongCh != nil {
a.logger.Debug("pong channel is not null")

return a.pongCh
}

a.stopPingCh = make(chan struct{})
a.pongCh = make(chan bool)

timeout := time.AfterFunc(pongTimeout, func() {
a.logger.Debug("close connection due pong timeout")

_ = a.Close()
})

a.conn.SetPongHandler(func(data string) error {
timeout.Reset(pongTimeout)
a.logger.Debug("pong timeout")

// non-blocking channel write
select {
case a.pongCh <- true:
a.logger.Debug("write true to pong channel")
default:
}

Expand All @@ -71,9 +102,11 @@ func (a *Adapter) Ping() chan bool {
select {
case <-ticker.C:
if err := a.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
logrus.WithError(err).Error("Failed to write ping message")
a.logger.WithError(err).Error("Failed to write ping message")
}
case <-a.stopPingCh:
a.logger.Debug("Stop ping message received")

return
}
}
Expand Down Expand Up @@ -112,6 +145,10 @@ func (a *Adapter) Read(b []byte) (int, error) {
}
}

log.WithError(err).
WithField("bytes", bytesRead).
Trace("bytes read from wsconnadapter")

return bytesRead, err
}

Expand All @@ -121,12 +158,18 @@ func (a *Adapter) Write(b []byte) (int, error) {

nextWriter, err := a.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
a.logger.WithError(err).Trace("failed to get the next writer")

return 0, err
}

bytesWritten, err := nextWriter.Write(b)
nextWriter.Close()

log.WithError(err).
WithField("bytes", bytesWritten).
Trace("bytes written from wsconnadapter")

return bytesWritten, err
}

Expand All @@ -153,6 +196,8 @@ func (a *Adapter) RemoteAddr() net.Addr {

func (a *Adapter) SetDeadline(t time.Time) error {
if err := a.SetReadDeadline(t); err != nil {
a.logger.WithError(err).Trace("failed to set the deadline")

return err
}

Expand Down

0 comments on commit c416ad6

Please sign in to comment.