Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Apply pod template to pytorch job pod spec
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Grätz <[email protected]>
  • Loading branch information
Fabio Grätz committed Nov 24, 2022
1 parent 75dc19c commit 2ec7da8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
11 changes: 11 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace())

if podTemplate != nil {
basePodSpec := podTemplate.Template.Spec.DeepCopy()
mergedPodSpec, err := flytek8s.MergePodSpecs(basePodSpec, podSpec, kubeflowv1.PytorchJobDefaultContainerName)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error())
}
podSpec = mergedPodSpec
}

common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName)

workers := pytorchTaskExtraArgs.GetWorkers()
Expand Down
6 changes: 3 additions & 3 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTas
}
}

func dummySparkTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate {
func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate {

ptObjJSON, err := utils.MarshalToString(pytorchCustomObj)
if err != nil {
Expand Down Expand Up @@ -260,7 +260,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl
}

ptObj := dummyPytorchCustomObj(workers)
taskTemplate := dummySparkTaskTemplate("the job", ptObj)
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)
resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
if err != nil {
panic(err)
Expand All @@ -286,7 +286,7 @@ func TestBuildResourcePytorch(t *testing.T) {
pytorchResourceHandler := pytorchOperatorResourceHandler{}

ptObj := dummyPytorchCustomObj(100)
taskTemplate := dummySparkTaskTemplate("the job", ptObj)
taskTemplate := dummyPytorchTaskTemplate("the job", ptObj)

resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate))
assert.NoError(t, err)
Expand Down

0 comments on commit 2ec7da8

Please sign in to comment.