Skip to content

Commit

Permalink
Kf operators use GetReplicaFunc (Error Handling) (#4471)
Browse files Browse the repository at this point in the history
* fix bug

Signed-off-by: Future Outlier <[email protected]>

* rename

Signed-off-by: Future Outlier <[email protected]>

* rename

Signed-off-by: Future Outlier <[email protected]>

* add test

Signed-off-by: Future Outlier <[email protected]>

* move GetReplicaCount to common package

Signed-off-by: Future Outlier <[email protected]>

* merge

Signed-off-by: Future Outlier <[email protected]>

* merge

Signed-off-by: Future Outlier <[email protected]>

---------

Signed-off-by: Future Outlier <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2023
1 parent 9df8677 commit ad0597c
Show file tree
Hide file tree
Showing 7 changed files with 54 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,11 @@ func ToReplicaSpecWithOverrides(ctx context.Context, taskCtx pluginsCore.TaskExe

return replicaSpec, nil
}

func GetReplicaCount(specs map[commonOp.ReplicaType]*commonOp.ReplicaSpec, replicaType commonOp.ReplicaType) *int32 {
if spec, ok := specs[replicaType]; ok && spec.Replicas != nil {
return spec.Replicas
}

return new(int32) // return 0 as default value
}
4 changes: 2 additions & 2 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ func (mpiOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

numWorkers = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas
numLauncherReplicas = app.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas
numWorkers = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker)
numLauncherReplicas = common.GetReplicaCount(app.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher)

taskLogs, err := common.GetLogs(pluginContext, common.MPITaskType, app.ObjectMeta, false,
*numWorkers, *numLauncherReplicas, 0, 0)
Expand Down
14 changes: 14 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,3 +776,17 @@ func TestBuildResourceMPIV1ResourceTolerations(t *testing.T) {
assert.NotContains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Tolerations, gpuToleration)
assert.Contains(t, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration)
}

func TestGetReplicaCount(t *testing.T) {
mpiResourceHandler := mpiOperatorResourceHandler{}
tfObj := dummyMPICustomObj(1, 1, 0)
taskTemplate := dummyMPITaskTemplate("the job", tfObj)
resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate, resourceRequirements, nil))
assert.NoError(t, err)
assert.NotNil(t, resource)
MPIJob, ok := resource.(*kubeflowv1.MPIJob)
assert.True(t, ok)

assert.NotNil(t, common.GetReplicaCount(MPIJob.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeWorker))
assert.NotNil(t, common.GetReplicaCount(MPIJob.Spec.MPIReplicaSpecs, kubeflowv1.MPIJobReplicaTypeLauncher))
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,18 @@ func ParseElasticConfig(elasticConfig ElasticConfig) *kubeflowv1.ElasticPolicy {
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (pytorchOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*kubeflowv1.PyTorchJob)
app, ok := resource.(*kubeflowv1.PyTorchJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

// Elastic PytorchJobs don't use master replicas
hasMaster := false
if _, ok := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok {
hasMaster = true
}

workersCount := app.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas
workersCount := common.GetReplicaCount(app.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker)

taskLogs, err := common.GetLogs(pluginContext, common.PytorchTaskType, app.ObjectMeta, hasMaster, *workersCount, 0, 0, 0)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,16 @@ func TestParseElasticConfig(t *testing.T) {
assert.Equal(t, int32(4), *elasticPolicy.NProcPerNode)
assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *elasticPolicy.RDZVBackend)
}

func TestGetReplicaCount(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}
tfObj := dummyPytorchCustomObj(1)
taskTemplate := dummyPytorchTaskTemplate("the job", tfObj)
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil))
assert.NoError(t, err)
assert.NotNil(t, resource)
PytorchJob, ok := resource.(*kubeflowv1.PyTorchJob)
assert.True(t, ok)

assert.NotNil(t, common.GetReplicaCount(PytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker))
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,24 +139,19 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task
return job, nil
}

func getReplicaCount(specs map[commonOp.ReplicaType]*commonOp.ReplicaSpec, replicaType commonOp.ReplicaType) *int32 {
if spec, ok := specs[replicaType]; ok && spec.Replicas != nil {
return spec.Replicas
}

return new(int32) // return 0 as default value
}

// Analyses the k8s resource and reports the status as TaskPhase. This call is expected to be relatively fast,
// any operations that might take a long time (limits are configured system-wide) should be offloaded to the
// background.
func (tensorflowOperatorResourceHandler) GetTaskPhase(_ context.Context, pluginContext k8s.PluginContext, resource client.Object) (pluginsCore.PhaseInfo, error) {
app := resource.(*kubeflowv1.TFJob)
app, ok := resource.(*kubeflowv1.TFJob)
if !ok {
return pluginsCore.PhaseInfoUndefined, fmt.Errorf("failed to convert resource data type")
}

workersCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)
psReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)
chiefCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)
evaluatorReplicasCount := getReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)
workersCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker)
psReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS)
chiefCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief)
evaluatorReplicasCount := common.GetReplicaCount(app.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval)

taskLogs, err := common.GetLogs(pluginContext, common.TensorflowTaskType, app.ObjectMeta, false,
*workersCount, *psReplicasCount, *chiefCount, *evaluatorReplicasCount)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,10 @@ func TestGetReplicaCount(t *testing.T) {
tensorflowJob, ok := resource.(*kubeflowv1.TFJob)
assert.True(t, ok)

assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief))
assert.NotNil(t, getReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeWorker))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypePS))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeChief))
assert.NotNil(t, common.GetReplicaCount(tensorflowJob.Spec.TFReplicaSpecs, kubeflowv1.TFJobReplicaTypeEval))
}

func TestBuildResourceTensorFlow(t *testing.T) {
Expand Down

0 comments on commit ad0597c

Please sign in to comment.