Skip to content

Commit

Permalink
support custom unmarshalling for map keys
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Jun 19, 2024
1 parent 4653a1b commit cbf5617
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
18 changes: 14 additions & 4 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ import (
"strconv"
"time"

"golang.org/x/xerrors"

"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/internal/errors"
"github.com/goccy/go-yaml/parser"
"github.com/goccy/go-yaml/token"
"golang.org/x/xerrors"
)

// Decoder reads and decodes YAML values from an input stream.
Expand Down Expand Up @@ -1500,10 +1501,19 @@ func (d *Decoder) decodeMap(ctx context.Context, dst reflect.Value, src ast.Node
}
continue
}
k := reflect.ValueOf(d.nodeToValue(key))
if k.IsValid() && k.Type().ConvertibleTo(keyType) {
k = k.Convert(keyType)

k := d.createDecodableValue(keyType)
if d.canDecodeByUnmarshaler(k) {
if err := d.decodeByUnmarshaler(ctx, k, key); err != nil {
return errors.Wrapf(err, "failed to decode by unmarshaler")
}
} else {
k = reflect.ValueOf(d.nodeToValue(key))
if k.IsValid() && k.Type().ConvertibleTo(keyType) {
k = k.Convert(keyType)
}
}

if k.IsValid() {
if err := d.validateDuplicateKey(keyMap, k.Interface(), key); err != nil {
return errors.Wrapf(err, "invalid map key")
Expand Down
29 changes: 28 additions & 1 deletion decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ import (
"testing"
"time"

"golang.org/x/xerrors"

"github.com/goccy/go-yaml"
"github.com/goccy/go-yaml/ast"
"github.com/goccy/go-yaml/internal/errors"
"github.com/goccy/go-yaml/parser"
"golang.org/x/xerrors"
)

type Child struct {
Expand Down Expand Up @@ -2906,3 +2907,29 @@ func TestSameNameInineStruct(t *testing.T) {
t.Fatalf("failed to decode")
}
}

type unmarshableMapKey struct {
Key string
}

func (mk *unmarshableMapKey) UnmarshalYAML(b []byte) error {
mk.Key = string(b)
return nil
}

func TestMapKeyCustomUnmarshaler(t *testing.T) {
var m map[unmarshableMapKey]string
if err := yaml.Unmarshal([]byte(`key: value`), &m); err != nil {
t.Fatalf("failed to unmarshal %v", err)
}
if len(m) != 1 {
t.Fatalf("expected 1 element in map, but got %d", len(m))
}
val, ok := m[unmarshableMapKey{Key: "key"}]
if !ok {
t.Fatal("expected to have element 'key' in map")
}
if val != "value" {
t.Fatalf("expected to have value \"value\", but got %q", val)
}
}

0 comments on commit cbf5617

Please sign in to comment.