Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scheduler initialisation startup race #4132

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions developer/env/docker/executor.env
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
ARMADA_EXECUTORAPICONNECTION_ARMADAURL="scheduler:50052"
ARMADA_EXECUTORAPICONNECTION_FORCENOTLS=true
ARMADA_APPLICATION_JOBLEASEREQUESTTIMEOUT=5s
1 change: 1 addition & 0 deletions internal/executor/application.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ func setupExecutorApiComponents(
clusterUtilisationService,
config.Kubernetes.PodDefaults,
config.Application.MaxLeasedJobs,
config.Application.JobLeaseRequestTimeout,
)
clusterAllocationService := service.NewClusterAllocationService(
clusterContext,
Expand Down
5 changes: 4 additions & 1 deletion internal/executor/service/job_requester.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type JobRequester struct {
podDefaults *configuration.PodDefaults
jobRunStateStore job.RunStateStore
maxLeasedJobs int
maxRequestDuration time.Duration
}

func NewJobRequester(
Expand All @@ -35,6 +36,7 @@ func NewJobRequester(
utilisationService utilisation.UtilisationService,
podDefaults *configuration.PodDefaults,
maxLeasedJobs int,
maxRequestDuration time.Duration,
) *JobRequester {
return &JobRequester{
leaseRequester: leaseRequester,
Expand All @@ -44,6 +46,7 @@ func NewJobRequester(
clusterId: clusterId,
podDefaults: podDefaults,
maxLeasedJobs: maxLeasedJobs,
maxRequestDuration: maxRequestDuration,
}
}

Expand All @@ -53,7 +56,7 @@ func (r *JobRequester) RequestJobsRuns() {
log.Errorf("Failed to create lease request because %s", err)
return
}
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 30*time.Second)
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), r.maxRequestDuration)
defer cancel()
leaseResponse, err := r.leaseRequester.LeaseJobRuns(ctx, leaseRequest)
if err != nil {
Expand Down
16 changes: 14 additions & 2 deletions internal/executor/service/job_requester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package service
import (
"fmt"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
Expand All @@ -23,7 +24,10 @@ import (
"github.com/armadaproject/armada/pkg/executorapi"
)

const defaultMaxLeasedJobs int = 5
const (
defaultMaxLeasedJobs int = 5
defaultMaxRequestDuration = 30 * time.Second
)

func TestRequestJobsRuns_HandlesLeaseRequestError(t *testing.T) {
jobRequester, eventReporter, leaseRequester, stateStore, _ := setupJobRequesterTest([]*job.RunState{})
Expand Down Expand Up @@ -257,7 +261,15 @@ func setupJobRequesterTest(initialJobRuns []*job.RunState) (*JobRequester, *mock
utilisationService.ClusterAvailableCapacityReport = &utilisation.ClusterAvailableCapacityReport{
AvailableCapacity: &armadaresource.ComputeResources{},
}
jobRequester := NewJobRequester(clusterId, eventReporter, leaseRequester, stateStore, utilisationService, podDefaults, defaultMaxLeasedJobs)
jobRequester := NewJobRequester(
clusterId,
eventReporter,
leaseRequester,
stateStore,
utilisationService,
podDefaults,
defaultMaxLeasedJobs,
defaultMaxRequestDuration)
return jobRequester, eventReporter, leaseRequester, stateStore, utilisationService
}

Expand Down
9 changes: 9 additions & 0 deletions internal/scheduler/queue/queue_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ func NewQueueCache(apiClient api.SubmitClient, updateFrequency time.Duration) *A
}
}

func (c *ApiQueueCache) Initialise(ctx *armadacontext.Context) error {
err := c.fetchQueues(ctx)
if err != nil {
ctx.Errorf("Error initialising queue cache, failed fetching queues: %v", err)
}

return err
}

func (c *ApiQueueCache) Run(ctx *armadacontext.Context) error {
if err := c.fetchQueues(ctx); err != nil {
ctx.Warnf("Error fetching queues: %v", err)
Expand Down
8 changes: 8 additions & 0 deletions internal/scheduler/schedulerapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ func Run(config schedulerconfig.Configuration) error {
}()
armadaClient := api.NewSubmitClient(conn)
queueCache := queue.NewQueueCache(armadaClient, config.QueueRefreshPeriod)
err = queueCache.Initialise(ctx)
if err != nil {
return errors.WithMessage(err, "error initialising queue cache")
}
services = append(services, func() error { return queueCache.Run(ctx) })

// ////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -241,6 +245,10 @@ func Run(config schedulerconfig.Configuration) error {
floatingResourceTypes,
resourceListFactory,
)
err = submitChecker.Initialise(ctx)
if err != nil {
return errors.WithMessage(err, "error initialising submit checker")
}
services = append(services, func() error {
return submitChecker.Run(ctx)
})
Expand Down
35 changes: 23 additions & 12 deletions internal/scheduler/submitcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,36 +73,45 @@ func NewSubmitChecker(
}
}

func (srv *SubmitChecker) Initialise(ctx *armadacontext.Context) error {
err := srv.updateExecutors(ctx)
if err != nil {
ctx.Errorf("Error initialising submit checker: %v", err)
}

return err
}

func (srv *SubmitChecker) Run(ctx *armadacontext.Context) error {
ctx.Infof("Will refresh executor state every %s", srv.schedulingConfig.ExecutorUpdateFrequency)
srv.updateExecutors(ctx)
if err := srv.updateExecutors(ctx); err != nil {
logging.WithStacktrace(ctx, err).Error("Error fetching executors")
}

ticker := time.NewTicker(srv.schedulingConfig.ExecutorUpdateFrequency)
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
srv.updateExecutors(ctx)

if err := srv.updateExecutors(ctx); err != nil {
logging.WithStacktrace(ctx, err).Error("Error fetching executors")
}
}
}
}

func (srv *SubmitChecker) updateExecutors(ctx *armadacontext.Context) {
func (srv *SubmitChecker) updateExecutors(ctx *armadacontext.Context) error {
queues, err := srv.queueCache.GetAll(ctx)
if err != nil {
logging.
WithStacktrace(ctx, err).
Error("Error fetching queues")
return
return fmt.Errorf("failed fetching queues from queue cache - %s", err)
}

executors, err := srv.executorRepository.GetExecutors(ctx)
if err != nil {
logging.
WithStacktrace(ctx, err).
Error("Error fetching executors")
return
return fmt.Errorf("failed fetching executors from db - %s", err)
}

ctx.Infof("Retrieved %d executors", len(executors))
jobSchedulingResultsCache, err := lru.New(10000)
if err != nil {
Expand Down Expand Up @@ -167,6 +176,8 @@ func (srv *SubmitChecker) updateExecutors(ctx *armadacontext.Context) {
constraintsByPool: constraintsByPool,
jobSchedulingResultsCache: jobSchedulingResultsCache,
})

return nil
}

func (srv *SubmitChecker) Check(ctx *armadacontext.Context, jobs []*jobdb.Job) (map[string]schedulingResult, error) {
Expand Down
61 changes: 60 additions & 1 deletion internal/scheduler/submitcheck_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scheduler

import (
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -298,7 +299,8 @@ func TestSubmitChecker_CheckJobDbJobs(t *testing.T) {
floatingResources,
testfixtures.TestResourceListFactory)
submitCheck.clock = fakeClock
submitCheck.updateExecutors(ctx)
err := submitCheck.Initialise(ctx)
assert.NoError(t, err)
results, err := submitCheck.Check(ctx, tc.jobs)
require.NoError(t, err)
require.Equal(t, len(tc.expectedResult), len(results))
Expand All @@ -316,6 +318,63 @@ func TestSubmitChecker_CheckJobDbJobs(t *testing.T) {
}
}

func TestSubmitChecker_Initialise(t *testing.T) {
tests := map[string]struct {
queueCacheErr error
executorRepoErr error
expectError bool
}{
"Successful initialisation": {
expectError: false,
},
"error on queue cache error": {
expectError: true,
queueCacheErr: fmt.Errorf("failed to get queues"),
},
"error on executor repo err": {
expectError: true,
queueCacheErr: fmt.Errorf("failed to get executors"),
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
ctx, cancel := armadacontext.WithTimeout(armadacontext.Background(), 5*time.Second)
defer cancel()

queue := &api.Queue{Name: "queue"}
executors := []*schedulerobjects.Executor{Executor(SmallNode("cpu"))}

ctrl := gomock.NewController(t)
mockExecutorRepo := schedulermocks.NewMockExecutorRepository(ctrl)
if tc.executorRepoErr != nil {
mockExecutorRepo.EXPECT().GetExecutors(ctx).Return(nil, tc.executorRepoErr).AnyTimes()
} else {
mockExecutorRepo.EXPECT().GetExecutors(ctx).Return(executors, nil).AnyTimes()
}

mockQueueCache := schedulermocks.NewMockQueueCache(ctrl)
if tc.queueCacheErr != nil {
mockQueueCache.EXPECT().GetAll(ctx).Return(nil, tc.queueCacheErr).AnyTimes()
} else {
mockQueueCache.EXPECT().GetAll(ctx).Return([]*api.Queue{queue}, nil).AnyTimes()
}

submitCheck := NewSubmitChecker(testfixtures.TestSchedulingConfig(),
mockExecutorRepo,
mockQueueCache,
testfixtures.TestFloatingResources,
testfixtures.TestResourceListFactory)

err := submitCheck.Initialise(ctx)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

func Executor(nodes ...*schedulerobjects.Node) *schedulerobjects.Executor {
executorId := uuid.NewString()
for _, node := range nodes {
Expand Down
5 changes: 5 additions & 0 deletions magefiles/ci.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ func TestSuite() error {

// Checks if Armada is ready to accept jobs.
func CheckForArmadaRunning() error {
// 30s gives time for:
// Scheduler + executor to start up
// Executor to report its state
// Scheduler to update its executor states
// TODO replace with an API call to work out when executors are loaded into scheduler (armadactl get executors?)
time.Sleep(30 * time.Second)
mg.Deps(createQueue)

Expand Down
Loading