Skip to content

Commit

Permalink
TAS: cleanups after merging main PRs
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo committed Oct 23, 2024
1 parent ffc4382 commit 997c7a0
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 70 deletions.
2 changes: 1 addition & 1 deletion pkg/cache/tas_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestFindTopologyAssignment(t *testing.T) {
Labels: map[string]string{
tasBlockLabel: "b1",
tasRackLabel: "r2",
tasHostLabel: "x1",
tasHostLabel: "x2",
},
},
Status: corev1.NodeStatus{
Expand Down
6 changes: 3 additions & 3 deletions pkg/cache/tas_flavor.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ import (
)

// usageOp indicates whether we should add or subtract the usage.
type usageOp bool
type usageOp int

const (
// add usage to the cache
add usageOp = true
add usageOp = iota
// subtract usage from the cache
subtract usageOp = false
subtract
)

type TASFlavorCache struct {
Expand Down
96 changes: 39 additions & 57 deletions pkg/cache/tas_flavor_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/resources"
utilmaps "sigs.k8s.io/kueue/pkg/util/maps"
utiltas "sigs.k8s.io/kueue/pkg/util/tas"
)

Expand All @@ -36,7 +37,9 @@ var (
// domain in the hierarchy of topology domains.
type domain struct {
// sortName indicates name used for sorting when two domains can fit
// the same number of pods
// the same number of pods.
// Example for domain corresponding to "rack2" in "block1" the value is
// "rack2 <domainID>" (where domainID is "block1,rack2").
sortName string

// id is the globally unique id of the domain
Expand All @@ -59,9 +62,9 @@ type TASFlavorSnapshot struct {
// on the Topology object
levelKeys []string

// freeCapacityPerDomain stores the free capacity per domain, only for the
// freeCapacityPerLeafDomain stores the free capacity per domain, only for the
// lowest level of topology
freeCapacityPerDomain map[utiltas.TopologyDomainID]resources.Requests
freeCapacityPerLeafDomain map[utiltas.TopologyDomainID]resources.Requests

// levelValuesPerDomain stores the mapping from domain ID back to the
// ordered list of values. It stores the information for all levels.
Expand All @@ -85,12 +88,12 @@ type TASFlavorSnapshot struct {

func newTASFlavorSnapshot(log logr.Logger, levels []string) *TASFlavorSnapshot {
snapshot := &TASFlavorSnapshot{
log: log,
levelKeys: slices.Clone(levels),
freeCapacityPerDomain: make(map[utiltas.TopologyDomainID]resources.Requests),
levelValuesPerDomain: make(map[utiltas.TopologyDomainID][]string),
domainsPerLevel: make([]domainByID, len(levels)),
state: make(statePerDomain),
log: log,
levelKeys: slices.Clone(levels),
freeCapacityPerLeafDomain: make(map[utiltas.TopologyDomainID]resources.Requests),
levelValuesPerDomain: make(map[utiltas.TopologyDomainID][]string),
domainsPerLevel: make([]domainByID, len(levels)),
state: make(statePerDomain),
}
return snapshot
}
Expand All @@ -103,10 +106,10 @@ func newTASFlavorSnapshot(log logr.Logger, levels []string) *TASFlavorSnapshot {
func (s *TASFlavorSnapshot) initialize() {
levelCount := len(s.levelKeys)
lastLevelIdx := levelCount - 1
for levelIdx := 0; levelIdx < len(s.levelKeys); levelIdx++ {
for levelIdx := range s.levelKeys {
s.domainsPerLevel[levelIdx] = make(domainByID)
}
for childID := range s.freeCapacityPerDomain {
for childID := range s.freeCapacityPerLeafDomain {
childDomain := &domain{
sortName: s.sortName(lastLevelIdx, childID),
id: childID,
Expand Down Expand Up @@ -141,7 +144,7 @@ func (s *TASFlavorSnapshot) sortName(levelIdx int, domainID utiltas.TopologyDoma
"levelIdx", levelIdx,
"domainID", domainID,
"levelValuesPerDomain", s.levelValuesPerDomain,
"freeCapacityPerDomain", s.freeCapacityPerDomain)
"freeCapacityPerDomain", s.freeCapacityPerLeafDomain)
}
// we prefix with the node label value to make it ordered naturally, but
// append also domain ID as the node label value may not be globally unique.
Expand All @@ -150,17 +153,17 @@ func (s *TASFlavorSnapshot) sortName(levelIdx int, domainID utiltas.TopologyDoma

func (s *TASFlavorSnapshot) addCapacity(domainID utiltas.TopologyDomainID, capacity resources.Requests) {
s.initializeFreeCapacityPerDomain(domainID)
s.freeCapacityPerDomain[domainID].Add(capacity)
s.freeCapacityPerLeafDomain[domainID].Add(capacity)
}

func (s *TASFlavorSnapshot) addUsage(domainID utiltas.TopologyDomainID, usage resources.Requests) {
s.initializeFreeCapacityPerDomain(domainID)
s.freeCapacityPerDomain[domainID].Sub(usage)
s.freeCapacityPerLeafDomain[domainID].Sub(usage)
}

func (s *TASFlavorSnapshot) initializeFreeCapacityPerDomain(domainID utiltas.TopologyDomainID) {
if _, found := s.freeCapacityPerDomain[domainID]; !found {
s.freeCapacityPerDomain[domainID] = resources.Requests{}
if _, found := s.freeCapacityPerLeafDomain[domainID]; !found {
s.freeCapacityPerLeafDomain[domainID] = resources.Requests{}
}
}

Expand All @@ -175,6 +178,7 @@ func (s *TASFlavorSnapshot) initializeFreeCapacityPerDomain(domainID utiltas.Top
// a) select the domain at requested level with count >= requestedCount
// b) traverse the structure down level-by-level optimizing the number of used
// domains at each level
// c) build the assignment for the lowest level in the hierarchy
func (s *TASFlavorSnapshot) FindTopologyAssignment(
topologyRequest *kueue.PodSetTopologyRequest,
requests resources.Requests,
Expand Down Expand Up @@ -221,10 +225,11 @@ func (s *TASFlavorSnapshot) resolveLevelIdx(
}

func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool, count int32) (int, []*domain) {
levelDomains := s.domainsForLevel(levelIdx)
if len(levelDomains) == 0 {
domains := s.domainsPerLevel[levelIdx]
if len(domains) == 0 {
return 0, nil
}
levelDomains := utilmaps.Values(domains)
sortedDomain := s.sortedDomains(levelDomains)
topDomain := sortedDomain[0]
if s.state[topDomain.id] < count {
Expand All @@ -236,7 +241,7 @@ func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool,
}
lastIdx := 0
remainingCount := count - s.state[sortedDomain[lastIdx].id]
for remainingCount > 0 && lastIdx < len(sortedDomain)-1 {
for remainingCount > 0 && lastIdx < len(sortedDomain)-1 && s.state[sortedDomain[lastIdx].id] > 0 {
lastIdx++
remainingCount -= s.state[sortedDomain[lastIdx].id]
}
Expand All @@ -251,23 +256,20 @@ func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool,
func (s *TASFlavorSnapshot) updateCountsToMinimum(domains []*domain, count int32) []*domain {
result := make([]*domain, 0)
remainingCount := count
for i := 0; i < len(domains); i++ {
domain := domains[i]
for _, domain := range domains {
if s.state[domain.id] >= remainingCount {
s.state[domain.id] = remainingCount
result = append(result, domain)
return result
}
if s.state[domain.id] > 0 {
remainingCount -= s.state[domain.id]
result = append(result, domain)
}
remainingCount -= s.state[domain.id]
result = append(result, domain)
}
s.log.Error(errCodeAssumptionsViolated, "unexpected remainingCount",
"remainingCount", remainingCount,
"count", count,
"levelValuesPerDomain", s.levelValuesPerDomain,
"freeCapacityPerDomain", s.freeCapacityPerDomain)
"freeCapacityPerDomain", s.freeCapacityPerLeafDomain)
return nil
}

Expand All @@ -276,49 +278,29 @@ func (s *TASFlavorSnapshot) buildAssignment(domains []*domain) *kueue.TopologyAs
Levels: s.levelKeys,
Domains: make([]kueue.TopologyDomainAssignment, 0),
}
for i := 0; i < len(domains); i++ {
for _, domain := range domains {
assignment.Domains = append(assignment.Domains, kueue.TopologyDomainAssignment{
Values: s.asLevelValues(domains[i].id),
Count: s.state[domains[i].id],
Values: s.levelValuesPerDomain[domain.id],
Count: s.state[domain.id],
})
}
return &assignment
}

func (s *TASFlavorSnapshot) asLevelValues(domainID utiltas.TopologyDomainID) []string {
result := make([]string, len(s.levelKeys))
for i := range s.levelKeys {
result[i] = s.levelValuesPerDomain[domainID][i]
}
return result
}

func (s *TASFlavorSnapshot) lowerLevelDomains(levelIdx int, infos []*domain) []*domain {
result := make([]*domain, 0, len(infos))
for _, info := range infos {
func (s *TASFlavorSnapshot) lowerLevelDomains(levelIdx int, domains []*domain) []*domain {
result := make([]*domain, 0, len(domains))
for _, info := range domains {
for _, childDomainID := range info.childIDs {
if childDomain := s.domainsPerLevel[levelIdx+1][childDomainID]; childDomain != nil {
result = append(result, childDomain)
}
childDomain := s.domainsPerLevel[levelIdx+1][childDomainID]
result = append(result, childDomain)
}
}
return result
}

func (s *TASFlavorSnapshot) domainsForLevel(levelIdx int) []*domain {
domains := s.domainsPerLevel[levelIdx]
func (s *TASFlavorSnapshot) sortedDomains(domains []*domain) []*domain {
result := make([]*domain, len(domains))
index := 0
for _, domain := range domains {
result[index] = domain
index++
}
return result
}

func (s *TASFlavorSnapshot) sortedDomains(infos []*domain) []*domain {
result := make([]*domain, len(infos))
copy(result, infos)
copy(result, domains)
slices.SortFunc(result, func(a, b *domain) int {
aCount := s.state[a.id]
bCount := s.state[b.id]
Expand All @@ -335,7 +317,7 @@ func (s *TASFlavorSnapshot) sortedDomains(infos []*domain) []*domain {
}

func (s *TASFlavorSnapshot) fillInCounts(requests resources.Requests) {
for domainID, capacity := range s.freeCapacityPerDomain {
for domainID, capacity := range s.freeCapacityPerLeafDomain {
s.state[domainID] = requests.CountIn(capacity)
}
lastLevelIdx := len(s.domainsPerLevel) - 1
Expand Down
2 changes: 2 additions & 0 deletions pkg/controller/tas/topology_ungater.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/controller"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/reconcile"

configapi "sigs.k8s.io/kueue/apis/config/v1beta1"
Expand Down Expand Up @@ -67,6 +68,7 @@ type podWithUngateInfo struct {
}

var _ reconcile.Reconciler = (*topologyUngater)(nil)
var _ predicate.Predicate = (*topologyUngater)(nil)

// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;update;patch;delete
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch
Expand Down
19 changes: 10 additions & 9 deletions test/integration/tas/tas_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ var _ = ginkgo.Describe("Topology Aware Scheduling", ginkgo.Ordered, func() {
Required: ptr.To(tasRackLabel),
}
gomega.Expect(k8sClient.Create(ctx, wl4)).Should(gomega.Succeed())
util.ExpectWorkloadsToBePending(ctx, k8sClient, wl4)
})

ginkgo.By("finish wl3", func() {
Expand Down Expand Up @@ -349,7 +350,11 @@ var _ = ginkgo.Describe("Topology Aware Scheduling", ginkgo.Ordered, func() {
})
})

ginkgo.When("Nodes node structure is mutated during test cases", func() {
ginkgo.When("Node structure is mutated during test cases", func() {
var (
nodes []corev1.Node
)

ginkgo.BeforeEach(func() {
ns = &corev1.Namespace{
ObjectMeta: metav1.ObjectMeta{
Expand Down Expand Up @@ -383,12 +388,14 @@ var _ = ginkgo.Describe("Topology Aware Scheduling", ginkgo.Ordered, func() {
gomega.Expect(util.DeleteObject(ctx, k8sClient, topology)).Should(gomega.Succeed())
util.ExpectObjectToBeDeleted(ctx, k8sClient, clusterQueue, true)
util.ExpectObjectToBeDeleted(ctx, k8sClient, tasFlavor, true)
for _, node := range nodes {
util.ExpectObjectToBeDeleted(ctx, k8sClient, &node, true)
}
})

ginkgo.It("should admit workload when nodes become available", func() {
var (
nodes []corev1.Node
wl1 *kueue.Workload
wl1 *kueue.Workload
)

ginkgo.By("creating a workload which requires rack, but does not fit in any", func() {
Expand Down Expand Up @@ -433,12 +440,6 @@ var _ = ginkgo.Describe("Topology Aware Scheduling", ginkgo.Ordered, func() {
util.ExpectReservingActiveWorkloadsMetric(clusterQueue, 1)
util.ExpectPendingWorkloadsMetric(clusterQueue, 0, 0)
})

ginkgo.By("Delete the node for cleanup", func() {
for _, node := range nodes {
gomega.Expect(k8sClient.Delete(ctx, &node)).Should(gomega.Succeed())
}
})
})
})
})
Expand Down

0 comments on commit 997c7a0

Please sign in to comment.