From aae6d931144e3fc95e40797155ccd9e502d4c061 Mon Sep 17 00:00:00 2001 From: Lz Date: Mon, 11 Mar 2024 10:24:26 +0800 Subject: [PATCH] feat(waitn): added WaitN --- CHANGELOG.md | 4 ++ async.go | 9 ++- awaiter.go | 72 ++++++++++++--------- awaiter_test.go | 168 ++++++++++++++++++++++++++++++++++++++++++------ errors.go | 9 --- 5 files changed, 203 insertions(+), 59 deletions(-) delete mode 100644 errors.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b24c0e..5abeac1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,5 +5,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [1.0.1] - 2024-03-11 +- added `WaitN` (#1) + ## [1.0.0] - 2024-03-08 - 1st release diff --git a/async.go b/async.go index 663b23a..0da538f 100644 --- a/async.go +++ b/async.go @@ -1,6 +1,13 @@ package async -import "context" +import ( + "context" + "errors" +) + +var ( + ErrTooLessDone = errors.New("async: too less tasks to completed without error") +) func New[T any](tasks ...func(ctx context.Context) (T, error)) Awaiter[T] { return &awaiter[T]{ diff --git a/awaiter.go b/awaiter.go index 3f69677..c085182 100644 --- a/awaiter.go +++ b/awaiter.go @@ -5,9 +5,14 @@ import ( ) type Awaiter[T any] interface { + // Add add a task Add(task func(context.Context) (T, error)) - Wait(context.Context) ([]T, error) - WaitAny(context.Context) (T, error) + // Wait wail for all tasks to completed + Wait(context.Context) ([]T, error, []error) + // WaitAny wait for any task to completed without error, can cancel other tasks + WaitAny(context.Context) (T, error, []error) + // WaitN wait for N tasks to completed without error + WaitN(context.Context, int) ([]T, error, []error) } type awaiter[T any] struct { @@ -18,10 +23,9 @@ func (a *awaiter[T]) Add(task func(ctx context.Context) (T, error)) { a.tasks = append(a.tasks, task) } -func (a *awaiter[T]) Wait(ctx context.Context) ([]T, error) { +func (a *awaiter[T]) Wait(ctx context.Context) ([]T, error, []error) { wait := make(chan Result[T]) - n := len(a.tasks) for _, task := range a.tasks { go func(task func(context.Context) (T, error)) { r, err := task(ctx) @@ -33,35 +37,29 @@ func (a *awaiter[T]) Wait(ctx context.Context) ([]T, error) { } var r Result[T] - var es Errors + var taskErrs []error var items []T - for i := 0; i < n; i++ { + tt := len(a.tasks) + for i := 0; i < tt; i++ { select { case r = <-wait: if r.Error != nil { - es = append(es, r.Error) + taskErrs = append(taskErrs, r.Error) } else { items = append(items, r.Data) } case <-ctx.Done(): - return items, ctx.Err() + return items, ctx.Err(), taskErrs } - - } - - if len(es) > 0 { - return items, es } - return items, nil + return items, nil, taskErrs } -func (a *awaiter[T]) WaitAny(ctx context.Context) (T, error) { - - n := len(a.tasks) - +func (a *awaiter[T]) WaitN(ctx context.Context, n int) ([]T, error, []error) { wait := make(chan Result[T]) + cancelCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -75,27 +73,39 @@ func (a *awaiter[T]) WaitAny(ctx context.Context) (T, error) { }(task) } - var t T - var r Result[T] - var es Errors - - for i := 0; i < n; i++ { + var taskErrs []error + var items []T + tt := len(a.tasks) + var done int + for i := 0; i < tt; i++ { select { case r = <-wait: - if r.Error == nil { - return r.Data, nil + if r.Error != nil { + taskErrs = append(taskErrs, r.Error) + } else { + items = append(items, r.Data) + done++ + if done == n { + return items, nil, taskErrs + } } - - es = append(es, r.Error) case <-ctx.Done(): - return t, ctx.Err() + return items, ctx.Err(), taskErrs } + } - if len(es) > 0 { - return t, es + return items, ErrTooLessDone, taskErrs +} + +func (a *awaiter[T]) WaitAny(ctx context.Context) (T, error, []error) { + var t T + result, err, taskErrs := a.WaitN(ctx, 1) + + if len(result) == 1 { + t = result[0] } - return t, nil + return t, err, taskErrs } diff --git a/awaiter_test.go b/awaiter_test.go index 1a70893..645bf8c 100644 --- a/awaiter_test.go +++ b/awaiter_test.go @@ -13,7 +13,7 @@ import ( func TestWait(t *testing.T) { wantedErr := errors.New("wanted") - var wantedErrs error = Errors([]error{wantedErr}) + wantedErrs := []error{wantedErr} tests := []struct { name string @@ -21,7 +21,8 @@ func TestWait(t *testing.T) { withCancel bool setup func() Awaiter[int] wantedResult []int - wantedError error + wantedErr error + wantedErrs []error }{ { name: "wait_should_work", @@ -40,7 +41,7 @@ func TestWait(t *testing.T) { return a }, wantedResult: []int{1, 2, 3}, - wantedError: nil, + wantedErr: nil, }, { name: "error_should_work", @@ -55,7 +56,7 @@ func TestWait(t *testing.T) { }) }, wantedResult: []int{1, 2}, - wantedError: wantedErrs, + wantedErrs: wantedErrs, }, { name: "context_should_work", @@ -75,7 +76,7 @@ func TestWait(t *testing.T) { }) }, wantedResult: []int{3}, - wantedError: context.DeadlineExceeded, + wantedErr: context.DeadlineExceeded, }, { name: "cancel_should_work", @@ -95,7 +96,7 @@ func TestWait(t *testing.T) { }, withCancel: true, wantedResult: []int{3}, - wantedError: context.Canceled, + wantedErr: context.Canceled, }, } @@ -104,6 +105,7 @@ func TestWait(t *testing.T) { a := test.setup() var result []int var err error + var taskErrs []error if test.withCancel { ctx, cancel := context.WithCancel(context.Background()) @@ -111,15 +113,16 @@ func TestWait(t *testing.T) { time.Sleep(1 * time.Second) cancel() }() - result, err = a.Wait(ctx) + result, err, taskErrs = a.Wait(ctx) } else { - result, err = a.Wait(test.ctx()) + result, err, taskErrs = a.Wait(test.ctx()) } slices.Sort(result) require.Equal(t, test.wantedResult, result) - require.Equal(t, test.wantedError, err) + require.Equal(t, test.wantedErr, err) + require.Equal(t, test.wantedErrs, taskErrs) }) @@ -136,7 +139,8 @@ func TestWaitAny(t *testing.T) { withCancel bool setup func() Awaiter[int] wantedResult int - wantedError error + wantedErr error + wantedErrs []error }{ { name: "1st_should_work", @@ -211,6 +215,7 @@ func TestWaitAny(t *testing.T) { }) }, wantedResult: 1, + wantedErrs: []error{wantedErr, wantedErr}, }, { name: "fastest_should_work", @@ -240,7 +245,8 @@ func TestWaitAny(t *testing.T) { return 0, wantedErr }) }, - wantedError: Errors([]error{wantedErr, wantedErr, wantedErr}), + wantedErr: ErrTooLessDone, + wantedErrs: []error{wantedErr, wantedErr, wantedErr}, }, { name: "error_should_work", @@ -250,7 +256,8 @@ func TestWaitAny(t *testing.T) { return 0, wantedErr }) }, - wantedError: Errors([]error{wantedErr}), + wantedErr: ErrTooLessDone, + wantedErrs: []error{wantedErr}, }, { name: "context_should_work", @@ -270,7 +277,7 @@ func TestWaitAny(t *testing.T) { return 3, nil }) }, - wantedError: context.DeadlineExceeded, + wantedErr: context.DeadlineExceeded, }, { name: "cancel_should_work", @@ -289,8 +296,8 @@ func TestWaitAny(t *testing.T) { return 3, nil }) }, - withCancel: true, - wantedError: context.Canceled, + withCancel: true, + wantedErr: context.Canceled, }, } @@ -299,6 +306,7 @@ func TestWaitAny(t *testing.T) { a := test.setup() var result int var err error + var taskErrs []error if test.withCancel { ctx, cancel := context.WithCancel(context.Background()) @@ -306,13 +314,137 @@ func TestWaitAny(t *testing.T) { time.Sleep(1 * time.Second) cancel() }() - result, err = a.WaitAny(ctx) + result, err, taskErrs = a.WaitAny(ctx) } else { - result, err = a.WaitAny(test.ctx()) + result, err, taskErrs = a.WaitAny(test.ctx()) } require.Equal(t, test.wantedResult, result) - require.Equal(t, test.wantedError, err) + require.Equal(t, test.wantedErr, err) + require.Equal(t, test.wantedErrs, taskErrs) + }) + + } +} + +func TestWaitN(t *testing.T) { + + wantedErr := errors.New("wanted") + + tests := []struct { + name string + ctx func() context.Context + withCancel bool + setup func() Awaiter[int] + wantedN int + wantedResult []int + wantedErr error + wantedErrs []error + }{ + { + name: "wait_n_should_work", + ctx: func() context.Context { return context.Background() }, + setup: func() Awaiter[int] { + a := New[int](func(ctx context.Context) (int, error) { + return 1, nil + }, func(ctx context.Context) (int, error) { + return 2, nil + }) + + a.Add(func(ctx context.Context) (int, error) { + time.Sleep(1 * time.Second) + return 3, nil + }) + + return a + }, + wantedN: 2, + wantedResult: []int{1, 2}, + }, + { + name: "error_n_should_work", + ctx: func() context.Context { return context.Background() }, + setup: func() Awaiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(1 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(1 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }) + }, + wantedN: 2, + wantedResult: []int{1, 2}, + wantedErrs: []error{wantedErr}, + }, + { + name: "context_should_work", + ctx: func() context.Context { + ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) //nolint + return ctx + }, + setup: func() Awaiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 3, nil + }) + }, + wantedResult: []int{3}, + wantedErr: context.DeadlineExceeded, + }, + { + name: "cancel_should_work", + ctx: func() context.Context { + return context.Background() + }, + setup: func() Awaiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 3, nil + }) + }, + withCancel: true, + wantedResult: []int{3}, + wantedErr: context.Canceled, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + a := test.setup() + var result []int + var err error + var taskErrs []error + + if test.withCancel { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + result, err, taskErrs = a.WaitN(ctx, test.wantedN) + } else { + result, err, taskErrs = a.WaitN(test.ctx(), test.wantedN) + } + + slices.Sort(result) + + require.Equal(t, test.wantedResult, result) + require.Equal(t, test.wantedErr, err) + require.Equal(t, test.wantedErrs, taskErrs) + }) } diff --git a/errors.go b/errors.go deleted file mode 100644 index 3399ce6..0000000 --- a/errors.go +++ /dev/null @@ -1,9 +0,0 @@ -package async - -import "fmt" - -type Errors []error - -func (es Errors) Error() string { - return fmt.Sprint([]error(es)) -}