diff --git a/sets.go b/sets.go index 0fe550c..0d22436 100644 --- a/sets.go +++ b/sets.go @@ -4,22 +4,34 @@ type Set[T comparable] struct { m *Map[T, struct{}] } -func NewSet[T comparable](hasher Hasher[T]) Set[T] { - return Set[T]{ +func NewSet[T comparable](hasher Hasher[T], values ...T) Set[T] { + s := Set[T]{ m: NewMap[T, struct{}](hasher), } + for _, value := range values { + s.m.set(value, struct{}{}, true) + } + return s } -func (s Set[T]) Set(val T) Set[T] { - return Set[T]{ - m: s.m.Set(val, struct{}{}), +func (s Set[T]) Set(values ...T) Set[T] { + n := Set[T]{ + m: s.m.clone(), + } + for _, value := range values { + n.m.set(value, struct{}{}, true) } + return n } -func (s Set[T]) Delete(val T) Set[T] { - return Set[T]{ - m: s.m.Delete(val), +func (s Set[T]) Delete(values ...T) Set[T] { + n := Set[T]{ + m: s.m.clone(), + } + for _, value := range values { + n.m.delete(value, true) } + return n } func (s Set[T]) Has(val T) bool { @@ -82,22 +94,34 @@ type SortedSet[T comparable] struct { m *SortedMap[T, struct{}] } -func NewSortedSet[T comparable](comparer Comparer[T]) SortedSet[T] { - return SortedSet[T]{ +func NewSortedSet[T comparable](comparer Comparer[T], values ...T) SortedSet[T] { + s := SortedSet[T]{ m: NewSortedMap[T, struct{}](comparer), } + for _, value := range values { + s.m.set(value, struct{}{}, true) + } + return s } -func (s SortedSet[T]) Put(val T) SortedSet[T] { - return SortedSet[T]{ - m: s.m.Set(val, struct{}{}), +func (s SortedSet[T]) Set(values ...T) SortedSet[T] { + n := SortedSet[T]{ + m: s.m.clone(), + } + for _, value := range values { + n.m.set(value, struct{}{}, true) } + return n } -func (s SortedSet[T]) Delete(val T) SortedSet[T] { - return SortedSet[T]{ - m: s.m.Delete(val), +func (s SortedSet[T]) Delete(values ...T) SortedSet[T] { + n := SortedSet[T]{ + m: s.m.clone(), + } + for _, value := range values { + n.m.delete(value, true) } + return n } func (s SortedSet[T]) Has(val T) bool { diff --git a/sets_test.go b/sets_test.go index d2e7f32..b3900fc 100644 --- a/sets_test.go +++ b/sets_test.go @@ -51,7 +51,7 @@ func TestSetsDelete(t *testing.T) { func TestSortedSetsPut(t *testing.T) { s := NewSortedSet[string](nil) - s2 := s.Put("1").Put("1").Put("0") + s2 := s.Set("1").Set("1").Set("0") if s.Len() != 0 { t.Fatalf("Unexpected mutation of set") } @@ -85,7 +85,7 @@ func TestSortedSetsPut(t *testing.T) { func TestSortedSetsDelete(t *testing.T) { s := NewSortedSet[string](nil) - s2 := s.Put("1") + s2 := s.Set("1") s3 := s.Delete("1") if s2.Len() != 1 { t.Fatalf("Unexpected non-mutation of set")