Skip to content

Commit

Permalink
feat(waitn): added WaitN
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi committed Mar 11, 2024
1 parent a04de99 commit aae6d93
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 59 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,9 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [1.0.1] - 2024-03-11
- added `WaitN` (#1)

## [1.0.0] - 2024-03-08
- 1st release
9 changes: 8 additions & 1 deletion async.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package async

import "context"
import (
"context"
"errors"
)

var (
ErrTooLessDone = errors.New("async: too less tasks to completed without error")
)

func New[T any](tasks ...func(ctx context.Context) (T, error)) Awaiter[T] {
return &awaiter[T]{
Expand Down
72 changes: 41 additions & 31 deletions awaiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ import (
)

type Awaiter[T any] interface {
// Add add a task
Add(task func(context.Context) (T, error))
Wait(context.Context) ([]T, error)
WaitAny(context.Context) (T, error)
// Wait wail for all tasks to completed
Wait(context.Context) ([]T, error, []error)
// WaitAny wait for any task to completed without error, can cancel other tasks
WaitAny(context.Context) (T, error, []error)
// WaitN wait for N tasks to completed without error
WaitN(context.Context, int) ([]T, error, []error)
}

type awaiter[T any] struct {
Expand All @@ -18,10 +23,9 @@ func (a *awaiter[T]) Add(task func(ctx context.Context) (T, error)) {
a.tasks = append(a.tasks, task)
}

func (a *awaiter[T]) Wait(ctx context.Context) ([]T, error) {
func (a *awaiter[T]) Wait(ctx context.Context) ([]T, error, []error) {
wait := make(chan Result[T])

n := len(a.tasks)
for _, task := range a.tasks {
go func(task func(context.Context) (T, error)) {
r, err := task(ctx)
Expand All @@ -33,35 +37,29 @@ func (a *awaiter[T]) Wait(ctx context.Context) ([]T, error) {
}

var r Result[T]
var es Errors
var taskErrs []error
var items []T

for i := 0; i < n; i++ {
tt := len(a.tasks)
for i := 0; i < tt; i++ {
select {
case r = <-wait:
if r.Error != nil {
es = append(es, r.Error)
taskErrs = append(taskErrs, r.Error)
} else {
items = append(items, r.Data)
}
case <-ctx.Done():
return items, ctx.Err()
return items, ctx.Err(), taskErrs
}

}

if len(es) > 0 {
return items, es
}

return items, nil
return items, nil, taskErrs
}

func (a *awaiter[T]) WaitAny(ctx context.Context) (T, error) {

n := len(a.tasks)

func (a *awaiter[T]) WaitN(ctx context.Context, n int) ([]T, error, []error) {
wait := make(chan Result[T])

cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()

Expand All @@ -75,27 +73,39 @@ func (a *awaiter[T]) WaitAny(ctx context.Context) (T, error) {
}(task)
}

var t T

var r Result[T]
var es Errors

for i := 0; i < n; i++ {
var taskErrs []error
var items []T
tt := len(a.tasks)
var done int
for i := 0; i < tt; i++ {
select {
case r = <-wait:
if r.Error == nil {
return r.Data, nil
if r.Error != nil {
taskErrs = append(taskErrs, r.Error)
} else {
items = append(items, r.Data)
done++
if done == n {
return items, nil, taskErrs
}
}

es = append(es, r.Error)
case <-ctx.Done():
return t, ctx.Err()
return items, ctx.Err(), taskErrs
}

}

if len(es) > 0 {
return t, es
return items, ErrTooLessDone, taskErrs
}

func (a *awaiter[T]) WaitAny(ctx context.Context) (T, error, []error) {
var t T
result, err, taskErrs := a.WaitN(ctx, 1)

if len(result) == 1 {
t = result[0]
}

return t, nil
return t, err, taskErrs
}
Loading

0 comments on commit aae6d93

Please sign in to comment.