diff --git a/server/ssh/middleware.go b/server/ssh/middleware.go index 300dd3798..23bf5ea97 100644 --- a/server/ssh/middleware.go +++ b/server/ssh/middleware.go @@ -13,11 +13,45 @@ import ( "github.com/charmbracelet/soft-serve/server/sshutils" "github.com/charmbracelet/soft-serve/server/store" "github.com/charmbracelet/ssh" + "github.com/charmbracelet/wish" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/spf13/cobra" + gossh "golang.org/x/crypto/ssh" ) +// ErrPermissionDenied is returned when a user is not allowed connect. +var ErrPermissionDenied = fmt.Errorf("permission denied") + +// AuthenticationMiddleware handles authentication. +func AuthenticationMiddleware(sh ssh.Handler) ssh.Handler { + return func(s ssh.Session) { + // XXX: The authentication key is set in the context but gossh doesn't + // validate the authentication. We need to verify that the _last_ key + // that was approved is the one that's being used. + + pk := s.PublicKey() + if pk != nil { + // There is no public key stored in the context, public-key auth + // was never requested, skip + perms := s.Permissions().Permissions + if perms == nil { + wish.Fatalln(s, ErrPermissionDenied) + return + } + + // Check if the key is the same as the one we have in context + fp := perms.Extensions["pubkey-fp"] + if fp != gossh.FingerprintSHA256(pk) { + wish.Fatalln(s, ErrPermissionDenied) + return + } + } + + sh(s) + } +} + // ContextMiddleware adds the config, backend, and logger to the session context. func ContextMiddleware(cfg *config.Config, dbx *db.DB, datastore store.Store, be *backend.Backend, logger *log.Logger) func(ssh.Handler) ssh.Handler { return func(sh ssh.Handler) ssh.Handler { diff --git a/server/ssh/ssh.go b/server/ssh/ssh.go index d02824da5..4d57dc90b 100644 --- a/server/ssh/ssh.go +++ b/server/ssh/ssh.go @@ -77,6 +77,11 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) { LoggingMiddleware, // Context middleware. ContextMiddleware(cfg, dbx, datastore, be, logger), + // Authentication middleware. + // gossh.PublicKeyHandler doesn't guarantee that the public key + // is in fact the one used for authentication, so we need to + // check it again here. + AuthenticationMiddleware, ), } @@ -91,6 +96,16 @@ func NewSSHServer(ctx context.Context) (*SSHServer, error) { return nil, err } + if config.IsDebug() { + s.srv.ServerConfigCallback = func(ctx ssh.Context) *gossh.ServerConfig { + return &gossh.ServerConfig{ + AuthLogCallback: func(conn gossh.ConnMetadata, method string, err error) { + logger.Debug("authentication", "user", conn.User(), "method", method, "err", err) + }, + } + } + } + if cfg.SSH.MaxTimeout > 0 { s.srv.MaxTimeout = time.Duration(cfg.SSH.MaxTimeout) * time.Second } @@ -130,6 +145,19 @@ func (s *SSHServer) Shutdown(ctx context.Context) error { return s.srv.Shutdown(ctx) } +func initializePermissions(ctx ssh.Context) { + perms := ctx.Permissions() + if perms == nil || perms.Permissions == nil { + perms = &ssh.Permissions{Permissions: &gossh.Permissions{}} + } + if perms.Extensions == nil { + perms.Extensions = make(map[string]string) + } + if perms.Permissions.Extensions == nil { + perms.Permissions.Extensions = make(map[string]string) + } +} + // PublicKeyAuthHandler handles public key authentication. func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed bool) { if pk == nil { @@ -144,6 +172,15 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed if user != nil { ctx.SetValue(proto.ContextKeyUser, user) allowed = true + + // XXX: store the first "approved" public-key fingerprint in the + // permissions block to use for authentication later. + initializePermissions(ctx) + perms := ctx.Permissions() + + // Set the public key fingerprint to be used for authentication. + perms.Extensions["pubkey-fp"] = gossh.FingerprintSHA256(pk) + ctx.SetValue(ssh.ContextKeyPermissions, perms) } return @@ -154,5 +191,16 @@ func (s *SSHServer) PublicKeyHandler(ctx ssh.Context, pk ssh.PublicKey) (allowed func (s *SSHServer) KeyboardInteractiveHandler(ctx ssh.Context, _ gossh.KeyboardInteractiveChallenge) bool { ac := s.be.AllowKeyless(ctx) keyboardInteractiveCounter.WithLabelValues(strconv.FormatBool(ac)).Inc() + + // If we're allowing keyless access, reset the public key fingerprint + if ac { + initializePermissions(ctx) + perms := ctx.Permissions() + + // XXX: reset the public-key fingerprint. This is used to validate the + // public key being used to authenticate. + perms.Extensions["pubkey-fp"] = "" + ctx.SetValue(ssh.ContextKeyPermissions, perms) + } return ac }