diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 338f6cd56..94978c9eb 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -2,7 +2,6 @@ package pytorch import ( "context" - "fmt" "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" @@ -69,9 +68,6 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName) workers := pytorchTaskExtraArgs.GetWorkers() - if workers == 0 { - return nil, fmt.Errorf("number of worker should be more then 0") - } var jobSpec kubeflowv1.PyTorchJobSpec @@ -115,23 +111,27 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx }, RestartPolicy: commonOp.RestartPolicyNever, }, - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, }, } + + if workers > 0 { + jobSpec.PyTorchReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker] = &commonOp.ReplicaSpec{ + Replicas: &workers, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + } + } } job := &kubeflowv1.PyTorchJob{ TypeMeta: metav1.TypeMeta{ Kind: kubeflowv1.PytorchJobKind, APIVersion: kubeflowv1.SchemeGroupVersion.String(), }, - Spec: jobSpec, + Spec: jobSpec, + ObjectMeta: *objectMeta, } return job, nil diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 74bd3fe92..23771b772 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -425,7 +425,7 @@ func TestReplicaCounts(t *testing.T) { contains []commonOp.ReplicaType notContains []commonOp.ReplicaType }{ - {"NoWorkers", 0, true, nil, nil}, + {"NoWorkers", 0, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster}, nil}, {"Works", 1, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster, kubeflowv1.PyTorchJobReplicaTypeWorker}, []commonOp.ReplicaType{}}, } { t.Run(test.name, func(t *testing.T) {