From 0b24e90b8499ee98600d6254ed4bbf5ca5c3bef4 Mon Sep 17 00:00:00 2001 From: Tim Deeb-Swihart Date: Tue, 28 Nov 2023 12:49:27 -0800 Subject: [PATCH 1/3] Fix handling of binary/null content type This was completely broken --- common/v1/payload_json.go | 26 ++++++++++++++++-- go.mod | 1 + go.sum | 1 + internal/temporalcommonv1/payload_json.go | 26 ++++++++++++++++-- .../temporalcommonv1/payload_json_test.go | 27 ++++++++++--------- 5 files changed, 65 insertions(+), 16 deletions(-) diff --git a/common/v1/payload_json.go b/common/v1/payload_json.go index b4ab578a..b72b91ba 100644 --- a/common/v1/payload_json.go +++ b/common/v1/payload_json.go @@ -43,6 +43,7 @@ const ( payloadMetadataKey = "metadata" payloadDataKey = "data" shorthandMessageTypeKey = "_protoMessageType" + binaryNullBase64 = "YmluYXJ5L251bGw=" ) var _ protojson.ProtoJSONMaybeMarshaler = (*Payload)(nil) @@ -55,7 +56,7 @@ var _ protojson.ProtoJSONMaybeUnmarshaler = (*Payloads)(nil) func marshalSingular(enc *json.Encoder, value interface{}) error { switch vv := value.(type) { case string: - enc.WriteString(vv) + return enc.WriteString(vv) case bool: enc.WriteBool(vv) case int: @@ -221,6 +222,22 @@ func marshalValue(enc *json.Encoder, vv reflect.Value) error { } func marshal(enc *json.Encoder, value interface{}) error { + if value == nil { + // nil data means we send the binary/null encoding + enc.StartObject() + defer enc.EndObject() + if err := enc.WriteName("metadata"); err != nil { + return err + } + + enc.StartObject() + defer enc.EndObject() + if err := enc.WriteName("encoding"); err != nil { + return err + } + // base64(binary/null) + return enc.WriteString(binaryNullBase64) + } return marshalValue(enc, reflect.ValueOf(value)) } @@ -376,6 +393,8 @@ func (p *Payload) toJSONShorthand() (handled bool, value interface{}, err error) valueMap[shorthandMessageTypeKey] = msgType } value = valueMap + default: + return false, nil, fmt.Errorf("unsupported encoding %s", string(p.Metadata["encoding"])) } return } @@ -493,7 +512,10 @@ func (p *Payload) unmarshalPayload(valueMap map[string]interface{}) bool { d, dataOk := valueMap[payloadDataKey] // It's ok to have no data key if the encoding is binary/null - if mdOk && !dataOk && enc == "binary/null" { + if mdOk && !dataOk && enc == binaryNullBase64 { + p.Metadata = map[string][]byte{ + "encoding": []byte("binary/null"), + } return true } else if !mdOk && !dataOk { return false diff --git a/go.mod b/go.mod index 90cc5137..c8d4522c 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/golang/mock v1.6.0 + github.com/google/go-cmp v0.6.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 github.com/stretchr/testify v1.8.4 google.golang.org/genproto/googleapis/api v0.0.0-20231127180814-3a041ad873d4 diff --git a/go.sum b/go.sum index c6c4b94f..5a7b338c 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,7 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 h1:6UKoz5ujsI55KNpsJH3UwCq3T8kKbZwNZBNPuTTje8U= github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1/go.mod h1:YvJ2f6MplWDhfxiUC3KpyTy76kYUZA4W3pTv/wdKQ9Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/temporalcommonv1/payload_json.go b/internal/temporalcommonv1/payload_json.go index b4ab578a..b72b91ba 100644 --- a/internal/temporalcommonv1/payload_json.go +++ b/internal/temporalcommonv1/payload_json.go @@ -43,6 +43,7 @@ const ( payloadMetadataKey = "metadata" payloadDataKey = "data" shorthandMessageTypeKey = "_protoMessageType" + binaryNullBase64 = "YmluYXJ5L251bGw=" ) var _ protojson.ProtoJSONMaybeMarshaler = (*Payload)(nil) @@ -55,7 +56,7 @@ var _ protojson.ProtoJSONMaybeUnmarshaler = (*Payloads)(nil) func marshalSingular(enc *json.Encoder, value interface{}) error { switch vv := value.(type) { case string: - enc.WriteString(vv) + return enc.WriteString(vv) case bool: enc.WriteBool(vv) case int: @@ -221,6 +222,22 @@ func marshalValue(enc *json.Encoder, vv reflect.Value) error { } func marshal(enc *json.Encoder, value interface{}) error { + if value == nil { + // nil data means we send the binary/null encoding + enc.StartObject() + defer enc.EndObject() + if err := enc.WriteName("metadata"); err != nil { + return err + } + + enc.StartObject() + defer enc.EndObject() + if err := enc.WriteName("encoding"); err != nil { + return err + } + // base64(binary/null) + return enc.WriteString(binaryNullBase64) + } return marshalValue(enc, reflect.ValueOf(value)) } @@ -376,6 +393,8 @@ func (p *Payload) toJSONShorthand() (handled bool, value interface{}, err error) valueMap[shorthandMessageTypeKey] = msgType } value = valueMap + default: + return false, nil, fmt.Errorf("unsupported encoding %s", string(p.Metadata["encoding"])) } return } @@ -493,7 +512,10 @@ func (p *Payload) unmarshalPayload(valueMap map[string]interface{}) bool { d, dataOk := valueMap[payloadDataKey] // It's ok to have no data key if the encoding is binary/null - if mdOk && !dataOk && enc == "binary/null" { + if mdOk && !dataOk && enc == binaryNullBase64 { + p.Metadata = map[string][]byte{ + "encoding": []byte("binary/null"), + } return true } else if !mdOk && !dataOk { return false diff --git a/internal/temporalcommonv1/payload_json_test.go b/internal/temporalcommonv1/payload_json_test.go index f251b1bc..e162c247 100644 --- a/internal/temporalcommonv1/payload_json_test.go +++ b/internal/temporalcommonv1/payload_json_test.go @@ -25,8 +25,10 @@ package common_test import ( "testing" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/testing/protocmp" "go.temporal.io/api/common/v1" "go.temporal.io/api/temporalproto" @@ -37,14 +39,6 @@ var tests = []struct { pb proto.Message json string }{{ - name: "binary/plain", - pb: &common.Payload{ - Metadata: map[string][]byte{ - "encoding": []byte("binary/plain"), - }, - Data: []byte("bytes")}, - json: `{"metadata": {"encoding": "YmluYXJ5L3BsYWlu"}, "data": "Ynl0ZXM="}`, -}, { name: "json/plain", json: `{"_protoMessageType": "temporal.api.common.v1.WorkflowType", "name": "workflow-name"}`, pb: &common.Payload{ @@ -54,6 +48,14 @@ var tests = []struct { }, Data: []byte(`{"name":"workflow-name"}`), }, +}, { + name: "binary/null", + json: `{"metadata":{"encoding":"YmluYXJ5L251bGw="}}`, + pb: &common.Payload{ + Metadata: map[string][]byte{ + "encoding": []byte("binary/null"), + }, + }, }, { name: "memo with varying fields", json: `{"fields": { @@ -61,13 +63,13 @@ var tests = []struct { "some-array": ["foo", 123, false]}}`, pb: &common.Memo{ Fields: map[string]*common.Payload{ - "some-string": &common.Payload{ + "some-string": { Metadata: map[string][]byte{ "encoding": []byte("json/plain"), }, Data: []byte(`"string"`), }, - "some-array": &common.Payload{ + "some-array": { Metadata: map[string][]byte{ "encoding": []byte("json/plain"), }, @@ -84,6 +86,7 @@ func TestMaybeMarshal(t *testing.T) { var opts temporalproto.CustomJSONMarshalOptions got, err := opts.Marshal(tt.pb) require.NoError(t, err) + t.Logf("%s", string(got)) require.JSONEq(t, tt.json, string(got), tt.pb) }) } @@ -96,8 +99,8 @@ func TestMaybeUnmarshal(t *testing.T) { var opts temporalproto.CustomJSONUnmarshalOptions err := opts.Unmarshal([]byte(tt.json), out) require.NoError(t, err) - if !proto.Equal(tt.pb, out) { - t.Errorf("protos mismatched\n%#v\n%#v", tt.pb, out) + if diff := cmp.Diff(tt.pb, out, protocmp.Transform()); diff != "" { + t.Errorf("Mismatch (-want, +got)\n%s", diff) } }) } From 7885765e5149c161fbf43d8b25a388c74375663a Mon Sep 17 00:00:00 2001 From: Tim Deeb-Swihart Date: Tue, 28 Nov 2023 12:52:15 -0800 Subject: [PATCH 2/3] Remove go-cmp again Ugh --- go.mod | 1 - go.sum | 1 - internal/temporalcommonv1/payload_json_test.go | 6 +----- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/go.mod b/go.mod index c8d4522c..90cc5137 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.20 require ( github.com/golang/mock v1.6.0 - github.com/google/go-cmp v0.6.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 github.com/stretchr/testify v1.8.4 google.golang.org/genproto/googleapis/api v0.0.0-20231127180814-3a041ad873d4 diff --git a/go.sum b/go.sum index 5a7b338c..c6c4b94f 100644 --- a/go.sum +++ b/go.sum @@ -8,7 +8,6 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 h1:6UKoz5ujsI55KNpsJH3UwCq3T8kKbZwNZBNPuTTje8U= github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1/go.mod h1:YvJ2f6MplWDhfxiUC3KpyTy76kYUZA4W3pTv/wdKQ9Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= diff --git a/internal/temporalcommonv1/payload_json_test.go b/internal/temporalcommonv1/payload_json_test.go index e162c247..2b200efa 100644 --- a/internal/temporalcommonv1/payload_json_test.go +++ b/internal/temporalcommonv1/payload_json_test.go @@ -25,10 +25,8 @@ package common_test import ( "testing" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/testing/protocmp" "go.temporal.io/api/common/v1" "go.temporal.io/api/temporalproto" @@ -99,9 +97,7 @@ func TestMaybeUnmarshal(t *testing.T) { var opts temporalproto.CustomJSONUnmarshalOptions err := opts.Unmarshal([]byte(tt.json), out) require.NoError(t, err) - if diff := cmp.Diff(tt.pb, out, protocmp.Transform()); diff != "" { - t.Errorf("Mismatch (-want, +got)\n%s", diff) - } + require.True(t, proto.Equal(tt.pb, out)) }) } } From cf302dbc6d9932afbc72567209df044cbcf3557d Mon Sep 17 00:00:00 2001 From: Tim Deeb-Swihart Date: Tue, 28 Nov 2023 13:06:47 -0800 Subject: [PATCH 3/3] Switch to explicitly enabling shorthand and improve tests I never tested longform JSON compatibility. This fixes that --- common/v1/payload_json.go | 21 ++++--- internal/temporalcommonv1/payload_json.go | 21 ++++--- .../temporalcommonv1/payload_json_test.go | 62 +++++++++++++++---- 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/common/v1/payload_json.go b/common/v1/payload_json.go index b72b91ba..637faf2b 100644 --- a/common/v1/payload_json.go +++ b/common/v1/payload_json.go @@ -241,10 +241,10 @@ func marshal(enc *json.Encoder, value interface{}) error { return marshalValue(enc, reflect.ValueOf(value)) } -// Key on the marshaler metadata specifying whether shorthand is disabled. +// Key on the marshaler metadata specifying whether shorthand is enabled. // // WARNING: This is internal API and should not be called externally. -const DisablePayloadShorthandMetadataKey = "__temporal_disable_payload_shorthand" +const EnablePayloadShorthandMetadataKey = "__temporal_enable_payload_shorthand" // MaybeMarshalProtoJSON implements // [go.temporal.io/api/internal/temporaljsonpb.ProtoJSONMaybeMarshaler.MaybeMarshalProtoJSON]. @@ -255,8 +255,9 @@ func (p *Payloads) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json. if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } @@ -289,8 +290,8 @@ func (p *Payloads) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *jso if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } tok, err := dec.Peek() @@ -331,8 +332,8 @@ func (p *Payload) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json.E if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } // If any are not handled or there is an error, return @@ -352,8 +353,8 @@ func (p *Payload) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *json if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } // Always considered handled, unmarshaler ignored (unknown fields always diff --git a/internal/temporalcommonv1/payload_json.go b/internal/temporalcommonv1/payload_json.go index b72b91ba..637faf2b 100644 --- a/internal/temporalcommonv1/payload_json.go +++ b/internal/temporalcommonv1/payload_json.go @@ -241,10 +241,10 @@ func marshal(enc *json.Encoder, value interface{}) error { return marshalValue(enc, reflect.ValueOf(value)) } -// Key on the marshaler metadata specifying whether shorthand is disabled. +// Key on the marshaler metadata specifying whether shorthand is enabled. // // WARNING: This is internal API and should not be called externally. -const DisablePayloadShorthandMetadataKey = "__temporal_disable_payload_shorthand" +const EnablePayloadShorthandMetadataKey = "__temporal_enable_payload_shorthand" // MaybeMarshalProtoJSON implements // [go.temporal.io/api/internal/temporaljsonpb.ProtoJSONMaybeMarshaler.MaybeMarshalProtoJSON]. @@ -255,8 +255,9 @@ func (p *Payloads) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json. if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } @@ -289,8 +290,8 @@ func (p *Payloads) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *jso if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } tok, err := dec.Peek() @@ -331,8 +332,8 @@ func (p *Payload) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json.E if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } // If any are not handled or there is an error, return @@ -352,8 +353,8 @@ func (p *Payload) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *json if p == nil { return false, nil } - // If shorthand is disabled, ignore - if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled { + // Skip unless explicitly enabled + if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled { return false, nil } // Always considered handled, unmarshaler ignored (unknown fields always diff --git a/internal/temporalcommonv1/payload_json_test.go b/internal/temporalcommonv1/payload_json_test.go index 2b200efa..3e6d5c74 100644 --- a/internal/temporalcommonv1/payload_json_test.go +++ b/internal/temporalcommonv1/payload_json_test.go @@ -33,12 +33,14 @@ import ( ) var tests = []struct { - name string - pb proto.Message - json string + name string + pb proto.Message + longformJSON string + shorthandJSON string }{{ - name: "json/plain", - json: `{"_protoMessageType": "temporal.api.common.v1.WorkflowType", "name": "workflow-name"}`, + name: "json/plain", + longformJSON: `{"metadata":{"encoding":"anNvbi9wcm90b2J1Zg==","messageType":"dGVtcG9yYWwuYXBpLmNvbW1vbi52MS5Xb3JrZmxvd1R5cGU="},"data":"eyJuYW1lIjoid29ya2Zsb3ctbmFtZSJ9"}`, + shorthandJSON: `{"_protoMessageType": "temporal.api.common.v1.WorkflowType", "name": "workflow-name"}`, pb: &common.Payload{ Metadata: map[string][]byte{ "encoding": []byte("json/protobuf"), @@ -47,8 +49,9 @@ var tests = []struct { Data: []byte(`{"name":"workflow-name"}`), }, }, { - name: "binary/null", - json: `{"metadata":{"encoding":"YmluYXJ5L251bGw="}}`, + name: "binary/null", + longformJSON: `{"metadata":{"encoding":"YmluYXJ5L251bGw="}}`, + shorthandJSON: `{"metadata":{"encoding":"YmluYXJ5L251bGw="}}`, pb: &common.Payload{ Metadata: map[string][]byte{ "encoding": []byte("binary/null"), @@ -56,7 +59,10 @@ var tests = []struct { }, }, { name: "memo with varying fields", - json: `{"fields": { + longformJSON: `{"fields":{ + "some-string":{"metadata":{"encoding":"anNvbi9wbGFpbg=="},"data":"InN0cmluZyI="}, + "some-array":{"metadata":{"encoding":"anNvbi9wbGFpbg=="},"data":"WyJmb28iLDEyMyxmYWxzZV0="}}}`, + shorthandJSON: `{"fields": { "some-string": "string", "some-array": ["foo", 123, false]}}`, pb: &common.Memo{ @@ -78,24 +84,56 @@ var tests = []struct { }, }} -func TestMaybeMarshal(t *testing.T) { +func TestMaybeMarshal_ShorthandEnabled(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := temporalproto.CustomJSONMarshalOptions{ + Metadata: map[string]interface{}{ + common.EnablePayloadShorthandMetadataKey: true, + }, + } + got, err := opts.Marshal(tt.pb) + require.NoError(t, err) + t.Logf("%s", string(got)) + require.JSONEq(t, tt.shorthandJSON, string(got)) + }) + } +} + +func TestMaybeMarshal_ShorthandDisabled(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var opts temporalproto.CustomJSONMarshalOptions got, err := opts.Marshal(tt.pb) require.NoError(t, err) t.Logf("%s", string(got)) - require.JSONEq(t, tt.json, string(got), tt.pb) + require.JSONEq(t, tt.longformJSON, string(got)) + }) + } +} + +func TestMaybeUnmarshal_Shorthand(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out = tt.pb.ProtoReflect().New().Interface() + opts := temporalproto.CustomJSONUnmarshalOptions{ + Metadata: map[string]interface{}{ + common.EnablePayloadShorthandMetadataKey: true, + }, + } + err := opts.Unmarshal([]byte(tt.shorthandJSON), out) + require.NoError(t, err) + require.True(t, proto.Equal(tt.pb, out)) }) } } -func TestMaybeUnmarshal(t *testing.T) { +func TestMaybeUnmarshal_Longform(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var out = tt.pb.ProtoReflect().New().Interface() var opts temporalproto.CustomJSONUnmarshalOptions - err := opts.Unmarshal([]byte(tt.json), out) + err := opts.Unmarshal([]byte(tt.longformJSON), out) require.NoError(t, err) require.True(t, proto.Equal(tt.pb, out)) })