Skip to content

Commit

Permalink
chore: refactor master with DefaultRMRouter
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinaecalderon committed Feb 20, 2024
1 parent 32fbd56 commit f59b7bd
Show file tree
Hide file tree
Showing 28 changed files with 468 additions and 176 deletions.
3 changes: 2 additions & 1 deletion master/internal/api_agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ import (
"github.com/determined-ai/determined/master/internal/authz"
"github.com/determined-ai/determined/master/internal/cluster"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/rm"
"github.com/determined-ai/determined/master/internal/rm/rmerrors"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

func (a *apiServer) GetAgents(
ctx context.Context, req *apiv1.GetAgentsRequest,
) (*apiv1.GetAgentsResponse, error) {
resp, err := a.m.rm.GetAgents(req)
resp, err := rm.GetAgents()
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion master/internal/api_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func (a *apiServer) getCommandLaunchParams(ctx context.Context, req *protoComman
}

poolName, launchWarnings, err := a.m.ResolveResources(
resources.ResourceManager,
resources.ResourcePool,
resources.Slots,
int(cmdSpec.Metadata.WorkspaceID),
Expand All @@ -107,7 +108,7 @@ func (a *apiServer) getCommandLaunchParams(ctx context.Context, req *protoComman
}

// Get the base TaskSpec.
taskSpec, err := a.m.fillTaskSpec(poolName, agentUserGroup, userModel)
taskSpec, err := a.m.fillTaskSpec(resources.ResourceManager, poolName, agentUserGroup, userModel)
if err != nil {
return nil, nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions master/internal/api_generic_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,17 @@ func (a *apiServer) getGenericTaskLaunchParameters(
return nil, nil, nil, fmt.Errorf("resource slots must be >= 0")
}
isSingleNode := resources.IsSingleNode != nil && *resources.IsSingleNode
poolName, launchWarnings, err := a.m.ResolveResources(resources.ResourcePool,
poolName, launchWarnings, err := a.m.ResolveResources(
resources.ResourceManager,
resources.ResourcePool,
resources.Slots,
int(proj.WorkspaceId),
isSingleNode)
if err != nil {
return nil, nil, nil, err
}
// Get the base TaskSpec.
taskSpec, err := a.m.fillTaskSpec(poolName, agentUserGroup, userModel)
taskSpec, err := a.m.fillTaskSpec(resources.ResourceManager, poolName, agentUserGroup, userModel)
if err != nil {
return nil, nil, nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions master/internal/api_resourcepool.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (a *apiServer) GetResourcePools(
if err != nil {
return nil, err
}
resp, err := a.m.rm.GetResourcePools()
resp, err := rm.GetResourcePools()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -253,7 +253,7 @@ func (a *apiServer) canUserModifyWorkspaces(ctx context.Context, ids []int32) er
}

func (a *apiServer) resourcePoolsAsConfigs() ([]config.ResourcePoolConfig, error) {
resp, err := a.m.rm.GetResourcePools()
resp, err := rm.GetResourcePools()
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion master/internal/api_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
expauth "github.com/determined-ai/determined/master/internal/experiment"
"github.com/determined-ai/determined/master/internal/grpcutil"
"github.com/determined-ai/determined/master/internal/logpattern"
"github.com/determined-ai/determined/master/internal/rm"
"github.com/determined-ai/determined/master/internal/task"
"github.com/determined-ai/determined/master/internal/webhooks"
"github.com/determined-ai/determined/master/pkg/model"
Expand Down Expand Up @@ -550,7 +551,7 @@ func (a *apiServer) GetActiveTasksCount(
func (a *apiServer) GetTasks(
ctx context.Context, req *apiv1.GetTasksRequest,
) (resp *apiv1.GetTasksResponse, err error) {
summary, err := a.m.rm.GetAllocationSummaries()
summary, err := rm.GetAllocationSummaries()
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions master/internal/checkpoint_gc.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func runCheckpointGCForCheckpoints(
}

func runCheckpointGCTask(
rm rm.ResourceManager,
rmInterface rm.ResourceManager,
pgDB *db.PgDB,
taskID model.TaskID,
jobID model.JobID,
Expand All @@ -98,15 +98,15 @@ func runCheckpointGCTask(
return nil
}

rp, err := rm.ResolveResourcePool("", -1, 0)
rp, err := rm.ResolveResourcePool(*taskSpec.ResourcesConfig.ResourceManager(), "", -1, 0)
if err != nil {
return fmt.Errorf("resolving resource pool: %w", err)
}

// t.Base is just a shallow copy of the m.taskSpec on the master, so
// use caution when mutating it.
tcd, err := rm.TaskContainerDefaults(
rp,
*taskSpec.ResourcesConfig.ResourceManager(), rp,
config.GetMasterConfig().TaskContainerDefaults)
if err != nil {
return fmt.Errorf("creating task container defaults: %v", err)
Expand Down Expand Up @@ -178,7 +178,7 @@ func runCheckpointGCTask(
},
ResourceManager: "", // TODO (multirm): add RM, once you figure out how to pass it in.
ResourcePool: rp,
}, pgDB, rm, gcSpec, onExit)
}, pgDB, rmInterface, gcSpec, onExit)
if err != nil {
return err
}
Expand Down
12 changes: 7 additions & 5 deletions master/internal/command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,13 @@ func (c *Command) garbageCollect() {

func (c *Command) setNTSCPriority(priority int, forward bool) error {
if forward {
switch err := c.rm.SetGroupPriority(sproto.SetGroupPriority{
Priority: priority,
ResourcePool: c.Config.Resources.ResourcePool,
JobID: c.jobID,
}).(type) {
switch err := rm.SetGroupPriority(
c.Config.Resources.ResourceManager,
sproto.SetGroupPriority{
Priority: priority,
ResourcePool: c.Config.Resources.ResourcePool,
JobID: c.jobID,
}).(type) {
case nil:
case rmerrors.UnsupportedError:
c.syslog.WithError(err).Debug("ignoring unsupported call to set group priority")
Expand Down
12 changes: 7 additions & 5 deletions master/internal/command/command_job_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/internal/rm"
"github.com/determined-ai/determined/master/internal/rm/rmerrors"
"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/proto/pkg/jobv1"
Expand Down Expand Up @@ -59,11 +60,12 @@ func (c *Command) SetWeight(weight float64) error {
c.mu.Lock()
defer c.mu.Unlock()

switch err := c.rm.SetGroupWeight(sproto.SetGroupWeight{
Weight: weight,
ResourcePool: c.Config.Resources.ResourcePool,
JobID: c.jobID,
}).(type) {
switch err := rm.SetGroupWeight(c.Config.Resources.ResourceManager,
sproto.SetGroupWeight{
Weight: weight,
ResourcePool: c.Config.Resources.ResourcePool,
JobID: c.jobID,
}).(type) {
case nil:
case rmerrors.UnsupportedError:
c.syslog.WithError(err).Debug("ignoring unsupported call to set group weight")
Expand Down
26 changes: 17 additions & 9 deletions master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ type Master struct {
logs *logger.LogBuffer
echo *echo.Echo
db *db.PgDB
rm rm.ResourceManager
rm rm.ResourceManager // TODO (multirm): remove this.

trialLogBackend TrialLogBackend
taskLogBackend TaskLogBackend
Expand Down Expand Up @@ -1122,17 +1122,25 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
}

// Resource Manager.
m.rm = rm.New(
m.db,
m.echo,
&m.config.ResourceConfig,
&m.config.TaskContainerDefaults,
&aproto.MasterSetAgentOptions{
rm.SetDefaultRouter(m.db, m.echo, m.config.ResourceConfig.ResourceManagers,
&m.config.TaskContainerDefaults, &aproto.MasterSetAgentOptions{
MasterInfo: m.Info(),
LoggingOptions: m.config.Logging,
},
cert,
)
cert)
/*
m.rm = rm.New(
m.db,
m.echo,
&m.config.ResourceConfig,
&m.config.TaskContainerDefaults,
&aproto.MasterSetAgentOptions{
MasterInfo: m.Info(),
LoggingOptions: m.config.Logging,
},
cert,
)
*/
jobservice.SetDefaultService(m.rm)

tasksGroup := m.echo.Group("/tasks")
Expand Down
16 changes: 13 additions & 3 deletions master/internal/core_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/determined-ai/determined/master/internal/db"
expauth "github.com/determined-ai/determined/master/internal/experiment"
"github.com/determined-ai/determined/master/internal/project"
"github.com/determined-ai/determined/master/internal/rm"
"github.com/determined-ai/determined/master/internal/templates"
"github.com/determined-ai/determined/master/internal/user"
"github.com/determined-ai/determined/master/internal/workspace"
Expand Down Expand Up @@ -295,14 +296,23 @@ func (m *Master) parseCreateExperiment(req *apiv1.CreateExperimentRequest, owner
}
workspaceID := resolveWorkspaceID(workspaceModel)
isSingleNode := resources.IsSingleNode() != nil && *resources.IsSingleNode()
poolName, _, err := m.ResolveResources(resources.ResourcePool(), resources.SlotsPerTrial(), workspaceID, isSingleNode)
poolName, _, err := m.ResolveResources(
*resources.ResourceManager(), resources.ResourcePool(),
resources.SlotsPerTrial(), workspaceID, isSingleNode)
if err != nil {
return nil, nil, config, nil, nil, errors.Wrapf(err, "invalid resource configuration")
}
if err = m.rm.ValidateResources(poolName, resources.SlotsPerTrial(), isSingleNode); err != nil {

if err = rm.ValidateResources(
*resources.ResourceManager(),
poolName,
resources.SlotsPerTrial(),
isSingleNode,
); err != nil {
return nil, nil, config, nil, nil, errors.Wrapf(err, "error validating resources")
}
taskContainerDefaults, err := m.rm.TaskContainerDefaults(
taskContainerDefaults, err := rm.TaskContainerDefaults(
*resources.ResourceManager(),
poolName,
m.config.TaskContainerDefaults,
)
Expand Down
4 changes: 2 additions & 2 deletions master/internal/core_observability.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"github.com/labstack/echo/v4"

"github.com/determined-ai/determined/master/internal/prom"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/master/internal/rm"
)

func (m *Master) getPrometheusTargets(c echo.Context) (interface{}, error) {
resp, err := m.rm.GetAgents(&apiv1.GetAgentsRequest{})
resp, err := rm.GetAgents()
if err != nil {
return nil, fmt.Errorf("gather agent statuses: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion master/internal/core_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"github.com/determined-ai/determined/master/internal/authz"
"github.com/determined-ai/determined/master/internal/context"
expauth "github.com/determined-ai/determined/master/internal/experiment"
"github.com/determined-ai/determined/master/internal/rm"
)

func (m *Master) getTasks(c echo.Context) (interface{}, error) {
summary, err := m.rm.GetAllocationSummaries()
summary, err := rm.GetAllocationSummaries()
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit f59b7bd

Please sign in to comment.