Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Apply default pod template to PytorchJob pods #297

Merged
merged 13 commits into from
Dec 16, 2022
125 changes: 73 additions & 52 deletions go/tasks/pluginmachinery/flytek8s/pod_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package flytek8s

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -144,70 +145,87 @@ func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*
return pod, nil
}

func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) {
pod := v1.Pod{
TypeMeta: v12.TypeMeta{
Kind: PodKind,
APIVersion: v1.SchemeGroupVersion.String(),
},
func MergePodSpecs(podTemplatePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string) (*v1.PodSpec, error) {
fg91 marked this conversation as resolved.
Show resolved Hide resolved
if podTemplatePodSpec == nil || podSpec == nil {
return nil, errors.New("podTemplatePodSpec and podSpec cannot be nil")
}

if podTemplate != nil {
// merge template PodSpec
basePodSpec := podTemplate.Template.Spec.DeepCopy()
err := mergo.Merge(basePodSpec, podSpec, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}
var podTemplatePodSpecCopy *v1.PodSpec = podTemplatePodSpec.DeepCopy()

// merge template Containers
var mergedContainers []v1.Container
var defaultContainerTemplate, primaryContainerTemplate *v1.Container
for i := 0; i < len(podTemplate.Template.Spec.Containers); i++ {
if podTemplate.Template.Spec.Containers[i].Name == defaultContainerTemplateName {
defaultContainerTemplate = &podTemplate.Template.Spec.Containers[i]
} else if podTemplate.Template.Spec.Containers[i].Name == primaryContainerTemplateName {
primaryContainerTemplate = &podTemplate.Template.Spec.Containers[i]
}
}
err := mergo.Merge(podTemplatePodSpecCopy, podSpec, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}

for _, container := range podSpec.Containers {
// if applicable start with defaultContainerTemplate
var mergedContainer *v1.Container
if defaultContainerTemplate != nil {
mergedContainer = defaultContainerTemplate.DeepCopy()
}
// merge template Containers
var mergedContainers []v1.Container
var defaultContainerTemplate, primaryContainerTemplate *v1.Container
for i := 0; i < len(podTemplatePodSpecCopy.Containers); i++ {
if podTemplatePodSpecCopy.Containers[i].Name == defaultContainerTemplateName {
defaultContainerTemplate = &podTemplatePodSpecCopy.Containers[i]
} else if podTemplatePodSpecCopy.Containers[i].Name == primaryContainerTemplateName {
primaryContainerTemplate = &podTemplatePodSpecCopy.Containers[i]
}
}

// if applicable merge with primaryContainerTemplate
if container.Name == primaryContainerName && primaryContainerTemplate != nil {
if mergedContainer == nil {
mergedContainer = primaryContainerTemplate.DeepCopy()
} else {
err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}
}
}
for _, container := range podSpec.Containers {
// if applicable start with defaultContainerTemplate
var mergedContainer *v1.Container
if defaultContainerTemplate != nil {
mergedContainer = defaultContainerTemplate.DeepCopy()
}

// if applicable merge with existing container
// if applicable merge with primaryContainerTemplate
if container.Name == primaryContainerName && primaryContainerTemplate != nil {
if mergedContainer == nil {
mergedContainers = append(mergedContainers, container)
mergedContainer = primaryContainerTemplate.DeepCopy()
} else {
err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice)
err := mergo.Merge(mergedContainer, primaryContainerTemplate, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}
}
}

// if applicable merge with existing container # TODO test
if mergedContainer == nil {
mergedContainers = append(mergedContainers, container)

mergedContainers = append(mergedContainers, *mergedContainer)
} else {
err := mergo.Merge(mergedContainer, container, mergo.WithOverride, mergo.WithAppendSlice)
if err != nil {
return nil, err
}

mergedContainers = append(mergedContainers, *mergedContainer)
}

}

// update Pod fields
podTemplatePodSpecCopy.Containers = mergedContainers

return podTemplatePodSpecCopy, nil
}

func BuildPodWithSpec(podTemplate *v1.PodTemplate, podSpec *v1.PodSpec, primaryContainerName string) (*v1.Pod, error) {
pod := v1.Pod{
TypeMeta: v12.TypeMeta{
Kind: PodKind,
APIVersion: v1.SchemeGroupVersion.String(),
},
}

if podTemplate != nil {
// merge template PodSpec
mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName)
if err != nil {
return nil, err
}

// update Pod fields
basePodSpec.Containers = mergedContainers
pod.ObjectMeta = podTemplate.Template.ObjectMeta
pod.Spec = *basePodSpec
pod.Spec = *mergedPodSpec

} else {
pod.Spec = *podSpec
}
Expand All @@ -231,12 +249,15 @@ func BuildIdentityPod() *v1.Pod {
// Important considerations.
// Pending Status in Pod could be for various reasons and sometimes could signal a problem
// Case I: Pending because the Image pull is failing and it is backing off
// This could be transient. So we can actually rely on the failure reason.
// The failure transitions from ErrImagePull -> ImagePullBackoff
//
// This could be transient. So we can actually rely on the failure reason.
// The failure transitions from ErrImagePull -> ImagePullBackoff
//
// Case II: Not enough resources are available. This is tricky. It could be that the total number of
// resources requested is beyond the capability of the system. for this we will rely on configuration
// and hence input gates. We should not allow bad requests that Request for large number of resource through.
// In the case it makes through, we will fail after timeout
//
// resources requested is beyond the capability of the system. for this we will rely on configuration
// and hence input gates. We should not allow bad requests that Request for large number of resource through.
// In the case it makes through, we will fail after timeout
func DemystifyPending(status v1.PodStatus) (pluginsCore.PhaseInfo, error) {
// Search over the difference conditions in the status object. Note that the 'Pending' this function is
// demystifying is the 'phase' of the pod status. This is different than the PodReady condition type also used below
Expand Down
139 changes: 139 additions & 0 deletions go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
pluginsCoreMock "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
Expand Down Expand Up @@ -1013,6 +1014,144 @@ func TestDeterminePrimaryContainerPhase(t *testing.T) {
})
}

func TestMergePodSpecs(t *testing.T) {
var priority int32 = 1

podSpec1, _ := MergePodSpecs(nil, nil, "foo")
assert.Nil(t, podSpec1)

podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo")
assert.Nil(t, podSpec2)

podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo")
assert.Nil(t, podSpec3)

podSpec := v1.PodSpec{
Containers: []v1.Container{
v1.Container{
Name: "foo",
},
v1.Container{
Name: "bar",
},
},
NodeSelector: map[string]string{
"baz": "bar",
},
Priority: &priority,
SchedulerName: "overrideScheduler",
Tolerations: []v1.Toleration{
v1.Toleration{
Key: "bar",
},
v1.Toleration{
Key: "baz",
},
},
}

defaultContainerTemplate := v1.Container{
Name: defaultContainerTemplateName,
TerminationMessagePath: "/dev/default-termination-log",
}

primaryContainerTemplate := v1.Container{
Name: primaryContainerTemplateName,
TerminationMessagePath: "/dev/primary-termination-log",
}

podTemplateSpec := v1.PodSpec{
Containers: []v1.Container{
defaultContainerTemplate,
primaryContainerTemplate,
},
HostNetwork: true,
NodeSelector: map[string]string{
"foo": "bar",
},
SchedulerName: "defaultScheduler",
Tolerations: []v1.Toleration{
v1.Toleration{
Key: "foo",
},
},
}

mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "foo")
assert.Nil(t, err)

// validate a PodTemplate-only field
assert.Equal(t, podTemplateSpec.HostNetwork, mergedPodSpec.HostNetwork)
// validate a PodSpec-only field
assert.Equal(t, podSpec.Priority, mergedPodSpec.Priority)
// validate an overwritten PodTemplate field
assert.Equal(t, podSpec.SchedulerName, mergedPodSpec.SchedulerName)
// validate a merged map
assert.Equal(t, len(podTemplateSpec.NodeSelector)+len(podSpec.NodeSelector), len(mergedPodSpec.NodeSelector))
// validate an appended array
assert.Equal(t, len(podTemplateSpec.Tolerations)+len(podSpec.Tolerations), len(mergedPodSpec.Tolerations))

// validate primary container
primaryContainer := mergedPodSpec.Containers[0]
assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name)
assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath)

// validate default container
defaultContainer := mergedPodSpec.Containers[1]
assert.Equal(t, podSpec.Containers[1].Name, defaultContainer.Name)
assert.Equal(t, defaultContainerTemplate.TerminationMessagePath, defaultContainer.TerminationMessagePath)

}

func TestBuildPodWithSpec2(t *testing.T) {
podSpec := v1.PodSpec{
Containers: []v1.Container{
v1.Container{
Name: "foo",
},
v1.Container{
Name: "bar",
},
},
}

pod, err := BuildPodWithSpec(nil, &podSpec, "foo")
assert.Nil(t, err)
assert.True(t, reflect.DeepEqual(pod.Spec, podSpec))

primaryContainerTemplate := v1.Container{
Name: primaryContainerTemplateName,
TerminationMessagePath: "/dev/primary-termination-log",
}

podTemplate := v1.PodTemplate{
Template: v1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"fooKey": "barVal",
},
},
Spec: v1.PodSpec{
Containers: []v1.Container{
primaryContainerTemplate,
},
},
},
}

pod, err = BuildPodWithSpec(&podTemplate, &podSpec, "foo")
assert.Nil(t, err)

// Test that template podSpec is merged
primaryContainer := pod.Spec.Containers[0]
assert.Equal(t, podSpec.Containers[0].Name, primaryContainer.Name)
assert.Equal(t, primaryContainerTemplate.TerminationMessagePath, primaryContainer.TerminationMessagePath)

// Test that template object metadata is copied
assert.Contains(t, pod.ObjectMeta.Labels, "fooKey")
assert.Equal(t, pod.ObjectMeta.Labels["fooKey"], "barVal")
}

func TestBuildPodWithSpec(t *testing.T) {
hamersaw marked this conversation as resolved.
Show resolved Hide resolved
var priority int32 = 1
podSpec := v1.PodSpec{
Expand Down
21 changes: 19 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}

common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.MPIJobDefaultContainerName)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hamersaw Do you think adding this line is incorrect?
As opposed to the pytorch and tfjob plugins, the container name is currently not overridden with mpi even though this is the default container name (see here and here).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. I'm not seeing the GetDefaultContainerName function being used anywhere. I'm wonder if either, it doesn't care about the container name or if the launcher pod automatically updates container names to reflect this.

cc @bimtauer is this something you know anything about? or can review this update?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hamersaw I will try and have a look by the end of the week!


podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace())

objectMeta := metav1.ObjectMeta{}
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

if podTemplate != nil {
mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.MPIJobDefaultContainerName)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error())
}
podSpec = mergedPodSpec
objectMeta = podTemplate.Template.ObjectMeta
}

// workersPodSpec is deepCopy of podSpec submitted by flyte
// WorkerPodSpec doesn't need any Argument & command. It will be trigger from launcher pod
workersPodSpec := podSpec.DeepCopy()
Expand All @@ -89,14 +104,16 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu
kubeflowv1.MPIJobReplicaTypeLauncher: {
Replicas: &launcherReplicas,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
kubeflowv1.MPIJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *workersPodSpec,
ObjectMeta: objectMeta,
Spec: *workersPodSpec,
},
RestartPolicy: commonKf.RestartPolicyNever,
},
Expand Down
19 changes: 17 additions & 2 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,35 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx

common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName)
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

podTemplate := flytek8s.DefaultPodTemplateStore.LoadOrDefault(taskCtx.TaskExecutionMetadata().GetNamespace())

objectMeta := metav1.ObjectMeta{}
hamersaw marked this conversation as resolved.
Show resolved Hide resolved

if podTemplate != nil {
mergedPodSpec, err := flytek8s.MergePodSpecs(&podTemplate.Template.Spec, podSpec, kubeflowv1.PytorchJobDefaultContainerName)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to merge default pod template: [%v]", err.Error())
}
podSpec = mergedPodSpec
objectMeta = podTemplate.Template.ObjectMeta
}

workers := pytorchTaskExtraArgs.GetWorkers()

jobSpec := kubeflowv1.PyTorchJobSpec{
PyTorchReplicaSpecs: map[commonOp.ReplicaType]*commonOp.ReplicaSpec{
kubeflowv1.PyTorchJobReplicaTypeMaster: {
Template: v1.PodTemplateSpec{
Spec: *podSpec,
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
kubeflowv1.PyTorchJobReplicaTypeWorker: {
Replicas: &workers,
Template: v1.PodTemplateSpec{
Spec: *podSpec,
ObjectMeta: objectMeta,
Spec: *podSpec,
},
RestartPolicy: commonOp.RestartPolicyNever,
},
Expand Down
Loading