Skip to content

Commit

Permalink
TAS: cleanups after merging main PRs
Browse files Browse the repository at this point in the history
  • Loading branch information
mimowo committed Oct 23, 2024
1 parent 6300ff3 commit a22958b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 27 deletions.
6 changes: 3 additions & 3 deletions pkg/cache/tas_flavor.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ import (
)

// usageOp indicates whether we should add or subtract the usage.
type usageOp bool
type usageOp int

const (
// add usage to the cache
add usageOp = true
add usageOp = iota
// subtract usage from the cache
subtract usageOp = false
subtract
)

type TASFlavorCache struct {
Expand Down
47 changes: 24 additions & 23 deletions pkg/cache/tas_flavor_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ var (
// domain in the hierarchy of topology domains.
type domain struct {
// sortName indicates name used for sorting when two domains can fit
// the same number of pods
// the same number of pods.
// Example for domain corresponding to "rack2" in "block1" the value is
// "rack2 <domainID>" (where domainID is "block1,rack2").
sortName string

// id is the globally unique id of the domain
Expand All @@ -59,9 +61,9 @@ type TASFlavorSnapshot struct {
// on the Topology object
levelKeys []string

// freeCapacityPerDomain stores the free capacity per domain, only for the
// freeCapacityPerLeafDomain stores the free capacity per domain, only for the
// lowest level of topology
freeCapacityPerDomain map[utiltas.TopologyDomainID]resources.Requests
freeCapacityPerLeafDomain map[utiltas.TopologyDomainID]resources.Requests

// levelValuesPerDomain stores the mapping from domain ID back to the
// ordered list of values. It stores the information for all levels.
Expand All @@ -85,12 +87,12 @@ type TASFlavorSnapshot struct {

func newTASFlavorSnapshot(log logr.Logger, levels []string) *TASFlavorSnapshot {
snapshot := &TASFlavorSnapshot{
log: log,
levelKeys: slices.Clone(levels),
freeCapacityPerDomain: make(map[utiltas.TopologyDomainID]resources.Requests),
levelValuesPerDomain: make(map[utiltas.TopologyDomainID][]string),
domainsPerLevel: make([]domainByID, len(levels)),
state: make(statePerDomain),
log: log,
levelKeys: slices.Clone(levels),
freeCapacityPerLeafDomain: make(map[utiltas.TopologyDomainID]resources.Requests),
levelValuesPerDomain: make(map[utiltas.TopologyDomainID][]string),
domainsPerLevel: make([]domainByID, len(levels)),
state: make(statePerDomain),
}
return snapshot
}
Expand All @@ -103,10 +105,10 @@ func newTASFlavorSnapshot(log logr.Logger, levels []string) *TASFlavorSnapshot {
func (s *TASFlavorSnapshot) initialize() {
levelCount := len(s.levelKeys)
lastLevelIdx := levelCount - 1
for levelIdx := 0; levelIdx < len(s.levelKeys); levelIdx++ {
for levelIdx := range s.levelKeys {
s.domainsPerLevel[levelIdx] = make(domainByID)
}
for childID := range s.freeCapacityPerDomain {
for childID := range s.freeCapacityPerLeafDomain {
childDomain := &domain{
sortName: s.sortName(lastLevelIdx, childID),
id: childID,
Expand Down Expand Up @@ -141,7 +143,7 @@ func (s *TASFlavorSnapshot) sortName(levelIdx int, domainID utiltas.TopologyDoma
"levelIdx", levelIdx,
"domainID", domainID,
"levelValuesPerDomain", s.levelValuesPerDomain,
"freeCapacityPerDomain", s.freeCapacityPerDomain)
"freeCapacityPerDomain", s.freeCapacityPerLeafDomain)
}
// we prefix with the node label value to make it ordered naturally, but
// append also domain ID as the node label value may not be globally unique.
Expand All @@ -150,17 +152,17 @@ func (s *TASFlavorSnapshot) sortName(levelIdx int, domainID utiltas.TopologyDoma

func (s *TASFlavorSnapshot) addCapacity(domainID utiltas.TopologyDomainID, capacity resources.Requests) {
s.initializeFreeCapacityPerDomain(domainID)
s.freeCapacityPerDomain[domainID].Add(capacity)
s.freeCapacityPerLeafDomain[domainID].Add(capacity)
}

func (s *TASFlavorSnapshot) addUsage(domainID utiltas.TopologyDomainID, usage resources.Requests) {
s.initializeFreeCapacityPerDomain(domainID)
s.freeCapacityPerDomain[domainID].Sub(usage)
s.freeCapacityPerLeafDomain[domainID].Sub(usage)
}

func (s *TASFlavorSnapshot) initializeFreeCapacityPerDomain(domainID utiltas.TopologyDomainID) {
if _, found := s.freeCapacityPerDomain[domainID]; !found {
s.freeCapacityPerDomain[domainID] = resources.Requests{}
if _, found := s.freeCapacityPerLeafDomain[domainID]; !found {
s.freeCapacityPerLeafDomain[domainID] = resources.Requests{}
}
}

Expand Down Expand Up @@ -251,8 +253,7 @@ func (s *TASFlavorSnapshot) findLevelWithFitDomains(levelIdx int, required bool,
func (s *TASFlavorSnapshot) updateCountsToMinimum(domains []*domain, count int32) []*domain {
result := make([]*domain, 0)
remainingCount := count
for i := 0; i < len(domains); i++ {
domain := domains[i]
for _, domain := range domains {
if s.state[domain.id] >= remainingCount {
s.state[domain.id] = remainingCount
result = append(result, domain)
Expand All @@ -267,7 +268,7 @@ func (s *TASFlavorSnapshot) updateCountsToMinimum(domains []*domain, count int32
"remainingCount", remainingCount,
"count", count,
"levelValuesPerDomain", s.levelValuesPerDomain,
"freeCapacityPerDomain", s.freeCapacityPerDomain)
"freeCapacityPerDomain", s.freeCapacityPerLeafDomain)
return nil
}

Expand All @@ -276,10 +277,10 @@ func (s *TASFlavorSnapshot) buildAssignment(domains []*domain) *kueue.TopologyAs
Levels: s.levelKeys,
Domains: make([]kueue.TopologyDomainAssignment, 0),
}
for i := 0; i < len(domains); i++ {
for _, domain := range domains {
assignment.Domains = append(assignment.Domains, kueue.TopologyDomainAssignment{
Values: s.asLevelValues(domains[i].id),
Count: s.state[domains[i].id],
Values: s.asLevelValues(domain.id),
Count: s.state[domain.id],
})
}
return &assignment
Expand Down Expand Up @@ -335,7 +336,7 @@ func (s *TASFlavorSnapshot) sortedDomains(infos []*domain) []*domain {
}

func (s *TASFlavorSnapshot) fillInCounts(requests resources.Requests) {
for domainID, capacity := range s.freeCapacityPerDomain {
for domainID, capacity := range s.freeCapacityPerLeafDomain {
s.state[domainID] = requests.CountIn(capacity)
}
lastLevelIdx := len(s.domainsPerLevel) - 1
Expand Down
2 changes: 2 additions & 0 deletions pkg/controller/tas/topology_ungater.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/controller"
"sigs.k8s.io/controller-runtime/pkg/event"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/predicate"
"sigs.k8s.io/controller-runtime/pkg/reconcile"

configapi "sigs.k8s.io/kueue/apis/config/v1beta1"
Expand Down Expand Up @@ -67,6 +68,7 @@ type podWithUngateInfo struct {
}

var _ reconcile.Reconciler = (*topologyUngater)(nil)
var _ predicate.Predicate = (*topologyUngater)(nil)

// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;update;patch;delete
// +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch
Expand Down
2 changes: 1 addition & 1 deletion test/integration/tas/tas_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ var _ = ginkgo.Describe("Topology Aware Scheduling", ginkgo.Ordered, func() {

ginkgo.By("Delete the node for cleanup", func() {
for _, node := range nodes {
gomega.Expect(k8sClient.Delete(ctx, &node)).Should(gomega.Succeed())
util.ExpectObjectToBeDeleted(ctx, k8sClient, &node, true)
}
})
})
Expand Down

0 comments on commit a22958b

Please sign in to comment.