Skip to content

Commit

Permalink
scheduled_task: fix memory leaks (#1649)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkroepke authored Sep 28, 2024
1 parent 798bf32 commit 01e8093
Showing 1 changed file with 98 additions and 53 deletions.
151 changes: 98 additions & 53 deletions pkg/collector/scheduled_task/scheduled_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ var ConfigDefaults = Config{
type Collector struct {
config Config

scheduledTasksCh chan *scheduledTaskResults

lastResult *prometheus.Desc
missedRuns *prometheus.Desc
state *prometheus.Desc
Expand All @@ -57,7 +59,9 @@ const (
SCHED_S_TASK_HAS_NOT_RUN TaskResult = 0x00041303
)

type ScheduledTask struct {
var taskStates = []string{"disabled", "queued", "ready", "running", "unknown"}

type scheduledTask struct {
Name string
Path string
Enabled bool
Expand All @@ -66,7 +70,10 @@ type ScheduledTask struct {
LastTaskResult TaskResult
}

type ScheduledTasks []ScheduledTask
type scheduledTaskResults struct {
scheduledTasks []scheduledTask
err error
}

func New(config *Config) *Collector {
if config == nil {
Expand Down Expand Up @@ -133,10 +140,23 @@ func (c *Collector) GetPerfCounter(_ *slog.Logger) ([]string, error) {
}

func (c *Collector) Close(_ *slog.Logger) error {
close(c.scheduledTasksCh)

c.scheduledTasksCh = nil

return nil
}

func (c *Collector) Build(_ *slog.Logger, _ *wmi.Client) error {
initErrCh := make(chan error)
c.scheduledTasksCh = make(chan *scheduledTaskResults)

go c.initializeScheduleService(initErrCh)

if err := <-initErrCh; err != nil {
return fmt.Errorf("initialize schedule service: %w", err)
}

c.lastResult = prometheus.NewDesc(
prometheus.BuildFQName(types.Namespace, Name, "last_result"),
"The result that was returned the last time the registered task was run",
Expand Down Expand Up @@ -174,12 +194,10 @@ func (c *Collector) Collect(_ *types.ScrapeContext, logger *slog.Logger, ch chan
return nil
}

var TASK_STATES = []string{"disabled", "queued", "ready", "running", "unknown"}

func (c *Collector) collect(ch chan<- prometheus.Metric) error {
scheduledTasks, err := getScheduledTasks()
scheduledTasks, err := c.getScheduledTasks()
if err != nil {
return err
return fmt.Errorf("get scheduled tasks: %w", err)
}

for _, task := range scheduledTasks {
Expand All @@ -188,7 +206,7 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error {
continue
}

for _, state := range TASK_STATES {
for _, state := range taskStates {
var stateValue float64

if strings.ToLower(task.State.String()) == state {
Expand Down Expand Up @@ -231,14 +249,15 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error {
return nil
}

const SCHEDULED_TASK_PROGRAM_ID = "Schedule.Service.1"
func (c *Collector) getScheduledTasks() ([]scheduledTask, error) {
c.scheduledTasksCh <- nil

// S_FALSE is returned by CoInitialize if it was already called on this thread.
const S_FALSE = 0x00000001
scheduledTasks := <-c.scheduledTasksCh

func getScheduledTasks() (ScheduledTasks, error) {
var scheduledTasks ScheduledTasks
return scheduledTasks.scheduledTasks, scheduledTasks.err
}

func (c *Collector) initializeScheduleService(initErrCh chan<- error) {
// The only way to run WMI queries in parallel while being thread-safe is to
// ensure the CoInitialize[Ex]() call is bound to its current OS thread.
// Otherwise, attempting to initialize and run parallel queries across
Expand All @@ -248,72 +267,72 @@ func getScheduledTasks() (ScheduledTasks, error) {

if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
var oleCode *ole.OleError
if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != S_FALSE {
return nil, err
if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != wmi.S_FALSE {
initErrCh <- err

return
}
}

defer ole.CoUninitialize()

schedClassID, err := ole.ClassIDFrom(SCHEDULED_TASK_PROGRAM_ID)
scheduleClassID, err := ole.ClassIDFrom("Schedule.Service.1")
if err != nil {
return scheduledTasks, err
initErrCh <- err

return
}

taskSchedulerObj, err := ole.CreateInstance(schedClassID, nil)
taskSchedulerObj, err := ole.CreateInstance(scheduleClassID, nil)
if err != nil || taskSchedulerObj == nil {
return scheduledTasks, err
initErrCh <- err

return
}
defer taskSchedulerObj.Release()

taskServiceObj := taskSchedulerObj.MustQueryInterface(ole.IID_IDispatch)

_, err = oleutil.CallMethod(taskServiceObj, "Connect")
if err != nil {
return scheduledTasks, err
}

defer taskServiceObj.Release()

res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`)
taskService, err := oleutil.CallMethod(taskServiceObj, "Connect")
if err != nil {
return scheduledTasks, err
initErrCh <- err

return
}

rootFolderObj := res.ToIDispatch()
defer rootFolderObj.Release()
defer func(taskService *ole.VARIANT) {
_ = taskService.Clear()
}(taskService)

err = fetchTasksRecursively(rootFolderObj, &scheduledTasks)
close(initErrCh)

return scheduledTasks, err
}
scheduledTasks := make([]scheduledTask, 0, 100)

func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error {
res, err := oleutil.CallMethod(folder, "GetTasks", 1)
if err != nil {
return err
}
for range c.scheduledTasksCh {
func() {
// Clear the slice to avoid memory leaks
clear(scheduledTasks)
scheduledTasks = scheduledTasks[:0]

tasks := res.ToIDispatch()
defer tasks.Release()

err = oleutil.ForEach(tasks, func(v *ole.VARIANT) error {
task := v.ToIDispatch()
defer task.Release()
res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`)
if err != nil {
c.scheduledTasksCh <- &scheduledTaskResults{err: err}

parsedTask, err := parseTask(task)
if err != nil {
return err
}
return
}

*scheduledTasks = append(*scheduledTasks, parsedTask)
rootFolderObj := res.ToIDispatch()
defer rootFolderObj.Release()

return nil
})
err = fetchTasksRecursively(rootFolderObj, &scheduledTasks)

return err
c.scheduledTasksCh <- &scheduledTaskResults{scheduledTasks: scheduledTasks, err: err}
}()
}
}

func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error {
func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error {
if err := fetchTasksInFolder(folder, scheduledTasks); err != nil {
return err
}
Expand All @@ -336,8 +355,34 @@ func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks
return err
}

func parseTask(task *ole.IDispatch) (ScheduledTask, error) {
var scheduledTask ScheduledTask
func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error {
res, err := oleutil.CallMethod(folder, "GetTasks", 1)
if err != nil {
return err
}

tasks := res.ToIDispatch()
defer tasks.Release()

err = oleutil.ForEach(tasks, func(v *ole.VARIANT) error {
task := v.ToIDispatch()
defer task.Release()

parsedTask, err := parseTask(task)
if err != nil {
return err
}

*scheduledTasks = append(*scheduledTasks, parsedTask)

return nil
})

return err
}

func parseTask(task *ole.IDispatch) (scheduledTask, error) {
var scheduledTask scheduledTask

taskNameVar, err := oleutil.GetProperty(task, "Name")
if err != nil {
Expand Down

0 comments on commit 01e8093

Please sign in to comment.