diff --git a/vms/platformvm/api/height.go b/vms/platformvm/api/height.go index 0a4a2660dcc4..68c82b4f1cae 100644 --- a/vms/platformvm/api/height.go +++ b/vms/platformvm/api/height.go @@ -7,13 +7,13 @@ import ( "errors" "math" - avajson "github.com/ava-labs/avalanchego/utils/json" + "github.com/ava-labs/avalanchego/utils/json" ) -type Height avajson.Uint64 +type Height json.Uint64 const ( - ProposedHeightFlag = "proposed" + ProposedHeightJSON = `"proposed"` ProposedHeight = math.MaxUint64 ) @@ -21,22 +21,28 @@ var errInvalidHeight = errors.New("invalid height") func (h Height) MarshalJSON() ([]byte, error) { if h == ProposedHeight { - return []byte(`"` + ProposedHeightFlag + `"`), nil + return []byte(ProposedHeightJSON), nil } - return avajson.Uint64(h).MarshalJSON() + return json.Uint64(h).MarshalJSON() } func (h *Height) UnmarshalJSON(b []byte) error { - if string(b) == ProposedHeightFlag { + // First check for known string values + switch string(b) { + case json.Null: + return nil + case ProposedHeightJSON: *h = ProposedHeight return nil } - err := (*avajson.Uint64)(h).UnmarshalJSON(b) - if err != nil { + + // Otherwise, unmarshal as a uint64 + if err := (*json.Uint64)(h).UnmarshalJSON(b); err != nil { return errInvalidHeight } - // MaxUint64 is reserved for proposed height - // return an error if supplied numerically + + // MaxUint64 is reserved for proposed height, so return an error if supplied + // numerically. if uint64(*h) == ProposedHeight { *h = 0 return errInvalidHeight diff --git a/vms/platformvm/api/height_test.go b/vms/platformvm/api/height_test.go index c20a9a5a1efc..b69982954f6d 100644 --- a/vms/platformvm/api/height_test.go +++ b/vms/platformvm/api/height_test.go @@ -4,42 +4,102 @@ package api import ( - "strconv" "testing" "github.com/stretchr/testify/require" ) -func TestMarshallHeight(t *testing.T) { - require := require.New(t) - h := Height(56) - bytes, err := h.MarshalJSON() - require.NoError(err) - require.Equal(`"56"`, string(bytes)) +func TestHeightMarshalJSON(t *testing.T) { + tests := []struct { + name string + height Height + expected string + }{ + { + name: "0", + height: 0, + expected: `"0"`, + }, + { + name: "56", + height: 56, + expected: `"56"`, + }, + { + name: "proposed", + height: ProposedHeight, + expected: `"proposed"`, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) - h = Height(ProposedHeight) - bytes, err = h.MarshalJSON() - require.NoError(err) - require.Equal(`"proposed"`, string(bytes)) + bytes, err := test.height.MarshalJSON() + require.NoError(err) + require.Equal(test.expected, string(bytes)) + }) + } } -func TestUnmarshallHeight(t *testing.T) { - require := require.New(t) +func TestHeightUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + initial Height + json string + expected Height + expectedErr error + }{ + { + name: "null 56", + initial: 56, + json: "null", + expected: 56, + }, + { + name: "null proposed", + initial: ProposedHeight, + json: "null", + expected: ProposedHeight, + }, + { + name: "proposed", + json: `"proposed"`, + expected: ProposedHeight, + }, + { + name: "not a number", + json: `"not a number"`, + expectedErr: errInvalidHeight, + }, + { + name: "56", + json: `56`, + expected: 56, + }, + { + name: `"56"`, + json: `"56"`, + expected: 56, + }, + { + name: "max uint64", + json: "18446744073709551615", + expectedErr: errInvalidHeight, + }, + { + name: `"max uint64"`, + json: `"18446744073709551615"`, + expectedErr: errInvalidHeight, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) - var h Height - - require.NoError(h.UnmarshalJSON([]byte("56"))) - require.Equal(Height(56), h) - - require.NoError(h.UnmarshalJSON([]byte(ProposedHeightFlag))) - require.Equal(Height(ProposedHeight), h) - require.True(h.IsProposed()) - - err := h.UnmarshalJSON([]byte("invalid")) - require.ErrorIs(err, errInvalidHeight) - require.Equal(Height(0), h) - - err = h.UnmarshalJSON([]byte(`"` + strconv.FormatUint(uint64(ProposedHeight), 10) + `"`)) - require.ErrorIs(err, errInvalidHeight) - require.Equal(Height(0), h) + err := test.initial.UnmarshalJSON([]byte(test.json)) + require.ErrorIs(err, test.expectedErr) + require.Equal(test.expected, test.initial) + }) + } } diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index 19db953251b7..361623e27e78 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -906,12 +906,55 @@ func TestGetValidatorsAt(t *testing.T) { service.vm.clock.Set(newLastAcceptedBlk.Timestamp().Add(40 * time.Second)) service.vm.ctx.Lock.Unlock() - // Resending the same request with [IsProposed] set to true should now + // Resending the same request with [Height] set to [platformapi.ProposedHeight] should now // include the new validator require.NoError(service.GetValidatorsAt(&http.Request{}, &args, &response)) require.Len(response.Validators, len(genesis.Validators)+1) } +func TestGetValidatorsAtArgsMarshalling(t *testing.T) { + subnetID, err := ids.FromString("u3Jjpzzj95827jdENvR1uc76f4zvvVQjGshbVWaSr2Ce5WV1H") + require.NoError(t, err) + + tests := []struct { + name string + args GetValidatorsAtArgs + json string + }{ + { + name: "specific height", + args: GetValidatorsAtArgs{ + Height: pchainapi.Height(12345), + SubnetID: subnetID, + }, + json: `{"height":"12345","subnetID":"u3Jjpzzj95827jdENvR1uc76f4zvvVQjGshbVWaSr2Ce5WV1H"}`, + }, + { + name: "proposed height", + args: GetValidatorsAtArgs{ + Height: pchainapi.ProposedHeight, + SubnetID: subnetID, + }, + json: `{"height":"proposed","subnetID":"u3Jjpzzj95827jdENvR1uc76f4zvvVQjGshbVWaSr2Ce5WV1H"}`, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + // Test that marshalling produces the expected JSON + argsJSON, err := json.Marshal(test.args) + require.NoError(err) + require.JSONEq(test.json, string(argsJSON)) + + // Test that unmarshalling produces the expected args + var parsedArgs GetValidatorsAtArgs + require.NoError(json.Unmarshal(argsJSON, &parsedArgs)) + require.Equal(test.args, parsedArgs) + }) + } +} + func TestGetTimestamp(t *testing.T) { require := require.New(t) service, _ := defaultService(t, upgradetest.Latest)