Skip to content

Commit

Permalink
fix(test): race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
aymanbagabas committed Jul 12, 2023
1 parent bdd174d commit c1cb2ba
Show file tree
Hide file tree
Showing 19 changed files with 86 additions and 205 deletions.
12 changes: 9 additions & 3 deletions cmd/soft/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
var (
configPath string

logFileCtxKey = struct{}{}
dbCtxKey = struct{ string }{"db"}

hookCmd = &cobra.Command{
Use: "hook",
Expand All @@ -37,18 +37,24 @@ var (

ctx = config.WithContext(ctx, cfg)
cmd.SetContext(ctx)
db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
dbx, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
if err != nil {
return fmt.Errorf("open database: %w", err)
}

ctx = db.WithContext(ctx, dbx)

// Set up the backend
sb := backend.New(ctx, cfg, db)
sb := backend.New(ctx, cfg, dbx)
ctx = backend.WithContext(ctx, sb)
cmd.SetContext(ctx)

return nil
},
PersistentPostRunE: func(cmd *cobra.Command, _ []string) error {
db := db.FromContext(cmd.Context())
return db.Close()
},
}

hooksRunE = func(cmd *cobra.Command, args []string) error {
Expand Down
1 change: 1 addition & 0 deletions cmd/soft/migrate_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
return fmt.Errorf("open database: %w", err)
}

defer db.Close() // nolint: errcheck
sb := backend.New(ctx, cfg, db)

// FIXME: Admin user gets created when the database is created.
Expand Down
1 change: 1 addition & 0 deletions cmd/soft/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
_ "github.com/lib/pq" // postgres driver
"github.com/spf13/cobra"
"go.uber.org/automaxprocs/maxprocs"

_ "modernc.org/sqlite" // sqlite driver
)

Expand Down
4 changes: 3 additions & 1 deletion cmd/soft/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ var (
return fmt.Errorf("open database: %w", err)
}

defer db.Close() // nolint: errcheck

if rollback {
if err := migrate.Rollback(ctx, db); err != nil {
return fmt.Errorf("rollback error: %w", err)
Expand All @@ -67,7 +69,7 @@ var (
}
}

s, err := server.NewServer(ctx)
s, err := server.NewServer(ctx, db)
if err != nil {
return fmt.Errorf("start server: %w", err)
}
Expand Down
74 changes: 0 additions & 74 deletions examples/setuid/main.go

This file was deleted.

2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ require (
github.com/caarlos0/env/v8 v8.0.0
github.com/charmbracelet/keygen v0.4.3
github.com/charmbracelet/log v0.2.2
github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103
github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc
github.com/gobwas/glob v0.2.3
github.com/gogs/git-module v1.8.2
github.com/hashicorp/golang-lru/v2 v2.0.4
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZ
github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c=
github.com/charmbracelet/log v0.2.2 h1:CaXgos+ikGn5tcws5Cw3paQuk9e/8bIwuYGhnkqQFjo=
github.com/charmbracelet/log v0.2.2/go.mod h1:Zs11hKpb8l+UyX4y1srwZIGW+MPCXJHIty3MB9l/sno=
github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103 h1:wpHMERIN0pQZE635jWwT1dISgfjbpUcEma+fbPKSMCU=
github.com/charmbracelet/ssh v0.0.0-20221117183211-483d43d97103/go.mod h1:0Vm2/8yBljiLDnGJHU8ehswfawrEybGk33j5ssqKQVM=
github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc h1:JUm+5HigAM5utFiThwIDX9iU0BaheKpuNVr+umi3sFg=
github.com/charmbracelet/ssh v0.0.0-20230712221603-7e03c5063afc/go.mod h1:F1vgddWsb/Yr/OZilFeRZEh5sE/qU0Dt1mKkmke6Zvg=
github.com/charmbracelet/wish v1.1.1 h1:KdICASKd2oh2JPvk1Z4CJtAi97cFErXF7NKienPICO4=
github.com/charmbracelet/wish v1.1.1/go.mod h1:xh4KZpSULw+Xqb9bcbhw92QAinVB75CVLWrFuyY6IVs=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
Expand Down Expand Up @@ -189,14 +189,14 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
Expand Down
1 change: 0 additions & 1 deletion server/backend/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ func (d *Backend) RenameRepository(ctx context.Context, oldName string, newName
func (d *Backend) Repositories(ctx context.Context) ([]store.Repository, error) {
repos := make([]store.Repository, 0)

d.logger.Debugf("get all repositories %v %v %v", ctx, d.db, d.store)
if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
ms, err := d.store.GetAllRepos(ctx, tx)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions server/backend/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (d *Backend) User(ctx context.Context, username string) (store.User, error)

var m models.User
var pks []ssh.PublicKey
if err := d.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
var err error
m, err = d.store.FindUserByUsername(ctx, tx, username)
if err != nil {
Expand All @@ -111,7 +111,7 @@ func (d *Backend) User(ctx context.Context, username string) (store.User, error)
func (d *Backend) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (store.User, error) {
var m models.User
var pks []ssh.PublicKey
if err := d.db.TransactionContext(context.Background(), func(tx *db.Tx) error {
if err := d.db.TransactionContext(ctx, func(tx *db.Tx) error {
var err error
m, err = d.store.FindUserByPublicKey(ctx, tx, pk)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions server/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (d *GitDaemon) Start() error {
default:
d.logger.Debugf("git: error accepting connection: %v", err)
}
if ne, ok := err.(net.Error); ok && ne.Temporary() {
if ne, ok := err.(net.Error); ok && ne.Temporary() { // nolint: staticcheck
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
Expand Down Expand Up @@ -147,7 +147,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
}
d.conns.Add(c)
defer func() {
d.conns.Close(c)
d.conns.Close(c) // nolint: errcheck
}()

readc := make(chan struct{}, 1)
Expand Down Expand Up @@ -303,7 +303,7 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
func (d *GitDaemon) Close() error {
d.once.Do(func() { close(d.finished) })
err := d.listener.Close()
d.conns.CloseAll()
d.conns.CloseAll() // nolint: errcheck
return err
}

Expand Down
1 change: 1 addition & 0 deletions server/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func TestMain(m *testing.M) {
if err != nil {
log.Fatal(err)
}
defer db.Close() // nolint: errcheck
if err := migrate.Migrate(ctx, db); err != nil {
log.Fatal(err)
}
Expand Down
18 changes: 18 additions & 0 deletions server/db/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package db

import "context"

var contextKey = struct{ string }{"db"}

// FromContext returns the database from the context.
func FromContext(ctx context.Context) *DB {
if db, ok := ctx.Value(contextKey).(*DB); ok {
return db
}
return nil
}

// WithContext returns a new context with the database.
func WithContext(ctx context.Context, db *DB) context.Context {
return context.WithValue(ctx, contextKey, db)
}
2 changes: 2 additions & 0 deletions server/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type Tx struct {
logger *log.Logger
}

var txContextKey = struct{ string }{"tx"}

// Transaction implements db.DB.
func (d *DB) Transaction(fn func(tx *Tx) error) error {
return d.TransactionContext(context.Background(), fn)
Expand Down
38 changes: 10 additions & 28 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,9 @@ type Server struct {
// key can be provided with authKey. If authKey is provided, access will be
// restricted to that key. If authKey is not provided, the server will be
// publicly writable until configured otherwise by cloning the `config` repo.
func NewServer(ctx context.Context) (*Server, error) {
func NewServer(ctx context.Context, db *db.DB) (*Server, error) {
var err error
cfg := config.FromContext(ctx)
db, err := db.Open(ctx, cfg.DB.Driver, cfg.DB.DataSource)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}

be := backend.New(ctx, cfg, db)
ctx = backend.WithContext(ctx, be)
srv := &Server{
Expand Down Expand Up @@ -84,47 +80,33 @@ func NewServer(ctx context.Context) (*Server, error) {
return srv, nil
}

func start(ctx context.Context, fn func() error) error {
errc := make(chan error, 1)
go func() {
errc <- fn()
}()

select {
case err := <-errc:
return err
case <-ctx.Done():
return ctx.Err()
}
}

// Start starts the SSH server.
func (s *Server) Start() error {
errg, ctx := errgroup.WithContext(s.ctx)
errg, _ := errgroup.WithContext(s.ctx)
errg.Go(func() error {
s.logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, daemon.ErrServerClosed) {
if err := s.GitDaemon.Start(); !errors.Is(err, daemon.ErrServerClosed) {
return err
}
return nil
})
errg.Go(func() error {
s.logger.Print("Starting HTTP server", "addr", s.Config.HTTP.ListenAddr)
if err := start(ctx, s.HTTPServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
if err := s.HTTPServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
errg.Go(func() error {
s.logger.Print("Starting SSH server", "addr", s.Config.SSH.ListenAddr)
if err := start(ctx, s.SSHServer.ListenAndServe); !errors.Is(err, ssh.ErrServerClosed) {
if err := s.SSHServer.ListenAndServe(); !errors.Is(err, ssh.ErrServerClosed) {
return err
}
return nil
})
errg.Go(func() error {
s.logger.Print("Starting Stats server", "addr", s.Config.Stats.ListenAddr)
if err := start(ctx, s.StatsServer.ListenAndServe); !errors.Is(err, http.ErrServerClosed) {
if err := s.StatsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
Expand All @@ -138,7 +120,7 @@ func (s *Server) Start() error {

// Shutdown lets the server gracefully shutdown.
func (s *Server) Shutdown(ctx context.Context) error {
var errg errgroup.Group
errg, ctx := errgroup.WithContext(ctx)
errg.Go(func() error {
return s.GitDaemon.Shutdown(ctx)
})
Expand All @@ -155,7 +137,7 @@ func (s *Server) Shutdown(ctx context.Context) error {
s.Cron.Stop()
return nil
})
defer s.DB.Close() // nolint: errcheck
// defer s.DB.Close() // nolint: errcheck
return errg.Wait()
}

Expand All @@ -170,6 +152,6 @@ func (s *Server) Close() error {
s.Cron.Stop()
return nil
})
defer s.DB.Close() // nolint: errcheck
// defer s.DB.Close() // nolint: errcheck
return errg.Wait()
}
Loading

0 comments on commit c1cb2ba

Please sign in to comment.