Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Change kubeflow plugins to allow settings specs for different replica #345

Merged
merged 11 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v1.3.19
github.com/flyteorg/flyteidl v1.5.2
github.com/flyteorg/flytestdlib v1.0.15
github.com/go-test/deep v1.0.7
github.com/golang/protobuf v1.5.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/flyteorg/flyteidl v1.3.19 h1:i79Dh7UoP8Z4LEJ2ox6jlfZVJtFZ+r4g84CJj1gh22Y=
github.com/flyteorg/flyteidl v1.3.19/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM=
github.com/flyteorg/flyteidl v1.5.2 h1:DZPzYkTg92qA4e17fd0ZW1M+gh1gJKh/VOK+F4bYgM8=
github.com/flyteorg/flyteidl v1.5.2/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I=
github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=
github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk=
Expand Down
77 changes: 77 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"sort"
"time"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/tasklog"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
kfplugins "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins/kubeflow"
flyteerr "github.com/flyteorg/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyteplugins/go/tasks/logs"
pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
Expand All @@ -21,6 +23,12 @@ const (
PytorchTaskType = "pytorch"
)

type ReplicaEntry struct {
PodSpec *v1.PodSpec
ReplicaNum int32
RestartPolicy commonOp.RestartPolicy
}

// ExtractMPICurrentCondition will return the first job condition for MPI
func ExtractMPICurrentCondition(jobConditions []commonOp.JobCondition) (commonOp.JobCondition, error) {
if jobConditions != nil {
Expand Down Expand Up @@ -180,3 +188,72 @@ func OverridePrimaryContainerName(podSpec *v1.PodSpec, primaryContainerName stri
}
}
}

// ParseRunPolicy converts a kubeflow plugin RunPolicy object to a k8s RunPolicy object.
func ParseRunPolicy(flyteRunPolicy kfplugins.RunPolicy) commonOp.RunPolicy {
runPolicy := commonOp.RunPolicy{}
if flyteRunPolicy.GetBackoffLimit() != 0 {
var backoffLimit = flyteRunPolicy.GetBackoffLimit()
runPolicy.BackoffLimit = &backoffLimit
}
var cleanPodPolicy = ParseCleanPodPolicy(flyteRunPolicy.GetCleanPodPolicy())
runPolicy.CleanPodPolicy = &cleanPodPolicy
if flyteRunPolicy.GetActiveDeadlineSeconds() != 0 {
var ddlSeconds = int64(flyteRunPolicy.GetActiveDeadlineSeconds())
runPolicy.ActiveDeadlineSeconds = &ddlSeconds
}
if flyteRunPolicy.GetTtlSecondsAfterFinished() != 0 {
var ttl = flyteRunPolicy.GetTtlSecondsAfterFinished()
runPolicy.TTLSecondsAfterFinished = &ttl
}

return runPolicy
}

// Get k8s clean pod policy from flyte kubeflow plugins clean pod policy.
func ParseCleanPodPolicy(flyteCleanPodPolicy kfplugins.CleanPodPolicy) commonOp.CleanPodPolicy {
cleanPodPolicyMap := map[kfplugins.CleanPodPolicy]commonOp.CleanPodPolicy{
kfplugins.CleanPodPolicy_CLEANPOD_POLICY_NONE: commonOp.CleanPodPolicyNone,
kfplugins.CleanPodPolicy_CLEANPOD_POLICY_ALL: commonOp.CleanPodPolicyAll,
kfplugins.CleanPodPolicy_CLEANPOD_POLICY_RUNNING: commonOp.CleanPodPolicyRunning,
}
return cleanPodPolicyMap[flyteCleanPodPolicy]
}

// Get k8s restart policy from flyte kubeflow plugins restart policy.
func ParseRestartPolicy(flyteRestartPolicy kfplugins.RestartPolicy) commonOp.RestartPolicy {
restartPolicyMap := map[kfplugins.RestartPolicy]commonOp.RestartPolicy{
kfplugins.RestartPolicy_RESTART_POLICY_NEVER: commonOp.RestartPolicyNever,
kfplugins.RestartPolicy_RESTART_POLICY_ON_FAILURE: commonOp.RestartPolicyOnFailure,
kfplugins.RestartPolicy_RESTART_POLICY_ALWAYS: commonOp.RestartPolicyAlways,
}
return restartPolicyMap[flyteRestartPolicy]
}

// OverrideContainerSpec overrides the specified container's properties in the given podSpec. The function
// updates the image, resources and command arguments of the container that matches the given containerName.
func OverrideContainerSpec(podSpec *v1.PodSpec, containerName string, image string, resources *core.Resources, args []string) (*v1.PodSpec, error) {
for idx, c := range podSpec.Containers {
if c.Name == containerName {
if image != "" {
podSpec.Containers[idx].Image = image
}
if resources != nil {
// if resources requests and limits both not set, we will not override the resources
if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 {
resources, err := flytek8s.ToK8sResourceRequirements(resources)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecificat ion on Resources [%v], Err: [%v]", resources, err.Error())
}
podSpec.Containers[idx].Resources = *resources
hamersaw marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
podSpec.Containers[idx].Resources = v1.ResourceRequirements{}
}
if len(args) != 0 {
podSpec.Containers[idx].Args = args
}
}
}
return podSpec, nil
}
101 changes: 101 additions & 0 deletions go/tasks/plugins/k8s/kfoperators/common/common_operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ import (
"testing"
"time"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyteplugins/go/tasks/logs"

pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core"
commonOp "github.com/kubeflow/common/pkg/apis/common/v1"
"github.com/stretchr/testify/assert"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)

func TestExtractMPICurrentCondition(t *testing.T) {
Expand Down Expand Up @@ -183,3 +186,101 @@ func TestGetLogs(t *testing.T) {
assert.Equal(t, fmt.Sprintf("k8s.com/#!/log/%s/%s-chiefReplica-0/pod?namespace=tensorflow-namespace", "tensorflow-namespace", "test"), jobLogs[2].Uri)

}

func dummyPodSpec() v1.PodSpec {
return v1.PodSpec{
Containers: []v1.Container{
{
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"cpu": resource.MustParse("2"),
"memory": resource.MustParse("200Mi"),
"gpu": resource.MustParse("1"),
},
Requests: v1.ResourceList{
"cpu": resource.MustParse("1"),
"memory": resource.MustParse("100Mi"),
"gpu": resource.MustParse("1"),
},
},
VolumeMounts: []v1.VolumeMount{
{
Name: "volume mount",
},
},
},
{
Name: "secondary container",
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"gpu": resource.MustParse("2"),
},
Requests: v1.ResourceList{
"gpu": resource.MustParse("2"),
},
},
},
},
Volumes: []v1.Volume{
{
Name: "dshm",
},
},
Tolerations: []v1.Toleration{
{
Key: "my toleration key",
Value: "my toleration value",
},
},
}
}

func TestOverrideContainerSpec(t *testing.T) {
podSpec := dummyPodSpec()
_, err := OverrideContainerSpec(
&podSpec, "primary container", "testing-image",
&core.Resources{
Requests: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "250m"},
},
Limits: []*core.Resources_ResourceEntry{
{Name: core.Resources_CPU, Value: "500m"},
},
},
[]string{"python", "-m", "run.py"},
)
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Equal(t, "testing-image", podSpec.Containers[0].Image)
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overriden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("250m")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("500m")))
assert.Equal(t, []string{"python", "-m", "run.py"}, podSpec.Containers[0].Args)
}

func TestOverrideContainerSpecEmptyFields(t *testing.T) {
podSpec := dummyPodSpec()
_, err := OverrideContainerSpec(&podSpec, "primary container", "", &core.Resources{}, []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.NotNil(t, podSpec.Containers[0].Resources.Limits)
assert.NotNil(t, podSpec.Containers[0].Resources.Requests)
// verify resources not overriden if empty resources
assert.True(t, podSpec.Containers[0].Resources.Requests.Cpu().Equal(resource.MustParse("1")))
assert.True(t, podSpec.Containers[0].Resources.Requests.Memory().Equal(resource.MustParse("100Mi")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Cpu().Equal(resource.MustParse("2")))
assert.True(t, podSpec.Containers[0].Resources.Limits.Memory().Equal(resource.MustParse("200Mi")))
}

func TestOverrideContainerNilResources(t *testing.T) {
podSpec := dummyPodSpec()
_, err := OverrideContainerSpec(&podSpec, "primary container", "", nil, []string{})
assert.NoError(t, err)
assert.Equal(t, 2, len(podSpec.Containers))
assert.Nil(t, podSpec.Containers[0].Resources.Limits)
assert.Nil(t, podSpec.Containers[0].Resources.Requests)
}
Loading