diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 0ad038f64..19a678e27 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -66,6 +66,17 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } + podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace()) + + if podTemplate != nil { + basePodSpec := podTemplate.Template.Spec.DeepCopy() + mergedPodSpec, err := flytek8s.MergePodSpecs(basePodSpec, podSpec, kubeflowv1.PytorchJobDefaultContainerName) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error()) + } + podSpec = mergedPodSpec + } + common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName) workers := pytorchTaskExtraArgs.GetWorkers() diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 5185d150d..0c4dcb3b1 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -69,7 +69,7 @@ func dummyPytorchCustomObj(workers int32) *plugins.DistributedPyTorchTrainingTas } } -func dummySparkTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { +func dummyPytorchTaskTemplate(id string, pytorchCustomObj *plugins.DistributedPyTorchTrainingTask) *core.TaskTemplate { ptObjJSON, err := utils.MarshalToString(pytorchCustomObj) if err != nil { @@ -260,7 +260,7 @@ func dummyPytorchJobResource(pytorchResourceHandler pytorchOperatorResourceHandl } ptObj := dummyPytorchCustomObj(workers) - taskTemplate := dummySparkTaskTemplate("the job", ptObj) + taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) if err != nil { panic(err) @@ -286,7 +286,7 @@ func TestBuildResourcePytorch(t *testing.T) { pytorchResourceHandler := pytorchOperatorResourceHandler{} ptObj := dummyPytorchCustomObj(100) - taskTemplate := dummySparkTaskTemplate("the job", ptObj) + taskTemplate := dummyPytorchTaskTemplate("the job", ptObj) resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate)) assert.NoError(t, err)