From 53c4b22dc4d5137f262d9f0b75e1148629abea63 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Thu, 30 Nov 2023 18:44:18 -0500 Subject: [PATCH] feat: add more unittests --- cmd/cmd.go | 2 +- cmd/soft/serve/server.go | 2 +- pkg/access/context_test.go | 20 +++++++++++++++++ pkg/backend/backend.go | 5 ++--- pkg/config/context_test.go | 29 ++++++++++++++++++++++++ pkg/config/file_test.go | 15 +++++++++++++ pkg/config/ssh.go | 27 +++++++++++++++++++++- pkg/config/ssh_test.go | 26 +++++++++++++++++++++ pkg/cron/cron_test.go | 31 +++++++++++++++++++++++++ pkg/daemon/daemon_test.go | 2 +- pkg/db/context_test.go | 28 +++++++++++++++++++++++ pkg/db/db_test.go | 17 ++++++++++++++ pkg/db/errors_test.go | 25 +++++++++++++++++++++ pkg/db/internal/test/test.go | 29 ++++++++++++++++++++++++ pkg/db/migrate/migrate_test.go | 22 ++++++++++++++++++ pkg/git/git.go | 16 ++++++++----- pkg/git/git_test.go | 41 ++++++++++++++++++++++++++++++++++ pkg/git/lfs.go | 9 -------- pkg/hooks/gen.go | 6 ----- pkg/hooks/gen_test.go | 40 +++++++++++++++++++++++++++++++++ pkg/jwk/jwk.go | 2 +- pkg/jwk/jwk_test.go | 22 ++++++++++++++++++ pkg/log/log.go | 3 +++ pkg/log/log_test.go | 39 ++++++++++++++++++++++++++++++++ pkg/ssh/cmd/git.go | 2 +- pkg/ssh/session_test.go | 2 +- pkg/web/auth.go | 2 +- pkg/web/git.go | 2 +- testscript/script_test.go | 2 +- 29 files changed, 435 insertions(+), 33 deletions(-) create mode 100644 pkg/access/context_test.go create mode 100644 pkg/config/context_test.go create mode 100644 pkg/config/file_test.go create mode 100644 pkg/config/ssh_test.go create mode 100644 pkg/cron/cron_test.go create mode 100644 pkg/db/context_test.go create mode 100644 pkg/db/db_test.go create mode 100644 pkg/db/errors_test.go create mode 100644 pkg/db/internal/test/test.go create mode 100644 pkg/db/migrate/migrate_test.go create mode 100644 pkg/hooks/gen_test.go create mode 100644 pkg/jwk/jwk_test.go create mode 100644 pkg/log/log_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index 9aead6369..556594e15 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -33,7 +33,7 @@ func InitBackendContext(cmd *cobra.Command, _ []string) error { ctx = db.WithContext(ctx, dbx) dbstore := database.New(ctx, dbx) ctx = store.WithContext(ctx, dbstore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, dbstore) ctx = backend.WithContext(ctx, be) cmd.SetContext(ctx) diff --git a/cmd/soft/serve/server.go b/cmd/soft/serve/server.go index 223e55a69..33679091f 100644 --- a/cmd/soft/serve/server.go +++ b/cmd/soft/serve/server.go @@ -147,7 +147,7 @@ func (s *Server) Shutdown(ctx context.Context) error { for _, j := range jobs.List() { s.Cron.Remove(j.ID) } - s.Cron.Shutdown() + s.Cron.Stop() return nil }) // defer s.DB.Close() // nolint: errcheck diff --git a/pkg/access/context_test.go b/pkg/access/context_test.go new file mode 100644 index 000000000..c4bcf4a60 --- /dev/null +++ b/pkg/access/context_test.go @@ -0,0 +1,20 @@ +package access + +import ( + "context" + "testing" +) + +func TestGoodFromContext(t *testing.T) { + ctx := WithContext(context.TODO(), AdminAccess) + if ac := FromContext(ctx); ac != AdminAccess { + t.Errorf("FromContext(ctx) => %d, want %d", ac, AdminAccess) + } +} + +func TestBadFromContext(t *testing.T) { + ctx := context.TODO() + if ac := FromContext(ctx); ac != -1 { + t.Errorf("FromContext(ctx) => %d, want %d", ac, -1) + } +} diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index 2d9eb2a05..ba8796b54 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -23,14 +23,13 @@ type Backend struct { } // New returns a new Soft Serve backend. -func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend { - dbstore := store.FromContext(ctx) +func New(ctx context.Context, cfg *config.Config, db *db.DB, st store.Store) *Backend { logger := log.FromContext(ctx).WithPrefix("backend") b := &Backend{ ctx: ctx, cfg: cfg, db: db, - store: dbstore, + store: st, logger: logger, manager: task.NewManager(ctx), } diff --git a/pkg/config/context_test.go b/pkg/config/context_test.go new file mode 100644 index 000000000..db7f8b257 --- /dev/null +++ b/pkg/config/context_test.go @@ -0,0 +1,29 @@ +package config + +import ( + "context" + "reflect" + "testing" +) + +func TestBadFromContext(t *testing.T) { + ctx := context.TODO() + if c := FromContext(ctx); c != nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, nil) + } +} + +func TestGoodFromContext(t *testing.T) { + ctx := WithContext(context.TODO(), &Config{}) + if c := FromContext(ctx); c == nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, &Config{}) + } +} + +func TestGoodFromContextWithDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + ctx := WithContext(context.TODO(), cfg) + if c := FromContext(ctx); c == nil || !reflect.DeepEqual(c, cfg) { + t.Errorf("FromContext(ctx) => %v, want %v", c, cfg) + } +} diff --git a/pkg/config/file_test.go b/pkg/config/file_test.go new file mode 100644 index 000000000..81efad71c --- /dev/null +++ b/pkg/config/file_test.go @@ -0,0 +1,15 @@ +package config + +import "testing" + +func TestNewConfigFile(t *testing.T) { + for _, cfg := range []*Config{ + nil, + DefaultConfig(), + &Config{}, + } { + if s := newConfigFile(cfg); s == "" { + t.Errorf("newConfigFile(nil) => %q, want non-empty string", s) + } + } +} diff --git a/pkg/config/ssh.go b/pkg/config/ssh.go index 102b39141..f7d33604b 100644 --- a/pkg/config/ssh.go +++ b/pkg/config/ssh.go @@ -1,8 +1,33 @@ package config -import "github.com/charmbracelet/keygen" +import ( + "errors" + + "github.com/charmbracelet/keygen" +) + +var ( + // ErrNilConfig is returned when a nil config is passed to a function. + ErrNilConfig = errors.New("nil config") + + // ErrEmptySSHKeyPath is returned when the SSH key path is empty. + ErrEmptySSHKeyPath = errors.New("empty SSH key path") +) // KeyPair returns the server's SSH key pair. func (c SSHConfig) KeyPair() (*keygen.SSHKeyPair, error) { return keygen.New(c.KeyPath, keygen.WithKeyType(keygen.Ed25519)) } + +// KeyPair returns the server's SSH key pair. +func KeyPair(cfg *Config) (*keygen.SSHKeyPair, error) { + if cfg == nil { + return nil, ErrNilConfig + } + + if cfg.SSH.KeyPath == "" { + return nil, ErrEmptySSHKeyPath + } + + return keygen.New(cfg.SSH.KeyPath, keygen.WithKeyType(keygen.Ed25519)) +} diff --git a/pkg/config/ssh_test.go b/pkg/config/ssh_test.go new file mode 100644 index 000000000..4f68ec149 --- /dev/null +++ b/pkg/config/ssh_test.go @@ -0,0 +1,26 @@ +package config + +import "testing" + +func TestBadSSHKeyPair(t *testing.T) { + for _, cfg := range []*Config{ + nil, + {}, + } { + if _, err := KeyPair(cfg); err == nil { + t.Errorf("cfg.SSH.KeyPair() => _, nil, want non-nil error") + } + } +} + +func TestGoodSSHKeyPair(t *testing.T) { + cfg := &Config{ + SSH: SSHConfig{ + KeyPath: "testdata/ssh_host_ed25519_key", + }, + } + + if _, err := KeyPair(cfg); err != nil { + t.Errorf("cfg.SSH.KeyPair() => _, %v, want nil error", err) + } +} diff --git a/pkg/cron/cron_test.go b/pkg/cron/cron_test.go new file mode 100644 index 000000000..c254191b2 --- /dev/null +++ b/pkg/cron/cron_test.go @@ -0,0 +1,31 @@ +package cron + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/charmbracelet/log" +) + +func TestCronLogger(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&buf) + logger.SetLevel(log.DebugLevel) + clogger := cronLogger{logger} + clogger.Info("foo") + clogger.Error(fmt.Errorf("bar"), "test") + if buf.String() != "DEBU foo\nERRO test err=bar\n" { + t.Errorf("unexpected log output: %s", buf.String()) + } +} + +func TestSchedularAddRemove(t *testing.T) { + s := NewScheduler(context.TODO()) + id, err := s.AddFunc("* * * * *", func() {}) + if err != nil { + t.Fatal(err) + } + s.Remove(id) +} diff --git a/pkg/daemon/daemon_test.go b/pkg/daemon/daemon_test.go index 7ebe31b53..0debd24e6 100644 --- a/pkg/daemon/daemon_test.go +++ b/pkg/daemon/daemon_test.go @@ -50,7 +50,7 @@ func TestMain(m *testing.M) { } datastore := database.New(ctx, dbx) ctx = store.WithContext(ctx, datastore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, datastore) ctx = backend.WithContext(ctx, be) d, err := NewGitDaemon(ctx) if err != nil { diff --git a/pkg/db/context_test.go b/pkg/db/context_test.go new file mode 100644 index 000000000..5da523975 --- /dev/null +++ b/pkg/db/context_test.go @@ -0,0 +1,28 @@ +package db_test + +import ( + "context" + "testing" + + "github.com/charmbracelet/soft-serve/pkg/db" + "github.com/charmbracelet/soft-serve/pkg/db/internal/test" +) + +func TestBadFromContext(t *testing.T) { + ctx := context.TODO() + if c := db.FromContext(ctx); c != nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, nil) + } +} + +func TestGoodFromContext(t *testing.T) { + ctx := context.TODO() + dbx, err := test.OpenSqlite(ctx, t) + if err != nil { + t.Fatal(err) + } + ctx = db.WithContext(ctx, dbx) + if c := db.FromContext(ctx); c == nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, dbx) + } +} diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go new file mode 100644 index 000000000..3ca95ad37 --- /dev/null +++ b/pkg/db/db_test.go @@ -0,0 +1,17 @@ +package db + +import ( + "context" + "strings" + "testing" +) + +func TestOpenUnknownDriver(t *testing.T) { + _, err := Open(context.TODO(), "invalid", "") + if err == nil { + t.Error("Open(invalid) => nil, want error") + } + if !strings.Contains(err.Error(), "unknown driver") { + t.Errorf("Open(invalid) => %v, want error containing 'unknown driver'", err) + } +} diff --git a/pkg/db/errors_test.go b/pkg/db/errors_test.go new file mode 100644 index 000000000..0aba63457 --- /dev/null +++ b/pkg/db/errors_test.go @@ -0,0 +1,25 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "testing" +) + +func TestWrapErrorBadNoRows(t *testing.T) { + for _, e := range []error{ + fmt.Errorf("foo"), + errors.New("bar"), + } { + if err := WrapError(e); err != e { + t.Errorf("WrapError(%v) => %v, want %v", e, err, e) + } + } +} + +func TestWrapErrorGoodNoRows(t *testing.T) { + if err := WrapError(sql.ErrNoRows); err != ErrRecordNotFound { + t.Errorf("WrapError(sql.ErrNoRows) => %v, want %v", err, ErrRecordNotFound) + } +} diff --git a/pkg/db/internal/test/test.go b/pkg/db/internal/test/test.go new file mode 100644 index 000000000..b40e2ad90 --- /dev/null +++ b/pkg/db/internal/test/test.go @@ -0,0 +1,29 @@ +package test + +import ( + "context" + "path/filepath" + "testing" + + "github.com/charmbracelet/soft-serve/pkg/db" +) + +// OpenSqlite opens a new temp SQLite database for testing. +// It removes the database file when the test is done using tb.Cleanup. +// If ctx is nil, context.TODO() is used. +func OpenSqlite(ctx context.Context, tb testing.TB) (*db.DB, error) { + if ctx == nil { + ctx = context.TODO() + } + dbpath := filepath.Join(tb.TempDir(), "test.db") + dbx, err := db.Open(ctx, "sqlite", dbpath) + if err != nil { + return nil, err + } + tb.Cleanup(func() { + if err := dbx.Close(); err != nil { + tb.Error(err) + } + }) + return dbx, nil +} diff --git a/pkg/db/migrate/migrate_test.go b/pkg/db/migrate/migrate_test.go new file mode 100644 index 000000000..bfc9d20f2 --- /dev/null +++ b/pkg/db/migrate/migrate_test.go @@ -0,0 +1,22 @@ +package migrate + +import ( + "context" + "testing" + + "github.com/charmbracelet/soft-serve/pkg/config" + "github.com/charmbracelet/soft-serve/pkg/db/internal/test" +) + +func TestMigrate(t *testing.T) { + // XXX: we need a config.Config in the context for the migrations to run + // properly. Some migrations depend on the config being present. + ctx := config.WithContext(context.TODO(), config.DefaultConfig()) + dbx, err := test.OpenSqlite(ctx, t) + if err != nil { + t.Fatal(err) + } + if err := Migrate(ctx, dbx); err != nil { + t.Errorf("Migrate() => %v, want nil error", err) + } +} diff --git a/pkg/git/git.go b/pkg/git/git.go index e523b90a2..d6c014296 100644 --- a/pkg/git/git.go +++ b/pkg/git/git.go @@ -2,6 +2,7 @@ package git import ( "context" + "errors" "fmt" "io" "path/filepath" @@ -13,6 +14,11 @@ import ( "github.com/go-git/go-git/v5/plumbing/format/pktline" ) +var ( + // ErrNoBranches is returned when a repo has no branches. + ErrNoBranches = errors.New("no branches found") +) + // WritePktline encodes and writes a pktline to the given writer. func WritePktline(w io.Writer, v ...interface{}) error { msg := fmt.Sprintln(v...) @@ -57,18 +63,18 @@ func EnsureWithin(reposDir string, repo string) error { // EnsureDefaultBranch ensures the repo has a default branch. // It will prefer choosing "main" or "master" if available. -func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error { - r, err := git.Open(scmd.Dir) +func EnsureDefaultBranch(ctx context.Context, repoPath string) error { + r, err := git.Open(repoPath) if err != nil { return err } brs, err := r.Branches() + if len(brs) == 0 { + return ErrNoBranches + } if err != nil { return err } - if len(brs) == 0 { - return fmt.Errorf("no branches found") - } // Rename the default branch to the first branch available _, err = r.HEAD() if err == git.ErrReferenceNotExist { diff --git a/pkg/git/git_test.go b/pkg/git/git_test.go index d95cb6497..4e4a476d2 100644 --- a/pkg/git/git_test.go +++ b/pkg/git/git_test.go @@ -2,8 +2,12 @@ package git import ( "bytes" + "context" + "errors" "fmt" "testing" + + "github.com/charmbracelet/soft-serve/git" ) func TestPktline(t *testing.T) { @@ -54,3 +58,40 @@ func TestPktline(t *testing.T) { }) } } + +func TestEnsureWithinBad(t *testing.T) { + tmp := t.TempDir() + for _, f := range []string{ + "..", + "../../../", + } { + if err := EnsureWithin(tmp, f); err == nil { + t.Errorf("EnsureWithin(%q, %q) => nil, want non-nil error", tmp, f) + } + } +} + +func TestEnsureWithinGood(t *testing.T) { + tmp := t.TempDir() + for _, f := range []string{ + tmp, + tmp + "/foo", + tmp + "/foo/bar", + } { + if err := EnsureWithin(tmp, f); err != nil { + t.Errorf("EnsureWithin(%q, %q) => %v, want nil error", tmp, f, err) + } + } +} + +func TestEnsureDefaultBranchEmpty(t *testing.T) { + tmp := t.TempDir() + r, err := git.Init(tmp, false) + if err != nil { + t.Fatal(err) + } + + if err := EnsureDefaultBranch(context.TODO(), r.Path); !errors.Is(err, ErrNoBranches) { + t.Errorf("EnsureDefaultBranch(%q) => %v, want ErrNoBranches", tmp, err) + } +} diff --git a/pkg/git/lfs.go b/pkg/git/lfs.go index 5aae027e5..7dc4b8b35 100644 --- a/pkg/git/lfs.go +++ b/pkg/git/lfs.go @@ -21,17 +21,8 @@ import ( "github.com/charmbracelet/soft-serve/pkg/proto" "github.com/charmbracelet/soft-serve/pkg/storage" "github.com/charmbracelet/soft-serve/pkg/store" - "github.com/rubyist/tracerx" ) -func init() { - // git-lfs-transfer uses tracerx for logging. - // use a custom key to avoid conflicts - // SOFT_SERVE_TRACE=1 to enable tracing git-lfs-transfer in soft-serve - tracerx.DefaultKey = "SOFT_SERVE" - tracerx.Prefix = "trace soft-serve-lfs-transfer: " -} - // lfsTransfer implements transfer.Backend. type lfsTransfer struct { ctx context.Context diff --git a/pkg/hooks/gen.go b/pkg/hooks/gen.go index 5eb16ebc2..467b2f263 100644 --- a/pkg/hooks/gen.go +++ b/pkg/hooks/gen.go @@ -3,7 +3,6 @@ package hooks import ( "bytes" "context" - "flag" "os" "path/filepath" "text/template" @@ -30,11 +29,6 @@ const ( // This function should be called by the backend when a repository is created. // TODO: support context. func GenerateHooks(_ context.Context, cfg *config.Config, repo string) error { - // TODO: support git hook tests. - if flag.Lookup("test.v") != nil { - log.WithPrefix("backend.hooks").Warn("refusing to set up hooks when in test") - return nil - } repo = utils.SanitizeRepo(repo) + ".git" hooksPath := filepath.Join(cfg.DataPath, "repos", repo, "hooks") if err := os.MkdirAll(hooksPath, os.ModePerm); err != nil { diff --git a/pkg/hooks/gen_test.go b/pkg/hooks/gen_test.go new file mode 100644 index 000000000..d9f03ee39 --- /dev/null +++ b/pkg/hooks/gen_test.go @@ -0,0 +1,40 @@ +package hooks + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/charmbracelet/soft-serve/git" + "github.com/charmbracelet/soft-serve/pkg/config" +) + +func TestGenerateHooks(t *testing.T) { + tmp := t.TempDir() + cfg := config.DefaultConfig() + cfg.DataPath = tmp + repoPath := filepath.Join(tmp, "repos", "test.git") + _, err := git.Init(repoPath, true) + if err != nil { + t.Fatal(err) + } + + if err := GenerateHooks(context.TODO(), cfg, "test.git"); err != nil { + t.Fatal(err) + } + + for _, hn := range []string{ + PreReceiveHook, + UpdateHook, + PostReceiveHook, + PostUpdateHook, + } { + if _, err := os.Stat(filepath.Join(repoPath, "hooks", hn)); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(filepath.Join(repoPath, "hooks", hn+".d", "soft-serve")); err != nil { + t.Fatal(err) + } + } +} diff --git a/pkg/jwk/jwk.go b/pkg/jwk/jwk.go index b758f88aa..f7fe92404 100644 --- a/pkg/jwk/jwk.go +++ b/pkg/jwk/jwk.go @@ -32,7 +32,7 @@ func (p Pair) JWK() jose.JSONWebKey { // NewPair creates a new JSON Web Key pair. func NewPair(cfg *config.Config) (Pair, error) { - kp, err := cfg.SSH.KeyPair() + kp, err := config.KeyPair(cfg) if err != nil { return Pair{}, err } diff --git a/pkg/jwk/jwk_test.go b/pkg/jwk/jwk_test.go new file mode 100644 index 000000000..a8d95fbdb --- /dev/null +++ b/pkg/jwk/jwk_test.go @@ -0,0 +1,22 @@ +package jwk + +import ( + "errors" + "testing" + + "github.com/charmbracelet/soft-serve/pkg/config" +) + +func TestBadNewPair(t *testing.T) { + _, err := NewPair(nil) + if !errors.Is(err, config.ErrNilConfig) { + t.Errorf("NewPair(nil) => %v, want %v", err, config.ErrNilConfig) + } +} + +func TestGoodNewPair(t *testing.T) { + cfg := config.DefaultConfig() + if _, err := NewPair(cfg); err != nil { + t.Errorf("NewPair(cfg) => _, %v, want nil error", err) + } +} diff --git a/pkg/log/log.go b/pkg/log/log.go index aa9a6b418..3f7447e10 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -11,6 +11,9 @@ import ( // NewLogger returns a new logger with default settings. func NewLogger(cfg *config.Config) (*log.Logger, *os.File, error) { + if cfg == nil { + return nil, nil, config.ErrNilConfig + } logger := log.NewWithOptions(os.Stderr, log.Options{ ReportTimestamp: true, TimeFormat: time.DateOnly, diff --git a/pkg/log/log_test.go b/pkg/log/log_test.go new file mode 100644 index 000000000..bb02f04c0 --- /dev/null +++ b/pkg/log/log_test.go @@ -0,0 +1,39 @@ +package log + +import ( + "path/filepath" + "testing" + + "github.com/charmbracelet/soft-serve/pkg/config" +) + +func TestGoodNewLogger(t *testing.T) { + for _, c := range []*config.Config{ + config.DefaultConfig(), + {}, + {Log: config.LogConfig{Path: filepath.Join(t.TempDir(), "logfile.txt")}}, + } { + _, f, err := NewLogger(c) + if err != nil { + t.Errorf("NewLogger(%v) => _, _, %v, want _, _, nil", c, err) + } + if f != nil { + f.Close() + } + } +} + +func TestBadNewLogger(t *testing.T) { + for _, c := range []*config.Config{ + nil, + {Log: config.LogConfig{Path: "\x00"}}, + } { + _, f, err := NewLogger(c) + if err == nil { + t.Errorf("NewLogger(%v) => _, _, nil, want _, _, %v", c, err) + } + if f != nil { + f.Close() + } + } +} diff --git a/pkg/ssh/cmd/git.go b/pkg/ssh/cmd/git.go index 355d235d4..f6aea60ab 100644 --- a/pkg/ssh/cmd/git.go +++ b/pkg/ssh/cmd/git.go @@ -249,7 +249,7 @@ func gitRunE(cmd *cobra.Command, args []string) error { return git.ErrSystemMalfunction } - if err := git.EnsureDefaultBranch(ctx, scmd); err != nil { + if err := git.EnsureDefaultBranch(ctx, scmd.Dir); err != nil { logger.Error("failed to ensure default branch", "err", err, "repo", name) return git.ErrSystemMalfunction } diff --git a/pkg/ssh/session_test.go b/pkg/ssh/session_test.go index 845e17865..792759c16 100644 --- a/pkg/ssh/session_test.go +++ b/pkg/ssh/session_test.go @@ -76,7 +76,7 @@ func setup(tb testing.TB) (*gossh.Session, func() error) { } dbstore := database.New(ctx, dbx) ctx = store.WithContext(ctx, dbstore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, dbstore) ctx = backend.WithContext(ctx, be) return testsession.New(tb, &ssh.Server{ Handler: ContextMiddleware(cfg, dbx, dbstore, be, log.Default())(bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256)(func(s ssh.Session) { diff --git a/pkg/web/auth.go b/pkg/web/auth.go index 631660bd3..96b2bab85 100644 --- a/pkg/web/auth.go +++ b/pkg/web/auth.go @@ -136,7 +136,7 @@ var ErrInvalidToken = errors.New("invalid token") func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("http.auth") - kp, err := cfg.SSH.KeyPair() + kp, err := config.KeyPair(cfg) if err != nil { return nil, err } diff --git a/pkg/web/git.go b/pkg/web/git.go index c03e56afd..8f1170bce 100644 --- a/pkg/web/git.go +++ b/pkg/web/git.go @@ -441,7 +441,7 @@ func serviceRpc(w http.ResponseWriter, r *http.Request) { } if service == git.ReceivePackService { - if err := git.EnsureDefaultBranch(ctx, cmd); err != nil { + if err := git.EnsureDefaultBranch(ctx, cmd.Dir); err != nil { logger.Errorf("failed to ensure default branch: %s", err) } } diff --git a/testscript/script_test.go b/testscript/script_test.go index 94c6408bc..1adf92be6 100644 --- a/testscript/script_test.go +++ b/testscript/script_test.go @@ -149,7 +149,7 @@ func TestScript(t *testing.T) { ctx = db.WithContext(ctx, dbx) datastore := database.New(ctx, dbx) ctx = store.WithContext(ctx, datastore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, datastore) ctx = backend.WithContext(ctx, be) lock.Lock()