diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index c9685e41b..dd56cf1b0 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -227,7 +227,7 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res return restartPolicyMap[flyteRestartPolicy] } -func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources) (*v1.PodSpec, error) { +func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) (*v1.PodSpec, error) { for idx, c := range podSpec.Containers { if c.Name == containerName { if image != "" { @@ -240,6 +240,9 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri } podSpec.Containers[idx].Resources = *resources } + if args != nil && len(args) != 0 { + podSpec.Containers[idx].Args = args + } } } return podSpec, nil diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 7837f5f4a..241297540 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -7,6 +7,8 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" + flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -48,7 +50,6 @@ func (mpiOperatorResourceHandler) BuildIdentityResource(ctx context.Context, tas // Defines a func to create the full resource object that will be posted to k8s. func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) - taskTemplateConfig := taskTemplate.GetConfig() if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) @@ -56,69 +57,127 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - - workers := mpiTaskExtraArgs.GetNumWorkers() - launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() - slots := mpiTaskExtraArgs.GetSlots() - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.MPIJobDefaultContainerName) - // workersPodSpec is deepCopy of podSpec submitted by flyte - workersPodSpec := podSpec.DeepCopy() + var launcherReplica = common.ReplicaEntry{ + ReplicaNum: int32(1), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + var workerReplica = common.ReplicaEntry{ + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + slots := int32(1) + runPolicy := commonOp.RunPolicy{} + + if taskTemplate.TaskTypeVersion == 0 { + mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } - // If users don't specify "worker_spec_command" in the task config, the command/args are empty. - // However, in some cases, the workers need command/args. - // For example, in horovod tasks, each worker runs a command launching ssh daemon. + workerReplica.ReplicaNum = mpiTaskExtraArgs.GetNumWorkers() + launcherReplica.ReplicaNum = mpiTaskExtraArgs.GetNumLauncherReplicas() + slots = mpiTaskExtraArgs.GetSlots() - workerSpecCommand := []string{} - if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok { - workerSpecCommand = strings.Split(val, " ") - } + // V1 requires passing worker command as template config parameter + taskTemplateConfig := taskTemplate.GetConfig() + workerSpecCommand := []string{} + if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok { + workerSpecCommand = strings.Split(val, " ") + } + + for k := range workerReplica.PodSpec.Containers { + if workerReplica.PodSpec.Containers[k].Name == kubeflowv1.MPIJobDefaultContainerName { + workerReplica.PodSpec.Containers[k].Args = workerSpecCommand + workerReplica.PodSpec.Containers[k].Command = []string{} + } + } + + } else if taskTemplate.TaskTypeVersion == 1 { + kfMPITaskExtraArgs := kfplugins.DistributedMPITrainingTask{} + + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfMPITaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + launcherReplicaSpec := kfMPITaskExtraArgs.GetLauncherReplicas() + if launcherReplicaSpec != nil { + // flyte commands will be passed as args to the container + common.OverrideContainerSpec( + launcherReplica.PodSpec, + kubeflowv1.MPIJobDefaultContainerName, + launcherReplicaSpec.GetImage(), + launcherReplicaSpec.GetResources(), + launcherReplicaSpec.GetCommand(), + ) + launcherReplica.RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(launcherReplicaSpec.GetRestartPolicy()), + ) + } + + workerReplicaSpec := kfMPITaskExtraArgs.GetWorkerReplicas() + if workerReplicaSpec != nil { + common.OverrideContainerSpec( + workerReplica.PodSpec, + kubeflowv1.MPIJobDefaultContainerName, + workerReplicaSpec.GetImage(), + workerReplicaSpec.GetResources(), + workerReplicaSpec.GetCommand(), + ) + workerReplica.RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()), + ) + workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() + } + + if kfMPITaskExtraArgs.GetRunPolicy() != nil { + runPolicy = common.ParseRunPolicy(*kfMPITaskExtraArgs.GetRunPolicy()) + } - for k := range workersPodSpec.Containers { - workersPodSpec.Containers[k].Args = workerSpecCommand - workersPodSpec.Containers[k].Command = []string{} + } else { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if workers == 0 { + if workerReplica.ReplicaNum == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } - if launcherReplicas == 0 { + if launcherReplica.ReplicaNum == 0 { return nil, fmt.Errorf("number of launch worker should be more then 0") } jobSpec := kubeflowv1.MPIJobSpec{ - SlotsPerWorker: &slots, - MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, - } - - for _, t := range []struct { - podSpec v1.PodSpec - replicaNum *int32 - replicaType commonOp.ReplicaType - }{ - {*podSpec, &launcherReplicas, kubeflowv1.MPIJobReplicaTypeLauncher}, - {*workersPodSpec, &workers, kubeflowv1.MPIJobReplicaTypeWorker}, - } { - if *t.replicaNum > 0 { - jobSpec.MPIReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ - Replicas: t.replicaNum, + SlotsPerWorker: &slots, + RunPolicy: runPolicy, + MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.MPIJobReplicaTypeLauncher: { + Replicas: &launcherReplica.ReplicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: t.podSpec, + Spec: *launcherReplica.PodSpec, }, - RestartPolicy: commonOp.RestartPolicyNever, - } - } + RestartPolicy: launcherReplica.RestartPolicy, + }, + kubeflowv1.MPIJobReplicaTypeWorker: { + Replicas: &workerReplica.ReplicaNum, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *workerReplica.PodSpec, + }, + RestartPolicy: workerReplica.RestartPolicy, + }, + }, } job := &kubeflowv1.MPIJob{ diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 29fefd9ca..8af67fa36 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -8,6 +8,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -68,9 +69,24 @@ func dummyMPICustomObj(workers int32, launcher int32, slots int32) *plugins.Dist } } -func dummyMPITaskTemplate(id string, mpiCustomObj *plugins.DistributedMPITrainingTask) *core.TaskTemplate { +func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate { + + var mpiObjJSON string + var err error + + for _, arg := range args { + switch t := arg.(type) { + case *kfplugins.DistributedMPITrainingTask: + var mpiCustomObj = t + mpiObjJSON, err = utils.MarshalToString(mpiCustomObj) + case *plugins.DistributedMPITrainingTask: + var mpiCustomObj = t + mpiObjJSON, err = utils.MarshalToString(mpiCustomObj) + default: + err = fmt.Errorf("Unkonw input type %T", t) + } + } - mpiObjJSON, err := utils.MarshalToString(mpiCustomObj) if err != nil { panic(err) } @@ -427,3 +443,119 @@ func TestReplicaCounts(t *testing.T) { }) } } + +func TestBuildResourceMPIV1(t *testing.T) { + launcherCommand := []string{"python", "launcher.py"} + workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} + taskConfig := &kfplugins.DistributedMPITrainingTask{ + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + Command: launcherCommand, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + Command: workerCommand, + }, + Slots: int32(1), + } + + launcherResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + }, + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + mpiResourceHandler := mpiOperatorResourceHandler{} + + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *launcherResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, launcherCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) +} + +func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { + workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} + + taskConfig := &kfplugins.DistributedMPITrainingTask{ + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + Command: []string{"/usr/sbin/sshd", "/.sshd_config"}, + }, + Slots: int32(1), + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + mpiResourceHandler := mpiOperatorResourceHandler{} + + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, testArgs, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) +} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 27d4bb0a6..1d46b9ef3 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -106,6 +106,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx kubeflowv1.PytorchJobDefaultContainerName, masterReplicaSpec.GetImage(), masterReplicaSpec.GetResources(), + nil, ) masterReplica.RestartPolicy = commonOp.RestartPolicy( @@ -121,6 +122,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx kubeflowv1.PytorchJobDefaultContainerName, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources(), + nil, ) workerReplica.RestartPolicy = commonOp.RestartPolicy( diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 55900d46e..980062201 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -109,6 +109,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobDefaultContainerName, chiefReplicaSpec.GetImage(), chiefReplicaSpec.GetResources(), + nil, ) replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].RestartPolicy = commonOp.RestartPolicy( @@ -124,6 +125,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobDefaultContainerName, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources(), + nil, ) replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = commonOp.RestartPolicy( @@ -139,6 +141,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobDefaultContainerName, psReplicaSpec.GetImage(), psReplicaSpec.GetResources(), + nil, ) replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].RestartPolicy = commonOp.RestartPolicy(