From 3b3fc828ebf4c11413693abcc4388f90ebd7f0f7 Mon Sep 17 00:00:00 2001 From: "Alan D. Cabrera" Date: Sun, 21 Mar 2021 22:13:33 -0700 Subject: [PATCH] Fix BufferWithTimeOrCount buffer emission The observable created with BufferWithTimeOrCount() should emit buffers that are at most length count. The previous code erroneously used count simply as a trigger for emission, resulting in the possibility of emitting buffers larger than count. --- observable_operator.go | 47 ++++++--- observable_operator_test.go | 188 +++++++++++++++++++++++++++++++++++- 2 files changed, 219 insertions(+), 16 deletions(-) diff --git a/observable_operator.go b/observable_operator.go index 5277b4a6..6ee093df 100644 --- a/observable_operator.go +++ b/observable_operator.go @@ -511,8 +511,10 @@ func (o *ObservableImpl) BufferWithTime(timespan Duration, opts ...Option) Obser return customObservableOperator(o.parent, f, opts...) } -// BufferWithTimeOrCount returns an Observable that emits buffers of items it collects from the source -// Observable either from a given count or at a given time interval. +// BufferWithTimeOrCount returns an Observable that emits buffers, of max size +// count, of items it collects from the source Observable, or if the timespan +// has elapsed, whatever was collected in the buffer since the last emitted +// buffer. func (o *ObservableImpl) BufferWithTimeOrCount(timespan Duration, count int, opts ...Option) Observable { if timespan == nil { return Thrown(IllegalInputError{error: "timespan must no be nil"}) @@ -528,16 +530,37 @@ func (o *ObservableImpl) BufferWithTimeOrCount(timespan Duration, count int, opt send := make(chan struct{}) mutex := sync.Mutex{} - checkBuffer := func() { + // checkBuffer will send buffered units of at most size count, unless + // flush is true. + checkBuffer := func(flush bool) { mutex.Lock() - if len(buffer) != 0 { - if !Of(buffer).SendContext(ctx, next) { - mutex.Unlock() - return + defer mutex.Unlock() + + length := len(buffer) + if length != 0 { + var last int + defer func() { + // create a copy of buffer, less whatever was already sent + t := make([]interface{}, length-(last+1)) + copy(t, buffer[last+1:length]) + buffer = t + }() + + for i := 0; i < length; i += count { + high := i + count + if high > length { + if flush { + high = length + } else { + return + } + } + if !Of(buffer[i:high]).SendContext(ctx, next) { + return + } + last = high - 1 } - buffer = make([]interface{}, 0) } - mutex.Unlock() } go func() { @@ -546,14 +569,14 @@ func (o *ObservableImpl) BufferWithTimeOrCount(timespan Duration, count int, opt for { select { case <-send: - checkBuffer() + checkBuffer(false) case <-stop: - checkBuffer() + checkBuffer(true) return case <-ctx.Done(): return case <-time.After(duration): - checkBuffer() + checkBuffer(true) } } }() diff --git a/observable_operator_test.go b/observable_operator_test.go index fa9ac9cb..b6980994 100644 --- a/observable_operator_test.go +++ b/observable_operator_test.go @@ -323,6 +323,186 @@ func Test_Observable_BufferWithTimeOrCount(t *testing.T) { })) } +func Test_Observable_BufferWithTimeOrCount_DoesNotExceedBufferSize(t *testing.T) { + observable := Range(1, 5) + buffers := observable.BufferWithTimeOrCount(WithDuration(time.Millisecond*500), 2) + var seen int + Assert(context.Background(), t, buffers, CustomPredicate(func(items []interface{}) error { + for _, item := range items { + buffer := item.([]interface{}) + if len(buffer) > 2 { + return errors.New("items should not be greater than two") + } + for _, entry := range buffer { + value := entry.(int) + if value-1 != seen { + return fmt.Errorf("items should consecutive, %d does not follow %d", value, seen) + } + seen = value + } + } + return nil + })) +} + +func delayItem(ctx context.Context, item interface{}) (interface{}, error) { + entry := item.(struct { + name string + msDelay int + }) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(entry.msDelay) * time.Millisecond): + return entry.name, nil + } +} + +func Test_Observable_BufferWithTimeOrCount_ElapsedCapturesPartial(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + observable := Just([]struct { + name string + msDelay int + }{ + {"a", 5}, + {"b", 1000}, + })().Map(delayItem, WithContext(ctx)) + + buffers := observable.BufferWithTimeOrCount(WithDuration(time.Millisecond*10), 2, WithContext(ctx)) + + Assert(ctx, t, buffers, CustomPredicate(func(items []interface{}) error { + partialEncountered := false + for _, item := range items { + buffer := item.([]interface{}) + if len(buffer) == 1 { + partialEncountered = true + } + } + if !partialEncountered { + return errors.New("at least one partial observation should have occurred") + } + return nil + })) +} + +func Test_Observable_BufferWithTimeOrCount_EmitsCompleteBuffers(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx, cancel1 := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel1() + + // straggler not emitted before context timeout + observable := Just([]struct { + name string + msDelay int + }{ + {"a", 1}, + {"b", 1}, + {"c", 1}, + {"d", 1}, + {"e", 1000}, + })().Map(delayItem, WithContext(ctx)) + + buffers := observable.BufferWithTimeOrCount(WithDuration(time.Millisecond*20), 2, WithContext(ctx)) + + ctx, cancel2 := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel2() + + Assert(ctx, t, buffers, CustomPredicate(func(items []interface{}) error { + for _, item := range items { + buffer := item.([]interface{}) + if len(buffer) != 2 { + return errors.New("items should be bundles of two") + } + } + return nil + })) +} + +func Test_Observable_BufferWithTimeOrCount_CapturesStraggler(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx, cancel1 := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel1() + + // straggler will stay in buffer until buffer timeout + observable := Just([]struct { + name string + msDelay int + }{ + {"a", 1}, + {"b", 1}, + {"c", 1}, + {"d", 1}, + {"straggler", 1}, + })().Map(delayItem, WithContext(ctx)) + + buffers := observable.BufferWithTimeOrCount(WithDuration(time.Millisecond*10), 2, WithContext(ctx)) + + ctx, cancel2 := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel2() + + Assert(ctx, t, buffers, CustomPredicate(func(items []interface{}) error { + seen := false + for _, item := range items { + buffer := item.([]interface{}) + for _, b := range buffer { + if b.(string) == "straggler" { + seen = true + } + } + } + if !seen { + return errors.New("straggler item not emitted") + } + return nil + })) +} + +func Test_Observable_BufferWithTimeOrCount_DoneEmitsStraggler(t *testing.T) { + defer goleak.VerifyNone(t) + + ctx, cancel1 := context.WithTimeout(context.Background(), time.Millisecond*50) + defer cancel1() + + // straggler will stay in buffer until context times out + observable := Just([]struct { + name string + msDelay int + }{ + {"a", 1}, + {"b", 1}, + {"c", 1}, + {"d", 1}, + {"straggler", 1}, + })().Map(delayItem, WithContext(ctx)) + + buffers := observable.BufferWithTimeOrCount(WithDuration(time.Second*10), 2, WithContext(ctx)) + + ctx, cancel2 := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel2() + + Assert(ctx, t, buffers, CustomPredicate(func(items []interface{}) error { + seen := false + for _, item := range items { + buffer := item.([]interface{}) + for _, b := range buffer { + if b.(string) == "straggler" { + seen = true + } + } + } + if !seen { + return errors.New("straggler item not emitted") + } + return nil + })) +} + func Test_Observable_Contain(t *testing.T) { defer goleak.VerifyNone(t) ctx, cancel := context.WithCancel(context.Background()) @@ -382,12 +562,12 @@ func Test_Observable_Count_Parallel(t *testing.T) { } // FIXME -//func Test_Observable_Debounce(t *testing.T) { +// func Test_Observable_Debounce(t *testing.T) { // defer goleak.VerifyNone(t) // ctx, obs, d := timeCausality(1, tick, 2, tick, 3, 4, 5, tick, 6, tick) // Assert(ctx, t, obs.Debounce(d, WithBufferedChannel(10), WithContext(ctx)), // HasItems(1, 2, 5, 6)) -//} +// } func Test_Observable_Debounce_Error(t *testing.T) { defer goleak.VerifyNone(t) @@ -2300,7 +2480,7 @@ func Test_Observable_WindowWithCount_InputError(t *testing.T) { } // FIXME -//func Test_Observable_WindowWithTime(t *testing.T) { +// func Test_Observable_WindowWithTime(t *testing.T) { // defer goleak.VerifyNone(t) // ctx, cancel := context.WithCancel(context.Background()) // defer cancel() @@ -2317,7 +2497,7 @@ func Test_Observable_WindowWithCount_InputError(t *testing.T) { // observe := obs.WindowWithTime(WithDuration(10*time.Millisecond), WithBufferedChannel(10)).Observe() // Assert(ctx, t, (<-observe).V.(Observable), HasItems(1, 2)) // Assert(ctx, t, (<-observe).V.(Observable), HasItems(3)) -//} +// } func Test_Observable_WindowWithTimeOrCount(t *testing.T) { defer goleak.VerifyNone(t)