diff --git a/cmd/kueue/main.go b/cmd/kueue/main.go index 8a97e425f1..69c47ca51d 100644 --- a/cmd/kueue/main.go +++ b/cmd/kueue/main.go @@ -51,6 +51,7 @@ import ( "sigs.k8s.io/kueue/pkg/controller/core" "sigs.k8s.io/kueue/pkg/controller/core/indexer" "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/tas" "sigs.k8s.io/kueue/pkg/debugger" "sigs.k8s.io/kueue/pkg/features" "sigs.k8s.io/kueue/pkg/metrics" @@ -216,6 +217,13 @@ func setupIndexes(ctx context.Context, mgr ctrl.Manager, cfg *configapi.Configur } } + if features.Enabled(features.TopologyAwareScheduling) { + if err := tas.SetupIndexes(ctx, mgr.GetFieldIndexer()); err != nil { + setupLog.Error(err, "Could not setup TAS indexer") + os.Exit(1) + } + } + if features.Enabled(features.MultiKueue) { if err := multikueue.SetupIndexer(ctx, mgr.GetFieldIndexer(), *cfg.Namespace); err != nil { setupLog.Error(err, "Could not setup multikueue indexer") @@ -265,6 +273,13 @@ func setupControllers(ctx context.Context, mgr ctrl.Manager, cCache *cache.Cache } } + if features.Enabled(features.TopologyAwareScheduling) { + if failedCtrl, err := tas.SetupControllers(mgr, queues, cCache, cfg); err != nil { + setupLog.Error(err, "Could not setup TAS controller", "controller", failedCtrl) + os.Exit(1) + } + } + if failedWebhook, err := webhooks.Setup(mgr); err != nil { setupLog.Error(err, "Unable to create webhook", "webhook", failedWebhook) os.Exit(1) diff --git a/go.mod b/go.mod index fedcf8d2de..87fda945d3 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/fsnotify/fsnotify v1.7.0 github.com/go-logr/logr v1.4.2 github.com/google/go-cmp v0.6.0 + github.com/json-iterator/go v1.1.12 github.com/kubeflow/mpi-operator v0.5.0 github.com/kubeflow/training-operator v1.8.1 github.com/onsi/ginkgo/v2 v2.20.2 @@ -80,7 +81,6 @@ require ( github.com/imdario/mergo v0.3.16 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect - github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de // indirect diff --git a/pkg/controller/tas/constants.go b/pkg/controller/tas/constants.go new file mode 100644 index 0000000000..851e26edb7 --- /dev/null +++ b/pkg/controller/tas/constants.go @@ -0,0 +1,21 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tas + +const ( + TASTopologyUngater = "tas-topology-ungater" +) diff --git a/pkg/controller/tas/controllers.go b/pkg/controller/tas/controllers.go new file mode 100644 index 0000000000..d063039a11 --- /dev/null +++ b/pkg/controller/tas/controllers.go @@ -0,0 +1,33 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tas + +import ( + ctrl "sigs.k8s.io/controller-runtime" + + configapi "sigs.k8s.io/kueue/apis/config/v1beta1" + "sigs.k8s.io/kueue/pkg/cache" + "sigs.k8s.io/kueue/pkg/queue" +) + +func SetupControllers(mgr ctrl.Manager, queues *queue.Manager, cache *cache.Cache, cfg *configapi.Configuration) (string, error) { + topologyUngater := newTopologyUngater(mgr.GetClient()) + if ctrlName, err := topologyUngater.setupWithManager(mgr, cfg); err != nil { + return ctrlName, err + } + return "", nil +} diff --git a/pkg/controller/tas/indexer.go b/pkg/controller/tas/indexer.go new file mode 100644 index 0000000000..b11a044624 --- /dev/null +++ b/pkg/controller/tas/indexer.go @@ -0,0 +1,41 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tas + +import ( + "context" + + corev1 "k8s.io/api/core/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" +) + +const ( + workloadNameKey = "metadata.workload" +) + +func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error { + return indexer.IndexField(ctx, &corev1.Pod{}, workloadNameKey, func(o client.Object) []string { + pod := o.(*corev1.Pod) + value, found := pod.Annotations[kueuealpha.WorkloadAnnotation] + if !found { + return nil + } + return []string{value} + }) +} diff --git a/pkg/controller/tas/topology_ungater.go b/pkg/controller/tas/topology_ungater.go new file mode 100644 index 0000000000..7f3973000b --- /dev/null +++ b/pkg/controller/tas/topology_ungater.go @@ -0,0 +1,300 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tas + +import ( + "context" + "time" + + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/util/workqueue" + "k8s.io/klog/v2" + "k8s.io/utils/ptr" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + + configapi "sigs.k8s.io/kueue/apis/config/v1beta1" + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/pkg/controller/core" + utilclient "sigs.k8s.io/kueue/pkg/util/client" + "sigs.k8s.io/kueue/pkg/util/parallelize" + utilpod "sigs.k8s.io/kueue/pkg/util/pod" + utiltas "sigs.k8s.io/kueue/pkg/util/tas" +) + +const ( + ungateBatchPeriod = time.Second +) + +type topologyUngater struct { + client client.Client +} + +type podWithUngateInfo struct { + pod *corev1.Pod + nodeLabels map[string]string +} + +var _ reconcile.Reconciler = (*topologyUngater)(nil) + +// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;update;patch;delete +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch +// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get + +func newTopologyUngater(c client.Client) *topologyUngater { + return &topologyUngater{ + client: c, + } +} + +func (r *topologyUngater) setupWithManager(mgr ctrl.Manager, cfg *configapi.Configuration) (string, error) { + podHandler := podHandler{} + return TASTopologyUngater, ctrl.NewControllerManagedBy(mgr). + Named(TASTopologyUngater). + For(&kueue.Workload{}). + Watches(&corev1.Pod{}, &podHandler). + WithOptions(controller.Options{NeedLeaderElection: ptr.To(false)}). + WithEventFilter(r). + Complete(core.WithLeadingManager(mgr, r, &kueue.ClusterQueue{}, cfg)) +} + +var _ handler.EventHandler = (*podHandler)(nil) + +type podHandler struct { +} + +func (h *podHandler) Create(_ context.Context, e event.CreateEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + pod, isPod := e.Object.(*corev1.Pod) + if !isPod { + return + } + h.queueReconcileForPod(pod, q) +} + +func (h *podHandler) Update(ctx context.Context, e event.UpdateEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + oldPod, isOldPod := e.ObjectOld.(*corev1.Pod) + newPod, isNewPod := e.ObjectNew.(*corev1.Pod) + if !isOldPod || !isNewPod { + return + } + h.queueReconcileForPod(oldPod, q) + h.queueReconcileForPod(newPod, q) +} + +func (h *podHandler) Delete(_ context.Context, e event.DeleteEvent, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + pod, isPod := e.Object.(*corev1.Pod) + if !isPod { + return + } + h.queueReconcileForPod(pod, q) +} + +func (h *podHandler) queueReconcileForPod(pod *corev1.Pod, q workqueue.TypedRateLimitingInterface[reconcile.Request]) { + if pod == nil { + return + } + if !utilpod.HasGate(pod, kueuealpha.TopologySchedulingGate) { + return + } + if wlName, found := pod.Annotations[kueuealpha.WorkloadAnnotation]; found { + q.AddAfter(reconcile.Request{NamespacedName: types.NamespacedName{ + Name: wlName, + Namespace: pod.Namespace, + }}, ungateBatchPeriod) + } +} + +func (h *podHandler) Generic(context.Context, event.GenericEvent, workqueue.TypedRateLimitingInterface[reconcile.Request]) { +} + +func (r *topologyUngater) Reconcile(ctx context.Context, req reconcile.Request) (reconcile.Result, error) { + log := ctrl.LoggerFrom(ctx).WithValues("workload", req.NamespacedName.Name) + log.V(2).Info("Reconcile Topology Ungater") + + wl := &kueue.Workload{} + if err := r.client.Get(ctx, req.NamespacedName, wl); err != nil { + if client.IgnoreNotFound(err) != nil { + return reconcile.Result{}, err + } + log.Info("workload not found") + return reconcile.Result{}, nil + } + if wl.Status.Admission == nil { + log.Info("workload is not admitted") + return reconcile.Result{}, nil + } + + allToUngate := make([]podWithUngateInfo, 0) + for _, psa := range wl.Status.Admission.PodSetAssignments { + if psa.TopologyAssignment != nil { + toUngate, err := r.podsetPodsToUngate(ctx, log, wl, &psa) + if err != nil { + log.Error(err, "failed to identify pods to ungate", "podset", psa.Name, "count", psa.Count) + return reconcile.Result{}, err + } else { + log.Info("identified pods to ungate for podset", "podset", psa.Name, "count", len(toUngate)) + allToUngate = append(allToUngate, toUngate...) + } + } + } + var err error + if len(allToUngate) > 0 { + log.V(2).Info("identified pods to ungate", "count", len(allToUngate)) + err = parallelize.Until(ctx, len(allToUngate), func(i int) error { + podWithUngateInfo := &allToUngate[i] + e := utilclient.Patch(ctx, r.client, podWithUngateInfo.pod, true, func() (bool, error) { + log.V(3).Info("ungating pod", "pod", klog.KObj(podWithUngateInfo.pod), "nodeLabels", podWithUngateInfo.nodeLabels) + utilpod.Ungate(podWithUngateInfo.pod, kueuealpha.TopologySchedulingGate) + if podWithUngateInfo.pod.Spec.NodeSelector == nil { + podWithUngateInfo.pod.Spec.NodeSelector = make(map[string]string) + } + for labelKey, labelValue := range podWithUngateInfo.nodeLabels { + podWithUngateInfo.pod.Spec.NodeSelector[labelKey] = labelValue + } + return true, nil + }) + if e != nil { + log.Error(e, "failed ungating pod", "pod", klog.KObj(podWithUngateInfo.pod)) + } + return e + }) + if err != nil { + return reconcile.Result{}, err + } + } + return reconcile.Result{}, nil +} + +func (r *topologyUngater) Create(event event.CreateEvent) bool { + wl, isWl := event.Object.(*kueue.Workload) + if isWl { + return isTASWorkload(wl) + } + return true +} + +func (r *topologyUngater) Delete(event event.DeleteEvent) bool { + wl, isWl := event.Object.(*kueue.Workload) + if isWl { + return isTASWorkload(wl) + } + return true +} + +func (r *topologyUngater) Update(event event.UpdateEvent) bool { + _, isOldWl := event.ObjectOld.(*kueue.Workload) + newWl, isNewWl := event.ObjectNew.(*kueue.Workload) + if isOldWl && isNewWl { + return isTASWorkload(newWl) + } + return true +} + +func isTASWorkload(wl *kueue.Workload) bool { + if wl.Status.Admission == nil { + return false + } + for _, psa := range wl.Status.Admission.PodSetAssignments { + if psa.TopologyAssignment != nil { + return true + } + } + return false +} + +func (r *topologyUngater) Generic(event event.GenericEvent) bool { + return false +} + +func (r *topologyUngater) podsetPodsToUngate(ctx context.Context, log logr.Logger, wl *kueue.Workload, psa *kueue.PodSetAssignment) ([]podWithUngateInfo, error) { + levelKeys := psa.TopologyAssignment.Levels + domainIDToLabelValues := make(map[utiltas.TopologyDomainID][]string) + domainIDToExpectedCount := make(map[utiltas.TopologyDomainID]int32) + for _, psaDomain := range psa.TopologyAssignment.Domains { + domainID := utiltas.DomainID(psaDomain.Values) + domainIDToExpectedCount[domainID] = psaDomain.Count + domainIDToLabelValues[domainID] = psaDomain.Values + } + pods, err := r.podsForDomain(ctx, wl.Namespace, wl.Name, psa.Name) + if err != nil { + return nil, err + } + gatedPods := make([]*corev1.Pod, 0) + domainIDToUngatedCnt := make(map[utiltas.TopologyDomainID]int32) + for i := range pods { + pod := pods[i] + isGated := utilpod.HasGate(pod, kueuealpha.TopologySchedulingGate) + if isGated { + gatedPods = append(gatedPods, pod) + } else { + levelValues := utiltas.LevelValues(levelKeys, pod.Spec.NodeSelector) + domainID := utiltas.DomainID(levelValues) + domainIDToUngatedCnt[domainID]++ + } + } + log.V(5).Info("searching pods to ungate", + "podSetName", psa.Name, + "podSetCount", psa.Count, + "domainIDToUngatedCount", domainIDToUngatedCnt, + "domainIDToLabelValues", domainIDToLabelValues, + "levelKeys", levelKeys) + toUngate := make([]podWithUngateInfo, 0) + for domainID, expectedInDomainCnt := range domainIDToExpectedCount { + ungatedInDomainCnt := domainIDToUngatedCnt[domainID] + remainingUngatedInDomain := max(expectedInDomainCnt-ungatedInDomainCnt, 0) + if remainingUngatedInDomain > 0 { + domainValues := domainIDToLabelValues[domainID] + + nodeLabels := utiltas.NodeLabelsFromKeysAndValues(levelKeys, domainValues) + remainingGatedCnt := int32(max(len(gatedPods)-len(toUngate), 0)) + toUngateCnt := min(remainingUngatedInDomain, remainingGatedCnt) + if toUngateCnt > 0 { + podsToUngateInDomain := gatedPods[len(toUngate) : int32(len(toUngate))+toUngateCnt] + for i := range podsToUngateInDomain { + toUngate = append(toUngate, podWithUngateInfo{ + pod: podsToUngateInDomain[i], + nodeLabels: nodeLabels, + }) + } + } + } + } + return toUngate, nil +} + +func (r *topologyUngater) podsForDomain(ctx context.Context, ns, wlName, psName string) ([]*corev1.Pod, error) { + var pods corev1.PodList + if err := r.client.List(ctx, &pods, client.InNamespace(ns), client.MatchingLabels{ + kueuealpha.PodSetLabel: psName, + }, client.MatchingFields{ + workloadNameKey: wlName, + }); err != nil { + return nil, err + } + result := make([]*corev1.Pod, 0) + for _, p := range pods.Items { + result = append(result, &p) + } + return result, nil +} diff --git a/pkg/controller/tas/topology_ungater_test.go b/pkg/controller/tas/topology_ungater_test.go new file mode 100644 index 0000000000..3e0765a941 --- /dev/null +++ b/pkg/controller/tas/topology_ungater_test.go @@ -0,0 +1,363 @@ +/* +Copyright The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tas + +import ( + "maps" + "testing" + + gocmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + jsoniter "github.com/json-iterator/go" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" + "sigs.k8s.io/kueue/pkg/constants" + utiltesting "sigs.k8s.io/kueue/pkg/util/testing" + testingpod "sigs.k8s.io/kueue/pkg/util/testingjobs/pod" + + _ "sigs.k8s.io/kueue/pkg/controller/jobs/job" + _ "sigs.k8s.io/kueue/pkg/controller/jobs/raycluster" +) + +const ( + tasBlockLabel = "cloud.com/topology-block" + tasRackLabel = "cloud.com/topology-rack" +) + +var ( + podCmpOpts = []gocmp.Option{ + cmpopts.EquateEmpty(), + cmpopts.IgnoreFields(corev1.Pod{}, "TypeMeta", "ObjectMeta.ResourceVersion", + "ObjectMeta.DeletionTimestamp"), + cmpopts.IgnoreFields(corev1.PodCondition{}, "LastTransitionTime"), + } + defaultTestLevels = []string{ + tasBlockLabel, + tasRackLabel, + } +) + +func TestReconcile(t *testing.T) { + // this code is meant to deal with the fact that the pod assignment to + // topology domains is non-deterministic (i.e. sometimes pod1 gets the + // assigned to domain1, but sometimes to domain2). We only care about the + // number of pods ungated to a given domain. + type counts struct { + NodeSelector map[string]string + Count int32 + } + + mapToJSON := func(t *testing.T, m map[string]string) string { + json := jsoniter.Config{ + SortMapKeys: true, + }.Froze() + bytes, err := json.Marshal(m) + if err != nil { + t.Fatalf("failed to serialize map: %v, error=%s", m, err) + } + return string(bytes) + } + + extractCountsMapFromPods := func(pods []corev1.Pod) map[string]*counts { + result := make(map[string]*counts, len(pods)) + for i := range pods { + pod := pods[i] + key := mapToJSON(t, pod.Spec.NodeSelector) + if _, found := result[key]; !found { + result[key] = &counts{ + NodeSelector: maps.Clone(pod.Spec.NodeSelector), + Count: 0, + } + } + result[key].Count++ + } + return result + } + + testCases := map[string]struct { + workloads []kueue.Workload + pods []corev1.Pod + wantPods []corev1.Pod + wantCounts []counts + wantErr error + }{ + "ungate single pod": { + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("unit-test", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 1).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota( + utiltesting.MakeAdmission("cq"). + Assignment(corev1.ResourceCPU, "unit-test-flavor", "1"). + AssignmentPodCount(1). + TopologyAssignment(&kueue.TopologyAssignment{ + Levels: defaultTestLevels, + Domains: []kueue.TopologyDomainAssignment{ + { + Count: 1, + Values: []string{ + "b1", + "r1", + }, + }, + }, + }). + Obj(), + ). + Admitted(true). + Obj(), + }, + pods: []corev1.Pod{ + *testingpod.MakePod("pod", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + TopologySchedulingGate(). + Obj(), + }, + wantPods: []corev1.Pod{ + *testingpod.MakePod("pod", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + Obj(), + }, + wantCounts: []counts{ + { + NodeSelector: map[string]string{ + tasBlockLabel: "b1", + tasRackLabel: "r1", + }, + Count: 1, + }, + }, + }, + "ungate multiple pods in a single domain": { + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("unit-test", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 3).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota( + utiltesting.MakeAdmission("cq"). + Assignment(corev1.ResourceCPU, "unit-test-flavor", "1"). + AssignmentPodCount(3). + TopologyAssignment(&kueue.TopologyAssignment{ + Levels: defaultTestLevels, + Domains: []kueue.TopologyDomainAssignment{ + { + Count: 3, + Values: []string{ + "b1", + "r1", + }, + }, + }, + }). + Obj(), + ). + Admitted(true). + Obj(), + }, + pods: []corev1.Pod{ + *testingpod.MakePod("pod1", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + TopologySchedulingGate(). + Obj(), + *testingpod.MakePod("pod2", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + TopologySchedulingGate(). + Obj(), + }, + wantPods: []corev1.Pod{ + *testingpod.MakePod("pod1", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + Obj(), + *testingpod.MakePod("pod2", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + Obj(), + }, + wantCounts: []counts{ + { + NodeSelector: map[string]string{ + tasBlockLabel: "b1", + tasRackLabel: "r1", + }, + Count: 2, + }, + }, + }, + "ungate multiple pods across multiple domains": { + workloads: []kueue.Workload{ + *utiltesting.MakeWorkload("unit-test", "ns").Finalizers(kueue.ResourceInUseFinalizerName). + PodSets(*utiltesting.MakePodSet(kueue.DefaultPodSetName, 2).Request(corev1.ResourceCPU, "1").Obj()). + ReserveQuota( + utiltesting.MakeAdmission("cq"). + Assignment(corev1.ResourceCPU, "unit-test-flavor", "1"). + AssignmentPodCount(2). + TopologyAssignment(&kueue.TopologyAssignment{ + Levels: defaultTestLevels, + Domains: []kueue.TopologyDomainAssignment{ + { + Count: 1, + Values: []string{ + "b1", + "r1", + }, + }, + { + Count: 1, + Values: []string{ + "b1", + "r2", + }, + }, + }, + }). + Obj(), + ). + Admitted(true). + Obj(), + }, + pods: []corev1.Pod{ + *testingpod.MakePod("pod1", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + TopologySchedulingGate(). + Obj(), + *testingpod.MakePod("pod2", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + TopologySchedulingGate(). + Obj(), + }, + wantPods: []corev1.Pod{ + *testingpod.MakePod("pod1", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + Obj(), + *testingpod.MakePod("pod2", "ns"). + Annotation(kueuealpha.WorkloadAnnotation, "unit-test"). + Label(constants.ManagedByKueueLabel, "true"). + Label(kueuealpha.PodSetLabel, kueue.DefaultPodSetName). + KueueFinalizer(). + Obj(), + }, + wantCounts: []counts{ + { + NodeSelector: map[string]string{ + tasBlockLabel: "b1", + tasRackLabel: "r1", + }, + Count: 1, + }, + { + NodeSelector: map[string]string{ + tasBlockLabel: "b1", + tasRackLabel: "r2", + }, + Count: 1, + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + ctx, _ := utiltesting.ContextWithLog(t) + clientBuilder := utiltesting.NewClientBuilder().WithInterceptorFuncs(interceptor.Funcs{SubResourcePatch: utiltesting.TreatSSAAsStrategicMerge}) + if err := SetupIndexes(ctx, utiltesting.AsIndexer(clientBuilder)); err != nil { + t.Fatalf("Could not setup indexes: %v", err) + } + + kcBuilder := clientBuilder.WithObjects() + for i := range tc.pods { + kcBuilder = kcBuilder.WithObjects(&tc.pods[i]) + } + + for i := range tc.workloads { + kcBuilder = kcBuilder.WithStatusSubresource(&tc.workloads[i]) + } + + kClient := kcBuilder.Build() + for i := range tc.workloads { + if err := kClient.Create(ctx, &tc.workloads[i]); err != nil { + t.Fatalf("Could not create workload: %v", err) + } + } + topologyUngater := newTopologyUngater(kClient) + + reconcileRequest := reconcile.Request{ + NamespacedName: client.ObjectKeyFromObject(&tc.workloads[0]), + } + + _, err := topologyUngater.Reconcile(ctx, reconcileRequest) + + if diff := gocmp.Diff(tc.wantErr, err, cmpopts.EquateErrors(), cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Reconcile returned error (-want,+got):\n%s", diff) + } + + var gotPods corev1.PodList + if err := kClient.List(ctx, &gotPods); err != nil { + if tc.wantPods != nil || !apierrors.IsNotFound(err) { + t.Fatalf("Could not get Pod after reconcile: %v", err) + } + } + + extPodCmpOpts := append(podCmpOpts, cmpopts.IgnoreFields(corev1.PodSpec{}, "NodeSelector")) + if diff := gocmp.Diff(tc.wantPods, gotPods.Items, extPodCmpOpts...); diff != "" { + t.Errorf("Pods after reconcile (-want,+got):\n%s", diff) + } + + wantCountsMap := make(map[string]*counts) + for i := range tc.wantCounts { + key := mapToJSON(t, tc.wantCounts[i].NodeSelector) + wantCountsMap[key] = &counts{ + NodeSelector: maps.Clone(tc.wantCounts[i].NodeSelector), + Count: tc.wantCounts[i].Count, + } + } + gotCountsMap := extractCountsMapFromPods(gotPods.Items) + if diff := gocmp.Diff(wantCountsMap, gotCountsMap); diff != "" { + t.Errorf("unexpected counts (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/pkg/util/tas/tas.go b/pkg/util/tas/tas.go new file mode 100644 index 0000000000..1bb8201b7e --- /dev/null +++ b/pkg/util/tas/tas.go @@ -0,0 +1,46 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tas + +import ( + "strings" +) + +type TopologyDomainID string + +func DomainID(levelValues []string) TopologyDomainID { + if len(levelValues) == 0 { + panic("hash invoked without levelValues") + } + return TopologyDomainID(strings.Join(levelValues, ",")) +} + +func NodeLabelsFromKeysAndValues(keys, values []string) map[string]string { + result := make(map[string]string, len(keys)) + for i := range keys { + result[keys[i]] = values[i] + } + return result +} + +func LevelValues(levelKeys []string, objectLabels map[string]string) []string { + levelValues := make([]string, len(levelKeys)) + for levelIdx, levelKey := range levelKeys { + levelValues[levelIdx] = objectLabels[levelKey] + } + return levelValues +} diff --git a/pkg/util/testing/wrappers.go b/pkg/util/testing/wrappers.go index 7d8cdc288f..5c1c5217bc 100644 --- a/pkg/util/testing/wrappers.go +++ b/pkg/util/testing/wrappers.go @@ -504,6 +504,11 @@ func (w *AdmissionWrapper) AssignmentPodCount(value int32) *AdmissionWrapper { return w } +func (w *AdmissionWrapper) TopologyAssignment(ts *kueue.TopologyAssignment) *AdmissionWrapper { + w.PodSetAssignments[0].TopologyAssignment = ts + return w +} + func (w *AdmissionWrapper) PodSets(podSets ...kueue.PodSetAssignment) *AdmissionWrapper { w.PodSetAssignments = podSets return w diff --git a/pkg/util/testingjobs/pod/wrappers.go b/pkg/util/testingjobs/pod/wrappers.go index 0aa39a750d..054d6ad1a1 100644 --- a/pkg/util/testingjobs/pod/wrappers.go +++ b/pkg/util/testingjobs/pod/wrappers.go @@ -28,8 +28,10 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/utils/ptr" + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" "sigs.k8s.io/kueue/pkg/constants" controllerconsts "sigs.k8s.io/kueue/pkg/controller/constants" + utilpod "sigs.k8s.io/kueue/pkg/util/pod" ) // PodWrapper wraps a Pod. @@ -135,6 +137,15 @@ func (p *PodWrapper) KueueSchedulingGate() *PodWrapper { return p } +// TopologySchedulingGate adds kueue scheduling gate to the Pod +func (p *PodWrapper) TopologySchedulingGate() *PodWrapper { + if p.Spec.SchedulingGates == nil { + p.Spec.SchedulingGates = make([]corev1.PodSchedulingGate, 0) + } + utilpod.Gate(&p.Pod, kueuealpha.TopologySchedulingGate) + return p +} + // Finalizer adds a finalizer to the Pod func (p *PodWrapper) Finalizer(f string) *PodWrapper { if p.ObjectMeta.Finalizers == nil {