Skip to content

Commit

Permalink
taskgroup: remove the deprecated Collector type
Browse files Browse the repository at this point in the history
Replace:
    c := taskgroup.Collect(x)
with
    c := taskgroup.Gather(g.Go, x)
  • Loading branch information
creachadair committed Oct 7, 2024
1 parent 729478c commit 4b10b56
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 186 deletions.
113 changes: 0 additions & 113 deletions collector.go

This file was deleted.

32 changes: 18 additions & 14 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,27 +148,28 @@ func ExampleSingle() {
// 2500 bytes
}

func ExampleCollector() {
var total int
c := taskgroup.Collect(func(v int) {
total += v
})

func ExampleGatherer() {
const numTasks = 25
input := rand.Perm(500)

// Start a bunch of tasks to find elements in the input...
g := taskgroup.New(nil)

var total int
c := taskgroup.Gather(g.Go, func(v int) {
total += v
})

for i := range numTasks {
target := i + 1
g.Go(c.Call(func() (int, error) {
c.Call(func() (int, error) {
for _, v := range input {
if v == target {
return v, nil
}
}
return 0, errors.New("not found")
}))
})
}

// Wait for the searchers to finish, then signal the collector to stop.
Expand All @@ -180,30 +181,33 @@ func ExampleCollector() {
// 325
}

func ExampleCollector_Report() {
func ExampleGatherer_Report() {
type val struct {
who string
v int
}
c := taskgroup.Collect(func(z val) { fmt.Println(z.who, z.v) })

g := taskgroup.New(nil)
c := taskgroup.Gather(g.Go, func(z val) {
fmt.Println(z.who, z.v)
})

// The Report method passes its argument a function to report multiple
// values to the collector.
g.Go(c.Report(func(report func(v val)) error {
c.Report(func(report func(v val)) error {
for i := range 3 {
report(val{"even", 2 * i})
}
return nil
}))
})
// Multiple reporters are fine.
g.Go(c.Report(func(report func(v val)) error {
c.Report(func(report func(v val)) error {
for i := range 3 {
report(val{"odd", 2*i + 1})
}
// An error from a reporter is propagated like any other task error.
return errors.New("no bueno")
}))
})
err := g.Wait()
if err == nil || err.Error() != "no bueno" {
log.Fatalf("Unexpected error: %v", err)
Expand Down
60 changes: 60 additions & 0 deletions gatherer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package taskgroup

import "sync"

// A Gatherer manages a group of [Task] functions that report values, and
// gathers the values they return.
type Gatherer[T any] struct {
run func(Task) // start the task in a goroutine

μ sync.Mutex
gather func(T) // handle values reported by tasks
}

func (g *Gatherer[T]) report(v T) {
g.μ.Lock()
defer g.μ.Unlock()
g.gather(v)
}

// Gather creates a new empty gatherer that uses run to execute tasks returning
// values of type T.
//
// If gather != nil, values reported by successful tasks are passed to the
// function, otherwise such values are discarded. Calls to gather are
// synchronized to a single goroutine.
//
// If run == nil, Gather will panic.
func Gather[T any](run func(Task), gather func(T)) *Gatherer[T] {
if run == nil {
panic("run function is nil")
}
if gather == nil {
gather = func(T) {}
}
return &Gatherer[T]{run: run, gather: gather}
}

// Call runs f in g. If f reports an error, the error is propagated to the
// runner; otherwise the non-error value reported by f is gathered.
func (g *Gatherer[T]) Call(f func() (T, error)) {
g.run(func() error {
v, err := f()
if err == nil {
g.report(v)
}
return err
})
}

// Run runs f in g, and gathers the value it reports.
func (g *Gatherer[T]) Run(f func() T) {
g.run(func() error { g.report(f()); return nil })
}

// Report runs f in g. Any values passed to report are gathered. If f reports
// an error, that error is propagated to the runner. Any values sent before f
// returns are still gathered, even if f reports an error.
func (g *Gatherer[T]) Report(f func(report func(T)) error) {
g.run(func() error { return f(g.report) })
}
64 changes: 5 additions & 59 deletions taskgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,12 @@ func TestSingleTask(t *testing.T) {
func TestWaitMoreTasks(t *testing.T) {
defer leaktest.Check(t)()

var g taskgroup.Group
var results int
coll := taskgroup.Collect(func(int) {
coll := taskgroup.Gather(g.Go, func(int) {
results++
})

var g taskgroup.Group

// Test that if a task spawns more tasks on its own recognizance, waiting
// correctly waits for all of them provided we do not let the group go empty
// before all the tasks are spawned.
Expand All @@ -249,14 +248,14 @@ func TestWaitMoreTasks(t *testing.T) {
if n > 1 {
// The subordinate task, if there is one, is started before this one
// exits, ensuring the group is kept "afloat".
g.Go(coll.Run(func() int {
coll.Run(func() int {
return countdown(n - 1)
}))
})
}
return n
}

g.Go(coll.Run(func() int { return countdown(15) }))
coll.Run(func() int { return countdown(15) })
g.Wait()

if results != 15 {
Expand Down Expand Up @@ -284,59 +283,6 @@ func TestSingleResult(t *testing.T) {
}
}

func TestCollector(t *testing.T) {
defer leaktest.Check(t)()

var sum int
c := taskgroup.Collect(func(v int) { sum += v })

vs := rand.Perm(15)
var g taskgroup.Group

for i, v := range vs {
v := v
if v > 10 {
// This value should not be accumulated.
g.Go(c.Call(func() (int, error) {
return -100, errors.New("don't add this")
}))
} else if i%2 == 0 {
// A function with an error.
g.Go(c.Call(func() (int, error) { return v, nil }))
} else {
// A function without an error.
g.Go(c.Run(func() int { return v }))
}
}
g.Wait() // wait for tasks to finish

if want := (10 * 11) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
}

func TestCollector_Report(t *testing.T) {
defer leaktest.Check(t)()

var sum int
c := taskgroup.Collect(func(v int) { sum += v })

var g taskgroup.Group
g.Go(c.Report(func(report func(v int)) error {
for _, v := range rand.Perm(10) {
report(v)
}
return nil
}))

if err := g.Wait(); err != nil {
t.Errorf("Unexpected error from group: %v", err)
}
if want := (9 * 10) / 2; sum != want {
t.Errorf("Final result: got %d, want %d", sum, want)
}
}

func TestGatherer(t *testing.T) {
defer leaktest.Check(t)()

Expand Down

0 comments on commit 4b10b56

Please sign in to comment.