Skip to content

Commit

Permalink
fix: send terminal msg asynchronously and rm close implicitly
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Dec 18, 2023
1 parent fb2d2c5 commit 3b6eae0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 62 deletions.
24 changes: 9 additions & 15 deletions lang/channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package channel

import (
"container/list"
"runtime"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -165,11 +164,6 @@ type Channel interface {
Close()
}

// channelWrapper use to detect user never hold the reference of channel object, and we need to close channel implicitly.
type channelWrapper struct {
Channel
}

// channel implements a safe and feature-rich channel struct for the real world.
type channel struct {
size int
Expand Down Expand Up @@ -206,14 +200,7 @@ func New(opts ...Option) Channel {
c.buffer = list.New()
go c.produce()
go c.consume()

// register finalizer for wrapper of channel
cw := &channelWrapper{c}
runtime.SetFinalizer(cw, func(obj *channelWrapper) {
// it's ok to call Close again if user already closed the channel
obj.Close()
})
return cw
return c
}

// Close will close the producer and consumer goroutines gracefully
Expand All @@ -226,7 +213,14 @@ func (c *channel) Close() {
c.buffer.Init() // clear
c.bufferLock.Unlock()
c.bufferCond.Broadcast()
c.producer <- terminalSig
select {
case c.producer <- terminalSig:
default:
// producer channel is full, so create a new goroutine to send terminal msg asynchronously
go func() {
c.producer <- terminalSig
}()
}
}

func (c *channel) isClosed() bool {
Expand Down
2 changes: 2 additions & 0 deletions lang/channel/channel_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func TestNetworkIsolationOrDownstreamBlock(t *testing.T) {
WithNonBlock(),
WithTimeout(time.Millisecond*10),
)
defer taskPool.Close()
var responded int32
go func() {
// task worker
Expand Down Expand Up @@ -93,6 +94,7 @@ func TestCPUHeavy(t *testing.T) {
return atomic.LoadInt32(&concurrency) > 10
}),
)
defer taskPool.Close()
var responded int32
go func() {
// task worker
Expand Down
82 changes: 35 additions & 47 deletions lang/channel/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ func BenchmarkChannel(b *testing.B) {
for _, size := range benchSizes {
b.Run(fmt.Sprintf("Size-%d", size), func(b *testing.B) {
var ch Channel
defer ch.Close()
if size < 0 {
ch = New(WithNonBlock())
} else {
Expand Down Expand Up @@ -105,58 +106,33 @@ func TestChannelDefaultSize(t *testing.T) {

func TestChannelClose(t *testing.T) {
beginGs := runtime.NumGoroutine()
channel := New()
ch := New()
afterGs := runtime.NumGoroutine()
assert.Equal(t, 2, afterGs-beginGs)
var exit int32
go func() {
for _ = range channel.Output() {
for _ = range ch.Output() {
}
atomic.AddInt32(&exit, 1)
}()
for i := 1; i <= 20; i++ {
channel.Input() <- i
ch.Input() <- i
}
channel.Close()
ch.Close()
for runtime.NumGoroutine() > beginGs {
runtime.Gosched()
}
<-channel.Output() // never block
<-ch.Output() // never block
assert.Equal(t, int32(1), atomic.LoadInt32(&exit))
}

func TestChannelGCClose(t *testing.T) {
// close implicitly
go func() {
_ = New()
}()
go func() {
ch := New()
ch.Input() <- 1
_ = <-ch.Output()
tlogf(t, "channel finished")
}()
for i := 0; i < 3; i++ {
time.Sleep(time.Millisecond * 10)
runtime.GC()
}

// close explicitly
go func() {
ch := New()
ch.Close()
}()
for i := 0; i < 3; i++ {
time.Sleep(time.Millisecond * 10)
runtime.GC()
}
}

func TestChannelTimeout(t *testing.T) {
ch := New(
WithTimeout(time.Millisecond*50),
WithSize(1024),
)
defer ch.Close()

go func() {
for i := 1; i <= 20; i++ {
ch.Input() <- i
Expand Down Expand Up @@ -187,6 +163,8 @@ func TestChannelConsumerInflightLimit(t *testing.T) {
return atomic.LoadInt32(&inflight) >= limit
}),
)
defer ch.Close()

var wg sync.WaitGroup
go func() {
for c := range ch.Output() {
Expand Down Expand Up @@ -218,6 +196,8 @@ func TestChannelConsumerInflightLimit(t *testing.T) {
func TestChannelProducerSpeedLimit(t *testing.T) {
var total = 15
ch := New(WithSize(0))
defer ch.Close()

go func() {
for c := range ch.Output() {
id := c.(int)
Expand All @@ -239,6 +219,8 @@ func TestChannelProducerSpeedLimit(t *testing.T) {
func TestChannelProducerNoLimit(t *testing.T) {
var total = 100
ch := New(WithSize(1000))
defer ch.Close()

go func() {
for c := range ch.Output() {
id := c.(int)
Expand Down Expand Up @@ -290,6 +272,8 @@ func TestChannelGoroutinesThrottle(t *testing.T) {

func TestChannelNoConsumer(t *testing.T) {
ch := New()
defer ch.Close()

var sum int32
go func() {
for i := 1; i <= 20; i++ {
Expand All @@ -298,12 +282,13 @@ func TestChannelNoConsumer(t *testing.T) {
atomic.AddInt32(&sum, 1)
}
}()
time.Sleep(time.Second)
time.Sleep(time.Millisecond * 100)
assert.Equal(t, int32(2), atomic.LoadInt32(&sum))
}

func TestChannelOneSlowTask(t *testing.T) {
ch := New(WithTimeout(time.Millisecond*500), WithSize(0))
defer ch.Close()

var total int32
go func() {
Expand All @@ -326,48 +311,52 @@ func TestChannelOneSlowTask(t *testing.T) {

func TestChannelProduceRateControl(t *testing.T) {
produceMaxRate := 100
channel := New(
ch := New(
WithRateThrottle(produceMaxRate, 0),
)
defer ch.Close()

go func() {
for c := range channel.Output() {
for c := range ch.Output() {
id := c.(int)
tlogf(t, "consumed: %d", id)
}
}()
begin := time.Now()
for i := 1; i <= 500; i++ {
channel.Input() <- i
ch.Input() <- i
}
cost := time.Now().Sub(begin)
tlogf(t, "Cost %dms", cost.Milliseconds())
}

func TestChannelConsumeRateControl(t *testing.T) {
channel := New(
ch := New(
WithRateThrottle(0, 100),
)
defer ch.Close()

go func() {
for c := range channel.Output() {
for c := range ch.Output() {
id := c.(int)
tlogf(t, "consumed: %d", id)
}
}()
begin := time.Now()
for i := 1; i <= 500; i++ {
channel.Input() <- i
ch.Input() <- i
}
cost := time.Now().Sub(begin)
tlogf(t, "Cost %dms", cost.Milliseconds())
}

func TestChannelNonBlock(t *testing.T) {
channel := New(WithNonBlock())
ch := New(WithNonBlock())
defer ch.Close()

begin := time.Now()
for i := 1; i <= 10000; i++ {
channel.Input() <- i
ch.Input() <- i
tlogf(t, "producer=%d finished", i)
}
cost := time.Now().Sub(begin)
Expand Down Expand Up @@ -397,19 +386,18 @@ func TestFastRecoverConsumer(t *testing.T) {
var consumed int32
var aborted int32
timeout := time.Second * 1

channel := New(
ch := New(
WithNonBlock(),
WithTimeout(timeout),
WithTimeoutCallback(func(i interface{}) {
atomic.AddInt32(&aborted, 1)
}),
)
defer channel.Close()
defer ch.Close()

// consumer
go func() {
for c := range channel.Output() {
for c := range ch.Output() {
id := c.(int)
t.Logf("consumed: %d", id)
time.Sleep(time.Millisecond * 100)
Expand All @@ -420,7 +408,7 @@ func TestFastRecoverConsumer(t *testing.T) {
// producer
// faster than consumer's ability
for i := 1; i <= 20; i++ {
channel.Input() <- i
ch.Input() <- i
time.Sleep(time.Millisecond * 10)
}
for (atomic.LoadInt32(&consumed) + atomic.LoadInt32(&aborted)) != 20 {
Expand All @@ -431,7 +419,7 @@ func TestFastRecoverConsumer(t *testing.T) {
aborted = 0
// quick recover consumer
for i := 1; i <= 10; i++ {
channel.Input() <- i
ch.Input() <- i
time.Sleep(time.Millisecond * 10)
}
for atomic.LoadInt32(&consumed) != 10 {
Expand Down

0 comments on commit 3b6eae0

Please sign in to comment.