Skip to content

Commit

Permalink
only build map of custom unmarshalers once instead of every call to u…
Browse files Browse the repository at this point in the history
…nmarshal
  • Loading branch information
jhump committed Aug 28, 2024
1 parent cb54552 commit 89140b1
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ import (
"gopkg.in/yaml.v3"
)

var (
// We have to initialize this from an init() function below
// instead of via initializer expression here to avoid the Go
// compiler complaining about a potential initialization cycle
// (the initializer expression refers to the function
// unmarshalAnyMsg, which indirectly refers back to this var).
wktUnmarshalers map[protoreflect.FullName]customUnmarshaler
)

// Validator is an interface for validating a Protobuf message produced from a given YAML node.
type Validator interface {
// Validate the given message.
Expand Down Expand Up @@ -139,13 +148,10 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message,
}
unm := &unmarshaler{
options: o,
custom: make(map[protoreflect.FullName]customUnmarshaler),
validator: o.Validator,
lines: strings.Split(string(data), "\n"),
}

addWktUnmarshalers(unm.custom)

// Unwrap the document node
if node.Kind == yaml.DocumentNode {
if len(node.Content) != 1 {
Expand Down Expand Up @@ -188,7 +194,6 @@ type protoResolver interface {
type unmarshaler struct {
options UnmarshalOptions
errors []error
custom map[protoreflect.FullName]customUnmarshaler
validator Validator
lines []string
}
Expand Down Expand Up @@ -657,7 +662,7 @@ func (u *unmarshaler) findNodeForCustom(node *yaml.Node, forAny bool) *yaml.Node
// Unmarshal the given yaml node into the given proto.Message.
func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message, forAny bool) {
// Check for a custom unmarshaler
if custom, ok := u.custom[message.ProtoReflect().Descriptor().FullName()]; ok {
if custom, ok := wktUnmarshalers[message.ProtoReflect().Descriptor().FullName()]; ok {
valueNode := u.findNodeForCustom(node, forAny)
if valueNode == nil {
return // Error already added.
Expand Down Expand Up @@ -718,28 +723,6 @@ func (u *unmarshaler) unmarshalMessageFields(node *yaml.Node, message proto.Mess

type customUnmarshaler func(u *unmarshaler, node *yaml.Node, message proto.Message) bool

// Add all well-known type unmarshalers to the given map (including struct unmarshalers).
func addWktUnmarshalers(custom map[protoreflect.FullName]customUnmarshaler) {
custom["google.protobuf.Any"] = unmarshalAnyMsg

custom["google.protobuf.Duration"] = unmarshalDurationMsg
custom["google.protobuf.Timestamp"] = unmarshalTimestampMsg

custom["google.protobuf.BoolValue"] = unmarshalWrapperMsg
custom["google.protobuf.BytesValue"] = unmarshalWrapperMsg
custom["google.protobuf.DoubleValue"] = unmarshalWrapperMsg
custom["google.protobuf.FloatValue"] = unmarshalWrapperMsg
custom["google.protobuf.Int32Value"] = unmarshalWrapperMsg
custom["google.protobuf.Int64Value"] = unmarshalWrapperMsg
custom["google.protobuf.UInt32Value"] = unmarshalWrapperMsg
custom["google.protobuf.UInt64Value"] = unmarshalWrapperMsg
custom["google.protobuf.StringValue"] = unmarshalWrapperMsg

custom["google.protobuf.Value"] = unmarshalValueMsg
custom["google.protobuf.ListValue"] = unmarshalListValueMsg
custom["google.protobuf.Struct"] = unmarshalStructMsg
}

func unmarshalAnyMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool {
if node.Kind != yaml.MappingNode || len(node.Content) == 0 {
return false
Expand Down Expand Up @@ -1359,3 +1342,23 @@ func leadingInt(str string) (result uint64, rem string, pre bool, err error) {
}
return result, str[i:], i > 0, nil
}

func init() { //nolint:gochecknoinits
wktUnmarshalers = map[protoreflect.FullName]customUnmarshaler{
"google.protobuf.Any": unmarshalAnyMsg,
"google.protobuf.Duration": unmarshalDurationMsg,
"google.protobuf.Timestamp": unmarshalTimestampMsg,
"google.protobuf.BoolValue": unmarshalWrapperMsg,
"google.protobuf.BytesValue": unmarshalWrapperMsg,
"google.protobuf.DoubleValue": unmarshalWrapperMsg,
"google.protobuf.FloatValue": unmarshalWrapperMsg,
"google.protobuf.Int32Value": unmarshalWrapperMsg,
"google.protobuf.Int64Value": unmarshalWrapperMsg,
"google.protobuf.UInt32Value": unmarshalWrapperMsg,
"google.protobuf.UInt64Value": unmarshalWrapperMsg,
"google.protobuf.StringValue": unmarshalWrapperMsg,
"google.protobuf.Value": unmarshalValueMsg,
"google.protobuf.ListValue": unmarshalListValueMsg,
"google.protobuf.Struct": unmarshalStructMsg,
}
}

0 comments on commit 89140b1

Please sign in to comment.