Skip to content

Commit

Permalink
changing lookup if rm is not named
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinaecalderon committed Mar 4, 2024
1 parent 9522c2e commit 0d058b4
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 31 deletions.
100 changes: 69 additions & 31 deletions master/internal/rm/multirm/multirm.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@ import (
"github.com/determined-ai/determined/proto/pkg/jobv1"
)

// ErrRMConflict returns a detailed error if multiple resource managers define a resource pool
// with the same name.
func ErrRMConflict(rmNames []string, rp string) error {
return fmt.Errorf("resource pool %s exists for both resource managers %v,", rp, rmNames)
}

// ErrRMNotDefined returns a detailed error if a resource manager isn't found in the MultiRMRouter map.
func ErrRMNotDefined(rm string) error {
return fmt.Errorf("resource manager %s not defined", rm)
}

// ErrRPNotDefined returns a detailed error if a resource pool isn't found.
func ErrRPNotDefined(rp string) error {
return fmt.Errorf("could not find resource pool %s", rp)
}

// MultiRMRouter tracks all resource managers in the system.
type MultiRMRouter struct {
defaultRMName string
Expand Down Expand Up @@ -58,7 +69,7 @@ func (m *MultiRMRouter) GetAllocationSummaries() (

// Allocate routes an AllocateRequest to the specified RM.
func (m *MultiRMRouter) Allocate(rmName string, req sproto.AllocateRequest) (*sproto.ResourcesSubscription, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
return nil, err
}
Expand All @@ -68,7 +79,7 @@ func (m *MultiRMRouter) Allocate(rmName string, req sproto.AllocateRequest) (*sp

// Release routes an allocation release request.
func (m *MultiRMRouter) Release(rmName string, req sproto.ResourcesReleased) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
m.syslog.WithError(err)
return
Expand All @@ -81,7 +92,7 @@ func (m *MultiRMRouter) Release(rmName string, req sproto.ResourcesReleased) {
func (m *MultiRMRouter) ValidateResources(
rmName string, req sproto.ValidateResourcesRequest,
) ([]command.LaunchWarning, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
return nil, err
}
Expand All @@ -104,7 +115,7 @@ func (m *MultiRMRouter) NotifyContainerRunning(req sproto.NotifyContainerRunning

// SetGroupMaxSlots routes a SetGroupMaxSlots request to a specified resource manager/pool.
func (m *MultiRMRouter) SetGroupMaxSlots(rmName string, req sproto.SetGroupMaxSlots) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
m.syslog.WithError(err)
return
Expand All @@ -115,7 +126,7 @@ func (m *MultiRMRouter) SetGroupMaxSlots(rmName string, req sproto.SetGroupMaxSl

// SetGroupWeight routes a SetGroupWeight request to a specified resource manager/pool.
func (m *MultiRMRouter) SetGroupWeight(rmName string, req sproto.SetGroupWeight) error {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
return err
}
Expand All @@ -125,7 +136,7 @@ func (m *MultiRMRouter) SetGroupWeight(rmName string, req sproto.SetGroupWeight)

// SetGroupPriority routes a SetGroupPriority request to a specified resource manager/pool.
func (m *MultiRMRouter) SetGroupPriority(rmName string, req sproto.SetGroupPriority) error {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
return err
}
Expand All @@ -142,7 +153,7 @@ func (m *MultiRMRouter) ExternalPreemptionPending(allocationID model.AllocationI

// IsReattachableOnlyAfterStarted routes a IsReattachableOnlyAfterStarted call to a specified resource manager/pool.
func (m *MultiRMRouter) IsReattachableOnlyAfterStarted(rmName string) bool {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, "")
if err != nil {
m.syslog.WithError(err)
return false // Not sure what else to return here.
Expand Down Expand Up @@ -171,7 +182,7 @@ func (m *MultiRMRouter) GetResourcePools() (*apiv1.GetResourcePoolsResponse, err
func (m *MultiRMRouter) GetDefaultComputeResourcePool(rmName string) (
sproto.GetDefaultComputeResourcePoolResponse, error,
) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, "")
if err != nil {
return sproto.GetDefaultComputeResourcePoolResponse{}, err
}
Expand All @@ -181,7 +192,7 @@ func (m *MultiRMRouter) GetDefaultComputeResourcePool(rmName string) (

// GetDefaultAuxResourcePool routes a GetDefaultAuxResourcePool to the specified resource manager.
func (m *MultiRMRouter) GetDefaultAuxResourcePool(rmName string) (sproto.GetDefaultAuxResourcePoolResponse, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, "")
if err != nil {
return sproto.GetDefaultAuxResourcePoolResponse{}, err
}
Expand All @@ -191,7 +202,7 @@ func (m *MultiRMRouter) GetDefaultAuxResourcePool(rmName string) (sproto.GetDefa

// ValidateResourcePool routes a ValidateResourcePool call to the specified resource manager.
func (m *MultiRMRouter) ValidateResourcePool(rmName string, rpName string) error {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, rpName)
if err != nil {
return err
}
Expand All @@ -203,7 +214,7 @@ func (m *MultiRMRouter) ValidateResourcePool(rmName string, rpName string) error
func (m *MultiRMRouter) ResolveResourcePool(rmName string, req sproto.ResolveResourcesRequest) (
resourceManager, resourcePool string, err error,
) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
return rmName, req.ResourcePool, err
}
Expand All @@ -216,7 +227,7 @@ func (m *MultiRMRouter) TaskContainerDefaults(
rmName, rpName string,
fallbackConfig model.TaskContainerDefaultsConfig,
) (model.TaskContainerDefaultsConfig, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, rpName)
if err != nil {
return model.TaskContainerDefaultsConfig{}, err
}
Expand All @@ -226,7 +237,7 @@ func (m *MultiRMRouter) TaskContainerDefaults(

// GetJobQ routes a GetJobQ call to a specified resource manager/pool.
func (m *MultiRMRouter) GetJobQ(rmName, rpName string) (map[model.JobID]*sproto.RMJobInfo, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, rpName)
if err != nil {
return nil, err
}
Expand All @@ -238,7 +249,7 @@ func (m *MultiRMRouter) GetJobQ(rmName, rpName string) (map[model.JobID]*sproto.
func (m *MultiRMRouter) GetJobQueueStatsRequest(rmName string, req *apiv1.GetJobQueueStatsRequest) (
*apiv1.GetJobQueueStatsResponse, error,
) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePools[0])
if err != nil {
return nil, err
}
Expand All @@ -248,7 +259,7 @@ func (m *MultiRMRouter) GetJobQueueStatsRequest(rmName string, req *apiv1.GetJob

// MoveJob routes a MoveJob call to a specified resource manager/pool.
func (m *MultiRMRouter) MoveJob(rmName string, req sproto.MoveJob) error {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
return err
}
Expand All @@ -258,7 +269,7 @@ func (m *MultiRMRouter) MoveJob(rmName string, req sproto.MoveJob) error {

// RecoverJobPosition routes a RecoverJobPosition call to a specified resource manager/pool.
func (m *MultiRMRouter) RecoverJobPosition(rmName string, req sproto.RecoverJobPosition) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.ResourcePool)
if err != nil {
m.syslog.WithError(err)
return
Expand All @@ -269,7 +280,7 @@ func (m *MultiRMRouter) RecoverJobPosition(rmName string, req sproto.RecoverJobP

// GetExternalJobs routes a GetExternalJobs request to a specified resource manager.
func (m *MultiRMRouter) GetExternalJobs(rmName string, rpName string) ([]*jobv1.Job, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, rpName)
if err != nil {
return nil, err
}
Expand All @@ -295,7 +306,7 @@ func (m *MultiRMRouter) GetAgents() (*apiv1.GetAgentsResponse, error) {

// GetAgent routes a GetAgent request to the specified resource manager & agent.
func (m *MultiRMRouter) GetAgent(rmName string, req *apiv1.GetAgentRequest) (*apiv1.GetAgentResponse, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.AgentId)
if err != nil {
return nil, err
}
Expand All @@ -305,7 +316,7 @@ func (m *MultiRMRouter) GetAgent(rmName string, req *apiv1.GetAgentRequest) (*ap

// EnableAgent routes an EnableAgent request to the specified resource manager & agent.
func (m *MultiRMRouter) EnableAgent(rmName string, req *apiv1.EnableAgentRequest) (*apiv1.EnableAgentResponse, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.AgentId)
if err != nil {
return nil, err
}
Expand All @@ -317,7 +328,7 @@ func (m *MultiRMRouter) EnableAgent(rmName string, req *apiv1.EnableAgentRequest
func (m *MultiRMRouter) DisableAgent(rmName string, req *apiv1.DisableAgentRequest) (
*apiv1.DisableAgentResponse, error,
) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.AgentId)
if err != nil {
return nil, err
}
Expand All @@ -327,7 +338,7 @@ func (m *MultiRMRouter) DisableAgent(rmName string, req *apiv1.DisableAgentReque

// GetSlots routes an GetSlots request to the specified resource manager & agent.
func (m *MultiRMRouter) GetSlots(rmName string, req *apiv1.GetSlotsRequest) (*apiv1.GetSlotsResponse, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.AgentId)
if err != nil {
return nil, err
}
Expand All @@ -337,7 +348,7 @@ func (m *MultiRMRouter) GetSlots(rmName string, req *apiv1.GetSlotsRequest) (*ap

// GetSlot routes an GetSlot request to the specified resource manager & agent.
func (m *MultiRMRouter) GetSlot(rmName string, req *apiv1.GetSlotRequest) (*apiv1.GetSlotResponse, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.AgentId)
if err != nil {
return nil, err
}
Expand All @@ -347,7 +358,7 @@ func (m *MultiRMRouter) GetSlot(rmName string, req *apiv1.GetSlotRequest) (*apiv

// EnableSlot routes an EnableSlot request to the specified resource manager & agent.
func (m *MultiRMRouter) EnableSlot(rmName string, req *apiv1.EnableSlotRequest) (*apiv1.EnableSlotResponse, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.AgentId)
if err != nil {
return nil, err
}
Expand All @@ -357,25 +368,52 @@ func (m *MultiRMRouter) EnableSlot(rmName string, req *apiv1.EnableSlotRequest)

// DisableSlot routes an DisableSlot request to the specified resource manager & agent.
func (m *MultiRMRouter) DisableSlot(rmName string, req *apiv1.DisableSlotRequest) (*apiv1.DisableSlotResponse, error) {
resolvedRMName, err := m.getRM(rmName)
resolvedRMName, err := m.getRM(rmName, req.AgentId)
if err != nil {
return nil, err
}

return m.rms[resolvedRMName].DisableSlot(resolvedRMName, req)
}

func (m *MultiRMRouter) getRM(rmName string) (string, error) {
if rmName == "" {
m.syslog.Infof("RM undefined, routing to default manager")
func (m *MultiRMRouter) getRM(rmName string, rpName string) (string, error) {
if rmName != "" {
// If explicitly given the RMName, check that it exists in the map.
_, ok := m.rms[rmName]
if !ok {
return rmName, ErrRMNotDefined(rmName)
}
return rmName, nil
}

// If given neither the RM or RP name, route to default RM.
if rpName == "" {
m.syslog.Infof("RM undefined, routing to default resource manager")
return m.defaultRMName, nil
}

_, ok := m.rms[rmName]
if !ok {
return rmName, ErrRMNotDefined(rmName)
// If just given the RP name, search through all resource managers for a single match.
rmMatches := []string{}
for name, r := range m.rms {
rps, err := r.GetResourcePools()
if err != nil {
return name, fmt.Errorf("could not get resource pools for %s", r)
}
for _, p := range rps.ResourcePools {
if p.Name == rpName {
rmMatches = append(rmMatches, name)
}
}
}

if len(rmMatches) == 0 {
// If the resolvedRMName isn't set, then the RP was not found.
return rmName, ErrRPNotDefined(rpName)
} else if len(rmMatches) > 1 {
// If the resolvedRMName is already set, we assume there is a conflict.
return "", ErrRMConflict(rmMatches, rpName)
}
return rmName, nil
return rmMatches[0], nil
}

func fanOutRMCall[TReturn any](m *MultiRMRouter, f func(rm.ResourceManager) (TReturn, error)) ([]TReturn, error) {
Expand Down
57 changes: 57 additions & 0 deletions master/internal/rm/multirm/multirm_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,60 @@ func TestDisableSlot(t *testing.T) {
require.Equal(t, err, ErrRMNotDefined("bogus"))
require.Empty(t, ret)
}

func TestGetRMName(t *testing.T) {
def := mocks.ResourceManager{}
def.On("GetResourcePools").Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{
{Name: "gcp2"},
},
}, nil)

gcp := mocks.ResourceManager{}
gcp.On("GetResourcePools").Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{
{Name: "gcp1"}, {Name: "gcp2"},
},
}, nil)

aws := mocks.ResourceManager{}
aws.On("GetResourcePools").Return(&apiv1.GetResourcePoolsResponse{
ResourcePools: []*resourcepoolv1.ResourcePool{
{Name: "aws1"}, {Name: "gcp2"},
},
}, nil)

mockMultiRM := MultiRMRouter{
defaultRMName: "default",
rms: map[string]rm.ResourceManager{
"default": &def,
"gcp": &gcp,
"aws": &aws,
},
syslog: logrus.WithField("component", "resource-router"),
}

cases := []struct {
name string
rmName string
rpName string
err error
expectedRMName string
}{
{"RM/RP undefined", "", "", nil, mockMultiRM.defaultRMName},
{"RM defined, RP undefined", "aws", "", nil, "aws"},
{"RM defined/doesn't exist, RP undefined", "aws123", "", ErrRMNotDefined("aws123"), "aws123"},
{"RM defined, RP defined", "aws", "aws1", nil, "aws"},
{"RM defined, RP defined/doesn't exist", "aws", "awsa", nil, "aws"},
{"RM undefined, RP defined", "", "aws1", nil, "aws"},
{"RM undefined, RP defined + conflict", "", "gcp2", ErrRMConflict([]string{"default", "gcp", "aws"}, "gcp2"), ""},
{"RM undefined, RP defined/doesn't exist", "", "gcp3", ErrRPNotDefined("gcp3"), ""},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
rmName, err := mockMultiRM.getRM(tt.rmName, tt.rpName)
require.Equal(t, tt.expectedRMName, rmName)
require.Equal(t, tt.err, err)
})
}
}

0 comments on commit 0d058b4

Please sign in to comment.