Skip to content

Commit

Permalink
fix: perform bhe client checkin in seperate goroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
ddlees committed Jan 26, 2023
1 parent 7116173 commit 922561a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 38 deletions.
93 changes: 56 additions & 37 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,51 +97,56 @@ func start(ctx context.Context) {
select {
case <-ticker.C:
if currentTask != nil {
log.V(1).Info("currently performing collection; continuing...")
log.V(1).Info("collection in progress...")
if err := checkin(ctx, *bheInstance, bheClient); err != nil {
log.Error(err, "bloodhound enterprise service checkin failed")
}
} else {
log.V(2).Info("checking for available collection tasks")
if availableTasks, err := getAvailableTasks(ctx, *bheInstance, bheClient); err != nil {
log.Error(err, "unable to fetch available tasks for azurehound")
} else {

// Get only the tasks that have reached their execution time
executableTasks := []models.ClientTask{}
now := time.Now()
for _, task := range availableTasks {
if task.ExectionTime.Before(now) || task.ExectionTime.Equal(now) {
executableTasks = append(executableTasks, task)
}
}

// Sort tasks in ascending order by execution time
sort.Slice(executableTasks, func(i, j int) bool {
return executableTasks[i].ExectionTime.Before(executableTasks[j].ExectionTime)
})

if len(executableTasks) == 0 {
log.V(2).Info("there are no tasks for azurehound to complete at this time")
go func() {
log.V(2).Info("checking for available collection tasks")
if availableTasks, err := getAvailableTasks(ctx, *bheInstance, bheClient); err != nil {
log.Error(err, "unable to fetch available tasks for azurehound")
} else {

// Notify BHE instance of task start
currentTask = &executableTasks[0]
startTask(ctx, *bheInstance, bheClient, currentTask.Id)
start := time.Now()
// Get only the tasks that have reached their execution time
executableTasks := []models.ClientTask{}
now := time.Now()
for _, task := range availableTasks {
if task.ExectionTime.Before(now) || task.ExectionTime.Equal(now) {
executableTasks = append(executableTasks, task)
}
}

// Sort tasks in ascending order by execution time
sort.Slice(executableTasks, func(i, j int) bool {
return executableTasks[i].ExectionTime.Before(executableTasks[j].ExectionTime)
})

// Batch data out for ingestion
stream := listAll(ctx, azClient)
batches := pipeline.Batch(ctx.Done(), stream, 999, 10*time.Second)
if err := ingest(ctx, *bheInstance, bheClient, batches); err != nil {
log.Error(err, "ingestion failed; collection will be re-attempted")
if len(executableTasks) == 0 {
log.V(2).Info("there are no tasks for azurehound to complete at this time")
} else {
// Notify BHE instance of task end
duration := time.Since(start)
endTask(ctx, *bheInstance, bheClient)
log.Info("finished collection task", "id", currentTask.Id, "duration", duration.String())

currentTask = nil
// Notify BHE instance of task start
currentTask = &executableTasks[0]
startTask(ctx, *bheInstance, bheClient, currentTask.Id)
start := time.Now()

// Batch data out for ingestion
stream := listAll(ctx, azClient)
batches := pipeline.Batch(ctx.Done(), stream, 999, 10*time.Second)
if err := ingest(ctx, *bheInstance, bheClient, batches); err != nil {
log.Error(err, "ingestion failed; collection will be re-attempted")
} else {
// Notify BHE instance of task end
duration := time.Since(start)
endTask(ctx, *bheInstance, bheClient)
log.Info("finished collection task", "id", currentTask.Id, "duration", duration.String())

currentTask = nil
}
}
}
}
}()
}
case <-ctx.Done():
return
Expand Down Expand Up @@ -193,6 +198,20 @@ func getAvailableTasks(ctx context.Context, bheUrl url.URL, bheClient *http.Clie
}
}

func checkin(ctx context.Context, bheUrl url.URL, bheClient *http.Client) error {
endpoint := bheUrl.ResolveReference(&url.URL{Path: "/api/v2/jobs/current"})

if req, err := rest.NewRequest(ctx, "GET", endpoint, nil, nil, nil); err != nil {
return err
} else if res, err := bheClient.Do(req); err != nil {
return err
} else if !contains([]int{http.StatusOK, http.StatusNotFound}, res.StatusCode) {
return fmt.Errorf("unexpected response code %s", res.Status)
} else {
return nil
}
}

func startTask(ctx context.Context, bheUrl url.URL, bheClient *http.Client, taskId int) error {
log.Info("beginning collection task", "id", taskId)
var (
Expand Down
2 changes: 1 addition & 1 deletion cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (s signingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return s.base.RoundTrip(clone)
}

func contains(collection []string, value string) bool {
func contains[T comparable](collection []T, value T) bool {
for _, item := range collection {
if item == value {
return true
Expand Down

0 comments on commit 922561a

Please sign in to comment.