Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match CEL and Go duration literal parsing, while preserving the full range of values #38

Merged
merged 9 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 166 additions & 7 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"math"
"math/big"
"strconv"
"strings"
"time"
Expand All @@ -34,8 +35,6 @@ import (
"gopkg.in/yaml.v3"
)

const atTypeFieldName = "@type"

// Validator is an interface for validating a Protobuf message produced from a given YAML node.
type Validator interface {
// Validate the given message.
Expand All @@ -57,11 +56,6 @@ type UnmarshalOptions struct {
}
}

type protoResolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}

// Unmarshal a Protobuf message from the given YAML data.
func Unmarshal(data []byte, message proto.Message) error {
return (UnmarshalOptions{}).Unmarshal(data, message)
Expand All @@ -76,6 +70,106 @@ func (o UnmarshalOptions) Unmarshal(data []byte, message proto.Message) error {
return o.unmarshalNode(&yamlFile, message, data)
}

// ParseDuration parses a duration string into a durationpb.Duration.
//
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
//
// This function supports the full range of durationpb.Duration values, including
// those outside the range of time.Duration.
func ParseDuration(str string) (*durationpb.Duration, error) {

// [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+
neg := false

// Consume [-+]?
if str != "" {
c := str[0]
if c == '-' || c == '+' {
neg = c == '-'
str = str[1:]
}
}
// Special case: if all that is left is "0", this is zero.
if str == "0" {
var empty *durationpb.Duration
return empty, nil
}
if str == "" {
return nil, errors.New("invalid duration")
}
totalNanos := &big.Int{}
for str != "" {
// The next character must be [0-9.]
if !(str[0] == '.' || '0' <= str[0] && str[0] <= '9') {
return nil, errors.New("invalid duration")
}
var err error
var whole, frac uint64
var pre bool // Whether we have seen a digit before the dot.
whole, str, pre, err = leadingInt(str)
if err != nil {
return nil, err
}
var scale *big.Int
var post bool // Whether we have seen a digit after the dot.
if str != "" && str[0] == '.' {
str = str[1:]
frac, scale, str, post = leadingFrac(str)
}
if !pre && !post {
return nil, errors.New("invalid duration")
}

var end int
for ; end < len(str); end++ {
c := str[end]
if c == '.' || '0' <= c && c <= '9' || c == '-' {
break
}
}
if end == 0 {
return nil, errors.New("invalid duration: missing unit")
}
unitName := str[:end]
str = str[end:]
nanosPerUnit, ok := nanosMap[unitName]
if !ok {
return nil, fmt.Errorf("invalid duration: unknown unit, expected one of %v", unitsNames)
}

// Convert to nanos and add to total.
// totalNanos += whole * nanosPerUnit + frac * nanosPerUnit / scale
if whole > 0 {
wholeNanos := &big.Int{}
wholeNanos.SetUint64(whole)
wholeNanos.Mul(wholeNanos, nanosPerUnit)
totalNanos.Add(totalNanos, wholeNanos)
}
if frac > 0 {
fracNanos := &big.Int{}
fracNanos.SetUint64(frac)
fracNanos.Mul(fracNanos, nanosPerUnit)
rem := &big.Int{}
fracNanos.QuoRem(fracNanos, scale, rem)
if rem.Uint64() > 0 {
return nil, errors.New("invalid duration: fractional nanos")
}
totalNanos.Add(totalNanos, fracNanos)
}
}
if neg {
totalNanos.Neg(totalNanos)
}
result := &durationpb.Duration{}
quo, rem := totalNanos.QuoRem(totalNanos, nanosPerSecond, &big.Int{})
if !quo.IsInt64() {
return nil, errors.New("invalid duration: out of range")
}
result.Seconds = quo.Int64()
result.Nanos = int32(rem.Int64())
return result, nil
}

func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, data []byte) error {
if node.Kind == 0 {
return nil
Expand Down Expand Up @@ -121,6 +215,13 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message,
return nil
}

const atTypeFieldName = "@type"

type protoResolver interface {
protoregistry.MessageTypeResolver
protoregistry.ExtensionTypeResolver
}

type unmarshaler struct {
options UnmarshalOptions
errors []error
Expand Down Expand Up @@ -1184,3 +1285,61 @@ func findEntryByKey(cur *yaml.Node, key string) (*yaml.Node, *yaml.Node, bool) {
}
return nil, cur, false
}

var nanosPerSecond = new(big.Int).SetUint64(uint64(time.Second / time.Nanosecond))

var nanosMap = map[string]*big.Int{
"ns": new(big.Int).SetUint64(1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to be consistent with the rest of the lookups:

Suggested change
"ns": new(big.Int).SetUint64(1),
"ns": new(big.Int).SetUint64(time.Nanosceond),

Copy link
Contributor Author

@Alfus Alfus Jul 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the map to nanoseconds, it converts all the other units to nanos. It just happens to be that time.Nanoseconds is 1, to be consistent this would be:

new(big.Int).SetUint64(uint64(time.Nanosecond / time.Nanosecond)),

but the linter doesn't like that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but the random magic 1 takes away from the implied documentation. What does the linter complain about?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also suggest dropping the / time.Nanosecond as well, since it's also redundant if you know about the type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, but this conde isn't supposed to know that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added docs to clarify what these values represent (which is the number of nanos in each unit)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go's not going to change the unit value to pico's until at least Go2 XD

"us": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)),
"µs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+00B5 = micro symbol
"μs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+03BC = Greek letter mu
"ms": new(big.Int).SetUint64(uint64(time.Millisecond / time.Nanosecond)),
"s": new(big.Int).SetUint64(uint64(time.Second / time.Nanosecond)),
"m": new(big.Int).SetUint64(uint64(time.Minute / time.Nanosecond)),
"h": new(big.Int).SetUint64(uint64(time.Hour / time.Nanosecond)),
}

var unitsNames = []string{"h", "m", "s", "ms", "us", "ns"}
Alfus marked this conversation as resolved.
Show resolved Hide resolved

func leadingFrac(str string) (result uint64, scale *big.Int, rem string, post bool) {
var i int
scale = big.NewInt(1)
big10 := big.NewInt(10)
var overflow bool
for ; i < len(str); i++ {
c := str[i]
if c < '0' || c > '9' {
break
}
if overflow {
continue
}
if result > (1<<63-1)/10 {
overflow = true
continue
}
temp := result*10 + uint64(c-'0')
if temp > 1<<63 {
overflow = true
continue
}
result = temp
scale.Mul(scale, big10)
}
return result, scale, str[i:], i > 0
}

func leadingInt(str string) (result uint64, rem string, pre bool, err error) {
var i int
for ; i < len(str); i++ {
c := str[i]
if c < '0' || c > '9' {
break
}
if result >= (1<<64)/10 {
return 0, str, false, fmt.Errorf("invalid duration: integer overflow")
}
result = result*10 + uint64(c-'0')
}
return result, str[i:], i > 0, nil
}
65 changes: 65 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/durationpb"
"gotest.tools/v3/assert"
Alfus marked this conversation as resolved.
Show resolved Hide resolved
)

func TestGoldenFiles(t *testing.T) {
Expand Down Expand Up @@ -172,3 +173,67 @@ func testRunYAMLFile(t *testing.T, testFile string) {
t.Errorf("%s: Test %s failed:\nExpected:\n%s\nActual:\n%s\nDiff:\n%s", expectedFileName, testFile, expectedText, errorText, diff)
}
}

func TestToDuration(t *testing.T) {
t.Parallel()
for _, testCase := range []struct {
Literal string
Expected *durationpb.Duration
ErrMsg string
}{
{Literal: "", Expected: nil, ErrMsg: "invalid duration"},
{Literal: "-", Expected: nil, ErrMsg: "invalid duration"},
{Literal: "s", Expected: nil, ErrMsg: "invalid duration"},
{Literal: ".", Expected: nil, ErrMsg: "invalid duration"},
{Literal: "-s", Expected: nil, ErrMsg: "invalid duration"},
{Literal: ".s", Expected: nil, ErrMsg: "invalid duration"},
{Literal: "-.", Expected: nil, ErrMsg: "invalid duration"},
{Literal: "-.s", Expected: nil, ErrMsg: "invalid duration"},
{Literal: "0y", Expected: nil, ErrMsg: "unknown unit"},
{Literal: "0so", Expected: nil, ErrMsg: "unknown unit"},
{Literal: "0os", Expected: nil, ErrMsg: "unknown unit"},
{Literal: "0s-0ms", Expected: nil, ErrMsg: "invalid duration"},
{Literal: "0.5ns", Expected: nil, ErrMsg: "fractional nanos"},
{Literal: "0.0005us", Expected: nil, ErrMsg: "fractional nanos"},
{Literal: "0.0000005μs", Expected: nil, ErrMsg: "fractional nanos"},
{Literal: "0.0000000005ms", Expected: nil, ErrMsg: "fractional nanos"},
{Literal: "9223372036854775807s", Expected: &durationpb.Duration{Seconds: 9223372036854775807}},
{Literal: "9223372036854775808s", ErrMsg: "out of range"},
{Literal: "-9223372036854775808s", Expected: &durationpb.Duration{Seconds: -9223372036854775808}},
{Literal: "-9223372036854775809s", ErrMsg: "out of range"},
{Literal: "18446744073709551615s", ErrMsg: "out of range"},
{Literal: "18446744073709551616s", ErrMsg: "overflow"},
{Literal: "0"},
{Literal: "0s"},
{Literal: "-0s"},
{Literal: "1s", Expected: &durationpb.Duration{Seconds: 1}},
{Literal: "-1s", Expected: &durationpb.Duration{Seconds: -1}},
{Literal: "1.5s", Expected: &durationpb.Duration{Seconds: 1, Nanos: 500000000}},
{Literal: "-1.5s", Expected: &durationpb.Duration{Seconds: -1, Nanos: -500000000}},
{Literal: "1.000000001s", Expected: &durationpb.Duration{Seconds: 1, Nanos: 1}},
{Literal: "1.0000000001s", ErrMsg: "fractional nanos"},
{Literal: "1.000000000s", Expected: &durationpb.Duration{Seconds: 1}},
{Literal: "1.0000000010s", Expected: &durationpb.Duration{Seconds: 1, Nanos: 1}},
{Literal: "1h", Expected: &durationpb.Duration{Seconds: 3600}},
{Literal: "1m", Expected: &durationpb.Duration{Seconds: 60}},
{Literal: "1h1m", Expected: &durationpb.Duration{Seconds: 3660}},
{Literal: "1h1m1s", Expected: &durationpb.Duration{Seconds: 3661}},
{Literal: "1h1m1.5s", Expected: &durationpb.Duration{Seconds: 3661, Nanos: 500000000}},
{Literal: "1.5h1m1.5s", Expected: &durationpb.Duration{Seconds: 5461, Nanos: 500000000}},
{Literal: "1.5h1m1.5s1.5h1m1.5s", Expected: &durationpb.Duration{Seconds: 10923}},
{Literal: "1h1m1s1ms1us1μs1µs1ns", Expected: &durationpb.Duration{Seconds: 3661, Nanos: 1003001}},
} {
testCase := testCase
t.Run("", func(t *testing.T) {
t.Parallel()
actual, err := ParseDuration(testCase.Literal)
if testCase.ErrMsg != "" {
require.ErrorContains(t, err, testCase.ErrMsg)
return
}
require.NoError(t, err)
assert.Equal(t, testCase.Expected.GetSeconds(), actual.GetSeconds())
assert.Equal(t, testCase.Expected.GetNanos(), actual.GetNanos())
})
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ require (
golang.org/x/text v0.14.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect
gotest.tools/v3 v3.5.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogR
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
Loading