From 793c16041c3f14d068e40ff2532eaf66cbe15ca7 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Thu, 20 Apr 2023 17:56:48 -0700 Subject: [PATCH 1/9] change pytorch plugin to accept new pytorch task idl Signed-off-by: Yubo Wang --- go.mod | 1 + .../k8s/kfoperators/common/common_operator.go | 58 +++++++ .../k8s/kfoperators/pytorch/pytorch.go | 142 ++++++++++++------ 3 files changed, 157 insertions(+), 44 deletions(-) diff --git a/go.mod b/go.mod index 6a2331949..7616dbb36 100644 --- a/go.mod +++ b/go.mod @@ -135,3 +135,4 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d +replace github.com/flyteorg/flyteidl => ../flyteidl \ No newline at end of file diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 88419b64c..1600d708f 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -5,9 +5,11 @@ import ( "sort" "time" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -180,3 +182,59 @@ func OverridePrimaryContainerName(podSpec *v1.PodSpec, primaryContainerName stri } } } + +func ParseRunPolicy(flyteRunPolicy kfplugins.RunPolicy) commonOp.RunPolicy { + runPolicy := commonOp.RunPolicy{} + if flyteRunPolicy.GetBackoffLimit() != 0 { + var backoffLimit = flyteRunPolicy.GetBackoffLimit() + runPolicy.BackoffLimit = &backoffLimit + } + var cleanPodPolicy = ParseCleanPodPolicy(flyteRunPolicy.GetCleanPodPolicy()) + runPolicy.CleanPodPolicy = &cleanPodPolicy + if flyteRunPolicy.GetActiveDeadlineSeconds() != 0 { + var ddlSeconds = int64(flyteRunPolicy.GetActiveDeadlineSeconds()) + runPolicy.ActiveDeadlineSeconds = &ddlSeconds + } + if flyteRunPolicy.GetTtlSecondsAfterFinished() != 0 { + var ttl = flyteRunPolicy.GetTtlSecondsAfterFinished() + runPolicy.TTLSecondsAfterFinished = &ttl + } + + return runPolicy +} + +func ParseCleanPodPolicy(flyteCleanPodPolicy kfplugins.CleanPodPolicy) commonOp.CleanPodPolicy { + cleanPodPolicyMap := map[kfplugins.CleanPodPolicy]commonOp.CleanPodPolicy{ + kfplugins.CleanPodPolicy_CLEANPOD_POLICY_NONE: commonOp.CleanPodPolicyNone, + kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL: commonOp.CleanPodPolicyAll, + kfplugins.CleanPodPolicy_CLEANPOD_POLICY_RUNNING: commonOp.CleanPodPolicyRunning, + } + return cleanPodPolicyMap[flyteCleanPodPolicy] +} + +func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.RestartPolicy { + restartPolicyMap := map[kfplugins.RestartPolicy]commonOp.RestartPolicy{ + kfplugins.RestartPolicy_RESTART_POLICY_NEVER: commonOp.RestartPolicyNever, + kfplugins.RestartPolicy_RESTART_POLICY_ON_FAILURE: commonOp.RestartPolicyOnFailure, + kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS: commonOp.RestartPolicyAlways, + } + return restartPolicyMap[flyteRestartPolicy] +} + +func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources) (*v1.PodSpec, error) { + for idx, c := range podSpec.Containers { + if c.Name == containerName { + if image != "" { + podSpec.Containers[idx].Image = image + } + if resources != nil { + resources, err := flytek8s.ToK8sResourceRequirements(resources) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) + } + podSpec.Containers[idx].Resources = *resources + } + } + } + return podSpec, nil +} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 338f6cd56..590db4dd9 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -6,6 +6,7 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" @@ -68,64 +69,117 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName) - workers := pytorchTaskExtraArgs.GetWorkers() + var workers int32 + var runPolicy = commonOp.RunPolicy{} + var workerPodSpec = podSpec.DeepCopy() + var masterPodSpec = podSpec.DeepCopy() + var workerRestartPolicy = commonOp.RestartPolicyNever + var masterRestartPolicy = commonOp.RestartPolicyNever + + if taskTemplate.TaskTypeVersion == 0 { + pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs) + + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + workers = pytorchTaskExtraArgs.GetWorkers() + } else if taskTemplate.TaskTypeVersion == 1 { + kfPytorchTaskExtraArgs := kfplugins.DistributedPyTorchTrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfPytorchTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + workerReplicaSpec := kfPytorchTaskExtraArgs.GetWorkerReplicas() + masterReplicaSpec := kfPytorchTaskExtraArgs.GetMasterReplicas() + + // Replace specs of worker replica + if workerReplicaSpec != nil { + for _, c := range workerPodSpec.Containers { + if c.Name == kubeflowv1.PytorchJobDefaultContainerName { + common.OverrideContainerSpec(workerPodSpec, c.Name, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources()) + } + } + workerRestartPolicy = commonOp.RestartPolicy(common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy())) + workers = workerReplicaSpec.GetReplicas() + } + + // Replace specs of master replica + if masterReplicaSpec != nil { + for _, c := range masterPodSpec.Containers { + if c.Name == kubeflowv1.PytorchJobDefaultContainerName { + common.OverrideContainerSpec(masterPodSpec, c.Name, masterReplicaSpec.GetImage(), masterReplicaSpec.GetResources()) + } + masterRestartPolicy = commonOp.RestartPolicy(common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy())) + } + } + + if kfPytorchTaskExtraArgs.GetRunPolicy() != nil { + runPolicy = common.ParseRunPolicy(*kfPytorchTaskExtraArgs.GetRunPolicy()) + } + } else { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) + } + if workers == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } - var jobSpec kubeflowv1.PyTorchJobSpec + jobSpec := kubeflowv1.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.PyTorchJobReplicaTypeMaster: { + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *workerPodSpec, + }, + RestartPolicy: workerRestartPolicy, + }, + kubeflowv1.PyTorchJobReplicaTypeWorker: { + Replicas: &workers, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *masterPodSpec, + }, + RestartPolicy: masterRestartPolicy, + }, + }, + RunPolicy: runPolicy, + } + - elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() + jobSpec = kubeflowv1.PyTorchJobSpec{ + PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.PyTorchJobReplicaTypeWorker: { + Replicas: &workers, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *podSpec, + }, + RestartPolicy: commonOp.RestartPolicyNever, + }, + }, + } + // Set elastic config + elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() if elasticConfig != nil { minReplicas := elasticConfig.GetMinReplicas() maxReplicas := elasticConfig.GetMaxReplicas() nProcPerNode := elasticConfig.GetNprocPerNode() maxRestarts := elasticConfig.GetMaxRestarts() rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) - - jobSpec = kubeflowv1.PyTorchJobSpec{ - ElasticPolicy: &kubeflowv1.ElasticPolicy{ - MinReplicas: &minReplicas, - MaxReplicas: &maxReplicas, - RDZVBackend: &rdzvBackend, - NProcPerNode: &nProcPerNode, - MaxRestarts: &maxRestarts, - }, - PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, - }, - } - - } else { - - jobSpec = kubeflowv1.PyTorchJobSpec{ - PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeMaster: { - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, - }, + var elasticPolicy := kubeflowv1.ElasticPolicy{ + MinReplicas: &minReplicas, + MaxReplicas: &maxReplicas, + RDZVBackend: &rdzvBackend, + NProcPerNode: &nProcPerNode, + MaxRestarts: &maxRestarts, } + jobSpec.elasticPolicy = &elasticPolicy } + job := &kubeflowv1.PyTorchJob{ TypeMeta: metav1.TypeMeta{ Kind: kubeflowv1.PytorchJobKind, From 5d5c22f2104f49c847229262bb26a5b38db0b0c0 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 25 Apr 2023 00:06:22 -0700 Subject: [PATCH 2/9] merge elastic config in Signed-off-by: Yubo Wang --- .../k8s/kfoperators/pytorch/pytorch.go | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 590db4dd9..fbafcfd9e 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -148,20 +148,6 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx RunPolicy: runPolicy, } - - jobSpec = kubeflowv1.PyTorchJobSpec{ - PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, - Template: v1.PodTemplateSpec{ - ObjectMeta: *objectMeta, - Spec: *podSpec, - }, - RestartPolicy: commonOp.RestartPolicyNever, - }, - }, - } - // Set elastic config elasticConfig := pytorchTaskExtraArgs.GetElasticConfig() if elasticConfig != nil { @@ -170,14 +156,16 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx nProcPerNode := elasticConfig.GetNprocPerNode() maxRestarts := elasticConfig.GetMaxRestarts() rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) - var elasticPolicy := kubeflowv1.ElasticPolicy{ + var elasticPolicy = kubeflowv1.ElasticPolicy{ MinReplicas: &minReplicas, MaxReplicas: &maxReplicas, RDZVBackend: &rdzvBackend, NProcPerNode: &nProcPerNode, MaxRestarts: &maxRestarts, } - jobSpec.elasticPolicy = &elasticPolicy + jobSpec.ElasticPolicy = &elasticPolicy + // Remove master replica if elastic policy is set + delete(jobSpec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeMaster) } job := &kubeflowv1.PyTorchJob{ From 4da8c4fae39303955e88a5ba9627d61f8485e742 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 25 Apr 2023 14:41:49 -0700 Subject: [PATCH 3/9] add unit tests for pytorch Signed-off-by: Yubo Wang --- .../k8s/kfoperators/pytorch/pytorch.go | 28 +-- .../k8s/kfoperators/pytorch/pytorch_test.go | 166 +++++++++++++++++- 2 files changed, 178 insertions(+), 16 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index fbafcfd9e..ba2703401 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -94,6 +94,16 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx workerReplicaSpec := kfPytorchTaskExtraArgs.GetWorkerReplicas() masterReplicaSpec := kfPytorchTaskExtraArgs.GetMasterReplicas() + // Replace specs of master replica + if masterReplicaSpec != nil { + for _, c := range masterPodSpec.Containers { + if c.Name == kubeflowv1.PytorchJobDefaultContainerName { + common.OverrideContainerSpec(masterPodSpec, c.Name, masterReplicaSpec.GetImage(), masterReplicaSpec.GetResources()) + } + masterRestartPolicy = commonOp.RestartPolicy(common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy())) + } + } + // Replace specs of worker replica if workerReplicaSpec != nil { for _, c := range workerPodSpec.Containers { @@ -105,16 +115,6 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx workers = workerReplicaSpec.GetReplicas() } - // Replace specs of master replica - if masterReplicaSpec != nil { - for _, c := range masterPodSpec.Containers { - if c.Name == kubeflowv1.PytorchJobDefaultContainerName { - common.OverrideContainerSpec(masterPodSpec, c.Name, masterReplicaSpec.GetImage(), masterReplicaSpec.GetResources()) - } - masterRestartPolicy = commonOp.RestartPolicy(common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy())) - } - } - if kfPytorchTaskExtraArgs.GetRunPolicy() != nil { runPolicy = common.ParseRunPolicy(*kfPytorchTaskExtraArgs.GetRunPolicy()) } @@ -132,17 +132,17 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx kubeflowv1.PyTorchJobReplicaTypeMaster: { Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: *workerPodSpec, + Spec: *masterPodSpec, }, - RestartPolicy: workerRestartPolicy, + RestartPolicy: masterRestartPolicy, }, kubeflowv1.PyTorchJobReplicaTypeWorker: { Replicas: &workers, Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: *masterPodSpec, + Spec: *workerPodSpec, }, - RestartPolicy: masterRestartPolicy, + RestartPolicy: workerRestartPolicy, }, }, RunPolicy: runPolicy, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 150bdb59a..f95fbbeb8 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -26,6 +26,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" @@ -35,6 +36,7 @@ import ( ) const testImage = "image://" +const testImageMaster = "image://master" const serviceAccount = "pytorch_sa" var ( @@ -76,9 +78,24 @@ func dummyElasticPytorchCustomObj(workers int32, elasticConfig plugins.ElasticCo } } -func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { +func dummyPytorchTaskTemplate(id string, args ...interface{}) *core.TaskTemplate { + + var ptObjJSON string + var err error + + for _, arg := range args { + switch t := arg.(type) { + case *kfplugins.DistributedPyTorchTrainingTask: + var pytorchCustomObj = t + ptObjJSON, err = utils.MarshalToString(pytorchCustomObj) + case *plugins.DistributedPyTorchTrainingTask: + var pytorchCustomObj = t + ptObjJSON, err = utils.MarshalToString(pytorchCustomObj) + default: + err = fmt.Errorf("Unkonw input type %T", t) + } + } - ptObjJSON, err := utils.MarshalToString(pytorchCustomObj) if err != nil { panic(err) } @@ -456,3 +473,148 @@ func TestReplicaCounts(t *testing.T) { }) } } + +func TestBuildResourcePytorchV1(t *testing.T) { + var taskConfig = &kfplugins.DistributedPyTorchTrainingTask{ + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Image: testImageMaster, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + BackoffLimit: 100, + }, + } + + masterResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + }, + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + + assert.Equal(t, testImageMaster, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) + + assert.Equal(t, *masterResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + + assert.Equal(t, commonOp.RestartPolicyAlways, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) + + assert.Equal(t, commonOp.CleanPodPolicyAll, *pytorchJob.Spec.RunPolicy.CleanPodPolicy) + assert.Equal(t, int32(100), *pytorchJob.Spec.RunPolicy.BackoffLimit) + assert.Nil(t, pytorchJob.Spec.RunPolicy.TTLSecondsAfterFinished) + assert.Nil(t, pytorchJob.Spec.RunPolicy.ActiveDeadlineSeconds) + + assert.Nil(t, pytorchJob.Spec.ElasticPolicy) +} + +func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { + var taskConfig = &kfplugins.DistributedPyTorchTrainingTask{ + WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + }, + } + // Master Replica should use resource from task override if not set + taskOverrideResourceRequirements := &corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1000m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("100m"), + corev1.ResourceMemory: resource.MustParse("512Mi"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + pytorchResourceHandler := pytorchOperatorResourceHandler{} + + taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, res) + + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) + + assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) + assert.Nil(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) + + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) + assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) + + assert.Equal(t, *taskOverrideResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) + assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) + + assert.Nil(t, pytorchJob.Spec.ElasticPolicy) +} From 099bba02e7afb419b78d972c6889d77997a26d19 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Wed, 26 Apr 2023 12:02:51 -0700 Subject: [PATCH 4/9] add tfjob Signed-off-by: Yubo Wang --- .../k8s/kfoperators/common/common_operator.go | 6 + .../k8s/kfoperators/pytorch/pytorch.go | 76 ++++--- .../k8s/kfoperators/pytorch/pytorch_test.go | 6 +- .../k8s/kfoperators/tensorflow/tensorflow.go | 126 +++++++++--- .../kfoperators/tensorflow/tensorflow_test.go | 189 +++++++++++++++++- 5 files changed, 343 insertions(+), 60 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 1600d708f..c9685e41b 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -23,6 +23,12 @@ const ( PytorchTaskType = "pytorch" ) +type ReplicaEntry struct { + PodSpec *v1.PodSpec + ReplicaNum int32 + RestartPolicy commonOp.RestartPolicy +} + // ExtractMPICurrentCondition will return the first job condition for MPI func ExtractMPICurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) { if jobConditions != nil { diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index ba2703401..27d4bb0a6 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -69,50 +69,64 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName) - var workers int32 - var runPolicy = commonOp.RunPolicy{} - var workerPodSpec = podSpec.DeepCopy() - var masterPodSpec = podSpec.DeepCopy() - var workerRestartPolicy = commonOp.RestartPolicyNever - var masterRestartPolicy = commonOp.RestartPolicyNever + var masterReplica = common.ReplicaEntry{ + ReplicaNum: int32(1), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + var workerReplica = common.ReplicaEntry{ + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + runPolicy := commonOp.RunPolicy{} if taskTemplate.TaskTypeVersion == 0 { pytorchTaskExtraArgs := plugins.DistributedPyTorchTrainingTask{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs) + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &pytorchTaskExtraArgs) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - workers = pytorchTaskExtraArgs.GetWorkers() + + workerReplica.ReplicaNum = pytorchTaskExtraArgs.GetWorkers() } else if taskTemplate.TaskTypeVersion == 1 { kfPytorchTaskExtraArgs := kfplugins.DistributedPyTorchTrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfPytorchTaskExtraArgs) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - workerReplicaSpec := kfPytorchTaskExtraArgs.GetWorkerReplicas() + // Replace specs of master replica, master should always have 1 replica masterReplicaSpec := kfPytorchTaskExtraArgs.GetMasterReplicas() - - // Replace specs of master replica if masterReplicaSpec != nil { - for _, c := range masterPodSpec.Containers { - if c.Name == kubeflowv1.PytorchJobDefaultContainerName { - common.OverrideContainerSpec(masterPodSpec, c.Name, masterReplicaSpec.GetImage(), masterReplicaSpec.GetResources()) - } - masterRestartPolicy = commonOp.RestartPolicy(common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy())) - } + common.OverrideContainerSpec( + masterReplica.PodSpec, + kubeflowv1.PytorchJobDefaultContainerName, + masterReplicaSpec.GetImage(), + masterReplicaSpec.GetResources(), + ) + masterReplica.RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy()), + ) } // Replace specs of worker replica + workerReplicaSpec := kfPytorchTaskExtraArgs.GetWorkerReplicas() if workerReplicaSpec != nil { - for _, c := range workerPodSpec.Containers { - if c.Name == kubeflowv1.PytorchJobDefaultContainerName { - common.OverrideContainerSpec(workerPodSpec, c.Name, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources()) - } - } - workerRestartPolicy = commonOp.RestartPolicy(common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy())) - workers = workerReplicaSpec.GetReplicas() + common.OverrideContainerSpec( + workerReplica.PodSpec, + kubeflowv1.PytorchJobDefaultContainerName, + workerReplicaSpec.GetImage(), + workerReplicaSpec.GetResources(), + ) + workerReplica.RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()), + ) + workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() } if kfPytorchTaskExtraArgs.GetRunPolicy() != nil { @@ -123,7 +137,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if workers == 0 { + if workerReplica.ReplicaNum == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } @@ -132,17 +146,17 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx kubeflowv1.PyTorchJobReplicaTypeMaster: { Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: *masterPodSpec, + Spec: *masterReplica.PodSpec, }, - RestartPolicy: masterRestartPolicy, + RestartPolicy: masterReplica.RestartPolicy, }, kubeflowv1.PyTorchJobReplicaTypeWorker: { - Replicas: &workers, + Replicas: &workerReplica.ReplicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: *workerPodSpec, + Spec: *workerReplica.PodSpec, }, - RestartPolicy: workerRestartPolicy, + RestartPolicy: workerReplica.RestartPolicy, }, }, RunPolicy: runPolicy, @@ -156,7 +170,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx nProcPerNode := elasticConfig.GetNprocPerNode() maxRestarts := elasticConfig.GetMaxRestarts() rdzvBackend := kubeflowv1.RDZVBackend(elasticConfig.GetRdzvBackend()) - var elasticPolicy = kubeflowv1.ElasticPolicy{ + elasticPolicy := kubeflowv1.ElasticPolicy{ MinReplicas: &minReplicas, MaxReplicas: &maxReplicas, RDZVBackend: &rdzvBackend, diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index f95fbbeb8..26c847598 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -475,7 +475,7 @@ func TestReplicaCounts(t *testing.T) { } func TestBuildResourcePytorchV1(t *testing.T) { - var taskConfig = &kfplugins.DistributedPyTorchTrainingTask{ + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ Image: testImageMaster, Resources: &core.Resources{ @@ -556,7 +556,7 @@ func TestBuildResourcePytorchV1(t *testing.T) { } func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { - var taskConfig = &kfplugins.DistributedPyTorchTrainingTask{ + taskConfig := &kfplugins.DistributedPyTorchTrainingTask{ WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ Replicas: 100, Resources: &core.Resources{ @@ -594,7 +594,7 @@ func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} - taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) taskTemplate.TaskTypeVersion = 1 res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index b5a5a675f..55900d46e 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -6,6 +6,7 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" @@ -56,23 +57,106 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - tensorflowTaskExtraArgs := plugins.DistributedTensorflowTrainingTask{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &tensorflowTaskExtraArgs) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.TFJobDefaultContainerName) - workers := tensorflowTaskExtraArgs.GetWorkers() - psReplicas := tensorflowTaskExtraArgs.GetPsReplicas() - chiefReplicas := tensorflowTaskExtraArgs.GetChiefReplicas() + replicaSpecMap := map[commonOp.ReplicaType]*common.ReplicaEntry{ + kubeflowv1.TFJobReplicaTypeChief: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, + kubeflowv1.TFJobReplicaTypeWorker: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, + kubeflowv1.TFJobReplicaTypePS: { + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + }, + } + runPolicy := commonOp.RunPolicy{} - if workers == 0 { + if taskTemplate.TaskTypeVersion == 0 { + tensorflowTaskExtraArgs := plugins.DistributedTensorflowTrainingTask{} + + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &tensorflowTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = tensorflowTaskExtraArgs.GetChiefReplicas() + replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = tensorflowTaskExtraArgs.GetWorkers() + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = tensorflowTaskExtraArgs.GetPsReplicas() + + } else if taskTemplate.TaskTypeVersion == 1 { + kfTensorflowTaskExtraArgs := kfplugins.DistributedTensorflowTrainingTask{} + + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfTensorflowTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + chiefReplicaSpec := kfTensorflowTaskExtraArgs.GetChiefReplicas() + if chiefReplicaSpec != nil { + common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + chiefReplicaSpec.GetImage(), + chiefReplicaSpec.GetResources(), + ) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(chiefReplicaSpec.GetRestartPolicy()), + ) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = chiefReplicaSpec.GetReplicas() + } + + workerReplicaSpec := kfTensorflowTaskExtraArgs.GetWorkerReplicas() + if workerReplicaSpec != nil { + common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.MPIJobReplicaTypeWorker].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + workerReplicaSpec.GetImage(), + workerReplicaSpec.GetResources(), + ) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()), + ) + replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = workerReplicaSpec.GetReplicas() + } + + psReplicaSpec := kfTensorflowTaskExtraArgs.GetPsReplicas() + if psReplicaSpec != nil { + common.OverrideContainerSpec( + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].PodSpec, + kubeflowv1.TFJobDefaultContainerName, + psReplicaSpec.GetImage(), + psReplicaSpec.GetResources(), + ) + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(psReplicaSpec.GetRestartPolicy()), + ) + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = psReplicaSpec.GetReplicas() + } + + if kfTensorflowTaskExtraArgs.GetRunPolicy() != nil { + runPolicy = common.ParseRunPolicy(*kfTensorflowTaskExtraArgs.GetRunPolicy()) + } + + } else { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) + } + + if replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } @@ -80,27 +164,21 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task TFReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, } - for _, t := range []struct { - podSpec v1.PodSpec - replicaNum *int32 - replicaType commonOp.ReplicaType - }{ - {*podSpec, &workers, kubeflowv1.TFJobReplicaTypeWorker}, - {*podSpec, &psReplicas, kubeflowv1.TFJobReplicaTypePS}, - {*podSpec, &chiefReplicas, kubeflowv1.TFJobReplicaTypeChief}, - } { - if *t.replicaNum > 0 { - jobSpec.TFReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ - Replicas: t.replicaNum, + for replicaType, replicaEntry := range replicaSpecMap { + if replicaEntry.ReplicaNum > 0 { + jobSpec.TFReplicaSpecs[replicaType] = &commonOp.ReplicaSpec{ + Replicas: &replicaEntry.ReplicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: t.podSpec, + Spec: *replicaEntry.PodSpec, }, - RestartPolicy: commonOp.RestartPolicyNever, + RestartPolicy: replicaEntry.RestartPolicy, } } } + jobSpec.RunPolicy = runPolicy + job := &kubeflowv1.TFJob{ TypeMeta: metav1.TypeMeta{ Kind: kubeflowv1.TFJobKind, diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index dc8d5f240..51d44ca38 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -26,6 +26,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" "github.com/golang/protobuf/jsonpb" structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" @@ -34,6 +35,7 @@ import ( ) const testImage = "image://" +const testImageChief = "image://chief" const serviceAccount = "tensorflow_sa" var ( @@ -70,9 +72,24 @@ func dummyTensorFlowCustomObj(workers int32, psReplicas int32, chiefReplicas int } } -func dummyTensorFlowTaskTemplate(id string, tensorflowCustomObj *plugins.DistributedTensorflowTrainingTask) *core.TaskTemplate { +func dummyTensorFlowTaskTemplate(id string, args ...interface{}) *core.TaskTemplate { + + var tfObjJSON string + var err error + + for _, arg := range args { + switch t := arg.(type) { + case *kfplugins.DistributedTensorflowTrainingTask: + var tensorflowCustomObj = t + tfObjJSON, err = utils.MarshalToString(tensorflowCustomObj) + case *plugins.DistributedTensorflowTrainingTask: + var tensorflowCustomObj = t + tfObjJSON, err = utils.MarshalToString(tensorflowCustomObj) + default: + err = fmt.Errorf("Unkonw input type %T", t) + } + } - tfObjJSON, err := utils.MarshalToString(tensorflowCustomObj) if err != nil { panic(err) } @@ -419,3 +436,171 @@ func TestReplicaCounts(t *testing.T) { }) } } + +func TestBuildResourceTensorFlowV1(t *testing.T) { + taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ + ChiefReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 1, + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "1Gi"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "2Gi"}, + }, + }, + RestartPolicy: kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS, + }, + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, + }, + PsReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 50, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + }, + RunPolicy: &kfplugins.RunPolicy{ + CleanPodPolicy: kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL, + ActiveDeadlineSeconds: int32(100), + }, + } + + resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ + kubeflowv1.TFJobReplicaTypeChief: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + corev1.ResourceMemory: resource.MustParse("1Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + }, + kubeflowv1.TFJobReplicaTypeWorker: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + kubeflowv1.TFJobReplicaTypePS: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + }, + }, + } + + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + + taskTemplate := dummyTensorFlowTaskTemplate("the job", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(50), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS].Replicas) + assert.Equal(t, int32(1), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief].Replicas) + + for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + var hasContainerWithDefaultTensorFlowName = false + + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == kubeflowv1.TFJobDefaultContainerName { + hasContainerWithDefaultTensorFlowName = true + assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + } + } + + assert.True(t, hasContainerWithDefaultTensorFlowName) + } + assert.Equal(t, commonOp.CleanPodPolicyAll, *tensorflowJob.Spec.RunPolicy.CleanPodPolicy) + assert.Equal(t, int64(100), *tensorflowJob.Spec.RunPolicy.ActiveDeadlineSeconds) +} + +func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { + taskConfig := &kfplugins.DistributedTensorflowTrainingTask{ + WorkerReplicas: &kfplugins.DistributedTensorflowTrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + {Name: core.Resources_GPU, Value: "1"}, + }, + }, + }, + } + + resourceRequirementsMap := map[commonOp.ReplicaType]*corev1.ResourceRequirements{ + kubeflowv1.TFJobReplicaTypeWorker: { + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), + }, + }, + } + + tensorflowResourceHandler := tensorflowOperatorResourceHandler{} + + taskTemplate := dummyTensorFlowTaskTemplate("the job", taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + tensorflowJob, ok := resource.(*kubeflowv1.TFJob) + assert.True(t, ok) + assert.Equal(t, int32(100), *tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Replicas) + assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeChief]) + assert.Nil(t, tensorflowJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypePS]) + + for replicaType, replicaSpec := range tensorflowJob.Spec.TFReplicaSpecs { + var hasContainerWithDefaultTensorFlowName = false + + for _, container := range replicaSpec.Template.Spec.Containers { + if container.Name == kubeflowv1.TFJobDefaultContainerName { + hasContainerWithDefaultTensorFlowName = true + assert.Equal(t, *resourceRequirementsMap[replicaType], container.Resources) + } + } + + assert.True(t, hasContainerWithDefaultTensorFlowName) + } +} From 9626895d6c84c7e4b9d012ed17c8c0b16fc07c8b Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Thu, 27 Apr 2023 00:06:31 -0700 Subject: [PATCH 5/9] add mpi job Signed-off-by: Yubo Wang --- .../k8s/kfoperators/common/common_operator.go | 5 +- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 147 ++++++++++++------ .../plugins/k8s/kfoperators/mpi/mpi_test.go | 136 +++++++++++++++- .../k8s/kfoperators/pytorch/pytorch.go | 2 + .../k8s/kfoperators/tensorflow/tensorflow.go | 3 + 5 files changed, 246 insertions(+), 47 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index c9685e41b..dd56cf1b0 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -227,7 +227,7 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res return restartPolicyMap[flyteRestartPolicy] } -func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources) (*v1.PodSpec, error) { +func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) (*v1.PodSpec, error) { for idx, c := range podSpec.Containers { if c.Name == containerName { if image != "" { @@ -240,6 +240,9 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri } podSpec.Containers[idx].Resources = *resources } + if args != nil && len(args) != 0 { + podSpec.Containers[idx].Args = args + } } } return podSpec, nil diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 7837f5f4a..241297540 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -7,6 +7,8 @@ import ( "time" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" + flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -48,7 +50,6 @@ func (mpiOperatorResourceHandler) BuildIdentityResource(ctx context.Context, tas // Defines a func to create the full resource object that will be posted to k8s. func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { taskTemplate, err := taskCtx.TaskReader().Read(ctx) - taskTemplateConfig := taskTemplate.GetConfig() if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "unable to fetch task specification [%v]", err.Error()) @@ -56,69 +57,127 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "nil task specification") } - mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) - } - - workers := mpiTaskExtraArgs.GetNumWorkers() - launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() - slots := mpiTaskExtraArgs.GetSlots() - podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.MPIJobDefaultContainerName) - // workersPodSpec is deepCopy of podSpec submitted by flyte - workersPodSpec := podSpec.DeepCopy() + var launcherReplica = common.ReplicaEntry{ + ReplicaNum: int32(1), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + var workerReplica = common.ReplicaEntry{ + ReplicaNum: int32(0), + PodSpec: podSpec.DeepCopy(), + RestartPolicy: commonOp.RestartPolicyNever, + } + slots := int32(1) + runPolicy := commonOp.RunPolicy{} + + if taskTemplate.TaskTypeVersion == 0 { + mpiTaskExtraArgs := plugins.DistributedMPITrainingTask{} + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &mpiTaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } - // If users don't specify "worker_spec_command" in the task config, the command/args are empty. - // However, in some cases, the workers need command/args. - // For example, in horovod tasks, each worker runs a command launching ssh daemon. + workerReplica.ReplicaNum = mpiTaskExtraArgs.GetNumWorkers() + launcherReplica.ReplicaNum = mpiTaskExtraArgs.GetNumLauncherReplicas() + slots = mpiTaskExtraArgs.GetSlots() - workerSpecCommand := []string{} - if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok { - workerSpecCommand = strings.Split(val, " ") - } + // V1 requires passing worker command as template config parameter + taskTemplateConfig := taskTemplate.GetConfig() + workerSpecCommand := []string{} + if val, ok := taskTemplateConfig[workerSpecCommandKey]; ok { + workerSpecCommand = strings.Split(val, " ") + } + + for k := range workerReplica.PodSpec.Containers { + if workerReplica.PodSpec.Containers[k].Name == kubeflowv1.MPIJobDefaultContainerName { + workerReplica.PodSpec.Containers[k].Args = workerSpecCommand + workerReplica.PodSpec.Containers[k].Command = []string{} + } + } + + } else if taskTemplate.TaskTypeVersion == 1 { + kfMPITaskExtraArgs := kfplugins.DistributedMPITrainingTask{} + + err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &kfMPITaskExtraArgs) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) + } + + launcherReplicaSpec := kfMPITaskExtraArgs.GetLauncherReplicas() + if launcherReplicaSpec != nil { + // flyte commands will be passed as args to the container + common.OverrideContainerSpec( + launcherReplica.PodSpec, + kubeflowv1.MPIJobDefaultContainerName, + launcherReplicaSpec.GetImage(), + launcherReplicaSpec.GetResources(), + launcherReplicaSpec.GetCommand(), + ) + launcherReplica.RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(launcherReplicaSpec.GetRestartPolicy()), + ) + } + + workerReplicaSpec := kfMPITaskExtraArgs.GetWorkerReplicas() + if workerReplicaSpec != nil { + common.OverrideContainerSpec( + workerReplica.PodSpec, + kubeflowv1.MPIJobDefaultContainerName, + workerReplicaSpec.GetImage(), + workerReplicaSpec.GetResources(), + workerReplicaSpec.GetCommand(), + ) + workerReplica.RestartPolicy = + commonOp.RestartPolicy( + common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()), + ) + workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() + } + + if kfMPITaskExtraArgs.GetRunPolicy() != nil { + runPolicy = common.ParseRunPolicy(*kfMPITaskExtraArgs.GetRunPolicy()) + } - for k := range workersPodSpec.Containers { - workersPodSpec.Containers[k].Args = workerSpecCommand - workersPodSpec.Containers[k].Command = []string{} + } else { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, + "Invalid TaskSpecification, unsupported task template version [%v] key", taskTemplate.TaskTypeVersion) } - if workers == 0 { + if workerReplica.ReplicaNum == 0 { return nil, fmt.Errorf("number of worker should be more then 0") } - if launcherReplicas == 0 { + if launcherReplica.ReplicaNum == 0 { return nil, fmt.Errorf("number of launch worker should be more then 0") } jobSpec := kubeflowv1.MPIJobSpec{ - SlotsPerWorker: &slots, - MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{}, - } - - for _, t := range []struct { - podSpec v1.PodSpec - replicaNum *int32 - replicaType commonOp.ReplicaType - }{ - {*podSpec, &launcherReplicas, kubeflowv1.MPIJobReplicaTypeLauncher}, - {*workersPodSpec, &workers, kubeflowv1.MPIJobReplicaTypeWorker}, - } { - if *t.replicaNum > 0 { - jobSpec.MPIReplicaSpecs[t.replicaType] = &commonOp.ReplicaSpec{ - Replicas: t.replicaNum, + SlotsPerWorker: &slots, + RunPolicy: runPolicy, + MPIReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.MPIJobReplicaTypeLauncher: { + Replicas: &launcherReplica.ReplicaNum, Template: v1.PodTemplateSpec{ ObjectMeta: *objectMeta, - Spec: t.podSpec, + Spec: *launcherReplica.PodSpec, }, - RestartPolicy: commonOp.RestartPolicyNever, - } - } + RestartPolicy: launcherReplica.RestartPolicy, + }, + kubeflowv1.MPIJobReplicaTypeWorker: { + Replicas: &workerReplica.ReplicaNum, + Template: v1.PodTemplateSpec{ + ObjectMeta: *objectMeta, + Spec: *workerReplica.PodSpec, + }, + RestartPolicy: workerReplica.RestartPolicy, + }, + }, } job := &kubeflowv1.MPIJob{ diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 29fefd9ca..8af67fa36 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -8,6 +8,7 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins" + kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow" "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" @@ -68,9 +69,24 @@ func dummyMPICustomObj(workers int32, launcher int32, slots int32) *plugins.Dist } } -func dummyMPITaskTemplate(id string, mpiCustomObj *plugins.DistributedMPITrainingTask) *core.TaskTemplate { +func dummyMPITaskTemplate(id string, args ...interface{}) *core.TaskTemplate { + + var mpiObjJSON string + var err error + + for _, arg := range args { + switch t := arg.(type) { + case *kfplugins.DistributedMPITrainingTask: + var mpiCustomObj = t + mpiObjJSON, err = utils.MarshalToString(mpiCustomObj) + case *plugins.DistributedMPITrainingTask: + var mpiCustomObj = t + mpiObjJSON, err = utils.MarshalToString(mpiCustomObj) + default: + err = fmt.Errorf("Unkonw input type %T", t) + } + } - mpiObjJSON, err := utils.MarshalToString(mpiCustomObj) if err != nil { panic(err) } @@ -427,3 +443,119 @@ func TestReplicaCounts(t *testing.T) { }) } } + +func TestBuildResourceMPIV1(t *testing.T) { + launcherCommand := []string{"python", "launcher.py"} + workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} + taskConfig := &kfplugins.DistributedMPITrainingTask{ + LauncherReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Image: testImage, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + Command: launcherCommand, + }, + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + Command: workerCommand, + }, + Slots: int32(1), + } + + launcherResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("250m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("500m"), + }, + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + mpiResourceHandler := mpiOperatorResourceHandler{} + + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *launcherResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Resources) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, launcherCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) +} + +func TestBuildResourceMPIV1WithOnlyWorkerReplica(t *testing.T) { + workerCommand := []string{"/usr/sbin/sshd", "/.sshd_config"} + + taskConfig := &kfplugins.DistributedMPITrainingTask{ + WorkerReplicas: &kfplugins.DistributedMPITrainingReplicaSpec{ + Replicas: 100, + Resources: &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "1024m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "2048m"}, + }, + }, + Command: []string{"/usr/sbin/sshd", "/.sshd_config"}, + }, + Slots: int32(1), + } + + workerResourceRequirements := &corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1024m"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2048m"), + }, + } + + mpiResourceHandler := mpiOperatorResourceHandler{} + + taskTemplate := dummyMPITaskTemplate(mpiID2, taskConfig) + taskTemplate.TaskTypeVersion = 1 + + resource, err := mpiResourceHandler.BuildResource(context.TODO(), dummyMPITaskContext(taskTemplate)) + assert.NoError(t, err) + assert.NotNil(t, resource) + + mpiJob, ok := resource.(*kubeflowv1.MPIJob) + assert.True(t, ok) + assert.Equal(t, int32(1), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Replicas) + assert.Equal(t, int32(100), *mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Replicas) + assert.Equal(t, int32(1), *mpiJob.Spec.SlotsPerWorker) + assert.Equal(t, *workerResourceRequirements, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) + assert.Equal(t, testArgs, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeLauncher].Template.Spec.Containers[0].Args) + assert.Equal(t, workerCommand, mpiJob.Spec.MPIReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker].Template.Spec.Containers[0].Args) +} diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 27d4bb0a6..1d46b9ef3 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -106,6 +106,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx kubeflowv1.PytorchJobDefaultContainerName, masterReplicaSpec.GetImage(), masterReplicaSpec.GetResources(), + nil, ) masterReplica.RestartPolicy = commonOp.RestartPolicy( @@ -121,6 +122,7 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx kubeflowv1.PytorchJobDefaultContainerName, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources(), + nil, ) workerReplica.RestartPolicy = commonOp.RestartPolicy( diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 55900d46e..980062201 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -109,6 +109,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobDefaultContainerName, chiefReplicaSpec.GetImage(), chiefReplicaSpec.GetResources(), + nil, ) replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].RestartPolicy = commonOp.RestartPolicy( @@ -124,6 +125,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobDefaultContainerName, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources(), + nil, ) replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = commonOp.RestartPolicy( @@ -139,6 +141,7 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task kubeflowv1.TFJobDefaultContainerName, psReplicaSpec.GetImage(), psReplicaSpec.GetResources(), + nil, ) replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].RestartPolicy = commonOp.RestartPolicy( From dfa5ba01472b5711cdb743da4fd7128ac66014b7 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 5 May 2023 23:06:05 -0700 Subject: [PATCH 6/9] add test to commone operator Signed-off-by: Yubo Wang --- .../k8s/kfoperators/common/common_operator.go | 13 ++- .../common/common_operator_test.go | 101 ++++++++++++++++++ 2 files changed, 110 insertions(+), 4 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index dd56cf1b0..e1caa9020 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -234,11 +234,16 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri podSpec.Containers[idx].Image = image } if resources != nil { - resources, err := flytek8s.ToK8sResourceRequirements(resources) - if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources [%v], Err: [%v]", resources, err.Error()) + // if resources requests and limits both not set, we will not override the resources + if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 { + resources, err := flytek8s.ToK8sResourceRequirements(resources) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecificat ion on Resources [%v], Err: [%v]", resources, err.Error()) + } + podSpec.Containers[idx].Resources = *resources } - podSpec.Containers[idx].Resources = *resources + } else { + podSpec.Containers[idx].Resources = v1.ResourceRequirements{} } if args != nil && len(args) != 0 { podSpec.Containers[idx].Args = args diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index 1fb26f128..d65628da2 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -5,12 +5,15 @@ import ( "testing" "time" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flyteplugins/go/tasks/logs" pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" commonOp "github.com/kubeflow/common/pkg/apis/common/v1" "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) func TestExtractMPICurrentCondition(t *testing.T) { @@ -183,3 +186,101 @@ func TestGetLogs(t *testing.T) { assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[2].Uri) } + +func dummyPodSpec() v1.PodSpec { + return v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "primary container", + Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.MustParse("2"), + "memory": resource.MustParse("200Mi"), + "gpu": resource.MustParse("1"), + }, + Requests: v1.ResourceList{ + "cpu": resource.MustParse("1"), + "memory": resource.MustParse("100Mi"), + "gpu": resource.MustParse("1"), + }, + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "volume mount", + }, + }, + }, + { + Name: "secondary container", + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "gpu": resource.MustParse("2"), + }, + Requests: v1.ResourceList{ + "gpu": resource.MustParse("2"), + }, + }, + }, + }, + Volumes: []v1.Volume{ + { + Name: "dshm", + }, + }, + Tolerations: []v1.Toleration{ + { + Key: "my toleration key", + Value: "my toleration value", + }, + }, + } +} + +func TestOverrideContainerSpec(t *testing.T) { + podSpec := dummyPodSpec() + _, err := OverrideContainerSpec( + &podSpec, "primary container", "testing-image", + &core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "250m"}, + }, + Limits: []*core.Resources_ResourceEntry{ + {Name: core.Resources_CPU, Value: "500m"}, + }, + }, + []string{"python", "-m", "run.py"}, + ) + assert.NoError(t, err) + assert.Equal(t, 2, len(podSpec.Containers)) + assert.Equal(t, "testing-image", podSpec.Containers[0].Image) + assert.NotNil(t, podSpec.Containers[0].Resources.Limits) + assert.NotNil(t, podSpec.Containers[0].Resources.Requests) + // verify resources not overriden if empty resources + assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m"))) + assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m"))) + assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args) +} + +func TestOverrideContainerSpecEmptyFields(t *testing.T) { + podSpec := dummyPodSpec() + _, err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{}) + assert.NoError(t, err) + assert.Equal(t, 2, len(podSpec.Containers)) + assert.NotNil(t, podSpec.Containers[0].Resources.Limits) + assert.NotNil(t, podSpec.Containers[0].Resources.Requests) + // verify resources not overriden if empty resources + assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1"))) + assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi"))) + assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2"))) + assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi"))) +} + +func TestOverrideContainerNilResources(t *testing.T) { + podSpec := dummyPodSpec() + _, err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{}) + assert.NoError(t, err) + assert.Equal(t, 2, len(podSpec.Containers)) + assert.Nil(t, podSpec.Containers[0].Resources.Limits) + assert.Nil(t, podSpec.Containers[0].Resources.Requests) +} From 552b4fda828a34bf397a3fa42766fa5c4375d987 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 5 May 2023 23:16:24 -0700 Subject: [PATCH 7/9] update flyteidl Signed-off-by: Yubo Wang --- go.mod | 3 +- go.sum | 4 +- go/tasks/plugins/k8s/ray/config_flags.go | 3 ++ go/tasks/plugins/k8s/ray/config_flags_test.go | 42 +++++++++++++++++++ 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 7616dbb36..7d73674eb 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v1.3.19 + github.com/flyteorg/flyteidl v1.5.2 github.com/flyteorg/flytestdlib v1.0.15 github.com/go-test/deep v1.0.7 github.com/golang/protobuf v1.5.2 @@ -135,4 +135,3 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d -replace github.com/flyteorg/flyteidl => ../flyteidl \ No newline at end of file diff --git a/go.sum b/go.sum index 70bbe278f..1af063022 100644 --- a/go.sum +++ b/go.sum @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/flyteorg/flyteidl v1.3.19 h1:i79Dh7UoP8Z4LEJ2ox6jlfZVJtFZ+r4g84CJj1gh22Y= -github.com/flyteorg/flyteidl v1.3.19/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= +github.com/flyteorg/flyteidl v1.5.2 h1:DZPzYkTg92qA4e17fd0ZW1M+gh1gJKh/VOK+F4bYgM8= +github.com/flyteorg/flyteidl v1.5.2/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= diff --git a/go/tasks/plugins/k8s/ray/config_flags.go b/go/tasks/plugins/k8s/ray/config_flags.go index 6f651a3d2..f8e983056 100755 --- a/go/tasks/plugins/k8s/ray/config_flags.go +++ b/go/tasks/plugins/k8s/ray/config_flags.go @@ -56,5 +56,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "nodeIPAddress"), defaultConfig.NodeIPAddress, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") return cmdFlags } diff --git a/go/tasks/plugins/k8s/ray/config_flags_test.go b/go/tasks/plugins/k8s/ray/config_flags_test.go index d5a59757c..60761b900 100755 --- a/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -183,4 +183,46 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_remoteClusterConfig.name", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.name", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Name) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.endpoint", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.endpoint", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.enabled", testValue) + if vBool, err := cmdFlags.GetBool("remoteClusterConfig.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.RemoteClusterConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } From 4c3b0fd238931c5012ca0ecceedce5e807908a45 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Mon, 8 May 2023 16:57:22 -0700 Subject: [PATCH 8/9] add function header comments Signed-off-by: Yubo Wang --- go/tasks/plugins/k8s/kfoperators/common/common_operator.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index e1caa9020..d93905046 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -189,6 +189,7 @@ func OverridePrimaryContainerName(podSpec *v1.PodSpec, primaryContainerName stri } } +// ParseRunPolicy converts a kubeflow plugin RunPolicy object to a k8s RunPolicy object. func ParseRunPolicy(flyteRunPolicy kfplugins.RunPolicy) commonOp.RunPolicy { runPolicy := commonOp.RunPolicy{} if flyteRunPolicy.GetBackoffLimit() != 0 { @@ -209,6 +210,7 @@ func ParseRunPolicy(flyteRunPolicy kfplugins.RunPolicy) commonOp.RunPolicy { return runPolicy } +// Get k8s clean pod policy from flyte kubeflow plugins clean pod policy. func ParseCleanPodPolicy(flyteCleanPodPolicy kfplugins.CleanPodPolicy) commonOp.CleanPodPolicy { cleanPodPolicyMap := map[kfplugins.CleanPodPolicy]commonOp.CleanPodPolicy{ kfplugins.CleanPodPolicy_CLEANPOD_POLICY_NONE: commonOp.CleanPodPolicyNone, @@ -218,6 +220,7 @@ func ParseCleanPodPolicy(flyteCleanPodPolicy kfplugins.CleanPodPolicy) commonOp. return cleanPodPolicyMap[flyteCleanPodPolicy] } +// Get k8s restart policy from flyte kubeflow plugins restart policy. func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.RestartPolicy { restartPolicyMap := map[kfplugins.RestartPolicy]commonOp.RestartPolicy{ kfplugins.RestartPolicy_RESTART_POLICY_NEVER: commonOp.RestartPolicyNever, @@ -227,6 +230,8 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res return restartPolicyMap[flyteRestartPolicy] } +// OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function +// updates the image, resources and command arguments of the container that matches the given containerName. func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) (*v1.PodSpec, error) { for idx, c := range podSpec.Containers { if c.Name == containerName { @@ -245,7 +250,7 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri } else { podSpec.Containers[idx].Resources = v1.ResourceRequirements{} } - if args != nil && len(args) != 0 { + if len(args) != 0 { podSpec.Containers[idx].Args = args } } From 3b79d771352daf059fe8fe74953d888e686b28c3 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Tue, 9 May 2023 13:03:19 -0700 Subject: [PATCH 9/9] fix lint Signed-off-by: Yubo Wang --- .../k8s/kfoperators/common/common_operator.go | 6 ++-- .../common/common_operator_test.go | 10 +++---- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 20 ++++++------- .../k8s/kfoperators/pytorch/pytorch.go | 20 ++++++------- .../k8s/kfoperators/tensorflow/tensorflow.go | 30 +++++++++---------- .../kfoperators/tensorflow/tensorflow_test.go | 5 ++-- 6 files changed, 45 insertions(+), 46 deletions(-) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index d93905046..d86ae42df 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -232,7 +232,7 @@ func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.Res // OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function // updates the image, resources and command arguments of the container that matches the given containerName. -func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) (*v1.PodSpec, error) { +func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) error { for idx, c := range podSpec.Containers { if c.Name == containerName { if image != "" { @@ -243,7 +243,7 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 { resources, err := flytek8s.ToK8sResourceRequirements(resources) if err != nil { - return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecificat ion on Resources [%v], Err: [%v]", resources, err.Error()) + return flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecificat ion on Resources [%v], Err: [%v]", resources, err.Error()) } podSpec.Containers[idx].Resources = *resources } @@ -255,5 +255,5 @@ func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image stri } } } - return podSpec, nil + return nil } diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go index d65628da2..ee2dc5a94 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go @@ -239,7 +239,7 @@ func dummyPodSpec() v1.PodSpec { func TestOverrideContainerSpec(t *testing.T) { podSpec := dummyPodSpec() - _, err := OverrideContainerSpec( + err := OverrideContainerSpec( &podSpec, "primary container", "testing-image", &core.Resources{ Requests: []*core.Resources_ResourceEntry{ @@ -256,7 +256,7 @@ func TestOverrideContainerSpec(t *testing.T) { assert.Equal(t, "testing-image", podSpec.Containers[0].Image) assert.NotNil(t, podSpec.Containers[0].Resources.Limits) assert.NotNil(t, podSpec.Containers[0].Resources.Requests) - // verify resources not overriden if empty resources + // verify resources not overridden if empty resources assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m"))) assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m"))) assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args) @@ -264,12 +264,12 @@ func TestOverrideContainerSpec(t *testing.T) { func TestOverrideContainerSpecEmptyFields(t *testing.T) { podSpec := dummyPodSpec() - _, err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{}) + err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{}) assert.NoError(t, err) assert.Equal(t, 2, len(podSpec.Containers)) assert.NotNil(t, podSpec.Containers[0].Resources.Limits) assert.NotNil(t, podSpec.Containers[0].Resources.Requests) - // verify resources not overriden if empty resources + // verify resources not overridden if empty resources assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1"))) assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi"))) assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2"))) @@ -278,7 +278,7 @@ func TestOverrideContainerSpecEmptyFields(t *testing.T) { func TestOverrideContainerNilResources(t *testing.T) { podSpec := dummyPodSpec() - _, err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{}) + err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{}) assert.NoError(t, err) assert.Equal(t, 2, len(podSpec.Containers)) assert.Nil(t, podSpec.Containers[0].Resources.Limits) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index 241297540..d4e35a25d 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -112,32 +112,32 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu launcherReplicaSpec := kfMPITaskExtraArgs.GetLauncherReplicas() if launcherReplicaSpec != nil { // flyte commands will be passed as args to the container - common.OverrideContainerSpec( + err = common.OverrideContainerSpec( launcherReplica.PodSpec, kubeflowv1.MPIJobDefaultContainerName, launcherReplicaSpec.GetImage(), launcherReplicaSpec.GetResources(), launcherReplicaSpec.GetCommand(), ) - launcherReplica.RestartPolicy = - commonOp.RestartPolicy( - common.ParseRestartPolicy(launcherReplicaSpec.GetRestartPolicy()), - ) + if err != nil { + return nil, err + } + launcherReplica.RestartPolicy = common.ParseRestartPolicy(launcherReplicaSpec.GetRestartPolicy()) } workerReplicaSpec := kfMPITaskExtraArgs.GetWorkerReplicas() if workerReplicaSpec != nil { - common.OverrideContainerSpec( + err = common.OverrideContainerSpec( workerReplica.PodSpec, kubeflowv1.MPIJobDefaultContainerName, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources(), workerReplicaSpec.GetCommand(), ) - workerReplica.RestartPolicy = - commonOp.RestartPolicy( - common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()), - ) + if err != nil { + return nil, err + } + workerReplica.RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 1d46b9ef3..d5cd747c6 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -101,33 +101,33 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx // Replace specs of master replica, master should always have 1 replica masterReplicaSpec := kfPytorchTaskExtraArgs.GetMasterReplicas() if masterReplicaSpec != nil { - common.OverrideContainerSpec( + err := common.OverrideContainerSpec( masterReplica.PodSpec, kubeflowv1.PytorchJobDefaultContainerName, masterReplicaSpec.GetImage(), masterReplicaSpec.GetResources(), nil, ) - masterReplica.RestartPolicy = - commonOp.RestartPolicy( - common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy()), - ) + if err != nil { + return nil, err + } + masterReplica.RestartPolicy = common.ParseRestartPolicy(masterReplicaSpec.GetRestartPolicy()) } // Replace specs of worker replica workerReplicaSpec := kfPytorchTaskExtraArgs.GetWorkerReplicas() if workerReplicaSpec != nil { - common.OverrideContainerSpec( + err := common.OverrideContainerSpec( workerReplica.PodSpec, kubeflowv1.PytorchJobDefaultContainerName, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources(), nil, ) - workerReplica.RestartPolicy = - commonOp.RestartPolicy( - common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()), - ) + if err != nil { + return nil, err + } + workerReplica.RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) workerReplica.ReplicaNum = workerReplicaSpec.GetReplicas() } diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index 980062201..6ee3ce440 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -104,49 +104,49 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task chiefReplicaSpec := kfTensorflowTaskExtraArgs.GetChiefReplicas() if chiefReplicaSpec != nil { - common.OverrideContainerSpec( + err := common.OverrideContainerSpec( replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].PodSpec, kubeflowv1.TFJobDefaultContainerName, chiefReplicaSpec.GetImage(), chiefReplicaSpec.GetResources(), nil, ) - replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].RestartPolicy = - commonOp.RestartPolicy( - common.ParseRestartPolicy(chiefReplicaSpec.GetRestartPolicy()), - ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].RestartPolicy = common.ParseRestartPolicy(chiefReplicaSpec.GetRestartPolicy()) replicaSpecMap[kubeflowv1.TFJobReplicaTypeChief].ReplicaNum = chiefReplicaSpec.GetReplicas() } workerReplicaSpec := kfTensorflowTaskExtraArgs.GetWorkerReplicas() if workerReplicaSpec != nil { - common.OverrideContainerSpec( + err := common.OverrideContainerSpec( replicaSpecMap[kubeflowv1.MPIJobReplicaTypeWorker].PodSpec, kubeflowv1.TFJobDefaultContainerName, workerReplicaSpec.GetImage(), workerReplicaSpec.GetResources(), nil, ) - replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = - commonOp.RestartPolicy( - common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()), - ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = common.ParseRestartPolicy(workerReplicaSpec.GetRestartPolicy()) replicaSpecMap[kubeflowv1.TFJobReplicaTypeWorker].ReplicaNum = workerReplicaSpec.GetReplicas() } psReplicaSpec := kfTensorflowTaskExtraArgs.GetPsReplicas() if psReplicaSpec != nil { - common.OverrideContainerSpec( + err := common.OverrideContainerSpec( replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].PodSpec, kubeflowv1.TFJobDefaultContainerName, psReplicaSpec.GetImage(), psReplicaSpec.GetResources(), nil, ) - replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].RestartPolicy = - commonOp.RestartPolicy( - common.ParseRestartPolicy(psReplicaSpec.GetRestartPolicy()), - ) + if err != nil { + return nil, err + } + replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].RestartPolicy = common.ParseRestartPolicy(psReplicaSpec.GetRestartPolicy()) replicaSpecMap[kubeflowv1.TFJobReplicaTypePS].ReplicaNum = psReplicaSpec.GetReplicas() } diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 9046d8288..8174258e1 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -35,7 +35,6 @@ import ( ) const testImage = "image://" -const testImageChief = "image://chief" const serviceAccount = "tensorflow_sa" var ( @@ -518,7 +517,7 @@ func TestBuildResourceTensorFlowV1(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskTemplate := dummyTensorFlowTaskTemplate("the job", taskConfig) + taskTemplate := dummyTensorFlowTaskTemplate("v1", taskConfig) taskTemplate.TaskTypeVersion = 1 resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate)) @@ -579,7 +578,7 @@ func TestBuildResourceTensorFlowV1WithOnlyWorker(t *testing.T) { tensorflowResourceHandler := tensorflowOperatorResourceHandler{} - taskTemplate := dummyTensorFlowTaskTemplate("the job", taskConfig) + taskTemplate := dummyTensorFlowTaskTemplate("v1 with only worker replica", taskConfig) taskTemplate.TaskTypeVersion = 1 resource, err := tensorflowResourceHandler.BuildResource(context.TODO(), dummyTensorFlowTaskContext(taskTemplate))