From 5698cc626593d4e10aa0883661f43bc2fcd873f0 Mon Sep 17 00:00:00 2001 From: Tim Liu Date: Mon, 14 Oct 2024 03:55:14 +0000 Subject: [PATCH] feat: add safe_channel pkg --- safe_channel/safe_channel.go | 32 +++++++++++++++++++++++ safe_channel/safe_channel_test.go | 42 +++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 safe_channel/safe_channel.go create mode 100644 safe_channel/safe_channel_test.go diff --git a/safe_channel/safe_channel.go b/safe_channel/safe_channel.go new file mode 100644 index 0000000..775aee4 --- /dev/null +++ b/safe_channel/safe_channel.go @@ -0,0 +1,32 @@ +package safe_channel + +import ( + "sync/atomic" +) + +type SafeCh[T any] struct { + closed atomic.Bool + ch chan T +} + +func NewSafeCh[T any](size int) *SafeCh[T] { + return &SafeCh[T]{ + ch: make(chan T, size), + } +} + +func (c *SafeCh[T]) Send(e T) { + if c.closed.CompareAndSwap(false, false) { + c.ch <- e + } +} + +func (c *SafeCh[T]) GetRcvChan() <-chan T { + return c.ch +} + +func (c *SafeCh[T]) Close() { + if c.closed.CompareAndSwap(false, true) { + close(c.ch) + } +} diff --git a/safe_channel/safe_channel_test.go b/safe_channel/safe_channel_test.go new file mode 100644 index 0000000..8d006f1 --- /dev/null +++ b/safe_channel/safe_channel_test.go @@ -0,0 +1,42 @@ +package safe_channel + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// Simulate 1 receiver and N(N=2) senders situation +func TestSafeChannel(t *testing.T) { + sCh := NewSafeCh[int](1) + wg := sync.WaitGroup{} + + // Two senders + wg.Add(2) + for i := 0; i < 2; i++ { + go func(i int) { + if i == 0 { + // Case: send after sCh closed + time.Sleep(1 * time.Second) + + require.Equal(t, sCh.closed.Load(), true) + sCh.Send(1) // No panic + require.Equal(t, sCh.closed.Load(), true) + } else { + // Case: send success + sCh.Send(1) + require.Equal(t, sCh.closed.Load(), false) + } + wg.Done() + }(i) + } + + // One receiver + <-sCh.GetRcvChan() + sCh.Close() + require.Equal(t, sCh.closed.Load(), true) + + wg.Wait() +}