diff --git a/.golangci.yml b/.golangci.yml index 9907f5b..cf4cf87 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -61,6 +61,10 @@ linters: - wsl # generous whitespace violates house style - exhaustive - exhaustruct + - nonamedreturns + - mnd + - err113 + - gochecknoglobals issues: exclude: # Don't ban use of fmt.Errorf to create new errors, but the remaining diff --git a/decode.go b/decode.go index 34067b4..ad69793 100644 --- a/decode.go +++ b/decode.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "math" + "math/big" "strconv" "strings" "time" @@ -34,8 +35,6 @@ import ( "gopkg.in/yaml.v3" ) -const atTypeFieldName = "@type" - // Validator is an interface for validating a Protobuf message produced from a given YAML node. type Validator interface { // Validate the given message. @@ -57,11 +56,6 @@ type UnmarshalOptions struct { } } -type protoResolver interface { - protoregistry.MessageTypeResolver - protoregistry.ExtensionTypeResolver -} - // Unmarshal a Protobuf message from the given YAML data. func Unmarshal(data []byte, message proto.Message) error { return (UnmarshalOptions{}).Unmarshal(data, message) @@ -76,6 +70,53 @@ func (o UnmarshalOptions) Unmarshal(data []byte, message proto.Message) error { return o.unmarshalNode(&yamlFile, message, data) } +// ParseDuration parses a duration string into a durationpb.Duration. +// +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +// +// This function supports the full range of durationpb.Duration values, including +// those outside the range of time.Duration. +func ParseDuration(str string) (*durationpb.Duration, error) { + // [-+]?([0-9]*(\.[0-9]*)?[a-z]+)+ + neg := false + + // Consume [-+]? + if str != "" { + c := str[0] + if c == '-' || c == '+' { + neg = c == '-' + str = str[1:] + } + } + // Special case: if all that is left is "0", this is zero. + if str == "0" { + var empty *durationpb.Duration + return empty, nil + } + if str == "" { + return nil, errors.New("invalid duration") + } + totalNanos := &big.Int{} + var err error + for str != "" { + str, err = parseDurationNext(str, totalNanos) + if err != nil { + return nil, err + } + } + if neg { + totalNanos.Neg(totalNanos) + } + result := &durationpb.Duration{} + quo, rem := totalNanos.QuoRem(totalNanos, nanosPerSecond, &big.Int{}) + if !quo.IsInt64() { + return nil, errors.New("invalid duration: out of range") + } + result.Seconds = quo.Int64() + result.Nanos = int32(rem.Int64()) + return result, nil +} + func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, data []byte) error { if node.Kind == 0 { return nil @@ -121,6 +162,13 @@ func (o UnmarshalOptions) unmarshalNode(node *yaml.Node, message proto.Message, return nil } +const atTypeFieldName = "@type" + +type protoResolver interface { + protoregistry.MessageTypeResolver + protoregistry.ExtensionTypeResolver +} + type unmarshaler struct { options UnmarshalOptions errors []error @@ -683,54 +731,6 @@ const ( minTimestampSeconds = -62135596800 ) -// Format is decimal seconds with up to 9 fractional digits, followed by an 's'. -func parseDuration(txt string, duration *durationpb.Duration) error { - // Remove trailing s. - txt = strings.TrimSpace(txt) - if len(txt) == 0 || txt[len(txt)-1] != 's' { - return errors.New("missing trailing 's'") - } - value := txt[:len(txt)-1] - isNeg := strings.HasPrefix(value, "-") - - // Split into seconds and nanos. - parts := strings.Split(value, ".") - switch len(parts) { - case 1: - // seconds only - seconds, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil { - return err - } - duration.Seconds = seconds - duration.Nanos = 0 - case 2: - // seconds and up to 9 digits of fractional seconds - seconds, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil { - return err - } - duration.Seconds = seconds - nanos, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return err - } - power := 9 - len(parts[1]) - if power < 0 { - return errors.New("too many fractional second digits") - } - nanos *= int64(math.Pow10(power)) - if isNeg { - duration.Nanos = -int32(nanos) - } else { - duration.Nanos = int32(nanos) - } - default: - return errors.New("invalid duration: too many '.' characters") - } - return nil -} - // Format is RFC3339Nano, limited to the range 0001-01-01T00:00:00Z to // 9999-12-31T23:59:59Z inclusive. func parseTimestamp(txt string, timestamp *timestamppb.Timestamp) error { @@ -770,19 +770,21 @@ func unmarshalDurationMsg(unm *unmarshaler, node *yaml.Node, message proto.Messa if node.Kind != yaml.ScalarNode || len(node.Value) == 0 || isNull(node) { return false } - duration, ok := message.(*durationpb.Duration) - if !ok { - duration = &durationpb.Duration{} - } - err := parseDuration(node.Value, duration) + duration, err := ParseDuration(node.Value) if err != nil { - unm.addErrorf(node, "invalid duration: %v", err) - } else if !ok { - // Set the fields dynamically. - return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.GetSeconds())) && - setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos())) + unm.addError(node, err) + return true } - return true + + if value, ok := message.(*durationpb.Duration); ok { + value.Seconds = duration.GetSeconds() + value.Nanos = duration.GetNanos() + return true + } + + // Set the fields dynamically. + return setFieldByName(message, "seconds", protoreflect.ValueOfInt64(duration.GetSeconds())) && + setFieldByName(message, "nanos", protoreflect.ValueOfInt32(duration.GetNanos())) } func unmarshalTimestampMsg(unm *unmarshaler, node *yaml.Node, message proto.Message) bool { @@ -1184,3 +1186,132 @@ func findEntryByKey(cur *yaml.Node, key string) (*yaml.Node, *yaml.Node, bool) { } return nil, cur, false } + +// nanosPerSecond is the number of nanoseconds in a second. +var nanosPerSecond = new(big.Int).SetUint64(uint64(time.Second / time.Nanosecond)) + +// nanosMap is a map of time unit names to their duration in nanoseconds. +var nanosMap = map[string]*big.Int{ + "ns": new(big.Int).SetUint64(1), // Identity for nanos. + "us": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), + "µs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+00B5 = micro symbol + "μs": new(big.Int).SetUint64(uint64(time.Microsecond / time.Nanosecond)), // U+03BC = Greek letter mu + "ms": new(big.Int).SetUint64(uint64(time.Millisecond / time.Nanosecond)), + "s": nanosPerSecond, + "m": new(big.Int).SetUint64(uint64(time.Minute / time.Nanosecond)), + "h": new(big.Int).SetUint64(uint64(time.Hour / time.Nanosecond)), +} + +// unitsNames is the (normalized) list of time unit names. +var unitsNames = []string{"h", "m", "s", "ms", "us", "ns"} + +// parseDurationNest parses a single segment of the duration string. +func parseDurationNext(str string, totalNanos *big.Int) (string, error) { + // The next character must be [0-9.] + if !(str[0] == '.' || '0' <= str[0] && str[0] <= '9') { + return "", errors.New("invalid duration") + } + var err error + var whole, frac uint64 + var pre bool // Whether we have seen a digit before the dot. + whole, str, pre, err = leadingInt(str) + if err != nil { + return "", err + } + var scale *big.Int + var post bool // Whether we have seen a digit after the dot. + if str != "" && str[0] == '.' { + str = str[1:] + frac, scale, str, post = leadingFrac(str) + } + if !pre && !post { + return "", errors.New("invalid duration") + } + + end := unitEnd(str) + if end == 0 { + return "", fmt.Errorf("invalid duration: missing unit, expected one of %v", unitsNames) + } + unitName := str[:end] + str = str[end:] + nanosPerUnit, ok := nanosMap[unitName] + if !ok { + return "", fmt.Errorf("invalid duration: unknown unit, expected one of %v", unitsNames) + } + + // Convert to nanos and add to total. + // totalNanos += whole * nanosPerUnit + frac * nanosPerUnit / scale + if whole > 0 { + wholeNanos := &big.Int{} + wholeNanos.SetUint64(whole) + wholeNanos.Mul(wholeNanos, nanosPerUnit) + totalNanos.Add(totalNanos, wholeNanos) + } + if frac > 0 { + fracNanos := &big.Int{} + fracNanos.SetUint64(frac) + fracNanos.Mul(fracNanos, nanosPerUnit) + rem := &big.Int{} + fracNanos.QuoRem(fracNanos, scale, rem) + if rem.Uint64() > 0 { + return "", errors.New("invalid duration: fractional nanos") + } + totalNanos.Add(totalNanos, fracNanos) + } + return str, nil +} + +func unitEnd(str string) int { + var i int + for ; i < len(str); i++ { + c := str[i] + if c == '.' || c == '-' || '0' <= c && c <= '9' { + return i + } + } + return i +} + +func leadingFrac(str string) (result uint64, scale *big.Int, rem string, post bool) { + var i int + scale = big.NewInt(1) + big10 := big.NewInt(10) + var overflow bool + for ; i < len(str); i++ { + chr := str[i] + if chr < '0' || chr > '9' { + break + } + if overflow { + continue + } + if result > (1<<63-1)/10 { + overflow = true + continue + } + temp := result*10 + uint64(chr-'0') + if temp > 1<<63 { + overflow = true + continue + } + result = temp + scale.Mul(scale, big10) + } + return result, scale, str[i:], i > 0 +} + +func leadingInt(str string) (result uint64, rem string, pre bool, err error) { + var i int + for ; i < len(str); i++ { + c := str[i] + if c < '0' || c > '9' { + break + } + newResult := result*10 + uint64(c-'0') + if newResult < result { + return 0, str, i > 0, errors.New("integer overflow") + } + result = newResult + } + return result, str[i:], i > 0, nil +} diff --git a/decode_test.go b/decode_test.go index d586465..d1a8a7b 100644 --- a/decode_test.go +++ b/decode_test.go @@ -24,9 +24,9 @@ import ( "github.com/bufbuild/protovalidate-go" testv1 "github.com/bufbuild/protoyaml-go/internal/gen/proto/buf/protoyaml/test/v1" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/durationpb" ) @@ -54,49 +54,62 @@ func TestParseDuration(t *testing.T) { for _, testCase := range []struct { Input string Expected *durationpb.Duration + ErrMsg string }{ - {Input: "", Expected: nil}, - {Input: "-", Expected: nil}, - {Input: "-s", Expected: nil}, - {Input: "0s", Expected: &durationpb.Duration{}}, - {Input: "-0s", Expected: &durationpb.Duration{}}, + {Input: "", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "-", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "s", Expected: nil, ErrMsg: "invalid duration"}, + {Input: ".", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "-s", Expected: nil, ErrMsg: "invalid duration"}, + {Input: ".s", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "-.", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "-.s", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "--0s", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "0y", Expected: nil, ErrMsg: "unknown unit"}, + {Input: "0so", Expected: nil, ErrMsg: "unknown unit"}, + {Input: "0os", Expected: nil, ErrMsg: "unknown unit"}, + {Input: "0s-0ms", Expected: nil, ErrMsg: "invalid duration"}, + {Input: "0.5ns", Expected: nil, ErrMsg: "fractional nanos"}, + {Input: "0.0005us", Expected: nil, ErrMsg: "fractional nanos"}, + {Input: "0.0000005μs", Expected: nil, ErrMsg: "fractional nanos"}, + {Input: "0.0000000005ms", Expected: nil, ErrMsg: "fractional nanos"}, + {Input: "9223372036854775807s", Expected: &durationpb.Duration{Seconds: 9223372036854775807}}, + {Input: "9223372036854775808s", ErrMsg: "out of range"}, + {Input: "-9223372036854775808s", Expected: &durationpb.Duration{Seconds: -9223372036854775808}}, + {Input: "-9223372036854775809s", ErrMsg: "out of range"}, + {Input: "18446744073709551615s", ErrMsg: "out of range"}, + {Input: "18446744073709551616s", ErrMsg: "overflow"}, + {Input: "0"}, + {Input: "0s"}, + {Input: "-0s"}, {Input: "1s", Expected: &durationpb.Duration{Seconds: 1}}, {Input: "-1s", Expected: &durationpb.Duration{Seconds: -1}}, - {Input: "--1s", Expected: nil}, {Input: "1.5s", Expected: &durationpb.Duration{Seconds: 1, Nanos: 500000000}}, {Input: "-1.5s", Expected: &durationpb.Duration{Seconds: -1, Nanos: -500000000}}, {Input: "1.000000001s", Expected: &durationpb.Duration{Seconds: 1, Nanos: 1}}, - {Input: "1.0000000001s", Expected: nil}, + {Input: "1.0000000001s", ErrMsg: "fractional nanos"}, {Input: "1.000000000s", Expected: &durationpb.Duration{Seconds: 1}}, - {Input: "1.0000000010s", Expected: nil}, - {Input: "-1.000000001s", Expected: &durationpb.Duration{Seconds: -1, Nanos: -1}}, - {Input: "0s", Expected: &durationpb.Duration{}}, - {Input: "-0s", Expected: &durationpb.Duration{}}, - {Input: "0.1s", Expected: &durationpb.Duration{Nanos: 100000000}}, - {Input: "-0.1s", Expected: &durationpb.Duration{Nanos: -100000000}}, - {Input: "0.000000001s", Expected: &durationpb.Duration{Nanos: 1}}, - {Input: "0.0000000001s", Expected: nil}, - {Input: "0.000000000s", Expected: &durationpb.Duration{}}, - {Input: "0.0000000010s", Expected: nil}, - {Input: "-0.000000001s", Expected: &durationpb.Duration{Nanos: -1}}, + {Input: "1.0000000010s", Expected: &durationpb.Duration{Seconds: 1, Nanos: 1}}, + {Input: "1h", Expected: &durationpb.Duration{Seconds: 3600}}, + {Input: "1m", Expected: &durationpb.Duration{Seconds: 60}}, + {Input: "1h1m", Expected: &durationpb.Duration{Seconds: 3660}}, + {Input: "1h1m1s", Expected: &durationpb.Duration{Seconds: 3661}}, + {Input: "1h1m1.5s", Expected: &durationpb.Duration{Seconds: 3661, Nanos: 500000000}}, + {Input: "1.5h1m1.5s", Expected: &durationpb.Duration{Seconds: 5461, Nanos: 500000000}}, + {Input: "1.5h1m1.5s1.5h1m1.5s", Expected: &durationpb.Duration{Seconds: 10923}}, + {Input: "1h1m1s1ms1us1μs1µs1ns", Expected: &durationpb.Duration{Seconds: 3661, Nanos: 1003001}}, } { testCase := testCase t.Run(testCase.Input, func(t *testing.T) { t.Parallel() - actual := &durationpb.Duration{} - err := parseDuration(testCase.Input, actual) - if testCase.Expected == nil { - if err == nil { - t.Fatal("Expected error, got nil") - } - } else { - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(testCase.Expected, actual, protocmp.Transform()); diff != "" { - t.Errorf("Unexpected diff:\n%s", diff) - } + actual, err := ParseDuration(testCase.Input) + if testCase.ErrMsg != "" { + require.ErrorContains(t, err, testCase.ErrMsg) + return } + require.NoError(t, err) + assert.Equal(t, testCase.Expected.GetSeconds(), actual.GetSeconds()) + assert.Equal(t, testCase.Expected.GetNanos(), actual.GetNanos()) }) } } diff --git a/internal/cmd/generate-txt-testdata/main.go b/internal/cmd/generate-txt-testdata/main.go index 5b38e18..9a16a25 100644 --- a/internal/cmd/generate-txt-testdata/main.go +++ b/internal/cmd/generate-txt-testdata/main.go @@ -113,7 +113,7 @@ func tryParse(filePath string) (string, error) { return "", fmt.Errorf("unknown file type: %s", filePath) } if err != nil { - return err.Error(), nil //nolint:nilerr + return err.Error(), nil } return "", nil } diff --git a/internal/proto/buf/protoyaml/test/v1/validate.proto b/internal/proto/buf/protoyaml/test/v1/validate.proto index c051547..9ec9ae7 100644 --- a/internal/proto/buf/protoyaml/test/v1/validate.proto +++ b/internal/proto/buf/protoyaml/test/v1/validate.proto @@ -26,13 +26,13 @@ message ValidateTest { message ValidateTestCase { google.protobuf.Any dynamic = 1; float float_gt_lt = 2 [(buf.validate.field).float = { - gt: 0, + gt: 0 lt: 10 }]; map string_map = 3 [(buf.validate.field).map = { keys: { string: {pattern: "^[a-z]+$"} - }, + } values: { string: {pattern: "^[A-Z]+$"} } diff --git a/internal/testdata/basic.proto3test.txt b/internal/testdata/basic.proto3test.txt index 18330b9..585f707 100644 --- a/internal/testdata/basic.proto3test.txt +++ b/internal/testdata/basic.proto3test.txt @@ -186,27 +186,27 @@ internal/testdata/basic.proto3test.yaml:105:25 expected fields for bufext.cel.ex 105 | - standalone_message: [] 105 | ........................^ -internal/testdata/basic.proto3test.yaml:107:22 invalid duration: missing trailing 's' +internal/testdata/basic.proto3test.yaml:107:22 invalid duration: missing unit, expected one of [h m s ms us ns] 107 | - single_duration: 1 107 | .....................^ -internal/testdata/basic.proto3test.yaml:110:22 invalid duration: too many fractional second digits +internal/testdata/basic.proto3test.yaml:110:22 invalid duration: fractional nanos 110 | - single_duration: 1.0123456789s 110 | .....................^ -internal/testdata/basic.proto3test.yaml:111:22 invalid duration: strconv.ParseInt: parsing "A": invalid syntax +internal/testdata/basic.proto3test.yaml:111:22 invalid duration 111 | - single_duration: As 111 | .....................^ -internal/testdata/basic.proto3test.yaml:112:22 invalid duration: strconv.ParseInt: parsing "A": invalid syntax +internal/testdata/basic.proto3test.yaml:112:22 invalid duration 112 | - single_duration: A.1s 112 | .....................^ -internal/testdata/basic.proto3test.yaml:113:22 invalid duration: invalid duration: too many '.' characters +internal/testdata/basic.proto3test.yaml:113:22 invalid duration: missing unit, expected one of [h m s ms us ns] 113 | - single_duration: 1.1.1s 113 | .....................^ -internal/testdata/basic.proto3test.yaml:114:22 invalid duration: strconv.ParseInt: parsing "B": invalid syntax +internal/testdata/basic.proto3test.yaml:114:22 invalid duration: unknown unit, expected one of [h m s ms us ns] 114 | - single_duration: 1.Bs 114 | .....................^