diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..d3ec4ee --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,16 @@ +name: Go +on: + push: + branches: [main] + pull_request: + branches: [main] +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: stable + - run: go build -v ./... + - run: go test -v ./... diff --git a/README.md b/README.md new file mode 100644 index 0000000..7168837 --- /dev/null +++ b/README.md @@ -0,0 +1,19 @@ +# egr + +**egr** is a small Go package that extends [errgroup](https://pkg.go.dev/golang.org/x/sync/errgroup) with a typed channel, allowing you to push items into a queue and process them in concurrent goroutines with error propagation. The goal of the package is to provide a standard way to use errgroup, while maintaining most of its flexibility. + +## Installation + +```console +go get github.com/invisiblefunnel/egr +``` + +## Example + +```go +# TODO +``` + +## License + +This project is licensed under the BSD 3-Clause License (to match errgroup's license). diff --git a/egr.go b/egr.go new file mode 100644 index 0000000..8aae948 --- /dev/null +++ b/egr.go @@ -0,0 +1,66 @@ +package egr + +import ( + "context" + + "golang.org/x/sync/errgroup" +) + +// Group[T] is a collection of goroutines processing +// items of type T from a shared queue. +type Group[T any] struct { + group *errgroup.Group + queue chan T +} + +// WithContext returns a new Group[T] along with a derived context.Context. +// The group's goroutines will be canceled if any goroutine returns a non-nil error. +func WithContext[T any](ctx context.Context, queueSize int) (*Group[T], context.Context) { + group, ctx := errgroup.WithContext(ctx) + queue := make(chan T, queueSize) + return &Group[T]{group, queue}, ctx +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. Any subsequent call to the Go method will +// block until it can add an active goroutine without exceeding the configured limit. +// The limit must not be modified while any goroutines in the group are active. +func (g *Group[T]) SetLimit(n int) { + g.group.SetLimit(n) +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// The return value reports whether the goroutine was started. +func (g *Group[T]) TryGo(f func(queue <-chan T) error) bool { + return g.group.TryGo(func() error { + return f(g.queue) + }) +} + +// Go runs a function in a new goroutine, passing a read-only channel of type T. +// If any goroutine returns an error, the context is canceled and the error is propagated. +func (g *Group[T]) Go(f func(queue <-chan T) error) { + g.group.Go(func() error { + return f(g.queue) + }) +} + +// Push sends an item of type T into the queue. +// If the provided ctx is canceled, Push returns the context's error. +// Push must not be called after Wait. +func (g *Group[T]) Push(ctx context.Context, item T) error { + select { + case g.queue <- item: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// Wait closes the queue channel and waits for all goroutines to complete, +// returning the first error encountered (if any). +func (g *Group[T]) Wait() error { + close(g.queue) + return g.group.Wait() +} diff --git a/egr_test.go b/egr_test.go new file mode 100644 index 0000000..0408508 --- /dev/null +++ b/egr_test.go @@ -0,0 +1,156 @@ +package egr_test + +import ( + "context" + "errors" + "sort" + "sync" + "testing" + "time" + + "github.com/invisiblefunnel/egr" +) + +// TestWithContext replicates errgroup_test’s approach: once a goroutine +// returns an error, the group's context should be canceled, and Wait +// should return that error. +func TestWithContext(t *testing.T) { + errDoom := errors.New("group_test: doomed") + + type testCase struct { + errs []error + want error + } + + cases := []testCase{ + {errs: []error{}, want: nil}, + {errs: []error{nil}, want: nil}, + {errs: []error{errDoom}, want: errDoom}, + {errs: []error{errDoom, nil}, want: errDoom}, + {errs: []error{nil, errDoom}, want: errDoom}, + } + + for _, tc := range cases { + ctx := context.Background() + g, ctx := egr.WithContext[int](ctx, 2) + + for _, e := range tc.errs { + e := e // capture + g.Go(func(_ <-chan int) error { return e }) + } + + got := g.Wait() + if got != tc.want { + t.Errorf("For errs=%v, Wait() = %v; want %v", tc.errs, got, tc.want) + } + + // The group’s returned context should be canceled once any error is encountered + select { + case <-ctx.Done(): + // ctx is canceled + default: + // If we expected an error (non-nil) but the context isn't canceled, that's a bug + if tc.want != nil { + t.Errorf("Context was not canceled but expected an error %v", tc.want) + } + } + } +} + +func TestPushContextDone(t *testing.T) { + ctx, cancel := context.WithTimeout( + context.Background(), + 100*time.Millisecond, + ) + defer cancel() + + g, ctx := egr.WithContext[int](ctx, 1) + + for i := 0; i < 5; i++ { + g.Go(func(queue <-chan int) error { + for range queue { + } + return nil + }) + } + + // Loop until the context deadline is exceeded + for { + if err := g.Push(ctx, 0); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected '%v' error return from Push, got %v", context.DeadlineExceeded, err) + } + break + } + } + + err := g.Wait() + if err != nil { + t.Errorf("unexpected error return from Wait: %v", err) + } +} + +func TestGoPushWait(t *testing.T) { + ctx := context.Background() + g, ctx := egr.WithContext[int](ctx, 2) + + var ( + consumed []int + lock sync.Mutex + ) + + nRoutines := 5 + for i := 0; i < nRoutines; i++ { + g.Go(func(queue <-chan int) error { + for item := range queue { + lock.Lock() + consumed = append(consumed, item) + lock.Unlock() + } + return nil + }) + } + + n := 1000 + for i := 0; i < n; i++ { + err := g.Push(ctx, i) + if err != nil { + t.Errorf("unexpected error return from Push: %v", err) + } + } + + err := g.Wait() + if err != nil { + t.Errorf("unexpected error return from Wait: %v", err) + } + + if len(consumed) != n { + t.Errorf("expected %d items consumed, got %d", n, len(consumed)) + } + + sort.Ints(consumed) + for i := range consumed { + if i != consumed[i] { + t.Errorf("expected consumed item %d, got %d", i, consumed[i]) + } + } +} + +// BenchmarkGo measures overhead of spawning goroutines in egr.Group. +func BenchmarkGo(b *testing.B) { + ctx := context.Background() + fn := func(_ <-chan int) error { return nil } + + b.ResetTimer() + b.ReportAllocs() + + // We create a new group once, spawn b.N goroutines, then Wait. + // This is slightly different from the original which tested repeated spawns, + // but it mirrors the general overhead test for egr. + for i := 0; i < b.N; i++ { + // Each iteration of b.N spawns one goroutine + g, _ := egr.WithContext[int](ctx, 0) + g.Go(fn) + g.Wait() + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..031bb10 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/invisiblefunnel/egr + +go 1.23.3 + +require golang.org/x/sync v0.10.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..cf16d91 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=