diff --git a/cmd/soft/serve/server.go b/cmd/soft/serve/server.go index 33679091f..69d5db5a8 100644 --- a/cmd/soft/serve/server.go +++ b/cmd/soft/serve/server.go @@ -95,7 +95,7 @@ func (s *Server) Start() error { errg, _ := errgroup.WithContext(s.ctx) errg.Go(func() error { s.logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr) - if err := s.GitDaemon.Start(); !errors.Is(err, daemon.ErrServerClosed) { + if err := s.GitDaemon.ListenAndServe(); !errors.Is(err, daemon.ErrServerClosed) { return err } return nil diff --git a/pkg/daemon/daemon.go b/pkg/daemon/daemon.go index de65cba75..e4e6340d9 100644 --- a/pkg/daemon/daemon.go +++ b/pkg/daemon/daemon.go @@ -8,6 +8,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" "github.com/charmbracelet/log" @@ -43,7 +44,6 @@ var ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed) // GitDaemon represents a Git daemon. type GitDaemon struct { ctx context.Context - listener net.Listener addr string finished chan struct{} conns connections @@ -52,6 +52,7 @@ type GitDaemon struct { wg sync.WaitGroup once sync.Once logger *log.Logger + done atomic.Bool // indicates if the server has been closed } // NewDaemon returns a new Git daemon. @@ -70,26 +71,31 @@ func NewGitDaemon(ctx context.Context) (*GitDaemon, error) { return d, nil } -// Start starts the Git TCP daemon. -func (d *GitDaemon) Start() error { - // listen on the socket - { - listener, err := net.Listen("tcp", d.addr) - if err != nil { - return err - } - d.listener = listener +// ListenAndServe starts the Git TCP daemon. +func (d *GitDaemon) ListenAndServe() error { + if d.done.Load() { + return ErrServerClosed + } + listener, err := net.Listen("tcp", d.addr) + if err != nil { + return err } + return d.Serve(listener) +} - // close eventual connections to the socket - defer d.listener.Close() // nolint: errcheck +// Serve listens on the TCP network address and serves Git requests. +func (d *GitDaemon) Serve(listener net.Listener) error { + if d.done.Load() { + return ErrServerClosed + } d.wg.Add(1) defer d.wg.Done() + defer listener.Close() //nolint:errcheck var tempDelay time.Duration for { - conn, err := d.listener.Accept() + conn, err := listener.Accept() if err != nil { select { case <-d.finished: @@ -305,21 +311,30 @@ func (d *GitDaemon) handleClient(conn net.Conn) { // Close closes the underlying listener. func (d *GitDaemon) Close() error { - d.once.Do(func() { close(d.finished) }) - err := d.listener.Close() + err := d.closeListener() d.conns.CloseAll() // nolint: errcheck return err } +// closeListener closes the listener and the finished channel. +func (d *GitDaemon) closeListener() error { + if d.done.Load() { + return ErrServerClosed + } + d.once.Do(func() { + close(d.finished) + d.done.Store(true) + }) + return nil +} + // Shutdown gracefully shuts down the daemon. func (d *GitDaemon) Shutdown(ctx context.Context) error { - // in the case when git daemon was never started - if d.listener == nil { - return nil + if d.done.Load() { + return ErrServerClosed } - d.once.Do(func() { close(d.finished) }) - err := d.listener.Close() + err := d.closeListener() finished := make(chan struct{}, 1) go func() { d.wg.Wait() diff --git a/pkg/daemon/daemon_test.go b/pkg/daemon/daemon_test.go index 88b4fdec7..b9cb77c64 100644 --- a/pkg/daemon/daemon_test.go +++ b/pkg/daemon/daemon_test.go @@ -59,7 +59,7 @@ func TestMain(m *testing.M) { } testDaemon = d go func() { - if err := d.Start(); err != ErrServerClosed { + if err := d.ListenAndServe(); err != ErrServerClosed { log.Fatal(err) } }() @@ -75,11 +75,21 @@ func TestMain(m *testing.M) { } func TestIdleTimeout(t *testing.T) { - c, err := net.Dial("tcp", testDaemon.addr) - if err != nil { - t.Fatal(err) + var err error + var c net.Conn + var tries int + for { + c, err = net.Dial("tcp", testDaemon.addr) + if err != nil && tries >= 3 { + t.Fatal(err) + } + tries++ + if testDaemon.conns.Size() != 0 { + break + } + time.Sleep(10 * time.Millisecond) } - time.Sleep(time.Second) + time.Sleep(2 * time.Second) _, err = readPktline(c) if err == nil { t.Errorf("expected error, got nil")