Skip to content

Commit

Permalink
PodSet label and Workload annotation for PodTemplates
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo committed Oct 14, 2024
1 parent f5eaeed commit c028214
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 0 deletions.
10 changes: 10 additions & 0 deletions pkg/controller/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,14 @@ const (

// ProvReqAnnotationPrefix is the prefix for annotations that should be pass to ProvisioningRequest as Parameters.
ProvReqAnnotationPrefix = "provreq.kueue.x-k8s.io/"

// WorkloadAnnotation is an annotation set on the Job's PodTemplate to
// indicate the name of the admitted Workload corresponding to the Job. The
// annotation is set when starting the Job, and removed on stopping the Job.
WorkloadAnnotation = "kueue.x-k8s.io/workload"

// PodSetLabel is a label set on the Job's PodTemplate to indicate the name
// of the PodSet of the admitted Workload corresponding to the PodTemplate.
// The label is set when starting the Job, and removed on stopping the Job.
PodSetLabel = "kueue.x-k8s.io/podset"
)
2 changes: 2 additions & 0 deletions pkg/controller/jobframework/reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,8 @@ func getPodSetsInfoFromStatus(ctx context.Context, c client.Client, w *kueue.Wor
if err != nil {
return nil, err
}
info.Labels[controllerconsts.PodSetLabel] = podSetFlavor.Name
info.Annotations[controllerconsts.WorkloadAnnotation] = w.Name

for _, admissionCheck := range w.Status.AdmissionChecks {
for _, podSetUpdate := range admissionCheck.PodSetUpdates {
Expand Down
8 changes: 8 additions & 0 deletions pkg/controller/jobs/job/job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ func TestReconciler(t *testing.T) {
wantJob: *baseJobWrapper.Clone().
Suspend(false).
PodLabel("ac-key", "ac-value").
PodAnnotation(controllerconsts.WorkloadAnnotation, "wl").
PodLabel(controllerconsts.PodSetLabel, kueue.DefaultPodSetName).
Obj(),
workloads: []kueue.Workload{
*baseWorkloadWrapper.Clone().
Expand Down Expand Up @@ -1341,7 +1343,9 @@ func TestReconciler(t *testing.T) {
Suspend(false).
PodAnnotation("annotation-key1", "common-value").
PodAnnotation("annotation-key2", "only-in-check1").
PodAnnotation(controllerconsts.WorkloadAnnotation, "wl").
PodLabel("label-key1", "common-value").
PodLabel(controllerconsts.PodSetLabel, kueue.DefaultPodSetName).
NodeSelector("node-selector-key1", "common-value").
NodeSelector("node-selector-key2", "only-in-check2").
Obj(),
Expand Down Expand Up @@ -1447,6 +1451,8 @@ func TestReconciler(t *testing.T) {
job: *baseJobWrapper.DeepCopy(),
wantJob: *baseJobWrapper.Clone().
Suspend(false).
PodAnnotation(controllerconsts.WorkloadAnnotation, "wl").
PodLabel(controllerconsts.PodSetLabel, kueue.DefaultPodSetName).
Obj(),
workloads: []kueue.Workload{
*baseWorkloadWrapper.Clone().
Expand Down Expand Up @@ -1528,6 +1534,8 @@ func TestReconciler(t *testing.T) {
Obj(),
wantJob: *baseJobWrapper.Clone().
SetAnnotation(JobMinParallelismAnnotation, "5").
PodAnnotation(controllerconsts.WorkloadAnnotation, "a").
PodLabel(controllerconsts.PodSetLabel, kueue.DefaultPodSetName).
Suspend(false).
Parallelism(8).
Obj(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,12 @@ func TestReconciler(t *testing.T) {
wantJob: testingmxjob.MakeMXJob("mxjob", "ns").
Image("").
Queue("foo").
ReplicaLabel(kftraining.MXJobReplicaTypeScheduler, controllerconsts.PodSetLabel, "scheduler").
ReplicaLabel(kftraining.MXJobReplicaTypeServer, controllerconsts.PodSetLabel, "server").
ReplicaLabel(kftraining.MXJobReplicaTypeWorker, controllerconsts.PodSetLabel, "worker").
ReplicaAnnotation(kftraining.MXJobReplicaTypeScheduler, controllerconsts.WorkloadAnnotation, "a").
ReplicaAnnotation(kftraining.MXJobReplicaTypeServer, controllerconsts.WorkloadAnnotation, "a").
ReplicaAnnotation(kftraining.MXJobReplicaTypeWorker, controllerconsts.WorkloadAnnotation, "a").
Suspend(false).
Parallelism(1, 1).
Request(kftraining.MXJobReplicaTypeScheduler, corev1.ResourceCPU, "1").
Expand Down
12 changes: 12 additions & 0 deletions pkg/controller/jobs/pod/pod_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ func TestReconciler(t *testing.T) {
wantPods: []corev1.Pod{*basePodWrapper.
Clone().
Label(constants.ManagedByKueueLabel, "true").
Label(controllerconsts.PodSetLabel, kueue.DefaultPodSetName).
Annotation(controllerconsts.WorkloadAnnotation, "unit-test").
NodeSelector("kubernetes.io/arch", "arm64").
KueueFinalizer().
Obj()},
Expand Down Expand Up @@ -772,6 +774,8 @@ func TestReconciler(t *testing.T) {
*basePodWrapper.
Clone().
Label(constants.ManagedByKueueLabel, "true").
Label(controllerconsts.PodSetLabel, "dc85db45").
Annotation(controllerconsts.WorkloadAnnotation, "test-group").
KueueFinalizer().
Group("test-group").
GroupTotalCount("2").
Expand All @@ -781,6 +785,8 @@ func TestReconciler(t *testing.T) {
Clone().
Name("pod2").
Label(constants.ManagedByKueueLabel, "true").
Label(controllerconsts.PodSetLabel, "dc85db45").
Annotation(controllerconsts.WorkloadAnnotation, "test-group").
KueueFinalizer().
Group("test-group").
GroupTotalCount("2").
Expand Down Expand Up @@ -1327,6 +1333,8 @@ func TestReconciler(t *testing.T) {
Clone().
Name("pod2").
Label(constants.ManagedByKueueLabel, "true").
Label(controllerconsts.PodSetLabel, "dc85db45").
Annotation(controllerconsts.WorkloadAnnotation, "test-group").
KueueFinalizer().
Group("test-group").
GroupTotalCount("1").
Expand Down Expand Up @@ -1466,6 +1474,8 @@ func TestReconciler(t *testing.T) {
Clone().
Name("replacement").
Label(constants.ManagedByKueueLabel, "true").
Label(controllerconsts.PodSetLabel, "dc85db45").
Annotation(controllerconsts.WorkloadAnnotation, "test-group").
KueueFinalizer().
Group("test-group").
GroupTotalCount("3").
Expand Down Expand Up @@ -4067,6 +4077,8 @@ func TestReconciler(t *testing.T) {
Clone().
Name("replacement").
Label(constants.ManagedByKueueLabel, "true").
Label(controllerconsts.PodSetLabel, "dc85db45").
Annotation(controllerconsts.WorkloadAnnotation, "test-group").
KueueFinalizer().
Group("test-group").
GroupTotalCount("2").
Expand Down
5 changes: 5 additions & 0 deletions pkg/controller/jobs/raycluster/raycluster_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/reconcile"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
kueueconstants "sigs.k8s.io/kueue/pkg/controller/constants"
"sigs.k8s.io/kueue/pkg/controller/jobframework"
"sigs.k8s.io/kueue/pkg/podset"
utiltesting "sigs.k8s.io/kueue/pkg/util/testing"
Expand Down Expand Up @@ -79,6 +80,10 @@ func TestReconciler(t *testing.T) {
wantJob: *baseJobWrapper.Clone().
Suspend(false).
NodeSelectorHeadGroup("kubernetes.io/arch", "arm64").
LabelHeadGroup(kueueconstants.PodSetLabel, "head").
AnnotationHeadGroup(kueueconstants.WorkloadAnnotation, "test").
LabelWorker(kueueconstants.PodSetLabel, "workers-group-0").
AnnotationWorker(kueueconstants.WorkloadAnnotation, "test").
Obj(),
workloads: []kueue.Workload{
*utiltesting.MakeWorkload("test", "ns").
Expand Down
18 changes: 18 additions & 0 deletions pkg/util/testingjobs/mxjob/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,24 @@ func (j *MXJobWrapper) Queue(queue string) *MXJobWrapper {
return j
}

// ReplicaLabel adds a label into the indicated PodTemplate.
func (j *MXJobWrapper) ReplicaLabel(r kftraining.ReplicaType, k, v string) *MXJobWrapper {
if j.Spec.MXReplicaSpecs[r].Template.Labels == nil {
j.Spec.MXReplicaSpecs[r].Template.Labels = make(map[string]string)
}
j.Spec.MXReplicaSpecs[r].Template.Labels[k] = v
return j
}

// ReplicaAnnotation adds a label into the indicated PodTemplate.
func (j *MXJobWrapper) ReplicaAnnotation(r kftraining.ReplicaType, k, v string) *MXJobWrapper {
if j.Spec.MXReplicaSpecs[r].Template.Annotations == nil {
j.Spec.MXReplicaSpecs[r].Template.Annotations = make(map[string]string)
}
j.Spec.MXReplicaSpecs[r].Template.Annotations[k] = v
return j
}

// Annotations updates annotations of the job.
func (j *MXJobWrapper) Annotations(annotations map[string]string) *MXJobWrapper {
j.ObjectMeta.Annotations = annotations
Expand Down
36 changes: 36 additions & 0 deletions pkg/util/testingjobs/raycluster/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,42 @@ func (j *ClusterWrapper) NodeSelectorHeadGroup(k, v string) *ClusterWrapper {
return j
}

// LabelHeadGroup adds a label to the job's head.
func (j *ClusterWrapper) LabelHeadGroup(k, v string) *ClusterWrapper {
if j.Spec.HeadGroupSpec.Template.Labels == nil {
j.Spec.HeadGroupSpec.Template.Labels = make(map[string]string)
}
j.Spec.HeadGroupSpec.Template.Labels[k] = v
return j
}

// LabelHeadGroup adds an annotation to the job's head.
func (j *ClusterWrapper) AnnotationHeadGroup(k, v string) *ClusterWrapper {
if j.Spec.HeadGroupSpec.Template.Annotations == nil {
j.Spec.HeadGroupSpec.Template.Annotations = make(map[string]string)
}
j.Spec.HeadGroupSpec.Template.Annotations[k] = v
return j
}

// LabelHeadGroup adds a label to the job's first worker.
func (j *ClusterWrapper) LabelWorker(k, v string) *ClusterWrapper {
if j.Spec.WorkerGroupSpecs[0].Template.Labels == nil {
j.Spec.WorkerGroupSpecs[0].Template.Labels = make(map[string]string)
}
j.Spec.WorkerGroupSpecs[0].Template.Labels[k] = v
return j
}

// LabelHeadGroup adds an annotation to the job's first worker.
func (j *ClusterWrapper) AnnotationWorker(k, v string) *ClusterWrapper {
if j.Spec.WorkerGroupSpecs[0].Template.Annotations == nil {
j.Spec.WorkerGroupSpecs[0].Template.Annotations = make(map[string]string)
}
j.Spec.WorkerGroupSpecs[0].Template.Annotations[k] = v
return j
}

// Obj returns the inner Job.
func (j *ClusterWrapper) Obj() *rayv1.RayCluster {
return &j.RayCluster
Expand Down
13 changes: 13 additions & 0 deletions test/integration/controller/jobs/job/job_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ var _ = ginkgo.Describe("Job controller", ginkgo.Ordered, ginkgo.ContinueOnFailu
}, util.Timeout, util.Interval).Should(gomega.BeTrue())
gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector).Should(gomega.HaveLen(1))
gomega.Expect(createdJob.Spec.Template.Spec.NodeSelector[instanceKey]).Should(gomega.Equal(onDemandFlavor.Name))

ginkgo.By("checking the PodSet label and Workload annotation are set on PodTemplate on Job start", func() {
gomega.Expect(createdJob.Spec.Template.Labels).Should(gomega.HaveKey(constants.PodSetLabel))
gomega.Expect(createdJob.Spec.Template.Labels[constants.PodSetLabel]).Should(gomega.Equal("main"))
gomega.Expect(createdJob.Spec.Template.Annotations).Should(gomega.HaveKey(constants.WorkloadAnnotation))
gomega.Expect(createdJob.Spec.Template.Annotations[constants.WorkloadAnnotation]).Should(gomega.Equal(createdWorkload.Name))
})
gomega.Consistently(func() bool {
if err := k8sClient.Get(ctx, wlLookupKey, createdWorkload); err != nil {
return false
Expand All @@ -234,6 +241,12 @@ var _ = ginkgo.Describe("Job controller", ginkgo.Ordered, ginkgo.ContinueOnFailu
return createdJob.Spec.Suspend != nil && *createdJob.Spec.Suspend && createdJob.Status.StartTime == nil &&
len(createdJob.Spec.Template.Spec.NodeSelector) == 0
}, util.Timeout, util.Interval).Should(gomega.BeTrue())

ginkgo.By("checking the PodSet label and Workload annotation are removed from PodTemplate on suspend", func() {
gomega.Expect(createdJob.Spec.Template.Labels).ShouldNot(gomega.HaveKey(constants.PodSetLabel))
gomega.Expect(createdJob.Spec.Template.Annotations).ShouldNot(gomega.HaveKey(constants.WorkloadAnnotation))
})

gomega.Eventually(func() bool {
ok, _ := testing.CheckEventRecordedFor(ctx, k8sClient, "DeletedWorkload", corev1.EventTypeNormal, fmt.Sprintf("Deleted not matching Workload: %v", wlLookupKey.String()), lookupKey)
return ok
Expand Down

0 comments on commit c028214

Please sign in to comment.