diff --git a/immutable/LICENSE b/immutable/LICENSE new file mode 100644 index 0000000..6a66aea --- /dev/null +++ b/immutable/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/immutable/README b/immutable/README new file mode 100644 index 0000000..1aa821b --- /dev/null +++ b/immutable/README @@ -0,0 +1,5 @@ +The immutable package contains a copy of the persistent map and sets implementations used by +gopls: https://cs.opensource.google/go/x/tools/+/master:gopls/internal/util/persistent/ + +These files are licensed under the BSD license as set out in the LICENSE file in +this directory. diff --git a/immutable/map.go b/immutable/map.go new file mode 100644 index 0000000..cbda601 --- /dev/null +++ b/immutable/map.go @@ -0,0 +1,326 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The persistent package defines various persistent data structures; +// that is, data structures that can be efficiently copied and modified +// in sublinear time. +package immutable + +import ( + "cmp" + "fmt" + "math/rand" + "strings" + "sync/atomic" +) + +// Implementation details: +// * Each value is reference counted by nodes which hold it. +// * Each node is reference counted by its parent nodes. +// * Each map is considered a top-level parent node from reference counting perspective. +// * Each change does always effectively produce a new top level node. +// +// Functions which operate directly with nodes do have a notation in form of +// `foo(arg1:+n1, arg2:+n2) (ret1:+n3)`. +// Each argument is followed by a delta change to its reference counter. +// In case if no change is expected, the delta will be `-0`. + +// Map is an associative mapping from keys to values. +// +// Maps can be Cloned in constant time. +// Get, Set, and Delete operations are done on average in logarithmic time. +// Maps can be merged (via SetAll) in O(m log(n/m)) time for maps of size n and m, where m < n. +// +// Values are reference counted, and a client-supplied release function +// is called when a value is no longer referenced by a map or any clone. +// +// Internally the implementation is based on a randomized persistent treap: +// https://en.wikipedia.org/wiki/Treap. +// +// The zero value is ready to use. +type Map[K cmp.Ordered, V any] struct { + // Map is a generic wrapper around a non-generic implementation to avoid a + // significant increase in the size of the executable. + root *mapNode +} + +func (*Map[K, V]) less(l, r any) bool { + return l.(K) < r.(K) +} + +func (m *Map[K, V]) String() string { + var buf strings.Builder + buf.WriteByte('{') + var sep string + m.Range(func(k K, v V) { + fmt.Fprintf(&buf, "%s%v: %v", sep, k, v) + sep = ", " + }) + buf.WriteByte('}') + return buf.String() +} + +type mapNode struct { + key any + value *refValue + weight uint64 + refCount int32 + left, right *mapNode +} + +type refValue struct { + refCount int32 + value any + release func(key, value any) +} + +func newNodeWithRef[K cmp.Ordered, V any](key K, value V, release func(key, value any)) *mapNode { + return &mapNode{ + key: key, + value: &refValue{ + value: value, + release: release, + refCount: 1, + }, + refCount: 1, + weight: rand.Uint64(), + } +} + +func (node *mapNode) shallowCloneWithRef() *mapNode { + atomic.AddInt32(&node.value.refCount, 1) + return &mapNode{ + key: node.key, + value: node.value, + weight: node.weight, + refCount: 1, + } +} + +func (node *mapNode) incref() *mapNode { + if node != nil { + atomic.AddInt32(&node.refCount, 1) + } + return node +} + +func (node *mapNode) decref() { + if node == nil { + return + } + if atomic.AddInt32(&node.refCount, -1) == 0 { + if atomic.AddInt32(&node.value.refCount, -1) == 0 { + if node.value.release != nil { + node.value.release(node.key, node.value.value) + } + node.value.value = nil + node.value.release = nil + } + node.left.decref() + node.right.decref() + } +} + +// Clone returns a copy of the given map. It is a responsibility of the caller +// to Destroy it at later time. +func (pm *Map[K, V]) Clone() *Map[K, V] { + return &Map[K, V]{ + root: pm.root.incref(), + } +} + +// Destroy destroys the map. +// +// After Destroy, the Map should not be used again. +func (pm *Map[K, V]) Destroy() { + // The implementation of these two functions is the same, + // but their intent is different. + pm.Clear() +} + +// Clear removes all entries from the map. +func (pm *Map[K, V]) Clear() { + pm.root.decref() + pm.root = nil +} + +// Keys returns all keys present in the map. +func (pm *Map[K, V]) Keys() []K { + var keys []K + pm.root.forEach(func(k, _ any) { + keys = append(keys, k.(K)) + }) + return keys +} + +// Range calls f sequentially in ascending key order for all entries in the map. +func (pm *Map[K, V]) Range(f func(key K, value V)) { + pm.root.forEach(func(k, v any) { + f(k.(K), v.(V)) + }) +} + +func (node *mapNode) forEach(f func(key, value any)) { + if node == nil { + return + } + node.left.forEach(f) + f(node.key, node.value.value) + node.right.forEach(f) +} + +// Get returns the map value associated with the specified key. +// The ok result indicates whether an entry was found in the map. +func (pm *Map[K, V]) Get(key K) (V, bool) { + node := pm.root + for node != nil { + if key < node.key.(K) { + node = node.left + } else if node.key.(K) < key { + node = node.right + } else { + return node.value.value.(V), true + } + } + var zero V + return zero, false +} + +// SetAll updates the map with key/value pairs from the other map, overwriting existing keys. +// It is equivalent to calling Set for each entry in the other map but is more efficient. +func (pm *Map[K, V]) SetAll(other *Map[K, V]) { + root := pm.root + pm.root = union(root, other.root, pm.less, true) + root.decref() +} + +// Set updates the value associated with the specified key. +func (pm *Map[K, V]) Set(key K, value V) { + pm.SetWithRelease(key, value, nil) +} + +// Set updates the value associated with the specified key. +// If release is non-nil, it will be called with entry's key and value once the +// key is no longer contained in the map or any clone. +func (pm *Map[K, V]) SetWithRelease(key K, value V, release func(key, value any)) { + first := pm.root + second := newNodeWithRef(key, value, release) + pm.root = union(first, second, pm.less, true) + first.decref() + second.decref() +} + +// union returns a new tree which is a union of first and second one. +// If overwrite is set to true, second one would override a value for any duplicate keys. +// +// union(first:-0, second:-0) (result:+1) +// Union borrows both subtrees without affecting their refcount and returns a +// new reference that the caller is expected to call decref. +func union(first, second *mapNode, less func(any, any) bool, overwrite bool) *mapNode { + if first == nil { + return second.incref() + } + if second == nil { + return first.incref() + } + + if first.weight < second.weight { + second, first, overwrite = first, second, !overwrite + } + + left, mid, right := split(second, first.key, less, false) + var result *mapNode + if overwrite && mid != nil { + result = mid.shallowCloneWithRef() + } else { + result = first.shallowCloneWithRef() + } + result.weight = first.weight + result.left = union(first.left, left, less, overwrite) + result.right = union(first.right, right, less, overwrite) + left.decref() + mid.decref() + right.decref() + return result +} + +// split the tree midway by the key into three different ones. +// Return three new trees: left with all nodes with smaller than key, mid with +// the node matching the key, right with all nodes larger than key. +// If there are no nodes in one of trees, return nil instead of it. +// If requireMid is set (such as during deletion), then all return arguments +// are nil if mid is not found. +// +// split(n:-0) (left:+1, mid:+1, right:+1) +// Split borrows n without affecting its refcount, and returns three +// new references that the caller is expected to call decref. +func split(n *mapNode, key any, less func(any, any) bool, requireMid bool) (left, mid, right *mapNode) { + if n == nil { + return nil, nil, nil + } + + if less(n.key, key) { + left, mid, right := split(n.right, key, less, requireMid) + if requireMid && mid == nil { + return nil, nil, nil + } + newN := n.shallowCloneWithRef() + newN.left = n.left.incref() + newN.right = left + return newN, mid, right + } else if less(key, n.key) { + left, mid, right := split(n.left, key, less, requireMid) + if requireMid && mid == nil { + return nil, nil, nil + } + newN := n.shallowCloneWithRef() + newN.left = right + newN.right = n.right.incref() + return left, mid, newN + } + mid = n.shallowCloneWithRef() + return n.left.incref(), mid, n.right.incref() +} + +// Delete deletes the value for a key. +// +// The result reports whether the key was present in the map. +func (pm *Map[K, V]) Delete(key K) bool { + root := pm.root + left, mid, right := split(root, key, pm.less, true) + if mid == nil { + return false + } + pm.root = merge(left, right) + left.decref() + mid.decref() + right.decref() + root.decref() + return true +} + +// merge two trees while preserving the weight invariant. +// All nodes in left must have smaller keys than any node in right. +// +// merge(left:-0, right:-0) (result:+1) +// Merge borrows its arguments without affecting their refcount +// and returns a new reference that the caller is expected to call decref. +func merge(left, right *mapNode) *mapNode { + switch { + case left == nil: + return right.incref() + case right == nil: + return left.incref() + case left.weight > right.weight: + root := left.shallowCloneWithRef() + root.left = left.left.incref() + root.right = merge(left.right, right) + return root + default: + root := right.shallowCloneWithRef() + root.left = merge(left, right.left) + root.right = right.right.incref() + return root + } +} diff --git a/immutable/map_test.go b/immutable/map_test.go new file mode 100644 index 0000000..297231e --- /dev/null +++ b/immutable/map_test.go @@ -0,0 +1,352 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package immutable + +import ( + "fmt" + "math/rand" + "reflect" + "sync/atomic" + "testing" +) + +type mapEntry struct { + key int + value int +} + +type validatedMap struct { + impl *Map[int, int] + expected map[int]int // current key-value mapping. + deleted map[mapEntry]int // maps deleted entries to their clock time of last deletion + seen map[mapEntry]int // maps seen entries to their clock time of last insertion + clock int +} + +func TestSimpleMap(t *testing.T) { + deletedEntries := make(map[mapEntry]int) + seenEntries := make(map[mapEntry]int) + + m1 := &validatedMap{ + impl: new(Map[int, int]), + expected: make(map[int]int), + deleted: deletedEntries, + seen: seenEntries, + } + + m3 := m1.clone() + validateRef(t, m1, m3) + m3.set(t, 8, 8) + validateRef(t, m1, m3) + m3.destroy() + + assertSameMap(t, entrySet(deletedEntries), map[mapEntry]struct{}{ + {key: 8, value: 8}: {}, + }) + + validateRef(t, m1) + m1.set(t, 1, 1) + validateRef(t, m1) + m1.set(t, 2, 2) + validateRef(t, m1) + m1.set(t, 3, 3) + validateRef(t, m1) + m1.remove(t, 2) + validateRef(t, m1) + m1.set(t, 6, 6) + validateRef(t, m1) + + assertSameMap(t, entrySet(deletedEntries), map[mapEntry]struct{}{ + {key: 2, value: 2}: {}, + {key: 8, value: 8}: {}, + }) + + m2 := m1.clone() + validateRef(t, m1, m2) + m1.set(t, 6, 60) + validateRef(t, m1, m2) + m1.remove(t, 1) + validateRef(t, m1, m2) + + gotAllocs := int(testing.AllocsPerRun(10, func() { + m1.impl.Delete(100) + m1.impl.Delete(1) + })) + wantAllocs := 0 + if gotAllocs != wantAllocs { + t.Errorf("wanted %d allocs, got %d", wantAllocs, gotAllocs) + } + + for i := 10; i < 14; i++ { + m1.set(t, i, i) + validateRef(t, m1, m2) + } + + m1.set(t, 10, 100) + validateRef(t, m1, m2) + + m1.remove(t, 12) + validateRef(t, m1, m2) + + m2.set(t, 4, 4) + validateRef(t, m1, m2) + m2.set(t, 5, 5) + validateRef(t, m1, m2) + + m1.destroy() + + assertSameMap(t, entrySet(deletedEntries), map[mapEntry]struct{}{ + {key: 2, value: 2}: {}, + {key: 6, value: 60}: {}, + {key: 8, value: 8}: {}, + {key: 10, value: 10}: {}, + {key: 10, value: 100}: {}, + {key: 11, value: 11}: {}, + {key: 12, value: 12}: {}, + {key: 13, value: 13}: {}, + }) + + m2.set(t, 7, 7) + validateRef(t, m2) + + m2.destroy() + + assertSameMap(t, entrySet(seenEntries), entrySet(deletedEntries)) +} + +func TestRandomMap(t *testing.T) { + deletedEntries := make(map[mapEntry]int) + seenEntries := make(map[mapEntry]int) + + m := &validatedMap{ + impl: new(Map[int, int]), + expected: make(map[int]int), + deleted: deletedEntries, + seen: seenEntries, + } + + keys := make([]int, 0, 1000) + for i := 0; i < 1000; i++ { + key := rand.Intn(10000) + m.set(t, key, key) + keys = append(keys, key) + + if i%10 == 1 { + index := rand.Intn(len(keys)) + last := len(keys) - 1 + key = keys[index] + keys[index], keys[last] = keys[last], keys[index] + keys = keys[:last] + + m.remove(t, key) + } + } + + m.destroy() + assertSameMap(t, entrySet(seenEntries), entrySet(deletedEntries)) +} + +func entrySet(m map[mapEntry]int) map[mapEntry]struct{} { + set := make(map[mapEntry]struct{}) + for k := range m { + set[k] = struct{}{} + } + return set +} + +func TestUpdate(t *testing.T) { + deletedEntries := make(map[mapEntry]int) + seenEntries := make(map[mapEntry]int) + + m1 := &validatedMap{ + impl: new(Map[int, int]), + expected: make(map[int]int), + deleted: deletedEntries, + seen: seenEntries, + } + m2 := m1.clone() + + m1.set(t, 1, 1) + m1.set(t, 2, 2) + m2.set(t, 2, 20) + m2.set(t, 3, 3) + m1.setAll(t, m2) + + m1.destroy() + m2.destroy() + assertSameMap(t, entrySet(seenEntries), entrySet(deletedEntries)) +} + +func validateRef(t *testing.T, maps ...*validatedMap) { + t.Helper() + + actualCountByEntry := make(map[mapEntry]int32) + nodesByEntry := make(map[mapEntry]map[*mapNode]struct{}) + expectedCountByEntry := make(map[mapEntry]int32) + for i, m := range maps { + dfsRef(m.impl.root, actualCountByEntry, nodesByEntry) + dumpMap(t, fmt.Sprintf("%d:", i), m.impl.root) + } + for entry, nodes := range nodesByEntry { + expectedCountByEntry[entry] = int32(len(nodes)) + } + assertSameMap(t, expectedCountByEntry, actualCountByEntry) +} + +func dfsRef(node *mapNode, countByEntry map[mapEntry]int32, nodesByEntry map[mapEntry]map[*mapNode]struct{}) { + if node == nil { + return + } + + entry := mapEntry{key: node.key.(int), value: node.value.value.(int)} + countByEntry[entry] = atomic.LoadInt32(&node.value.refCount) + + nodes, ok := nodesByEntry[entry] + if !ok { + nodes = make(map[*mapNode]struct{}) + nodesByEntry[entry] = nodes + } + nodes[node] = struct{}{} + + dfsRef(node.left, countByEntry, nodesByEntry) + dfsRef(node.right, countByEntry, nodesByEntry) +} + +func dumpMap(t *testing.T, prefix string, n *mapNode) { + if n == nil { + t.Logf("%s nil", prefix) + return + } + t.Logf("%s {key: %v, value: %v (ref: %v), ref: %v, weight: %v}", prefix, n.key, n.value.value, n.value.refCount, n.refCount, n.weight) + dumpMap(t, prefix+"l", n.left) + dumpMap(t, prefix+"r", n.right) +} + +func (vm *validatedMap) validate(t *testing.T) { + t.Helper() + + validateNode(t, vm.impl.root) + + // Note: this validation may not make sense if maps were constructed using + // SetAll operations. If this proves to be problematic, remove the clock, + // deleted, and seen fields. + for key, value := range vm.expected { + entry := mapEntry{key: key, value: value} + if deleteAt := vm.deleted[entry]; deleteAt > vm.seen[entry] { + t.Fatalf("entry is deleted prematurely, key: %d, value: %d", key, value) + } + } + + actualMap := make(map[int]int, len(vm.expected)) + vm.impl.Range(func(key, value int) { + if other, ok := actualMap[key]; ok { + t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other) + } + actualMap[key] = value + }) + + assertSameMap(t, actualMap, vm.expected) +} + +func validateNode(t *testing.T, node *mapNode) { + if node == nil { + return + } + + if node.left != nil { + if node.key.(int) < node.left.key.(int) { + t.Fatalf("left child has larger key: %v vs %v", node.left.key, node.key) + } + if node.left.weight > node.weight { + t.Fatalf("left child has larger weight: %v vs %v", node.left.weight, node.weight) + } + } + + if node.right != nil { + if node.right.key.(int) < node.key.(int) { + t.Fatalf("right child has smaller key: %v vs %v", node.right.key, node.key) + } + if node.right.weight > node.weight { + t.Fatalf("right child has larger weight: %v vs %v", node.right.weight, node.weight) + } + } + + validateNode(t, node.left) + validateNode(t, node.right) +} + +func (vm *validatedMap) setAll(t *testing.T, other *validatedMap) { + vm.impl.SetAll(other.impl) + + // Note: this is buggy because we are not updating vm.clock, vm.deleted, or + // vm.seen. + for key, value := range other.expected { + vm.expected[key] = value + } + vm.validate(t) +} + +func (vm *validatedMap) set(t *testing.T, key, value int) { + entry := mapEntry{key: key, value: value} + + vm.clock++ + vm.seen[entry] = vm.clock + + vm.impl.SetWithRelease(key, value, func(deletedKey, deletedValue any) { + if deletedKey != key || deletedValue != value { + t.Fatalf("unexpected passed in deleted entry: %v/%v, expected: %v/%v", deletedKey, deletedValue, key, value) + } + // Not safe if closure shared between two validatedMaps. + vm.deleted[entry] = vm.clock + }) + vm.expected[key] = value + vm.validate(t) + + gotValue, ok := vm.impl.Get(key) + if !ok || gotValue != value { + t.Fatalf("unexpected get result after insertion, key: %v, expected: %v, got: %v (%v)", key, value, gotValue, ok) + } +} + +func (vm *validatedMap) remove(t *testing.T, key int) { + vm.clock++ + deleted := vm.impl.Delete(key) + if _, ok := vm.expected[key]; ok != deleted { + t.Fatalf("Delete(%d) = %t, want %t", key, deleted, ok) + } + delete(vm.expected, key) + vm.validate(t) + + gotValue, ok := vm.impl.Get(key) + if ok { + t.Fatalf("unexpected get result after removal, key: %v, got: %v", key, gotValue) + } +} + +func (vm *validatedMap) clone() *validatedMap { + expected := make(map[int]int, len(vm.expected)) + for key, value := range vm.expected { + expected[key] = value + } + + return &validatedMap{ + impl: vm.impl.Clone(), + expected: expected, + deleted: vm.deleted, + seen: vm.seen, + } +} + +func (vm *validatedMap) destroy() { + vm.impl.Destroy() +} + +func assertSameMap(t *testing.T, map1, map2 any) { + t.Helper() + + if !reflect.DeepEqual(map1, map2) { + t.Fatalf("different maps:\n%v\nvs\n%v", map1, map2) + } +} diff --git a/immutable/set.go b/immutable/set.go new file mode 100644 index 0000000..55add52 --- /dev/null +++ b/immutable/set.go @@ -0,0 +1,78 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package immutable + +import "cmp" + +// Set is a collection of elements of type K. +// +// It uses immutable data structures internally, so that sets can be cloned in +// constant time. +// +// The zero value is a valid empty set. +type Set[K cmp.Ordered] struct { + impl *Map[K, struct{}] +} + +// Clone creates a copy of the receiver. +func (s *Set[K]) Clone() *Set[K] { + clone := new(Set[K]) + if s.impl != nil { + clone.impl = s.impl.Clone() + } + return clone +} + +// Destroy destroys the set. +// +// After Destroy, the Set should not be used again. +func (s *Set[K]) Destroy() { + if s.impl != nil { + s.impl.Destroy() + } +} + +// Contains reports whether s contains the given key. +func (s *Set[K]) Contains(key K) bool { + if s.impl == nil { + return false + } + _, ok := s.impl.Get(key) + return ok +} + +// Range calls f sequentially in ascending key order for all entries in the set. +func (s *Set[K]) Range(f func(key K)) { + if s.impl != nil { + s.impl.Range(func(key K, _ struct{}) { + f(key) + }) + } +} + +// AddAll adds all elements from other to the receiver set. +func (s *Set[K]) AddAll(other *Set[K]) { + if other.impl != nil { + if s.impl == nil { + s.impl = new(Map[K, struct{}]) + } + s.impl.SetAll(other.impl) + } +} + +// Add adds an element to the set. +func (s *Set[K]) Add(key K) { + if s.impl == nil { + s.impl = new(Map[K, struct{}]) + } + s.impl.Set(key, struct{}{}) +} + +// Remove removes an element from the set. +func (s *Set[K]) Remove(key K) { + if s.impl != nil { + s.impl.Delete(key) + } +} diff --git a/immutable/set_test.go b/immutable/set_test.go new file mode 100644 index 0000000..eb041cf --- /dev/null +++ b/immutable/set_test.go @@ -0,0 +1,132 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package immutable_test + +import ( + "fmt" + "strings" + "testing" + + "github.com/cilium/statedb/immutable" + "golang.org/x/exp/constraints" +) + +func TestSet(t *testing.T) { + const ( + add = iota + remove + ) + type op struct { + op int + v int + } + + tests := []struct { + label string + ops []op + want []int + }{ + {"empty", nil, nil}, + {"singleton", []op{{add, 1}}, []int{1}}, + {"add and remove", []op{ + {add, 1}, + {remove, 1}, + }, nil}, + {"interleaved and remove", []op{ + {add, 1}, + {add, 2}, + {remove, 1}, + {add, 3}, + }, []int{2, 3}}, + } + + for _, test := range tests { + t.Run(test.label, func(t *testing.T) { + var s immutable.Set[int] + for _, op := range test.ops { + switch op.op { + case add: + s.Add(op.v) + case remove: + s.Remove(op.v) + } + } + + if d := diff(&s, test.want); d != "" { + t.Errorf("unexpected diff:\n%s", d) + } + }) + } +} + +func TestSet_Clone(t *testing.T) { + s1 := new(immutable.Set[int]) + s1.Add(1) + s1.Add(2) + s2 := s1.Clone() + s1.Add(3) + s2.Add(4) + if d := diff(s1, []int{1, 2, 3}); d != "" { + t.Errorf("s1: unexpected diff:\n%s", d) + } + if d := diff(s2, []int{1, 2, 4}); d != "" { + t.Errorf("s2: unexpected diff:\n%s", d) + } +} + +func TestSet_AddAll(t *testing.T) { + s1 := new(immutable.Set[int]) + s1.Add(1) + s1.Add(2) + s2 := new(immutable.Set[int]) + s2.Add(2) + s2.Add(3) + s2.Add(4) + s3 := new(immutable.Set[int]) + + s := new(immutable.Set[int]) + s.AddAll(s1) + s.AddAll(s2) + s.AddAll(s3) + + if d := diff(s1, []int{1, 2}); d != "" { + t.Errorf("s1: unexpected diff:\n%s", d) + } + if d := diff(s2, []int{2, 3, 4}); d != "" { + t.Errorf("s2: unexpected diff:\n%s", d) + } + if d := diff(s3, nil); d != "" { + t.Errorf("s3: unexpected diff:\n%s", d) + } + if d := diff(s, []int{1, 2, 3, 4}); d != "" { + t.Errorf("s: unexpected diff:\n%s", d) + } +} + +func diff[K constraints.Ordered](got *immutable.Set[K], want []K) string { + wantSet := make(map[K]struct{}) + for _, w := range want { + wantSet[w] = struct{}{} + } + var diff []string + got.Range(func(key K) { + if _, ok := wantSet[key]; !ok { + diff = append(diff, fmt.Sprintf("+%v", key)) + } + }) + for key := range wantSet { + if !got.Contains(key) { + diff = append(diff, fmt.Sprintf("-%v", key)) + } + } + if len(diff) > 0 { + d := new(strings.Builder) + for _, l := range diff { + fmt.Fprintln(d, l) + } + return d.String() + } + return "" +}