Skip to content

Commit

Permalink
TAS: Support Pod. (kubernetes-sigs#3402)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbobrovskyi authored and kannon92 committed Nov 19, 2024
1 parent 9a2f93c commit 6d90178
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ const (

DefaultPendingWorkloadsLimit = 1000

// Label that signalize that an object is managed by Kueue
// ManagedByKueueLabel label that signalize that an object is managed by Kueue
ManagedByKueueLabel = "kueue.x-k8s.io/managed"
)
8 changes: 4 additions & 4 deletions pkg/controller/jobframework/tas.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ limitations under the License.
package jobframework

import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"

kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
)

func PodSetTopologyRequest(template *corev1.PodTemplateSpec) *kueue.PodSetTopologyRequest {
requiredValue, requiredFound := template.Annotations[kueuealpha.PodSetRequiredTopologyAnnotation]
func PodSetTopologyRequest(meta *metav1.ObjectMeta) *kueue.PodSetTopologyRequest {
requiredValue, requiredFound := meta.Annotations[kueuealpha.PodSetRequiredTopologyAnnotation]
if requiredFound {
return &kueue.PodSetTopologyRequest{
Required: ptr.To(requiredValue),
}
}
preferredValue, preferredFound := template.Annotations[kueuealpha.PodSetPreferredTopologyAnnotation]
preferredValue, preferredFound := meta.Annotations[kueuealpha.PodSetPreferredTopologyAnnotation]
if preferredFound {
return &kueue.PodSetTopologyRequest{
Preferred: ptr.To(preferredValue),
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobs/job/job_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func (j *Job) PodSets() []kueue.PodSet {
Template: *cleanManagedLabels(j.Spec.Template.DeepCopy()),
Count: j.podsCount(),
MinCount: j.minPodsCount(),
TopologyRequest: jobframework.PodSetTopologyRequest(&j.Spec.Template),
TopologyRequest: jobframework.PodSetTopologyRequest(&j.Spec.Template.ObjectMeta),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/jobs/jobset/jobset_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (j *JobSet) PodSets() []kueue.PodSet {
Name: replicatedJob.Name,
Template: *replicatedJob.Template.Spec.Template.DeepCopy(),
Count: podsCount(&replicatedJob),
TopologyRequest: jobframework.PodSetTopologyRequest(&replicatedJob.Template.Spec.Template),
TopologyRequest: jobframework.PodSetTopologyRequest(&replicatedJob.Template.Spec.Template.ObjectMeta),
}
}
return podSets
Expand Down
1 change: 1 addition & 0 deletions pkg/controller/jobs/pod/pod_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ func (p *Pod) PodSets() []kueue.PodSet {
Template: corev1.PodTemplateSpec{
Spec: *p.pod.Spec.DeepCopy(),
},
TopologyRequest: jobframework.PodSetTopologyRequest(&p.pod.ObjectMeta),
},
}
}
Expand Down
60 changes: 60 additions & 0 deletions pkg/controller/jobs/pod/pod_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client/interceptor"
"sigs.k8s.io/controller-runtime/pkg/reconcile"

kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/constants"
controllerconsts "sigs.k8s.io/kueue/pkg/controller/constants"
Expand Down Expand Up @@ -126,6 +127,65 @@ func TestRun(t *testing.T) {
}
}

func TestPodSets(t *testing.T) {
testCases := map[string]struct {
pod *Pod
wantPodSets func(pod *Pod) []kueue.PodSet
}{
"no annotations": {
pod: FromObject(testingpod.MakePod("pod", "ns").Obj()),
wantPodSets: func(pod *Pod) []kueue.PodSet {
return []kueue.PodSet{
{
Name: kueue.DefaultPodSetName,
Count: 1,
Template: corev1.PodTemplateSpec{Spec: *pod.pod.Spec.DeepCopy()},
},
}
},
},
"with required topology annotation": {
pod: FromObject(testingpod.MakePod("pod", "ns").
Annotation(kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block").
Obj(),
),
wantPodSets: func(pod *Pod) []kueue.PodSet {
return []kueue.PodSet{
{
Name: kueue.DefaultPodSetName,
Count: 1,
Template: corev1.PodTemplateSpec{Spec: *pod.pod.Spec.DeepCopy()},
TopologyRequest: &kueue.PodSetTopologyRequest{Required: ptr.To("cloud.com/block")},
},
}
},
},
"with required topology preferred": {
pod: FromObject(testingpod.MakePod("pod", "ns").
Annotation(kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").
Obj(),
),
wantPodSets: func(pod *Pod) []kueue.PodSet {
return []kueue.PodSet{
{
Name: kueue.DefaultPodSetName,
Count: 1,
Template: corev1.PodTemplateSpec{Spec: *pod.pod.Spec.DeepCopy()},
TopologyRequest: &kueue.PodSetTopologyRequest{Preferred: ptr.To("cloud.com/block")},
},
}
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
if diff := cmp.Diff(tc.wantPodSets(tc.pod), tc.pod.PodSets()); diff != "" {
t.Errorf("pod sets mismatch (-want +got):\n%s", diff)
}
})
}
}

var (
podCmpOpts = []cmp.Option{
cmpopts.EquateEmpty(),
Expand Down
29 changes: 18 additions & 11 deletions pkg/controller/jobs/pod/pod_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ const (
)

var (
labelsPath = field.NewPath("metadata", "labels")
annotationsPath = field.NewPath("metadata", "annotations")
metaPath = field.NewPath("metadata")
labelsPath = metaPath.Child("labels")
annotationsPath = metaPath.Child("annotations")
managedLabelPath = labelsPath.Key(ManagedLabelKey)
groupNameLabelPath = labelsPath.Key(GroupNameLabel)
groupTotalCountAnnotationPath = annotationsPath.Key(GroupTotalCountAnnotation)
Expand Down Expand Up @@ -203,11 +204,9 @@ func (w *PodWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (ad
pod := FromObject(obj)
log := ctrl.LoggerFrom(ctx).WithName("pod-webhook").WithValues("pod", klog.KObj(&pod.pod))
log.V(5).Info("Validating create")
allErrs := jobframework.ValidateJobOnCreate(pod)

allErrs = append(allErrs, validateManagedLabel(pod)...)

allErrs = append(allErrs, validatePodGroupMetadata(pod)...)
allErrs := jobframework.ValidateJobOnCreate(pod)
allErrs = append(allErrs, validateCommon(pod)...)

if warn := warningForPodManagedLabel(pod); warn != "" {
warnings = append(warnings, warn)
Expand All @@ -223,14 +222,11 @@ func (w *PodWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.
newPod := FromObject(newObj)
log := ctrl.LoggerFrom(ctx).WithName("pod-webhook").WithValues("pod", klog.KObj(&newPod.pod))
log.V(5).Info("Validating update")
allErrs := jobframework.ValidateJobOnUpdate(oldPod, newPod)

allErrs = append(allErrs, validateManagedLabel(newPod)...)
allErrs := jobframework.ValidateJobOnUpdate(oldPod, newPod)
allErrs = append(allErrs, validateCommon(newPod)...)

allErrs = append(allErrs, validation.ValidateImmutableField(podGroupName(newPod.pod), podGroupName(oldPod.pod), groupNameLabelPath)...)

allErrs = append(allErrs, validatePodGroupMetadata(newPod)...)

allErrs = append(allErrs, validateUpdateForRetriableInGroupAnnotation(oldPod, newPod)...)

if warn := warningForPodManagedLabel(newPod); warn != "" {
Expand All @@ -244,6 +240,13 @@ func (w *PodWebhook) ValidateDelete(context.Context, runtime.Object) (admission.
return nil, nil
}

func validateCommon(pod *Pod) field.ErrorList {
allErrs := validateManagedLabel(pod)
allErrs = append(allErrs, validatePodGroupMetadata(pod)...)
allErrs = append(allErrs, validateTopologyRequest(pod)...)
return allErrs
}

func validateManagedLabel(pod *Pod) field.ErrorList {
var allErrs field.ErrorList

Expand Down Expand Up @@ -298,6 +301,10 @@ func validatePodGroupMetadata(p *Pod) field.ErrorList {
return allErrs
}

func validateTopologyRequest(pod *Pod) field.ErrorList {
return jobframework.ValidateTASPodSetRequest(metaPath, &pod.pod.ObjectMeta)
}

func validateUpdateForRetriableInGroupAnnotation(oldPod, newPod *Pod) field.ErrorList {
if podGroupName(newPod.pod) != "" && isUnretriablePod(oldPod.pod) && !isUnretriablePod(newPod.pod) {
return field.ErrorList{
Expand Down
20 changes: 20 additions & 0 deletions pkg/controller/jobs/pod/pod_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

configapi "sigs.k8s.io/kueue/apis/config/v1beta1"
kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1"
"sigs.k8s.io/kueue/pkg/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
Expand Down Expand Up @@ -480,6 +481,25 @@ func TestValidateCreate(t *testing.T) {
},
}.ToAggregate(),
},
"valid topology request": {
pod: testingpod.MakePod("test-pod", "test-ns").
Label(constants.ManagedByKueueLabel, "true").
Annotation(kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block").
Obj(),
},
"invalid topology request": {
pod: testingpod.MakePod("test-pod", "test-ns").
Label(constants.ManagedByKueueLabel, "true").
Annotation(kueuealpha.PodSetRequiredTopologyAnnotation, "cloud.com/block").
Annotation(kueuealpha.PodSetPreferredTopologyAnnotation, "cloud.com/block").
Obj(),
wantErr: field.ErrorList{
&field.Error{
Type: field.ErrorTypeInvalid,
Field: "metadata.annotations",
},
}.ToAggregate(),
},
}

for name, tc := range testCases {
Expand Down
4 changes: 2 additions & 2 deletions pkg/controller/jobs/rayjob/rayjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (j *RayJob) PodSets() []kueue.PodSet {
Name: headGroupPodSetName,
Template: *j.Spec.RayClusterSpec.HeadGroupSpec.Template.DeepCopy(),
Count: 1,
TopologyRequest: jobframework.PodSetTopologyRequest(&j.Spec.RayClusterSpec.HeadGroupSpec.Template),
TopologyRequest: jobframework.PodSetTopologyRequest(&j.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta),
}

// workers
Expand All @@ -125,7 +125,7 @@ func (j *RayJob) PodSets() []kueue.PodSet {
Name: strings.ToLower(wgs.GroupName),
Template: *wgs.Template.DeepCopy(),
Count: replicas,
TopologyRequest: jobframework.PodSetTopologyRequest(&wgs.Template),
TopologyRequest: jobframework.PodSetTopologyRequest(&wgs.Template.ObjectMeta),
}
}
return podSets
Expand Down

0 comments on commit 6d90178

Please sign in to comment.