diff --git a/decode.go b/decode.go index 25bd408..79e0412 100644 --- a/decode.go +++ b/decode.go @@ -35,6 +35,15 @@ import ( "gopkg.in/yaml.v3" ) +var ( + // We have to initialize this from an init() function below + // instead of via initializer expression here to avoid the Go + // compiler complaining about a potential initialization cycle + // (the initializer expression refers to the function + // unmarshalAnyMsg, which indirectly refers back to this var). + wktUnmarshalers map[protoreflect.FullName]customUnmarshaler +) + // Validator is an interface for validating a Protobuf message produced from a given YAML node. type Validator interface { // Validate the given message. @@ -139,13 +148,10 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, } unm := &unmarshaler{ options: o, - custom: make(map[protoreflect.FullName]customUnmarshaler), validator: o.Validator, lines: strings.Split(string(data), "\n"), } - addWktUnmarshalers(unm.custom) - // Unwrap the document node if node.Kind == yaml.DocumentNode { if len(node.Content) != 1 { @@ -188,7 +194,6 @@ type protoResolver interface { type unmarshaler struct { options UnmarshalOptions errors []error - custom map[protoreflect.FullName]customUnmarshaler validator Validator lines []string } @@ -657,7 +662,7 @@ func (u *unmarshaler) findNodeForCustom(node *yaml.Node, forAny bool) *yaml.Node // Unmarshal the given yaml node into the given proto.Message. func (u *unmarshaler) unmarshalMessage(node *yaml.Node, message proto.Message, forAny bool) { // Check for a custom unmarshaler - if custom, ok := u.custom[message.ProtoReflect().Descriptor().FullName()]; ok { + if custom, ok := wktUnmarshalers[message.ProtoReflect().Descriptor().FullName()]; ok { valueNode := u.findNodeForCustom(node, forAny) if valueNode == nil { return // Error already added. @@ -718,28 +723,6 @@ func (u *unmarshaler) unmarshalMessageFields(node *yaml.Node, message proto.Mess type customUnmarshaler func(u *unmarshaler, node *yaml.Node, message proto.Message) bool -// Add all well-known type unmarshalers to the given map (including struct unmarshalers). -func addWktUnmarshalers(custom map[protoreflect.FullName]customUnmarshaler) { - custom["google.protobuf.Any"] = unmarshalAnyMsg - - custom["google.protobuf.Duration"] = unmarshalDurationMsg - custom["google.protobuf.Timestamp"] = unmarshalTimestampMsg - - custom["google.protobuf.BoolValue"] = unmarshalWrapperMsg - custom["google.protobuf.BytesValue"] = unmarshalWrapperMsg - custom["google.protobuf.DoubleValue"] = unmarshalWrapperMsg - custom["google.protobuf.FloatValue"] = unmarshalWrapperMsg - custom["google.protobuf.Int32Value"] = unmarshalWrapperMsg - custom["google.protobuf.Int64Value"] = unmarshalWrapperMsg - custom["google.protobuf.UInt32Value"] = unmarshalWrapperMsg - custom["google.protobuf.UInt64Value"] = unmarshalWrapperMsg - custom["google.protobuf.StringValue"] = unmarshalWrapperMsg - - custom["google.protobuf.Value"] = unmarshalValueMsg - custom["google.protobuf.ListValue"] = unmarshalListValueMsg - custom["google.protobuf.Struct"] = unmarshalStructMsg -} - func unmarshalAnyMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool { if node.Kind != yaml.MappingNode || len(node.Content) == 0 { return false @@ -1359,3 +1342,23 @@ func leadingInt(str string) (result uint64, rem string, pre bool, err error) { } return result, str[i:], i > 0, nil } + +func init() { //nolint:gochecknoinits + wktUnmarshalers = map[protoreflect.FullName]customUnmarshaler{ + "google.protobuf.Any": unmarshalAnyMsg, + "google.protobuf.Duration": unmarshalDurationMsg, + "google.protobuf.Timestamp": unmarshalTimestampMsg, + "google.protobuf.BoolValue": unmarshalWrapperMsg, + "google.protobuf.BytesValue": unmarshalWrapperMsg, + "google.protobuf.DoubleValue": unmarshalWrapperMsg, + "google.protobuf.FloatValue": unmarshalWrapperMsg, + "google.protobuf.Int32Value": unmarshalWrapperMsg, + "google.protobuf.Int64Value": unmarshalWrapperMsg, + "google.protobuf.UInt32Value": unmarshalWrapperMsg, + "google.protobuf.UInt64Value": unmarshalWrapperMsg, + "google.protobuf.StringValue": unmarshalWrapperMsg, + "google.protobuf.Value": unmarshalValueMsg, + "google.protobuf.ListValue": unmarshalListValueMsg, + "google.protobuf.Struct": unmarshalStructMsg, + } +}