From 01e809315cabd5f6f85bb120f4eb18497425a940 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan-Otto=20Kr=C3=B6pke?= Date: Sat, 28 Sep 2024 15:15:15 +0200 Subject: [PATCH] scheduled_task: fix memory leaks (#1649) --- .../scheduled_task/scheduled_task.go | 151 ++++++++++++------ 1 file changed, 98 insertions(+), 53 deletions(-) diff --git a/pkg/collector/scheduled_task/scheduled_task.go b/pkg/collector/scheduled_task/scheduled_task.go index 861bafe45..e38f6081c 100644 --- a/pkg/collector/scheduled_task/scheduled_task.go +++ b/pkg/collector/scheduled_task/scheduled_task.go @@ -33,6 +33,8 @@ var ConfigDefaults = Config{ type Collector struct { config Config + scheduledTasksCh chan *scheduledTaskResults + lastResult *prometheus.Desc missedRuns *prometheus.Desc state *prometheus.Desc @@ -57,7 +59,9 @@ const ( SCHED_S_TASK_HAS_NOT_RUN TaskResult = 0x00041303 ) -type ScheduledTask struct { +var taskStates = []string{"disabled", "queued", "ready", "running", "unknown"} + +type scheduledTask struct { Name string Path string Enabled bool @@ -66,7 +70,10 @@ type ScheduledTask struct { LastTaskResult TaskResult } -type ScheduledTasks []ScheduledTask +type scheduledTaskResults struct { + scheduledTasks []scheduledTask + err error +} func New(config *Config) *Collector { if config == nil { @@ -133,10 +140,23 @@ func (c *Collector) GetPerfCounter(_ *slog.Logger) ([]string, error) { } func (c *Collector) Close(_ *slog.Logger) error { + close(c.scheduledTasksCh) + + c.scheduledTasksCh = nil + return nil } func (c *Collector) Build(_ *slog.Logger, _ *wmi.Client) error { + initErrCh := make(chan error) + c.scheduledTasksCh = make(chan *scheduledTaskResults) + + go c.initializeScheduleService(initErrCh) + + if err := <-initErrCh; err != nil { + return fmt.Errorf("initialize schedule service: %w", err) + } + c.lastResult = prometheus.NewDesc( prometheus.BuildFQName(types.Namespace, Name, "last_result"), "The result that was returned the last time the registered task was run", @@ -174,12 +194,10 @@ func (c *Collector) Collect(_ *types.ScrapeContext, logger *slog.Logger, ch chan return nil } -var TASK_STATES = []string{"disabled", "queued", "ready", "running", "unknown"} - func (c *Collector) collect(ch chan<- prometheus.Metric) error { - scheduledTasks, err := getScheduledTasks() + scheduledTasks, err := c.getScheduledTasks() if err != nil { - return err + return fmt.Errorf("get scheduled tasks: %w", err) } for _, task := range scheduledTasks { @@ -188,7 +206,7 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error { continue } - for _, state := range TASK_STATES { + for _, state := range taskStates { var stateValue float64 if strings.ToLower(task.State.String()) == state { @@ -231,14 +249,15 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error { return nil } -const SCHEDULED_TASK_PROGRAM_ID = "Schedule.Service.1" +func (c *Collector) getScheduledTasks() ([]scheduledTask, error) { + c.scheduledTasksCh <- nil -// S_FALSE is returned by CoInitialize if it was already called on this thread. -const S_FALSE = 0x00000001 + scheduledTasks := <-c.scheduledTasksCh -func getScheduledTasks() (ScheduledTasks, error) { - var scheduledTasks ScheduledTasks + return scheduledTasks.scheduledTasks, scheduledTasks.err +} +func (c *Collector) initializeScheduleService(initErrCh chan<- error) { // The only way to run WMI queries in parallel while being thread-safe is to // ensure the CoInitialize[Ex]() call is bound to its current OS thread. // Otherwise, attempting to initialize and run parallel queries across @@ -248,72 +267,72 @@ func getScheduledTasks() (ScheduledTasks, error) { if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil { var oleCode *ole.OleError - if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != S_FALSE { - return nil, err + if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != wmi.S_FALSE { + initErrCh <- err + + return } } + defer ole.CoUninitialize() - schedClassID, err := ole.ClassIDFrom(SCHEDULED_TASK_PROGRAM_ID) + scheduleClassID, err := ole.ClassIDFrom("Schedule.Service.1") if err != nil { - return scheduledTasks, err + initErrCh <- err + + return } - taskSchedulerObj, err := ole.CreateInstance(schedClassID, nil) + taskSchedulerObj, err := ole.CreateInstance(scheduleClassID, nil) if err != nil || taskSchedulerObj == nil { - return scheduledTasks, err + initErrCh <- err + + return } defer taskSchedulerObj.Release() taskServiceObj := taskSchedulerObj.MustQueryInterface(ole.IID_IDispatch) - - _, err = oleutil.CallMethod(taskServiceObj, "Connect") - if err != nil { - return scheduledTasks, err - } - defer taskServiceObj.Release() - res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`) + taskService, err := oleutil.CallMethod(taskServiceObj, "Connect") if err != nil { - return scheduledTasks, err + initErrCh <- err + + return } - rootFolderObj := res.ToIDispatch() - defer rootFolderObj.Release() + defer func(taskService *ole.VARIANT) { + _ = taskService.Clear() + }(taskService) - err = fetchTasksRecursively(rootFolderObj, &scheduledTasks) + close(initErrCh) - return scheduledTasks, err -} + scheduledTasks := make([]scheduledTask, 0, 100) -func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error { - res, err := oleutil.CallMethod(folder, "GetTasks", 1) - if err != nil { - return err - } + for range c.scheduledTasksCh { + func() { + // Clear the slice to avoid memory leaks + clear(scheduledTasks) + scheduledTasks = scheduledTasks[:0] - tasks := res.ToIDispatch() - defer tasks.Release() - - err = oleutil.ForEach(tasks, func(v *ole.VARIANT) error { - task := v.ToIDispatch() - defer task.Release() + res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`) + if err != nil { + c.scheduledTasksCh <- &scheduledTaskResults{err: err} - parsedTask, err := parseTask(task) - if err != nil { - return err - } + return + } - *scheduledTasks = append(*scheduledTasks, parsedTask) + rootFolderObj := res.ToIDispatch() + defer rootFolderObj.Release() - return nil - }) + err = fetchTasksRecursively(rootFolderObj, &scheduledTasks) - return err + c.scheduledTasksCh <- &scheduledTaskResults{scheduledTasks: scheduledTasks, err: err} + }() + } } -func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error { +func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error { if err := fetchTasksInFolder(folder, scheduledTasks); err != nil { return err } @@ -336,8 +355,34 @@ func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks return err } -func parseTask(task *ole.IDispatch) (ScheduledTask, error) { - var scheduledTask ScheduledTask +func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error { + res, err := oleutil.CallMethod(folder, "GetTasks", 1) + if err != nil { + return err + } + + tasks := res.ToIDispatch() + defer tasks.Release() + + err = oleutil.ForEach(tasks, func(v *ole.VARIANT) error { + task := v.ToIDispatch() + defer task.Release() + + parsedTask, err := parseTask(task) + if err != nil { + return err + } + + *scheduledTasks = append(*scheduledTasks, parsedTask) + + return nil + }) + + return err +} + +func parseTask(task *ole.IDispatch) (scheduledTask, error) { + var scheduledTask scheduledTask taskNameVar, err := oleutil.GetProperty(task, "Name") if err != nil {