Skip to content

Commit

Permalink
fix(ssh): add authentication middleware
Browse files Browse the repository at this point in the history
We need to verify that the key used to establish the connection is the
same key used for authentication, otherwise, refuse connection.
  • Loading branch information
aymanbagabas committed Sep 26, 2023
1 parent 9021825 commit d6ca5fb
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
37 changes: 37 additions & 0 deletions server/ssh/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,48 @@ import (
"github.com/charmbracelet/soft-serve/server/sshutils"
"github.com/charmbracelet/soft-serve/server/store"
"github.com/charmbracelet/ssh"
"github.com/charmbracelet/wish"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/spf13/cobra"
gossh "golang.org/x/crypto/ssh"
)

// ErrPermissionDenied is returned when a user is not allowed connect.
var ErrPermissionDenied = fmt.Errorf("permission denied")

// AuthenticationMiddleware handles authentication.
func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler {
return func(s ssh.Session) {
// XXX: The authentication key is set in the context but gossh doesn't
// validate the authentication. We need to verify that the _last_ key
// that was approved is the one that's being used.

pk := s.PublicKey()
if pk == nil {
// There is no public key stored in the context, public-key auth
// was never requested, skip
sh(s)
return
}

perms := s.Permissions().Permissions
if perms == nil {
wish.Fatalln(s, ErrPermissionDenied)
return
}

// Check if the key is the same as the one we have in context
fp := perms.Extensions["pubkey-fp"]
if fp != gossh.FingerprintSHA256(pk) {
wish.Fatalln(s, ErrPermissionDenied)
return
}

sh(s)
}
}

// ContextMiddleware adds the config, backend, and logger to the session context.
func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler {
return func(sh ssh.Handler) ssh.Handler {
Expand Down
48 changes: 48 additions & 0 deletions server/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) {
LoggingMiddleware,
// Context middleware.
ContextMiddleware(cfg, dbx, datastore, be, logger),
// Authentication middleware.
// gossh.PublicKeyHandler doesn't guarantee that the public key
// is in fact the one used for authentication, so we need to
// check it again here.
AuthenticationMiddleware,
),
}

Expand All @@ -91,6 +96,16 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) {
return nil, err
}

if config.IsDebug() {
s.srv.ServerConfigCallback = func(ctx ssh.Context) *gossh.ServerConfig {
return &gossh.ServerConfig{
AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) {
logger.Debug("authentication", "user", conn.User(), "method", method, "err", err)
},
}
}
}

if cfg.SSH.MaxTimeout > 0 {
s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second
}
Expand Down Expand Up @@ -130,6 +145,19 @@ func (s *SSHServer) Shutdown(ctx context.Context) error {
return s.srv.Shutdown(ctx)
}

func initializePermissions(ctx ssh.Context) {
perms := ctx.Permissions()
if perms == nil || perms.Permissions == nil {
perms = &ssh.Permissions{Permissions: &gossh.Permissions{}}
}
if perms.Extensions == nil {
perms.Extensions = make(map[string]string)
}
if perms.Permissions.Extensions == nil {
perms.Permissions.Extensions = make(map[string]string)
}
}

// PublicKeyAuthHandler handles public key authentication.
func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) {
if pk == nil {
Expand All @@ -144,6 +172,15 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed
if user != nil {
ctx.SetValue(proto.ContextKeyUser, user)
allowed = true

// XXX: store the first "approved" public-key fingerprint in the
// permissions block to use for authentication later.
initializePermissions(ctx)
perms := ctx.Permissions()

// Set the public key fingerprint to be used for authentication.
perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk)
ctx.SetValue(ssh.ContextKeyPermissions, perms)
}

return
Expand All @@ -154,5 +191,16 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed
func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool {
ac := s.be.AllowKeyless(ctx)
keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc()

// If we're allowing keyless access, reset the public key fingerprint
if ac {
initializePermissions(ctx)
perms := ctx.Permissions()

// XXX: reset the public-key fingerprint. This is used to validate the
// public key being used to authenticate.
perms.Extensions["pubkey-fp"] = ""
ctx.SetValue(ssh.ContextKeyPermissions, perms)
}
return ac
}

0 comments on commit d6ca5fb

Please sign in to comment.