Skip to content

Commit

Permalink
part: Make Map and Set JSON marshable
Browse files Browse the repository at this point in the history
Implement MarshalJSON and UnmarshalJSON for Map and Set.

The Map is marshalled into a JSON array of key-value pairs
(it cannot be an object since the key type may not be a valid
JSON object key). The Set type is marshalled into a JSON
array.

Signed-off-by: Jussi Maki <[email protected]>
  • Loading branch information
joamaki committed May 23, 2024
1 parent 9a366e0 commit 0222c76
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 35 deletions.
179 changes: 149 additions & 30 deletions part/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,26 @@

package part

// MapIterator iterates over key and value pairs.
type MapIterator[K, V any] struct {
iter *Iterator[V]
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
)

// Map of key-value pairs. The zero value is ready for use, provided
// that the key type has been registered with RegisterKeyType.
//
// Map is a typed wrapper around Tree[T] for working with
// keys that are not []byte.
type Map[K, V any] struct {
bytesFromKey func(K) []byte
tree *Tree[mapKVPair[K, V]]
}

// Next returns the next key (as bytes) and value. If the iterator
// is exhausted it returns false.
func (it MapIterator[K, V]) Next() (k []byte, v V, ok bool) {
if it.iter == nil {
return
}
k, v, ok = it.iter.Next()
return
type mapKVPair[K, V any] struct {
Key K `json:"k"`
Value V `json:"v"`
}

// FromMap copies values from the hash map into the given Map.
Expand All @@ -25,32 +32,18 @@ func FromMap[K comparable, V any](m Map[K, V], hm map[K]V) Map[K, V] {
m.ensureTree()
txn := m.tree.Txn()
for k, v := range hm {
txn.Insert(m.bytesFromKey(k), v)
txn.Insert(m.bytesFromKey(k), mapKVPair[K, V]{k, v})
}
m.tree = txn.CommitOnly()
return m
}

// Map of key-value pairs. The zero value is ready for use, provided
// that the key type has been registered with RegisterKeyType.
//
// Map is a typed wrapper around Tree[T] for working with
// keys that are not []byte.
//
// The iteration over the map returns the keys in the []byte form as
// the conversion from bytes to K is not always a desired nor necessary
// operation.
type Map[K, V any] struct {
bytesFromKey func(K) []byte
tree *Tree[V]
}

// ensureTree checks that the tree is not nil and allocates it if
// it is. The whole nil tree thing is to make sure that creating
// an empty map does not allocate anything.
func (m *Map[K, V]) ensureTree() {
if m.tree == nil {
m.tree = New[V](RootOnlyWatch)
m.tree = New[mapKVPair[K, V]](RootOnlyWatch)
}
m.bytesFromKey = lookupKeyType[K]()
}
Expand All @@ -60,16 +53,16 @@ func (m Map[K, V]) Get(key K) (value V, found bool) {
if m.tree == nil {
return
}
value, _, found = m.tree.Get(m.bytesFromKey(key))
return
kv, _, found := m.tree.Get(m.bytesFromKey(key))
return kv.Value, found
}

// Set a value. Returns a new map with the value set.
// Original map is unchanged.
func (m Map[K, V]) Set(key K, value V) Map[K, V] {
m.ensureTree()
txn := m.tree.Txn()
txn.Insert(m.bytesFromKey(key), value)
txn.Insert(m.bytesFromKey(key), mapKVPair[K, V]{key, value})
m.tree = txn.CommitOnly()
return m
}
Expand All @@ -86,6 +79,21 @@ func (m Map[K, V]) Delete(key K) Map[K, V] {
return m
}

// MapIterator iterates over key and value pairs.
type MapIterator[K, V any] struct {
iter *Iterator[mapKVPair[K, V]]
}

// Next returns the next key (as bytes) and value. If the iterator
// is exhausted it returns false.
func (it MapIterator[K, V]) Next() (k K, v V, ok bool) {
if it.iter == nil {
return
}
_, kv, ok := it.iter.Next()
return kv.Key, kv.Value, ok
}

// LowerBound iterates over all keys in order with value equal
// to or greater than [from].
func (m Map[K, V]) LowerBound(from K) MapIterator[K, V] {
Expand Down Expand Up @@ -121,10 +129,121 @@ func (m Map[K, V]) All() MapIterator[K, V] {
}
}

// EqualKeys returns true if both maps contain the same keys.
func (m Map[K, V]) EqualKeys(other Map[K, V]) bool {
switch {
case m.tree == nil:
return other.tree == nil
case other.tree == nil:
return m.tree == nil
case m.Len() != other.Len():
return false
default:
iter1 := m.tree.Iterator()
iter2 := other.tree.Iterator()
for {
k1, _, ok := iter1.Next()
if !ok {
break
}
k2, _, _ := iter2.Next()
// Equal lengths, no need to check 'ok' for 'iter2'.
if !bytes.Equal(k1, k2) {
return false
}
}
return true
}
}

// SlowEqual returns true if the two maps contain the same keys and values.
// Value comparison is implemented with reflect.DeepEqual which makes this
// slow and mostly useful for testing.
func (m Map[K, V]) SlowEqual(other Map[K, V]) bool {
switch {
case m.tree == nil:
return other.tree == nil
case other.tree == nil:
return m.tree == nil
case m.Len() != other.Len():
return false
default:
iter1 := m.tree.Iterator()
iter2 := other.tree.Iterator()
for {
k1, v1, ok := iter1.Next()
if !ok {
break
}
k2, v2, _ := iter2.Next()
// Equal lengths, no need to check 'ok' for 'iter2'.
if !bytes.Equal(k1, k2) || !reflect.DeepEqual(v1, v2) {
return false
}
}
return true
}
}

// Len returns the number of elements in the map.
func (m Map[K, V]) Len() int {
if m.tree == nil {
return 0
}
return m.tree.size
}

func (m Map[K, V]) MarshalJSON() ([]byte, error) {
if m.tree == nil {
return []byte("[]"), nil
}

var b bytes.Buffer
b.WriteRune('[')
iter := m.tree.Iterator()
_, kv, ok := iter.Next()
for ok {
bs, err := json.Marshal(kv)
if err != nil {
return nil, err
}
b.Write(bs)
_, kv, ok = iter.Next()
if ok {
b.WriteRune(',')
}
}
b.WriteRune(']')
return b.Bytes(), nil
}

func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
dec := json.NewDecoder(bytes.NewReader(data))
t, err := dec.Token()
if err != nil {
return err
}
if d, ok := t.(json.Delim); !ok || d != '[' {
return fmt.Errorf("%T.UnmarshalJSON: expected '[' got %v", m, t)
}
m.ensureTree()
txn := m.tree.Txn()
for dec.More() {
var kv mapKVPair[K, V]
err := dec.Decode(&kv)
if err != nil {
return err
}
txn.Insert(m.bytesFromKey(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value})
}

t, err = dec.Token()
if err != nil {
return err
}
if d, ok := t.(json.Delim); !ok || d != ']' {
return fmt.Errorf("%T.UnmarshalJSON: expected ']' got %v", m, t)
}
m.tree = txn.CommitOnly()
return nil
}
24 changes: 19 additions & 5 deletions part/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
package part_test

import (
"encoding/binary"
"encoding/json"
"math/rand/v2"
"testing"

"github.com/cilium/statedb/part"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestStringMap(t *testing.T) {
Expand All @@ -27,7 +28,7 @@ func TestStringMap(t *testing.T) {
t.Helper()
k, v, ok := iter.Next()
assert.False(t, ok, "expected empty iterator")
assert.Nil(t, k, "empty key")
assert.Empty(t, k, "empty key")
assert.Equal(t, 0, v)
}
assertIterEmpty(m.LowerBound(""))
Expand Down Expand Up @@ -119,12 +120,12 @@ func TestUint64Map(t *testing.T) {
iter := m.LowerBound(55)
k, v, ok := iter.Next()
assert.True(t, ok, "Next")
assert.EqualValues(t, 55, binary.BigEndian.Uint64(k))
assert.EqualValues(t, 55, k)
assert.EqualValues(t, 55, v)

k, v, ok = iter.Next()
assert.True(t, ok, "Next")
assert.EqualValues(t, 72, binary.BigEndian.Uint64(k))
assert.EqualValues(t, 72, k)
assert.EqualValues(t, 72, v)

_, _, ok = iter.Next()
Expand All @@ -147,14 +148,27 @@ func TestRegisterKeyType(t *testing.T) {
iter := m.All()
k, v, ok := iter.Next()
assert.True(t, ok, "Next")
assert.Equal(t, "hello", string(k))
assert.Equal(t, testKey{"hello"}, k)
assert.Equal(t, 123, v)

_, _, ok = iter.Next()
assert.False(t, ok, "Next")

}

func TestMapJSON(t *testing.T) {
var m part.Map[string, int]
m = m.Set("foo", 1).Set("bar", 2).Set("baz", 3)

bs, err := json.Marshal(m)
require.NoError(t, err, "Marshal")

var m2 part.Map[string, int]
err = json.Unmarshal(bs, &m2)
require.NoError(t, err, "Unmarshal")
require.True(t, m.SlowEqual(m2), "SlowEqual")
}

func Benchmark_Uint64Map_Random(b *testing.B) {
numItems := 1000
keys := map[uint64]int{}
Expand Down
Loading

0 comments on commit 0222c76

Please sign in to comment.