Skip to content

Commit

Permalink
copy cache directly
Browse files Browse the repository at this point in the history
  • Loading branch information
kkunapuli committed Jun 25, 2024
1 parent d1ce157 commit f004433
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
8 changes: 5 additions & 3 deletions master/internal/rm/kubernetesrm/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/rm"
"github.com/determined-ai/determined/master/internal/rm/rmevents"
"github.com/determined-ai/determined/master/internal/sproto"
"github.com/determined-ai/determined/master/pkg/cproto"
Expand Down Expand Up @@ -945,7 +946,7 @@ func (j *jobsService) refreshPodStates(allocationID model.AllocationID) error {
return nil
}

func (j *jobsService) GetAgents() *apiv1.GetAgentsResponse {
func (j *jobsService) GetAgents() (*apiv1.GetAgentsResponse, error) {
j.mu.Lock()
defer j.mu.Unlock()
return j.getAgents()
Expand Down Expand Up @@ -1598,7 +1599,7 @@ func (j *jobsService) getSlot(agentID string, slotID string) *apiv1.GetSlotRespo

const getAgentsCacheDuration = 15 * time.Second

func (j *jobsService) getAgents() *apiv1.GetAgentsResponse {
func (j *jobsService) getAgents() (*apiv1.GetAgentsResponse, error) {
j.getAgentsCacheLock.Lock()
defer j.getAgentsCacheLock.Unlock()

Expand All @@ -1615,7 +1616,8 @@ func (j *jobsService) getAgents() *apiv1.GetAgentsResponse {
}
}

return j.getAgentsCache
// Ensure cached response is not inadvertently modified.
return rm.CopyGetAgentsResponse(j.getAgentsCache)
}

func (j *jobsService) getAgent(agentID string) *apiv1.GetAgentResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ func (k *ResourceManager) GetAgent(msg *apiv1.GetAgentRequest) (*apiv1.GetAgentR

// GetAgents implements rm.ResourceManager.
func (k *ResourceManager) GetAgents() (*apiv1.GetAgentsResponse, error) {
resp := k.jobsService.GetAgents()
// Ensure cached response is not inadvertently modified.
return rm.CopyGetAgentsResponse(resp)
return k.jobsService.GetAgents()
}

// GetAllocationSummaries implements rm.ResourceManager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ func TestGetAgents(t *testing.T) {

for _, test := range agentsTests {
t.Run(test.Name, func(t *testing.T) {
agentsResp := test.jobsService.getAgents()
agentsResp, err := test.jobsService.getAgents()
require.NoError(t, err)
require.Equal(t, len(test.wantedAgentIDs), len(agentsResp.Agents))
for _, agent := range agentsResp.Agents {
_, ok := test.wantedAgentIDs[agent.Id]
Expand Down Expand Up @@ -243,7 +244,8 @@ func TestGetAgentsNodeSelectors(t *testing.T) {
}},
}}

agentsResp := js.getAgents()
agentsResp, err := js.getAgents()
require.NoError(t, err)
require.Equal(t, len(test.agentsMatched), len(agentsResp.Agents))

for _, agent := range agentsResp.Agents {
Expand Down Expand Up @@ -863,7 +865,7 @@ func TestRMValidateResources(t *testing.T) {

func testROCMGetAgents() {
ps := createMockJobsService(createCompNodeMap(), device.ROCM, false)
ps.getAgents()
ps.getAgents() // nolint
}

func testROCMGetAgent() {
Expand Down
12 changes: 8 additions & 4 deletions master/internal/rm/kubernetesrm/resource_pool_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,9 @@ func TestPartialJobsShowQueuedStates(t *testing.T) {
require.NoError(t, err)

var slots int
for _, n := range rp.jobsService.GetAgents().Agents {
agents, err := rp.jobsService.GetAgents()
require.NoError(t, err)
for _, n := range agents.Agents {
slots += len(n.Slots)
}

Expand Down Expand Up @@ -944,11 +946,12 @@ func TestNodeWorkflows(t *testing.T) {
j := newTestJobsService(t)
rp := newTestResourcePool(j)

resp := j.getAgents()
resp, err := j.getAgents()
require.NoError(t, err)
require.Len(t, resp.Agents, 1)
nodeID := resp.Agents[0].Id

_, err := rp.jobsService.DisableAgent(&apiv1.DisableAgentRequest{AgentId: nodeID})
_, err = rp.jobsService.DisableAgent(&apiv1.DisableAgentRequest{AgentId: nodeID})
defer func() {
// Ensure we re-enable the agent, otherwise failures in this test will break others.
_, err := rp.jobsService.EnableAgent(&apiv1.EnableAgentRequest{AgentId: nodeID})
Expand All @@ -963,7 +966,8 @@ func TestNodeWorkflows(t *testing.T) {
j.getAgentsCacheTime = j.getAgentsCacheTime.Add(-time.Hour)
j.mu.Unlock()

resp = j.GetAgents()
resp, err := j.GetAgents()
require.NoError(t, err)
require.Len(t, resp.Agents, 1)
return !resp.Agents[0].Enabled
}), "GetAgents didn't say %s is disabled, but we just disabled it", nodeID)
Expand Down

0 comments on commit f004433

Please sign in to comment.