diff --git a/flag_test.go b/flag_test.go index de5c232b80..265c9b5a81 100644 --- a/flag_test.go +++ b/flag_test.go @@ -2,11 +2,13 @@ package cli import ( "context" + "errors" "flag" "fmt" "io" "os" "reflect" + "regexp" "strings" "testing" "time" @@ -2285,30 +2287,201 @@ func TestTimestampFlagApply_SingleFormat(t *testing.T) { } func TestTimestampFlagApply_MultipleFormats(t *testing.T) { + now := time.Now().UTC() + + testCases := []struct { + caseName string + layoutsPrecisions map[string]time.Duration + expRes time.Time + expErrValidation func(err error) (validation error) + }{ + { + caseName: "all_valid_layouts", + layoutsPrecisions: map[string]time.Duration{ + time.RFC3339: time.Second, + time.DateTime: time.Second, + time.RFC1123: time.Second, + }, + expRes: now.Truncate(time.Second), + }, + { + caseName: "one_invalid_layout", + layoutsPrecisions: map[string]time.Duration{ + time.RFC3339: time.Second, + time.DateTime: time.Second, + "foo": 0, + }, + expRes: now.Truncate(time.Second), + }, + { + caseName: "multiple_invalid_layouts", + layoutsPrecisions: map[string]time.Duration{ + time.RFC3339: time.Second, + "foo": 0, + time.DateTime: time.Second, + "bar": 0, + }, + expRes: now.Truncate(time.Second), + }, + { + caseName: "all_invalid_layouts", + layoutsPrecisions: map[string]time.Duration{ + "foo": 0, + "2024-08-07 74:01:82Z-100": 0, + "25:70": 0, + "": 0, + }, + expErrValidation: func(err error) error { + if err == nil { + return errors.New("got nil err") + } + + found := regexp.MustCompile(`(cannot parse ".+" as ".*")|(extra text: ".+")`).Match([]byte(err.Error())) + if !found { + return fmt.Errorf("given error does not satisfy pattern: %w", err) + } + + return nil + }, + }, + { + caseName: "empty_layout", + layoutsPrecisions: map[string]time.Duration{ + "": 0, + }, + expErrValidation: func(err error) error { + if err == nil { + return errors.New("got nil err") + } + + found := regexp.MustCompile(`extra text: ".+"`).Match([]byte(err.Error())) + if !found { + return fmt.Errorf("given error does not satisfy pattern: %w", err) + } + + return nil + }, + }, + { + caseName: "nil_layouts_slice", + expErrValidation: func(err error) error { + if err == nil { + return errors.New("got nil err") + } + + found := regexp.MustCompile(`got nil/empty layouts slice"`).Match([]byte(err.Error())) + if !found { + return fmt.Errorf("given error does not satisfy pattern: %w", err) + } + + return nil + }, + }, + { + caseName: "empty_layouts_slice", + layoutsPrecisions: map[string]time.Duration{}, + expErrValidation: func(err error) error { + if err == nil { + return errors.New("got nil err") + } + + found := regexp.MustCompile(`got nil/empty layouts slice"`).Match([]byte(err.Error())) + if !found { + return fmt.Errorf("given error does not satisfy pattern: %w", err) + } + + return nil + }, + }, + } + + getKeys := func(m map[string]time.Duration) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys + } + + for idx := range testCases { + testCase := testCases[idx] + t.Run(testCase.caseName, func(t *testing.T) { + t.Parallel() + fl := TimestampFlag{ + Name: "time", + Config: TimestampConfig{ + AvailableLayouts: getKeys(testCase.layoutsPrecisions), + }, + } + + set := flag.NewFlagSet("test", 0) + _ = fl.Apply(set) + + validLayouts := make([]string, 0, len(testCase.layoutsPrecisions)) + invalidLayouts := make([]string, 0, len(testCase.layoutsPrecisions)) + + for layout, prec := range testCase.layoutsPrecisions { + v, err := time.Parse(layout, now.Format(layout)) + if err != nil || prec == 0 || now.Truncate(prec).UnixNano() != v.Truncate(prec).UnixNano() { + invalidLayouts = append(invalidLayouts, layout) + continue + } + validLayouts = append(validLayouts, layout) + } + + for _, layout := range validLayouts { + err := set.Parse([]string{"--time", now.Format(layout)}) + assert.NoError(t, err) + if !testCase.expRes.IsZero() { + assert.Equal(t, testCase.expRes, set.Lookup("time").Value.(flag.Getter).Get()) + } + } + + for range invalidLayouts { + err := set.Parse([]string{"--time", now.Format(time.RFC3339)}) + if testCase.expErrValidation != nil { + assert.NoError(t, testCase.expErrValidation(err)) + } + } + }) + } +} + +func TestTimestampFlagApply_ShortenedLayouts(t *testing.T) { + now := time.Now().UTC() + + shortenedLayoutsPrecisions := map[string]time.Duration{ + time.Kitchen: time.Minute, + time.Stamp: time.Second, + time.StampMilli: time.Millisecond, + time.StampMicro: time.Microsecond, + time.StampNano: time.Nanosecond, + time.TimeOnly: time.Second, + "15:04": time.Minute, + } + + getKeys := func(m map[string]time.Duration) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys + } + fl := TimestampFlag{ - Name: "time", - Aliases: []string{"t"}, + Name: "time", Config: TimestampConfig{ - Timezone: time.Local, - AvailableLayouts: []string{ - time.DateTime, - time.TimeOnly, - time.Kitchen, - }, + AvailableLayouts: getKeys(shortenedLayoutsPrecisions), }, } + set := flag.NewFlagSet("test", 0) _ = fl.Apply(set) - now := time.Now() - for timeStr, expectedRes := range map[string]time.Time{ - now.Format(time.DateTime): now.Truncate(time.Second), - now.Format(time.TimeOnly): now.Truncate(time.Second), - now.Format(time.Kitchen): now.Truncate(time.Minute), - } { - err := set.Parse([]string{"--time", timeStr}) + for layout, prec := range shortenedLayoutsPrecisions { + err := set.Parse([]string{"--time", now.Format(layout)}) assert.NoError(t, err) - assert.Equal(t, expectedRes, set.Lookup("time").Value.(flag.Getter).Get()) + assert.Equal(t, now.Truncate(prec), set.Lookup("time").Value.(flag.Getter).Get()) } } diff --git a/flag_timestamp.go b/flag_timestamp.go index 79d3cb086d..090f03fc83 100644 --- a/flag_timestamp.go +++ b/flag_timestamp.go @@ -1,6 +1,7 @@ package cli import ( + "errors" "fmt" "time" ) @@ -57,6 +58,10 @@ func (t *timestampValue) Set(value string) error { t.location = time.UTC } + if len(t.availableLayouts) == 0 { + return errors.New("got nil/empty layouts slice") + } + for _, layout := range t.availableLayouts { var locErr error @@ -83,9 +88,10 @@ func (t *timestampValue) Set(value string) error { defaultTS, _ := time.ParseInLocation(time.TimeOnly, time.TimeOnly, timestamp.Location()) - // If format is missing date, set it explicitly to current + n := time.Now() + + // If format is missing date (or year only), set it explicitly to current if timestamp.Truncate(time.Hour*24).UnixNano() == defaultTS.Truncate(time.Hour*24).UnixNano() { - n := time.Now() timestamp = time.Date( n.Year(), n.Month(), @@ -96,6 +102,17 @@ func (t *timestampValue) Set(value string) error { timestamp.Nanosecond(), timestamp.Location(), ) + } else if timestamp.Year() == 0 { + timestamp = time.Date( + n.Year(), + timestamp.Month(), + timestamp.Day(), + timestamp.Hour(), + timestamp.Minute(), + timestamp.Second(), + timestamp.Nanosecond(), + timestamp.Location(), + ) } if t.timestamp != nil {