Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sequence operations #8

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/arg0net/collections

go 1.22.3
go 1.23.0

require (
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24
Expand Down
25 changes: 24 additions & 1 deletion notifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package collections

import (
"context"
"iter"
"sync"
)

Expand All @@ -16,7 +17,7 @@ type StatefulNotifier[T any] struct {

func NewStatefulNotifier[T any](initial T) *StatefulNotifier[T] {
return &StatefulNotifier[T]{
value: initial,
value: initial,
}
}

Expand Down Expand Up @@ -80,3 +81,25 @@ func (n *StatefulNotifier[T]) Wait(ctx context.Context, fn func(T) bool) (T, err
}
}
}

// Watch returns an iterator which will yield the current value and any updates.
// Note that updates may be missed if multiple updates occur quickly.
// If all updates should be processed, use a Channel instead.
// If the context is cancelled, then the iterator terminates.
func (n *StatefulNotifier[T]) Watch(ctx context.Context) iter.Seq[T] {
v, ch := n.Load()
return func(yield func(T) bool) {
for {
if !yield(v) {
return
}

select {
case <-ctx.Done():
return
case <-ch:
v, ch = n.Load()
}
}
}
}
36 changes: 34 additions & 2 deletions notifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package collections_test
import (
"context"
"math/rand"
"testing"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -35,7 +36,7 @@ func TestNotifierUpdate(t *testing.T) {
start := make(chan struct{})

incr := func(in int) int {
return in+1
return in + 1
}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
Expand Down Expand Up @@ -99,6 +100,37 @@ func TestWaitCancel(t *testing.T) {
require.ErrorIs(t, err, context.Canceled)
}

func TestWatch(t *testing.T) {
sn := collections.NewStatefulNotifier(0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var lastValue atomic.Int32
var done atomic.Bool
recv := sn.Watch(ctx)
go func() {
for v := range recv {
lastValue.Store(int32(v))
}
done.Store(true)
}()

sn.Store(42)
require.Eventually(t, func() bool {
return lastValue.Load() == 42
}, 2*time.Second, 10*time.Millisecond)

sn.Store(999)
require.Eventually(t, func() bool {
return lastValue.Load() == 999
}, 2*time.Second, 10*time.Millisecond)

cancel()
require.Eventually(t, func() bool {
return done.Load()
}, 2*time.Second, 10*time.Millisecond)
}

func TestNotifierWaitAny(t *testing.T) {
ctx := context.Background()

Expand Down
66 changes: 58 additions & 8 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@ package collections

import (
"context"
"iter"
"sync"
)

// Channel is a publish/subscribe channel. It is similar to an infinitely
// buffered Go channel, but where each value is sent to all subscribers.
// Channel is a publish/subscribe channel. It is similar to a Go channel with
// infinite capacity, with a couple important differences.
//
// The zero value of a Channel is ready to use.
// 1. Multiple receivers. There may be multiple receivers (or publishers), and
// all receivers get all messages.
//
// 2. Persistence. Messages are not persisted. If no receivers are listening when
// a message is published, it will be lost. When a receiver subscribes, it will
// only receive messages published after the subscription is created.
type Channel[T any] struct {
mu sync.Mutex // for reading `next` and for writes.
next *message[T]
}

type message[T any] struct {
value T
next *message[T]
final chan struct{}
value T
next *message[T]
final chan struct{}
closed bool
}

// Publish a new value to the channel. This value will be sent to all subscribers.
Expand All @@ -27,8 +34,8 @@ func (c *Channel[T]) Publish(value T) {
c.mu.Lock()
defer c.mu.Unlock()

if c.next == nil {
// no subscribers, can drop message.
if c.next == nil || c.next.closed {
// drop message.
return
}

Expand All @@ -40,6 +47,23 @@ func (c *Channel[T]) Publish(value T) {
close(old.final)
}

// Close the channel. This will prevent any new values from being published, and
// will cause all subscribers to stop receiving values after the last message.
// For receive iterators, this will cause the iterator to terminate.
func (c *Channel[T]) Close() {
c.mu.Lock()
defer c.mu.Unlock()

if c.next == nil {
c.next = &message[T]{final: make(chan struct{})}
}
if c.next.closed {
return
}
c.next.closed = true
close(c.next.final)
}

func (c *Channel[T]) head() *message[T] {
c.mu.Lock()
defer c.mu.Unlock()
Expand All @@ -52,6 +76,7 @@ func (c *Channel[T]) head() *message[T] {
// Watch updates on the channel. The function will be called with each new value
// sent to the channel. If the function returns an error, the subscription will
// be canceled and the error will be returned.
// If the channel is closed, Watch will return nil.
func (c *Channel[T]) Watch(ctx context.Context, fn func(T) error) error {
next := c.head()
for {
Expand All @@ -60,6 +85,9 @@ func (c *Channel[T]) Watch(ctx context.Context, fn func(T) error) error {
return ctx.Err()

case <-next.final:
if next.closed {
return nil
}
if err := fn(next.value); err != nil {
return err
}
Expand All @@ -68,6 +96,25 @@ func (c *Channel[T]) Watch(ctx context.Context, fn func(T) error) error {
}
}

// Receive subscribes to updates on the channel and returns a sequence of values.
// The subscription is setup before the function returns, so it is safe to publish
// values immediately after calling Receive.
// The sequence may be infinite, it will only terminate if the channel is closed.
func (c *Channel[T]) Receive() iter.Seq[T] {
next := c.head()
return func(yield func(T) bool) {
for {
select {
case <-next.final:
if next.closed || !yield(next.value) {
return
}
next = next.next
}
}
}
}

// Subscribe is like Watch, but without the context. The subscription will run
// until it is canceled.
// The subscription is setup before the function returns, so it is safe to
Expand Down Expand Up @@ -98,6 +145,9 @@ func (s *Subscription[T]) loop(next *message[T], fn func(T)) {
return

case <-next.final:
if next.closed {
return
}
fn(next.value)
next = next.next
}
Expand Down
30 changes: 30 additions & 0 deletions pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package collections_test
import (
"context"
"fmt"
"iter"
"math/rand"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -71,6 +72,35 @@ func TestPubSub_Watch(t *testing.T) {
require.Error(t, err)
}

func TestPubSub_Receive(t *testing.T) {
var c collections.Channel[int]

recv1 := c.Receive()
recv2 := c.Receive()

// Publish to the channel.
go func() {
for _, i := range rand.Perm(64) {
c.Publish(i)
}
c.Close()
}()

sum := func(recv iter.Seq[int]) int {
var sum int
for v := range recv {
sum += v
}
return sum
}

sum1 := sum(recv1)
sum2 := sum(recv2)

require.Equal(t, 2016, sum1)
require.Equal(t, 2016, sum2)
}

func BenchmarkPubSub(b *testing.B) {
for _, n := range []int{0, 1, 10, 100, 1000} {
b.Run(fmt.Sprintf("PubSub-%d", n), func(b *testing.B) {
Expand Down
18 changes: 18 additions & 0 deletions ring.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package collections

import "iter"

// Ring is a fixed-size ring buffer that supports pushing and popping elements,
// as well as copying elements into a slice, and removing an element by index.
// The ring is implemented as a single slice, which is never reallocated.
Expand Down Expand Up @@ -167,3 +169,19 @@ func (r *Ring[T]) Scan(fn func(T) bool) (T, int) {
var zero T
return zero, -1
}

// All returns a sequence of all elements in the ring.
func (r *Ring[T]) All() iter.Seq[T] {
return func(yield func(T) bool) {
for _, e := range r.right {
if !yield(e) {
return
}
}
for _, e := range r.left {
if !yield(e) {
return
}
}
}
}