diff --git a/cmd/start.go b/cmd/start.go index d8f5e6d..0199c18 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -97,51 +97,56 @@ func start(ctx context.Context) { select { case <-ticker.C: if currentTask != nil { - log.V(1).Info("currently performing collection; continuing...") + log.V(1).Info("collection in progress...") + if err := checkin(ctx, *bheInstance, bheClient); err != nil { + log.Error(err, "bloodhound enterprise service checkin failed") + } } else { - log.V(2).Info("checking for available collection tasks") - if availableTasks, err := getAvailableTasks(ctx, *bheInstance, bheClient); err != nil { - log.Error(err, "unable to fetch available tasks for azurehound") - } else { - - // Get only the tasks that have reached their execution time - executableTasks := []models.ClientTask{} - now := time.Now() - for _, task := range availableTasks { - if task.ExectionTime.Before(now) || task.ExectionTime.Equal(now) { - executableTasks = append(executableTasks, task) - } - } - - // Sort tasks in ascending order by execution time - sort.Slice(executableTasks, func(i, j int) bool { - return executableTasks[i].ExectionTime.Before(executableTasks[j].ExectionTime) - }) - - if len(executableTasks) == 0 { - log.V(2).Info("there are no tasks for azurehound to complete at this time") + go func() { + log.V(2).Info("checking for available collection tasks") + if availableTasks, err := getAvailableTasks(ctx, *bheInstance, bheClient); err != nil { + log.Error(err, "unable to fetch available tasks for azurehound") } else { - // Notify BHE instance of task start - currentTask = &executableTasks[0] - startTask(ctx, *bheInstance, bheClient, currentTask.Id) - start := time.Now() + // Get only the tasks that have reached their execution time + executableTasks := []models.ClientTask{} + now := time.Now() + for _, task := range availableTasks { + if task.ExectionTime.Before(now) || task.ExectionTime.Equal(now) { + executableTasks = append(executableTasks, task) + } + } + + // Sort tasks in ascending order by execution time + sort.Slice(executableTasks, func(i, j int) bool { + return executableTasks[i].ExectionTime.Before(executableTasks[j].ExectionTime) + }) - // Batch data out for ingestion - stream := listAll(ctx, azClient) - batches := pipeline.Batch(ctx.Done(), stream, 999, 10*time.Second) - if err := ingest(ctx, *bheInstance, bheClient, batches); err != nil { - log.Error(err, "ingestion failed; collection will be re-attempted") + if len(executableTasks) == 0 { + log.V(2).Info("there are no tasks for azurehound to complete at this time") } else { - // Notify BHE instance of task end - duration := time.Since(start) - endTask(ctx, *bheInstance, bheClient) - log.Info("finished collection task", "id", currentTask.Id, "duration", duration.String()) - currentTask = nil + // Notify BHE instance of task start + currentTask = &executableTasks[0] + startTask(ctx, *bheInstance, bheClient, currentTask.Id) + start := time.Now() + + // Batch data out for ingestion + stream := listAll(ctx, azClient) + batches := pipeline.Batch(ctx.Done(), stream, 999, 10*time.Second) + if err := ingest(ctx, *bheInstance, bheClient, batches); err != nil { + log.Error(err, "ingestion failed; collection will be re-attempted") + } else { + // Notify BHE instance of task end + duration := time.Since(start) + endTask(ctx, *bheInstance, bheClient) + log.Info("finished collection task", "id", currentTask.Id, "duration", duration.String()) + + currentTask = nil + } } } - } + }() } case <-ctx.Done(): return @@ -193,6 +198,20 @@ func getAvailableTasks(ctx context.Context, bheUrl url.URL, bheClient *http.Clie } } +func checkin(ctx context.Context, bheUrl url.URL, bheClient *http.Client) error { + endpoint := bheUrl.ResolveReference(&url.URL{Path: "/api/v2/jobs/current"}) + + if req, err := rest.NewRequest(ctx, "GET", endpoint, nil, nil, nil); err != nil { + return err + } else if res, err := bheClient.Do(req); err != nil { + return err + } else if !contains([]int{http.StatusOK, http.StatusNotFound}, res.StatusCode) { + return fmt.Errorf("unexpected response code %s", res.Status) + } else { + return nil + } +} + func startTask(ctx context.Context, bheUrl url.URL, bheClient *http.Client, taskId int) error { log.Info("beginning collection task", "id", taskId) var ( diff --git a/cmd/utils.go b/cmd/utils.go index c1bcd92..cfe0b55 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -321,7 +321,7 @@ func (s signingTransport) RoundTrip(req *http.Request) (*http.Response, error) { return s.base.RoundTrip(clone) } -func contains(collection []string, value string) bool { +func contains[T comparable](collection []T, value T) bool { for _, item := range collection { if item == value { return true