Skip to content

Commit

Permalink
TAS: Support Kubeflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi committed Nov 4, 2024
1 parent 6047afe commit 1b3f079
Show file tree
Hide file tree
Showing 15 changed files with 434 additions and 19 deletions.
14 changes: 8 additions & 6 deletions pkg/controller/jobframework/base_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"

"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
Expand Down Expand Up @@ -67,11 +66,11 @@ func (w *BaseWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (a
job := w.FromObject(obj)
log := ctrl.LoggerFrom(ctx)
log.V(5).Info("Validating create", "job", klog.KObj(job.Object()))
return nil, validateCreate(job).ToAggregate()
}

func validateCreate(job GenericJob) field.ErrorList {
return ValidateJobOnCreate(job)
allErrs := ValidateJobOnCreate(job)
if jobWithValidation, ok := job.(JobWithValidation); ok {
allErrs = append(allErrs, jobWithValidation.ValidateOnCreate()...)
}
return nil, allErrs.ToAggregate()
}

// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
Expand All @@ -81,6 +80,9 @@ func (w *BaseWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime
log := ctrl.LoggerFrom(ctx).WithName("mxjob-webhook")
log.Info("Validating update", "job", klog.KObj(newJob.Object()))
allErrs := ValidateJobOnUpdate(oldJob, newJob)
if jobWithValidation, ok := newJob.(JobWithValidation); ok {
allErrs = append(allErrs, jobWithValidation.ValidateOnUpdate(oldJob)...)
}
return nil, allErrs.ToAggregate()
}

Expand Down
10 changes: 10 additions & 0 deletions pkg/controller/jobframework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/client-go/tools/record"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -107,6 +108,15 @@ type JobWithPriorityClass interface {
PriorityClass() string
}

// JobWithValidation optional interface that allows custom webhook validation
// for Jobs that use BaseWebhook.
type JobWithValidation interface {
// ValidateOnCreate returns list of webhook create validation errors.
ValidateOnCreate() field.ErrorList
// ValidateOnUpdate returns list of webhook update validation errors.
ValidateOnUpdate(oldJob GenericJob) field.ErrorList
}

// ComposableJob interface should be implemented by generic jobs that
// are composed out of multiple API objects.
type ComposableJob interface {
Expand Down
4 changes: 4 additions & 0 deletions pkg/controller/jobs/kubeflow/jobs/mxjob/mxjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.MXReplicaSpecs
}

func (j *JobControl) ReplicaSpecsFieldName() string {
return "mxReplicaSpecs"
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
112 changes: 112 additions & 0 deletions pkg/controller/jobs/kubeflow/jobs/mxjob/mxjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ limitations under the License.
package mxjob

import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/client-go/tools/record"
"k8s.io/utils/ptr"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/reconcile"

kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
controllerconsts "sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
Expand Down Expand Up @@ -295,6 +298,115 @@ func TestOrderedReplicaTypes(t *testing.T) {
}
}

func TestPodSets(t *testing.T) {
testCases := map[string]struct {
job *kftraining.MXJob
wantPodSets func(job *kftraining.MXJob) []kueue.PodSet
}{
"no annotations": {
job: testingmxjob.MakeMXJob("mxjob", "ns").Obj(),
wantPodSets: func(job *kftraining.MXJob) []kueue.PodSet {
return []kueue.PodSet{
{
Name: strings.ToLower(string(kftraining.MXJobReplicaTypeScheduler)),
Template: job.Spec.MXReplicaSpecs[kftraining.MXJobReplicaTypeScheduler].Template,
Count: 1,
},
{
Name: strings.ToLower(string(kftraining.MXJobReplicaTypeServer)),
Template: job.Spec.MXReplicaSpecs[kftraining.MXJobReplicaTypeServer].Template,
Count: 1,
},
{
Name: strings.ToLower(string(kftraining.MXJobReplicaTypeWorker)),
Template: job.Spec.MXReplicaSpecs[kftraining.MXJobReplicaTypeWorker].Template,
Count: 1,
},
}
},
},
"with required and preferred topology annotation": {
job: testingmxjob.MakeMXJob("mxjob", "ns").
PodAnnotation(kftraining.MXJobReplicaTypeScheduler, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/rack").
PodAnnotation(kftraining.MXJobReplicaTypeServer, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").
Obj(),
wantPodSets: func(job *kftraining.MXJob) []kueue.PodSet {
return []kueue.PodSet{
{
Name: strings.ToLower(string(kftraining.MXJobReplicaTypeScheduler)),
Template: job.Spec.MXReplicaSpecs[kftraining.MXJobReplicaTypeScheduler].Template,
Count: 1,
TopologyRequest: &kueue.PodSetTopologyRequest{Required: ptr.To("cloud.com/rack")},
},
{
Name: strings.ToLower(string(kftraining.MXJobReplicaTypeServer)),
Template: job.Spec.MXReplicaSpecs[kftraining.MXJobReplicaTypeServer].Template,
Count: 1,
TopologyRequest: &kueue.PodSetTopologyRequest{Preferred: ptr.To("cloud.com/block")},
},
{
Name: strings.ToLower(string(kftraining.MXJobReplicaTypeWorker)),
Template: job.Spec.MXReplicaSpecs[kftraining.MXJobReplicaTypeWorker].Template,
Count: 1,
},
}
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
gotPodSets := fromObject(tc.job).PodSets()
if diff := cmp.Diff(tc.wantPodSets(tc.job), gotPodSets); diff != "" {
t.Errorf("pod sets mismatch (-want +got):\n%s", diff)
}
})
}
}

func TestValidate(t *testing.T) {
annotationsPath := field.NewPath("spec", "mxReplicaSpecs").
Key("Scheduler").
Child("template", "metadata", "annotations")

testCases := map[string]struct {
job *kftraining.MXJob
wantErrs field.ErrorList
}{
"no annotations": {
job: testingmxjob.MakeMXJob("mxjob", "ns").Obj(),
wantErrs: nil,
},
"valid TAS request": {
job: testingmxjob.MakeMXJob("mxjob", "ns").
PodAnnotation(kftraining.MXJobReplicaTypeScheduler, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/rack").
PodAnnotation(kftraining.MXJobReplicaTypeServer, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").
Obj(),
wantErrs: nil,
},
"invalid TAS request": {
job: testingmxjob.MakeMXJob("mxjob", "ns").
PodAnnotation(kftraining.MXJobReplicaTypeScheduler, kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/rack").
PodAnnotation(kftraining.MXJobReplicaTypeScheduler, kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").
Obj(),
wantErrs: field.ErrorList{
field.Invalid(annotationsPath, field.OmitValueType{},
`must not contain both "kueue.x-k8s.io/podset-required-topology" and "kueue.x-k8s.io/podset-preferred-topology"`),
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
if diff := cmp.Diff(tc.wantErrs, fromObject(tc.job).ValidateOnCreate()); diff != "" {
t.Errorf("validate create error list mismatch (-want +got):\n%s", diff)
}

if diff := cmp.Diff(tc.wantErrs, fromObject(tc.job).ValidateOnUpdate(nil)); diff != "" {
t.Errorf("validate create error list mismatch (-want +got):\n%s", diff)
}
})
}
}

var (
jobCmpOpts = cmp.Options{
cmpopts.EquateEmpty(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.PaddleReplicaSpecs
}

func (j *JobControl) ReplicaSpecsFieldName() string {
return "paddleReplicaSpecs"
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.PyTorchReplicaSpecs
}

func (j *JobControl) ReplicaSpecsFieldName() string {
return "pytorchReplicaSpecs"
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.TFReplicaSpecs
}

func (j *JobControl) ReplicaSpecsFieldName() string {
return "tfReplicaSpecs"
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.XGBReplicaSpecs
}

func (j *JobControl) ReplicaSpecsFieldName() string {
return "xgbReplicaSpecs"
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/controller/jobs/kubeflow/kubeflowjob/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type KFJobControl interface {
RunPolicy() *kftraining.RunPolicy
// ReplicaSpecs returns the ReplicaSpecs for the KFJob.
ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.ReplicaSpec
// ReplicaSpecsFieldName returns the field name of the ReplicaSpecs.
ReplicaSpecsFieldName() string
// JobStatus returns the JobStatus for the KFJob.
JobStatus() *kftraining.JobStatus
// OrderedReplicaTypes returns the ordered list of ReplicaTypes for the KFJob.
Expand Down
26 changes: 23 additions & 3 deletions pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand All @@ -36,6 +37,7 @@ type KubeflowJob struct {

var _ jobframework.GenericJob = (*KubeflowJob)(nil)
var _ jobframework.JobWithPriorityClass = (*KubeflowJob)(nil)
var _ jobframework.JobWithValidation = (*KubeflowJob)(nil)

func (j *KubeflowJob) Object() client.Object {
return j.KFJobControl.Object()
Expand Down Expand Up @@ -99,9 +101,10 @@ func (j *KubeflowJob) PodSets() []kueue.PodSet {
podSets := make([]kueue.PodSet, len(replicaTypes))
for index, replicaType := range replicaTypes {
podSets[index] = kueue.PodSet{
Name: strings.ToLower(string(replicaType)),
Template: *j.KFJobControl.ReplicaSpecs()[replicaType].Template.DeepCopy(),
Count: podsCount(j.KFJobControl.ReplicaSpecs(), replicaType),
Name: strings.ToLower(string(replicaType)),
Template: *j.KFJobControl.ReplicaSpecs()[replicaType].Template.DeepCopy(),
Count: podsCount(j.KFJobControl.ReplicaSpecs(), replicaType),
TopologyRequest: jobframework.PodSetTopologyRequest(&j.KFJobControl.ReplicaSpecs()[replicaType].Template.ObjectMeta),
}
}
return podSets
Expand Down Expand Up @@ -168,6 +171,23 @@ func (j *KubeflowJob) OrderedReplicaTypes() []kftraining.ReplicaType {
return result
}

func (j *KubeflowJob) ValidateOnCreate() field.ErrorList {
var allErrs field.ErrorList
replicaTypes := j.OrderedReplicaTypes()
for _, replicaType := range replicaTypes {
replicaSpecsPath := field.NewPath("spec", j.KFJobControl.ReplicaSpecsFieldName())
allErrs = append(allErrs, jobframework.ValidateTASPodSetRequest(
replicaSpecsPath.Key(string(replicaType)).Child("template", "metadata"),
&j.KFJobControl.ReplicaSpecs()[replicaType].Template.ObjectMeta,
)...)
}
return allErrs
}

func (j *KubeflowJob) ValidateOnUpdate(_ jobframework.GenericJob) field.ErrorList {
return j.ValidateOnCreate()
}

func podsCount(replicaSpecs map[kftraining.ReplicaType]*kftraining.ReplicaSpec, replicaType kftraining.ReplicaType) int32 {
return ptr.Deref(replicaSpecs[replicaType].Replicas, 1)
}
9 changes: 9 additions & 0 deletions pkg/util/testingjobs/mxjob/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ func (j *MXJobWrapper) Annotations(annotations map[string]string) *MXJobWrapper
return j
}

// PodAnnotation sets annotation at the pod template level
func (j *MXJobWrapper) PodAnnotation(replicaType kftraining.ReplicaType, k, v string) *MXJobWrapper {
if j.Spec.MXReplicaSpecs[replicaType].Template.Annotations == nil {
j.Spec.MXReplicaSpecs[replicaType].Template.Annotations = make(map[string]string)
}
j.Spec.MXReplicaSpecs[replicaType].Template.Annotations[k] = v
return j
}

// Request adds a resource request to the default container.
func (j *MXJobWrapper) Request(replicaType kftraining.ReplicaType, r corev1.ResourceName, v string) *MXJobWrapper {
j.Spec.MXReplicaSpecs[replicaType].Template.Spec.Containers[0].Resources.Requests[r] = resource.MustParse(v)
Expand Down
Loading

0 comments on commit 1b3f079

Please sign in to comment.