From 97206851ec8885fc5c989173d16c8cee584c35a1 Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Mon, 18 Dec 2023 20:22:39 +0800 Subject: [PATCH] fix: channel min size is 1 and change Input interface --- lang/channel/channel.go | 133 +++++++++++------------ lang/channel/channel_example_test.go | 2 +- lang/channel/channel_test.go | 155 ++++++++++++++++----------- 3 files changed, 163 insertions(+), 127 deletions(-) diff --git a/lang/channel/channel.go b/lang/channel/channel.go index bc813628..b9c75f43 100644 --- a/lang/channel/channel.go +++ b/lang/channel/channel.go @@ -16,6 +16,7 @@ package channel import ( "container/list" + "runtime" "sync" "sync/atomic" "time" @@ -23,7 +24,7 @@ import ( const ( defaultThrottleWindow = time.Millisecond * 100 - defaultSize = 0 + defaultMinSize = 1 ) type item struct { @@ -45,12 +46,12 @@ type Option func(c *channel) // Throttle define channel Throttle function type Throttle func(c Channel) bool -// WithSize define the size of channel. +// WithSize define the size of channel. If channel is full, it will block. // It conflicts with WithNonBlock option. func WithSize(size int) Option { return func(c *channel) { // with non block mode, no need to change size - if !c.nonblock { + if size >= defaultMinSize && !c.nonblock { c.size = size } } @@ -61,7 +62,6 @@ func WithSize(size int) Option { func WithNonBlock() Option { return func(c *channel) { c.nonblock = true - c.size = 1024 } } @@ -146,37 +146,43 @@ func WithRateThrottle(produceRate, consumeRate int) Option { }) } -var _ Channel = (*channel)(nil) +var ( + _ Channel = (*channel)(nil) +) +// Channel is a safe and feature-rich alternative for Go chan struct type Channel interface { - // Input return a native chan for produce task - Input() chan interface{} - // Output return a native chan for consume task - Output() chan interface{} - // Len return the count of un-consumed tasks + // Input send value to Output channel. If channel is closed, do nothing and will not panic. + Input(v interface{}) + // Output return a read-only native chan for consumer. + Output() <-chan interface{} + // Len return the count of un-consumed items. Len() int - // Stats return the produced and consumed count + // Stats return the produced and consumed count. Stats() (produced uint64, consumed uint64) - // Close will close the producer and consumer goroutines gracefully + // Close closed the output chan. If channel is not closed explicitly, it will be closed when it's finalized. Close() } +// channelWrapper use to detect user never hold the reference of Channel object, and runtime will help to close channel implicitly. +type channelWrapper struct { + Channel +} + // channel implements a safe and feature-rich channel struct for the real world. type channel struct { size int state int32 - producer chan interface{} consumer chan interface{} + nonblock bool // non blocking mode timeout time.Duration timeoutCallback func(interface{}) producerThrottle Throttle consumerThrottle Throttle throttleWindow time.Duration // statistics - produced uint64 - consumed uint64 - // non blocking mode - nonblock bool + produced uint64 // item already been insert into buffer + consumed uint64 // item already been sent into Output chan // buffer buffer *list.List // TODO: use high perf queue to reduce GC here bufferCond *sync.Cond @@ -186,18 +192,23 @@ type channel struct { // New create a new channel. func New(opts ...Option) Channel { c := new(channel) - c.size = defaultSize + c.size = defaultMinSize c.throttleWindow = defaultThrottleWindow c.bufferCond = sync.NewCond(&c.bufferLock) for _, opt := range opts { opt(c) } - c.producer = make(chan interface{}, c.size) c.consumer = make(chan interface{}) c.buffer = list.New() - go c.produce() go c.consume() - return c + + // register finalizer for wrapper of channel + cw := &channelWrapper{c} + runtime.SetFinalizer(cw, func(obj *channelWrapper) { + // it's ok to call Close again if user already closed the channel + obj.Close() + }) + return cw } // Close will close the producer and consumer goroutines gracefully @@ -205,8 +216,6 @@ func (c *channel) Close() { if !atomic.CompareAndSwapInt32(&c.state, 0, -1) { return } - // stop producer - close(c.producer) // stop consumer c.bufferLock.Lock() c.buffer.Init() // clear buffer @@ -218,17 +227,44 @@ func (c *channel) isClosed() bool { return atomic.LoadInt32(&c.state) < 0 } -// Input return a native chan for produce task -func (c *channel) Input() chan interface{} { - return c.producer +func (c *channel) Input(v interface{}) { + if c.isClosed() { + return + } + + // prepare item + it := item{value: v} + if c.timeout > 0 { + it.deadline = time.Now().Add(c.timeout) + } + + // only check throttle function in blocking mode + if !c.nonblock { + if c.throttling(c.producerThrottle) { + // closed + return + } + } + + // enqueue buffer + c.bufferLock.Lock() + if !c.nonblock { + // only check length with blocking mode + for c.buffer.Len() >= c.size { + // wait for consuming + c.bufferCond.Wait() + } + } + c.enqueueBuffer(it) + atomic.AddUint64(&c.produced, 1) + c.bufferLock.Unlock() + c.bufferCond.Signal() // use Signal because only 1 goroutine wait for cond } -// Output return a native chan for consume task -func (c *channel) Output() chan interface{} { +func (c *channel) Output() <-chan interface{} { return c.consumer } -// Len return the count of un-consumed tasks. func (c *channel) Len() int { produced, consumed := c.Stats() l := produced - consumed @@ -240,42 +276,7 @@ func (c *channel) Stats() (uint64, uint64) { return produced, consumed } -// produce used to process input channel -func (c *channel) produce() { - capacity := c.size - if c.size == 0 { - capacity = 1 - } - for p := range c.producer { - // only check throttle function in blocking mode - if !c.nonblock { - if c.throttling(c.producerThrottle) { - // closed - return - } - } - - // produced - atomic.AddUint64(&c.produced, 1) - // prepare item - it := item{value: p} - if c.timeout > 0 { - it.deadline = time.Now().Add(c.timeout) - } - // enqueue buffer - c.bufferLock.Lock() - c.enqueueBuffer(it) - c.bufferCond.Signal() - if !c.nonblock { - for c.buffer.Len() >= capacity && !c.isClosed() { - c.bufferCond.Wait() - } - } - c.bufferLock.Unlock() - } -} - -// consume used to process output channel +// consume used to process input buffer func (c *channel) consume() { for { // check throttle @@ -297,7 +298,7 @@ func (c *channel) consume() { } it, ok := c.dequeueBuffer() c.bufferLock.Unlock() - c.bufferCond.Signal() + c.bufferCond.Broadcast() // use Broadcast because there will be more than 1 goroutines wait for cond if !ok { // in fact, this case will never happen continue diff --git a/lang/channel/channel_example_test.go b/lang/channel/channel_example_test.go index 49b26d85..a68eccf1 100644 --- a/lang/channel/channel_example_test.go +++ b/lang/channel/channel_example_test.go @@ -36,7 +36,7 @@ type response struct { var taskPool Channel func Service1(req *request) { - taskPool.Input() <- req // async run + taskPool.Input(req) return } diff --git a/lang/channel/channel_test.go b/lang/channel/channel_test.go index 1c9187cc..b0258bb5 100644 --- a/lang/channel/channel_test.go +++ b/lang/channel/channel_test.go @@ -44,7 +44,7 @@ func BenchmarkNativeChan(b *testing.B) { if size < 0 { continue } - b.Run(fmt.Sprintf("Size-%d", size), func(b *testing.B) { + b.Run(fmt.Sprintf("Size-[%d]", size), func(b *testing.B) { ch := make(chan interface{}, size) b.RunParallel(func(pb *testing.PB) { n := 0 @@ -60,7 +60,7 @@ func BenchmarkNativeChan(b *testing.B) { func BenchmarkChannel(b *testing.B) { for _, size := range benchSizes { - b.Run(fmt.Sprintf("Size-%d", size), func(b *testing.B) { + b.Run(fmt.Sprintf("Size-[%d]", size), func(b *testing.B) { var ch Channel if size < 0 { ch = New(WithNonBlock()) @@ -72,7 +72,7 @@ func BenchmarkChannel(b *testing.B) { n := 0 for pb.Next() { n++ - ch.Input() <- n + ch.Input(n) <-ch.Output() } }) @@ -84,39 +84,37 @@ func TestChannelDefaultSize(t *testing.T) { ch := New() defer ch.Close() - for i := 1; i <= 10; i++ { - ch.Input() <- i - t.Logf("put %d", i) - x := <-ch.Output() - t.Logf("get %d", x) - assert.Equal(t, i, x) - } - ch.Input() <- 0 // wait for be consumed - t.Logf("put 0-1") - ch.Input() <- 0 // wait for be buffered - t.Logf("put 0-2") - timeout := false - select { - case ch.Input() <- 0: // block - case <-time.After(time.Millisecond * 10): - timeout = true - } - assert.True(t, timeout) + ch.Input(0) + ch.Input(0) + var timeouted uint32 + go func() { + ch.Input(0) // block + atomic.AddUint32(&timeouted, 1) + }() + go func() { + ch.Input(0) // block + atomic.AddUint32(&timeouted, 1) + }() + time.Sleep(time.Millisecond * 100) + assert.Equal(t, atomic.LoadUint32(&timeouted), uint32(0)) } func TestChannelClose(t *testing.T) { beginGs := runtime.NumGoroutine() ch := New() afterGs := runtime.NumGoroutine() - assert.Equal(t, 2, afterGs-beginGs) + assert.Equal(t, 1, afterGs-beginGs) var exit int32 go func() { - for _ = range ch.Output() { + for v := range ch.Output() { + id := v.(int) + tlogf(t, "consumer=%d started", id) } atomic.AddInt32(&exit, 1) }() for i := 1; i <= 20; i++ { - ch.Input() <- i + ch.Input(i) + tlogf(t, "producer=%d started", i) } ch.Close() for runtime.NumGoroutine() > beginGs { @@ -127,6 +125,35 @@ func TestChannelClose(t *testing.T) { assert.Equal(t, int32(1), atomic.LoadInt32(&exit)) } +func TestChannelGCClose(t *testing.T) { + beginGs := runtime.NumGoroutine() + // close implicitly + go func() { + _ = New() + }() + go func() { + ch := New() + ch.Input(1) + _ = <-ch.Output() + tlogf(t, "channel finished") + }() + for i := 0; i < 3; i++ { + time.Sleep(time.Millisecond * 10) + runtime.GC() + } + // close explicitly + go func() { + ch := New() + ch.Close() + }() + for i := 0; i < 3; i++ { + time.Sleep(time.Millisecond * 10) + runtime.GC() + } + afterGs := runtime.NumGoroutine() + assert.Equal(t, beginGs, afterGs) +} + func TestChannelTimeout(t *testing.T) { ch := New( WithTimeout(time.Millisecond*50), @@ -136,7 +163,7 @@ func TestChannelTimeout(t *testing.T) { go func() { for i := 1; i <= 20; i++ { - ch.Input() <- i + ch.Input(i) } }() var total int32 @@ -185,7 +212,7 @@ func TestChannelConsumerInflightLimit(t *testing.T) { for i := 1; i <= total; i++ { wg.Add(1) id := i - ch.Input() <- id + ch.Input(id) tlogf(t, "producer=%d finished", id) time.Sleep(time.Millisecond * 10) } @@ -210,7 +237,7 @@ func TestChannelProducerSpeedLimit(t *testing.T) { now := time.Now() for i := 1; i <= total; i++ { id := i - ch.Input() <- id + ch.Input(id) tlogf(t, "producer=%d finished", id) } duration := time.Now().Sub(now) @@ -233,7 +260,7 @@ func TestChannelProducerNoLimit(t *testing.T) { now := time.Now() for i := 1; i <= total; i++ { id := i - ch.Input() <- id + ch.Input(id) } duration := time.Now().Sub(now) assert.Equal(t, 0, int(duration.Seconds())) @@ -264,7 +291,7 @@ func TestChannelGoroutinesThrottle(t *testing.T) { for i := 1; i <= total; i++ { wg.Add(1) id := i - ch.Input() <- id + ch.Input(id) tlogf(t, "producer=%d finished", id) runtime.Gosched() } @@ -272,22 +299,48 @@ func TestChannelGoroutinesThrottle(t *testing.T) { } func TestChannelNoConsumer(t *testing.T) { + // zero size channel ch := New() - var sum int32 go func() { for i := 1; i <= 20; i++ { - ch.Input() <- i + ch.Input(i) tlogf(t, "producer=%d finished", i) atomic.AddInt32(&sum, 1) } }() time.Sleep(time.Millisecond * 100) assert.Equal(t, int32(2), atomic.LoadInt32(&sum)) + + // 1 size channel + ch = New(WithSize(1)) + atomic.StoreInt32(&sum, 0) + go func() { + for i := 1; i <= 20; i++ { + ch.Input(i) + tlogf(t, "producer=%d finished", i) + atomic.AddInt32(&sum, 1) + } + }() + time.Sleep(time.Millisecond * 100) + assert.Equal(t, int32(2), atomic.LoadInt32(&sum)) + + // 10 size channel + ch = New(WithSize(10)) + atomic.StoreInt32(&sum, 0) + go func() { + for i := 1; i <= 20; i++ { + ch.Input(i) + tlogf(t, "producer=%d finished", i) + atomic.AddInt32(&sum, 1) + } + }() + time.Sleep(time.Millisecond * 100) + assert.Equal(t, int32(11), atomic.LoadInt32(&sum)) } func TestChannelOneSlowTask(t *testing.T) { - ch := New(WithTimeout(time.Millisecond*500), WithSize(0)) + ch := New(WithTimeout(time.Millisecond*100), WithSize(20)) defer ch.Close() var total int32 @@ -295,18 +348,19 @@ func TestChannelOneSlowTask(t *testing.T) { for c := range ch.Output() { id := c.(int) if id == 10 { - time.Sleep(time.Second) + time.Sleep(time.Millisecond * 200) } atomic.AddInt32(&total, 1) + tlogf(t, "consumer=%d finished", id) } }() for i := 1; i <= 20; i++ { - ch.Input() <- i + ch.Input(i) tlogf(t, "producer=%d finished", i) } - time.Sleep(time.Second) - assert.Equal(t, int32(19), atomic.LoadInt32(&total)) + time.Sleep(time.Millisecond * 300) + assert.Equal(t, int32(11), atomic.LoadInt32(&total)) } func TestChannelProduceRateControl(t *testing.T) { @@ -324,7 +378,7 @@ func TestChannelProduceRateControl(t *testing.T) { }() begin := time.Now() for i := 1; i <= 500; i++ { - ch.Input() <- i + ch.Input(i) } cost := time.Now().Sub(begin) tlogf(t, "Cost %dms", cost.Milliseconds()) @@ -344,7 +398,7 @@ func TestChannelConsumeRateControl(t *testing.T) { }() begin := time.Now() for i := 1; i <= 500; i++ { - ch.Input() <- i + ch.Input(i) } cost := time.Now().Sub(begin) tlogf(t, "Cost %dms", cost.Milliseconds()) @@ -356,32 +410,13 @@ func TestChannelNonBlock(t *testing.T) { begin := time.Now() for i := 1; i <= 10000; i++ { - ch.Input() <- i + ch.Input(i) tlogf(t, "producer=%d finished", i) } cost := time.Now().Sub(begin) tlogf(t, "Cost %dms", cost.Milliseconds()) } -func TestAvoidGoroutineLeak(t *testing.T) { - // Default channel is safe - recvCh := New() - var wg sync.WaitGroup - wg.Add(1) - // producer - go func() { - time.Sleep(time.Millisecond * 100) // RPC Call - recvCh.Input() <- 1 - wg.Done() - }() - // consumer - select { - case <-recvCh.Output(): - case <-time.After(time.Millisecond * 50): - } - wg.Wait() // goroutine exit -} - func TestFastRecoverConsumer(t *testing.T) { var consumed int32 var aborted int32 @@ -408,7 +443,7 @@ func TestFastRecoverConsumer(t *testing.T) { // producer // faster than consumer's ability for i := 1; i <= 20; i++ { - ch.Input() <- i + ch.Input(i) time.Sleep(time.Millisecond * 10) } for (atomic.LoadInt32(&consumed) + atomic.LoadInt32(&aborted)) != 20 { @@ -419,7 +454,7 @@ func TestFastRecoverConsumer(t *testing.T) { aborted = 0 // quick recover consumer for i := 1; i <= 10; i++ { - ch.Input() <- i + ch.Input(i) time.Sleep(time.Millisecond * 10) } for atomic.LoadInt32(&consumed) != 10 {