diff --git a/cmd/cmd.go b/cmd/cmd.go index 9aead6369..556594e15 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -33,7 +33,7 @@ func InitBackendContext(cmd *cobra.Command, _ []string) error { ctx = db.WithContext(ctx, dbx) dbstore := database.New(ctx, dbx) ctx = store.WithContext(ctx, dbstore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, dbstore) ctx = backend.WithContext(ctx, be) cmd.SetContext(ctx) diff --git a/pkg/access/context_test.go b/pkg/access/context_test.go new file mode 100644 index 000000000..c4bcf4a60 --- /dev/null +++ b/pkg/access/context_test.go @@ -0,0 +1,20 @@ +package access + +import ( + "context" + "testing" +) + +func TestGoodFromContext(t *testing.T) { + ctx := WithContext(context.TODO(), AdminAccess) + if ac := FromContext(ctx); ac != AdminAccess { + t.Errorf("FromContext(ctx) => %d, want %d", ac, AdminAccess) + } +} + +func TestBadFromContext(t *testing.T) { + ctx := context.TODO() + if ac := FromContext(ctx); ac != -1 { + t.Errorf("FromContext(ctx) => %d, want %d", ac, -1) + } +} diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index 2d9eb2a05..ba8796b54 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -23,14 +23,13 @@ type Backend struct { } // New returns a new Soft Serve backend. -func New(ctx context.Context, cfg *config.Config, db *db.DB) *Backend { - dbstore := store.FromContext(ctx) +func New(ctx context.Context, cfg *config.Config, db *db.DB, st store.Store) *Backend { logger := log.FromContext(ctx).WithPrefix("backend") b := &Backend{ ctx: ctx, cfg: cfg, db: db, - store: dbstore, + store: st, logger: logger, manager: task.NewManager(ctx), } diff --git a/pkg/config/context_test.go b/pkg/config/context_test.go new file mode 100644 index 000000000..db7f8b257 --- /dev/null +++ b/pkg/config/context_test.go @@ -0,0 +1,29 @@ +package config + +import ( + "context" + "reflect" + "testing" +) + +func TestBadFromContext(t *testing.T) { + ctx := context.TODO() + if c := FromContext(ctx); c != nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, nil) + } +} + +func TestGoodFromContext(t *testing.T) { + ctx := WithContext(context.TODO(), &Config{}) + if c := FromContext(ctx); c == nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, &Config{}) + } +} + +func TestGoodFromContextWithDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + ctx := WithContext(context.TODO(), cfg) + if c := FromContext(ctx); c == nil || !reflect.DeepEqual(c, cfg) { + t.Errorf("FromContext(ctx) => %v, want %v", c, cfg) + } +} diff --git a/pkg/config/file_test.go b/pkg/config/file_test.go new file mode 100644 index 000000000..81efad71c --- /dev/null +++ b/pkg/config/file_test.go @@ -0,0 +1,15 @@ +package config + +import "testing" + +func TestNewConfigFile(t *testing.T) { + for _, cfg := range []*Config{ + nil, + DefaultConfig(), + &Config{}, + } { + if s := newConfigFile(cfg); s == "" { + t.Errorf("newConfigFile(nil) => %q, want non-empty string", s) + } + } +} diff --git a/pkg/config/ssh.go b/pkg/config/ssh.go index 102b39141..f7d33604b 100644 --- a/pkg/config/ssh.go +++ b/pkg/config/ssh.go @@ -1,8 +1,33 @@ package config -import "github.com/charmbracelet/keygen" +import ( + "errors" + + "github.com/charmbracelet/keygen" +) + +var ( + // ErrNilConfig is returned when a nil config is passed to a function. + ErrNilConfig = errors.New("nil config") + + // ErrEmptySSHKeyPath is returned when the SSH key path is empty. + ErrEmptySSHKeyPath = errors.New("empty SSH key path") +) // KeyPair returns the server's SSH key pair. func (c SSHConfig) KeyPair() (*keygen.SSHKeyPair, error) { return keygen.New(c.KeyPath, keygen.WithKeyType(keygen.Ed25519)) } + +// KeyPair returns the server's SSH key pair. +func KeyPair(cfg *Config) (*keygen.SSHKeyPair, error) { + if cfg == nil { + return nil, ErrNilConfig + } + + if cfg.SSH.KeyPath == "" { + return nil, ErrEmptySSHKeyPath + } + + return keygen.New(cfg.SSH.KeyPath, keygen.WithKeyType(keygen.Ed25519)) +} diff --git a/pkg/config/ssh_test.go b/pkg/config/ssh_test.go new file mode 100644 index 000000000..4f68ec149 --- /dev/null +++ b/pkg/config/ssh_test.go @@ -0,0 +1,26 @@ +package config + +import "testing" + +func TestBadSSHKeyPair(t *testing.T) { + for _, cfg := range []*Config{ + nil, + {}, + } { + if _, err := KeyPair(cfg); err == nil { + t.Errorf("cfg.SSH.KeyPair() => _, nil, want non-nil error") + } + } +} + +func TestGoodSSHKeyPair(t *testing.T) { + cfg := &Config{ + SSH: SSHConfig{ + KeyPath: "testdata/ssh_host_ed25519_key", + }, + } + + if _, err := KeyPair(cfg); err != nil { + t.Errorf("cfg.SSH.KeyPair() => _, %v, want nil error", err) + } +} diff --git a/pkg/cron/cron_test.go b/pkg/cron/cron_test.go new file mode 100644 index 000000000..c254191b2 --- /dev/null +++ b/pkg/cron/cron_test.go @@ -0,0 +1,31 @@ +package cron + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/charmbracelet/log" +) + +func TestCronLogger(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&buf) + logger.SetLevel(log.DebugLevel) + clogger := cronLogger{logger} + clogger.Info("foo") + clogger.Error(fmt.Errorf("bar"), "test") + if buf.String() != "DEBU foo\nERRO test err=bar\n" { + t.Errorf("unexpected log output: %s", buf.String()) + } +} + +func TestSchedularAddRemove(t *testing.T) { + s := NewScheduler(context.TODO()) + id, err := s.AddFunc("* * * * *", func() {}) + if err != nil { + t.Fatal(err) + } + s.Remove(id) +} diff --git a/pkg/daemon/daemon_test.go b/pkg/daemon/daemon_test.go index 7ebe31b53..0debd24e6 100644 --- a/pkg/daemon/daemon_test.go +++ b/pkg/daemon/daemon_test.go @@ -50,7 +50,7 @@ func TestMain(m *testing.M) { } datastore := database.New(ctx, dbx) ctx = store.WithContext(ctx, datastore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, datastore) ctx = backend.WithContext(ctx, be) d, err := NewGitDaemon(ctx) if err != nil { diff --git a/pkg/db/context_test.go b/pkg/db/context_test.go new file mode 100644 index 000000000..a20aeb30a --- /dev/null +++ b/pkg/db/context_test.go @@ -0,0 +1,25 @@ +package db + +import ( + "context" + "testing" +) + +func TestBadFromContext(t *testing.T) { + ctx := context.TODO() + if c := FromContext(ctx); c != nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, nil) + } +} + +func TestGoodFromContext(t *testing.T) { + ctx := context.TODO() + dbx, err := testOpenSqlite(t, ctx) + if err != nil { + t.Fatal(err) + } + ctx = WithContext(ctx, dbx) + if c := FromContext(ctx); c == nil { + t.Errorf("FromContext(ctx) => %v, want %v", c, dbx) + } +} diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go new file mode 100644 index 000000000..2cfaca403 --- /dev/null +++ b/pkg/db/db_test.go @@ -0,0 +1,38 @@ +package db + +import ( + "context" + "path/filepath" + "strings" + "testing" +) + +// testOpenSqlite opens a new temp SQLite database for testing. +// It removes the database file when the test is done using tb.Cleanup. +// If ctx is nil, context.TODO() is used. +func testOpenSqlite(tb testing.TB, ctx context.Context) (*DB, error) { + if ctx == nil { + ctx = context.TODO() + } + dbpath := filepath.Join(tb.TempDir(), "test.db") + dbx, err := Open(ctx, "sqlite", dbpath) + if err != nil { + return nil, err + } + tb.Cleanup(func() { + if err := dbx.Close(); err != nil { + tb.Error(err) + } + }) + return dbx, nil +} + +func TestOpenUnknownDriver(t *testing.T) { + _, err := Open(context.TODO(), "invalid", "") + if err == nil { + t.Error("Open(invalid) => nil, want error") + } + if !strings.Contains(err.Error(), "unknown driver") { + t.Errorf("Open(invalid) => %v, want error containing 'unknown driver'", err) + } +} diff --git a/pkg/db/errors_test.go b/pkg/db/errors_test.go new file mode 100644 index 000000000..0aba63457 --- /dev/null +++ b/pkg/db/errors_test.go @@ -0,0 +1,25 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" + "testing" +) + +func TestWrapErrorBadNoRows(t *testing.T) { + for _, e := range []error{ + fmt.Errorf("foo"), + errors.New("bar"), + } { + if err := WrapError(e); err != e { + t.Errorf("WrapError(%v) => %v, want %v", e, err, e) + } + } +} + +func TestWrapErrorGoodNoRows(t *testing.T) { + if err := WrapError(sql.ErrNoRows); err != ErrRecordNotFound { + t.Errorf("WrapError(sql.ErrNoRows) => %v, want %v", err, ErrRecordNotFound) + } +} diff --git a/pkg/git/git.go b/pkg/git/git.go index e523b90a2..d6c014296 100644 --- a/pkg/git/git.go +++ b/pkg/git/git.go @@ -2,6 +2,7 @@ package git import ( "context" + "errors" "fmt" "io" "path/filepath" @@ -13,6 +14,11 @@ import ( "github.com/go-git/go-git/v5/plumbing/format/pktline" ) +var ( + // ErrNoBranches is returned when a repo has no branches. + ErrNoBranches = errors.New("no branches found") +) + // WritePktline encodes and writes a pktline to the given writer. func WritePktline(w io.Writer, v ...interface{}) error { msg := fmt.Sprintln(v...) @@ -57,18 +63,18 @@ func EnsureWithin(reposDir string, repo string) error { // EnsureDefaultBranch ensures the repo has a default branch. // It will prefer choosing "main" or "master" if available. -func EnsureDefaultBranch(ctx context.Context, scmd ServiceCommand) error { - r, err := git.Open(scmd.Dir) +func EnsureDefaultBranch(ctx context.Context, repoPath string) error { + r, err := git.Open(repoPath) if err != nil { return err } brs, err := r.Branches() + if len(brs) == 0 { + return ErrNoBranches + } if err != nil { return err } - if len(brs) == 0 { - return fmt.Errorf("no branches found") - } // Rename the default branch to the first branch available _, err = r.HEAD() if err == git.ErrReferenceNotExist { diff --git a/pkg/git/git_test.go b/pkg/git/git_test.go index d95cb6497..4e4a476d2 100644 --- a/pkg/git/git_test.go +++ b/pkg/git/git_test.go @@ -2,8 +2,12 @@ package git import ( "bytes" + "context" + "errors" "fmt" "testing" + + "github.com/charmbracelet/soft-serve/git" ) func TestPktline(t *testing.T) { @@ -54,3 +58,40 @@ func TestPktline(t *testing.T) { }) } } + +func TestEnsureWithinBad(t *testing.T) { + tmp := t.TempDir() + for _, f := range []string{ + "..", + "../../../", + } { + if err := EnsureWithin(tmp, f); err == nil { + t.Errorf("EnsureWithin(%q, %q) => nil, want non-nil error", tmp, f) + } + } +} + +func TestEnsureWithinGood(t *testing.T) { + tmp := t.TempDir() + for _, f := range []string{ + tmp, + tmp + "/foo", + tmp + "/foo/bar", + } { + if err := EnsureWithin(tmp, f); err != nil { + t.Errorf("EnsureWithin(%q, %q) => %v, want nil error", tmp, f, err) + } + } +} + +func TestEnsureDefaultBranchEmpty(t *testing.T) { + tmp := t.TempDir() + r, err := git.Init(tmp, false) + if err != nil { + t.Fatal(err) + } + + if err := EnsureDefaultBranch(context.TODO(), r.Path); !errors.Is(err, ErrNoBranches) { + t.Errorf("EnsureDefaultBranch(%q) => %v, want ErrNoBranches", tmp, err) + } +} diff --git a/pkg/git/lfs.go b/pkg/git/lfs.go index 5aae027e5..7dc4b8b35 100644 --- a/pkg/git/lfs.go +++ b/pkg/git/lfs.go @@ -21,17 +21,8 @@ import ( "github.com/charmbracelet/soft-serve/pkg/proto" "github.com/charmbracelet/soft-serve/pkg/storage" "github.com/charmbracelet/soft-serve/pkg/store" - "github.com/rubyist/tracerx" ) -func init() { - // git-lfs-transfer uses tracerx for logging. - // use a custom key to avoid conflicts - // SOFT_SERVE_TRACE=1 to enable tracing git-lfs-transfer in soft-serve - tracerx.DefaultKey = "SOFT_SERVE" - tracerx.Prefix = "trace soft-serve-lfs-transfer: " -} - // lfsTransfer implements transfer.Backend. type lfsTransfer struct { ctx context.Context diff --git a/pkg/jwk/jwk.go b/pkg/jwk/jwk.go index b758f88aa..f7fe92404 100644 --- a/pkg/jwk/jwk.go +++ b/pkg/jwk/jwk.go @@ -32,7 +32,7 @@ func (p Pair) JWK() jose.JSONWebKey { // NewPair creates a new JSON Web Key pair. func NewPair(cfg *config.Config) (Pair, error) { - kp, err := cfg.SSH.KeyPair() + kp, err := config.KeyPair(cfg) if err != nil { return Pair{}, err } diff --git a/pkg/jwk/jwk_test.go b/pkg/jwk/jwk_test.go new file mode 100644 index 000000000..a8d95fbdb --- /dev/null +++ b/pkg/jwk/jwk_test.go @@ -0,0 +1,22 @@ +package jwk + +import ( + "errors" + "testing" + + "github.com/charmbracelet/soft-serve/pkg/config" +) + +func TestBadNewPair(t *testing.T) { + _, err := NewPair(nil) + if !errors.Is(err, config.ErrNilConfig) { + t.Errorf("NewPair(nil) => %v, want %v", err, config.ErrNilConfig) + } +} + +func TestGoodNewPair(t *testing.T) { + cfg := config.DefaultConfig() + if _, err := NewPair(cfg); err != nil { + t.Errorf("NewPair(cfg) => _, %v, want nil error", err) + } +} diff --git a/pkg/ssh/cmd/git.go b/pkg/ssh/cmd/git.go index 355d235d4..f6aea60ab 100644 --- a/pkg/ssh/cmd/git.go +++ b/pkg/ssh/cmd/git.go @@ -249,7 +249,7 @@ func gitRunE(cmd *cobra.Command, args []string) error { return git.ErrSystemMalfunction } - if err := git.EnsureDefaultBranch(ctx, scmd); err != nil { + if err := git.EnsureDefaultBranch(ctx, scmd.Dir); err != nil { logger.Error("failed to ensure default branch", "err", err, "repo", name) return git.ErrSystemMalfunction } diff --git a/pkg/ssh/session_test.go b/pkg/ssh/session_test.go index 845e17865..792759c16 100644 --- a/pkg/ssh/session_test.go +++ b/pkg/ssh/session_test.go @@ -76,7 +76,7 @@ func setup(tb testing.TB) (*gossh.Session, func() error) { } dbstore := database.New(ctx, dbx) ctx = store.WithContext(ctx, dbstore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, dbstore) ctx = backend.WithContext(ctx, be) return testsession.New(tb, &ssh.Server{ Handler: ContextMiddleware(cfg, dbx, dbstore, be, log.Default())(bm.MiddlewareWithProgramHandler(SessionHandler, termenv.ANSI256)(func(s ssh.Session) { diff --git a/pkg/web/auth.go b/pkg/web/auth.go index 631660bd3..96b2bab85 100644 --- a/pkg/web/auth.go +++ b/pkg/web/auth.go @@ -136,7 +136,7 @@ var ErrInvalidToken = errors.New("invalid token") func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) { cfg := config.FromContext(ctx) logger := log.FromContext(ctx).WithPrefix("http.auth") - kp, err := cfg.SSH.KeyPair() + kp, err := config.KeyPair(cfg) if err != nil { return nil, err } diff --git a/pkg/web/git.go b/pkg/web/git.go index c03e56afd..8f1170bce 100644 --- a/pkg/web/git.go +++ b/pkg/web/git.go @@ -441,7 +441,7 @@ func serviceRpc(w http.ResponseWriter, r *http.Request) { } if service == git.ReceivePackService { - if err := git.EnsureDefaultBranch(ctx, cmd); err != nil { + if err := git.EnsureDefaultBranch(ctx, cmd.Dir); err != nil { logger.Errorf("failed to ensure default branch: %s", err) } } diff --git a/testscript/script_test.go b/testscript/script_test.go index 94c6408bc..1adf92be6 100644 --- a/testscript/script_test.go +++ b/testscript/script_test.go @@ -149,7 +149,7 @@ func TestScript(t *testing.T) { ctx = db.WithContext(ctx, dbx) datastore := database.New(ctx, dbx) ctx = store.WithContext(ctx, datastore) - be := backend.New(ctx, cfg, dbx) + be := backend.New(ctx, cfg, dbx, datastore) ctx = backend.WithContext(ctx, be) lock.Lock()