Skip to content

Commit

Permalink
Merge pull request #246 from quentinmit/duration
Browse files Browse the repository at this point in the history
Encode and decode time.Duration fields (#201)
  • Loading branch information
goccy authored Aug 25, 2021
2 parents e2008a9 + 0206999 commit 864ce75
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 11 deletions.
36 changes: 35 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,8 @@ func (d *Decoder) canDecodeByUnmarshaler(dst reflect.Value) bool {
return true
case *time.Time:
return true
case *time.Duration:
return true
case encoding.TextUnmarshaler:
return true
case jsonUnmarshaler:
Expand Down Expand Up @@ -576,6 +578,10 @@ func (d *Decoder) decodeByUnmarshaler(ctx context.Context, dst reflect.Value, sr
return d.decodeTime(ctx, dst, src)
}

if _, ok := iface.(*time.Duration); ok {
return d.decodeDuration(ctx, dst, src)
}

if unmarshaler, isText := iface.(encoding.TextUnmarshaler); isText {
b, ok := d.unmarshalableText(src)
if ok {
Expand Down Expand Up @@ -882,7 +888,35 @@ func (d *Decoder) castToTime(src ast.Node) (time.Time, error) {
func (d *Decoder) decodeTime(ctx context.Context, dst reflect.Value, src ast.Node) error {
t, err := d.castToTime(src)
if err != nil {
return err
return errors.Wrapf(err, "failed to convert to time")
}
dst.Set(reflect.ValueOf(t))
return nil
}

func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) {
if src == nil {
return 0, nil
}
v := d.nodeToValue(src)
if t, ok := v.(time.Duration); ok {
return t, nil
}
s, ok := v.(string)
if !ok {
return 0, errTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v))
}
t, err := time.ParseDuration(s)
if err != nil {
return 0, errors.Wrapf(err, "failed to parse duration")
}
return t, nil
}

func (d *Decoder) decodeDuration(ctx context.Context, dst reflect.Value, src ast.Node) error {
t, err := d.castToDuration(src)
if err != nil {
return errors.Wrapf(err, "failed to convert to duration")
}
dst.Set(reflect.ValueOf(t))
return nil
Expand Down
47 changes: 47 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ func TestDecoder(t *testing.T) {
"v: 2015-02-24 18:19:39\n",
map[string]time.Time{"v": time.Date(2015, 2, 24, 18, 19, 39, 0, time.UTC)},
},
{
"v: 60s\n",
map[string]time.Duration{"v": time.Minute},
},
{
"v: -0.5h\n",
map[string]time.Duration{"v": -30 * time.Minute},
},

// Single Quoted values.
{
Expand Down Expand Up @@ -1187,6 +1195,45 @@ func TestDecoder_TypeConversionError(t *testing.T) {
}
})
})
t.Run("type conversion for time", func(t *testing.T) {
type T struct {
A time.Time
B time.Duration
}
t.Run("int to time", func(t *testing.T) {
var v T
err := yaml.Unmarshal([]byte(`a: 123`), &v)
if err == nil {
t.Fatal("expected to error")
}
msg := "cannot unmarshal uint64 into Go struct field T.A of type time.Time"
if err.Error() != msg {
t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg)
}
})
t.Run("string to duration", func(t *testing.T) {
var v T
err := yaml.Unmarshal([]byte(`b: str`), &v)
if err == nil {
t.Fatal("expected to error")
}
msg := `time: invalid duration "str"`
if err.Error() != msg {
t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg)
}
})
t.Run("int to duration", func(t *testing.T) {
var v T
err := yaml.Unmarshal([]byte(`b: 10`), &v)
if err == nil {
t.Fatal("expected to error")
}
msg := "cannot unmarshal uint64 into Go struct field T.B of type time.Duration"
if err.Error() != msg {
t.Fatalf("unexpected error message: %s. expect: %s", err.Error(), msg)
}
})
})
}

func TestDecoder_AnchorReferenceDirs(t *testing.T) {
Expand Down
14 changes: 14 additions & 0 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ func (e *Encoder) canEncodeByMarshaler(v reflect.Value) bool {
return true
case time.Time:
return true
case time.Duration:
return true
case encoding.TextMarshaler:
return true
case jsonMarshaler:
Expand Down Expand Up @@ -254,6 +256,10 @@ func (e *Encoder) encodeByMarshaler(ctx context.Context, v reflect.Value, column
return e.encodeTime(t, column), nil
}

if t, ok := iface.(time.Duration); ok {
return e.encodeDuration(t, column), nil
}

if marshaler, ok := iface.(encoding.TextMarshaler); ok {
doc, err := marshaler.MarshalText()
if err != nil {
Expand Down Expand Up @@ -566,6 +572,14 @@ func (e *Encoder) encodeTime(v time.Time, column int) ast.Node {
return ast.String(token.New(value, value, e.pos(column)))
}

func (e *Encoder) encodeDuration(v time.Duration, column int) ast.Node {
value := v.String()
if e.isJSONStyle {
value = strconv.Quote(value)
}
return ast.String(token.New(value, value, e.pos(column)))
}

func (e *Encoder) encodeAnchor(anchorName string, value ast.Node, fieldValue reflect.Value, column int) (ast.Node, error) {
anchorNode := ast.Anchor(token.New("&", "&", e.pos(column)))
anchorNode.Name = ast.String(token.New(anchorName, anchorName, e.pos(column)))
Expand Down
27 changes: 17 additions & 10 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,11 @@ func TestEncoder(t *testing.T) {
map[string]*time.Time{"v": nil},
nil,
},
{
"v: 30s\n",
map[string]time.Duration{"v": 30 * time.Second},
nil,
},
}
for _, test := range tests {
var buf bytes.Buffer
Expand Down Expand Up @@ -934,14 +939,15 @@ func TestEncoder_JSON(t *testing.T) {
F float32
}
if err := enc.Encode(struct {
I int
U uint
S string
F float64
Struct *st
Slice []int
Map map[string]interface{}
Time time.Time
I int
U uint
S string
F float64
Struct *st
Slice []int
Map map[string]interface{}
Time time.Time
Duration time.Duration
}{
I: -10,
U: 10,
Expand All @@ -958,12 +964,13 @@ func TestEncoder_JSON(t *testing.T) {
"b": 1.23,
"c": "json",
},
Time: time.Time{},
Time: time.Time{},
Duration: 5 * time.Minute,
}); err != nil {
t.Fatalf("%+v", err)
}
expect := `
{"i": -10, "u": 10, "s": "hello", "f": 3.14, "struct": {"i": 2, "s": "world", "f": 1.23}, "slice": [1, 2, 3, 4, 5], "map": {"a": 1, "b": 1.23, "c": "json"}, "time": "0001-01-01T00:00:00Z"}
{"i": -10, "u": 10, "s": "hello", "f": 3.14, "struct": {"i": 2, "s": "world", "f": 1.23}, "slice": [1, 2, 3, 4, 5], "map": {"a": 1, "b": 1.23, "c": "json"}, "time": "0001-01-01T00:00:00Z", "duration": "5m0s"}
`
actual := "\n" + buf.String()
if expect != actual {
Expand Down

0 comments on commit 864ce75

Please sign in to comment.