diff --git a/adapters/backend/v1/adapter.go b/adapters/backend/v1/adapter.go index 216ddd9..c07cc73 100644 --- a/adapters/backend/v1/adapter.go +++ b/adapters/backend/v1/adapter.go @@ -193,24 +193,39 @@ func (b *Adapter) Batch(ctx context.Context, kind domain.Kind, batchType domain. // startReconciliationPeriodicTask starts a periodic task that sends reconciliation request messages to connected clients // every configurable minutes (interval). If interval is 0 (not set), the task is disabled. +// when cron schedule is set, the task will be executed according to the cron schedule. // intervalFromConnection is the minimum interval time in minutes from the connection time that the reconciliation task will be sent. func (a *Adapter) startReconciliationPeriodicTask(mainCtx context.Context, cfg *config.ReconciliationTaskConfig) { - if cfg == nil || cfg.TaskIntervalSeconds == 0 || cfg.IntervalFromConnectionSeconds == 0 { + if cfg == nil || (cfg.TaskIntervalSeconds == 0 && cfg.CronSchedule == "") || cfg.IntervalFromConnectionSeconds == 0 { logger.L().Warning("reconciliation task is disabled (intervals are not set)") return } go func() { - logger.L().Info("starting reconciliation periodic task", - helpers.Int("TaskIntervalSeconds", cfg.TaskIntervalSeconds), - helpers.Int("IntervalFromConnectionSeconds", cfg.IntervalFromConnectionSeconds)) - ticker := time.NewTicker(time.Duration(cfg.TaskIntervalSeconds) * time.Second) + var ticker utils.Ticker + if cfg.CronSchedule != "" { + var err error + ticker, err = utils.NewCronTicker(cfg.CronSchedule) + if err != nil { + logger.L().Warning("failed to create cron ticker", helpers.String("error", err.Error())) + } else { + logger.L().Info("starting reconciliation periodic task with cron schedule", + helpers.String("CronSchedule", cfg.CronSchedule), + helpers.Int("IntervalFromConnectionSeconds", cfg.IntervalFromConnectionSeconds)) + } + } + if ticker == nil { + logger.L().Info("starting reconciliation periodic task with interval", + helpers.Int("TaskIntervalSeconds", cfg.TaskIntervalSeconds), + helpers.Int("IntervalFromConnectionSeconds", cfg.IntervalFromConnectionSeconds)) + ticker = utils.NewStdTicker(time.Duration(cfg.TaskIntervalSeconds) * time.Second) + } for { select { case <-mainCtx.Done(): ticker.Stop() return - case <-ticker.C: + case <-ticker.Chan(): a.connMapMutex.Lock() logger.L().Info("running reconciliation task for connected clients", helpers.Int("clients", a.clientsMap.Len())) for connId, clientId := range a.connectionMap { diff --git a/config/config.go b/config/config.go index fe48e3c..67c5d0d 100644 --- a/config/config.go +++ b/config/config.go @@ -70,8 +70,9 @@ type PrometheusConfig struct { } type ReconciliationTaskConfig struct { - TaskIntervalSeconds int `mapstructure:"taskIntervalSeconds"` - IntervalFromConnectionSeconds int `mapstructure:"intervalFromConnectionSeconds"` + CronSchedule string `mapstructure:"cronSchedule"` // when this is set, taskIntervalSeconds is ignored + TaskIntervalSeconds int `mapstructure:"taskIntervalSeconds"` + IntervalFromConnectionSeconds int `mapstructure:"intervalFromConnectionSeconds"` } type KeepAliveTaskConfig struct { diff --git a/go.mod b/go.mod index eec2d63..c6ab90a 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/panjf2000/ants/v2 v2.9.1 github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 github.com/prometheus/client_golang v1.19.0 + github.com/robfig/cron/v3 v3.0.1 github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.30.0 diff --git a/go.sum b/go.sum index a477385..00ce753 100644 --- a/go.sum +++ b/go.sum @@ -634,6 +634,8 @@ github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsT github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= diff --git a/tests/go.mod b/tests/go.mod index eb345bc..de24f06 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -283,6 +283,7 @@ require ( github.com/prometheus/procfs v0.13.0 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect github.com/saferwall/pe v1.5.2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect diff --git a/tests/go.sum b/tests/go.sum index 300b359..0e5e011 100644 --- a/tests/go.sum +++ b/tests/go.sum @@ -1022,6 +1022,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= diff --git a/utils/ticker.go b/utils/ticker.go new file mode 100644 index 0000000..5050892 --- /dev/null +++ b/utils/ticker.go @@ -0,0 +1,141 @@ +package utils + +import ( + "fmt" + "strings" + "time" + + "github.com/robfig/cron/v3" +) + +// inspired by https://github.com/krayzpipes/cronticker/blob/main/cronticker/ticker.go + +type Ticker interface { + Chan() <-chan time.Time + Stop() +} + +type StdTicker struct { + *time.Ticker +} + +func (g *StdTicker) Chan() <-chan time.Time { + return g.C +} + +func NewStdTicker(d time.Duration) *StdTicker { + return &StdTicker{time.NewTicker(d)} +} + +var _ Ticker = (*StdTicker)(nil) + +// CronTicker is the struct returned to the user as a proxy +// to the ticker. The user can check the ticker channel for the next +// 'tick' via CronTicker.C (similar to the user of time.Timer). +type CronTicker struct { + C chan time.Time + k chan bool +} + +var _ Ticker = (*CronTicker)(nil) + +// Stop sends the appropriate message on the control channel to +// kill the CronTicker goroutines. It's good practice to use `defer CronTicker.Stop()`. +func (c *CronTicker) Stop() { + c.k <- true +} + +func (c *CronTicker) Chan() <-chan time.Time { + return c.C +} + +// NewCronTicker returns a CronTicker struct. +// You can check the ticker channel for the next tick by +// `CronTicker.Chan()`. +func NewCronTicker(schedule string) (*CronTicker, error) { + var cronTicker CronTicker + var err error + + cronTicker.C = make(chan time.Time, 1) + cronTicker.k = make(chan bool, 1) + + err = newCronTicker(schedule, cronTicker.C, cronTicker.k) + if err != nil { + return nil, err + } + return &cronTicker, nil +} + +// newCronTicker prepares the channels, parses the schedule, and kicks off +// the goroutine that handles scheduling of each 'tick'. +func newCronTicker(schedule string, c chan time.Time, k <-chan bool) error { + var err error + + scheduleWithTZ, loc, err := guaranteeTimeZone(schedule) + if err != nil { + return err + } + parser := getScheduleParser() + + cronSchedule, err := parser.Parse(scheduleWithTZ) + if err != nil { + return err + } + + go cronRunner(cronSchedule, loc, c, k) + + return nil + +} + +// getScheduleParser returns a new parser that allows the use of the 'seconds' field +// like in the Quarts cron format, as well as descriptors such as '@weekly'. +func getScheduleParser() cron.Parser { + parser := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) + return parser +} + +// guaranteeTimeZone sets the `TZ=` value to `UTC` if there is none +// already in the cron schedule string. +func guaranteeTimeZone(schedule string) (string, *time.Location, error) { + var loc *time.Location + + // If time zone is not included, set default to UTC + if !strings.HasPrefix(schedule, "TZ=") { + schedule = fmt.Sprintf("TZ=%s %s", "UTC", schedule) + } + + tz := extractTZ(schedule) + + loc, err := time.LoadLocation(tz) + if err != nil { + return schedule, loc, err + } + + return schedule, loc, nil +} + +func extractTZ(schedule string) string { + end := strings.Index(schedule, " ") + eq := strings.Index(schedule, "=") + return schedule[eq+1 : end] +} + +// cronRunner handles calculating the next 'tick'. It communicates to +// the CronTicker via a channel and will stop/return whenever it receives +// a bool on the `k` channel. +func cronRunner(schedule cron.Schedule, loc *time.Location, c chan time.Time, k <-chan bool) { + nextTick := schedule.Next(time.Now().In(loc)) + timer := time.NewTimer(time.Until(nextTick)) + for { + select { + case <-k: + timer.Stop() + return + case tickTime := <-timer.C: + c <- tickTime + nextTick = schedule.Next(tickTime.In(loc)) + timer.Reset(time.Until(nextTick)) + } + } +}