diff --git a/state/keys.go b/state/keys.go index 13cb8d934c..6772a104a0 100644 --- a/state/keys.go +++ b/state/keys.go @@ -6,11 +6,16 @@ package state import ( "encoding/hex" "encoding/json" + "errors" "fmt" + "strings" + "github.com/ava-labs/hypersdk/codec" "github.com/ava-labs/hypersdk/keys" ) +var errInvalidHexadecimalString = errors.New("invalid hexadecimal string") + const ( Read Permissions = 1 Allocate = 1<<1 | Read @@ -56,40 +61,66 @@ func (k Keys) ChunkSizes() ([]uint16, bool) { return chunks, true } -type permsJSON []string - -type keysJSON struct { - Perms [8]permsJSON -} +type keysJSON map[string]Permissions +// MarshalJSON marshals Keys as readable JSON. +// Keys are hex encoded strings and permissions +// are either valid named strings or unknown hex encoded strings. func (k Keys) MarshalJSON() ([]byte, error) { - var keysJSON keysJSON + kJSON := make(keysJSON) for key, perm := range k { - keysJSON.Perms[perm] = append(keysJSON.Perms[perm], hex.EncodeToString([]byte(key))) + hexKey, err := codec.Bytes(key).MarshalText() + if err != nil { + return nil, err + } + kJSON[string(hexKey)] = perm } - return json.Marshal(keysJSON) + return json.Marshal(kJSON) } +// UnmarshalJSON unmarshals readable JSON. func (k *Keys) UnmarshalJSON(b []byte) error { var keysJSON keysJSON if err := json.Unmarshal(b, &keysJSON); err != nil { return err } - for perm, keyList := range keysJSON.Perms { - if perm < int(None) || perm > int(All) { - return fmt.Errorf("invalid permission encoded in json %d", perm) + for hexKey, perm := range keysJSON { + var key codec.Bytes + err := key.UnmarshalText([]byte(hexKey)) + if err != nil { + return err } - for _, encodedKey := range keyList { - key, err := hex.DecodeString(encodedKey) - if err != nil { - return err - } - (*k)[string(key)] = Permissions(perm) + (*k)[string(key)] = perm + } + return nil +} + +func (p *Permissions) UnmarshalText(in []byte) error { + switch str := strings.ToLower(string(in)); str { + case "read": + *p = Read + case "write": + *p = Write + case "allocate": + *p = Allocate + case "all": + *p = All + case "none": + *p = None + default: + res, err := hex.DecodeString(str) + if err != nil || len(res) != 1 { + return fmt.Errorf("permission %s: %w", str, errInvalidHexadecimalString) } + *p = Permissions(res[0]) } return nil } +func (p Permissions) MarshalText() ([]byte, error) { + return []byte(p.String()), nil +} + // Has returns true if [p] has all the permissions that are contained in require func (p Permissions) Has(require Permissions) bool { return require&^p == 0 @@ -108,6 +139,6 @@ func (p Permissions) String() string { case None: return "none" default: - return "unknown" + return hex.EncodeToString([]byte{byte(p)}) } } diff --git a/state/keys_test.go b/state/keys_test.go index 8e819f836a..16094bceff 100644 --- a/state/keys_test.go +++ b/state/keys_test.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "math/rand" "slices" + "strconv" "testing" "github.com/stretchr/testify/require" @@ -135,7 +136,7 @@ func TestKeysMarshalingSimple(t *testing.T) { require.True(keys.Add("key1", Read)) bytes, err := keys.MarshalJSON() require.NoError(err) - require.Equal([]byte{0x7b, 0x22, 0x50, 0x65, 0x72, 0x6d, 0x73, 0x22, 0x3a, 0x5b, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x5b, 0x22, 0x36, 0x62, 0x36, 0x35, 0x37, 0x39, 0x33, 0x31, 0x22, 0x5d, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x5d, 0x7d}, bytes) + require.Equal(`{"6b657931":"read"}`, string(bytes)) keys = Keys{} require.NoError(keys.UnmarshalJSON(bytes)) require.Len(keys, 1) @@ -146,7 +147,7 @@ func TestKeysMarshalingSimple(t *testing.T) { require.True(keys.Add("key2", Read|Write)) bytes, err = keys.MarshalJSON() require.NoError(err) - require.Equal([]byte{0x7b, 0x22, 0x50, 0x65, 0x72, 0x6d, 0x73, 0x22, 0x3a, 0x5b, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x5b, 0x22, 0x36, 0x62, 0x36, 0x35, 0x37, 0x39, 0x33, 0x32, 0x22, 0x5d, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x2c, 0x6e, 0x75, 0x6c, 0x6c, 0x5d, 0x7d}, bytes) + require.Equal(`{"6b657932":"write"}`, string(bytes)) keys = Keys{} require.NoError(keys.UnmarshalJSON(bytes)) require.Len(keys, 1) @@ -181,3 +182,64 @@ func TestKeysMarshalingFuzz(t *testing.T) { require.True(keys.compare(decodedKeys)) } } + +func TestNewPermissionFromString(t *testing.T) { + tests := []struct { + strPerm string + perm Permissions + expectedErr error + }{ + { + strPerm: "read", + perm: Read, + }, + { + strPerm: "write", + perm: Write, + }, + { + strPerm: "allocate", + perm: Allocate, + }, + { + strPerm: "all", + perm: All, + }, + { + strPerm: "none", + perm: None, + }, + { + strPerm: "09", + perm: Permissions(9), + }, + { + strPerm: "010A", + expectedErr: errInvalidHexadecimalString, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + require := require.New(t) + var perm Permissions + err := perm.UnmarshalText([]byte(test.strPerm)) + if test.expectedErr != nil { + require.ErrorIs(err, test.expectedErr) + } else { + require.NoError(err) + require.Equal(test.perm, perm) + } + }) + } +} + +func TestPermissionStringer(t *testing.T) { + require := require.New(t) + require.Equal("read", Read.String()) + require.Equal("write", Write.String()) + require.Equal("allocate", Allocate.String()) + require.Equal("all", All.String()) + require.Equal("none", None.String()) + require.Equal("09", Permissions(9).String()) +}