From 969bcabf2e780de3b2de562cfae2011dcc85600a Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Tue, 3 Oct 2023 23:54:20 -0400 Subject: [PATCH] Move the indication of destination message type into Codec (#599) This is so that a custom `Codec` can choose what details to reveal to clients, in cases where emitting the destination message type is considered to leak too much information. See #584 for more context. --- codec.go | 12 ++++++++++-- envelope.go | 2 +- protocol_connect.go | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/codec.go b/codec.go index 4809fe3c..7cd4552a 100644 --- a/codec.go +++ b/codec.go @@ -118,7 +118,11 @@ func (c *protoBinaryCodec) Unmarshal(data []byte, message any) error { if !ok { return errNotProto(message) } - return proto.Unmarshal(data, protoMessage) + err := proto.Unmarshal(data, protoMessage) + if err != nil { + return fmt.Errorf("unmarshal into %T: %w", message, err) + } + return nil } func (c *protoBinaryCodec) MarshalStable(message any) ([]byte, error) { @@ -174,7 +178,11 @@ func (c *protoJSONCodec) Unmarshal(binary []byte, message any) error { // Discard unknown fields so clients and servers aren't forced to always use // exactly the same version of the schema. options := protojson.UnmarshalOptions{DiscardUnknown: true} - return options.Unmarshal(binary, protoMessage) + err := options.Unmarshal(binary, protoMessage) + if err != nil { + return fmt.Errorf("unmarshal into %T: %w", message, err) + } + return nil } func (c *protoJSONCodec) MarshalStable(message any) ([]byte, error) { diff --git a/envelope.go b/envelope.go index 5f02acfa..c207e096 100644 --- a/envelope.go +++ b/envelope.go @@ -219,7 +219,7 @@ func (r *envelopeReader) Unmarshal(message any) *Error { } if err := r.codec.Unmarshal(data.Bytes(), message); err != nil { - return errorf(CodeInvalidArgument, "unmarshal into %T: %w", message, err) + return errorf(CodeInvalidArgument, "unmarshal message: %w", err) } return nil } diff --git a/protocol_connect.go b/protocol_connect.go index e87365e3..b14eb4db 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -1099,7 +1099,7 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by data = decompressed } if err := unmarshal(data.Bytes(), message); err != nil { - return errorf(CodeInvalidArgument, "unmarshal into %T: %w", message, err) + return errorf(CodeInvalidArgument, "unmarshal message: %w", err) } return nil }