Skip to content

Commit

Permalink
TAS: Support Kubeflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi committed Nov 1, 2024
1 parent 6047afe commit 3ac1ad4
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 32 deletions.
7 changes: 6 additions & 1 deletion pkg/controller/jobs/kubeflow/jobs/mxjob/mxjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package mxjob
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/validation/field"
"strings"

kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -37,7 +38,7 @@ var (
gvk = kftraining.SchemeGroupVersion.WithKind(kftraining.MXJobKind)
FrameworkName = "kubeflow.org/mxjob"

SetupMXJobWebhook = jobframework.BaseWebhookFactory(
SetupMXJobWebhook = kubeflowjob.KubeflowJobWebhookFactory(
NewJob(),
func(o runtime.Object) jobframework.GenericJob {
return fromObject(o)
Expand Down Expand Up @@ -109,6 +110,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.MXReplicaSpecs
}

func (j *JobControl) ReplicaSpecsValidationPath() *field.Path {
return field.NewPath("spec", "mxReplicaSpecs")
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package paddlejob
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/validation/field"
"strings"

kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -37,7 +38,7 @@ var (
gvk = kftraining.SchemeGroupVersion.WithKind(kftraining.PaddleJobKind)
FrameworkName = "kubeflow.org/paddlejob"

SetupPaddleJobWebhook = jobframework.BaseWebhookFactory(
SetupPaddleJobWebhook = kubeflowjob.KubeflowJobWebhookFactory(
NewJob(),
func(o runtime.Object) jobframework.GenericJob {
return fromObject(o)
Expand Down Expand Up @@ -112,6 +113,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.PaddleReplicaSpecs
}

func (j *JobControl) ReplicaSpecsValidationPath() *field.Path {
return field.NewPath("spec", "paddleReplicaSpecs")
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package pytorchjob
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/validation/field"
"strings"

kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -37,7 +38,7 @@ var (
gvk = kftraining.SchemeGroupVersion.WithKind(kftraining.PyTorchJobKind)
FrameworkName = "kubeflow.org/pytorchjob"

SetupPyTorchJobWebhook = jobframework.BaseWebhookFactory(
SetupPyTorchJobWebhook = kubeflowjob.KubeflowJobWebhookFactory(
NewJob(),
func(o runtime.Object) jobframework.GenericJob {
return fromObject(o)
Expand Down Expand Up @@ -112,6 +113,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.PyTorchReplicaSpecs
}

func (j *JobControl) ReplicaSpecsValidationPath() *field.Path {
return field.NewPath("spec", "pytorchReplicaSpecs")
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
7 changes: 6 additions & 1 deletion pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package tfjob
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/validation/field"
"strings"

kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -37,7 +38,7 @@ var (
gvk = kftraining.SchemeGroupVersion.WithKind(kftraining.TFJobKind)
FrameworkName = "kubeflow.org/tfjob"

SetupTFJobWebhook = jobframework.BaseWebhookFactory(
SetupTFJobWebhook = kubeflowjob.KubeflowJobWebhookFactory(
NewJob(),
func(o runtime.Object) jobframework.GenericJob {
return fromObject(o)
Expand Down Expand Up @@ -112,6 +113,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.TFReplicaSpecs
}

func (j *JobControl) ReplicaSpecsValidationPath() *field.Path {
return field.NewPath("spec", "tfReplicaSpecs")
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package xgboostjob
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/validation/field"
"strings"

kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
Expand All @@ -37,7 +38,7 @@ var (
gvk = kftraining.SchemeGroupVersion.WithKind(kftraining.XGBoostJobKind)
FrameworkName = "kubeflow.org/xgboostjob"

SetupXGBoostJobWebhook = jobframework.BaseWebhookFactory(
SetupXGBoostJobWebhook = kubeflowjob.KubeflowJobWebhookFactory(
NewJob(),
func(o runtime.Object) jobframework.GenericJob {
return fromObject(o)
Expand Down Expand Up @@ -112,6 +113,10 @@ func (j *JobControl) ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.Repli
return j.Spec.XGBReplicaSpecs
}

func (j *JobControl) ReplicaSpecsValidationPath() *field.Path {
return field.NewPath("spec", "xgbReplicaSpecs")
}

func (j *JobControl) JobStatus() *kftraining.JobStatus {
return &j.Status
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/controller/jobs/kubeflow/kubeflowjob/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package kubeflowjob
import (
kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
)

Expand All @@ -31,6 +32,8 @@ type KFJobControl interface {
RunPolicy() *kftraining.RunPolicy
// ReplicaSpecs returns the ReplicaSpecs for the KFJob.
ReplicaSpecs() map[kftraining.ReplicaType]*kftraining.ReplicaSpec
// ReplicaSpecsValidationPath returns the field.Path for the ReplicaSpecs.
ReplicaSpecsValidationPath() *field.Path
// JobStatus returns the JobStatus for the KFJob.
JobStatus() *kftraining.JobStatus
// OrderedReplicaTypes returns the ordered list of ReplicaTypes for the KFJob.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ func (j *KubeflowJob) PodSets() []kueue.PodSet {
podSets := make([]kueue.PodSet, len(replicaTypes))
for index, replicaType := range replicaTypes {
podSets[index] = kueue.PodSet{
Name: strings.ToLower(string(replicaType)),
Template: *j.KFJobControl.ReplicaSpecs()[replicaType].Template.DeepCopy(),
Count: podsCount(j.KFJobControl.ReplicaSpecs(), replicaType),
Name: strings.ToLower(string(replicaType)),
Template: *j.KFJobControl.ReplicaSpecs()[replicaType].Template.DeepCopy(),
Count: podsCount(j.KFJobControl.ReplicaSpecs(), replicaType),
TopologyRequest: jobframework.PodSetTopologyRequest(&j.KFJobControl.ReplicaSpecs()[replicaType].Template.ObjectMeta),
}
}
return podSets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package jobframework
package kubeflowjob

import (
"context"
Expand All @@ -23,21 +23,22 @@ import (
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/klog/v2"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/controller/jobframework/webhook"
)

// BaseWebhook applies basic defaulting and validation for jobs.
type BaseWebhook struct {
// KubeflowJobWebhook applies basic defaulting and validation for jobs.
type KubeflowJobWebhook struct {
ManageJobsWithoutQueueName bool
FromObject func(runtime.Object) GenericJob
FromObject func(runtime.Object) jobframework.GenericJob
}

func BaseWebhookFactory(job GenericJob, fromObject func(runtime.Object) GenericJob) func(ctrl.Manager, ...Option) error {
return func(mgr ctrl.Manager, opts ...Option) error {
options := ProcessOptions(opts...)
wh := &BaseWebhook{
func KubeflowJobWebhookFactory(job jobframework.GenericJob, fromObject func(runtime.Object) jobframework.GenericJob) func(ctrl.Manager, ...jobframework.Option) error {
return func(mgr ctrl.Manager, opts ...jobframework.Option) error {
options := jobframework.ProcessOptions(opts...)
wh := &KubeflowJobWebhook{
ManageJobsWithoutQueueName: options.ManageJobsWithoutQueueName,
FromObject: fromObject,
}
Expand All @@ -49,42 +50,55 @@ func BaseWebhookFactory(job GenericJob, fromObject func(runtime.Object) GenericJ
}
}

var _ admission.CustomDefaulter = &BaseWebhook{}
var _ admission.CustomDefaulter = &KubeflowJobWebhook{}

// Default implements webhook.CustomDefaulter so a webhook will be registered for the type
func (w *BaseWebhook) Default(ctx context.Context, obj runtime.Object) error {
func (w *KubeflowJobWebhook) Default(ctx context.Context, obj runtime.Object) error {
job := w.FromObject(obj)
log := ctrl.LoggerFrom(ctx)
log.V(5).Info("Applying defaults", "job", klog.KObj(job.Object()))
ApplyDefaultForSuspend(job, w.ManageJobsWithoutQueueName)
jobframework.ApplyDefaultForSuspend(job, w.ManageJobsWithoutQueueName)
return nil
}

var _ admission.CustomValidator = &BaseWebhook{}
var _ admission.CustomValidator = &KubeflowJobWebhook{}

// ValidateCreate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *BaseWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
func (w *KubeflowJobWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
job := w.FromObject(obj)
log := ctrl.LoggerFrom(ctx)
log.V(5).Info("Validating create", "job", klog.KObj(job.Object()))
return nil, validateCreate(job).ToAggregate()
}

func validateCreate(job GenericJob) field.ErrorList {
return ValidateJobOnCreate(job)
allErrs := jobframework.ValidateJobOnCreate(job)
allErrs = append(allErrs, validateTopologyRequest(job)...)
return nil, allErrs.ToAggregate()
}

// ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
func (w *BaseWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
func (w *KubeflowJobWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
oldJob := w.FromObject(oldObj)
newJob := w.FromObject(newObj)
log := ctrl.LoggerFrom(ctx).WithName("mxjob-webhook")
log.Info("Validating update", "job", klog.KObj(newJob.Object()))
allErrs := ValidateJobOnUpdate(oldJob, newJob)
allErrs := jobframework.ValidateJobOnUpdate(oldJob, newJob)
allErrs = append(allErrs, validateTopologyRequest(oldJob)...)
return nil, allErrs.ToAggregate()
}

func validateTopologyRequest(job jobframework.GenericJob) field.ErrorList {
var allErrs field.ErrorList
if kfjob, ok := job.(*KubeflowJob); ok {
replicaTypes := kfjob.OrderedReplicaTypes()
for _, replicaType := range replicaTypes {
allErrs = append(allErrs, jobframework.ValidateTASPodSetRequest(
kfjob.KFJobControl.ReplicaSpecsValidationPath(),
&kfjob.KFJobControl.ReplicaSpecs()[replicaType].Template.ObjectMeta,
)...)
}
}
return allErrs
}

// ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type
func (w *BaseWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) {
func (w *KubeflowJobWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) {
return nil, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package jobframework_test
package kubeflowjob_test

import (
"context"
"sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -51,7 +52,7 @@ func TestBaseWebhookDefault(t *testing.T) {
}
for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
w := &jobframework.BaseWebhook{
w := &kubeflowjob.KubeflowJobWebhook{
ManageJobsWithoutQueueName: tc.manageJobsWithoutQueueName,
FromObject: toMPIJob,
}
Expand Down

0 comments on commit 3ac1ad4

Please sign in to comment.