Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scheduled_task: fix memory leaks #1649

Merged
merged 2 commits into from
Sep 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 98 additions & 53 deletions pkg/collector/scheduled_task/scheduled_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ var ConfigDefaults = Config{
type Collector struct {
config Config

scheduledTasksCh chan *scheduledTaskResults

lastResult *prometheus.Desc
missedRuns *prometheus.Desc
state *prometheus.Desc
Expand All @@ -57,7 +59,9 @@ const (
SCHED_S_TASK_HAS_NOT_RUN TaskResult = 0x00041303
)

type ScheduledTask struct {
var taskStates = []string{"disabled", "queued", "ready", "running", "unknown"}

type scheduledTask struct {
Name string
Path string
Enabled bool
Expand All @@ -66,7 +70,10 @@ type ScheduledTask struct {
LastTaskResult TaskResult
}

type ScheduledTasks []ScheduledTask
type scheduledTaskResults struct {
scheduledTasks []scheduledTask
err error
}

func New(config *Config) *Collector {
if config == nil {
Expand Down Expand Up @@ -133,10 +140,23 @@ func (c *Collector) GetPerfCounter(_ *slog.Logger) ([]string, error) {
}

func (c *Collector) Close(_ *slog.Logger) error {
close(c.scheduledTasksCh)

c.scheduledTasksCh = nil

return nil
}

func (c *Collector) Build(_ *slog.Logger, _ *wmi.Client) error {
initErrCh := make(chan error)
c.scheduledTasksCh = make(chan *scheduledTaskResults)

go c.initializeScheduleService(initErrCh)

if err := <-initErrCh; err != nil {
return fmt.Errorf("initialize schedule service: %w", err)
}

c.lastResult = prometheus.NewDesc(
prometheus.BuildFQName(types.Namespace, Name, "last_result"),
"The result that was returned the last time the registered task was run",
Expand Down Expand Up @@ -174,12 +194,10 @@ func (c *Collector) Collect(_ *types.ScrapeContext, logger *slog.Logger, ch chan
return nil
}

var TASK_STATES = []string{"disabled", "queued", "ready", "running", "unknown"}

func (c *Collector) collect(ch chan<- prometheus.Metric) error {
scheduledTasks, err := getScheduledTasks()
scheduledTasks, err := c.getScheduledTasks()
if err != nil {
return err
return fmt.Errorf("get scheduled tasks: %w", err)
}

for _, task := range scheduledTasks {
Expand All @@ -188,7 +206,7 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error {
continue
}

for _, state := range TASK_STATES {
for _, state := range taskStates {
var stateValue float64

if strings.ToLower(task.State.String()) == state {
Expand Down Expand Up @@ -231,14 +249,15 @@ func (c *Collector) collect(ch chan<- prometheus.Metric) error {
return nil
}

const SCHEDULED_TASK_PROGRAM_ID = "Schedule.Service.1"
func (c *Collector) getScheduledTasks() ([]scheduledTask, error) {
c.scheduledTasksCh <- nil

// S_FALSE is returned by CoInitialize if it was already called on this thread.
const S_FALSE = 0x00000001
scheduledTasks := <-c.scheduledTasksCh

func getScheduledTasks() (ScheduledTasks, error) {
var scheduledTasks ScheduledTasks
return scheduledTasks.scheduledTasks, scheduledTasks.err
}

func (c *Collector) initializeScheduleService(initErrCh chan<- error) {
// The only way to run WMI queries in parallel while being thread-safe is to
// ensure the CoInitialize[Ex]() call is bound to its current OS thread.
// Otherwise, attempting to initialize and run parallel queries across
Expand All @@ -248,72 +267,72 @@ func getScheduledTasks() (ScheduledTasks, error) {

if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
var oleCode *ole.OleError
if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != S_FALSE {
return nil, err
if errors.As(err, &oleCode) && oleCode.Code() != ole.S_OK && oleCode.Code() != wmi.S_FALSE {
initErrCh <- err

return
}
}

defer ole.CoUninitialize()

schedClassID, err := ole.ClassIDFrom(SCHEDULED_TASK_PROGRAM_ID)
scheduleClassID, err := ole.ClassIDFrom("Schedule.Service.1")
if err != nil {
return scheduledTasks, err
initErrCh <- err

return
}

taskSchedulerObj, err := ole.CreateInstance(schedClassID, nil)
taskSchedulerObj, err := ole.CreateInstance(scheduleClassID, nil)
if err != nil || taskSchedulerObj == nil {
return scheduledTasks, err
initErrCh <- err

return
}
defer taskSchedulerObj.Release()

taskServiceObj := taskSchedulerObj.MustQueryInterface(ole.IID_IDispatch)

_, err = oleutil.CallMethod(taskServiceObj, "Connect")
if err != nil {
return scheduledTasks, err
}

defer taskServiceObj.Release()

res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`)
taskService, err := oleutil.CallMethod(taskServiceObj, "Connect")
if err != nil {
return scheduledTasks, err
initErrCh <- err

return
}

rootFolderObj := res.ToIDispatch()
defer rootFolderObj.Release()
defer func(taskService *ole.VARIANT) {
_ = taskService.Clear()
}(taskService)

err = fetchTasksRecursively(rootFolderObj, &scheduledTasks)
close(initErrCh)

return scheduledTasks, err
}
scheduledTasks := make([]scheduledTask, 0, 100)

func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error {
res, err := oleutil.CallMethod(folder, "GetTasks", 1)
if err != nil {
return err
}
for range c.scheduledTasksCh {
func() {
// Clear the slice to avoid memory leaks
clear(scheduledTasks)
scheduledTasks = scheduledTasks[:0]

tasks := res.ToIDispatch()
defer tasks.Release()

err = oleutil.ForEach(tasks, func(v *ole.VARIANT) error {
task := v.ToIDispatch()
defer task.Release()
res, err := oleutil.CallMethod(taskServiceObj, "GetFolder", `\`)
if err != nil {
c.scheduledTasksCh <- &scheduledTaskResults{err: err}

parsedTask, err := parseTask(task)
if err != nil {
return err
}
return
}

*scheduledTasks = append(*scheduledTasks, parsedTask)
rootFolderObj := res.ToIDispatch()
defer rootFolderObj.Release()

return nil
})
err = fetchTasksRecursively(rootFolderObj, &scheduledTasks)

return err
c.scheduledTasksCh <- &scheduledTaskResults{scheduledTasks: scheduledTasks, err: err}
}()
}
}

func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks) error {
func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error {
if err := fetchTasksInFolder(folder, scheduledTasks); err != nil {
return err
}
Expand All @@ -336,8 +355,34 @@ func fetchTasksRecursively(folder *ole.IDispatch, scheduledTasks *ScheduledTasks
return err
}

func parseTask(task *ole.IDispatch) (ScheduledTask, error) {
var scheduledTask ScheduledTask
func fetchTasksInFolder(folder *ole.IDispatch, scheduledTasks *[]scheduledTask) error {
res, err := oleutil.CallMethod(folder, "GetTasks", 1)
if err != nil {
return err
}

tasks := res.ToIDispatch()
defer tasks.Release()

err = oleutil.ForEach(tasks, func(v *ole.VARIANT) error {
task := v.ToIDispatch()
defer task.Release()

parsedTask, err := parseTask(task)
if err != nil {
return err
}

*scheduledTasks = append(*scheduledTasks, parsedTask)

return nil
})

return err
}

func parseTask(task *ole.IDispatch) (scheduledTask, error) {
var scheduledTask scheduledTask

taskNameVar, err := oleutil.GetProperty(task, "Name")
if err != nil {
Expand Down