Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of binary/null content type #136

Merged
merged 3 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions common/v1/payload_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
payloadMetadataKey = "metadata"
payloadDataKey = "data"
shorthandMessageTypeKey = "_protoMessageType"
binaryNullBase64 = "YmluYXJ5L251bGw="
)

var _ protojson.ProtoJSONMaybeMarshaler = (*Payload)(nil)
Expand All @@ -55,7 +56,7 @@ var _ protojson.ProtoJSONMaybeUnmarshaler = (*Payloads)(nil)
func marshalSingular(enc *json.Encoder, value interface{}) error {
switch vv := value.(type) {
case string:
enc.WriteString(vv)
return enc.WriteString(vv)
case bool:
enc.WriteBool(vv)
case int:
Expand Down Expand Up @@ -221,13 +222,29 @@ func marshalValue(enc *json.Encoder, vv reflect.Value) error {
}

func marshal(enc *json.Encoder, value interface{}) error {
if value == nil {
// nil data means we send the binary/null encoding
enc.StartObject()
defer enc.EndObject()
if err := enc.WriteName("metadata"); err != nil {
return err
}

enc.StartObject()
defer enc.EndObject()
if err := enc.WriteName("encoding"); err != nil {
return err
}
// base64(binary/null)
return enc.WriteString(binaryNullBase64)
}
return marshalValue(enc, reflect.ValueOf(value))
}

// Key on the marshaler metadata specifying whether shorthand is disabled.
// Key on the marshaler metadata specifying whether shorthand is enabled.
//
// WARNING: This is internal API and should not be called externally.
const DisablePayloadShorthandMetadataKey = "__temporal_disable_payload_shorthand"
const EnablePayloadShorthandMetadataKey = "__temporal_enable_payload_shorthand"

// MaybeMarshalProtoJSON implements
// [go.temporal.io/api/internal/temporaljsonpb.ProtoJSONMaybeMarshaler.MaybeMarshalProtoJSON].
Expand All @@ -238,8 +255,9 @@ func (p *Payloads) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json.
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {

// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}

Expand Down Expand Up @@ -272,8 +290,8 @@ func (p *Payloads) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *jso
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {
// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}
tok, err := dec.Peek()
Expand Down Expand Up @@ -314,8 +332,8 @@ func (p *Payload) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json.E
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {
// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}
// If any are not handled or there is an error, return
Expand All @@ -335,8 +353,8 @@ func (p *Payload) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *json
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {
// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}
// Always considered handled, unmarshaler ignored (unknown fields always
Expand Down Expand Up @@ -376,6 +394,8 @@ func (p *Payload) toJSONShorthand() (handled bool, value interface{}, err error)
valueMap[shorthandMessageTypeKey] = msgType
}
value = valueMap
default:
return false, nil, fmt.Errorf("unsupported encoding %s", string(p.Metadata["encoding"]))
}
return
}
Expand Down Expand Up @@ -493,7 +513,10 @@ func (p *Payload) unmarshalPayload(valueMap map[string]interface{}) bool {

d, dataOk := valueMap[payloadDataKey]
// It's ok to have no data key if the encoding is binary/null
if mdOk && !dataOk && enc == "binary/null" {
if mdOk && !dataOk && enc == binaryNullBase64 {
p.Metadata = map[string][]byte{
"encoding": []byte("binary/null"),
}
return true
} else if !mdOk && !dataOk {
return false
Expand Down
47 changes: 35 additions & 12 deletions internal/temporalcommonv1/payload_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
payloadMetadataKey = "metadata"
payloadDataKey = "data"
shorthandMessageTypeKey = "_protoMessageType"
binaryNullBase64 = "YmluYXJ5L251bGw="
)

var _ protojson.ProtoJSONMaybeMarshaler = (*Payload)(nil)
Expand All @@ -55,7 +56,7 @@ var _ protojson.ProtoJSONMaybeUnmarshaler = (*Payloads)(nil)
func marshalSingular(enc *json.Encoder, value interface{}) error {
switch vv := value.(type) {
case string:
enc.WriteString(vv)
return enc.WriteString(vv)
case bool:
enc.WriteBool(vv)
case int:
Expand Down Expand Up @@ -221,13 +222,29 @@ func marshalValue(enc *json.Encoder, vv reflect.Value) error {
}

func marshal(enc *json.Encoder, value interface{}) error {
if value == nil {
// nil data means we send the binary/null encoding
enc.StartObject()
defer enc.EndObject()
if err := enc.WriteName("metadata"); err != nil {
return err
}

enc.StartObject()
defer enc.EndObject()
if err := enc.WriteName("encoding"); err != nil {
return err
}
// base64(binary/null)
return enc.WriteString(binaryNullBase64)
}
return marshalValue(enc, reflect.ValueOf(value))
}

// Key on the marshaler metadata specifying whether shorthand is disabled.
// Key on the marshaler metadata specifying whether shorthand is enabled.
//
// WARNING: This is internal API and should not be called externally.
const DisablePayloadShorthandMetadataKey = "__temporal_disable_payload_shorthand"
const EnablePayloadShorthandMetadataKey = "__temporal_enable_payload_shorthand"

// MaybeMarshalProtoJSON implements
// [go.temporal.io/api/internal/temporaljsonpb.ProtoJSONMaybeMarshaler.MaybeMarshalProtoJSON].
Expand All @@ -238,8 +255,9 @@ func (p *Payloads) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json.
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {

// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}

Expand Down Expand Up @@ -272,8 +290,8 @@ func (p *Payloads) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *jso
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {
// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}
tok, err := dec.Peek()
Expand Down Expand Up @@ -314,8 +332,8 @@ func (p *Payload) MaybeMarshalProtoJSON(meta map[string]interface{}, enc *json.E
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {
// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}
// If any are not handled or there is an error, return
Expand All @@ -335,8 +353,8 @@ func (p *Payload) MaybeUnmarshalProtoJSON(meta map[string]interface{}, dec *json
if p == nil {
return false, nil
}
// If shorthand is disabled, ignore
if disabled, _ := meta[DisablePayloadShorthandMetadataKey].(bool); disabled {
// Skip unless explicitly enabled
if _, enabled := meta[EnablePayloadShorthandMetadataKey].(bool); !enabled {
return false, nil
}
// Always considered handled, unmarshaler ignored (unknown fields always
Expand Down Expand Up @@ -376,6 +394,8 @@ func (p *Payload) toJSONShorthand() (handled bool, value interface{}, err error)
valueMap[shorthandMessageTypeKey] = msgType
}
value = valueMap
default:
return false, nil, fmt.Errorf("unsupported encoding %s", string(p.Metadata["encoding"]))
}
return
}
Expand Down Expand Up @@ -493,7 +513,10 @@ func (p *Payload) unmarshalPayload(valueMap map[string]interface{}) bool {

d, dataOk := valueMap[payloadDataKey]
// It's ok to have no data key if the encoding is binary/null
if mdOk && !dataOk && enc == "binary/null" {
if mdOk && !dataOk && enc == binaryNullBase64 {
p.Metadata = map[string][]byte{
"encoding": []byte("binary/null"),
}
return true
} else if !mdOk && !dataOk {
return false
Expand Down
81 changes: 59 additions & 22 deletions internal/temporalcommonv1/payload_json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,41 +33,47 @@ import (
)

var tests = []struct {
name string
pb proto.Message
json string
name string
pb proto.Message
longformJSON string
shorthandJSON string
}{{
name: "binary/plain",
name: "json/plain",
longformJSON: `{"metadata":{"encoding":"anNvbi9wcm90b2J1Zg==","messageType":"dGVtcG9yYWwuYXBpLmNvbW1vbi52MS5Xb3JrZmxvd1R5cGU="},"data":"eyJuYW1lIjoid29ya2Zsb3ctbmFtZSJ9"}`,
shorthandJSON: `{"_protoMessageType": "temporal.api.common.v1.WorkflowType", "name": "workflow-name"}`,
pb: &common.Payload{
Metadata: map[string][]byte{
"encoding": []byte("binary/plain"),
"encoding": []byte("json/protobuf"),
"messageType": []byte("temporal.api.common.v1.WorkflowType"),
},
Data: []byte("bytes")},
json: `{"metadata": {"encoding": "YmluYXJ5L3BsYWlu"}, "data": "Ynl0ZXM="}`,
Data: []byte(`{"name":"workflow-name"}`),
},
}, {
name: "json/plain",
json: `{"_protoMessageType": "temporal.api.common.v1.WorkflowType", "name": "workflow-name"}`,
name: "binary/null",
longformJSON: `{"metadata":{"encoding":"YmluYXJ5L251bGw="}}`,
shorthandJSON: `{"metadata":{"encoding":"YmluYXJ5L251bGw="}}`,
pb: &common.Payload{
Metadata: map[string][]byte{
"encoding": []byte("json/protobuf"),
"messageType": []byte("temporal.api.common.v1.WorkflowType"),
"encoding": []byte("binary/null"),
},
Data: []byte(`{"name":"workflow-name"}`),
},
}, {
name: "memo with varying fields",
json: `{"fields": {
longformJSON: `{"fields":{
"some-string":{"metadata":{"encoding":"anNvbi9wbGFpbg=="},"data":"InN0cmluZyI="},
"some-array":{"metadata":{"encoding":"anNvbi9wbGFpbg=="},"data":"WyJmb28iLDEyMyxmYWxzZV0="}}}`,
shorthandJSON: `{"fields": {
"some-string": "string",
"some-array": ["foo", 123, false]}}`,
pb: &common.Memo{
Fields: map[string]*common.Payload{
"some-string": &common.Payload{
"some-string": {
Metadata: map[string][]byte{
"encoding": []byte("json/plain"),
},
Data: []byte(`"string"`),
},
"some-array": &common.Payload{
"some-array": {
Metadata: map[string][]byte{
"encoding": []byte("json/plain"),
},
Expand All @@ -78,27 +84,58 @@ var tests = []struct {
},
}}

func TestMaybeMarshal(t *testing.T) {
func TestMaybeMarshal_ShorthandEnabled(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := temporalproto.CustomJSONMarshalOptions{
Metadata: map[string]interface{}{
common.EnablePayloadShorthandMetadataKey: true,
},
}
got, err := opts.Marshal(tt.pb)
require.NoError(t, err)
t.Logf("%s", string(got))
require.JSONEq(t, tt.shorthandJSON, string(got))
})
}
}

func TestMaybeMarshal_ShorthandDisabled(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var opts temporalproto.CustomJSONMarshalOptions
got, err := opts.Marshal(tt.pb)
require.NoError(t, err)
require.JSONEq(t, tt.json, string(got), tt.pb)
t.Logf("%s", string(got))
require.JSONEq(t, tt.longformJSON, string(got))
})
}
}

func TestMaybeUnmarshal_Shorthand(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var out = tt.pb.ProtoReflect().New().Interface()
opts := temporalproto.CustomJSONUnmarshalOptions{
Metadata: map[string]interface{}{
common.EnablePayloadShorthandMetadataKey: true,
},
}
err := opts.Unmarshal([]byte(tt.shorthandJSON), out)
require.NoError(t, err)
require.True(t, proto.Equal(tt.pb, out))
})
}
}

func TestMaybeUnmarshal(t *testing.T) {
func TestMaybeUnmarshal_Longform(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var out = tt.pb.ProtoReflect().New().Interface()
var opts temporalproto.CustomJSONUnmarshalOptions
err := opts.Unmarshal([]byte(tt.json), out)
err := opts.Unmarshal([]byte(tt.longformJSON), out)
require.NoError(t, err)
if !proto.Equal(tt.pb, out) {
t.Errorf("protos mismatched\n%#v\n%#v", tt.pb, out)
}
require.True(t, proto.Equal(tt.pb, out))
})
}
}
Loading