diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a1a8b54a..230ee1d6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ * [CHANGE] grpcutil: Convert Resolver into concrete type. #105 * [CHANGE] grpcutil.Resolver.Resolve: Take a service parameter. #102 * [CHANGE] grpcutil.Update: Remove gRPC LB related metadata. #102 +* [CHANGE] concurrency.ForEach: deprecated and reimplemented by new `concurrency.ForEachJob`. #113 * [ENHANCEMENT] Add middleware package. #38 * [ENHANCEMENT] Add the ring package #45 * [ENHANCEMENT] Add limiter package. #41 diff --git a/concurrency/runner.go b/concurrency/runner.go index a6740f3ac..023be10d7 100644 --- a/concurrency/runner.go +++ b/concurrency/runner.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "go.uber.org/atomic" "golang.org/x/sync/errgroup" "github.com/grafana/dskit/internal/math" @@ -62,45 +63,53 @@ func ForEachUser(ctx context.Context, userIDs []string, concurrency int, userFun // ForEach runs the provided jobFunc for each job up to concurrency concurrent workers. // The execution breaks on first error encountered. +// +// Deprecated: use ForEachJob instead. func ForEach(ctx context.Context, jobs []interface{}, concurrency int, jobFunc func(ctx context.Context, job interface{}) error) error { - if len(jobs) == 0 { - return nil + return ForEachJob(ctx, len(jobs), concurrency, func(ctx context.Context, idx int) error { + return jobFunc(ctx, jobs[idx]) + }) +} + +// CreateJobsFromStrings is an utility to create jobs from an slice of strings. +// +// Deprecated: will be removed as it's not needed when using ForEachJob. +func CreateJobsFromStrings(values []string) []interface{} { + jobs := make([]interface{}, len(values)) + for i := 0; i < len(values); i++ { + jobs[i] = values[i] } + return jobs +} - // Push all jobs to a channel. - ch := make(chan interface{}, len(jobs)) - for _, job := range jobs { - ch <- job +// ForEachJob runs the provided jobFunc for each job index in [0, jobs) up to concurrency concurrent workers. +// The execution breaks on first error encountered. +func ForEachJob(ctx context.Context, jobs int, concurrency int, jobFunc func(ctx context.Context, idx int) error) error { + if jobs == 0 { + return nil } - close(ch) + + // Initialise indexes with -1 so first Inc() returns index 0. + indexes := atomic.NewInt64(-1) // Start workers to process jobs. g, ctx := errgroup.WithContext(ctx) - for ix := 0; ix < math.Min(concurrency, len(jobs)); ix++ { + for ix := 0; ix < math.Min(concurrency, jobs); ix++ { g.Go(func() error { - for job := range ch { - if err := ctx.Err(); err != nil { - return err + for ctx.Err() == nil { + idx := int(indexes.Inc()) + if idx >= jobs { + return nil } - if err := jobFunc(ctx, job); err != nil { + if err := jobFunc(ctx, idx); err != nil { return err } } - - return nil + return ctx.Err() }) } // Wait until done (or context has canceled). return g.Wait() } - -// CreateJobsFromStrings is an utility to create jobs from an slice of strings. -func CreateJobsFromStrings(values []string) []interface{} { - jobs := make([]interface{}, len(values)) - for i := 0; i < len(values); i++ { - jobs[i] = values[i] - } - return jobs -} diff --git a/concurrency/runner_test.go b/concurrency/runner_test.go index 1dec972c4..142705313 100644 --- a/concurrency/runner_test.go +++ b/concurrency/runner_test.go @@ -14,8 +14,6 @@ import ( func TestForEachUser(t *testing.T) { var ( - ctx = context.Background() - // Keep track of processed users. processedMx sync.Mutex processed []string @@ -23,7 +21,7 @@ func TestForEachUser(t *testing.T) { input := []string{"a", "b", "c"} - err := ForEachUser(ctx, input, 2, func(ctx context.Context, user string) error { + err := ForEachUser(context.Background(), input, 2, func(ctx context.Context, user string) error { processedMx.Lock() defer processedMx.Unlock() processed = append(processed, user) @@ -35,16 +33,12 @@ func TestForEachUser(t *testing.T) { } func TestForEachUser_ShouldContinueOnErrorButReturnIt(t *testing.T) { - var ( - ctx = context.Background() - - // Keep the processed users count. - processed atomic.Int32 - ) + // Keep the processed users count. + var processed atomic.Int32 input := []string{"a", "b", "c"} - err := ForEachUser(ctx, input, 2, func(ctx context.Context, user string) error { + err := ForEachUser(context.Background(), input, 2, func(ctx context.Context, user string) error { if processed.CAS(0, 1) { return errors.New("the first request is failing") } @@ -72,10 +66,90 @@ func TestForEachUser_ShouldReturnImmediatelyOnNoUsersProvided(t *testing.T) { })) } +func TestForEachJob(t *testing.T) { + jobs := []string{"a", "b", "c"} + processed := make([]string, len(jobs)) + + err := ForEachJob(context.Background(), len(jobs), 2, func(ctx context.Context, idx int) error { + processed[idx] = jobs[idx] + return nil + }) + + require.NoError(t, err) + assert.ElementsMatch(t, jobs, processed) +} + +func TestForEachJob_ShouldBreakOnFirstError_ContextCancellationHandled(t *testing.T) { + // Keep the processed jobs count. + var processed atomic.Int32 + + err := ForEachJob(context.Background(), 3, 2, func(ctx context.Context, idx int) error { + if processed.CAS(0, 1) { + return errors.New("the first request is failing") + } + + // Wait 1s and increase the number of processed jobs, unless the context get canceled earlier. + select { + case <-time.After(time.Second): + processed.Add(1) + case <-ctx.Done(): + return ctx.Err() + } + + return nil + }) + + require.EqualError(t, err, "the first request is failing") + + // Since we expect the first error interrupts the workers, we should only see + // 1 job processed (the one which immediately returned error). + assert.Equal(t, int32(1), processed.Load()) +} + +func TestForEachJob_ShouldBreakOnFirstError_ContextCancellationUnhandled(t *testing.T) { + // Keep the processed jobs count. + var processed atomic.Int32 + + // waitGroup to await the start of the first two jobs + var wg sync.WaitGroup + wg.Add(2) + + err := ForEachJob(context.Background(), 3, 2, func(ctx context.Context, idx int) error { + wg.Done() + + if processed.CAS(0, 1) { + // wait till two jobs have been started + wg.Wait() + return errors.New("the first request is failing") + } + + // Wait till context is cancelled to add processed jobs. + <-ctx.Done() + processed.Add(1) + + return nil + }) + + require.EqualError(t, err, "the first request is failing") + + // Since we expect the first error interrupts the workers, we should only + // see 2 job processed (the one which immediately returned error and the + // job with "b"). + assert.Equal(t, int32(2), processed.Load()) +} + +func TestForEachJob_ShouldReturnImmediatelyOnNoJobsProvided(t *testing.T) { + // Keep the processed jobs count. + var processed atomic.Int32 + require.NoError(t, ForEachJob(context.Background(), 0, 2, func(ctx context.Context, idx int) error { + processed.Inc() + return nil + })) + require.Zero(t, processed.Load()) +} + func TestForEach(t *testing.T) { var ( - ctx = context.Background() - // Keep track of processed jobs. processedMx sync.Mutex processed []string @@ -83,7 +157,7 @@ func TestForEach(t *testing.T) { jobs := []string{"a", "b", "c"} - err := ForEach(ctx, CreateJobsFromStrings(jobs), 2, func(ctx context.Context, job interface{}) error { + err := ForEach(context.Background(), CreateJobsFromStrings(jobs), 2, func(ctx context.Context, job interface{}) error { processedMx.Lock() defer processedMx.Unlock() processed = append(processed, job.(string)) @@ -126,18 +200,14 @@ func TestForEach_ShouldBreakOnFirstError_ContextCancellationHandled(t *testing.T } func TestForEach_ShouldBreakOnFirstError_ContextCancellationUnhandled(t *testing.T) { - var ( - ctx = context.Background() - - // Keep the processed jobs count. - processed atomic.Int32 - ) + // Keep the processed jobs count. + var processed atomic.Int32 // waitGroup to await the start of the first two jobs var wg sync.WaitGroup wg.Add(2) - err := ForEach(ctx, []interface{}{"a", "b", "c"}, 2, func(ctx context.Context, job interface{}) error { + err := ForEach(context.Background(), []interface{}{"a", "b", "c"}, 2, func(ctx context.Context, job interface{}) error { wg.Done() if processed.CAS(0, 1) {