diff --git a/.deepsource.toml b/.deepsource.toml index 3374da1..99b3aae 100644 --- a/.deepsource.toml +++ b/.deepsource.toml @@ -1,5 +1,9 @@ version = 1 +test_patterns = ["*_test.go"] + +exclude_patterns = ["*_test.go"] + [[analyzers]] name = "go" @@ -7,4 +11,6 @@ name = "go" import_root = "github.com/yaitoo/async" [[transformers]] -name = "gofmt" \ No newline at end of file +name = "gofmt" + + diff --git a/CHANGELOG.md b/CHANGELOG.md index a3b092a..d98becd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.0.4] - 2024-03-18 +- added `Action` support (#4) ## [1.0.3] - 2024-03-12 - added `WaitN` (#1) diff --git a/README.md b/README.md index 721bfc1..66cc543 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Async -Async is an asynchronous task package for Go. +Async is an async/await like task package for Go ![License](https://img.shields.io/badge/license-MIT-green.svg) [![Tests](https://github.com/yaitoo/async/actions/workflows/tests.yml/badge.svg)](https://github.com/yaitoo/async/actions/workflows/tests.yml) @@ -10,12 +10,12 @@ Async is an asynchronous task package for Go. ## Features -- Wait/WaitAny/WaitN +- Wait/WaitAny/WaitN for `Task` and `Action` - `context.Context` with `timeout`, `cancel` support - Works with generic instead of `interface{}` ## Tutorials -see more examples on [tests](./awaiter_test.go) or [go.dev](https://go.dev/play/p/7jgcRltbwts) +see more examples on [tasks](./waiter_test.go), [actions](./awaiter_test.go) or [go.dev](https://go.dev/play/p/7jgcRltbwts) ### Install async - install latest commit from `main` branch diff --git a/async.go b/async.go index 0da538f..1cb9d56 100644 --- a/async.go +++ b/async.go @@ -6,11 +6,25 @@ import ( ) var ( - ErrTooLessDone = errors.New("async: too less tasks to completed without error") + ErrTooLessDone = errors.New("async: too less tasks/actions to completed without error") ) -func New[T any](tasks ...func(ctx context.Context) (T, error)) Awaiter[T] { - return &awaiter[T]{ +// Task a task with result T +type Task[T any] func(ctx context.Context) (T, error) + +// New create a task waiter +func New[T any](tasks ...Task[T]) Waiter[T] { + return &waiter[T]{ tasks: tasks, } } + +// Action a task without result +type Action func(ctx context.Context) error + +// NewA create an action awaiter +func NewA(actions ...Action) Awaiter { + return &awaiter{ + actions: actions, + } +} diff --git a/awaiter.go b/awaiter.go index ea9a797..c168378 100644 --- a/awaiter.go +++ b/awaiter.go @@ -4,112 +4,94 @@ import ( "context" ) -type Awaiter[T any] interface { - // Add add a task - Add(task func(context.Context) (T, error)) - // Wait wail for all tasks to completed - Wait(context.Context) ([]T, []error, error) - // WaitAny wait for any task to completed without error, can cancel other tasks - WaitAny(context.Context) (T, []error, error) - // WaitN wait for N tasks to completed without error - WaitN(context.Context, int) ([]T, []error, error) +type Awaiter interface { + // Add add an action + Add(action Action) + // Wait wail for all actions to completed + Wait(context.Context) ([]error, error) + // WaitAny wait for any action to completed without error, can cancel other tasks + WaitAny(context.Context) ([]error, error) + // WaitN wait for N actions to completed without error + WaitN(context.Context, int) ([]error, error) } -type awaiter[T any] struct { - tasks []func(context.Context) (T, error) +type awaiter struct { + actions []Action } -func (a *awaiter[T]) Add(task func(ctx context.Context) (T, error)) { - a.tasks = append(a.tasks, task) +func (a *awaiter) Add(action Action) { + a.actions = append(a.actions, action) } -func (a *awaiter[T]) Wait(ctx context.Context) ([]T, []error, error) { - wait := make(chan Result[T]) +func (a *awaiter) Wait(ctx context.Context) ([]error, error) { + wait := make(chan error) - for _, task := range a.tasks { - go func(task func(context.Context) (T, error)) { - r, err := task(ctx) - wait <- Result[T]{ - Data: r, - Error: err, - } - }(task) + for _, action := range a.actions { + go func(action Action) { + + wait <- action(ctx) + }(action) } - var r Result[T] var taskErrs []error - var items []T - tt := len(a.tasks) + tt := len(a.actions) for i := 0; i < tt; i++ { select { - case r = <-wait: - if r.Error != nil { - taskErrs = append(taskErrs, r.Error) - } else { - items = append(items, r.Data) + case err := <-wait: + if err != nil { + taskErrs = append(taskErrs, err) } case <-ctx.Done(): - return items, taskErrs, ctx.Err() + return taskErrs, ctx.Err() } } - if len(items) == tt { - return items, taskErrs, nil + if len(taskErrs) > 0 { + return taskErrs, ErrTooLessDone } - return items, taskErrs, ErrTooLessDone + return taskErrs, nil } -func (a *awaiter[T]) WaitN(ctx context.Context, n int) ([]T, []error, error) { - wait := make(chan Result[T]) +func (a *awaiter) WaitN(ctx context.Context, n int) ([]error, error) { + wait := make(chan error) cancelCtx, cancel := context.WithCancel(ctx) defer cancel() - for _, task := range a.tasks { - go func(task func(context.Context) (T, error)) { - r, err := task(cancelCtx) - wait <- Result[T]{ - Data: r, - Error: err, - } - }(task) + for _, action := range a.actions { + go func(action Action) { + wait <- action(cancelCtx) + + }(action) } - var r Result[T] var taskErrs []error - var items []T - tt := len(a.tasks) + tt := len(a.actions) + var done int for i := 0; i < tt; i++ { select { - case r = <-wait: - if r.Error != nil { - taskErrs = append(taskErrs, r.Error) + case err := <-wait: + if err != nil { + taskErrs = append(taskErrs, err) } else { - items = append(items, r.Data) + done++ if done == n { - return items, taskErrs, nil + return taskErrs, nil } } case <-ctx.Done(): - return items, taskErrs, ctx.Err() + return taskErrs, ctx.Err() } } - return items, taskErrs, ErrTooLessDone + return taskErrs, ErrTooLessDone } -func (a *awaiter[T]) WaitAny(ctx context.Context) (T, []error, error) { - var t T - result, err, taskErrs := a.WaitN(ctx, 1) - - if len(result) == 1 { - t = result[0] - } - - return t, err, taskErrs +func (a *awaiter) WaitAny(ctx context.Context) ([]error, error) { + return a.WaitN(ctx, 1) } diff --git a/awaiter_test.go b/awaiter_test.go index 28fd2ce..5e37be4 100644 --- a/awaiter_test.go +++ b/awaiter_test.go @@ -3,72 +3,68 @@ package async import ( "context" "errors" - "slices" "testing" "time" "github.com/stretchr/testify/require" ) -func TestWait(t *testing.T) { +func TestAwait(t *testing.T) { wantedErr := errors.New("wanted") - wantedErrs := []error{wantedErr} tests := []struct { - name string - ctx func() context.Context - withCancel bool - setup func() Awaiter[int] - wantedResult []int - wantedErr error - wantedErrs []error + name string + ctx func() context.Context + withCancel bool + setup func() Awaiter + wantedErr error + wantedErrs []error }{ { name: "wait_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - a := New[int](func(ctx context.Context) (int, error) { - return 1, nil - }, func(ctx context.Context) (int, error) { - return 2, nil + setup: func() Awaiter { + a := NewA(func(ctx context.Context) error { + return nil + }, func(ctx context.Context) error { + return nil }) - a.Add(func(ctx context.Context) (int, error) { - return 3, nil + a.Add(func(ctx context.Context) error { + return nil }) return a }, - wantedResult: []int{1, 2, 3}, - wantedErr: nil, + + wantedErr: nil, }, { name: "error_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { - return 1, nil - }, func(ctx context.Context) (int, error) { - return 2, nil - }, func(ctx context.Context) (int, error) { - return 0, wantedErr + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { + return nil + }, func(ctx context.Context) error { + return nil + }, func(ctx context.Context) error { + return wantedErr }) }, - wantedResult: []int{1, 2}, - wantedErr: ErrTooLessDone, - wantedErrs: wantedErrs, + wantedErr: ErrTooLessDone, + wantedErrs: []error{wantedErr}, }, { name: "errors_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { - return 0, wantedErr - }, func(ctx context.Context) (int, error) { - return 0, wantedErr - }, func(ctx context.Context) (int, error) { - return 0, wantedErr + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { + return wantedErr + }, func(ctx context.Context) error { + return wantedErr + }, func(ctx context.Context) error { + return wantedErr }) }, wantedErr: ErrTooLessDone, @@ -80,44 +76,43 @@ func TestWait(t *testing.T) { ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) //nolint return ctx }, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { - return 3, nil + return nil + }, func(ctx context.Context) error { + return nil }) }, - wantedResult: []int{3}, - wantedErr: context.DeadlineExceeded, + + wantedErr: context.DeadlineExceeded, }, { name: "cancel_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { - return 3, nil + return nil + }, func(ctx context.Context) error { + return nil }) }, - withCancel: true, - wantedResult: []int{3}, - wantedErr: context.Canceled, + withCancel: true, + wantedErr: context.Canceled, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { a := test.setup() - var result []int + var err error var taskErrs []error @@ -127,14 +122,11 @@ func TestWait(t *testing.T) { time.Sleep(1 * time.Second) cancel() }() - result, taskErrs, err = a.Wait(ctx) + taskErrs, err = a.Wait(ctx) } else { - result, taskErrs, err = a.Wait(test.ctx()) + taskErrs, err = a.Wait(test.ctx()) } - slices.Sort(result) - - require.Equal(t, test.wantedResult, result) require.Equal(t, test.wantedErr, err) require.Equal(t, test.wantedErrs, taskErrs) @@ -143,114 +135,108 @@ func TestWait(t *testing.T) { } } -func TestWaitAny(t *testing.T) { +func TestAwaitAny(t *testing.T) { wantedErr := errors.New("wanted") tests := []struct { - name string - ctx func() context.Context - withCancel bool - setup func() Awaiter[int] - wantedResult int - wantedErr error - wantedErrs []error + name string + ctx func() context.Context + withCancel bool + setup func() Awaiter + wantedErr error + wantedErrs []error }{ { name: "1st_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - a := New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + a := NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 3, nil + return nil }) - a.Add(func(ctx context.Context) (int, error) { - return 1, nil + a.Add(func(ctx context.Context) error { + return nil }) return a }, - wantedResult: 1, }, { name: "2nd_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { - return 2, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 3, nil + return nil }) }, - wantedResult: 2, }, { name: "3rd_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { - return 3, nil + return nil }) }, - wantedResult: 3, }, { name: "slowest_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(3 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { - return 0, wantedErr - }, func(ctx context.Context) (int, error) { - return 0, wantedErr + return nil + }, func(ctx context.Context) error { + return wantedErr + }, func(ctx context.Context) error { + return wantedErr }) }, - wantedResult: 1, - wantedErrs: []error{wantedErr, wantedErr}, + wantedErrs: []error{wantedErr, wantedErr}, }, { name: "fastest_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { - return 1, nil - }, func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { + return nil + }, func(ctx context.Context) error { time.Sleep(2 * time.Second) - return 0, wantedErr - }, func(ctx context.Context) (int, error) { + return wantedErr + }, func(ctx context.Context) error { time.Sleep(3 * time.Second) - return 0, wantedErr + return wantedErr }) }, - wantedResult: 1, }, { name: "errors_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { - return 0, wantedErr - }, func(ctx context.Context) (int, error) { - return 0, wantedErr - }, func(ctx context.Context) (int, error) { - return 0, wantedErr + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { + return wantedErr + }, func(ctx context.Context) error { + return wantedErr + }, func(ctx context.Context) error { + return wantedErr }) }, wantedErr: ErrTooLessDone, @@ -259,9 +245,9 @@ func TestWaitAny(t *testing.T) { { name: "error_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { - return 0, wantedErr + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { + return wantedErr }) }, wantedErr: ErrTooLessDone, @@ -273,16 +259,16 @@ func TestWaitAny(t *testing.T) { ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) //nolint return ctx }, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 3, nil + return nil }) }, wantedErr: context.DeadlineExceeded, @@ -290,16 +276,16 @@ func TestWaitAny(t *testing.T) { { name: "cancel_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 3, nil + return nil }) }, withCancel: true, @@ -310,7 +296,7 @@ func TestWaitAny(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { a := test.setup() - var result int + var err error var taskErrs []error @@ -320,12 +306,11 @@ func TestWaitAny(t *testing.T) { time.Sleep(1 * time.Second) cancel() }() - result, taskErrs, err = a.WaitAny(ctx) + taskErrs, err = a.WaitAny(ctx) } else { - result, taskErrs, err = a.WaitAny(test.ctx()) + taskErrs, err = a.WaitAny(test.ctx()) } - require.Equal(t, test.wantedResult, result) require.Equal(t, test.wantedErr, err) require.Equal(t, test.wantedErrs, taskErrs) }) @@ -333,57 +318,54 @@ func TestWaitAny(t *testing.T) { } } -func TestWaitN(t *testing.T) { +func TestAwaitN(t *testing.T) { wantedErr := errors.New("wanted") tests := []struct { - name string - ctx func() context.Context - withCancel bool - setup func() Awaiter[int] - wantedN int - wantedResult []int - wantedErr error - wantedErrs []error + name string + ctx func() context.Context + withCancel bool + setup func() Awaiter + wantedN int + wantedErr error + wantedErrs []error }{ { name: "wait_n_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - a := New[int](func(ctx context.Context) (int, error) { - return 1, nil - }, func(ctx context.Context) (int, error) { - return 2, nil + setup: func() Awaiter { + a := NewA(func(ctx context.Context) error { + return nil + }, func(ctx context.Context) error { + return nil }) - a.Add(func(ctx context.Context) (int, error) { + a.Add(func(ctx context.Context) error { time.Sleep(1 * time.Second) - return 3, nil + return nil }) return a }, - wantedN: 2, - wantedResult: []int{1, 2}, + wantedN: 2, }, { name: "error_n_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(1 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(1 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { - return 0, wantedErr + return nil + }, func(ctx context.Context) error { + return wantedErr }) }, - wantedN: 2, - wantedResult: []int{1, 2}, - wantedErrs: []error{wantedErr}, + wantedN: 2, + wantedErrs: []error{wantedErr}, }, { name: "context_should_work", @@ -391,44 +373,42 @@ func TestWaitN(t *testing.T) { ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) //nolint return ctx }, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { - return 3, nil + return nil + }, func(ctx context.Context) error { + return nil }) }, - wantedResult: []int{3}, - wantedErr: context.DeadlineExceeded, + wantedErr: context.DeadlineExceeded, }, { name: "cancel_should_work", ctx: context.Background, - setup: func() Awaiter[int] { - return New[int](func(ctx context.Context) (int, error) { + setup: func() Awaiter { + return NewA(func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 1, nil - }, func(ctx context.Context) (int, error) { + return nil + }, func(ctx context.Context) error { time.Sleep(5 * time.Second) - return 2, nil - }, func(ctx context.Context) (int, error) { - return 3, nil + return nil + }, func(ctx context.Context) error { + return nil }) }, - withCancel: true, - wantedResult: []int{3}, - wantedErr: context.Canceled, + withCancel: true, + + wantedErr: context.Canceled, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { a := test.setup() - var result []int var err error var taskErrs []error @@ -438,14 +418,11 @@ func TestWaitN(t *testing.T) { time.Sleep(1 * time.Second) cancel() }() - result, taskErrs, err = a.WaitN(ctx, test.wantedN) + taskErrs, err = a.WaitN(ctx, test.wantedN) } else { - result, taskErrs, err = a.WaitN(test.ctx(), test.wantedN) + taskErrs, err = a.WaitN(test.ctx(), test.wantedN) } - slices.Sort(result) - - require.Equal(t, test.wantedResult, result) require.Equal(t, test.wantedErr, err) require.Equal(t, test.wantedErrs, taskErrs) diff --git a/waiter.go b/waiter.go new file mode 100644 index 0000000..83c337f --- /dev/null +++ b/waiter.go @@ -0,0 +1,115 @@ +package async + +import ( + "context" +) + +type Waiter[T any] interface { + // Add add a task + Add(task Task[T]) + // Wait wail for all tasks to completed + Wait(context.Context) ([]T, []error, error) + // WaitAny wait for any task to completed without error, can cancel other tasks + WaitAny(context.Context) (T, []error, error) + // WaitN wait for N tasks to completed without error + WaitN(context.Context, int) ([]T, []error, error) +} + +type waiter[T any] struct { + tasks []Task[T] +} + +func (a *waiter[T]) Add(task Task[T]) { + a.tasks = append(a.tasks, task) +} + +func (a *waiter[T]) Wait(ctx context.Context) ([]T, []error, error) { + wait := make(chan Result[T]) + + for _, task := range a.tasks { + go func(task func(context.Context) (T, error)) { + r, err := task(ctx) + wait <- Result[T]{ + Data: r, + Error: err, + } + }(task) + } + + var r Result[T] + var taskErrs []error + var items []T + + tt := len(a.tasks) + for i := 0; i < tt; i++ { + select { + case r = <-wait: + if r.Error != nil { + taskErrs = append(taskErrs, r.Error) + } else { + items = append(items, r.Data) + } + case <-ctx.Done(): + return items, taskErrs, ctx.Err() + } + } + + if len(items) == tt { + return items, taskErrs, nil + } + + return items, taskErrs, ErrTooLessDone +} + +func (a *waiter[T]) WaitN(ctx context.Context, n int) ([]T, []error, error) { + wait := make(chan Result[T]) + + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, task := range a.tasks { + go func(task func(context.Context) (T, error)) { + r, err := task(cancelCtx) + wait <- Result[T]{ + Data: r, + Error: err, + } + }(task) + } + + var r Result[T] + var taskErrs []error + var items []T + tt := len(a.tasks) + var done int + for i := 0; i < tt; i++ { + select { + case r = <-wait: + if r.Error != nil { + taskErrs = append(taskErrs, r.Error) + } else { + items = append(items, r.Data) + done++ + if done == n { + return items, taskErrs, nil + } + } + case <-ctx.Done(): + return items, taskErrs, ctx.Err() + } + + } + + return items, taskErrs, ErrTooLessDone +} + +func (a *waiter[T]) WaitAny(ctx context.Context) (T, []error, error) { + var t T + result, taskErrs, err := a.WaitN(ctx, 1) + + if len(result) == 1 { + t = result[0] + } + + return t, taskErrs, err +} diff --git a/waiter_test.go b/waiter_test.go new file mode 100644 index 0000000..1843c37 --- /dev/null +++ b/waiter_test.go @@ -0,0 +1,455 @@ +package async + +import ( + "context" + "errors" + "slices" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWait(t *testing.T) { + + wantedErr := errors.New("wanted") + wantedErrs := []error{wantedErr} + + tests := []struct { + name string + ctx func() context.Context + withCancel bool + setup func() Waiter[int] + wantedResult []int + wantedErr error + wantedErrs []error + }{ + { + name: "wait_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + a := New[int](func(ctx context.Context) (int, error) { + return 1, nil + }, func(ctx context.Context) (int, error) { + return 2, nil + }) + + a.Add(func(ctx context.Context) (int, error) { + return 3, nil + }) + + return a + }, + wantedResult: []int{1, 2, 3}, + wantedErr: nil, + }, + { + name: "error_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + return 1, nil + }, func(ctx context.Context) (int, error) { + return 2, nil + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }) + }, + wantedResult: []int{1, 2}, + wantedErr: ErrTooLessDone, + wantedErrs: wantedErrs, + }, + { + name: "errors_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + return 0, wantedErr + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }) + }, + wantedErr: ErrTooLessDone, + wantedErrs: []error{wantedErr, wantedErr, wantedErr}, + }, + { + name: "context_should_work", + ctx: func() context.Context { + ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) //nolint + return ctx + }, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 3, nil + }) + }, + wantedResult: []int{3}, + wantedErr: context.DeadlineExceeded, + }, + { + name: "cancel_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 3, nil + }) + }, + withCancel: true, + wantedResult: []int{3}, + wantedErr: context.Canceled, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + a := test.setup() + var result []int + var err error + var taskErrs []error + + if test.withCancel { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + result, taskErrs, err = a.Wait(ctx) + } else { + result, taskErrs, err = a.Wait(test.ctx()) + } + + slices.Sort(result) + + require.Equal(t, test.wantedResult, result) + require.Equal(t, test.wantedErr, err) + require.Equal(t, test.wantedErrs, taskErrs) + + }) + + } +} + +func TestWaitAny(t *testing.T) { + + wantedErr := errors.New("wanted") + + tests := []struct { + name string + ctx func() context.Context + withCancel bool + setup func() Waiter[int] + wantedResult int + wantedErr error + wantedErrs []error + }{ + { + name: "1st_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + a := New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 3, nil + }) + + a.Add(func(ctx context.Context) (int, error) { + return 1, nil + }) + + return a + }, + wantedResult: 1, + }, + { + name: "2nd_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + return 2, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 3, nil + }) + }, + wantedResult: 2, + }, + { + name: "3rd_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + + return 3, nil + }) + }, + wantedResult: 3, + }, + { + name: "slowest_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(3 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }) + }, + wantedResult: 1, + wantedErrs: []error{wantedErr, wantedErr}, + }, + { + name: "fastest_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(2 * time.Second) + return 0, wantedErr + }, func(ctx context.Context) (int, error) { + time.Sleep(3 * time.Second) + return 0, wantedErr + }) + }, + wantedResult: 1, + }, + { + name: "errors_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + return 0, wantedErr + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }) + }, + wantedErr: ErrTooLessDone, + wantedErrs: []error{wantedErr, wantedErr, wantedErr}, + }, + { + name: "error_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + return 0, wantedErr + }) + }, + wantedErr: ErrTooLessDone, + wantedErrs: []error{wantedErr}, + }, + { + name: "context_should_work", + ctx: func() context.Context { + ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) //nolint + return ctx + }, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 3, nil + }) + }, + wantedErr: context.DeadlineExceeded, + }, + { + name: "cancel_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 3, nil + }) + }, + withCancel: true, + wantedErr: context.Canceled, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + a := test.setup() + var result int + var err error + var taskErrs []error + + if test.withCancel { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + result, taskErrs, err = a.WaitAny(ctx) + } else { + result, taskErrs, err = a.WaitAny(test.ctx()) + } + + require.Equal(t, test.wantedResult, result) + require.Equal(t, test.wantedErr, err) + require.Equal(t, test.wantedErrs, taskErrs) + }) + + } +} + +func TestWaitN(t *testing.T) { + + wantedErr := errors.New("wanted") + + tests := []struct { + name string + ctx func() context.Context + withCancel bool + setup func() Waiter[int] + wantedN int + wantedResult []int + wantedErr error + wantedErrs []error + }{ + { + name: "wait_n_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + a := New[int](func(ctx context.Context) (int, error) { + return 1, nil + }, func(ctx context.Context) (int, error) { + return 2, nil + }) + + a.Add(func(ctx context.Context) (int, error) { + time.Sleep(1 * time.Second) + return 3, nil + }) + + return a + }, + wantedN: 2, + wantedResult: []int{1, 2}, + }, + { + name: "error_n_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(1 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(1 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 0, wantedErr + }) + }, + wantedN: 2, + wantedResult: []int{1, 2}, + wantedErrs: []error{wantedErr}, + }, + { + name: "context_should_work", + ctx: func() context.Context { + ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) //nolint + return ctx + }, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 3, nil + }) + }, + wantedResult: []int{3}, + wantedErr: context.DeadlineExceeded, + }, + { + name: "cancel_should_work", + ctx: context.Background, + setup: func() Waiter[int] { + return New[int](func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 1, nil + }, func(ctx context.Context) (int, error) { + time.Sleep(5 * time.Second) + return 2, nil + }, func(ctx context.Context) (int, error) { + return 3, nil + }) + }, + withCancel: true, + wantedResult: []int{3}, + wantedErr: context.Canceled, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + a := test.setup() + var result []int + var err error + var taskErrs []error + + if test.withCancel { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Second) + cancel() + }() + result, taskErrs, err = a.WaitN(ctx, test.wantedN) + } else { + result, taskErrs, err = a.WaitN(test.ctx(), test.wantedN) + } + + slices.Sort(result) + + require.Equal(t, test.wantedResult, result) + require.Equal(t, test.wantedErr, err) + require.Equal(t, test.wantedErrs, taskErrs) + + }) + + } +}