diff --git a/docs/release-notes/pods-to-jobs.rst b/docs/release-notes/pods-to-jobs.rst new file mode 100644 index 00000000000..ef6a026cc74 --- /dev/null +++ b/docs/release-notes/pods-to-jobs.rst @@ -0,0 +1,10 @@ +:orphan: + +**New Features** + +- Kubernetes: The system now launches Kubernetes jobs on behalf of users when they submit workloads + to Determined, instead of launching Kubernetes pods. This change allows Determined to work + properly with other Kubernetes features like resource quotas. + + As a result, permissions are now required to create, get, list, delete, and watch Kubernetes job + resources. diff --git a/docs/setup-cluster/k8s/_index.rst b/docs/setup-cluster/k8s/_index.rst index 0cc7e67aa0f..694234e86ad 100644 --- a/docs/setup-cluster/k8s/_index.rst +++ b/docs/setup-cluster/k8s/_index.rst @@ -23,9 +23,10 @@ Determined master and a Postgres database in the Kubernetes cluster. Once the ma running, you can launch :ref:`experiments `, :ref:`notebooks `, :ref:`TensorBoards `, :ref:`commands `, and :ref:`shells `. When new workloads are submitted to the Determined master, the master -launches pods and configMaps on the Kubernetes cluster to execute those workloads. Users of +launches jobs and config maps on the Kubernetes cluster to execute those workloads. Users of Determined shouldn't need to interact with Kubernetes directly after installation, as Determined -handles all the necessary interaction with the Kubernetes cluster. +handles all the necessary interaction with the Kubernetes cluster. Kubernetes creates and cleans up +pods for all jobs that Determined may request. It is also important to note that when running Determined on Kubernetes, a higher priority value means a higher priority (e.g. a priority 50 task will run before a priority 40 task). This is @@ -138,20 +139,6 @@ for diagnosing any issues that arise during installation. # Get logs for the pod running the Determined master. kubectl logs -Get All Running Task Pods -========================= - -These ``kubectl`` commands list and delete pods which are running Determined tasks: - -.. code:: bash - - # Get all pods that are running Determined tasks. - kubectl get pods -l=determined - - # Delete all Determined task pods. Users should never have to run this, - # unless they are removing a deployment of Determined. - kubectl get pods --no-headers=true -l=determined | awk '{print $1}' | xargs kubectl delete pod - .. toctree:: :maxdepth: 1 :hidden: diff --git a/docs/setup-cluster/k8s/custom-pod-specs.rst b/docs/setup-cluster/k8s/custom-pod-specs.rst index bd06baa99d1..6fb09db785a 100644 --- a/docs/setup-cluster/k8s/custom-pod-specs.rst +++ b/docs/setup-cluster/k8s/custom-pod-specs.rst @@ -5,8 +5,8 @@ ################# In a :ref:`Determined cluster running on Kubernetes `, tasks (e.g., -experiments, notebooks) are executed by launching one or more Kubernetes pods. You can customize -these pods by providing custom `pod specs +experiments, notebooks) are executed by launching a Kubernetes job. These jobs launch one or more +Kubernetes pods. You can customize these pods by providing custom `pod specs `__. Common use cases include assigning pods to specific nodes, specifying additional volume mounts, and attaching permissions. Configuring pod specs is not required to use Determined on Kubernetes. diff --git a/docs/setup-cluster/k8s/helm-commands.rst b/docs/setup-cluster/k8s/helm-commands.rst index 228c8f9885e..3f0d52a88e9 100644 --- a/docs/setup-cluster/k8s/helm-commands.rst +++ b/docs/setup-cluster/k8s/helm-commands.rst @@ -76,20 +76,31 @@ for diagnosing any issues that arise during installation. # Get logs for the pod running the Determined master. kubectl logs -*************************** - Get All Running Task Pods -*************************** +********************************************* + Get All Determined-launched Kubernetes Jobs +********************************************* -These ``kubectl`` commands list and delete pods which are running Determined tasks: +On Determined with Kubernetes, tasks start their own jobs, which have associated pods. These +``kubectl`` commands list and delete pods which are running Determined tasks: .. code:: bash - # Get all pods that are running Determined tasks. - kubectl get pods -l=determined + # Get all jobs that are running Determined tasks. + kubectl get jobs -l=determined + + # Get all pods associated with a given job. + kubectl get pods -l="batch.kubernetes.io/job-name=" - # Delete all Determined task pods. Users should never have to run this, + # Delete all Determined jobs for all tasks for ALL clusters. Users should never have to run this, # unless they are removing a deployment of Determined. - kubectl get pods --no-headers=true -l=determined | awk '{print $1}' | xargs kubectl delete pod + kubectl get jobs --no-headers=true -l=determined | awk '{print $1}' | xargs kubectl delete jobs + + # Get logs for a Determined task that make it to STDOUT or STDERR. Most logs are shipped to the + # Determined API server but logs that can't be shipped still go here. This is useful for debugging + # log shipping failures. + # For Determined tasks that require multiple pods, this will return logs for only one pod. It is + # recommended that you search the logs for each pod individually. + kubectl logs jobs/ *************************** Useful Debugging Commands diff --git a/harness/determined/exec/prep_container.py b/harness/determined/exec/prep_container.py index e56686f028e..7f98e1497a8 100644 --- a/harness/determined/exec/prep_container.py +++ b/harness/determined/exec/prep_container.py @@ -111,6 +111,60 @@ def do_rendezvous_slurm( ) +def do_rendezvous_kubernetes( + sess: api.Session, + allocation_id: str, + resources_id: str, +) -> "det.RendezvousInfo": + job_parallelism_str = os.environ.get("DET_KUBERNETES_JOB_PARALLELISM") + assert job_parallelism_str, "Unable to rendezvous without DET_KUBERNETES_JOB_PARALLELISM" + job_parallelism = int(job_parallelism_str) + + pod_ip_str = os.environ.get("DET_KUBERNETES_POD_IP") + assert pod_ip_str, "Unable to rendezvous without DET_KUBERNETES_POD_IP" + + num_slots_str = os.environ.get("DET_SLOT_IDS") + assert num_slots_str, "Unable to rendezvous without DET_SLOT_IDS" + num_slots = len(json.loads(os.environ["DET_SLOT_IDS"])) + + request_uuid = str(uuid.uuid4()) + resp = bindings.post_AllocationAllGather( + sess, + allocationId=allocation_id, + body=bindings.v1AllocationAllGatherRequest( + allocationId=allocation_id, + requestUuid=request_uuid, + numPeers=job_parallelism, + data={ + # We use the lexigraphical order of request IDs to + # agree on ranks among peers, so they all need it. + "request_uuid": request_uuid, + "rendezvous_ip": pod_ip_str, + "slots": num_slots, + }, + ), + ) + + # TODO(RM-306): Use indexed completions and JOB_COMPLETION_INDEX to get pod rank. + data_by_rank = [] + our_rank = None + for i, d in enumerate(sorted(resp.data, key=lambda d: str(d["request_uuid"]))): + if d["request_uuid"] == request_uuid: + our_rank = i + data_by_rank.append(d) + assert our_rank is not None, "rendezvous was missing our own information" + assert len(data_by_rank) == job_parallelism, "didn't receive enough peers from rendezvous" + + addrs = [d["rendezvous_ip"] for d in data_by_rank] + slots = [d["slots"] for d in data_by_rank] + + return det.RendezvousInfo( + container_addrs=addrs, + container_rank=our_rank, + container_slot_counts=slots, + ) + + # On HPC, the "launcher" tells the Determined Master that the job is "Running" # as soon as the workload manager (e.g., Slurm, PBS, etc) starts running the job. # However, if the container is not already cached on the compute node, it will @@ -194,7 +248,7 @@ def get_eth_interface_name() -> Optional[str]: # The canonical definitions of these consts live in Go code. -RESOURCES_TYPE_K8S_POD = "k8s-pod" +RESOURCES_TYPE_K8S_JOB = "k8s-job" RESOURCES_TYPE_DOCKER_CONTAINER = "docker-container" RESOURCES_TYPE_SLURM_JOB = "slurm-job" @@ -207,10 +261,12 @@ def do_rendezvous(sess: api.Session, allocation_id: str) -> None: assert r_type, "Unable to complete rendezvous info without DET_RESOURCES_TYPE" rendezvous_info = None - if r_type == RESOURCES_TYPE_DOCKER_CONTAINER or r_type == RESOURCES_TYPE_K8S_POD: + if r_type == RESOURCES_TYPE_DOCKER_CONTAINER: rendezvous_info = do_rendezvous_rm_provided(sess, allocation_id, r_id) elif r_type == RESOURCES_TYPE_SLURM_JOB: rendezvous_info = do_rendezvous_slurm(sess, allocation_id, r_id) + elif r_type == RESOURCES_TYPE_K8S_JOB: + rendezvous_info = do_rendezvous_kubernetes(sess, allocation_id, r_id) else: raise ValueError(f"unsupported resources type: {r_type}") @@ -251,14 +307,29 @@ def set_proxy_address(sess: api.Session, allocation_id: str) -> None: ) +def set_proxy_address_kubernetes(sess: api.Session, allocation_id: str) -> None: + pod_ip_str = os.environ.get("DET_KUBERNETES_POD_IP") + assert pod_ip_str, "Unable to complete rendezvous without DET_KUBERNETES_POD_IP" + + bindings.post_PostAllocationProxyAddress( + sess, + allocationId=allocation_id, + body=bindings.v1PostAllocationProxyAddressRequest( + proxyAddress=pod_ip_str, + ), + ) + + def do_proxy(sess: api.Session, allocation_id: str) -> None: r_type = os.environ.get("DET_RESOURCES_TYPE") assert r_type, "Unable to complete rendezvous info without DET_RESOURCES_TYPE" - if r_type == RESOURCES_TYPE_DOCKER_CONTAINER or r_type == RESOURCES_TYPE_K8S_POD: + if r_type == RESOURCES_TYPE_DOCKER_CONTAINER: return elif r_type == RESOURCES_TYPE_SLURM_JOB: set_proxy_address(sess, allocation_id) + elif r_type == RESOURCES_TYPE_K8S_JOB: + set_proxy_address_kubernetes(sess, allocation_id) else: raise ValueError(f"unsupported resources type: {r_type}") diff --git a/helm/charts/determined/templates/master-permissions.yaml b/helm/charts/determined/templates/master-permissions.yaml index dd3acf603bc..2b032d1229a 100644 --- a/helm/charts/determined/templates/master-permissions.yaml +++ b/helm/charts/determined/templates/master-permissions.yaml @@ -35,6 +35,9 @@ rules: - apiGroups: [""] resources: ["nodes"] verbs: ["list", "watch", "patch"] + - apiGroups: ["batch"] + resources: ["jobs"] + verbs: ["create", "get", "list", "delete", "watch"] --- diff --git a/master/.gitignore b/master/.gitignore index 6ae53e8dfa0..b8c045e5868 100644 --- a/master/.gitignore +++ b/master/.gitignore @@ -20,7 +20,7 @@ vendor/ # Test binary, build with `go test -c` *.test -# Output of the go coverage tool, specifically when used with LiteIDE +# Output of the go coverage tool *.out # VS Code Workspace diff --git a/master/Makefile b/master/Makefile index 4f5c48dad8a..9901b545863 100644 --- a/master/Makefile +++ b/master/Makefile @@ -5,7 +5,7 @@ SCHEMA_INPUTS = ../schemas/gen.py $(shell find ./pkg/schemas/ -name 'zgen_*.go' STREAM_INPUTS = $(shell find ./internal/stream/ -name '*_test.go' -prune -o -name '*.go' -print) STREAM_PYTHON_CLIENT = ../harness/determined/common/streams/wire.py STREAM_TS_CLIENT = ../webui/react/src/services/stream/wire.ts -MOCK_INPUTS = ./internal/sproto/task.go ./internal/db/database.go ./internal/command/authz_iface.go ../go.mod ../go.sum ./internal/rm/resource_manager_iface.go ./internal/task/allocation_service_iface.go +MOCK_INPUTS = Makefile ./internal/sproto/task.go ./internal/db/database.go ./internal/command/authz_iface.go ../go.mod ../go.sum ./internal/rm/resource_manager_iface.go ./internal/task/allocation_service_iface.go GORELEASER = goreleaser export VERSION := $(shell cat ../VERSION) @@ -99,11 +99,13 @@ build/mock_gen.stamp: $(MOCK_INPUTS) mockery --quiet --name=PodInterface --srcpkg=k8s.io/client-go/kubernetes/typed/core/v1 --output internal/mocks --filename pod_iface.go mockery --quiet --name=EventInterface --srcpkg=k8s.io/client-go/kubernetes/typed/core/v1 --output internal/mocks --filename event_iface.go mockery --quiet --name=NodeInterface --srcpkg=k8s.io/client-go/kubernetes/typed/core/v1 --output internal/mocks --filename node_iface.go + mockery --quiet --name=JobInterface --srcpkg=k8s.io/client-go/kubernetes/typed/batch/v1 --output internal/mocks --filename job_iface.go mockery --quiet --name=ResourceManager --dir=internal/rm --output internal/mocks --filename rm.go mockery --quiet --name=AllocationService --dir=internal/task --output internal/mocks/allocationmocks --filename allocation_service.go --outpkg allocationmocks mockery --quiet --name=ResourceManagerAuthZ --dir=internal/rm --output internal/mocks --filename rm_authz_iface.go mockery --quiet --name=Interface --output internal/mocks --srcpkg "k8s.io/client-go/kubernetes" --filename k8s_clientset.go --structname K8sClientsetInterface mockery --quiet --name=CoreV1Interface --output internal/mocks --srcpkg "k8s.io/client-go/kubernetes/typed/core/v1" --filename k8s_corev1_iface.go --structname K8sCoreV1Interface + mockery --quiet --name=BatchV1Interface --output internal/mocks --srcpkg "k8s.io/client-go/kubernetes/typed/batch/v1" --filename k8s_batchv1_iface.go --structname K8sBatchV1Interface mkdir -p build touch $@ diff --git a/master/internal/db/postgres_agent_intg_test.go b/master/internal/db/postgres_agent_intg_test.go index 3bd0f68ff06..963a8b08e0e 100644 --- a/master/internal/db/postgres_agent_intg_test.go +++ b/master/internal/db/postgres_agent_intg_test.go @@ -112,7 +112,7 @@ func TestEndAllAgentStats(t *testing.T) { setTimesTo(a0, a0Start, nil) // Cluster heartbeat between these. - // TODO(!!!) make cluster heartbeat a timestamptz. + // TODO(nickb): make cluster heartbeat a timestamptz. _, err := db.GetOrCreateClusterID("") require.NoError(t, err) heartBeatTime := time.Date(2021, 10, 10, 0, 0, 0, 0, time.Local).Truncate(time.Millisecond) diff --git a/master/internal/db/postgres_experiments_intg_test.go b/master/internal/db/postgres_experiments_intg_test.go index bb5d2564d97..3628177260d 100644 --- a/master/internal/db/postgres_experiments_intg_test.go +++ b/master/internal/db/postgres_experiments_intg_test.go @@ -437,7 +437,7 @@ func TestProjectHyperparameters(t *testing.T) { require.NoError(t, RemoveProjectHyperparameters(ctx, nil, []int32{int32(exp1.ID)})) - require.ElementsMatch(t, []string{}, // TODO(!!!) this is a bug in the query. + require.ElementsMatch(t, []string{}, // TODO(nickb): This is a bug in the query. RequireGetProjectHParams(t, db, projectID)) } diff --git a/master/internal/db/postgres_rp_workspace_bindings_intg_test.go b/master/internal/db/postgres_rp_workspace_bindings_intg_test.go index dabf4e80d4f..9107a666255 100644 --- a/master/internal/db/postgres_rp_workspace_bindings_intg_test.go +++ b/master/internal/db/postgres_rp_workspace_bindings_intg_test.go @@ -119,7 +119,7 @@ func TestGetDefaultPoolsForWorkspace(t *testing.T) { MustMigrateTestPostgres(t, pgDB, MigrationsFromDB) comp, aux, err := GetDefaultPoolsForWorkspace(ctx, -1) - require.NoError(t, err) // TODO(!!!) we should return errors for these cases. + require.NoError(t, err) // TODO(nickb): We should return errors for these cases. require.Equal(t, "", comp) require.Equal(t, "", aux) diff --git a/master/internal/rm/agentrm/resource_pool.go b/master/internal/rm/agentrm/resource_pool.go index 1a3acd18b7b..320383a7208 100644 --- a/master/internal/rm/agentrm/resource_pool.go +++ b/master/internal/rm/agentrm/resource_pool.go @@ -155,7 +155,7 @@ func (rp *resourcePool) allocateRequest(msg sproto.AllocateRequest) { log.WithError(err).Error("error restoring resources") // Clear out the state / close and terminate the allocation. - rmevents.Publish(msg.AllocationID, &sproto.ResourcesRestoreError{ + rmevents.Publish(msg.AllocationID, &sproto.ResourcesFailedError{ FailureType: sproto.RestoreError, ErrMsg: err.Error(), ExitCode: nil, diff --git a/master/internal/rm/kubernetesrm/informer.go b/master/internal/rm/kubernetesrm/informer.go index 65522ba5ef1..6aa605a48d0 100644 --- a/master/internal/rm/kubernetesrm/informer.go +++ b/master/internal/rm/kubernetesrm/informer.go @@ -88,7 +88,11 @@ func newEventInformer( "namespace": namespace, }) for i := range events.Items { - syslog.Debugf("informer added event: %s", events.Items[i].Name) + syslog.Debugf( + "informer added %s event: %s", + events.Items[i].InvolvedObject.Kind, + events.Items[i].Name, + ) cb(watch.Event{Object: &events.Items[i]}) } diff --git a/master/internal/rm/kubernetesrm/job.go b/master/internal/rm/kubernetesrm/job.go new file mode 100644 index 00000000000..fbb288bb8a7 --- /dev/null +++ b/master/internal/rm/kubernetesrm/job.go @@ -0,0 +1,708 @@ +package kubernetesrm + +import ( + "context" + "fmt" + "regexp" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/docker/docker/pkg/stdcopy" + "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + batchV1 "k8s.io/api/batch/v1" + k8sV1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + k8sClient "k8s.io/client-go/kubernetes" + typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" + + "github.com/determined-ai/determined/master/internal/config" + "github.com/determined-ai/determined/master/internal/rm/rmevents" + "github.com/determined-ai/determined/master/internal/sproto" + "github.com/determined-ai/determined/master/pkg/aproto" + "github.com/determined-ai/determined/master/pkg/cproto" + "github.com/determined-ai/determined/master/pkg/device" + "github.com/determined-ai/determined/master/pkg/logger" + "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/master/pkg/set" + "github.com/determined-ai/determined/master/pkg/tasks" +) + +var successfulExit = exitReason{} + +// describes why a job failed. empty value indicates success. +type exitReason struct { + code int + msg string + failureType sproto.FailureType +} + +func (r *exitReason) String() string { + if isSuccessfulExit(r) { + return "success" + } + return fmt.Sprintf("%s code=%d type=%s", r.msg, r.code, r.failureType) +} + +type podNodeInfo struct { + nodeName string + numSlots int + slotType device.Type + container *cproto.Container +} + +// job manages the lifecycle of a Kubernetes Job that executes a +// Determined task. +type job struct { + // Configuration details. Set in initialization (the `newJob` constructor) and never modified after. + clusterID string + masterIP string + masterPort int32 + masterTLSConfig model.TLSClientConfig + jobName string + configMapName string + allocationID model.AllocationID + // req.State is mutated, we should change this. + req *sproto.AllocateRequest + // Kubernetes-specific request information. + namespace string + slotsPerPod int + numPods int + containerNames set.Set[string] + scheduler string + slotType device.Type + slotResourceRequests config.PodSlotResourceRequests + restore bool + + // System dependencies. Also set in initialization and never modified after. + syslog *logrus.Entry + clientSet k8sClient.Interface + podInterface typedV1.PodInterface + configMapInterface typedV1.ConfigMapInterface + resourceRequestQueue *requestQueue + + // Internal state. Access should be protected. + mu sync.Mutex + podKillSent map[string]bool + podLogStreamerStarted map[string]bool + podNodeNames map[string]string + podStates map[string]cproto.State + podExits map[string]bool + jobExitCause *exitReason + sentStartingEvent bool + sentRunningEvent bool + sentTerminationEvent bool + // TODO(DET-10013) : Remove container field from pod struct. And get away from having several IDs, just use job name. + container cproto.Container + resourcesDeleted atomic.Bool +} + +func newJob( + name string, + msg startJob, + clusterID string, + clientSet k8sClient.Interface, + namespace string, + masterIP string, + masterPort int32, + masterTLSConfig model.TLSClientConfig, + podInterface typedV1.PodInterface, + configMapInterface typedV1.ConfigMapInterface, + resourceRequestQueue *requestQueue, + slotType device.Type, + slotResourceRequests config.PodSlotResourceRequests, + scheduler string, +) *job { + // The lifecycle of the containers specified in this map will be monitored. + // As soon as one or more of them exits, the pod will be terminated. + containerNames := set.FromSlice([]string{model.DeterminedK8ContainerName}) + + p := &job{ + req: msg.req, + clusterID: clusterID, + allocationID: msg.allocationID, + clientSet: clientSet, + namespace: namespace, + masterIP: masterIP, + masterPort: masterPort, + masterTLSConfig: masterTLSConfig, + numPods: msg.numPods, + slotsPerPod: msg.slots, + podInterface: podInterface, + configMapInterface: configMapInterface, + resourceRequestQueue: resourceRequestQueue, + jobName: name, + configMapName: name, + podNodeNames: make(map[string]string), + podStates: make(map[string]cproto.State), + podKillSent: make(map[string]bool), + podExits: make(map[string]bool), + podLogStreamerStarted: make(map[string]bool), + container: cproto.Container{ + ID: cproto.ID(msg.spec.ContainerID), + State: cproto.Assigned, + Description: msg.spec.Description, + }, + containerNames: containerNames, + scheduler: scheduler, + slotType: slotType, + slotResourceRequests: slotResourceRequests, + syslog: logrus.WithField("component", "job").WithFields( + logger.MergeContexts(msg.logContext, logger.Context{ + "job": name, + }).Fields(), + ), + } + return p +} + +func (j *job) finalize() { + j.mu.Lock() + defer j.mu.Unlock() + + // If an error occurred during the lifecycle of the pods, we need to update the scheduler + // and the task handler with new state. + if j.container.State != cproto.Terminated { + j.kill() + j.syslog.Warnf("killed job after our handler exited unexpectedly") + j.container.State = cproto.Terminated + j.jobExitCause = &exitReason{ + failureType: sproto.TaskError, + msg: "job crashed", + } + j.informTaskResourcesStopped() + } +} + +func (j *job) exitCause() *sproto.ResourcesFailedError { + if isSuccessfulExit(j.jobExitCause) { + return nil + } + + failureType := j.jobExitCause.failureType + if failureType == "" { + failureType = sproto.ResourcesFailed + } + var exitCode *sproto.ExitCode + if j.jobExitCause.code > 0 { + exitCode = (*sproto.ExitCode)(&j.jobExitCause.code) + } + return &sproto.ResourcesFailedError{ + FailureType: failureType, + ErrMsg: j.jobExitCause.msg, + ExitCode: exitCode, + } +} + +func isSuccessfulExit(cause *exitReason) bool { + return cause == nil || *cause == successfulExit +} + +func (j *job) jobUpdatedCallback(updatedJob *batchV1.Job) (cproto.State, error) { + j.mu.Lock() + defer j.mu.Unlock() + + if j.container.State == cproto.Terminated { + return j.container.State, nil + } + + conds := updatedJob.Status.Conditions + if len(conds) == 0 { + return j.container.State, nil + } + + for _, cond := range conds { + if cond.Status != k8sV1.ConditionTrue { + continue + } + + switch cond.Type { + case batchV1.JobComplete: + if j.jobExitCause == nil { + j.jobExitCause = &successfulExit + } + j.syslog.Infof( + "job %s completed and transitioned from %s to %s", + updatedJob.Name, j.container.State, cproto.Terminated, + ) + j.container.State = cproto.Terminated + j.informTaskResourcesStopped() + return cproto.Terminated, nil + + case batchV1.JobFailed: + if j.jobExitCause == nil { + j.jobExitCause = &exitReason{msg: fmt.Sprintf( + "job exited with a failure but we don't have pod-level detail: %s", + cond.Message, + )} + } + j.syslog.Infof("job %s failed and transitioned from %s to %s", updatedJob.Name, j.container.State, cproto.Terminated) + j.container.State = cproto.Terminated + j.informTaskResourcesStopped() + return cproto.Terminated, nil + } + } + + return j.container.State, nil +} + +func (j *job) jobDeletedCallback() { + j.mu.Lock() + defer j.mu.Unlock() + + if j.container.State == cproto.Terminated { + return + } + + if j.jobExitCause == nil { + j.jobExitCause = &exitReason{msg: "job was deleted"} + } + j.syslog.Info("job deleted") + j.container.State = cproto.Terminated + j.informTaskResourcesStopped() +} + +func (j *job) podUpdatedCallback(updatedPod k8sV1.Pod) error { + j.mu.Lock() + defer j.mu.Unlock() + + podName := updatedPod.Name + updatedPodState, err := j.getPodState(updatedPod) + if err != nil { + return err + } + j.podStates[podName] = updatedPodState + + j.podNodeNames[podName] = updatedPod.Spec.NodeName + + // Jobs with pods in ImagePullBackOff get stuck (https://github.com/kubernetes/kubernetes/issues/101584). + for _, s := range append(updatedPod.Status.InitContainerStatuses, updatedPod.Status.ContainerStatuses...) { + // Only check for ImagePullBackOff, ErrImagePull could be an intermittent issue and we want to be sure. + // Waiting for backoff doesn't take very long. + if waiting := s.State.Waiting; waiting != nil && waiting.Reason == "ImagePullBackOff" { + j.jobExitCause = &exitReason{msg: "job was stuck due to unrecoverable image pull errors"} + j.syslog.WithField("detail", waiting.Message).Infof(j.jobExitCause.msg) + j.kill() + } + } + + allPodsAtLeastStarting := all(cproto.Starting.Before, maps.Values(j.podStates)...) + if allPodsAtLeastStarting && !j.sentStartingEvent { + // Kubernetes does not have an explicit state for pulling container images. + // We insert it here because our current implementation of the trial actor requires it. + j.syslog.WithField("pod-name", podName).Info("pod is pulling images and starting") + j.container.State = cproto.Pulling + j.informTaskResourcesState() + + j.container.State = cproto.Starting + j.informTaskResourcesState() + j.sentStartingEvent = true + } + + if updatedPodState == cproto.Running && !j.podLogStreamerStarted[podName] { + err := startPodLogStreamer(j.podInterface, podName, func(log []byte) { + j.receiveContainerLog(sproto.ContainerLog{ + Timestamp: time.Now().UTC(), + RunMessage: &aproto.RunMessage{ + Value: string(log), + StdType: stdcopy.Stdout, + }, + }) + }) + if err != nil { + return fmt.Errorf("starting pod logs streamer for %s: %w", podName, err) + } + j.podLogStreamerStarted[podName] = true + } + + allPodsAtLeastRunning := all(cproto.Running.Before, maps.Values(j.podStates)...) + if allPodsAtLeastRunning && !j.sentRunningEvent { + j.syslog.WithField("pod-name", podName).Info("pod is running") + j.container.State = cproto.Running + j.informTaskResourcesStarted(sproto.ResourcesStarted{NativeResourcesID: j.jobName}) + j.sentRunningEvent = true + } + + if updatedPodState == cproto.Terminated && !j.podExits[podName] { + j.syslog.WithField("pod-name", podName).Info("pod is terminated") + exit, err := getExitCodeAndMessage(&updatedPod, j.containerNames) + if err != nil { + if updatedPod.ObjectMeta.DeletionTimestamp == nil { + return err + } + // When a pod is deleted, it is possible that it will exit before the + // determined containers generates an exit code. To check if this is + // the case we check if a deletion timestamp has been set. + exit = &exitReason{msg: "unable to get exit code or exit message from deleted pod"} + } + if !isSuccessfulExit(exit) { + if j.jobExitCause == nil { + j.jobExitCause = exit + } + j.syslog. + WithField("code", exit.code). + WithField("cause", j.jobExitCause). + Infof("detected a determined containers crashed, cleaning up job: %s", exit.msg) + j.killPod(podName) + } + j.podExits[podName] = true + } + + if len(j.podExits) == j.numPods { + if j.jobExitCause == nil { + // Explicitly mark this case as a success before we delete the job. + j.jobExitCause = &successfulExit + } + + j.syslog. + WithField("cause", j.jobExitCause). + Infof("detected all determined containers exited, cleaning up job") + j.kill() + } + + return nil +} + +func (j *job) podDeletedCallback(deleted *k8sV1.Pod) { + j.mu.Lock() + defer j.mu.Unlock() + + j.syslog.WithField("pod-name", deleted.Name).Info("pod deleted") + if j.jobExitCause == nil { + j.jobExitCause = &exitReason{ + failureType: sproto.TaskError, + msg: fmt.Sprintf("pod %s deleted", deleted.Name), + } + } +} + +func (j *job) newEventCallback(event *k8sV1.Event) { + j.mu.Lock() + defer j.mu.Unlock() + + msgText := j.preparePodUpdateMessage(event.Message) + message := fmt.Sprintf("%s %s: %s", event.InvolvedObject.Kind, event.InvolvedObject.Name, msgText) + j.insertLog(event.CreationTimestamp.Time, message) +} + +func (j *job) preemptionCallback() { + j.syslog.Info("received preemption command") + rmevents.Publish(j.allocationID, &sproto.ReleaseResources{Reason: "preempted by the scheduler"}) +} + +func (j *job) changePriority() { + j.syslog.Info("interrupting job to change priorities") + rmevents.Publish(j.allocationID, &sproto.ReleaseResources{Reason: "priority changed"}) +} + +func (j *job) changePosition() { + j.syslog.Info("interrupting job to change positions") + rmevents.Publish(j.allocationID, &sproto.ReleaseResources{Reason: "queue position changed"}) +} + +func (j *job) Kill() { + j.mu.Lock() + defer j.mu.Unlock() + + j.syslog.Info("received request to stop job") + if j.jobExitCause == nil { + j.jobExitCause = &exitReason{msg: "killed"} + } + j.kill() +} + +func (j *job) kill() { + if !j.resourcesDeleted.CompareAndSwap(false, true) { + return + } + + j.syslog.Infof("requesting to delete kubernetes resources %s", j.jobName) + j.resourceRequestQueue.deleteKubernetesResources(j.namespace, j.jobName, j.configMapName, "") +} + +func (j *job) killPod(name string) { + if j.podKillSent[name] { + return + } + + j.syslog.Infof("requesting to delete kubernetes resources %s", j.jobName) + j.resourceRequestQueue.deleteKubernetesResources(j.namespace, "", "", name) + j.podKillSent[name] = true +} + +func (j *job) getNodeInfoForPods() []podNodeInfo { + j.mu.Lock() + defer j.mu.Unlock() + + var infos []podNodeInfo + for _, nodeName := range j.podNodeNames { + infos = append(infos, podNodeInfo{ + nodeName: nodeName, + numSlots: j.slotsPerPod, + slotType: j.slotType, + container: j.container.DeepCopy(), + }) + } + return infos +} + +func (j *job) startPodLogStreamers() error { + podList, err := j.podInterface.List(context.TODO(), metav1.ListOptions{ + LabelSelector: fmt.Sprintf("%s=%s", determinedLabel, j.req.AllocationID), + }) + if err != nil { + return fmt.Errorf("listing job pods to reattach log streamers: %w", err) + } + for _, pod := range podList.Items { + if pod.Status.Phase != k8sV1.PodRunning { + j.syslog.Warnf("skipped reattaching pod log streamer for pod %s in phase %s", pod.Name, pod.Status.Phase) + continue + } + + err := startPodLogStreamer(j.podInterface, pod.Name, func(log []byte) { + j.receiveContainerLog(sproto.ContainerLog{ + Timestamp: time.Now().UTC(), + RunMessage: &aproto.RunMessage{ + Value: string(log), + StdType: stdcopy.Stdout, + }, + }) + }) + if err != nil { + return fmt.Errorf("starting pod logs streamer for %s: %w", pod.Name, err) + } + } + return nil +} + +func (j *job) createSpecAndSubmit(spec *tasks.TaskSpec) error { + jobSpec, configMapSpec, err := j.createSpec(j.scheduler, spec) + if err != nil { + return err + } + + j.resourceRequestQueue.createKubernetesResources(jobSpec, configMapSpec) + return nil +} + +func (j *job) receiveResourceCreationFailed(msg resourceCreationFailed) { + j.syslog.WithError(msg.err).Error("pod handler notified that resource creation failed") + j.insertLog(time.Now().UTC(), msg.err.Error()) +} + +func (j *job) receiveResourceCreationCancelled() { + j.syslog.Info("pod creation canceled") + j.resourcesDeleted.Store(true) +} + +func (j *job) receiveResourceDeletionFailed(msg resourceDeletionFailed) { + j.syslog.WithError(msg.err).Error("pod handler notified that resource deletion failed") +} + +func (j *job) informTaskResourcesState() { + rmevents.Publish(j.allocationID, &sproto.ResourcesStateChanged{ + ResourcesID: sproto.FromContainerID(j.container.ID), + ResourcesState: sproto.FromContainerState(j.container.State), + Container: j.container.DeepCopy(), + }) +} + +func (j *job) informTaskResourcesStarted(rs sproto.ResourcesStarted) { + rmevents.Publish(j.allocationID, &sproto.ResourcesStateChanged{ + ResourcesID: sproto.FromContainerID(j.container.ID), + ResourcesState: sproto.FromContainerState(j.container.State), + ResourcesStarted: &rs, + Container: j.container.DeepCopy(), + }) +} + +func (j *job) informTaskResourcesStopped() { + if j.sentTerminationEvent { + return + } + + rmevents.Publish(j.allocationID, &sproto.ResourcesStateChanged{ + ResourcesID: sproto.FromContainerID(j.container.ID), + ResourcesState: sproto.FromContainerState(j.container.State), + ResourcesStopped: &sproto.ResourcesStopped{Failure: j.exitCause()}, + Container: j.container.DeepCopy(), + }) + j.sentTerminationEvent = true +} + +func (j *job) receiveContainerLog(msg sproto.ContainerLog) { + msg.ContainerID = j.container.ID + rmevents.Publish(j.allocationID, &msg) +} + +func (j *job) insertLog(timestamp time.Time, msg string) { + j.receiveContainerLog(sproto.ContainerLog{ + Timestamp: timestamp, + AuxMessage: &msg, + }) +} + +// Converts k8s message to be more understandable. +func (j *job) preparePodUpdateMessage(msgText string) string { + // Handle simple message replacements. + replacements := map[string]string{ + "pod triggered scale-up": "Job requires additional resources, scaling up cluster.", + "Successfully assigned": "Pod resources allocated.", + "skip schedule deleting pod": "Deleting unscheduled pod.", + } + + simpleReplacement := false + + for k, v := range replacements { + matched, err := regexp.MatchString(k, msgText) + if err != nil { + break + } else if matched { + msgText = v + simpleReplacement = true + } + } + + // Otherwise, try special treatment for slots availability message. + if !simpleReplacement { + matched, err := regexp.MatchString("nodes are available", msgText) + if err == nil && matched { + available := string(msgText[0]) + required := strconv.Itoa(j.slotsPerPod) + var resourceName string + switch j.slotType { + case device.CPU: + resourceName = "CPU slots" + default: + resourceName = "GPUs" + } + + msgText = fmt.Sprintf("Waiting for resources. %s %s are available, %s %s required", + available, resourceName, required, resourceName) + } + } + + return msgText +} + +func (j *job) getPodState(pod k8sV1.Pod) (cproto.State, error) { + switch pod.Status.Phase { + case k8sV1.PodPending: + // When pods are deleted, Kubernetes sometimes transitions pod statuses to pending + // prior to deleting them. In these cases we have observed that we do not always + // receive a PodFailed or a PodSucceeded message. We check if pods have a set pod + // deletion timestamp to see if this is the case. + if pod.ObjectMeta.DeletionTimestamp != nil { + j.syslog.Warn("marking pod as terminated due to deletion timestamp") + return cproto.Terminated, nil + } + + for _, condition := range pod.Status.Conditions { + if condition.Type == k8sV1.PodScheduled && condition.Status == k8sV1.ConditionTrue { + return cproto.Starting, nil + } + } + return cproto.Assigned, nil + + case k8sV1.PodRunning: + // Pods are in a running state as long as at least one container has not terminated. + // We check the status of the Determined containers directly to determine if they + // are still running. + containerStatuses, err := getDeterminedContainersStatus( + pod.Status.ContainerStatuses, j.containerNames) + if err != nil { + return "", err + } + + for _, containerStatus := range containerStatuses { + if containerStatus.State.Terminated != nil { + return cproto.Terminated, nil + } + } + + for _, containerStatus := range containerStatuses { + // Check that all Determined containers are running. + if containerStatus.State.Running == nil { + return cproto.Starting, nil + } + } + + return cproto.Running, nil + + case k8sV1.PodFailed, k8sV1.PodSucceeded: + return cproto.Terminated, nil + + default: + return "", fmt.Errorf("unexpected pod status %s for pod %s", pod.Status.Phase, pod.Name) + } +} + +func getExitCodeAndMessage(pod *k8sV1.Pod, containerNames set.Set[string]) (*exitReason, error) { + if len(pod.Status.InitContainerStatuses) == 0 { + return nil, fmt.Errorf("unexpected number of init containers when processing exit code for pod %s", pod.Name) + } + + for _, initContainerStatus := range pod.Status.InitContainerStatuses { + if initContainerStatus.State.Terminated == nil { + continue + } + exitCode := initContainerStatus.State.Terminated.ExitCode + if exitCode != aproto.SuccessExitCode { + errMessage := fmt.Sprintf( + "container %s: %s", initContainerStatus.Name, + initContainerStatus.State.Terminated.Message, + ) + return &exitReason{ + code: int(exitCode), + msg: errMessage, + }, nil + } + } + + if len(pod.Status.ContainerStatuses) < len(containerNames) { + return nil, fmt.Errorf("unexpected number of containers when processing exit code for pod %s", pod.Name) + } + + containerStatuses, err := getDeterminedContainersStatus( + pod.Status.ContainerStatuses, containerNames) + if err != nil { + return nil, err + } + + for _, containerStatus := range containerStatuses { + terminationStatus := containerStatus.State.Terminated + if terminationStatus != nil { + return &exitReason{ + code: int(terminationStatus.ExitCode), + msg: terminationStatus.Message, + }, nil + } + } + + return nil, fmt.Errorf("unable to get exit code from pod %s", pod.Name) +} + +func getDeterminedContainersStatus( + statuses []k8sV1.ContainerStatus, + containerNames set.Set[string], +) ([]*k8sV1.ContainerStatus, error) { + containerStatuses := make([]*k8sV1.ContainerStatus, 0, len(statuses)) + for idx, containerStatus := range statuses { + if !containerNames.Contains(containerStatus.Name) { + continue + } + containerStatuses = append(containerStatuses, &statuses[idx]) + } + + if len(containerStatuses) != len(containerNames) { + containerNamesFound := make([]string, 0, len(containerStatuses)) + for _, containerStatus := range containerStatuses { + containerNamesFound = append(containerNamesFound, containerStatus.Name) + } + return nil, fmt.Errorf("found container statuses only for: %v", containerNamesFound) + } + + return containerStatuses, nil +} diff --git a/master/internal/rm/kubernetesrm/job_test.go b/master/internal/rm/kubernetesrm/job_test.go new file mode 100644 index 00000000000..56df9230ce4 --- /dev/null +++ b/master/internal/rm/kubernetesrm/job_test.go @@ -0,0 +1 @@ +package kubernetesrm diff --git a/master/internal/rm/kubernetesrm/jobs.go b/master/internal/rm/kubernetesrm/jobs.go new file mode 100644 index 00000000000..63f435de3c2 --- /dev/null +++ b/master/internal/rm/kubernetesrm/jobs.go @@ -0,0 +1,1859 @@ +package kubernetesrm + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + batchV1 "k8s.io/api/batch/v1" + k8sV1 "k8s.io/api/core/v1" + k8error "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/api/resource" + metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/informers" + k8sClient "k8s.io/client-go/kubernetes" + typedBatchV1 "k8s.io/client-go/kubernetes/typed/batch/v1" + typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/client-go/util/homedir" + + "github.com/determined-ai/determined/master/internal/config" + "github.com/determined-ai/determined/master/internal/db" + "github.com/determined-ai/determined/master/internal/rm/rmevents" + "github.com/determined-ai/determined/master/internal/sproto" + "github.com/determined-ai/determined/master/pkg/cproto" + "github.com/determined-ai/determined/master/pkg/device" + "github.com/determined-ai/determined/master/pkg/logger" + "github.com/determined-ai/determined/master/pkg/model" + "github.com/determined-ai/determined/master/pkg/set" + "github.com/determined-ai/determined/master/pkg/syncx/waitgroupx" + "github.com/determined-ai/determined/master/pkg/tasks" + "github.com/determined-ai/determined/proto/pkg/apiv1" + + // Used to load all auth plugins. + _ "k8s.io/client-go/plugin/pkg/client/auth" +) + +const ( + determinedLabel = "determined" + determinedPreemptionLabel = "determined-preemption" + determinedSystemLabel = "determined-system" + + kubernetesJobNameLabel = "batch.kubernetes.io/job-name" + + resourceTypeNvidia = "nvidia.com/gpu" +) + +type summarizeResult struct { + summary map[string]model.AgentSummary + err error +} + +type jobMetadata struct { + jobName string + allocationID model.AllocationID +} + +// High lever overview of the actors within the kubernetes package: +// +// jobsService +// +- job(s): manages pod lifecycle. One per container in a task. +// +- podLogStreamer: stream logs for a specific pod. +// +- informer: sends updates about pod states +// +- events: sends updates about kubernetes events. +// +- requestQueue: queues requests to create / delete kubernetes resources. +// +- requestProcessingWorkers: processes request to create / delete kubernetes resources. +type jobsService struct { + // Configuration details. Set in initialization (the `newJobService` constructor) and never modified after. + namespace string + namespaceToPoolName map[string]string + scheduler string + slotType device.Type + slotResourceRequests config.PodSlotResourceRequests + resourcePoolConfigs []config.ResourcePoolConfig + baseContainerDefaults *model.TaskContainerDefaultsConfig + masterServiceName string + masterTLSConfig model.TLSClientConfig + detMasterIP string + detMasterPort int32 + kubeconfigPath string + + // System dependencies. Also set in initialization and never modified after. + syslog *logrus.Entry + clientSet k8sClient.Interface + podInterfaces map[string]typedV1.PodInterface + configMapInterfaces map[string]typedV1.ConfigMapInterface + jobInterfaces map[string]typedBatchV1.JobInterface + resourceRequestQueue *requestQueue + jobSchedulingStateCallback jobSchedulingStateCallback + + // Internal state. Access should be protected. + wg waitgroupx.Group + mu sync.RWMutex + jobNameToJobHandler map[string]*job + jobNameToResourcePool map[string]string + jobNameToPodNameToSchedulingState map[string]map[string]sproto.SchedulingState + allocationIDToJobName map[model.AllocationID]string + jobHandlerToMetadata map[*job]jobMetadata + nodeToSystemResourceRequests map[string]int64 + currentNodes map[string]*k8sV1.Node + // TODO(RM-236) make one cache and make this code more straightforward. + summarizeCacheLock sync.RWMutex + summarizeCache summarizeResult + summarizeCacheTime time.Time + getAgentsCacheLock sync.Mutex + getAgentsCache *apiv1.GetAgentsResponse + getAgentsCacheTime time.Time +} + +// newJobsService creates a new pod service for launching, querying and interacting with k8s pods. +func newJobsService( + namespace string, + namespaceToPoolName map[string]string, + masterServiceName string, + masterTLSConfig model.TLSClientConfig, + scheduler string, + slotType device.Type, + slotResourceRequests config.PodSlotResourceRequests, + resourcePoolConfigs []config.ResourcePoolConfig, + taskContainerDefaults *model.TaskContainerDefaultsConfig, + detMasterIP string, + detMasterPort int32, + kubeconfigPath string, + jobSchedulingStateCb jobSchedulingStateCallback, +) (*jobsService, error) { + p := &jobsService{ + wg: waitgroupx.WithContext(context.Background()), + + namespace: namespace, + namespaceToPoolName: namespaceToPoolName, + masterServiceName: masterServiceName, + masterTLSConfig: masterTLSConfig, + scheduler: scheduler, + jobNameToJobHandler: make(map[string]*job), + jobNameToResourcePool: make(map[string]string), + allocationIDToJobName: make(map[model.AllocationID]string), + jobNameToPodNameToSchedulingState: make(map[string]map[string]sproto.SchedulingState), + jobHandlerToMetadata: make(map[*job]jobMetadata), + slotType: slotType, + slotResourceRequests: slotResourceRequests, + resourcePoolConfigs: resourcePoolConfigs, + baseContainerDefaults: taskContainerDefaults, + detMasterIP: detMasterIP, + detMasterPort: detMasterPort, + currentNodes: make(map[string]*k8sV1.Node), + nodeToSystemResourceRequests: make(map[string]int64), + podInterfaces: make(map[string]typedV1.PodInterface), + configMapInterfaces: make(map[string]typedV1.ConfigMapInterface), + jobInterfaces: make(map[string]typedBatchV1.JobInterface), + syslog: logrus.WithField("namespace", namespace), + jobSchedulingStateCallback: jobSchedulingStateCb, + + kubeconfigPath: kubeconfigPath, + } + + if err := p.startClientSet(); err != nil { + return nil, err + } + if err := p.getMasterIPAndPort(); err != nil { + return nil, err + } + if err := p.getSystemResourceRequests(); err != nil { + return nil, err + } + + p.startResourceRequestQueue() + + if err := p.deleteDoomedKubernetesResources(); err != nil { + return nil, err + } + + err := p.startNodeInformer() + switch { + case err != nil && k8error.IsForbidden(err): + p.syslog.Warnf("unable to start node informer due to permission error,"+ + "some features will be degraded: %s", err, + ) + case err != nil: + return nil, err + } + + err = p.startEventListeners() + if err != nil { + return nil, err + } + + err = p.startPreemptionListeners() + if err != nil { + return nil, err + } + + var cacheSyncs []cache.InformerSynced + for namespace := range p.namespaceToPoolName { + factory := informers.NewSharedInformerFactoryWithOptions(p.clientSet, time.Hour, informers.WithNamespace(namespace)) + + jobsInformer := factory.Batch().V1().Jobs() + jobsInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + p.mu.Lock() + defer p.mu.Unlock() + p.jobUpdatedCallback(obj) + }, + UpdateFunc: func(_, obj interface{}) { + p.mu.Lock() + defer p.mu.Unlock() + p.jobUpdatedCallback(obj) + }, + + // If a job is deleted out from under us, this is the only hook we have to not + // leave our workloads running or pending forever. + DeleteFunc: func(obj interface{}) { + p.mu.Lock() + defer p.mu.Unlock() + p.jobDeletedCallback(obj) + }, + }) + cacheSyncs = append(cacheSyncs, jobsInformer.Informer().HasSynced) + + podsInformer := factory.Core().V1().Pods() + podsInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj interface{}) { + p.mu.Lock() + defer p.mu.Unlock() + p.podStatusCallback(obj) + }, + UpdateFunc: func(_, obj interface{}) { + p.mu.Lock() + defer p.mu.Unlock() + p.podStatusCallback(obj) + }, + + // If a pod is deleted out from under us, it is nice to let the user know that + // is what happened. + DeleteFunc: func(obj interface{}) { + p.mu.Lock() + defer p.mu.Unlock() + p.podDeletedCallback(obj) + }, + }) + cacheSyncs = append(cacheSyncs, podsInformer.Informer().HasSynced) + + factory.Start(nil) + } + if !cache.WaitForCacheSync(nil, cacheSyncs...) { + return nil, errors.New("failed to wait for cache sync for jobs informer") + } + return p, nil +} + +func (j *jobsService) startClientSet() error { + config, err := readClientConfig(j.kubeconfigPath) + if err != nil { + return fmt.Errorf("error building kubernetes config: %w", err) + } + + j.clientSet, err = k8sClient.NewForConfig(config) + if err != nil { + return fmt.Errorf("failed to initialize kubernetes clientSet: %w", err) + } + + for _, ns := range append(maps.Keys(j.namespaceToPoolName), j.namespace) { + j.podInterfaces[ns] = j.clientSet.CoreV1().Pods(ns) + j.configMapInterfaces[ns] = j.clientSet.CoreV1().ConfigMaps(ns) + j.jobInterfaces[ns] = j.clientSet.BatchV1().Jobs(ns) + } + + j.syslog.Infof("kubernetes clientSet initialized") + return nil +} + +func readClientConfig(kubeconfigPath string) (*rest.Config, error) { + if len(kubeconfigPath) == 0 { + // The default in-cluster case. Internally, k8s.io/client-go/rest is going to look for + // environment variables: + // - KUBERNETES_SERVICE_HOST + // - KUBERNETES_SERVICE_PORT + // and it expects to find files: + // - /var/run/secrets/kubernetes.io/serviceaccount/token + // - /var/run/secrets/kubernetes.io/serviceaccount/ca.crt + return rest.InClusterConfig() + } + + if parts := strings.Split(kubeconfigPath, string(os.PathSeparator)); parts[0] == "~" { + parts[0] = homedir.HomeDir() + expanded := filepath.Join(parts...) + logrus.Infof("expanding kubeconfig path from %s to %s", kubeconfigPath, expanded) + kubeconfigPath = expanded + } + + bs, err := os.ReadFile(kubeconfigPath) // #nosec G304 // User must have fs access to set this config var anyway. + if err != nil { + return nil, fmt.Errorf("reading kubeconfig at %s: %w", kubeconfigPath, err) + } + + cl, err := clientcmd.RESTConfigFromKubeConfig(bs) + if err != nil { + return nil, fmt.Errorf("building rest.Config from kubeconfig at %s: %w", kubeconfigPath, err) + } + return cl, nil +} + +func (j *jobsService) getMasterIPAndPort() error { + if j.detMasterIP != "" && j.detMasterPort != 0 { + // Master ip and port were manually configured. For special circumstances, e.g., the master is running + // outside of this cluster (happens in development or when we spread across multiple k8s clusters). + return nil + } + masterService, err := j.clientSet.CoreV1().Services(j.namespace).Get( + context.TODO(), j.masterServiceName, metaV1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get master service: %w", err) + } + + j.detMasterIP = masterService.Spec.ClusterIP + j.detMasterPort = masterService.Spec.Ports[0].Port + j.syslog.Infof("master URL set to %s:%d", j.detMasterIP, j.detMasterPort) + return nil +} + +func (j *jobsService) getSystemResourceRequests() error { + systemPods, err := j.podInterfaces[j.namespace].List( + context.TODO(), metaV1.ListOptions{LabelSelector: determinedSystemLabel}) + if err != nil { + return fmt.Errorf("failed to get system pods: %w", err) + } + + for _, systemPod := range systemPods.Items { + for _, container := range systemPod.Spec.Containers { + j.nodeToSystemResourceRequests[systemPod.Spec.NodeName] += container.Resources.Requests.Cpu(). + MilliValue() + } + } + return nil +} + +func (j *jobsService) deleteDoomedKubernetesResources() error { + var openAllocations []model.Allocation + if err := db.Bun().NewSelect().Model(&openAllocations). + Where("end_time IS NULL"). + Scan(context.TODO()); err != nil { + return fmt.Errorf("error querying the database for open allocations: %w", err) + } + openAllocationIDs := make(set.Set[model.AllocationID]) + for _, alloc := range openAllocations { + openAllocationIDs.Insert(alloc.AllocationID) + } + j.syslog.Infof("found open allocations %s", openAllocationIDs) + + listOptions := metaV1.ListOptions{LabelSelector: determinedLabel} + jobs, err := j.listJobsInAllNamespaces(context.TODO(), listOptions) + if err != nil { + return fmt.Errorf("error listing existing pods: %w", err) + } + + toKillJobs := &batchV1.JobList{} + savedJobNames := make(set.Set[string]) + for _, job := range jobs.Items { + if _, ok := j.namespaceToPoolName[job.Namespace]; !ok { + continue + } + + resourcePool := job.Labels[resourcePoolLabel] + if resourcePool == "" { + j.syslog.Warnf("deleting job '%s' without resource pool label", job.Name) + toKillJobs.Items = append(toKillJobs.Items, job) + continue + } + + allocationIDStr := job.Labels[allocationIDLabel] + if allocationIDStr == "" { + j.syslog.Warnf("deleting job '%s' without determined label (whose value is the allocation ID)", job.Name) + toKillJobs.Items = append(toKillJobs.Items, job) + continue + } + allocationID := model.AllocationID(allocationIDStr) + + if !openAllocationIDs.Contains(allocationID) { + j.syslog. + WithField("allocation-id", allocationID). + Warnf("deleting job '%s', did not find an open allocation for it", job.Name) + toKillJobs.Items = append(toKillJobs.Items, job) + continue + } + + savedJobNames.Insert(job.Name) + } + + configMaps, err := j.listConfigMapsInAllNamespaces(context.TODO(), listOptions) + if err != nil { + return fmt.Errorf("error listing existing config maps: %w", err) + } + toKillConfigMaps := &k8sV1.ConfigMapList{} + for _, cm := range configMaps.Items { + if _, ok := j.namespaceToPoolName[cm.Namespace]; !ok { + continue + } + + if savedJobNames.Contains(cm.Name) { // Job name is same as config map name. + continue + } + + j.syslog.Debugf("deleting config map '%s', did not find a matching job that will be restored", cm.Name) + toKillConfigMaps.Items = append(toKillConfigMaps.Items, cm) + } + + j.deleteKubernetesResources(toKillJobs, toKillConfigMaps) + return nil +} + +// startJob notifies the pods actor to start a pod with the task spec. +type startJob struct { + req *sproto.AllocateRequest + allocationID model.AllocationID + spec tasks.TaskSpec + slots int + rank int + resourcePool string + namespace string + + numPods int + + logContext logger.Context +} + +func (j *jobsService) StartJob(msg startJob) error { + j.mu.Lock() + defer j.mu.Unlock() + return j.startJob(msg) +} + +func (j *jobsService) startJob(msg startJob) error { + newJobHandler := newJob( + configureUniqueName(msg.spec), + msg, + msg.spec.ClusterID, + j.clientSet, + msg.namespace, + j.detMasterIP, + j.detMasterPort, + j.masterTLSConfig, + j.podInterfaces[msg.namespace], + j.configMapInterfaces[msg.namespace], + j.resourceRequestQueue, + j.slotType, + j.slotResourceRequests, + j.scheduler, + ) + + if _, alreadyExists := j.jobNameToJobHandler[newJobHandler.jobName]; alreadyExists { + return fmt.Errorf("attempting to register same job name: %s multiple times", newJobHandler.jobName) + } + + err := newJobHandler.createSpecAndSubmit(&msg.spec) + if err != nil { + return fmt.Errorf("creating pod: %w", err) + } + + j.jobNameToJobHandler[newJobHandler.jobName] = newJobHandler + j.jobNameToResourcePool[newJobHandler.jobName] = msg.resourcePool + j.allocationIDToJobName[msg.req.AllocationID] = newJobHandler.jobName + j.jobNameToPodNameToSchedulingState[newJobHandler.jobName] = make(map[string]sproto.SchedulingState) + j.jobHandlerToMetadata[newJobHandler] = jobMetadata{ + jobName: newJobHandler.jobName, + allocationID: newJobHandler.req.AllocationID, + } + + return nil +} + +func (j *jobsService) ChangePriority(id model.AllocationID) { + j.mu.Lock() + defer j.mu.Unlock() + j.changePriority(id) +} + +func (j *jobsService) ChangePosition(id model.AllocationID) { + j.mu.Lock() + defer j.mu.Unlock() + j.changePosition(id) +} + +func (j *jobsService) KillJob(id model.AllocationID) { + j.mu.Lock() + defer j.mu.Unlock() + j.killJob(id) +} + +func (j *jobsService) SummarizeResources(poolName string) (*computeUsageSummary, error) { + j.mu.Lock() + defer j.mu.Unlock() + return j.summarizeComputeUsage(poolName) +} + +func (j *jobsService) ReattachJob(msg reattachJobRequest) (reattachJobResponse, error) { + j.mu.Lock() + defer j.mu.Unlock() + return j.reattachJob(msg) +} + +type reattachJobRequest struct { + req *sproto.AllocateRequest + numPods int + allocationID model.AllocationID + slots int + logContext logger.Context +} + +type reattachJobResponse struct { + started *sproto.ResourcesStarted +} + +func (j *jobsService) reattachJob(msg reattachJobRequest) (reattachJobResponse, error) { + listOptions := metaV1.ListOptions{ + LabelSelector: fmt.Sprintf("%s=%s", determinedLabel, msg.allocationID), + } + + jobs, err := j.listJobsInAllNamespaces(context.TODO(), listOptions) + if err != nil { + return reattachJobResponse{}, fmt.Errorf("error listing pods checking if they can be restored: %w", err) + } + + configMaps, err := j.listConfigMapsInAllNamespaces(context.TODO(), listOptions) + if err != nil { + return reattachJobResponse{}, fmt.Errorf("error listing config maps checking if they can be restored: %w", err) + } + existingConfigMaps := make(set.Set[string]) + for _, cm := range configMaps.Items { + if _, ok := j.namespaceToPoolName[cm.Namespace]; !ok { + continue + } + existingConfigMaps.Insert(cm.Name) + } + + if len(jobs.Items) == 0 { + return reattachJobResponse{}, fmt.Errorf("did not find job for allocation %s", msg.allocationID) + } else if len(jobs.Items) > 1 { + return reattachJobResponse{}, fmt.Errorf("found multiple allocation jobs for allocation %s", msg.allocationID) + } + job := jobs.Items[0] + + resourcePool, ok := job.Labels[resourcePoolLabel] + if !ok { + return reattachJobResponse{}, fmt.Errorf("could not recover resource pool for %s", msg.allocationID) + } + + resp, err := j.recreateJobHandler( + job.Name, + msg.req, + msg.allocationID, + resourcePool, + &job, + msg.slots, + msg.numPods, + msg.logContext, + ) + if err != nil { + j.deleteKubernetesResources(jobs, configMaps) + return reattachJobResponse{}, fmt.Errorf("error restoring pod with allocation ID %s: %w", msg.allocationID, err) + } + return resp, nil +} + +func (j *jobsService) recreateJobHandler( + name string, + req *sproto.AllocateRequest, + allocationID model.AllocationID, + resourcePool string, + job *batchV1.Job, + slots int, + numPods int, + logContext logger.Context, +) (reattachJobResponse, error) { + startMsg := startJob{ + req: req, + allocationID: allocationID, + spec: tasks.TaskSpec{ + // This gets used in reattach to find the job by label its determinedLabel. + AllocationID: string(allocationID), + ContainerID: req.AllocationID.String(), // ContainerID is non-sense, make a better abstraction. + }, + slots: slots, + numPods: numPods, + resourcePool: resourcePool, + logContext: logContext, + } + + newJobHandler := newJob( + name, + startMsg, + startMsg.spec.ClusterID, + j.clientSet, + job.Namespace, + j.detMasterIP, + j.detMasterPort, + j.masterTLSConfig, + j.podInterfaces[job.Namespace], + j.configMapInterfaces[job.Namespace], + j.resourceRequestQueue, + j.slotType, + j.slotResourceRequests, + j.scheduler, + ) + + newJobHandler.restore = true + newJobHandler.jobName = job.Name + newJobHandler.configMapName = job.Name + + err := newJobHandler.startPodLogStreamers() + if err != nil { + return reattachJobResponse{}, fmt.Errorf("reattaching pod: %w", err) + } + + j.jobNameToJobHandler[job.Name] = newJobHandler + j.jobNameToResourcePool[job.Name] = resourcePool + j.allocationIDToJobName[newJobHandler.req.AllocationID] = job.Name + j.jobNameToPodNameToSchedulingState[job.Name] = make(map[string]sproto.SchedulingState) + j.jobHandlerToMetadata[newJobHandler] = jobMetadata{ + jobName: job.Name, + allocationID: newJobHandler.req.AllocationID, + } + + return reattachJobResponse{started: nil}, nil +} + +func (j *jobsService) deleteKubernetesResources( + jobs *batchV1.JobList, configMaps *k8sV1.ConfigMapList, +) { + for _, job := range jobs.Items { + j.resourceRequestQueue.deleteKubernetesResources(job.Namespace, job.Name, "", "") + } + + for _, configMap := range configMaps.Items { + j.resourceRequestQueue.deleteKubernetesResources(configMap.Namespace, "", configMap.Name, "") + } +} + +func (j *jobsService) RefreshStates(allocationID model.AllocationID) error { + j.mu.Lock() + defer j.mu.Unlock() + err := j.refreshJobState(allocationID) + if err != nil { + return err + } + return j.refreshPodStates(allocationID) +} + +func (j *jobsService) refreshJobState(allocationID model.AllocationID) error { + if allocationID == "" { + return fmt.Errorf("invalid call: allocationID missing") + } + + jobs, err := j.listJobsInAllNamespaces(context.TODO(), metaV1.ListOptions{ + LabelSelector: fmt.Sprintf("%s=%s", determinedLabel, allocationID), + }) + if err != nil { + return fmt.Errorf("error listing pods checking if they can be restored: %w", err) + } + + for _, job := range jobs.Items { + if _, ok := j.namespaceToPoolName[job.Namespace]; !ok { + continue + } + job := job + j.jobUpdatedCallback(&job) + } + return nil +} + +func (j *jobsService) refreshPodStates(allocationID model.AllocationID) error { + if allocationID == "" { + return fmt.Errorf("invalid call: allocationID missing") + } + + pods, err := j.listPodsInAllNamespaces(context.TODO(), metaV1.ListOptions{ + LabelSelector: fmt.Sprintf("%s=%s", determinedLabel, allocationID), + }) + if err != nil { + return fmt.Errorf("error listing pods checking if they can be restored: %w", err) + } + + for _, pod := range pods.Items { + if _, ok := j.namespaceToPoolName[pod.Namespace]; !ok { + continue + } + pod := pod + j.podStatusCallback(&pod) + } + return nil +} + +func (j *jobsService) GetAgents() *apiv1.GetAgentsResponse { + j.mu.Lock() + defer j.mu.Unlock() + return j.getAgents() +} + +func (j *jobsService) GetAgent(msg *apiv1.GetAgentRequest) *apiv1.GetAgentResponse { + j.mu.Lock() + defer j.mu.Unlock() + return j.getAgent(msg.AgentId) +} + +func (j *jobsService) EnableAgent(msg *apiv1.EnableAgentRequest) (*apiv1.EnableAgentResponse, error) { + j.mu.Lock() + defer j.mu.Unlock() + return j.enableNode(msg.AgentId) +} + +func (j *jobsService) DisableAgent(msg *apiv1.DisableAgentRequest) (*apiv1.DisableAgentResponse, error) { + j.mu.Lock() + defer j.mu.Unlock() + return j.disableNode(msg.AgentId, msg.Drain) +} + +func (j *jobsService) GetSlots(msg *apiv1.GetSlotsRequest) *apiv1.GetSlotsResponse { + j.mu.Lock() + defer j.mu.Unlock() + return j.getSlots(msg.AgentId) +} + +func (j *jobsService) GetSlot(msg *apiv1.GetSlotRequest) *apiv1.GetSlotResponse { + j.mu.Lock() + defer j.mu.Unlock() + return j.getSlot(msg.AgentId, msg.SlotId) +} + +func (j *jobsService) HealthStatus() model.HealthStatus { + j.mu.Lock() + defer j.mu.Unlock() + for _, podInterface := range j.podInterfaces { + _, err := podInterface.List(context.TODO(), metaV1.ListOptions{Limit: 1}) + if err != nil { + j.syslog.WithError(err).Error("kubernetes resource manager marked as unhealthy") + return model.Unhealthy + } + return model.Healthy + } + + logrus.Error("expected jobInterface to be non empty") + return model.Unhealthy +} + +func (j *jobsService) startNodeInformer() error { + i, err := newNodeInformer( + context.TODO(), + j.clientSet.CoreV1().Nodes(), + func(event watch.Event) { + j.mu.Lock() + defer j.mu.Unlock() + j.nodeStatusCallback(event) + }) + if err != nil { + return err + } + + go i.run(context.TODO()) + return nil +} + +func (j *jobsService) startEventListeners() error { + for namespace := range j.namespaceToPoolName { + l, err := newEventInformer( + context.TODO(), + j.clientSet.CoreV1().Events(namespace), + namespace, + func(event watch.Event) { + j.mu.Lock() + defer j.mu.Unlock() + j.newEventCallback(event) + }) + if err != nil { + return err + } + go l.run(context.TODO()) + } + return nil +} + +func (j *jobsService) startPreemptionListeners() error { + for namespace := range j.namespaceToPoolName { + l, err := newPodInformer( + context.TODO(), + determinedPreemptionLabel, + "preemption", + namespace, + j.clientSet.CoreV1().Pods(namespace), + func(event watch.Event) { + j.mu.Lock() + defer j.mu.Unlock() + j.preemptionCallback(event) + }) + if err != nil { + return err + } + go l.run(context.TODO()) + } + return nil +} + +func (j *jobsService) startResourceRequestQueue() { + failures := make(chan resourcesRequestFailure, 16) + j.resourceRequestQueue = startRequestQueue(j.jobInterfaces, j.podInterfaces, j.configMapInterfaces, failures) + j.wg.Go(func(ctx context.Context) { + for { + select { + case failure := <-failures: + j.handleResourceRequestFailure(failure) + case <-ctx.Done(): + return + } + } + }) +} + +func (j *jobsService) handleResourceRequestFailure(msg resourcesRequestFailure) { + j.mu.Lock() + defer j.mu.Unlock() + + jobName := msg.getJobName() + jobHandler, ok := j.jobNameToJobHandler[jobName] + if !ok { + j.syslog.Warnf("received resource request error for unregistered pod %s", jobName) + return + } + + switch msg := msg.(type) { + case resourceCreationFailed: + jobHandler.receiveResourceCreationFailed(msg) + case resourceCreationCancelled: + jobHandler.receiveResourceCreationCancelled() + case resourceDeletionFailed: + jobHandler.receiveResourceDeletionFailed(msg) + default: + panic(fmt.Sprintf("unexpected message %T", msg)) + } + + err := j.cleanUpJobHandler(jobHandler) + if err != nil { + j.syslog.WithError(err).Error("cleaning up pod handler after resource request failure") + } +} + +func (j *jobsService) jobUpdatedCallback(obj any) { + job, ok := obj.(*batchV1.Job) + if !ok { + j.syslog.Warnf("error converting event of type %T to *batchV1.Job: %+v", obj, obj) + return + } + syslog := j.syslog.WithField("job", job.Name) + + jobHandler, ok := j.jobNameToJobHandler[job.Name] + if !ok { + syslog.Debugf("received job status update for un-registered job %s", job.Name) + return + } + + state, err := jobHandler.jobUpdatedCallback(job) + if err != nil { + syslog.WithError(err).Error("failed to process job status update") + if err := j.cleanUpJobHandler(jobHandler); err != nil { + syslog.WithError(err).Error("unable to cleanup job handler after an error") + } + } else if state == cproto.Terminated { + if err := j.cleanUpJobHandler(jobHandler); err != nil { + syslog.WithError(err).Error("unable to cleanup job handler after termination") + } + } +} + +func (j *jobsService) jobDeletedCallback(obj any) { + job, ok := obj.(*batchV1.Job) + if !ok { + j.syslog.Warnf("failed to convert event of type %T to *batchV1.Job: %+v", obj, obj) + return + } + syslog := j.syslog.WithField("job", job.Name) + + jobHandler, ok := j.jobNameToJobHandler[job.Name] + if !ok { + syslog.Debugf("received job status update for un-registered job %s", job.Name) + return + } + + jobHandler.jobDeletedCallback() + if err := j.cleanUpJobHandler(jobHandler); err != nil { + syslog.WithError(err).Error("unable to cleanup job handler after an error") + } +} + +func (j *jobsService) podStatusCallback(obj any) { + pod, ok := obj.(*k8sV1.Pod) + if !ok { + j.syslog.Warnf("error converting event of type %T to *k8sV1.Pod: %+v", obj, obj) + return + } + syslog := j.syslog.WithField("pod", pod.Name) + + jobName, ok := pod.Labels[kubernetesJobNameLabel] + if !ok { + syslog.Debugf("received pod informer event for pod without %s label", kubernetesJobNameLabel) + return + } + + jobHandler, ok := j.jobNameToJobHandler[jobName] + if !ok { + syslog.Debugf("received pod status update for un-registered job %s", jobName) + return + } + + err := jobHandler.podUpdatedCallback(*pod) + if err != nil { + syslog.WithError(err).Error("error processing pod status update") + return + } + + j.updatePodSchedulingState(jobName, pod) + if j.jobSchedulingStateCallback != nil { + go j.jobSchedulingStateCallback(jobSchedulingStateChanged{ + AllocationID: jobHandler.req.AllocationID, + NumPods: jobHandler.numPods, + State: j.jobSchedulingState(jobName), + }) + } +} + +func (j *jobsService) podDeletedCallback(obj any) { + pod, ok := obj.(*k8sV1.Pod) + if !ok { + j.syslog.Warnf("error converting event of type %T to *k8sV1.Pod: %+v", obj, obj) + return + } + syslog := j.syslog.WithField("pod", pod.Name) + + jobName, ok := pod.Labels[kubernetesJobNameLabel] + if !ok { + syslog.Debugf("received pod informer event for pod without %s label", kubernetesJobNameLabel) + return + } + + jobHandler, ok := j.jobNameToJobHandler[jobName] + if !ok { + syslog.Debugf("received pod status update for un-registered job %s", jobName) + return + } + + jobHandler.podDeletedCallback(pod) +} + +// jobSchedulingState is a roll-up of the sceduling states of its individual pods. +func (j *jobsService) jobSchedulingState(jobName string) sproto.SchedulingState { + states, ok := j.jobNameToPodNameToSchedulingState[jobName] + if !ok { + return sproto.SchedulingStateQueued + } + if !allEqual(sproto.SchedulingStateScheduled, maps.Values(states)...) { + return sproto.SchedulingStateQueued + } + return sproto.SchedulingStateScheduled +} + +// updatePodSchedulingState stores the scheduling state of a pod based on its state (in particular the phase). +func (j *jobsService) updatePodSchedulingState(jobName string, pod *k8sV1.Pod) { + states, ok := j.jobNameToPodNameToSchedulingState[jobName] + if !ok { + states = make(map[string]sproto.SchedulingState) + } + + states[pod.Name] = sproto.SchedulingStateQueued + if pod.Status.Phase == "Running" { + states[pod.Name] = sproto.SchedulingStateScheduled + } + j.jobNameToPodNameToSchedulingState[jobName] = states +} + +var ( + clusterID string + once sync.Once +) + +func setClusterID(s string) { + once.Do(func() { + clusterID = s + }) +} + +func clusterIDNodeLabel() string { + return fmt.Sprintf("determined.ai/cluster-id-%s", clusterID) +} + +const ( + noExecuteNodeLabelValue = "no-execute" + noScheduleNodeLabelValue = "no-schedule" +) + +func (j *jobsService) enableNode( + nodeName string, +) (*apiv1.EnableAgentResponse, error) { + patch := []byte(fmt.Sprintf(`{ + "metadata": { + "labels": { + "%s": null + } + } + }`, clusterIDNodeLabel())) + + _, err := j.clientSet.CoreV1().Nodes(). + Patch(context.TODO(), nodeName, types.StrategicMergePatchType, patch, metaV1.PatchOptions{}) + if k8error.IsForbidden(err) { + return nil, fmt.Errorf("the Determined master Kubernetes service account " + + "is missing permissions to patch nodes. " + + "Enabling or disabling nodes requires this permission, " + + "however Determined will otherwise still function correctly without " + + "these Kubernetes permissions") + } else if err != nil { + return nil, fmt.Errorf( + "enabling node %s by removing the Determined no schedule label: %w", nodeName, err) + } + j.syslog.Infof("node %s enabled by an user", nodeName) + + n, ok := j.summarizeClusterByNodes()[nodeName] + if !ok { + return nil, fmt.Errorf("node %s enabled without error, error getting node summary", nodeName) + } + n.Enabled = true + n.Draining = false + for slotKey := range n.Slots { + s := n.Slots[slotKey] + s.Enabled = n.Enabled + s.Draining = n.Draining + n.Slots[slotKey] = s + } + + return &apiv1.EnableAgentResponse{ + Agent: n.ToProto(), + }, nil +} + +func (j *jobsService) disableNode( + nodeName string, shouldDrain bool, +) (*apiv1.DisableAgentResponse, error) { + labelValue := noExecuteNodeLabelValue + if shouldDrain { + labelValue = noScheduleNodeLabelValue + } + + patchStruct := metaV1.ObjectMeta{ + Labels: map[string]string{clusterIDNodeLabel(): labelValue}, + } + patch, err := json.Marshal(map[string]any{"metadata": patchStruct}) + if err != nil { + return nil, fmt.Errorf("marshaling JSON patch %v: %s", patchStruct, err) + } + + _, err = j.clientSet.CoreV1().Nodes(). + Patch(context.TODO(), nodeName, types.StrategicMergePatchType, patch, metaV1.PatchOptions{}) + if k8error.IsForbidden(err) { + return nil, fmt.Errorf("the Determined master Kubernetes service account " + + "is missing permissions to patch nodes. " + + "Enabling or disabling nodes requires this permission, " + + "however Determined will otherwise still function correctly without " + + "these Kubernetes permissions") + } else if err != nil { + return nil, fmt.Errorf( + "disabling node %s by adding the Determined no schedule label: %w", nodeName, err) + } + j.syslog.Infof("node %s disabled by a user", nodeName) + + if !shouldDrain { // See note in spec.go about how we could remove killing all pods here. + if err := j.releaseAllocationsOnDisabledNode(nodeName); err != nil { + return nil, fmt.Errorf( + "node disabled without error, error killing existing pod on node: %w", err) + } + } + + n, ok := j.summarizeClusterByNodes()[nodeName] + if !ok { + return nil, fmt.Errorf("node %s disabled without error, error getting node summary", nodeName) + } + n.Enabled = false + n.Draining = shouldDrain + for slotKey := range n.Slots { + s := n.Slots[slotKey] + s.Enabled = n.Enabled + s.Draining = n.Draining + n.Slots[slotKey] = s + } + + return &apiv1.DisableAgentResponse{ + Agent: n.ToProto(), + }, nil +} + +func (j *jobsService) releaseAllocationsOnDisabledNode(nodeName string) error { + listOptions := metaV1.ListOptions{ + LabelSelector: determinedLabel, + FieldSelector: fmt.Sprintf("spec.nodeName=%s", nodeName), + } + pods, err := j.listPodsInAllNamespaces(context.TODO(), listOptions) + if err != nil { + return fmt.Errorf("listing pods on node %s: %w", nodeName, err) + } + + notifiedAllocations := make(map[model.AllocationID]bool) + for _, pod := range pods.Items { + jobName, ok := pod.Labels[kubernetesJobNameLabel] + if !ok { + j.syslog.Debugf("found pod when disabling node without %s label", kubernetesJobNameLabel) + continue + } + + jobHandler, ok := j.jobNameToJobHandler[jobName] + if !ok { + j.syslog.Warnf( + "during node disable couldn't find pod %s's actor to kill", pod.Name) + continue + } + + j.syslog.Infof( + "stopping pod %s because node %s was disabled without drain option", pod.Name, nodeName) + if notifiedAllocations[jobHandler.allocationID] { + continue + } + + rmevents.Publish(jobHandler.allocationID, &sproto.ReleaseResources{ + Reason: "node disabled without drain", + ForceKill: true, + }) + notifiedAllocations[jobHandler.allocationID] = true + } + + return nil +} + +func (j *jobsService) nodeStatusCallback(event watch.Event) { + node, ok := event.Object.(*k8sV1.Node) + if !ok { + j.syslog.Warnf("error converting event of type %T to *k8sV1.Node: %+v", event, event) + return + } + + j.syslog.Debugf(`informer got new node event for node '%s': %s %s`, + node.Name, event.Type, node.Status.Phase) + + switch event.Type { + case watch.Added: + j.currentNodes[node.Name] = node + case watch.Modified: + j.currentNodes[node.Name] = node + case watch.Deleted: + delete(j.currentNodes, node.Name) + default: + } +} + +func (j *jobsService) newEventCallback(event watch.Event) { + newEvent, ok := event.Object.(*k8sV1.Event) + if !ok { + j.syslog.Warnf("error converting object type %T to *k8sV1.Event: %+v", event, event) + return + } + syslog := j.syslog.WithFields(logrus.Fields{ + "name": newEvent.InvolvedObject.Name, + "kind": newEvent.InvolvedObject.Kind, + }) + + switch newEvent.InvolvedObject.Kind { + case "Pod": //nolint:goconst // Useless lint. + podName := newEvent.InvolvedObject.Name + jobNameParts := strings.Split(podName, "-") + if len(jobNameParts) <= 1 { + syslog.Tracef("received pod event for an un-registered pod %s", podName) + return + } + jobName := strings.Join(jobNameParts[:len(jobNameParts)-1], "-") + ref, ok := j.jobNameToJobHandler[jobName] + if !ok { + syslog.Tracef("received pod event for an un-registered job %s", jobName) + return + } + ref.newEventCallback(newEvent) + case "Job": + jobName := newEvent.InvolvedObject.Name + ref, ok := j.jobNameToJobHandler[jobName] + if !ok { + syslog.Tracef("received job event for an un-registered job %s", jobName) + return + } + ref.newEventCallback(newEvent) + } +} + +type computeUsageSummary struct { + numAgentsUsed int + slotsAvailable int +} + +// TODO(!!!): good func comment. +func (j *jobsService) summarizeComputeUsage(poolName string) (*computeUsageSummary, error) { + summary, err := j.summarize() + if err != nil { + return nil, err + } + + slots := 0 + if len(poolName) > 0 { + slots = numSlots(summary[poolName].Slots) + } else { + for _, pool := range summary { + slots += numSlots(pool.Slots) + } + } + return &computeUsageSummary{numAgentsUsed: len(summary), slotsAvailable: slots}, nil +} + +func (j *jobsService) preemptionCallback(event watch.Event) { + pod, ok := event.Object.(*k8sV1.Pod) + if !ok { + j.syslog.Warnf("error converting event of type %T to *k8sV1.Pod: %+v", event, event) + return + } + j.syslog.Debugf("informer got new preemption event for pod %s ", pod.Name) + + ref, ok := j.jobNameToJobHandler[pod.Name] + if !ok { + j.syslog.Debug("received preemption command for unregistered pod") + return + } + ref.preemptionCallback() +} + +func (j *jobsService) verifyJobAndGetRef(id model.AllocationID) (*job, error) { + jobName, ok := j.allocationIDToJobName[id] + if !ok { + return nil, fmt.Errorf("unknown allocation %s", id) + } + + ref, ok := j.jobNameToJobHandler[jobName] + if !ok { + return nil, fmt.Errorf("unknown job %s", jobName) + } + return ref, nil +} + +func (j *jobsService) changePriority(id model.AllocationID) { + ref, err := j.verifyJobAndGetRef(id) + if err != nil { + j.syslog.WithError(err).Debug("changing allocation priority") + return + } + ref.changePriority() +} + +func (j *jobsService) changePosition(id model.AllocationID) { + ref, err := j.verifyJobAndGetRef(id) + if err != nil { + j.syslog.WithError(err).Debug("changing allocation position") + return + } + ref.changePosition() +} + +func (j *jobsService) killJob(id model.AllocationID) { + ref, err := j.verifyJobAndGetRef(id) + if err != nil { + j.syslog.WithError(err).Debug("killing allocation") + return + } + ref.Kill() +} + +func (j *jobsService) cleanUpJobHandler(jobHandler *job) error { + jobHandler.finalize() + + jobInfo, ok := j.jobHandlerToMetadata[jobHandler] + if !ok { + return fmt.Errorf("unknown job handler being deleted %s", jobHandler.jobName) + } + + j.syslog. + WithField("pod", jobInfo.jobName). + WithField("handler", jobHandler.jobName). + Infof("de-registering job handler") + delete(j.jobNameToJobHandler, jobInfo.jobName) + delete(j.jobNameToResourcePool, jobInfo.jobName) + delete(j.allocationIDToJobName, jobInfo.allocationID) + delete(j.jobNameToPodNameToSchedulingState, jobInfo.jobName) + delete(j.jobHandlerToMetadata, jobHandler) + + // launch this work async, since we hold the lock and it does API calls. + j.wg.Go(func(ctx context.Context) { + name := fmt.Sprintf("%s-priorityclass", jobInfo.allocationID) + err := j.clientSet. + SchedulingV1(). + PriorityClasses(). + Delete(ctx, name, metaV1.DeleteOptions{}) + if err != nil && !k8error.IsNotFound(err) { + j.syslog.Warnf("Deletion of PriorityClass %s failed.", name) + } + }) + + return nil +} + +func (j *jobsService) getSlots(agentID string) *apiv1.GetSlotsResponse { + agentResp := j.getAgent(agentID) + if agentResp == nil { + j.syslog.Warnf("no agent with id %s", agentID) + return nil + } + return &apiv1.GetSlotsResponse{Slots: maps.Values(agentResp.Agent.Slots)} +} + +func (j *jobsService) getSlot(agentID string, slotID string) *apiv1.GetSlotResponse { + agentResp := j.getAgent(agentID) + if agentResp == nil { + j.syslog.Warnf("no agent with id %s", agentID) + return nil + } + slots := agentResp.Agent.Slots + slot, ok := slots[slotID] + if !ok { + // Try converting an index input to a slot and see if that exists (1 to 001). + tryIndex, err := strconv.Atoi(slotID) + if s, ok := slots[model.SortableSlotIndex(tryIndex)]; err == nil && ok { + slot = s + } else { + j.syslog.Warnf("no slot with id %s", slotID) + return nil + } + } + return &apiv1.GetSlotResponse{Slot: slot} +} + +const getAgentsCacheDuration = 15 * time.Second + +func (j *jobsService) getAgents() *apiv1.GetAgentsResponse { + j.getAgentsCacheLock.Lock() + defer j.getAgentsCacheLock.Unlock() + + if time.Since(j.getAgentsCacheTime) > getAgentsCacheDuration { + j.getAgentsCacheTime = time.Now() + + nodeSummaries := j.summarizeClusterByNodes() + _, nodesToPools := j.getNodeResourcePoolMapping(nodeSummaries) + + j.getAgentsCache = &apiv1.GetAgentsResponse{} + for _, summary := range nodeSummaries { + summary.ResourcePool = nodesToPools[summary.ID] + j.getAgentsCache.Agents = append(j.getAgentsCache.Agents, summary.ToProto()) + } + } + + return j.getAgentsCache +} + +func (j *jobsService) getAgent(agentID string) *apiv1.GetAgentResponse { + nodeSummaries := j.summarizeClusterByNodes() + _, nodesToPools := j.getNodeResourcePoolMapping(nodeSummaries) + agentSummary, ok := nodeSummaries[agentID] + if !ok { + // TODO(DET-10029): We should return an error indicating the invalid ID request (rather + // than a warn). + j.syslog.Warnf("no agent with id %s", agentID) + return nil + } + agentSummary.ResourcePool = nodesToPools[agentSummary.ID] + return &apiv1.GetAgentResponse{Agent: agentSummary.ToProto()} +} + +const summarizeCacheDuration = 5 * time.Second + +// summarize describes pods' available resources. When there's exactly one resource pool, it uses +// the whole cluster's info. Otherwise, it matches nodes to resource pools using taints and +// tolerations to derive that info. This may be cached, so don't use this for decisions +// that require up-to-date information. +func (j *jobsService) summarize() (map[string]model.AgentSummary, error) { + j.summarizeCacheLock.Lock() + defer j.summarizeCacheLock.Unlock() + + if time.Since(j.summarizeCacheTime) > summarizeCacheDuration { + summary, err := j.computeSummary() + j.summarizeCacheTime = time.Now() + j.summarizeCache = summarizeResult{ + summary: summary, + err: err, + } + } + + return j.summarizeCache.summary, j.summarizeCache.err +} + +// Get the mapping of many-to-many relationship between nodes and resource pools. +func (j *jobsService) getNodeResourcePoolMapping(nodeSummaries map[string]model.AgentSummary) ( + map[string][]*k8sV1.Node, map[string][]string, +) { + poolTaskContainerDefaults := extractTCDs(j.resourcePoolConfigs) + + // Nvidia automatically taints nodes, so we should tolerate that when users don't customize + // their resource pool config. + defaultTolerations := []k8sV1.Toleration{{ + Key: resourceTypeNvidia, + Value: "present", + Operator: k8sV1.TolerationOpEqual, + }} + cpuTolerations, gpuTolerations := extractTolerations(j.baseContainerDefaults) + poolsToNodes := make(map[string][]*k8sV1.Node, len(j.namespaceToPoolName)) + nodesToPools := make(map[string][]string, len(j.namespaceToPoolName)) + + for _, node := range j.currentNodes { + _, slotType := extractSlotInfo(nodeSummaries[node.Name]) + + for poolName, tcd := range poolTaskContainerDefaults { + var poolTolerations []k8sV1.Toleration + + // If they're using the default RP config, use the default tolerations. + if len(j.resourcePoolConfigs) <= 1 && + (tcd == nil || (tcd.CPUPodSpec == nil && tcd.GPUPodSpec == nil)) { + if slotType == device.CUDA { + //nolint:gocritic + poolTolerations = append(defaultTolerations, gpuTolerations...) + } else if slotType == device.CPU { + //nolint:gocritic + poolTolerations = append(defaultTolerations, cpuTolerations...) + } + } else if tcd != nil { + // Decide which poolTolerations to use based on slot device type + if slotType == device.CUDA && tcd.GPUPodSpec != nil { + //nolint:gocritic + poolTolerations = append(tcd.GPUPodSpec.Spec.Tolerations, gpuTolerations...) + } else if tcd.CPUPodSpec != nil { + //nolint:gocritic + poolTolerations = append(tcd.CPUPodSpec.Spec.Tolerations, cpuTolerations...) + } + } + + // add default toleration so that autoscaling nodes will still be counted. + poolTolerations = append(poolTolerations, k8sV1.Toleration{ + Key: "DeletionCandidateOfClusterAutoscaler", + Operator: "Exists", + Effect: "PreferNoSchedule", + TolerationSeconds: nil, + }) + // If all of a node's taints are tolerated by a pool, that node belongs to the pool. + if allTaintsTolerated(node.Spec.Taints, poolTolerations) { + poolsToNodes[poolName] = append(poolsToNodes[poolName], node) + nodesToPools[node.Name] = append(nodesToPools[node.Name], poolName) + } + } + } + + return poolsToNodes, nodesToPools +} + +var programStartTime = time.Now() + +func (j *jobsService) computeSummary() (map[string]model.AgentSummary, error) { + nodeSummaries := j.summarizeClusterByNodes() + + // Build the many-to-many relationship between nodes and resource pools + poolsToNodes, _ := j.getNodeResourcePoolMapping(nodeSummaries) + + // Build the set of summaries for each resource pool + containers := j.containersPerResourcePool() + summaries := make(map[string]model.AgentSummary, len(j.namespaceToPoolName)) + for poolName, nodes := range poolsToNodes { + slots := model.SlotsSummary{} + numContainersInPool := containers[poolName] + + // We'll create a number of pseudo-containers in the summary equal to the number of + // running containers in this pool. + pseudoContainersAdded := 0 + + for _, node := range nodes { + numSlots, slotType := extractSlotInfo(nodeSummaries[node.Name]) + + for j := 0; j < numSlots; j++ { + id := fmt.Sprintf("%s/%s/%s/%d", poolName, node.Name, string(slotType), j) + + var container *cproto.Container + if pseudoContainersAdded < numContainersInPool { + container = &cproto.Container{ + ID: cproto.ID(id), + State: "RUNNING", + } + pseudoContainersAdded++ + } + + slots[id] = model.SlotSummary{ + ID: id, + Device: device.Device{Type: slotType}, + Enabled: true, + Container: container, + } + } + } + + summaries[poolName] = model.AgentSummary{ + ID: poolName, + RegisteredTime: programStartTime, + NumContainers: numContainersInPool, + ResourcePool: []string{poolName}, + Slots: slots, + } + } + + return summaries, nil +} + +func (j *jobsService) summarizeClusterByNodes() map[string]model.AgentSummary { + var allPods []podNodeInfo + + for _, p := range j.jobNameToJobHandler { + allPods = append(allPods, p.getNodeInfoForPods()...) + } + + // Separate pods by nodes. + podByNode := make(map[string][]podNodeInfo, len(allPods)) + for _, podInfo := range allPods { + if len(podInfo.nodeName) == 0 { + // If a pod doesn't have a nodeName it means it has not yet + // been allocated to a node. + continue + } + podByNode[podInfo.nodeName] = append(podByNode[podInfo.nodeName], podInfo) + } + + nodeToTasks, taskSlots := j.getNonDetSlots(j.slotType) + summary := make(map[string]model.AgentSummary, len(j.currentNodes)) + for _, node := range j.currentNodes { + disabledLabel, isDisabled := node.Labels[clusterIDNodeLabel()] + isDraining := isDisabled && disabledLabel == noScheduleNodeLabelValue + + var numSlots int64 + var deviceType device.Type + + // TODO(DET-10010): slot type per node probably shouldn't be decided from pods literal + // (which has the same value for all nodes). + switch j.slotType { + case device.CPU: + resources := node.Status.Allocatable[k8sV1.ResourceCPU] + milliCPUs := resources.MilliValue() - j.nodeToSystemResourceRequests[node.Name] + numSlots = int64(float32(milliCPUs) / (1000. * j.slotResourceRequests.CPU)) + deviceType = device.CPU + case device.ROCM: + panic("ROCm is not supported on k8s yet") + case device.CUDA: + fallthrough + default: + resources := node.Status.Allocatable[resourceTypeNvidia] + numSlots = resources.Value() + deviceType = device.CUDA + } + + if numSlots < 1 { + continue + } + + slotsSummary := make(model.SlotsSummary) + curSlot := 0 + for _, podInfo := range podByNode[node.Name] { + for i := 0; i < podInfo.numSlots; i++ { + if curSlot >= int(numSlots) { + j.syslog.Warnf("too many pods mapping to node %s", node.Name) + continue + } + + slotsSummary[model.SortableSlotIndex(curSlot)] = model.SlotSummary{ + ID: model.SortableSlotIndex(curSlot), + Device: device.Device{Type: deviceType}, + Draining: isDraining, + Enabled: !isDisabled, + Container: podInfo.container, + } + curSlot++ + } + } + + for _, taskName := range nodeToTasks[node.Name] { + for i := int64(0); i < taskSlots[taskName]; i++ { + if curSlot >= int(numSlots) { + j.syslog.Warnf("too many pods mapping to node %s", node.Name) + continue + } + + slotsSummary[model.SortableSlotIndex(curSlot)] = model.SlotSummary{ + ID: model.SortableSlotIndex(curSlot), + Device: device.Device{Type: deviceType}, + Draining: isDraining, + Enabled: !isDisabled, + Container: &cproto.Container{ + ID: cproto.ID(taskName), + State: "RUNNING", + Devices: []device.Device{}, + Description: "unknown", + }, + } + curSlot++ + } + } + + for i := curSlot; i < int(numSlots); i++ { + slotsSummary[model.SortableSlotIndex(i)] = model.SlotSummary{ + ID: model.SortableSlotIndex(i), + Device: device.Device{Type: deviceType}, + Draining: isDraining, + Enabled: !isDisabled, + } + } + + var addrs []string + for _, addr := range node.Status.Addresses { + addrs = append(addrs, addr.Address) + } + + summary[node.Name] = model.AgentSummary{ + ID: node.Name, + RegisteredTime: node.ObjectMeta.CreationTimestamp.Time, + Slots: slotsSummary, + NumContainers: len(podByNode[node.Name]) + len(nodeToTasks[node.Name]), + ResourcePool: []string{""}, + Addresses: addrs, + Draining: isDraining, + Enabled: !isDisabled, + } + } + + return summary +} + +func (j *jobsService) getNonDetPods() ([]k8sV1.Pod, error) { + // TODO(RM-235) use a filter in metaV1.ListOptions. This change gets a lot easier after + // we have K8s integration tests. Using a filter means we should really talk to a real + // k8s server. Doing an e2e test for this is possible but would take a lot more work. + allPods, err := j.listPodsInAllNamespaces(context.TODO(), metaV1.ListOptions{}) + if err != nil { + return nil, err + } + + var nonDetPods []k8sV1.Pod + for _, p := range allPods.Items { + _, isDet := p.Labels[determinedLabel] + _, isDetSystem := p.Labels[determinedSystemLabel] + + if !(isDet || isDetSystem) { + if p.Spec.NodeName != "" { + nonDetPods = append(nonDetPods, p) + } + } + } + return nonDetPods, nil +} + +func (j *jobsService) getNonDetSlots(deviceType device.Type) (map[string][]string, map[string]int64) { + nodeToTasks := make(map[string][]string, len(j.currentNodes)) + taskSlots := make(map[string]int64) + + nonDetPods, err := j.getNonDetPods() + if err != nil { + j.syslog.WithError(err).Warn("getting non determined pods, " + + "this may cause slots to look free when they are in use") + } + + if len(nonDetPods) == 0 { + return nodeToTasks, taskSlots + } + for _, node := range j.currentNodes { + nodeToTasks[node.Name] = []string{} + } + + // Ignore pods not yet scheduled on a node. + for _, pod := range nonDetPods { + if _, ok := nodeToTasks[pod.Spec.NodeName]; !ok { + continue + } + reqs := int64(0) + for _, c := range pod.Spec.Containers { + if deviceType == device.CPU { + reqs += j.getCPUReqs(c) + } else if deviceType == device.CUDA { + reqs += c.Resources.Requests.Name(resourceTypeNvidia, resource.DecimalSI).Value() + } + } + if reqs > 0 { + nodeToTasks[pod.Spec.NodeName] = append(nodeToTasks[pod.Spec.NodeName], pod.Name) + taskSlots[pod.Name] = reqs + } + } + return nodeToTasks, taskSlots +} + +func (j *jobsService) getCPUReqs(c k8sV1.Container) int64 { + requested := float32(c.Resources.Requests.Cpu().MilliValue()) / + (1000. * j.slotResourceRequests.CPU) + return int64(requested) +} + +func (j *jobsService) containersPerResourcePool() map[string]int { + counts := make(map[string]int, len(j.namespaceToPoolName)) + for name, pool := range j.jobNameToResourcePool { + handler, ok := j.jobNameToJobHandler[name] + if !ok { + j.syslog.Errorf("job %s not in jobNameToResourcePool but in jobNameToJobHandler map", name) + continue + } + counts[pool] += handler.numPods + } + return counts +} + +func numSlots(slots model.SlotsSummary) int { + slotCountsByType := make(map[device.Type]int) + for _, slot := range slots { + slotCountsByType[slot.Device.Type]++ + } + + if slotCountsByType[device.CUDA] > 0 { + return slotCountsByType[device.CUDA] + } + + return slotCountsByType[device.CPU] +} + +func (j *jobsService) listJobsInAllNamespaces( + ctx context.Context, opts metaV1.ListOptions, +) (*batchV1.JobList, error) { + res := &batchV1.JobList{} + for n, i := range j.jobInterfaces { + pods, err := i.List(ctx, opts) + if err != nil { + return nil, fmt.Errorf("error listing pods for namespace %s: %w", n, err) + } + + res.Items = append(res.Items, pods.Items...) + } + + return res, nil +} + +func (j *jobsService) listPodsInAllNamespaces( + ctx context.Context, opts metaV1.ListOptions, +) (*k8sV1.PodList, error) { + res := &k8sV1.PodList{} + for n, i := range j.podInterfaces { + pods, err := i.List(ctx, opts) + if err != nil { + return nil, fmt.Errorf("error listing pods for namespace %s: %w", n, err) + } + + res.Items = append(res.Items, pods.Items...) + } + + return res, nil +} + +func (j *jobsService) listConfigMapsInAllNamespaces( + ctx context.Context, opts metaV1.ListOptions, +) (*k8sV1.ConfigMapList, error) { + res := &k8sV1.ConfigMapList{} + for n, i := range j.configMapInterfaces { + cms, err := i.List(ctx, opts) + if err != nil { + return nil, fmt.Errorf("error listing config maps for namespace %s: %w", n, err) + } + res.Items = append(res.Items, cms.Items...) + } + + return res, nil +} + +func extractTCDs(resourcePoolConfigs []config.ResourcePoolConfig, +) map[string]*model.TaskContainerDefaultsConfig { + result := map[string]*model.TaskContainerDefaultsConfig{} + + for _, config := range resourcePoolConfigs { + result[config.PoolName] = config.TaskContainerDefaults + } + + return result +} + +func taintTolerated(taint k8sV1.Taint, tolerations []k8sV1.Toleration) bool { + for _, toleration := range tolerations { + if toleration.ToleratesTaint(&taint) { + return true + } + } + + return false +} + +func allTaintsTolerated(taints []k8sV1.Taint, tolerations []k8sV1.Toleration) bool { + for _, taint := range taints { + if !taintTolerated(taint, tolerations) { + return false + } + } + + return true +} + +func extractSlotInfo(node model.AgentSummary) (numSlots int, devType device.Type) { + var gpuSlots, cpuSlots int + + for _, slot := range node.Slots { + if slot.Device.Type == device.CPU { + cpuSlots++ + } else if slot.Device.Type == device.CUDA { + gpuSlots++ + } + } + + if gpuSlots > 0 { + return gpuSlots, device.CUDA + } + + return cpuSlots, device.CPU +} + +func extractTolerations(tcd *model.TaskContainerDefaultsConfig) ( + cpuTolerations, gpuTolerations []k8sV1.Toleration, +) { + if tcd != nil { + if tcd.GPUPodSpec != nil { + gpuTolerations = tcd.GPUPodSpec.Spec.Tolerations + } + if tcd.CPUPodSpec != nil { + cpuTolerations = tcd.CPUPodSpec.Spec.Tolerations + } + } + + return cpuTolerations, gpuTolerations +} + +func all[T any](pred func(T) bool, elems ...T) bool { + for _, elem := range elems { + if !pred(elem) { + return false + } + } + return true +} + +func allEqual[T comparable](other T, elems ...T) bool { + return all(func(elem T) bool { + return elem == other + }, elems...) +} diff --git a/master/internal/rm/kubernetesrm/pods_test.go b/master/internal/rm/kubernetesrm/jobs_test.go similarity index 99% rename from master/internal/rm/kubernetesrm/pods_test.go rename to master/internal/rm/kubernetesrm/jobs_test.go index dd5d3a0bf3e..590524e221a 100644 --- a/master/internal/rm/kubernetesrm/pods_test.go +++ b/master/internal/rm/kubernetesrm/jobs_test.go @@ -1,3 +1,5 @@ +//go:build integration + package kubernetesrm import ( @@ -61,7 +63,7 @@ func TestGetNonDetPods(t *testing.T) { ns2.On("List", mock.Anything, mock.Anything).Once(). Return(&k8sV1.PodList{Items: append(hiddenPods, expectedPods[1])}, nil) - p := pods{ + p := jobsService{ podInterfaces: map[string]typedV1.PodInterface{ "ns1": ns1, "ns2": ns2, diff --git a/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go b/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go index 109fe45cfc9..433935a97a9 100644 --- a/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go +++ b/master/internal/rm/kubernetesrm/kubernetes_resource_manager.go @@ -41,7 +41,7 @@ type ResourceManager struct { poolsConfig []config.ResourcePoolConfig taskContainerDefaults *model.TaskContainerDefaultsConfig - podsService *pods + jobsService *jobsService pools map[string]*kubernetesResourcePool // immutable after initialization in new. masterTLSConfig model.TLSClientConfig @@ -65,11 +65,11 @@ func New( } // TODO(DET-9833) clusterID should just be a `internal/config` package singleton. - clusterID, err := db.GetOrCreateClusterID("") + id, err := db.GetOrCreateClusterID("") if err != nil { return nil, fmt.Errorf("getting clusterID: %w", err) } - setClusterID(clusterID) + setClusterID(id) k := &ResourceManager{ syslog: logrus.WithField("component", "k8srm"), @@ -95,12 +95,11 @@ func New( poolNamespaces[k.poolsConfig[i].KubernetesNamespace] = k.poolsConfig[i].PoolName } - k.podsService = newPodsService( + k.jobsService, err = newJobsService( k.config.Namespace, poolNamespaces, k.config.MasterServiceName, k.masterTLSConfig, - k.loggingConfig, k.config.DefaultScheduler, k.config.SlotType, config.PodSlotResourceRequests{CPU: k.config.SlotResourceRequests.CPU}, @@ -109,8 +108,11 @@ func New( k.config.DetMasterIP, k.config.DetMasterPort, k.config.KubeconfigPath, - k.podStatusUpdateCallback, + k.jobSchedulingStateCallback, ) + if err != nil { + return nil, err + } for _, poolConfig := range k.poolsConfig { maxSlotsPerPod := 0 @@ -124,7 +126,7 @@ func New( } poolConfig := poolConfig - rp := newResourcePool(maxSlotsPerPod, &poolConfig, k.podsService, k.db) + rp := newResourcePool(maxSlotsPerPod, &poolConfig, k.jobsService, k.db) go func() { t := time.NewTicker(podSubmissionInterval) defer t.Stop() @@ -178,19 +180,19 @@ func (k *ResourceManager) HealthCheck() []model.ResourceManagerHealth { return []model.ResourceManagerHealth{ { Name: k.config.Name, - Status: k.podsService.HealthStatus(), + Status: k.jobsService.HealthStatus(), }, } } // GetAgent implements rm.ResourceManager. func (k *ResourceManager) GetAgent(msg *apiv1.GetAgentRequest) (*apiv1.GetAgentResponse, error) { - return k.podsService.GetAgent(msg), nil + return k.jobsService.GetAgent(msg), nil } // GetAgents implements rm.ResourceManager. func (k *ResourceManager) GetAgents() (*apiv1.GetAgentsResponse, error) { - return k.podsService.GetAgents(), nil + return k.jobsService.GetAgents(), nil } // GetAllocationSummaries implements rm.ResourceManager. @@ -292,12 +294,12 @@ func (k *ResourceManager) GetResourcePools() (*apiv1.GetResourcePoolsResponse, e // GetSlot implements rm.ResourceManager. func (k *ResourceManager) GetSlot(msg *apiv1.GetSlotRequest) (*apiv1.GetSlotResponse, error) { - return k.podsService.GetSlot(msg), nil + return k.jobsService.GetSlot(msg), nil } // GetSlots implements rm.ResourceManager. func (k *ResourceManager) GetSlots(msg *apiv1.GetSlotsRequest) (*apiv1.GetSlotsResponse, error) { - return k.podsService.GetSlots(msg), nil + return k.jobsService.GetSlots(msg), nil } // MoveJob implements rm.ResourceManager. @@ -515,9 +517,17 @@ func (k ResourceManager) TaskContainerDefaults( return result, nil } -func (k *ResourceManager) podStatusUpdateCallback(msg sproto.UpdatePodStatus) { +type jobSchedulingStateCallback func(jobSchedulingStateChanged) + +type jobSchedulingStateChanged struct { + AllocationID model.AllocationID + NumPods int + State sproto.SchedulingState +} + +func (k *ResourceManager) jobSchedulingStateCallback(msg jobSchedulingStateChanged) { for _, rp := range k.pools { - rp.UpdatePodStatus(msg) + rp.JobSchedulingStateChanged(msg) } } @@ -588,7 +598,7 @@ func (k *ResourceManager) createResourcePoolSummary( return &resourcepoolv1.ResourcePool{}, err } - resourceSummary, err := rp.getResourceSummary(getResourceSummary{}) + resourceSummary, err := rp.getResourceSummary() if err != nil { return &resourcepoolv1.ResourcePool{}, err } @@ -635,14 +645,14 @@ func (k *ResourceManager) getResourcePoolConfig(poolName string) ( func (k *ResourceManager) EnableAgent( req *apiv1.EnableAgentRequest, ) (resp *apiv1.EnableAgentResponse, err error) { - return k.podsService.EnableAgent(req) + return k.jobsService.EnableAgent(req) } // DisableAgent prevents scheduling on a node and has the option to kill running jobs. func (k *ResourceManager) DisableAgent( req *apiv1.DisableAgentRequest, ) (resp *apiv1.DisableAgentResponse, err error) { - return k.podsService.DisableAgent(req) + return k.jobsService.DisableAgent(req) } // EnableSlot implements 'det slot enable...' functionality. diff --git a/master/internal/rm/kubernetesrm/kubernetes_resource_manager_intg_test.go b/master/internal/rm/kubernetesrm/kubernetes_resource_manager_intg_test.go index 1c178e2cf30..4169f7bb28e 100644 --- a/master/internal/rm/kubernetesrm/kubernetes_resource_manager_intg_test.go +++ b/master/internal/rm/kubernetesrm/kubernetes_resource_manager_intg_test.go @@ -17,6 +17,7 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + batchV1 "k8s.io/api/batch/v1" k8sV1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -27,7 +28,6 @@ import ( "github.com/determined-ai/determined/master/internal/mocks" "github.com/determined-ai/determined/master/internal/rm/tasklist" "github.com/determined-ai/determined/master/internal/sproto" - "github.com/determined-ai/determined/master/pkg/cproto" "github.com/determined-ai/determined/master/pkg/device" "github.com/determined-ai/determined/master/pkg/etc" "github.com/determined-ai/determined/master/pkg/model" @@ -53,7 +53,7 @@ const ( func TestMain(m *testing.M) { // Need to set up the DB for TestJobQueueStats - pgDB, _, err := db.ResolveTestPostgres() + pgDB, _, err := db.ResolveNewPostgresDatabase() if err != nil { log.Panicln(err) } @@ -68,13 +68,15 @@ func TestMain(m *testing.M) { log.Panicln(err) } + setClusterID(uuid.NewString()) + os.Exit(m.Run()) } func TestGetAgents(t *testing.T) { type AgentsTestCase struct { Name string - podsService *pods + jobsService *jobsService wantedAgentIDs map[string]int } @@ -83,7 +85,7 @@ func TestGetAgents(t *testing.T) { agentsTests := []AgentsTestCase{ { Name: "GetAgents-CPU-NoPodLabels-NoAgents", - podsService: createMockPodsService(make(map[string]*k8sV1.Node), + jobsService: createMockJobsService(make(map[string]*k8sV1.Node), device.CPU, false, ), @@ -91,7 +93,7 @@ func TestGetAgents(t *testing.T) { }, { Name: "GetAgents-CPU-NoPodLabels", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ auxNode1Name: auxNode1, auxNode2Name: auxNode2, }, @@ -102,7 +104,7 @@ func TestGetAgents(t *testing.T) { }, { Name: "GetAgents-CPU-PodLabels", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ auxNode1Name: auxNode1, auxNode2Name: auxNode2, }, @@ -113,7 +115,7 @@ func TestGetAgents(t *testing.T) { }, { Name: "GetAgents-GPU-PodLabels-NonDetAgent", - podsService: createMockPodsService(make(map[string]*k8sV1.Node), + jobsService: createMockJobsService(make(map[string]*k8sV1.Node), slotTypeGPU, true, ), @@ -121,7 +123,7 @@ func TestGetAgents(t *testing.T) { }, { Name: "GetAgents-GPU-NoPodNoLabels", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -132,7 +134,7 @@ func TestGetAgents(t *testing.T) { }, { Name: "GetAgents-GPU-PodLabels", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -143,7 +145,7 @@ func TestGetAgents(t *testing.T) { }, { Name: "GetAgents-CUDA-NoPodLabels", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -154,7 +156,7 @@ func TestGetAgents(t *testing.T) { }, { Name: "GetAgents-CUDA-PodLabels", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -167,7 +169,7 @@ func TestGetAgents(t *testing.T) { for _, test := range agentsTests { t.Run(test.Name, func(t *testing.T) { - agentsResp := test.podsService.handleGetAgentsRequest() + agentsResp := test.jobsService.getAgents() require.Equal(t, len(test.wantedAgentIDs), len(agentsResp.Agents)) for _, agent := range agentsResp.Agents { _, ok := test.wantedAgentIDs[agent.Id] @@ -181,7 +183,7 @@ func TestGetAgents(t *testing.T) { func TestGetAgent(t *testing.T) { type AgentTestCase struct { Name string - podsService *pods + jobsService *jobsService agentExists bool wantedAgentID string } @@ -195,7 +197,7 @@ func TestGetAgent(t *testing.T) { }, Status: k8sV1.NodeStatus{ Allocatable: map[k8sV1.ResourceName]resource.Quantity{ - k8sV1.ResourceName(ResourceTypeNvidia): *resource.NewQuantity( + k8sV1.ResourceName(resourceTypeNvidia): *resource.NewQuantity( 16, resource.DecimalSI, ), @@ -206,7 +208,7 @@ func TestGetAgent(t *testing.T) { agentTests := []AgentTestCase{ { Name: "GetAgent-CPU-NoPodLabels-Aux1", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ auxNode1Name: auxNode1, auxNode2Name: auxNode2, }, @@ -218,7 +220,7 @@ func TestGetAgent(t *testing.T) { }, { Name: "GetAgent-CPU-PodLabels-Aux2", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ auxNode1Name: auxNode1, auxNode2Name: auxNode2, }, @@ -230,7 +232,7 @@ func TestGetAgent(t *testing.T) { }, { Name: "GetAgent-GPU-PodLabels-Comp1", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -242,7 +244,7 @@ func TestGetAgent(t *testing.T) { }, { Name: "GetAgent-CUDA-NoPodLabels-Comp2", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -254,7 +256,7 @@ func TestGetAgent(t *testing.T) { }, { Name: "GetAgent-CUDA-NoPodLabels-NonexistentAgent", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -266,7 +268,7 @@ func TestGetAgent(t *testing.T) { }, { Name: "GetAgent-CUDA-NoPodLabels-EmptyAgentID", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -278,7 +280,7 @@ func TestGetAgent(t *testing.T) { }, { Name: "GetAgent-CUDA-Large-Node", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: largeNode, }, slotTypeGPU, false), wantedAgentID: compNode1Name, @@ -288,7 +290,7 @@ func TestGetAgent(t *testing.T) { for _, test := range agentTests { t.Run(test.Name, func(t *testing.T) { - agentResp := test.podsService.handleGetAgentRequest(test.wantedAgentID) + agentResp := test.jobsService.getAgent(test.wantedAgentID) if agentResp == nil { require.True(t, !test.agentExists) return @@ -316,7 +318,7 @@ func TestGetAgent(t *testing.T) { func TestGetSlots(t *testing.T) { type SlotsTestCase struct { Name string - podsService *pods + jobsService *jobsService agentID string agentExists bool wantedSlotsNum int @@ -327,7 +329,7 @@ func TestGetSlots(t *testing.T) { slotsTests := []SlotsTestCase{ { Name: "GetSlots-CPU-NoPodLabels-Aux1", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ auxNode1Name: auxNode1, auxNode2Name: auxNode2, }, @@ -340,7 +342,7 @@ func TestGetSlots(t *testing.T) { }, { Name: "GetSlots-GPU-NoPodLabels-Comp2", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -353,7 +355,7 @@ func TestGetSlots(t *testing.T) { }, { Name: "GetSlots-CUDA-PodLabels-Comp1", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -366,7 +368,7 @@ func TestGetSlots(t *testing.T) { }, { Name: "GetSlots-CUDA-PodLabels-NonexistentAgent", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -379,7 +381,7 @@ func TestGetSlots(t *testing.T) { }, { Name: "GetSlots-CUDA-PodLabels-EmptyAgentID", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -401,7 +403,7 @@ func TestGetSlots(t *testing.T) { } for _, test := range slotsTests { t.Run(test.Name, func(t *testing.T) { - slotsResp := test.podsService.handleGetSlotsRequest(test.agentID) + slotsResp := test.jobsService.getSlots(test.agentID) if slotsResp == nil { require.True(t, !test.agentExists) return @@ -428,7 +430,7 @@ func TestGetSlots(t *testing.T) { func TestGetSlot(t *testing.T) { type SlotTestCase struct { Name string - podsService *pods + jobsService *jobsService agentID string wantedSlotNum string } @@ -438,7 +440,7 @@ func TestGetSlot(t *testing.T) { slotTests := []SlotTestCase{ { Name: "GetSlot-CPU-PodLabels-Aux1-LastId", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ auxNode1Name: auxNode1, auxNode2Name: auxNode2, }, @@ -450,7 +452,7 @@ func TestGetSlot(t *testing.T) { }, { Name: "GetSlot-GPU-PodLabels-Comp1-Id4", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -462,7 +464,7 @@ func TestGetSlot(t *testing.T) { }, { Name: "GetSlot-GPU-PodLabels-Comp1-Id4", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -474,7 +476,7 @@ func TestGetSlot(t *testing.T) { }, { Name: "GetSlot-GPU-PodLabels-Comp1-Id0", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -486,7 +488,7 @@ func TestGetSlot(t *testing.T) { }, { Name: "GetSlot-CUDA-NoPodLabels-Comp1-BadSlotReq", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -498,7 +500,7 @@ func TestGetSlot(t *testing.T) { }, { Name: "GetSlot-CUDA-PodLabels-Comp2-BadSlotReq", - podsService: createMockPodsService(map[string]*k8sV1.Node{ + jobsService: createMockJobsService(map[string]*k8sV1.Node{ compNode1Name: compNode1, compNode2Name: compNode2, }, @@ -515,7 +517,7 @@ func TestGetSlot(t *testing.T) { wantedSlotInt, err := strconv.Atoi(test.wantedSlotNum) require.NoError(t, err) - slotResp := test.podsService.handleGetSlotRequest(test.agentID, test.wantedSlotNum) + slotResp := test.jobsService.getSlot(test.agentID, test.wantedSlotNum) if slotResp == nil { require.True(t, wantedSlotInt < 0 || wantedSlotInt >= int(nodeNumSlots)) return @@ -539,14 +541,12 @@ func TestAssignResourcesTime(t *testing.T) { groups[allocateReq.JobID] = &tasklist.Group{ JobID: allocateReq.JobID, } - mockPods := createMockPodsService(make(map[string]*k8sV1.Node), device.CUDA, true) + mockPods := createMockJobsService(make(map[string]*k8sV1.Node), device.CUDA, true) poolRef := &kubernetesResourcePool{ poolConfig: &config.ResourcePoolConfig{PoolName: "cpu-pool"}, - podsService: mockPods, + jobsService: mockPods, reqList: taskList, groups: groups, - allocationIDToContainerID: map[model.AllocationID]cproto.ID{}, - containerIDtoAllocationID: map[string]model.AllocationID{}, jobIDToAllocationID: map[model.JobID]model.AllocationID{}, allocationIDToJobID: map[model.AllocationID]model.JobID{}, slotsUsedPerGroup: map[*tasklist.Group]int{}, @@ -579,15 +579,15 @@ func TestGetResourcePools(t *testing.T) { }, } - mockPods := createMockPodsService(make(map[string]*k8sV1.Node), device.CUDA, true) + mockPods := createMockJobsService(make(map[string]*k8sV1.Node), device.CUDA, true) cpuPoolRef := &kubernetesResourcePool{ poolConfig: &config.ResourcePoolConfig{PoolName: "cpu-pool"}, - podsService: mockPods, + jobsService: mockPods, reqList: tasklist.New(), } gpuPoolRef := &kubernetesResourcePool{ poolConfig: &config.ResourcePoolConfig{PoolName: "gpu-pool"}, - podsService: mockPods, + jobsService: mockPods, reqList: tasklist.New(), } kubernetesRM := &ResourceManager{ @@ -648,15 +648,15 @@ func TestGetResourcePools(t *testing.T) { } func TestGetJobQueueStatsRequest(t *testing.T) { - mockPods := createMockPodsService(make(map[string]*k8sV1.Node), device.CUDA, true) + mockPods := createMockJobsService(make(map[string]*k8sV1.Node), device.CUDA, true) pool1 := &kubernetesResourcePool{ poolConfig: &config.ResourcePoolConfig{PoolName: "pool1"}, - podsService: mockPods, + jobsService: mockPods, reqList: tasklist.New(), } pool2 := &kubernetesResourcePool{ poolConfig: &config.ResourcePoolConfig{PoolName: "pool2"}, - podsService: mockPods, + jobsService: mockPods, reqList: tasklist.New(), } k8sRM := &ResourceManager{ @@ -691,7 +691,7 @@ func TestHealthCheck(t *testing.T) { config: &config.KubernetesResourceManagerConfig{ Name: "testname", }, - podsService: &pods{ + jobsService: &jobsService{ podInterfaces: map[string]typedV1.PodInterface{ "namespace": mockPodInterface, }, @@ -721,7 +721,7 @@ func TestHealthCheck(t *testing.T) { }) } -func TestROCmPodsService(t *testing.T) { +func TestROCmJobsService(t *testing.T) { tests := []struct { name string testFunc func() @@ -737,27 +737,27 @@ func TestROCmPodsService(t *testing.T) { } func testROCMGetAgents() { - ps := createMockPodsService(createCompNodeMap(), device.ROCM, false) - ps.handleGetAgentsRequest() + ps := createMockJobsService(createCompNodeMap(), device.ROCM, false) + ps.getAgents() } func testROCMGetAgent() { nodes := createCompNodeMap() - ps := createMockPodsService(nodes, device.ROCM, false) - ps.handleGetAgentRequest(compNode1Name) + ps := createMockJobsService(nodes, device.ROCM, false) + ps.getAgent(compNode1Name) } func testROCMGetSlots() { nodes := createCompNodeMap() - ps := createMockPodsService(nodes, device.ROCM, false) - ps.handleGetSlotsRequest(compNode1Name) + ps := createMockJobsService(nodes, device.ROCM, false) + ps.getSlots(compNode1Name) } func testROCMGetSlot() { nodes := createCompNodeMap() - ps := createMockPodsService(nodes, device.ROCM, false) + ps := createMockJobsService(nodes, device.ROCM, false) for i := 0; i < int(nodeNumSlots); i++ { - ps.handleGetSlotRequest(compNode1Name, strconv.Itoa(i)) + ps.getSlot(compNode1Name, strconv.Itoa(i)) } } @@ -767,7 +767,7 @@ func setupNodes() (*k8sV1.Node, *k8sV1.Node, *k8sV1.Node, *k8sV1.Node) { } compResourceList := map[k8sV1.ResourceName]resource.Quantity{ - k8sV1.ResourceName(ResourceTypeNvidia): *resource.NewQuantity( + k8sV1.ResourceName(resourceTypeNvidia): *resource.NewQuantity( nodeNumSlots, resource.DecimalSI, ), @@ -824,45 +824,56 @@ func createCompNodeMap() map[string]*k8sV1.Node { } } -// createMockPodsService creates two pods. One pod is run on the auxiliary node and the other is +// createMockJobsService creates two pods. One pod is run on the auxiliary node and the other is // run on the compute node. -func createMockPodsService(nodes map[string]*k8sV1.Node, devSlotType device.Type, +func createMockJobsService(nodes map[string]*k8sV1.Node, devSlotType device.Type, labels bool, -) *pods { +) *jobsService { + var jobsList batchV1.JobList + var podsList k8sV1.PodList // Create two pods that are scheduled on a node. - pod1 := &pod{ + jobName1 := uuid.NewString() + job1 := &job{ allocationID: model.AllocationID(uuid.New().String()), - slots: pod1NumSlots, - pod: &k8sV1.Pod{ - Spec: k8sV1.PodSpec{NodeName: auxNode1Name}, + jobName: jobName1, + slotsPerPod: pod1NumSlots, + podNodeNames: map[string]string{ + jobName1: auxNode1Name, }, } - pod2 := &pod{ + jobsList.Items = append(jobsList.Items, batchV1.Job{ObjectMeta: metaV1.ObjectMeta{Name: jobName1}}) + podsList.Items = append(podsList.Items, k8sV1.Pod{ObjectMeta: metaV1.ObjectMeta{Name: jobName1}}) + + jobName2 := uuid.NewString() + job2 := &job{ allocationID: model.AllocationID(uuid.New().String()), - slots: pod2NumSlots, - pod: &k8sV1.Pod{ - Spec: k8sV1.PodSpec{NodeName: compNode1Name}, + jobName: jobName2, + slotsPerPod: pod2NumSlots, + podNodeNames: map[string]string{ + uuid.NewString(): compNode1Name, }, } + jobsList.Items = append(jobsList.Items, batchV1.Job{ObjectMeta: metaV1.ObjectMeta{Name: jobName2}}) + podsList.Items = append(podsList.Items, k8sV1.Pod{ObjectMeta: metaV1.ObjectMeta{Name: jobName2}}) // Create pod that is not yet scheduled on a node. - pod3 := &pod{ + jobName3 := uuid.NewString() + job3 := &job{ allocationID: model.AllocationID(uuid.New().String()), - slots: 0, - pod: &k8sV1.Pod{ - Spec: k8sV1.PodSpec{NodeName: ""}, - }, + jobName: jobName3, + slotsPerPod: 0, + podNodeNames: map[string]string{}, } + jobsList.Items = append(jobsList.Items, batchV1.Job{ObjectMeta: metaV1.ObjectMeta{Name: jobName3}}) + podsList.Items = append(podsList.Items, k8sV1.Pod{ObjectMeta: metaV1.ObjectMeta{Name: jobName3}}) - podsList := &k8sV1.PodList{Items: []k8sV1.Pod{*pod1.pod, *pod2.pod, *pod3.pod}} - - var nonDetPod *pod + var nonDetPod *job if labels { // Give labels to all determined pods. - pod1.pod.ObjectMeta = metaV1.ObjectMeta{Labels: map[string]string{"determined": ""}} - pod2.pod.ObjectMeta = metaV1.ObjectMeta{Labels: map[string]string{"determined": ""}} - pod3.pod.ObjectMeta = metaV1.ObjectMeta{Labels: map[string]string{"determined": ""}} + for _, j := range jobsList.Items { + j.ObjectMeta = metaV1.ObjectMeta{Labels: map[string]string{"determined": ""}} + } resourceList := make(map[k8sV1.ResourceName]resource.Quantity) @@ -870,7 +881,7 @@ func createMockPodsService(nodes map[string]*k8sV1.Node, devSlotType device.Type resourceList[k8sV1.ResourceName(device.CPU)] = *resource.NewQuantity(nodeNumSlotsCPU, resource.DecimalSI) } else { - resourceList[k8sV1.ResourceName(ResourceTypeNvidia)] = *resource.NewQuantity(nodeNumSlots, + resourceList[k8sV1.ResourceName(resourceTypeNvidia)] = *resource.NewQuantity(nodeNumSlots, resource.DecimalSI) } nonDetNode := k8sV1.Node{ @@ -884,23 +895,23 @@ func createMockPodsService(nodes map[string]*k8sV1.Node, devSlotType device.Type nodes[nonDetNode.Name] = &nonDetNode // Create pod without determined label. - nonDetPod = &pod{ + nonDetPod = &job{ allocationID: model.AllocationID(uuid.New().String()), - slots: 0, - pod: &k8sV1.Pod{ - Spec: k8sV1.PodSpec{NodeName: nonDetNodeName}, + slotsPerPod: 0, + podNodeNames: map[string]string{ + uuid.NewString(): nonDetNodeName, }, } - podsList.Items = append(podsList.Items, *nonDetPod.pod) + jobsList.Items = append(jobsList.Items, batchV1.Job{}) } - podHandlers := map[string]*pod{ - string(pod1.allocationID): pod1, - string(pod2.allocationID): pod2, - string(pod3.allocationID): pod3, + jobHandlers := map[string]*job{ + string(job1.allocationID): job1, + string(job2.allocationID): job2, + string(job3.allocationID): job3, } if nonDetPod != nil { - podHandlers[string(nonDetPod.allocationID)] = nonDetPod + jobHandlers[string(nonDetPod.allocationID)] = nonDetPod } // Create pod service client set. @@ -908,14 +919,19 @@ func createMockPodsService(nodes map[string]*k8sV1.Node, devSlotType device.Type coreV1Interface := &mocks.K8sCoreV1Interface{} podsInterface := &mocks.PodInterface{} podsInterface.On("List", mock.Anything, mock.Anything).Return(podsList, nil) + batchV1Interface := &mocks.K8sBatchV1Interface{} + jobsInterface := &mocks.JobInterface{} + jobsInterface.On("List", mock.Anything, mock.Anything).Return(jobsList, nil) coreV1Interface.On("Pods", mock.Anything).Return(podsInterface) + batchV1Interface.On("Jobs", mock.Anything).Return(jobsInterface) podsClientSet.On("CoreV1").Return(coreV1Interface) + podsClientSet.On("BatchV1").Return(batchV1Interface) - return &pods{ + return &jobsService{ namespace: "default", namespaceToPoolName: make(map[string]string), currentNodes: nodes, - podNameToPodHandler: podHandlers, + jobNameToJobHandler: jobHandlers, slotType: devSlotType, syslog: logrus.WithField("namespace", namespace), nodeToSystemResourceRequests: map[string]int64{ diff --git a/master/internal/rm/kubernetesrm/log.go b/master/internal/rm/kubernetesrm/log.go index fba315a1a02..9272466ba46 100644 --- a/master/internal/rm/kubernetesrm/log.go +++ b/master/internal/rm/kubernetesrm/log.go @@ -33,8 +33,8 @@ func startPodLogStreamer( return errors.Wrapf(err, "failed to initialize log stream for pod: %s", podName) } syslog := logrus.WithField("podName", podName) - logger := &podLogStreamer{callback} + logger := &podLogStreamer{callback} go logger.receiveStreamLogs(syslog, logReader) return nil @@ -50,6 +50,7 @@ func (p *podLogStreamer) receiveStreamLogs( syslog *logrus.Entry, logReader io.ReadCloser, ) { + syslog.Debug("starting pod log streamer") _, err := io.Copy(p, logReader) if err != nil { syslog.WithError(err).Debug("error reading logs") diff --git a/master/internal/rm/kubernetesrm/mock_client_test.go b/master/internal/rm/kubernetesrm/mock_client_test.go index 16cc1ac2282..895fd68ef92 100644 --- a/master/internal/rm/kubernetesrm/mock_client_test.go +++ b/master/internal/rm/kubernetesrm/mock_client_test.go @@ -12,6 +12,7 @@ import ( "github.com/hashicorp/go-cleanhttp" "github.com/pkg/errors" + batchV1 "k8s.io/api/batch/v1" k8sV1 "k8s.io/api/core/v1" "k8s.io/api/policy/v1beta1" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -106,12 +107,6 @@ type mockPodInterface struct { mux sync.Mutex } -func (m *mockPodInterface) hasPod(name string) bool { - m.mux.Lock() - defer m.mux.Unlock() - return m.pods[name] != nil -} - func (m *mockPodInterface) Create( ctx context.Context, pod *k8sV1.Pod, opts metaV1.CreateOptions, ) (*k8sV1.Pod, error) { @@ -154,13 +149,6 @@ func (m *mockPodInterface) Delete( return nil } -func (m *mockPodInterface) delete(name string) { - m.mux.Lock() - defer m.mux.Unlock() - - delete(m.pods, name) -} - func (m *mockPodInterface) DeleteCollection( ctx context.Context, options metaV1.DeleteOptions, listOptions metaV1.ListOptions, ) error { @@ -248,3 +236,91 @@ func (m *mockRoundTripInterface) RoundTrip(req *http.Request) (*http.Response, e Body: io.NopCloser(strings.NewReader(msg)), }, nil } + +type mockJobInterface struct { + jobs map[string]*batchV1.Job + // Simulates latency of the real k8 API server. + operationalDelay time.Duration + mux sync.Mutex +} + +func (m *mockJobInterface) Create( + ctx context.Context, job *batchV1.Job, opts metaV1.CreateOptions, +) (*batchV1.Job, error) { + time.Sleep(m.operationalDelay) + m.mux.Lock() + defer m.mux.Unlock() + + if _, present := m.jobs[job.Name]; present { + return nil, errors.Errorf("pod with name %s already exists", job.Name) + } + + m.jobs[job.Name] = job.DeepCopy() + return m.jobs[job.Name], nil +} + +func (m *mockJobInterface) Update( + context.Context, *batchV1.Job, metaV1.UpdateOptions, +) (*batchV1.Job, error) { + panic("implement me") +} + +func (m *mockJobInterface) UpdateStatus( + context.Context, *batchV1.Job, metaV1.UpdateOptions, +) (*batchV1.Job, error) { + panic("implement me") +} + +func (m *mockJobInterface) Delete( + ctx context.Context, name string, options metaV1.DeleteOptions, +) error { + m.mux.Lock() + defer m.mux.Unlock() + + if _, present := m.jobs[name]; !present { + return errors.Errorf("job with name %s doesn't exists", name) + } + + delete(m.jobs, name) + return nil +} + +func (m *mockJobInterface) DeleteCollection( + ctx context.Context, options metaV1.DeleteOptions, listOptions metaV1.ListOptions, +) error { + panic("implement me") +} + +func (m *mockJobInterface) Get( + ctx context.Context, name string, options metaV1.GetOptions, +) (*batchV1.Job, error) { + panic("implement me") +} + +func (m *mockJobInterface) List( + ctx context.Context, opts metaV1.ListOptions, +) (*batchV1.JobList, error) { + time.Sleep(m.operationalDelay) + m.mux.Lock() + defer m.mux.Unlock() + + jobList := &batchV1.JobList{} + for _, job := range m.jobs { + jobList.Items = append(jobList.Items, *job) + } + + return jobList, nil +} + +func (m *mockJobInterface) Watch( + ctx context.Context, opts metaV1.ListOptions, +) (watch.Interface, error) { + panic("implement me") +} + +func (m *mockJobInterface) Patch( + ctx context.Context, name string, pt types.PatchType, data []byte, opts metaV1.PatchOptions, + subresources ...string, +) (result *batchV1.Job, err error) { + panic("implement me") +} diff --git a/master/internal/rm/kubernetesrm/pod.go b/master/internal/rm/kubernetesrm/pod.go deleted file mode 100644 index b77d5e33714..00000000000 --- a/master/internal/rm/kubernetesrm/pod.go +++ /dev/null @@ -1,608 +0,0 @@ -package kubernetesrm - -import ( - "fmt" - "regexp" - "strconv" - "sync" - "sync/atomic" - "time" - - "github.com/docker/docker/pkg/stdcopy" - "github.com/sirupsen/logrus" - - "github.com/pkg/errors" - - "github.com/determined-ai/determined/master/internal/config" - "github.com/determined-ai/determined/master/internal/rm/rmevents" - "github.com/determined-ai/determined/master/internal/sproto" - "github.com/determined-ai/determined/master/pkg/aproto" - "github.com/determined-ai/determined/master/pkg/cproto" - "github.com/determined-ai/determined/master/pkg/device" - "github.com/determined-ai/determined/master/pkg/logger" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/set" - "github.com/determined-ai/determined/master/pkg/tasks" - - k8sV1 "k8s.io/api/core/v1" - k8sClient "k8s.io/client-go/kubernetes" - typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" -) - -const ( - initContainerTarSrcPath = "/run/determined/temp/tar/src" - initContainerTarDstPath = "/run/determined/temp/tar/dst" - initContainerWorkDir = "/run/determined/temp/" - determinedLabel = "determined" - determinedPreemptionLabel = "determined-preemption" - determinedSystemLabel = "determined-system" -) - -type podSubmissionInfo struct { - taskSpec tasks.TaskSpec -} - -// TODO(mar). -// podStatusUpdate: messages that are sent by the pod informer. -type podStatusUpdate struct { - updatedPod *k8sV1.Pod -} - -// pod manages the lifecycle of a Kubernetes pod that executes a -// Determined task. The lifecycle of the pod is managed based on -// the status of the specified set of containers. -// -// TODO(DET-10011): Give this literal a more intuitive name. -type pod struct { - mu sync.Mutex - - req *sproto.AllocateRequest - - clusterID string - allocationID model.AllocationID - clientSet k8sClient.Interface - namespace string - masterIP string - masterPort int32 - // submissionInfo will be nil when the pod is restored. - // These fields can not be relied on after a pod is submitted. - submissionInfo *podSubmissionInfo - masterTLSConfig model.TLSClientConfig - loggingTLSConfig model.TLSClientConfig - loggingConfig model.LoggingConfig - slots int - podInterface typedV1.PodInterface - configMapInterface typedV1.ConfigMapInterface - resourceRequestQueue *requestQueue - scheduler string - slotType device.Type - slotResourceRequests config.PodSlotResourceRequests - - pod *k8sV1.Pod - podName string - configMap *k8sV1.ConfigMap - configMapName string - // TODO(DET-10013) : Remove container field from pod struct. - container cproto.Container - ports []int - resourcesDeleted atomic.Bool - containerNames set.Set[string] - - restore bool - - syslog *logrus.Entry -} - -type podNodeInfo struct { - nodeName string - numSlots int - slotType device.Type - container *cproto.Container -} - -func newPod( - msg StartTaskPod, - clusterID string, - clientSet k8sClient.Interface, - namespace string, - masterIP string, - masterPort int32, - masterTLSConfig model.TLSClientConfig, - loggingTLSConfig model.TLSClientConfig, - loggingConfig model.LoggingConfig, - podInterface typedV1.PodInterface, - configMapInterface typedV1.ConfigMapInterface, - resourceRequestQueue *requestQueue, - slotType device.Type, - slotResourceRequests config.PodSlotResourceRequests, - scheduler string, -) *pod { - podContainer := cproto.Container{ - ID: cproto.ID(msg.Spec.ContainerID), - State: cproto.Assigned, - Description: msg.Spec.Description, - } - uniqueName := configureUniqueName(msg.Spec, msg.Rank) - - // The lifecycle of the containers specified in this map will be monitored. - // As soon as one or more of them exits, the pod will be terminated. - containerNames := set.FromSlice([]string{model.DeterminedK8ContainerName}) - - p := &pod{ - req: msg.Req, - submissionInfo: &podSubmissionInfo{ - taskSpec: msg.Spec, - }, - clusterID: clusterID, - allocationID: msg.AllocationID, - clientSet: clientSet, - namespace: namespace, - masterIP: masterIP, - masterPort: masterPort, - masterTLSConfig: masterTLSConfig, - loggingTLSConfig: loggingTLSConfig, - loggingConfig: loggingConfig, - slots: msg.Slots, - podInterface: podInterface, - configMapInterface: configMapInterface, - resourceRequestQueue: resourceRequestQueue, - podName: uniqueName, - configMapName: uniqueName, - container: podContainer, - containerNames: containerNames, - scheduler: scheduler, - slotType: slotType, - slotResourceRequests: slotResourceRequests, - syslog: logrus.New().WithField("component", "pod").WithFields( - logger.MergeContexts(msg.LogContext, logger.Context{ - "pod": uniqueName, - }).Fields(), - ), - } - return p -} - -func (p *pod) start() error { - if p.restore { - if p.container.State == cproto.Running { - err := p.startPodLogStreamer() - if err != nil { - return err - } - } - } else { - if err := p.createPodSpecAndSubmit(); err != nil { - return fmt.Errorf("creating pod spec: %w", err) - } - } - return nil -} - -func (p *pod) finalize() { - p.kill() - p.finalizeTaskState() -} - -func (p *pod) podStatusUpdate(updatedPod *k8sV1.Pod) (cproto.State, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if p.container.State == cproto.Terminated { - return p.container.State, nil - } - - p.pod = updatedPod - - containerState, err := p.getPodState(p.pod, p.containerNames) - if err != nil { - return p.container.State, err - } - - if containerState == p.container.State { - return p.container.State, nil - } - - switch containerState { - case cproto.Assigned: - // Don't need to do anything. - - case cproto.Starting: - // Kubernetes does not have an explicit state for pulling container images. - // We insert it here because our current implementation of the trial actor requires it. - p.syslog.Infof( - "transitioning pod state from %s to %s", p.container.State, cproto.Pulling) - p.container = p.container.Transition(cproto.Pulling) - p.informTaskResourcesState() - - p.syslog.Infof("transitioning pod state from %s to %s", p.container.State, containerState) - p.container = p.container.Transition(cproto.Starting) - p.informTaskResourcesState() - - case cproto.Running: - p.syslog.Infof("transitioning pod state from %s to %s", p.container.State, containerState) - p.container = p.container.Transition(cproto.Running) - p.informTaskResourcesStarted(getResourcesStartedForPod(p.pod, p.ports)) - err := p.startPodLogStreamer() - if err != nil { - return p.container.State, err - } - - case cproto.Terminated: - exitCode, exitMessage, err := getExitCodeAndMessage(p.pod, p.containerNames) - if err != nil { - // When a pod is deleted, it is possible that it will exit before the - // determined containers generates an exit code. To check if this is - // the case we check if a deletion timestamp has been set. - if p.pod.ObjectMeta.DeletionTimestamp != nil { - p.syslog.Info("unable to get exit code for pod, setting exit code to 1025") - exitCode = 1025 - exitMessage = "unable to get exit code or exit message from pod" - } else { - return p.container.State, err - } - } - - p.syslog.Infof("transitioning pod state from %s to %s", p.container.State, containerState) - p.container = p.container.Transition(cproto.Terminated) - - var resourcesStopped sproto.ResourcesStopped - switch exitCode { - case aproto.SuccessExitCode: - p.syslog.Infof("pod exited successfully") - default: - p.syslog.Infof("pod failed with exit code: %d %s", exitCode, exitMessage) - resourcesStopped.Failure = sproto.NewResourcesFailure( - sproto.ResourcesFailed, - exitMessage, - ptrs.Ptr(sproto.ExitCode(exitCode))) - } - p.informTaskResourcesStopped(resourcesStopped) - return p.container.State, nil - - default: - panic(fmt.Sprintf("unexpected container state %s", containerState)) - } - - return p.container.State, nil -} - -func (p *pod) podEventUpdate(event *k8sV1.Event) { - p.mu.Lock() - defer p.mu.Unlock() - - // We only forward messages while pods are starting up. - switch p.container.State { - case cproto.Running, cproto.Terminated: - return - } - - msgText := p.preparePodUpdateMessage(event.Message) - event.Message = msgText - - message := fmt.Sprintf("Pod %s: %s", event.InvolvedObject.Name, msgText) - p.insertLog(event.CreationTimestamp.Time, message) -} - -func (p *pod) PreemptTaskPod() { - p.syslog.Info("received preemption command") - rmevents.Publish(p.allocationID, &sproto.ReleaseResources{Reason: "preempted by the scheduler"}) -} - -func (p *pod) ChangePriority() { - p.syslog.Info("interrupting pod to change priorities") - rmevents.Publish(p.allocationID, &sproto.ReleaseResources{Reason: "priority changed"}) -} - -func (p *pod) ChangePosition() { - p.syslog.Info("interrupting pod to change positions") - rmevents.Publish(p.allocationID, &sproto.ReleaseResources{Reason: "queue position changed"}) -} - -func (p *pod) KillTaskPod() { - p.syslog.Info("received request to stop pod") - p.kill() -} - -func (p *pod) kill() { - if !p.resourcesDeleted.CompareAndSwap(false, true) { - return - } - - p.syslog.Infof("requesting to delete kubernetes resources") - p.resourceRequestQueue.deleteKubernetesResources( - p.namespace, - p.podName, - p.configMapName, - ) -} - -func (p *pod) getPodNodeInfo() podNodeInfo { - p.mu.Lock() - defer p.mu.Unlock() - - return podNodeInfo{ - nodeName: p.pod.Spec.NodeName, - numSlots: p.slots, - slotType: p.slotType, - container: p.container.DeepCopy(), - } -} - -func (p *pod) startPodLogStreamer() error { - return startPodLogStreamer(p.podInterface, p.podName, func(log []byte) { - p.receiveContainerLog(sproto.ContainerLog{ - Timestamp: time.Now().UTC(), - RunMessage: &aproto.RunMessage{ - Value: string(log), - StdType: stdcopy.Stdout, - }, - }) - }) -} - -func (p *pod) createPodSpecAndSubmit() error { - if err := p.createPodSpec(p.scheduler); err != nil { - return err - } - - p.resourceRequestQueue.createKubernetesResources(p.pod, p.configMap) - return nil -} - -func (p *pod) receiveResourceCreationFailed(msg resourceCreationFailed) { - p.syslog.WithError(msg.err).Error("pod handler notified that resource creation failed") - p.insertLog(time.Now().UTC(), msg.err.Error()) -} - -func (p *pod) receiveResourceCreationCancelled() { - p.syslog.Info("pod creation canceled") - p.resourcesDeleted.Store(true) -} - -func (p *pod) receiveResourceDeletionFailed(err resourceDeletionFailed) { - p.syslog.WithError(err.err).Error("pod handler notified that resource deletion failed") -} - -func (p *pod) finalizeTaskState() { - p.mu.Lock() - defer p.mu.Unlock() - - // If an error occurred during the lifecycle of the pods, we need to update the scheduler - // and the task handler with new state. - if p.container.State != cproto.Terminated { - p.syslog.Warnf("updating container state after pod exited unexpectedly") - p.container = p.container.Transition(cproto.Terminated) - - p.informTaskResourcesStopped(sproto.ResourcesError( - sproto.TaskError, - errors.New("pod handler exited while pod was running"), - )) - } -} - -func (p *pod) informTaskResourcesState() { - rmevents.Publish(p.allocationID, &sproto.ResourcesStateChanged{ - ResourcesID: sproto.FromContainerID(p.container.ID), - ResourcesState: sproto.FromContainerState(p.container.State), - Container: p.container.DeepCopy(), - }) -} - -func (p *pod) informTaskResourcesStarted(rs sproto.ResourcesStarted) { - rmevents.Publish(p.allocationID, &sproto.ResourcesStateChanged{ - ResourcesID: sproto.FromContainerID(p.container.ID), - ResourcesState: sproto.FromContainerState(p.container.State), - ResourcesStarted: &rs, - Container: p.container.DeepCopy(), - }) -} - -func (p *pod) informTaskResourcesStopped(rs sproto.ResourcesStopped) { - rmevents.Publish(p.allocationID, &sproto.ResourcesStateChanged{ - ResourcesID: sproto.FromContainerID(p.container.ID), - ResourcesState: sproto.FromContainerState(p.container.State), - ResourcesStopped: &rs, - Container: p.container.DeepCopy(), - }) -} - -func (p *pod) receiveContainerLog(msg sproto.ContainerLog) { - msg.ContainerID = p.container.ID - rmevents.Publish(p.allocationID, &msg) -} - -func (p *pod) insertLog(timestamp time.Time, msg string) { - p.receiveContainerLog(sproto.ContainerLog{ - Timestamp: timestamp, - AuxMessage: &msg, - }) -} - -// Converts k8s message to be more understandable. -func (p *pod) preparePodUpdateMessage(msgText string) string { - // Handle simple message replacements. - replacements := map[string]string{ - "pod triggered scale-up": "Job requires additional resources, scaling up cluster.", - "Successfully assigned": "Pod resources allocated.", - "skip schedule deleting pod": "Deleting unscheduled pod.", - } - - simpleReplacement := false - - for k, v := range replacements { - matched, err := regexp.MatchString(k, msgText) - if err != nil { - break - } else if matched { - msgText = v - simpleReplacement = true - } - } - - // Otherwise, try special treatment for slots availability message. - if !simpleReplacement { - matched, err := regexp.MatchString("nodes are available", msgText) - if err == nil && matched { - available := string(msgText[0]) - required := strconv.Itoa(p.slots) - var resourceName string - switch p.slotType { - case device.CPU: - resourceName = "CPU slots" - default: - resourceName = "GPUs" - } - - msgText = fmt.Sprintf("Waiting for resources. %s %s are available, %s %s required", - available, resourceName, required, resourceName) - } - } - - return msgText -} - -func (p *pod) getPodState( - pod *k8sV1.Pod, - containerNames set.Set[string], -) (cproto.State, error) { - switch pod.Status.Phase { - case k8sV1.PodPending: - // When pods are deleted, Kubernetes sometimes transitions pod statuses to pending - // prior to deleting them. In these cases we have observed that we do not always - // receive a PodFailed or a PodSucceeded message. We check if pods have a set pod - // deletion timestamp to see if this is the case. - if pod.ObjectMeta.DeletionTimestamp != nil { - p.syslog.Warn("marking pod as terminated due to deletion timestamp") - return cproto.Terminated, nil - } - - for _, condition := range pod.Status.Conditions { - if condition.Type == k8sV1.PodScheduled && condition.Status == k8sV1.ConditionTrue { - return cproto.Starting, nil - } - } - return cproto.Assigned, nil - - case k8sV1.PodRunning: - // Pods are in a running state as long as at least one container has not terminated. - // We check the status of the Determined containers directly to determine if they - // are still running. - containerStatuses, err := getDeterminedContainersStatus( - pod.Status.ContainerStatuses, containerNames) - if err != nil { - return "", err - } - - for _, containerStatus := range containerStatuses { - if containerStatus.State.Terminated != nil { - return cproto.Terminated, nil - } - } - - for _, containerStatus := range containerStatuses { - // Check that all Determined containers are running. - if containerStatus.State.Running == nil { - return cproto.Starting, nil - } - } - - return cproto.Running, nil - - case k8sV1.PodFailed, k8sV1.PodSucceeded: - return cproto.Terminated, nil - - default: - return "", errors.Errorf( - "unexpected pod status %s for pod %s", pod.Status.Phase, pod.Name) - } -} - -func getExitCodeAndMessage(pod *k8sV1.Pod, containerNames set.Set[string]) (int, string, error) { - if len(pod.Status.InitContainerStatuses) == 0 { - return 0, "", errors.Errorf( - "unexpected number of init containers when processing exit code for pod %s", pod.Name) - } - - for _, initContainerStatus := range pod.Status.InitContainerStatuses { - if initContainerStatus.State.Terminated == nil { - continue - } - exitCode := initContainerStatus.State.Terminated.ExitCode - if exitCode != aproto.SuccessExitCode { - errMessage := fmt.Sprintf( - "container %s: %s", initContainerStatus.Name, - initContainerStatus.State.Terminated.Message, - ) - return int(exitCode), errMessage, nil - } - } - - if len(pod.Status.ContainerStatuses) < len(containerNames) { - return 0, "", errors.Errorf( - "unexpected number of containers when processing exit code for pod %s", pod.Name) - } - - containerStatuses, err := getDeterminedContainersStatus( - pod.Status.ContainerStatuses, containerNames) - if err != nil { - return 0, "", err - } - - for _, containerStatus := range containerStatuses { - terminationStatus := containerStatus.State.Terminated - if terminationStatus != nil { - return int(terminationStatus.ExitCode), terminationStatus.Message, nil - } - } - - return 0, "", errors.Errorf("unable to get exit code from pod %s", pod.Name) -} - -func getResourcesStartedForPod(pod *k8sV1.Pod, ports []int) sproto.ResourcesStarted { - addresses := []cproto.Address{} - for _, port := range ports { - addresses = append(addresses, cproto.Address{ - ContainerIP: pod.Status.PodIP, - ContainerPort: port, - HostIP: pod.Status.PodIP, - HostPort: port, - }) - } - - var taskContainerID string - for _, containerStatus := range pod.Status.ContainerStatuses { - if containerStatus.Name == model.DeterminedK8ContainerName { - taskContainerID = containerStatus.ContainerID - break - } - } - - return sproto.ResourcesStarted{ - Addresses: addresses, - NativeResourcesID: taskContainerID, - } -} - -func getDeterminedContainersStatus( - statuses []k8sV1.ContainerStatus, - containerNames set.Set[string], -) ([]*k8sV1.ContainerStatus, error) { - containerStatuses := make([]*k8sV1.ContainerStatus, 0, len(statuses)) - for idx, containerStatus := range statuses { - if !containerNames.Contains(containerStatus.Name) { - continue - } - containerStatuses = append(containerStatuses, &statuses[idx]) - } - - if len(containerStatuses) != len(containerNames) { - containerNamesFound := make([]string, 0, len(containerStatuses)) - for _, containerStatus := range containerStatuses { - containerNamesFound = append(containerNamesFound, containerStatus.Name) - } - return nil, errors.Errorf("found container statuses only for: %v", containerNamesFound) - } - - return containerStatuses, nil -} diff --git a/master/internal/rm/kubernetesrm/pod_test.go b/master/internal/rm/kubernetesrm/pod_test.go deleted file mode 100644 index 20cee39121a..00000000000 --- a/master/internal/rm/kubernetesrm/pod_test.go +++ /dev/null @@ -1,811 +0,0 @@ -package kubernetesrm - -import ( - "context" - "fmt" - "reflect" - "testing" - "time" - - "github.com/google/uuid" - "github.com/pkg/errors" - "github.com/stretchr/testify/require" - "gotest.tools/assert" - - "github.com/determined-ai/determined/master/internal/config" - "github.com/determined-ai/determined/master/internal/rm/rmevents" - "github.com/determined-ai/determined/master/internal/sproto" - "github.com/determined-ai/determined/master/pkg/cproto" - "github.com/determined-ai/determined/master/pkg/device" - "github.com/determined-ai/determined/master/pkg/etc" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/set" - "github.com/determined-ai/determined/master/pkg/tasks" - - k8sV1 "k8s.io/api/core/v1" - metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" - k8sClient "k8s.io/client-go/kubernetes" - typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" -) - -func createPod( - allocationID model.AllocationID, - resourceHandler *requestQueue, - task tasks.TaskSpec, -) *pod { - msg := StartTaskPod{ - Req: &sproto.AllocateRequest{}, - AllocationID: allocationID, - Spec: task, - Slots: 1, - } - clusterID := "test" - clientSet := k8sClient.Clientset{} - namespace := "default" - masterIP := "0.0.0.0" - var masterPort int32 = 32 - podInterface := &mockPodInterface{} - configMapInterface := clientSet.CoreV1().ConfigMaps(namespace) - resourceRequestQueue := resourceHandler - slotType := device.CUDA - slotResourceRequests := config.PodSlotResourceRequests{} - - newPodHandler := newPod( - msg, clusterID, &clientSet, namespace, masterIP, masterPort, - model.TLSClientConfig{}, model.TLSClientConfig{}, - model.LoggingConfig{DefaultLoggingConfig: &model.DefaultLoggingConfig{}}, - podInterface, configMapInterface, resourceRequestQueue, - slotType, slotResourceRequests, "default-scheduler", - ) - - return newPodHandler -} - -func createAgentUserGroup() *model.AgentUserGroup { - return &model.AgentUserGroup{ - ID: 1, - UserID: 1, - User: "determined", - UID: 1, - Group: "test-group", - GID: 1, - } -} - -func createUser() *model.User { - return &model.User{ - ID: 1, - Username: "determined", - Active: true, - Admin: false, - } -} - -func createPodWithMockQueue(t *testing.T, k8sRequestQueue *requestQueue) ( - *pod, - model.AllocationID, - *sproto.ResourcesSubscription, -) { - commandSpec := tasks.GenericCommandSpec{ - Base: tasks.TaskSpec{ - AllocationID: "task", - ContainerID: "container", - ClusterID: "cluster", - AgentUserGroup: createAgentUserGroup(), - Owner: createUser(), - UserSessionToken: "bogus", - }, - Config: model.CommandConfig{Description: "test-config"}, - } - - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - failures := make(chan resourcesRequestFailure, 1024) - if k8sRequestQueue == nil { - podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} - configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} - k8sRequestQueue = startRequestQueue( - map[string]typedV1.PodInterface{"default": podInterface}, - map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, - failures, - ) - } - - aID := model.AllocationID(uuid.NewString()) - sub := rmevents.Subscribe(aID) - newPod := createPod( - aID, - k8sRequestQueue, - commandSpec.ToTaskSpec(), - ) - - go consumeResourceRequestFailures(ctx, failures, newPod) - - err := newPod.start() - require.NoError(t, err) - time.Sleep(500 * time.Millisecond) - - return newPod, aID, sub -} - -func setupEntrypoint(t *testing.T) { - err := etc.SetRootPath("../../../static/srv") - if err != nil { - t.Logf("Failed to set root directory") - } -} - -func checkReceiveTermination( - t *testing.T, - update podStatusUpdate, - newPod *pod, - sub *sproto.ResourcesSubscription, -) { - state, err := newPod.podStatusUpdate(update.updatedPod) - switch { - case err != nil, state == cproto.Terminated: - newPod.finalize() - } - time.Sleep(time.Second) - - assert.Equal(t, sub.Len(), 1) - message := sub.Get() - containerMsg, ok := message.(*sproto.ResourcesStateChanged) - if !ok { - t.Errorf( - "expected sproto.ResourcesStateChanged but received %s", - reflect.TypeOf(message), - ) - } - if containerMsg.ResourcesStopped == nil { - t.Errorf("container stopped message not present (state=%s)", containerMsg.ResourcesState) - } - - assert.Equal(t, newPod.container.State, cproto.Terminated) -} - -func TestResourceCreationFailed(t *testing.T) { - setupEntrypoint(t) - - const correctMsg = "already exists" - - ref, aID, sub := createPodWithMockQueue(t, nil) - - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - // Send a second start message to trigger an additional resource creation failure. - err := ref.start() - require.NoError(t, err) - time.Sleep(time.Second) - - // We expect two messages in the queue because the pod actor sends itself a stop message. - assert.Equal(t, sub.Len(), 2) - message := sub.Get() - containerMsg, ok := message.(*sproto.ContainerLog) - if !ok { - t.Errorf("expected sproto.ContainerLog but received %s", reflect.TypeOf(message)) - } - assert.ErrorContains(t, errors.New(*containerMsg.AuxMessage), correctMsg) -} - -func TestReceivePodStatusUpdateTerminated(t *testing.T) { - setupEntrypoint(t) - - typeMeta := metaV1.TypeMeta{Kind: "rest test"} - objectMeta := metaV1.ObjectMeta{ - Name: "test meta", - DeletionTimestamp: &metaV1.Time{Time: time.Now()}, - } - - t.Run("pod deleting, but in pending state", func(t *testing.T) { - t.Logf("Testing PodPending status") - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: k8sV1.PodStatus{Phase: k8sV1.PodPending}, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - - checkReceiveTermination(t, statusUpdate, ref, sub) - }) - - t.Run("pod failed", func(t *testing.T) { - t.Logf("Testing PodFailed status") - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: k8sV1.PodStatus{Phase: k8sV1.PodFailed}, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - - checkReceiveTermination(t, statusUpdate, ref, sub) - }) - - // Pod succeeded. - t.Run("pod succeeded", func(t *testing.T) { - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: k8sV1.PodStatus{Phase: k8sV1.PodSucceeded}, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - - checkReceiveTermination(t, statusUpdate, ref, sub) - }) -} - -func TestMultipleContainerTerminate(t *testing.T) { - // Status update test involving two containers. - setupEntrypoint(t) - - containerStatuses := []k8sV1.ContainerStatus{ - { - Name: "test-pod-1", - State: k8sV1.ContainerState{ - Running: &k8sV1.ContainerStateRunning{}, - }, - }, - { - Name: "test-pod-2", - State: k8sV1.ContainerState{ - Terminated: &k8sV1.ContainerStateTerminated{}, - }, - }, - } - - t.Run("pod running with > 1 container, and one terminated", func(t *testing.T) { - t.Logf("two pods with one in terminated state") - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - ref.containerNames = set.FromSlice([]string{"test-pod-1", "test-pod-2"}) - - pod := k8sV1.Pod{ - TypeMeta: metaV1.TypeMeta{Kind: "rest test"}, - ObjectMeta: metaV1.ObjectMeta{ - Name: "test meta", - DeletionTimestamp: &metaV1.Time{Time: time.Now()}, - }, - Status: k8sV1.PodStatus{ - Phase: k8sV1.PodRunning, - ContainerStatuses: containerStatuses, - }, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - checkReceiveTermination(t, statusUpdate, ref, sub) - }) - - t.Run("multiple pods, 1 termination, no deletion timestamp", func(t *testing.T) { - // This results in an error, which causes pod termination and the same outcome. - t.Logf("two pods with one in terminated state and no deletion timestamp") - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - pod := k8sV1.Pod{ - TypeMeta: metaV1.TypeMeta{Kind: "rest test"}, - ObjectMeta: metaV1.ObjectMeta{ - Name: "test meta", - }, - Status: k8sV1.PodStatus{ - Phase: k8sV1.PodRunning, - ContainerStatuses: containerStatuses, - }, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - checkReceiveTermination(t, statusUpdate, ref, sub) - }) -} - -func TestReceivePodStatusUpdateAssigned(t *testing.T) { - setupEntrypoint(t) - - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - typeMeta := metaV1.TypeMeta{Kind: "rest test"} - objectMeta := metaV1.ObjectMeta{ - Name: "test meta", - } - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: k8sV1.PodStatus{Phase: k8sV1.PodPending}, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - - assert.Equal(t, ref.container.State, cproto.Assigned) - _, err := ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 0) - - ref.container.State = cproto.Starting - - _, err = ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 0) - assert.Equal(t, ref.container.State, cproto.Starting) -} - -func TestReceivePodStatusUpdateStarting(t *testing.T) { - setupEntrypoint(t) - - typeMeta := metaV1.TypeMeta{Kind: "rest test"} - objectMeta := metaV1.ObjectMeta{ - Name: "test meta", - } - - t.Run("pod status pending, pod scheduled", func(t *testing.T) { - t.Logf("Testing pod scheduled with pending status") - - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - condition := k8sV1.PodCondition{ - Type: k8sV1.PodScheduled, - Status: k8sV1.ConditionTrue, - Message: "This doesn't matter :)", - } - status := k8sV1.PodStatus{ - Phase: k8sV1.PodPending, - Conditions: []k8sV1.PodCondition{condition}, - } - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: status, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - - _, err := ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - time.Sleep(time.Second) - - assert.Equal(t, sub.Len(), 2) - assert.Equal(t, ref.container.State, cproto.Starting) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - _, err = ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - time.Sleep(time.Second) - - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - assert.Equal(t, ref.container.State, cproto.Starting) - }) - - t.Run("pod status Running, but container status waiting", func(t *testing.T) { - t.Logf("Testing pod running with waiting status") - - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - containerStatuses := []k8sV1.ContainerStatus{ - { - Name: "determined-container", - State: k8sV1.ContainerState{Waiting: &k8sV1.ContainerStateWaiting{}}, - }, - } - status := k8sV1.PodStatus{ - Phase: k8sV1.PodRunning, - ContainerStatuses: containerStatuses, - } - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: status, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - - _, err := ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - time.Sleep(time.Second) - - assert.Equal(t, sub.Len(), 2) - assert.Equal(t, ref.container.State, cproto.Starting) - }) - - t.Run("pod status running, but no container State inside", func(t *testing.T) { - t.Logf("Testing pod running with no status") - - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - status := k8sV1.PodStatus{ - Phase: k8sV1.PodRunning, - ContainerStatuses: []k8sV1.ContainerStatus{ - {Name: "determined-container"}, - }, - } - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: status, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - _, err := ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - time.Sleep(time.Second) - - assert.Equal(t, sub.Len(), 2) - assert.Equal(t, ref.container.State, cproto.Starting) - }) -} - -func TestMultipleContainersRunning(t *testing.T) { - // Status update test involving two containers. - setupEntrypoint(t) - - typeMeta := metaV1.TypeMeta{Kind: "rest test"} - objectMeta := metaV1.ObjectMeta{ - Name: "test meta", - } - containerStatuses := []k8sV1.ContainerStatus{ - { - Name: "determined-container", - State: k8sV1.ContainerState{Running: &k8sV1.ContainerStateRunning{}}, - }, - { - Name: "test-pod", - }, - } - - t.Run("pod with two containers and one doesn't have running state", func(t *testing.T) { - t.Logf("Testing two pods and one doesn't have running state") - - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - ref.container.State = cproto.Starting - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - status := k8sV1.PodStatus{ - Phase: k8sV1.PodRunning, - ContainerStatuses: containerStatuses, - } - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: status, - } - ref.containerNames = set.FromSlice([]string{ - "determined-container", - "test-pod", - }) - statusUpdate := podStatusUpdate{updatedPod: &pod} - - _, err := ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 0) - assert.Equal(t, ref.container.State, cproto.Starting) - }) - - // . - t.Run("multiple containers, all in running state, results in a running state", func(t *testing.T) { - t.Logf("Testing two pods with running states") - - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - ref.container.State = cproto.Starting - containerStatuses[1] = k8sV1.ContainerStatus{ - Name: "test-pod-2", - State: k8sV1.ContainerState{Running: &k8sV1.ContainerStateRunning{}}, - } - status := k8sV1.PodStatus{ - Phase: k8sV1.PodRunning, - ContainerStatuses: containerStatuses, - } - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: status, - } - statusUpdate := podStatusUpdate{updatedPod: &pod} - _, err := ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - time.Sleep(time.Second) - - assert.Equal(t, sub.Len(), 1) - message := sub.Get() - containerMsg, ok := message.(*sproto.ResourcesStateChanged) - if !ok { - t.Errorf("expected *sproto.ResourcesStateChanged but received %s", reflect.TypeOf(message)) - } - if containerMsg.ResourcesStarted == nil { - t.Errorf("container started message not present") - } - }) -} - -func TestReceivePodEventUpdate(t *testing.T) { - setupEntrypoint(t) - - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - object := k8sV1.ObjectReference{Kind: "mock", Namespace: "test", Name: "MockObject"} - newEvent := k8sV1.Event{ - InvolvedObject: object, - Reason: "testing", - Message: "0/99 nodes are available: 99 Insufficient cpu", - } - ref.slots = 99 - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - ref.podEventUpdate(&newEvent) - time.Sleep(time.Second) // TODO(DET-9790): Remove sleeps. - - assert.Equal(t, sub.Len(), 1) - message := sub.Get() - correctMsg := fmt.Sprintf("Pod %s: %s", object.Name, - "Waiting for resources. 0 GPUs are available, 99 GPUs required") - - containerMsg, ok := message.(*sproto.ContainerLog) - if !ok { - t.Errorf("expected sproto.ContainerLog but received %s", reflect.TypeOf(message)) - } - assert.Equal(t, *containerMsg.AuxMessage, correctMsg) - - // When container is in Running state, pod actor should not forward message. - purge(aID, sub) - ref.container.State = cproto.Running - ref.podEventUpdate(&newEvent) - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 0) - - // When container is in Terminated state, pod actor should not forward message. - ref.container.State = cproto.Terminated - ref.podEventUpdate(&newEvent) - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 0) -} - -func TestReceiveContainerLog(t *testing.T) { - setupEntrypoint(t) - - mockLogMessage := "mock log message" - ref, aID, sub := createPodWithMockQueue(t, nil) - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - ref.restore = true - ref.container.State = cproto.Running - ref.podInterface = &mockPodInterface{logMessage: &mockLogMessage} - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - err := ref.start() - require.NoError(t, err) - time.Sleep(time.Second) - - assert.Equal(t, sub.Len(), 1) - message := sub.Get() - containerMsg, ok := message.(*sproto.ContainerLog) - if !ok { - t.Errorf("expected sproto.ContainerLog but received %s", reflect.TypeOf(message)) - } - assert.Equal(t, containerMsg.RunMessage.Value, mockLogMessage) - - // reset state to starting - ref.container.State = cproto.Starting - mockLogMessage = "new mock log message" - - typeMeta := metaV1.TypeMeta{Kind: "running log test"} - objectMeta := metaV1.ObjectMeta{ - Name: "test meta", - } - containerStatuses := []k8sV1.ContainerStatus{ - { - Name: "sample-container", - State: k8sV1.ContainerState{Running: &k8sV1.ContainerStateRunning{}}, - }, - } - status := k8sV1.PodStatus{ - Phase: k8sV1.PodRunning, - ContainerStatuses: containerStatuses, - } - pod := k8sV1.Pod{ - TypeMeta: typeMeta, - ObjectMeta: objectMeta, - Status: status, - } - ref.containerNames = set.FromSlice([]string{ - "sample-container", - }) - statusUpdate := podStatusUpdate{updatedPod: &pod} - - _, err = ref.podStatusUpdate(statusUpdate.updatedPod) - require.NoError(t, err) - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 2) - assert.Equal(t, ref.container.State, cproto.Running) - - message = sub.Get() - resourceMsg, ok := message.(*sproto.ResourcesStateChanged) - if !ok { - t.Errorf("expected sproto.ResourcesStateChanged but received %s", reflect.TypeOf(message)) - } - assert.Equal(t, resourceMsg.Container.State, cproto.Running) - - message = sub.Get() - containerMsg, ok = message.(*sproto.ContainerLog) - if !ok { - t.Errorf("expected sproto.ContainerLog but received %s", reflect.TypeOf(message)) - } - assert.Equal(t, containerMsg.RunMessage.Value, mockLogMessage) -} - -func TestKillTaskPod(t *testing.T) { - setupEntrypoint(t) - - podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} - configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} - failures := make(chan resourcesRequestFailure, 1024) - k8sRequestQueue := startRequestQueue( - map[string]typedV1.PodInterface{"default": podInterface}, - map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, - failures, - ) - ref, _, _ := createPodWithMockQueue(t, k8sRequestQueue) - - // We take a quick nap immediately so we can purge the start message after it arrives. - time.Sleep(time.Second) - assert.Check(t, podInterface.hasPod(ref.podName)) - ref.KillTaskPod() - time.Sleep(time.Second) - assert.Check(t, !podInterface.hasPod(ref.podName)) - assert.Check(t, ref.resourcesDeleted.Load()) -} - -func TestResourceCreationCancelled(t *testing.T) { - setupEntrypoint(t) - - podInterface := &mockPodInterface{ - pods: make(map[string]*k8sV1.Pod), - operationalDelay: time.Minute * numKubernetesWorkers, - } - configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} - failures := make(chan resourcesRequestFailure, 1024) - k8sRequestQueue := startRequestQueue( - map[string]typedV1.PodInterface{"default": podInterface}, - map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, - failures, - ) - - for i := 0; i < numKubernetesWorkers; i++ { - createPodWithMockQueue(t, k8sRequestQueue) - } - time.Sleep(time.Second) - ref, aID, sub := createPodWithMockQueue(t, k8sRequestQueue) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go consumeResourceRequestFailures(ctx, failures, ref) - - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - ref.KillTaskPod() - - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 1) - - message := sub.Get() - containerMsg, ok := message.(*sproto.ResourcesStateChanged) - if !ok { - t.Errorf("expected *sproto.ResourcesStateChanged but received %s", - reflect.TypeOf(message)) - } - - var correctContainerStarted *sproto.ResourcesStarted - correctFailType := "task failed without an associated exit code" - correctErrMsg := "pod handler exited while pod was running" - var correctCode *sproto.ExitCode - - assert.Equal(t, containerMsg.ResourcesStarted, correctContainerStarted) - assert.Equal(t, containerMsg.ResourcesStopped.Failure.FailureType, - sproto.FailureType(correctFailType)) - assert.Equal(t, containerMsg.ResourcesStopped.Failure.ErrMsg, correctErrMsg) - assert.Equal(t, containerMsg.ResourcesStopped.Failure.ExitCode, correctCode) -} - -func TestResourceDeletionFailed(t *testing.T) { - setupEntrypoint(t) - - podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} - configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} - failures := make(chan resourcesRequestFailure, 1024) - k8sRequestQueue := startRequestQueue( - map[string]typedV1.PodInterface{"default": podInterface}, - map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, - failures, - ) - - ref, aID, sub := createPodWithMockQueue(t, k8sRequestQueue) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go consumeResourceRequestFailures(ctx, failures, ref) - - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - podInterface.delete(ref.podName) - - ref.KillTaskPod() - time.Sleep(time.Second) - assert.Equal(t, sub.Len(), 1) - - message := sub.Get() - containerMsg, ok := message.(*sproto.ResourcesStateChanged) - if !ok { - t.Errorf("expected *sproto.ResourcesStateChanged but received %s", - reflect.TypeOf(message)) - } - - var correctContainerStarted *sproto.ResourcesStarted - var correctCode *sproto.ExitCode - - assert.Equal(t, containerMsg.ResourcesStarted, correctContainerStarted) - assert.Equal(t, containerMsg.ResourcesStopped.Failure.FailureType, - sproto.FailureType("task failed without an associated exit code")) - assert.Equal(t, containerMsg.ResourcesStopped.Failure.ErrMsg, - "pod handler exited while pod was running") - assert.Equal(t, containerMsg.ResourcesStopped.Failure.ExitCode, correctCode) -} - -func TestGetPodNodeInfo(t *testing.T) { - setupEntrypoint(t) - - ref, aID, sub := createPodWithMockQueue(t, nil) - ref.slots = 99 - time.Sleep(time.Second) - - purge(aID, sub) - assert.Equal(t, sub.Len(), 0) - - podInfo := ref.getPodNodeInfo() - time.Sleep(time.Second) - - assert.Equal(t, podInfo.nodeName, ref.pod.Spec.NodeName) - assert.Equal(t, podInfo.numSlots, ref.slots) -} - -var sentinelEvent = &sproto.ContainerLog{ContainerID: "sentinel"} - -func purge(aID model.AllocationID, sub *sproto.ResourcesSubscription) { - rmevents.Publish(aID, sentinelEvent) - for { - event := sub.Get() - if event == sentinelEvent { - return - } - } -} diff --git a/master/internal/rm/kubernetesrm/pods.go b/master/internal/rm/kubernetesrm/pods.go deleted file mode 100644 index 13df704700b..00000000000 --- a/master/internal/rm/kubernetesrm/pods.go +++ /dev/null @@ -1,1783 +0,0 @@ -package kubernetesrm - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - "reflect" - "strconv" - "strings" - "sync" - "time" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - k8sV1 "k8s.io/api/core/v1" - k8error "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/api/resource" - metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/watch" - k8sClient "k8s.io/client-go/kubernetes" - typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" - "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" - "k8s.io/client-go/util/homedir" - - "github.com/determined-ai/determined/master/internal/config" - "github.com/determined-ai/determined/master/internal/db" - "github.com/determined-ai/determined/master/internal/rm/rmevents" - "github.com/determined-ai/determined/master/internal/sproto" - "github.com/determined-ai/determined/master/pkg/cproto" - "github.com/determined-ai/determined/master/pkg/device" - "github.com/determined-ai/determined/master/pkg/logger" - "github.com/determined-ai/determined/master/pkg/model" - "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/set" - "github.com/determined-ai/determined/master/pkg/syncx/waitgroupx" - "github.com/determined-ai/determined/master/pkg/tasks" - "github.com/determined-ai/determined/proto/pkg/apiv1" - - // Used to load all auth plugins. - _ "k8s.io/client-go/plugin/pkg/client/auth" -) - -// ResourceTypeNvidia describes the GPU resource type. -const ResourceTypeNvidia = "nvidia.com/gpu" - -const ( - getAgentsCacheDuration = 15 * time.Second - summarizeCacheDuration = 5 * time.Second -) - -type podMetadata struct { - podName string - containerID string -} - -type podStatusUpdateCallback func(sproto.UpdatePodStatus) - -// High lever overview of the actors within the kubernetes package: -// -// pods -// +- pod(s): manages pod lifecycle. One per container in a task. -// +- podLogStreamer: stream logs for a specific pod. -// +- informer: sends updates about pod states -// +- events: sends updates about kubernetes events. -// +- requestQueue: queues requests to create / delete kubernetes resources. -// +- requestProcessingWorkers: processes request to create / delete kubernetes resources. -// -// TODO(DET-10011): Give this literal a more intuitive name. -type pods struct { - mu sync.RWMutex - wg waitgroupx.Group - - namespace string - namespaceToPoolName map[string]string - masterServiceName string - scheduler string - slotType device.Type - slotResourceRequests config.PodSlotResourceRequests - resourcePoolConfigs []config.ResourcePoolConfig - baseContainerDefaults *model.TaskContainerDefaultsConfig - - kubeconfigPath string - - clientSet k8sClient.Interface - detMasterIP string - detMasterPort int32 - masterTLSConfig model.TLSClientConfig - loggingTLSConfig model.TLSClientConfig - loggingConfig model.LoggingConfig - - resourceRequestQueue *requestQueue - podNameToPodHandler map[string]*pod - podNameToResourcePool map[string]string - containerIDToPodName map[string]string - containerIDToSchedulingState map[string]sproto.SchedulingState - podNameToContainerID map[string]string - podHandlerToMetadata map[*pod]podMetadata - nodeToSystemResourceRequests map[string]int64 - - currentNodes map[string]*k8sV1.Node - - podInterfaces map[string]typedV1.PodInterface - configMapInterfaces map[string]typedV1.ConfigMapInterface - - // TODO(RM-236) make one cache and make this code more straightforward. - summarizeCacheLock sync.RWMutex - summarizeCache summarizeResult - summarizeCacheTime time.Time - getAgentsCacheLock sync.Mutex - getAgentsCache *apiv1.GetAgentsResponse - getAgentsCacheTime time.Time - - syslog *logrus.Entry - - podStatusUpdateCallback podStatusUpdateCallback -} - -type summarizeResult struct { - summary map[string]model.AgentSummary - err error -} - -// PodsInfo contains information for pods. -type PodsInfo struct { - NumAgents int - SlotsAvailable int -} - -// SummarizeResources summerize pods resource. -type SummarizeResources struct { - PoolName string -} - -type reattachAllocationPods struct { - req *sproto.AllocateRequest - numPods int - allocationID model.AllocationID - slots int - logContext logger.Context -} - -type reattachPodResponse struct { - containerID string - started *sproto.ResourcesStarted -} - -type refreshPodStates struct { - allocationID model.AllocationID -} - -// newPodsService creates a new pod service for launching, querying and interacting with k8s pods. -func newPodsService( - namespace string, - namespaceToPoolName map[string]string, - masterServiceName string, - masterTLSConfig model.TLSClientConfig, - loggingConfig model.LoggingConfig, - scheduler string, - slotType device.Type, - slotResourceRequests config.PodSlotResourceRequests, - resourcePoolConfigs []config.ResourcePoolConfig, - taskContainerDefaults *model.TaskContainerDefaultsConfig, - detMasterIP string, - detMasterPort int32, - kubeconfigPath string, - podStatusUpdateCallback podStatusUpdateCallback, -) *pods { - loggingTLSConfig := masterTLSConfig - if loggingConfig.ElasticLoggingConfig != nil { - loggingTLSConfig = loggingConfig.ElasticLoggingConfig.Security.TLS - } - p := &pods{ - wg: waitgroupx.WithContext(context.Background()), - - namespace: namespace, - namespaceToPoolName: namespaceToPoolName, - masterServiceName: masterServiceName, - masterTLSConfig: masterTLSConfig, - scheduler: scheduler, - loggingTLSConfig: loggingTLSConfig, - loggingConfig: loggingConfig, - podNameToPodHandler: make(map[string]*pod), - podNameToResourcePool: make(map[string]string), - containerIDToPodName: make(map[string]string), - containerIDToSchedulingState: make(map[string]sproto.SchedulingState), - podNameToContainerID: make(map[string]string), - podHandlerToMetadata: make(map[*pod]podMetadata), - slotType: slotType, - slotResourceRequests: slotResourceRequests, - resourcePoolConfigs: resourcePoolConfigs, - baseContainerDefaults: taskContainerDefaults, - detMasterIP: detMasterIP, - detMasterPort: detMasterPort, - currentNodes: make(map[string]*k8sV1.Node), - nodeToSystemResourceRequests: make(map[string]int64), - podInterfaces: make(map[string]typedV1.PodInterface), - configMapInterfaces: make(map[string]typedV1.ConfigMapInterface), - syslog: logrus.WithField("namespace", namespace), - podStatusUpdateCallback: podStatusUpdateCallback, - - kubeconfigPath: kubeconfigPath, - } - - if err := p.startClientSet(); err != nil { - panic(err) - } - if err := p.getMasterIPAndPort(); err != nil { - panic(err) - } - if err := p.getSystemResourceRequests(); err != nil { - panic(err) - } - - p.startResourceRequestQueue() - - if err := p.deleteDoomedKubernetesResources(); err != nil { - panic(err) - } - - err := p.startPodInformer() - if err != nil { - panic(err) - } - - err = p.startNodeInformer() - switch { - case err != nil && k8error.IsForbidden(err): - p.syslog.Warnf("unable to start node informer due to permission error,"+ - "some features will be degraded: %s", err, - ) - case err != nil: - panic(err) - } - - err = p.startEventListeners() - if err != nil { - panic(err) - } - - err = p.startPreemptionListeners() - if err != nil { - panic(err) - } - - return p -} - -// StartTaskPod notifies the pods actor to start a pod with the task spec. -type StartTaskPod struct { - Req *sproto.AllocateRequest - AllocationID model.AllocationID - Spec tasks.TaskSpec - Slots int - Rank int - ResourcePool string - Namespace string - - LogContext logger.Context -} - -func (p *pods) StartTaskPod(msg StartTaskPod) error { - p.mu.Lock() - defer p.mu.Unlock() - return p.receiveStartTaskPod(msg) -} - -func (p *pods) ChangePriority(podID cproto.ID) { - p.mu.Lock() - defer p.mu.Unlock() - p.receivePriorityChange(podID) -} - -func (p *pods) ChangePosition(podID cproto.ID) { - p.mu.Lock() - defer p.mu.Unlock() - p.receivePositionChange(podID) -} - -func (p *pods) KillPod(podID cproto.ID) { - p.mu.Lock() - defer p.mu.Unlock() - p.receiveKillPod(podID) -} - -func (p *pods) SummarizeResources(msg SummarizeResources) (*PodsInfo, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.receiveResourceSummarize(msg) -} - -func (p *pods) ReattachAllocationPods(msg reattachAllocationPods) ([]reattachPodResponse, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.reattachAllocationPods(msg) -} - -func (p *pods) RefreshPodStates(msg refreshPodStates) error { - p.mu.Lock() - defer p.mu.Unlock() - return p.refreshPodStates(msg.allocationID) -} - -func (p *pods) GetSlots(msg *apiv1.GetSlotsRequest) *apiv1.GetSlotsResponse { - p.mu.Lock() - defer p.mu.Unlock() - return p.handleGetSlotsRequest(msg.AgentId) -} - -func (p *pods) GetSlot(msg *apiv1.GetSlotRequest) *apiv1.GetSlotResponse { - p.mu.Lock() - defer p.mu.Unlock() - return p.handleGetSlotRequest(msg.AgentId, msg.SlotId) -} - -func (p *pods) HealthStatus() model.HealthStatus { - p.mu.Lock() - defer p.mu.Unlock() - for _, podInterface := range p.podInterfaces { - _, err := podInterface.List(context.TODO(), metaV1.ListOptions{Limit: 1}) - if err != nil { - p.syslog.WithError(err).Error("kubernetes resource manager marked as unhealthy") - return model.Unhealthy - } - return model.Healthy - } - - logrus.Error("expected podInterfaces to be non empty") - return model.Unhealthy -} - -func (p *pods) GetAgents() *apiv1.GetAgentsResponse { - p.mu.Lock() - defer p.mu.Unlock() - return p.handleGetAgentsRequest() -} - -func (p *pods) GetAgent(msg *apiv1.GetAgentRequest) *apiv1.GetAgentResponse { - p.mu.Lock() - defer p.mu.Unlock() - return p.handleGetAgentRequest(msg.AgentId) -} - -func (p *pods) EnableAgent(msg *apiv1.EnableAgentRequest) (*apiv1.EnableAgentResponse, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.enableNode(msg.AgentId) -} - -func (p *pods) DisableAgent(msg *apiv1.DisableAgentRequest) (*apiv1.DisableAgentResponse, error) { - p.mu.Lock() - defer p.mu.Unlock() - return p.disableNode(msg.AgentId, msg.Drain) -} - -func readClientConfig(kubeconfigPath string) (*rest.Config, error) { - if len(kubeconfigPath) == 0 { - // The default in-cluster case. Internally, k8s.io/client-go/rest is going to look for - // environment variables: - // - KUBERNETES_SERVICE_HOST - // - KUBERNETES_SERVICE_PORT - // and it expects to find files: - // - /var/run/secrets/kubernetes.io/serviceaccount/token - // - /var/run/secrets/kubernetes.io/serviceaccount/ca.crt - return rest.InClusterConfig() - } - - if parts := strings.Split(kubeconfigPath, string(os.PathSeparator)); parts[0] == "~" { - parts[0] = homedir.HomeDir() - expanded := filepath.Join(parts...) - logrus.Infof("expanding kubeconfig path from %s to %s", kubeconfigPath, expanded) - kubeconfigPath = expanded - } - - bs, err := os.ReadFile(kubeconfigPath) // #nosec G304 // User must have fs access to set this config var anyway. - if err != nil { - return nil, fmt.Errorf("reading kubeconfig at %s: %w", kubeconfigPath, err) - } - - cl, err := clientcmd.RESTConfigFromKubeConfig(bs) - if err != nil { - return nil, fmt.Errorf("building rest.Config from kubeconfig at %s: %w", kubeconfigPath, err) - } - return cl, nil -} - -func (p *pods) startClientSet() error { - config, err := readClientConfig(p.kubeconfigPath) - if err != nil { - return errors.Wrap(err, "error building kubernetes config") - } - - p.clientSet, err = k8sClient.NewForConfig(config) - if err != nil { - return errors.Wrap(err, "failed to initialize kubernetes clientSet") - } - - for _, ns := range append(maps.Keys(p.namespaceToPoolName), p.namespace) { - p.podInterfaces[ns] = p.clientSet.CoreV1().Pods(ns) - p.configMapInterfaces[ns] = p.clientSet.CoreV1().ConfigMaps(ns) - } - - p.syslog.Infof("kubernetes clientSet initialized") - return nil -} - -func (p *pods) getMasterIPAndPort() error { - if p.detMasterIP != "" && p.detMasterPort != 0 { - // Master ip and port were manually configured. For special circumstances, e.g., the master is running - // outside of this cluster (happens in development or when we spread across multiple k8s clusters). - return nil - } - masterService, err := p.clientSet.CoreV1().Services(p.namespace).Get( - context.TODO(), p.masterServiceName, metaV1.GetOptions{}) - if err != nil { - return errors.Wrap(err, "failed to get master service") - } - - p.detMasterIP = masterService.Spec.ClusterIP - p.detMasterPort = masterService.Spec.Ports[0].Port - p.syslog.Infof("master URL set to %s:%d", p.detMasterIP, p.detMasterPort) - return nil -} - -func (p *pods) getSystemResourceRequests() error { - systemPods, err := p.podInterfaces[p.namespace].List( - context.TODO(), metaV1.ListOptions{LabelSelector: determinedSystemLabel}) - if err != nil { - return errors.Wrap(err, "failed to get system pods") - } - - for _, systemPod := range systemPods.Items { - for _, container := range systemPod.Spec.Containers { - p.nodeToSystemResourceRequests[systemPod.Spec.NodeName] += container.Resources.Requests.Cpu(). - MilliValue() - } - } - return nil -} - -func (p *pods) reattachAllocationPods(msg reattachAllocationPods) ([]reattachPodResponse, error) { - listOptions := metaV1.ListOptions{ - LabelSelector: fmt.Sprintf("%s=%s", determinedLabel, msg.allocationID), - } - - pods, err := p.listPodsInAllNamespaces(context.TODO(), listOptions) - if err != nil { - return nil, errors.Wrap(err, "error listing pods checking if they can be restored") - } - - configMaps, err := p.listConfigMapsInAllNamespaces(context.TODO(), listOptions) - if err != nil { - return nil, errors.Wrap(err, "error listing config maps checking if they can be restored") - } - existingConfigMaps := make(set.Set[string]) - for _, cm := range configMaps.Items { - if _, ok := p.namespaceToPoolName[cm.Namespace]; !ok { - continue - } - existingConfigMaps.Insert(cm.Name) - } - - var containerIDs []string - var k8sPods []*k8sV1.Pod - var ports [][]int - var resourcePool string - for _, pod := range pods.Items { - if _, ok := p.namespaceToPoolName[pod.Namespace]; !ok { - continue - } - - foundID := false - foundPool := false - for _, container := range pod.Spec.Containers { - for _, env := range container.Env { - switch env.Name { - case "DET_CONTAINER_ID": - if !existingConfigMaps.Contains(pod.Name) { - p.deleteKubernetesResources(pods, configMaps) - return nil, fmt.Errorf("pod missing config map %s", pod.Name) - } - - p := pod - k8sPods = append(k8sPods, &p) - containerIDs = append(containerIDs, env.Value) - - var podPorts []int - for _, p := range container.Ports { - podPorts = append(podPorts, int(p.ContainerPort)) - } - ports = append(ports, podPorts) - - foundID = true - case resourcePoolEnvVar: - resourcePool = env.Value - foundPool = true - } - } - if foundID && foundPool { - break - } - } - } - - if len(k8sPods) != msg.numPods { - p.deleteKubernetesResources(pods, configMaps) - return nil, fmt.Errorf("not enough pods found for allocation expected %d got %d instead", - msg.numPods, len(k8sPods)) - } - - if err := p.dontReattachQueuedPreAgentDisabledPods(pods, configMaps); err != nil { - return nil, err - } - - var restoreResponses []reattachPodResponse - for i, containerID := range containerIDs { - resp, err := p.reattachPod(msg.req, msg.allocationID, resourcePool, containerID, - k8sPods[i], ports[i], msg.slots, msg.logContext) - if err != nil { - p.deleteKubernetesResources(pods, configMaps) - return nil, errors.Wrapf(err, - "error restoring pod with containerID %s", containerID) - } - restoreResponses = append(restoreResponses, resp) - } - - return restoreResponses, nil -} - -func (p *pods) dontReattachQueuedPreAgentDisabledPods( - pods *k8sV1.PodList, configMaps *k8sV1.ConfigMapList, -) error { - // This is needed to label pods created before Determined supported k8s agent enable disable. - // We will not reattach pods that are queued and don't have the affinity that respects - // agent disabling. Not many people should be relying on this feature when this will be released - // since it was behind _agent_reattach_enabled until the version this is also released on. - // We can't patch the pods with the needed field, as a limitation of Kubernetes. - for _, pod := range pods.Items { - pod := pod - if pod.Spec.NodeName == "" { // Only do this for pods not assigned to a node yet. - before := pod.DeepCopy() - addNodeDisabledAffinityToPodSpec(&pod, clusterIDNodeLabel()) - - if !reflect.DeepEqual(pod.Spec, before.Spec) { - p.deleteKubernetesResources(pods, configMaps) - return fmt.Errorf( - "unable to restore pod %s since it was queued and does not have the needed "+ - "Determined's affinity to prevent scheduling on disabled nodes. "+ - "This is expected to happen on allocations with queued pods "+ - "when upgrading from before 0.25.1 "+ - "to after or equal to 0.26.1", pod.Name) - } - } - } - - return nil -} - -func (p *pods) reattachPod( - req *sproto.AllocateRequest, - allocationID model.AllocationID, - resourcePool string, - containerID string, - pod *k8sV1.Pod, - ports []int, - slots int, - logContext logger.Context, -) (reattachPodResponse, error) { - startMsg := StartTaskPod{ - Req: req, - AllocationID: allocationID, - Spec: tasks.TaskSpec{ - ContainerID: containerID, - }, - Slots: slots, - ResourcePool: resourcePool, - LogContext: logContext, - } - - newPodHandler := newPod( - startMsg, - startMsg.Spec.ClusterID, - p.clientSet, - pod.Namespace, - p.detMasterIP, - p.detMasterPort, - p.masterTLSConfig, - p.loggingTLSConfig, - p.loggingConfig, - p.podInterfaces[pod.Namespace], - p.configMapInterfaces[pod.Namespace], - p.resourceRequestQueue, - p.slotType, - p.slotResourceRequests, - p.scheduler, - ) - - newPodHandler.restore = true - newPodHandler.podName = pod.Name - newPodHandler.configMapName = pod.Name - newPodHandler.ports = ports - - state, err := newPodHandler.getPodState(pod, newPodHandler.containerNames) - if err != nil { - return reattachPodResponse{}, errors.Wrap(err, "error finding pod state to restore") - } - // Don't set container state if the state is terminated. - // This is so that when we send the update message we will go - // through pod shutdown logic and avoid dropping a duplicate state messages. - if state != cproto.Terminated { - newPodHandler.container.State = state - } - - var started *sproto.ResourcesStarted - if newPodHandler.container.State == cproto.Running { - started = ptrs.Ptr(getResourcesStartedForPod(pod, newPodHandler.ports)) - } - - newPodHandler.pod = pod - - err = newPodHandler.start() - if err != nil { - return reattachPodResponse{}, fmt.Errorf("reattaching pod: %w", err) - } - - p.podNameToPodHandler[pod.Name] = newPodHandler - p.podNameToResourcePool[pod.Name] = resourcePool - p.containerIDToPodName[containerID] = pod.Name - p.podNameToContainerID[pod.Name] = containerID - p.containerIDToSchedulingState[containerID] = sproto.SchedulingStateQueued - p.podHandlerToMetadata[newPodHandler] = podMetadata{ - podName: pod.Name, - containerID: containerID, - } - - return reattachPodResponse{containerID: containerID, started: started}, nil -} - -func (p *pods) refreshPodStates(allocationID model.AllocationID) error { - if allocationID == "" { - return fmt.Errorf("invalid call: allocationID missing") - } - - pods, err := p.listPodsInAllNamespaces(context.TODO(), metaV1.ListOptions{ - LabelSelector: fmt.Sprintf("%s=%s", determinedLabel, allocationID), - }) - if err != nil { - return errors.Wrap(err, "error listing pods checking if they can be restored") - } - - for _, pod := range pods.Items { - if _, ok := p.namespaceToPoolName[pod.Namespace]; !ok { - continue - } - pod := pod - p.podStatusCallback(watch.Event{Object: &pod}) - } - return nil -} - -func (p *pods) deleteKubernetesResources( - pods *k8sV1.PodList, configMaps *k8sV1.ConfigMapList, -) { - for _, pod := range pods.Items { - p.resourceRequestQueue.deleteKubernetesResources(pod.Namespace, pod.Name, "") - } - - for _, configMap := range configMaps.Items { - p.resourceRequestQueue.deleteKubernetesResources(configMap.Namespace, "", configMap.Name) - } -} - -func (p *pods) deleteDoomedKubernetesResources() error { - var openAllocations []model.Allocation - if err := db.Bun().NewSelect().Model(&openAllocations). - Where("end_time IS NULL"). - Scan(context.TODO()); err != nil { - return errors.Wrap(err, "error querying the database for open allocations") - } - openAllocationIDs := make(set.Set[model.AllocationID]) - for _, alloc := range openAllocations { - openAllocationIDs.Insert(alloc.AllocationID) - } - - listOptions := metaV1.ListOptions{LabelSelector: determinedLabel} - pods, err := p.listPodsInAllNamespaces(context.TODO(), listOptions) - if err != nil { - return errors.Wrap(err, "error listing existing pods") - } - toKillPods := &k8sV1.PodList{} - savedPodNames := make(set.Set[string]) - for _, pod := range pods.Items { - if _, ok := p.namespaceToPoolName[pod.Namespace]; !ok { - continue - } - - resourcePool := (func() string { - for _, c := range pod.Spec.Containers { - for _, e := range c.Env { - if e.Name == resourcePoolEnvVar { - return e.Value - } - } - } - return "" - })() - - if resourcePool == "" { - p.syslog.Debugf("deleting pod '%s' without environment variable '%s'", - pod.Name, resourcePoolEnvVar) - toKillPods.Items = append(toKillPods.Items, pod) - continue - } - - if !openAllocationIDs.Contains(model.AllocationID(pod.Labels[determinedLabel])) { - p.syslog.Warnf("deleting pod '%s', did not find open allocation '%s'", - pod.Name, pod.Labels[determinedLabel]) - toKillPods.Items = append(toKillPods.Items, pod) - continue - } - savedPodNames.Insert(pod.Name) - } - - configMaps, err := p.listConfigMapsInAllNamespaces(context.TODO(), listOptions) - if err != nil { - return errors.Wrap(err, "error listing existing config maps") - } - toKillConfigMaps := &k8sV1.ConfigMapList{} - for _, cm := range configMaps.Items { - if _, ok := p.namespaceToPoolName[cm.Namespace]; !ok { - continue - } - - if savedPodNames.Contains(cm.Name) { // PodName is same as config map name. - continue - } - - p.syslog.Debugf("Deleting config map '%s' did not find a matching pod that will be restored", - cm.Name) - toKillConfigMaps.Items = append(toKillConfigMaps.Items, cm) - } - - p.deleteKubernetesResources(toKillPods, toKillConfigMaps) - return nil -} - -func (p *pods) startPodInformer() error { - for namespace := range p.namespaceToPoolName { - i, err := newPodInformer( - context.TODO(), - determinedLabel, - "pod", - namespace, - p.podInterfaces[namespace], - func(event watch.Event) { - p.mu.Lock() - defer p.mu.Unlock() - p.podStatusCallback(event) - }, - ) - if err != nil { - return err - } - - go i.run(context.TODO()) - } - return nil -} - -func (p *pods) startNodeInformer() error { - i, err := newNodeInformer( - context.TODO(), - p.clientSet.CoreV1().Nodes(), - func(event watch.Event) { - p.mu.Lock() - defer p.mu.Unlock() - p.nodeStatusCallback(event) - }) - if err != nil { - return err - } - - go i.run(context.TODO()) - return nil -} - -func (p *pods) startEventListeners() error { - for namespace := range p.namespaceToPoolName { - l, err := newEventInformer( - context.TODO(), - p.clientSet.CoreV1().Events(namespace), - namespace, - func(event watch.Event) { - p.mu.Lock() - defer p.mu.Unlock() - p.eventStatusCallback(event) - }) - if err != nil { - return err - } - go l.run(context.TODO()) - } - return nil -} - -func (p *pods) startPreemptionListeners() error { - for namespace := range p.namespaceToPoolName { - l, err := newPodInformer( - context.TODO(), - determinedPreemptionLabel, - "preemption", - namespace, - p.clientSet.CoreV1().Pods(namespace), - func(event watch.Event) { - p.mu.Lock() - defer p.mu.Unlock() - p.preemptionCallback(event) - }) - if err != nil { - return err - } - go l.run(context.TODO()) - } - return nil -} - -func (p *pods) startResourceRequestQueue() { - failures := make(chan resourcesRequestFailure, 16) - p.resourceRequestQueue = startRequestQueue(p.podInterfaces, p.configMapInterfaces, failures) - p.wg.Go(func(ctx context.Context) { - for { - select { - case failure := <-failures: - p.handleResourceRequestFailure(failure) - case <-ctx.Done(): - return - } - } - }) -} - -func (p *pods) handleResourceRequestFailure(msg resourcesRequestFailure) { - p.mu.Lock() - defer p.mu.Unlock() - - podName := msg.getPodName() - podHandler, ok := p.podNameToPodHandler[podName] - if !ok { - p.syslog.Warnf("received resource request error for unregistered pod %s", podName) - return - } - - switch msg := msg.(type) { - case resourceCreationFailed: - podHandler.receiveResourceCreationFailed(msg) - case resourceCreationCancelled: - podHandler.receiveResourceCreationCancelled() - case resourceDeletionFailed: - podHandler.receiveResourceDeletionFailed(msg) - default: - panic(fmt.Sprintf("unexpected message %T", msg)) - } - - err := p.cleanUpPodHandler(podHandler) - if err != nil { - p.syslog.WithError(err).Error("cleaning up pod handler after resource request failure") - } -} - -func (p *pods) receiveStartTaskPod(msg StartTaskPod) error { - newPodHandler := newPod( - msg, - msg.Spec.ClusterID, - p.clientSet, - msg.Namespace, - p.detMasterIP, - p.detMasterPort, - p.masterTLSConfig, - p.loggingTLSConfig, - p.loggingConfig, - p.podInterfaces[msg.Namespace], - p.configMapInterfaces[msg.Namespace], - p.resourceRequestQueue, - p.slotType, - p.slotResourceRequests, - p.scheduler, - ) - - if _, alreadyExists := p.podNameToPodHandler[newPodHandler.podName]; alreadyExists { - return errors.Errorf( - "attempting to register same pod name: %s multiple times", newPodHandler.podName) - } - - err := newPodHandler.start() - if err != nil { - return fmt.Errorf("creating pod: %w", err) - } - - p.podNameToPodHandler[newPodHandler.podName] = newPodHandler - p.podNameToResourcePool[newPodHandler.podName] = msg.ResourcePool - p.containerIDToPodName[msg.Spec.ContainerID] = newPodHandler.podName - p.podNameToContainerID[newPodHandler.podName] = msg.Spec.ContainerID - p.containerIDToSchedulingState[msg.Spec.ContainerID] = sproto.SchedulingStateQueued - p.podHandlerToMetadata[newPodHandler] = podMetadata{ - podName: newPodHandler.podName, - containerID: msg.Spec.ContainerID, - } - - return nil -} - -func (p *pods) podStatusCallback(event watch.Event) { - pod, ok := event.Object.(*k8sV1.Pod) - if !ok { - p.syslog.Warnf("error converting event of type %T to *k8sV1.Pod: %+v", event, event) - return - } - syslog := p.syslog.WithField("pod", pod.Name) - syslog.WithField("event.Type", event.Type).Debug("received pod informer event") - - podHandler, ok := p.podNameToPodHandler[pod.Name] - if !ok { - syslog.Debug("received status update for un-registered pod") - return - } - - state, err := podHandler.podStatusUpdate(pod) - switch { - case err != nil: - syslog.WithError(err).Error("error processing pod status update") - err := p.cleanUpPodHandler(podHandler) - if err != nil { - syslog.WithError(err).Error("unable to cleanup pod handler after update error") - } - return - case state == cproto.Terminated: - err := p.cleanUpPodHandler(podHandler) - if err != nil { - syslog.WithError(err).Error("unable to cleanup pod handler after termination") - } - } - - if containerID, ok := p.podNameToContainerID[pod.Name]; ok { - if state, ok := p.containerIDToSchedulingState[containerID]; ok { - currState := sproto.SchedulingStateQueued - if pod.Status.Phase == "Running" { - currState = sproto.SchedulingStateScheduled - } - if currState != state { - p.containerIDToSchedulingState[containerID] = currState - go p.podStatusUpdateCallback(sproto.UpdatePodStatus{ - ContainerID: containerID, - State: currState, - }) - } - } - } -} - -var ( - clusterID string - once sync.Once -) - -func setClusterID(s string) { - once.Do(func() { - clusterID = s - }) -} - -func clusterIDNodeLabel() string { - return fmt.Sprintf("determined.ai/cluster-id-%s", clusterID) -} - -const ( - noExecuteNodeLabelValue = "no-execute" - noScheduleNodeLabelValue = "no-schedule" -) - -func (p *pods) enableNode( - nodeName string, -) (*apiv1.EnableAgentResponse, error) { - patch := []byte(fmt.Sprintf(`{ - "metadata": { - "labels": { - "%s": null - } - } - }`, clusterIDNodeLabel())) - - _, err := p.clientSet.CoreV1().Nodes(). - Patch(context.TODO(), nodeName, types.StrategicMergePatchType, patch, metaV1.PatchOptions{}) - if k8error.IsForbidden(err) { - return nil, fmt.Errorf("the Determined master Kubernetes service account " + - "is missing permissions to patch nodes. " + - "Enabling or disabling nodes requires this permission, " + - "however Determined will otherwise still function correctly without " + - "these Kubernetes permissions") - } else if err != nil { - return nil, fmt.Errorf( - "enabling node %s by removing the Determined no schedule label: %w", nodeName, err) - } - p.syslog.Infof("node %s enabled by an user", nodeName) - - n, ok := p.summarizeClusterByNodes()[nodeName] - if !ok { - return nil, fmt.Errorf("node %s enabled without error, error getting node summary", nodeName) - } - n.Enabled = true - n.Draining = false - for slotKey := range n.Slots { - s := n.Slots[slotKey] - s.Enabled = n.Enabled - s.Draining = n.Draining - n.Slots[slotKey] = s - } - - return &apiv1.EnableAgentResponse{ - Agent: n.ToProto(), - }, nil -} - -func (p *pods) disableNode( - nodeName string, shouldDrain bool, -) (*apiv1.DisableAgentResponse, error) { - labelValue := noExecuteNodeLabelValue - if shouldDrain { - labelValue = noScheduleNodeLabelValue - } - - patchStruct := metaV1.ObjectMeta{ - Labels: map[string]string{clusterIDNodeLabel(): labelValue}, - } - patch, err := json.Marshal(map[string]any{"metadata": patchStruct}) - if err != nil { - return nil, fmt.Errorf("marshaling JSON patch %v: %s", patchStruct, err) - } - - _, err = p.clientSet.CoreV1().Nodes(). - Patch(context.TODO(), nodeName, types.StrategicMergePatchType, patch, metaV1.PatchOptions{}) - if k8error.IsForbidden(err) { - return nil, fmt.Errorf("the Determined master Kubernetes service account " + - "is missing permissions to patch nodes. " + - "Enabling or disabling nodes requires this permission, " + - "however Determined will otherwise still function correctly without " + - "these Kubernetes permissions") - } else if err != nil { - return nil, fmt.Errorf( - "disabling node %s by adding the Determined no schedule label: %w", nodeName, err) - } - p.syslog.Infof("node %s disabled by an user", nodeName) - - if !shouldDrain { // See note in spec.go about how we could remove killing all pods here. - if err := p.releaseAllocationsOnDisabledNode(nodeName); err != nil { - return nil, fmt.Errorf( - "node disabled without error, error killing existing pod on node: %w", err) - } - } - - n, ok := p.summarizeClusterByNodes()[nodeName] - if !ok { - return nil, fmt.Errorf("node %s disabled without error, error getting node summary", nodeName) - } - n.Enabled = false - n.Draining = shouldDrain - for slotKey := range n.Slots { - s := n.Slots[slotKey] - s.Enabled = n.Enabled - s.Draining = n.Draining - n.Slots[slotKey] = s - } - - return &apiv1.DisableAgentResponse{ - Agent: n.ToProto(), - }, nil -} - -func (p *pods) releaseAllocationsOnDisabledNode(nodeName string) error { - listOptions := metaV1.ListOptions{ - LabelSelector: determinedLabel, - FieldSelector: fmt.Sprintf("spec.nodeName=%s", nodeName), - } - pods, err := p.listPodsInAllNamespaces(context.TODO(), listOptions) - if err != nil { - return fmt.Errorf("listing pods on node %s: %w", nodeName, err) - } - - notifiedAllocations := make(map[model.AllocationID]bool) - for _, pod := range pods.Items { - podHandler, ok := p.podNameToPodHandler[pod.Name] - if !ok { - p.syslog.Warnf( - "during node disable couldn't find pod %s's actor to kill", pod.Name) - continue - } - - p.syslog.Infof( - "stopping pod %s because node %s was disabled without drain option", pod.Name, nodeName) - if notifiedAllocations[podHandler.allocationID] { - continue - } - - rmevents.Publish(podHandler.allocationID, &sproto.ReleaseResources{ - Reason: "node disabled without drain", - ForceKill: true, - }) - notifiedAllocations[podHandler.allocationID] = true - } - - return nil -} - -func (p *pods) nodeStatusCallback(event watch.Event) { - node, ok := event.Object.(*k8sV1.Node) - if !ok { - p.syslog.Warnf("error converting event of type %T to *k8sV1.Node: %+v", event, event) - return - } - - p.syslog.Debugf(`informer got new node event for node '%s': %s %s`, - node.Name, event.Type, node.Status.Phase) - - switch event.Type { - case watch.Added: - p.currentNodes[node.Name] = node - case watch.Modified: - p.currentNodes[node.Name] = node - case watch.Deleted: - delete(p.currentNodes, node.Name) - default: - } -} - -func (p *pods) eventStatusCallback(event watch.Event) { - newEvent, ok := event.Object.(*k8sV1.Event) - if !ok { - p.syslog.Warnf("error converting object type %T to *k8sV1.Event: %+v", event, event) - return - } - - syslog := p.syslog.WithFields(logrus.Fields{ - "name": newEvent.InvolvedObject.Name, - "kind": newEvent.InvolvedObject.Kind, - }) - - syslog.Debugf("listener got new event: %s", newEvent.Message) - ref, ok := p.podNameToPodHandler[newEvent.InvolvedObject.Name] - if !ok { - // We log at the debug level because we are unable to filter - // pods based on their labels the way we do with pod status updates. - syslog.Debug("received pod event for an un-registered pod") - return - } - - ref.podEventUpdate(newEvent) -} - -func (p *pods) receiveResourceSummarize(msg SummarizeResources) (*PodsInfo, error) { - summary, err := p.summarize() - if err != nil { - return nil, err - } - - slots := 0 - if len(msg.PoolName) > 0 { - slots = numSlots(summary[msg.PoolName].Slots) - } else { - for _, pool := range summary { - slots += numSlots(pool.Slots) - } - } - return &PodsInfo{NumAgents: len(summary), SlotsAvailable: slots}, nil -} - -func (p *pods) preemptionCallback(event watch.Event) { - pod, ok := event.Object.(*k8sV1.Pod) - if !ok { - p.syslog.Warnf("error converting event of type %T to *k8sV1.Pod: %+v", event, event) - return - } - p.syslog.Debugf("informer got new preemption event for pod %s ", pod.Name) - - ref, ok := p.podNameToPodHandler[pod.Name] - if !ok { - p.syslog.Debug("received preemption command for unregistered pod") - return - } - ref.PreemptTaskPod() -} - -func (p *pods) verifyPodAndGetRef(podID string) *pod { - podName, ok := p.containerIDToPodName[podID] - if !ok { - p.syslog.WithField("pod-id", podID).Debug( - "received change priority command for unregistered container id") - return nil - } - ref, ok := p.podNameToPodHandler[podName] - if !ok { - p.syslog.WithField("pod-id", podID).Debug( - "received change priority command for unregistered container id") - return nil - } - - return ref -} - -func (p *pods) receivePriorityChange(podID cproto.ID) { - ref := p.verifyPodAndGetRef(podID.String()) - if ref != nil { - ref.ChangePriority() - } -} - -func (p *pods) receivePositionChange(podID cproto.ID) { - ref := p.verifyPodAndGetRef(podID.String()) - if ref != nil { - ref.ChangePosition() - } -} - -func (p *pods) receiveKillPod(podID cproto.ID) { - name, ok := p.containerIDToPodName[podID.String()] - if !ok { - // For multi-pod tasks, when the chief pod exits, the scheduler - // will request to terminate pods all other pods that have - // notified the scheduler that they have exited. - p.syslog.WithField("pod-id", podID).Info( - "received stop pod command for unregistered container id") - return - } - - ref, ok := p.podNameToPodHandler[name] - if !ok { - p.syslog.WithField("pod-id", podID).Info( - "received stop pod command for unregistered container id") - return - } - - ref.KillTaskPod() -} - -func (p *pods) cleanUpPodHandler(podHandler *pod) error { - podHandler.finalize() - - podInfo, ok := p.podHandlerToMetadata[podHandler] - if !ok { - return errors.Errorf("unknown pod handler being deleted %s", podHandler.podName) - } - - p.syslog.WithField("pod", podInfo.podName).WithField( - "handler", podHandler.podName).Infof("de-registering pod handler") - delete(p.podNameToPodHandler, podInfo.podName) - delete(p.podNameToResourcePool, podInfo.podName) - delete(p.podNameToContainerID, podInfo.podName) - delete(p.containerIDToPodName, podInfo.containerID) - delete(p.containerIDToSchedulingState, podInfo.containerID) - delete(p.podHandlerToMetadata, podHandler) - - // launch this work async, since we hold the lock and it does API calls. - p.wg.Go(func(ctx context.Context) { - name := fmt.Sprintf("%s-priorityclass", podInfo.containerID) - err := p.clientSet. - SchedulingV1(). - PriorityClasses(). - Delete(ctx, name, metaV1.DeleteOptions{}) - if err != nil && !k8error.IsNotFound(err) { - p.syslog.Warnf("Deletion of PriorityClass %s failed.", name) - } - }) - - return nil -} - -func (p *pods) handleGetSlotsRequest(agentID string) *apiv1.GetSlotsResponse { - agentResp := p.handleGetAgentRequest(agentID) - if agentResp == nil { - p.syslog.Warnf("no agent with id %s", agentID) - return nil - } - return &apiv1.GetSlotsResponse{Slots: maps.Values(agentResp.Agent.Slots)} -} - -func (p *pods) handleGetSlotRequest(agentID string, slotID string) *apiv1.GetSlotResponse { - agentResp := p.handleGetAgentRequest(agentID) - if agentResp == nil { - p.syslog.Warnf("no agent with id %s", agentID) - return nil - } - slots := agentResp.Agent.Slots - slot, ok := slots[slotID] - if !ok { - // Try converting an index input to a slot and see if that exists (1 to 001). - tryIndex, err := strconv.Atoi(slotID) - if s, ok := slots[model.SortableSlotIndex(tryIndex)]; err == nil && ok { - slot = s - } else { - p.syslog.Warnf("no slot with id %s", slotID) - return nil - } - } - return &apiv1.GetSlotResponse{Slot: slot} -} - -func (p *pods) handleGetAgentsRequest() *apiv1.GetAgentsResponse { - p.getAgentsCacheLock.Lock() - defer p.getAgentsCacheLock.Unlock() - - if time.Since(p.getAgentsCacheTime) > getAgentsCacheDuration { - p.getAgentsCacheTime = time.Now() - - nodeSummaries := p.summarizeClusterByNodes() - _, nodesToPools := p.getNodeResourcePoolMapping(nodeSummaries) - - p.getAgentsCache = &apiv1.GetAgentsResponse{} - for _, summary := range nodeSummaries { - summary.ResourcePool = nodesToPools[summary.ID] - p.getAgentsCache.Agents = append(p.getAgentsCache.Agents, summary.ToProto()) - } - } - - return p.getAgentsCache -} - -func (p *pods) handleGetAgentRequest(agentID string) *apiv1.GetAgentResponse { - nodeSummaries := p.summarizeClusterByNodes() - _, nodesToPools := p.getNodeResourcePoolMapping(nodeSummaries) - agentSummary, ok := nodeSummaries[agentID] - if !ok { - // TODO(DET-10029): We should return an error indicating the invalid ID request (rather - // than a warn). - p.syslog.Warnf("no agent with id %s", agentID) - return nil - } - agentSummary.ResourcePool = nodesToPools[agentSummary.ID] - return &apiv1.GetAgentResponse{Agent: agentSummary.ToProto()} -} - -// summarize describes pods' available resources. When there's exactly one resource pool, it uses -// the whole cluster's info. Otherwise, it matches nodes to resource pools using taints and -// tolerations to derive that info. This may be cached, so don't use this for decisions -// that require up-to-date information. -func (p *pods) summarize() (map[string]model.AgentSummary, error) { - p.summarizeCacheLock.Lock() - defer p.summarizeCacheLock.Unlock() - - if time.Since(p.summarizeCacheTime) > summarizeCacheDuration { - summary, err := p.computeSummary() - p.summarizeCacheTime = time.Now() - p.summarizeCache = summarizeResult{ - summary: summary, - err: err, - } - } - - return p.summarizeCache.summary, p.summarizeCache.err -} - -// Get the mapping of many-to-many relationship between nodes and resource pools. -func (p *pods) getNodeResourcePoolMapping(nodeSummaries map[string]model.AgentSummary) ( - map[string][]*k8sV1.Node, map[string][]string, -) { - poolTaskContainerDefaults := extractTCDs(p.resourcePoolConfigs) - - // Nvidia automatically taints nodes, so we should tolerate that when users don't customize - // their resource pool config. - defaultTolerations := []k8sV1.Toleration{{ - Key: ResourceTypeNvidia, - Value: "present", - Operator: k8sV1.TolerationOpEqual, - }} - cpuTolerations, gpuTolerations := extractTolerations(p.baseContainerDefaults) - poolsToNodes := make(map[string][]*k8sV1.Node, len(p.namespaceToPoolName)) - nodesToPools := make(map[string][]string, len(p.namespaceToPoolName)) - - for _, node := range p.currentNodes { - _, slotType := extractSlotInfo(nodeSummaries[node.Name]) - - for poolName, tcd := range poolTaskContainerDefaults { - var poolTolerations []k8sV1.Toleration - - // If they're using the default RP config, use the default tolerations. - if len(p.resourcePoolConfigs) <= 1 && - (tcd == nil || (tcd.CPUPodSpec == nil && tcd.GPUPodSpec == nil)) { - if slotType == device.CUDA { - //nolint:gocritic - poolTolerations = append(defaultTolerations, gpuTolerations...) - } else if slotType == device.CPU { - //nolint:gocritic - poolTolerations = append(defaultTolerations, cpuTolerations...) - } - } else if tcd != nil { - // Decide which poolTolerations to use based on slot device type - if slotType == device.CUDA && tcd.GPUPodSpec != nil { - //nolint:gocritic - poolTolerations = append(tcd.GPUPodSpec.Spec.Tolerations, gpuTolerations...) - } else if tcd.CPUPodSpec != nil { - //nolint:gocritic - poolTolerations = append(tcd.CPUPodSpec.Spec.Tolerations, cpuTolerations...) - } - } - - // add default toleration so that autoscaling nodes will still be counted. - poolTolerations = append(poolTolerations, k8sV1.Toleration{ - Key: "DeletionCandidateOfClusterAutoscaler", - Operator: "Exists", - Effect: "PreferNoSchedule", - TolerationSeconds: nil, - }) - // If all of a node's taints are tolerated by a pool, that node belongs to the pool. - if allTaintsTolerated(node.Spec.Taints, poolTolerations) { - poolsToNodes[poolName] = append(poolsToNodes[poolName], node) - nodesToPools[node.Name] = append(nodesToPools[node.Name], poolName) - } - } - } - - return poolsToNodes, nodesToPools -} - -var programStartTime = time.Now() - -func (p *pods) computeSummary() (map[string]model.AgentSummary, error) { - nodeSummaries := p.summarizeClusterByNodes() - - // Build the many-to-many relationship between nodes and resource pools - poolsToNodes, _ := p.getNodeResourcePoolMapping(nodeSummaries) - - // Build the set of summaries for each resource pool - containers := p.containersPerResourcePool() - summaries := make(map[string]model.AgentSummary, len(p.namespaceToPoolName)) - for poolName, nodes := range poolsToNodes { - slots := model.SlotsSummary{} - numContainersInPool := containers[poolName] - - // We'll create a number of pseudo-containers in the summary equal to the number of - // running containers in this pool. - pseudoContainersAdded := 0 - - for _, node := range nodes { - numSlots, slotType := extractSlotInfo(nodeSummaries[node.Name]) - - for j := 0; j < numSlots; j++ { - id := fmt.Sprintf("%s/%s/%s/%d", poolName, node.Name, string(slotType), j) - - var container *cproto.Container - if pseudoContainersAdded < numContainersInPool { - container = &cproto.Container{ - ID: cproto.ID(id), - State: "RUNNING", - } - pseudoContainersAdded++ - } - - slots[id] = model.SlotSummary{ - ID: id, - Device: device.Device{Type: slotType}, - Enabled: true, - Container: container, - } - } - } - - summaries[poolName] = model.AgentSummary{ - ID: poolName, - RegisteredTime: programStartTime, - NumContainers: numContainersInPool, - ResourcePool: []string{poolName}, - Slots: slots, - } - } - - return summaries, nil -} - -func (p *pods) summarizeClusterByNodes() map[string]model.AgentSummary { - var allPods []podNodeInfo - - for _, p := range p.podNameToPodHandler { - allPods = append(allPods, p.getPodNodeInfo()) - } - - // Separate pods by nodes. - podByNode := make(map[string][]podNodeInfo, len(allPods)) - for _, podInfo := range allPods { - if len(podInfo.nodeName) == 0 { - // If a pod doesn't have a nodeName it means it has not yet - // been allocated to a node. - continue - } - podByNode[podInfo.nodeName] = append(podByNode[podInfo.nodeName], podInfo) - } - - nodeToTasks, taskSlots := p.getNonDetSlots(p.slotType) - summary := make(map[string]model.AgentSummary, len(p.currentNodes)) - for _, node := range p.currentNodes { - disabledLabel, isDisabled := node.Labels[clusterIDNodeLabel()] - isDraining := isDisabled && disabledLabel == noScheduleNodeLabelValue - - var numSlots int64 - var deviceType device.Type - - // TODO(DET-10010): slot type per node probably shouldn't be decided from pods literal - // (which has the same value for all nodes). - switch p.slotType { - case device.CPU: - resources := node.Status.Allocatable[k8sV1.ResourceCPU] - milliCPUs := resources.MilliValue() - p.nodeToSystemResourceRequests[node.Name] - numSlots = int64(float32(milliCPUs) / (1000. * p.slotResourceRequests.CPU)) - deviceType = device.CPU - case device.ROCM: - panic("ROCm is not supported on k8s yet") - case device.CUDA: - fallthrough - default: - resources := node.Status.Allocatable[ResourceTypeNvidia] - numSlots = resources.Value() - deviceType = device.CUDA - } - - if numSlots < 1 { - continue - } - - slotsSummary := make(model.SlotsSummary) - curSlot := 0 - for _, podInfo := range podByNode[node.Name] { - for i := 0; i < podInfo.numSlots; i++ { - if curSlot >= int(numSlots) { - p.syslog.Warnf("too many pods mapping to node %s", node.Name) - continue - } - - slotsSummary[model.SortableSlotIndex(curSlot)] = model.SlotSummary{ - ID: model.SortableSlotIndex(curSlot), - Device: device.Device{Type: deviceType}, - Draining: isDraining, - Enabled: !isDisabled, - Container: podInfo.container, - } - curSlot++ - } - } - - for _, taskName := range nodeToTasks[node.Name] { - for i := int64(0); i < taskSlots[taskName]; i++ { - if curSlot >= int(numSlots) { - p.syslog.Warnf("too many pods mapping to node %s", node.Name) - continue - } - - slotsSummary[model.SortableSlotIndex(curSlot)] = model.SlotSummary{ - ID: model.SortableSlotIndex(curSlot), - Device: device.Device{Type: deviceType}, - Draining: isDraining, - Enabled: !isDisabled, - Container: &cproto.Container{ - ID: cproto.ID(taskName), - State: "RUNNING", - Devices: []device.Device{}, - Description: "unknown", - }, - } - curSlot++ - } - } - - for i := curSlot; i < int(numSlots); i++ { - slotsSummary[model.SortableSlotIndex(i)] = model.SlotSummary{ - ID: model.SortableSlotIndex(i), - Device: device.Device{Type: deviceType}, - Draining: isDraining, - Enabled: !isDisabled, - } - } - - var addrs []string - for _, addr := range node.Status.Addresses { - addrs = append(addrs, addr.Address) - } - - summary[node.Name] = model.AgentSummary{ - ID: node.Name, - RegisteredTime: node.ObjectMeta.CreationTimestamp.Time, - Slots: slotsSummary, - NumContainers: len(podByNode[node.Name]) + len(nodeToTasks[node.Name]), - ResourcePool: []string{""}, - Addresses: addrs, - Draining: isDraining, - Enabled: !isDisabled, - } - } - - return summary -} - -func (p *pods) getNonDetPods() ([]k8sV1.Pod, error) { - // TODO(RM-235) use a filter in metaV1.ListOptions. This change gets a lot easier after - // we have K8s integration tests. Using a filter means we should really talk to a real - // k8s server. Doing an e2e test for this is possible but would take a lot more work. - allPods, err := p.listPodsInAllNamespaces(context.TODO(), metaV1.ListOptions{}) - if err != nil { - return nil, err - } - - var nonDetPods []k8sV1.Pod - for _, p := range allPods.Items { - _, isDet := p.Labels[determinedLabel] - _, isDetSystem := p.Labels[determinedSystemLabel] - - if !(isDet || isDetSystem) { - if p.Spec.NodeName != "" { - nonDetPods = append(nonDetPods, p) - } - } - } - return nonDetPods, nil -} - -func (p *pods) getNonDetSlots(deviceType device.Type) (map[string][]string, map[string]int64) { - nodeToTasks := make(map[string][]string, len(p.currentNodes)) - taskSlots := make(map[string]int64) - - nonDetPods, err := p.getNonDetPods() - if err != nil { - p.syslog.WithError(err).Warn("getting non determined pods, " + - "this may cause slots to look free when they are in use") - } - - if len(nonDetPods) == 0 { - return nodeToTasks, taskSlots - } - for _, node := range p.currentNodes { - nodeToTasks[node.Name] = []string{} - } - - // Ignore pods not yet scheduled on a node. - for _, pod := range nonDetPods { - if _, ok := nodeToTasks[pod.Spec.NodeName]; !ok { - continue - } - reqs := int64(0) - for _, c := range pod.Spec.Containers { - if deviceType == device.CPU { - reqs += p.getCPUReqs(c) - } else if deviceType == device.CUDA { - reqs += c.Resources.Requests.Name(ResourceTypeNvidia, resource.DecimalSI).Value() - } - } - if reqs > 0 { - nodeToTasks[pod.Spec.NodeName] = append(nodeToTasks[pod.Spec.NodeName], pod.Name) - taskSlots[pod.Name] = reqs - } - } - return nodeToTasks, taskSlots -} - -func (p *pods) getCPUReqs(c k8sV1.Container) int64 { - requested := float32(c.Resources.Requests.Cpu().MilliValue()) / - (1000. * p.slotResourceRequests.CPU) - return int64(requested) -} - -func (p *pods) containersPerResourcePool() map[string]int { - counts := make(map[string]int, len(p.namespaceToPoolName)) - for _, pool := range p.podNameToResourcePool { - counts[pool]++ - } - return counts -} - -func numSlots(slots model.SlotsSummary) int { - slotCountsByType := make(map[device.Type]int) - for _, slot := range slots { - slotCountsByType[slot.Device.Type]++ - } - - if slotCountsByType[device.CUDA] > 0 { - return slotCountsByType[device.CUDA] - } - - return slotCountsByType[device.CPU] -} - -func (p *pods) listPodsInAllNamespaces( - ctx context.Context, opts metaV1.ListOptions, -) (*k8sV1.PodList, error) { - res := &k8sV1.PodList{} - for n, i := range p.podInterfaces { - pods, err := i.List(ctx, opts) - if err != nil { - return nil, errors.Wrapf(err, "error listing pods for namespace %s", n) - } - - res.Items = append(res.Items, pods.Items...) - } - - return res, nil -} - -func (p *pods) listConfigMapsInAllNamespaces( - ctx context.Context, opts metaV1.ListOptions, -) (*k8sV1.ConfigMapList, error) { - res := &k8sV1.ConfigMapList{} - for n, i := range p.configMapInterfaces { - cms, err := i.List(ctx, opts) - if err != nil { - return nil, errors.Wrapf(err, "error listing config maps for namespace %s", n) - } - - res.Items = append(res.Items, cms.Items...) - } - - return res, nil -} - -func extractTCDs(resourcePoolConfigs []config.ResourcePoolConfig, -) map[string]*model.TaskContainerDefaultsConfig { - result := map[string]*model.TaskContainerDefaultsConfig{} - - for _, config := range resourcePoolConfigs { - result[config.PoolName] = config.TaskContainerDefaults - } - - return result -} - -func taintTolerated(taint k8sV1.Taint, tolerations []k8sV1.Toleration) bool { - for _, toleration := range tolerations { - if toleration.ToleratesTaint(&taint) { - return true - } - } - - return false -} - -func allTaintsTolerated(taints []k8sV1.Taint, tolerations []k8sV1.Toleration) bool { - for _, taint := range taints { - if !taintTolerated(taint, tolerations) { - return false - } - } - - return true -} - -func extractSlotInfo(node model.AgentSummary) (numSlots int, devType device.Type) { - var gpuSlots, cpuSlots int - - for _, slot := range node.Slots { - if slot.Device.Type == device.CPU { - cpuSlots++ - } else if slot.Device.Type == device.CUDA { - gpuSlots++ - } - } - - if gpuSlots > 0 { - return gpuSlots, device.CUDA - } - - return cpuSlots, device.CPU -} - -func extractTolerations(tcd *model.TaskContainerDefaultsConfig) ( - cpuTolerations, gpuTolerations []k8sV1.Toleration, -) { - if tcd != nil { - if tcd.GPUPodSpec != nil { - gpuTolerations = tcd.GPUPodSpec.Spec.Tolerations - } - if tcd.CPUPodSpec != nil { - cpuTolerations = tcd.CPUPodSpec.Spec.Tolerations - } - } - - return cpuTolerations, gpuTolerations -} diff --git a/master/internal/rm/kubernetesrm/request_queue.go b/master/internal/rm/kubernetesrm/request_queue.go index 9cbcf5b9728..9dcb7247c92 100644 --- a/master/internal/rm/kubernetesrm/request_queue.go +++ b/master/internal/rm/kubernetesrm/request_queue.go @@ -4,6 +4,9 @@ import ( "strconv" "sync" + batchV1 "k8s.io/api/batch/v1" + typedBatchV1 "k8s.io/client-go/kubernetes/typed/batch/v1" + "github.com/sirupsen/logrus" k8sV1 "k8s.io/api/core/v1" typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" @@ -19,14 +22,15 @@ const ( // message types that are sent to the requestProcessingWorkers channel. type ( createKubernetesResources struct { - podSpec *k8sV1.Pod + jobSpec *batchV1.Job configMapSpec *k8sV1.ConfigMap } deleteKubernetesResources struct { namespace string - podName string + jobName string configMapName string + podName string } ) @@ -34,26 +38,26 @@ type ( // to creation or deletion requests. type ( resourceCreationFailed struct { - podName string + jobName string err error } resourceDeletionFailed struct { - podName string + jobName string err error } resourceCreationCancelled struct { - podName string + jobName string } ) type resourcesRequestFailure interface { - getPodName() string + getJobName() string resourcesRequestFailure() } -func (e resourceCreationFailed) getPodName() string { return e.podName } -func (e resourceDeletionFailed) getPodName() string { return e.podName } -func (e resourceCreationCancelled) getPodName() string { return e.podName } +func (e resourceCreationFailed) getJobName() string { return e.jobName } +func (e resourceDeletionFailed) getJobName() string { return e.jobName } +func (e resourceCreationCancelled) getJobName() string { return e.jobName } func (resourceCreationFailed) resourcesRequestFailure() {} func (resourceDeletionFailed) resourcesRequestFailure() {} @@ -101,6 +105,7 @@ type queuedResourceRequest struct { // requestProcessingWorkers notify the requestQueue that they are available to receive work // by sending a `workerAvailable` message. type requestQueue struct { + jobInterfaces map[string]typedBatchV1.JobInterface podInterfaces map[string]typedV1.PodInterface configMapInterfaces map[string]typedV1.ConfigMapInterface failures chan<- resourcesRequestFailure @@ -120,11 +125,13 @@ type requestQueue struct { type requestID string func startRequestQueue( + jobInterfaces map[string]typedBatchV1.JobInterface, podInterfaces map[string]typedV1.PodInterface, configMapInterfaces map[string]typedV1.ConfigMapInterface, failures chan<- resourcesRequestFailure, ) *requestQueue { r := &requestQueue{ + jobInterfaces: jobInterfaces, podInterfaces: podInterfaces, configMapInterfaces: configMapInterfaces, failures: failures, @@ -137,7 +144,7 @@ func startRequestQueue( pendingResourceCreations: make(map[requestID]*queuedResourceRequest), blockedResourceDeletions: make(map[requestID]*queuedResourceRequest), - syslog: logrus.New().WithField("component", "kubernetesrm-queue"), + syslog: logrus.WithField("component", "kubernetesrm-queue"), } r.startWorkers() return r @@ -146,6 +153,7 @@ func startRequestQueue( func (r *requestQueue) startWorkers() { for i := 0; i < numKubernetesWorkers; i++ { startRequestProcessingWorker( + r.jobInterfaces, r.podInterfaces, r.configMapInterfaces, strconv.Itoa(i), @@ -157,8 +165,8 @@ func (r *requestQueue) startWorkers() { } func keyForCreate(msg createKubernetesResources) requestID { - if msg.podSpec != nil { - return requestID(msg.podSpec.Namespace + "/" + msg.podSpec.Name) + if msg.jobSpec != nil { + return requestID(msg.jobSpec.Namespace + "/" + msg.jobSpec.Name) } if msg.configMapSpec != nil { return requestID(msg.configMapSpec.Namespace + "/" + msg.configMapSpec.Name) @@ -167,6 +175,9 @@ func keyForCreate(msg createKubernetesResources) requestID { } func keyForDelete(msg deleteKubernetesResources) requestID { + if msg.jobName != "" { + return requestID(msg.namespace + "/" + msg.jobName) + } if msg.podName != "" { return requestID(msg.namespace + "/" + msg.podName) } @@ -177,13 +188,13 @@ func keyForDelete(msg deleteKubernetesResources) requestID { } func (r *requestQueue) createKubernetesResources( - podSpec *k8sV1.Pod, + jobSpec *batchV1.Job, configMapSpec *k8sV1.ConfigMap, ) { r.mu.Lock() defer r.mu.Unlock() - msg := createKubernetesResources{podSpec, configMapSpec} + msg := createKubernetesResources{jobSpec, configMapSpec} ref := keyForCreate(msg) if _, requestAlreadyExists := r.pendingResourceCreations[ref]; requestAlreadyExists { @@ -203,13 +214,19 @@ func (r *requestQueue) createKubernetesResources( func (r *requestQueue) deleteKubernetesResources( namespace string, - podName string, + jobName string, configMapName string, + podName string, ) { r.mu.Lock() defer r.mu.Unlock() - msg := deleteKubernetesResources{namespace, podName, configMapName} + msg := deleteKubernetesResources{ + namespace: namespace, + jobName: jobName, + configMapName: configMapName, + podName: podName, + } ref := keyForDelete(msg) // If the request has not been processed yet, cancel it and inform the handler. @@ -217,7 +234,7 @@ func (r *requestQueue) deleteKubernetesResources( r.pendingResourceCreations[ref].createResources = nil delete(r.pendingResourceCreations, ref) r.failures <- resourceCreationCancelled{ - podName: podName, + jobName: jobName, } r.syslog.Warnf("delete issued with pending create request for %s", ref) return diff --git a/master/internal/rm/kubernetesrm/request_queue_test.go b/master/internal/rm/kubernetesrm/request_queue_test.go index 2e4d747a92c..1b0f283ced2 100644 --- a/master/internal/rm/kubernetesrm/request_queue_test.go +++ b/master/internal/rm/kubernetesrm/request_queue_test.go @@ -11,23 +11,25 @@ import ( "github.com/sirupsen/logrus" "gotest.tools/assert" + batchV1 "k8s.io/api/batch/v1" k8sV1 "k8s.io/api/core/v1" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" + typedBatchV1 "k8s.io/client-go/kubernetes/typed/batch/v1" typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" ) -type mockPod struct { +type mockJob struct { requestQueue *requestQueue name string syslog *logrus.Entry } -func startMockPod(requestQueue *requestQueue) *mockPod { - m := &mockPod{ +func startMockJob(requestQueue *requestQueue) *mockJob { + m := &mockJob{ requestQueue: requestQueue, name: petName.Generate(3, "-"), } - m.syslog = logrus.New().WithField("component", "kubernetesrm-mock-pod").WithField("name", m.name) + m.syslog = logrus.WithField("component", "kubernetesrm-mock-pod").WithField("name", m.name) m.create() return m } @@ -52,38 +54,8 @@ func runDefaultErrorHandler(ctx context.Context, failures <-chan resourcesReques } } -func consumeResourceRequestFailures( - ctx context.Context, - failures <-chan resourcesRequestFailure, - ref *pod, -) { - for { - select { - case failure := <-failures: - switch e := failure.(type) { - case resourceCreationFailed: - logrus.Errorf("defaultErrorHandler resource creation failed: %v", e) - ref.receiveResourceCreationFailed(e) - ref.finalize() - case resourceDeletionFailed: - logrus.Errorf("defaultErrorHandler resource deletion failed: %v", e) - ref.receiveResourceDeletionFailed(e) - ref.finalize() - case resourceCreationCancelled: - logrus.Infof("defaultErrorHandler resource deletion failed: %v", e) - ref.receiveResourceCreationCancelled() - ref.finalize() - default: - panic(fmt.Sprintf("unexpected error %T", e)) - } - case <-ctx.Done(): - return - } - } -} - -func (m *mockPod) create() { - podSpec := k8sV1.Pod{ObjectMeta: metaV1.ObjectMeta{ +func (m *mockJob) create() { + jobSpec := batchV1.Job{ObjectMeta: metaV1.ObjectMeta{ Name: m.name, Namespace: "default", }} @@ -91,20 +63,20 @@ func (m *mockPod) create() { Name: m.name, Namespace: "default", }} - m.requestQueue.createKubernetesResources(&podSpec, &cmSpec) + m.requestQueue.createKubernetesResources(&jobSpec, &cmSpec) } -func (m *mockPod) delete() { - m.requestQueue.deleteKubernetesResources("default", m.name, m.name) +func (m *mockJob) delete() { + m.requestQueue.deleteKubernetesResources("default", m.name, m.name, "") } -func getNumberOfActivePods(podInterface typedV1.PodInterface) int { - podList, err := podInterface.List(context.TODO(), metaV1.ListOptions{}) +func getNumberOfActiveJobs(jobInterface typedBatchV1.JobInterface) int { + jobList, err := jobInterface.List(context.TODO(), metaV1.ListOptions{}) if err != nil { panic(err) } - return len(podList.Items) + return len(jobList.Items) } func requestQueueIsDone(r *requestQueue) bool { @@ -125,18 +97,20 @@ func waitForPendingRequestToFinish(k8RequestQueue *requestQueue) { time.Sleep(time.Second) } -func deleteAll(pods []*mockPod) { +func deleteAll(pods []*mockJob) { for _, p := range pods { p.delete() } } func TestRequestQueueCreatingManyPod(t *testing.T) { + jobInterface := &mockJobInterface{jobs: make(map[string]*batchV1.Job)} podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} failures := make(chan resourcesRequestFailure, 64) k8sRequestQueue := startRequestQueue( + map[string]typedBatchV1.JobInterface{"default": jobInterface}, map[string]typedV1.PodInterface{"default": podInterface}, map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, failures, @@ -148,19 +122,21 @@ func TestRequestQueueCreatingManyPod(t *testing.T) { numPods := 15 for i := 0; i < numPods; i++ { - startMockPod(k8sRequestQueue) + startMockJob(k8sRequestQueue) } waitForPendingRequestToFinish(k8sRequestQueue) - assert.Equal(t, getNumberOfActivePods(podInterface), numPods) + assert.Equal(t, getNumberOfActiveJobs(jobInterface), numPods) } func TestRequestQueueCreatingAndDeletingManyPod(t *testing.T) { + jobInterface := &mockJobInterface{jobs: make(map[string]*batchV1.Job)} podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} failures := make(chan resourcesRequestFailure, 64) k8sRequestQueue := startRequestQueue( + map[string]typedBatchV1.JobInterface{"default": jobInterface}, map[string]typedV1.PodInterface{"default": podInterface}, map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, failures, @@ -171,22 +147,24 @@ func TestRequestQueueCreatingAndDeletingManyPod(t *testing.T) { go runDefaultErrorHandler(ctx, failures) numPods := 15 - pods := make([]*mockPod, 0) + pods := make([]*mockJob, 0) for i := 0; i < numPods; i++ { - pods = append(pods, startMockPod(k8sRequestQueue)) + pods = append(pods, startMockJob(k8sRequestQueue)) } deleteAll(pods) waitForPendingRequestToFinish(k8sRequestQueue) - assert.Equal(t, getNumberOfActivePods(podInterface), 0) + assert.Equal(t, getNumberOfActiveJobs(jobInterface), 0) } func TestRequestQueueCreatingThenDeletingManyPods(t *testing.T) { + jobInterface := &mockJobInterface{jobs: make(map[string]*batchV1.Job)} podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} failures := make(chan resourcesRequestFailure, 64) k8sRequestQueue := startRequestQueue( + map[string]typedBatchV1.JobInterface{"default": jobInterface}, map[string]typedV1.PodInterface{"default": podInterface}, map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, failures, @@ -197,29 +175,31 @@ func TestRequestQueueCreatingThenDeletingManyPods(t *testing.T) { go runDefaultErrorHandler(ctx, failures) numPods := 15 - pods := make([]*mockPod, 0) + pods := make([]*mockJob, 0) for i := 0; i < numPods; i++ { - pods = append(pods, startMockPod(k8sRequestQueue)) + pods = append(pods, startMockJob(k8sRequestQueue)) } waitForPendingRequestToFinish(k8sRequestQueue) - assert.Equal(t, getNumberOfActivePods(podInterface), numPods) + assert.Equal(t, getNumberOfActiveJobs(jobInterface), numPods) deleteAll(pods) waitForPendingRequestToFinish(k8sRequestQueue) - assert.Equal(t, getNumberOfActivePods(podInterface), 0) + assert.Equal(t, getNumberOfActiveJobs(jobInterface), 0) } func TestRequestQueueCreatingAndDeletingManyPodWithDelay(t *testing.T) { - podInterface := &mockPodInterface{ - pods: make(map[string]*k8sV1.Pod), + jobInterface := &mockJobInterface{ + jobs: make(map[string]*batchV1.Job), operationalDelay: time.Millisecond * 500, } + podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod), operationalDelay: time.Millisecond * 500} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} failures := make(chan resourcesRequestFailure, 64) k8sRequestQueue := startRequestQueue( + map[string]typedBatchV1.JobInterface{"default": jobInterface}, map[string]typedV1.PodInterface{"default": podInterface}, map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, failures, @@ -230,32 +210,34 @@ func TestRequestQueueCreatingAndDeletingManyPodWithDelay(t *testing.T) { go runDefaultErrorHandler(ctx, failures) numPods := 15 - pods := make([]*mockPod, 0) + pods := make([]*mockJob, 0) for i := 0; i < numPods; i++ { - pods = append(pods, startMockPod(k8sRequestQueue)) + pods = append(pods, startMockJob(k8sRequestQueue)) } deleteAll(pods) waitForPendingRequestToFinish(k8sRequestQueue) - assert.Equal(t, getNumberOfActivePods(podInterface), 0) + assert.Equal(t, getNumberOfActiveJobs(jobInterface), 0) } func TestRequestQueueCreationCancelled(t *testing.T) { - podInterface := &mockPodInterface{ - pods: make(map[string]*k8sV1.Pod), + jobInterface := &mockJobInterface{ + jobs: make(map[string]*batchV1.Job), operationalDelay: time.Millisecond * 500, } + podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod), operationalDelay: time.Millisecond * 500} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} failures := make(chan resourcesRequestFailure, 64) k8sRequestQueue := startRequestQueue( + map[string]typedBatchV1.JobInterface{"default": jobInterface}, map[string]typedV1.PodInterface{"default": podInterface}, map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, failures, ) for i := 0; i < numKubernetesWorkers; i++ { - startMockPod(k8sRequestQueue) + startMockJob(k8sRequestQueue) } time.Sleep(time.Millisecond * 100) @@ -275,7 +257,7 @@ func TestRequestQueueCreationCancelled(t *testing.T) { } }() - pod := startMockPod(k8sRequestQueue) + pod := startMockJob(k8sRequestQueue) assert.Equal(t, createCancelled, false) pod.delete() wg.Wait() @@ -283,11 +265,13 @@ func TestRequestQueueCreationCancelled(t *testing.T) { } func TestRequestQueueCreationFailed(t *testing.T) { + jobInterface := &mockJobInterface{jobs: make(map[string]*batchV1.Job)} podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} failures := make(chan resourcesRequestFailure, 64) k8sRequestQueue := startRequestQueue( + map[string]typedBatchV1.JobInterface{"default": jobInterface}, map[string]typedV1.PodInterface{"default": podInterface}, map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, failures, @@ -309,7 +293,7 @@ func TestRequestQueueCreationFailed(t *testing.T) { } }() - pod := startMockPod(k8sRequestQueue) + pod := startMockJob(k8sRequestQueue) waitForPendingRequestToFinish(k8sRequestQueue) assert.Equal(t, createFailed, false) @@ -319,11 +303,13 @@ func TestRequestQueueCreationFailed(t *testing.T) { } func TestRequestQueueDeletionFailed(t *testing.T) { + jobInterface := &mockJobInterface{jobs: make(map[string]*batchV1.Job)} podInterface := &mockPodInterface{pods: make(map[string]*k8sV1.Pod)} configMapInterface := &mockConfigMapInterface{configMaps: make(map[string]*k8sV1.ConfigMap)} failures := make(chan resourcesRequestFailure, 64) k8sRequestQueue := startRequestQueue( + map[string]typedBatchV1.JobInterface{"default": jobInterface}, map[string]typedV1.PodInterface{"default": podInterface}, map[string]typedV1.ConfigMapInterface{"default": configMapInterface}, failures, @@ -345,7 +331,7 @@ func TestRequestQueueDeletionFailed(t *testing.T) { } }() - pod := startMockPod(k8sRequestQueue) + pod := startMockJob(k8sRequestQueue) waitForPendingRequestToFinish(k8sRequestQueue) assert.Equal(t, deleteFailed, false) diff --git a/master/internal/rm/kubernetesrm/request_workers.go b/master/internal/rm/kubernetesrm/request_workers.go index c219d21fdd6..cd07c1fde38 100644 --- a/master/internal/rm/kubernetesrm/request_workers.go +++ b/master/internal/rm/kubernetesrm/request_workers.go @@ -5,13 +5,17 @@ import ( "fmt" "github.com/sirupsen/logrus" - + k8serrors "k8s.io/apimachinery/pkg/api/errors" metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" + batchV1 "k8s.io/client-go/kubernetes/typed/batch/v1" typedV1 "k8s.io/client-go/kubernetes/typed/core/v1" + + "github.com/determined-ai/determined/master/pkg/ptrs" ) type requestProcessingWorker struct { - podInterfaces map[string]typedV1.PodInterface + jobInterface map[string]batchV1.JobInterface + podInterface map[string]typedV1.PodInterface configMapInterfaces map[string]typedV1.ConfigMapInterface failures chan<- resourcesRequestFailure syslog *logrus.Entry @@ -20,16 +24,18 @@ type requestProcessingWorker struct { type readyCallbackFunc func(createRef requestID) func startRequestProcessingWorker( - podInterfaces map[string]typedV1.PodInterface, + jobInterface map[string]batchV1.JobInterface, + podInterface map[string]typedV1.PodInterface, configMapInterfaces map[string]typedV1.ConfigMapInterface, id string, in <-chan interface{}, ready readyCallbackFunc, failures chan<- resourcesRequestFailure, ) *requestProcessingWorker { - syslog := logrus.New().WithField("component", "kubernetesrm-worker").WithField("id", id) + syslog := logrus.WithField("component", "kubernetesrm-worker").WithField("id", id) r := &requestProcessingWorker{ - podInterfaces: podInterfaces, + jobInterface: jobInterface, + podInterface: podInterface, configMapInterfaces: configMapInterfaces, failures: failures, syslog: syslog, @@ -59,26 +65,26 @@ func (r *requestProcessingWorker) receive(in <-chan interface{}, ready readyCall func (r *requestProcessingWorker) receiveCreateKubernetesResources( msg createKubernetesResources, ) { - r.syslog.Debugf("creating configMap with spec %v", msg.configMapSpec) - configMap, err := r.configMapInterfaces[msg.podSpec.Namespace].Create( + r.syslog.Debugf("creating configMap %v", msg.configMapSpec.Name) + configMap, err := r.configMapInterfaces[msg.jobSpec.Namespace].Create( context.TODO(), msg.configMapSpec, metaV1.CreateOptions{}) if err != nil { r.syslog.WithError(err).Errorf("error creating configMap %s", msg.configMapSpec.Name) - r.failures <- resourceCreationFailed{podName: msg.podSpec.Name, err: err} + r.failures <- resourceCreationFailed{jobName: msg.jobSpec.Name, err: err} return } r.syslog.Infof("created configMap %s", configMap.Name) - r.syslog.Debugf("launching pod with spec %v", msg.podSpec) - pod, err := r.podInterfaces[msg.podSpec.Namespace].Create( - context.TODO(), msg.podSpec, metaV1.CreateOptions{}, + r.syslog.Debugf("creating job %s", msg.jobSpec.Name) + job, err := r.jobInterface[msg.jobSpec.Namespace].Create( + context.TODO(), msg.jobSpec, metaV1.CreateOptions{}, ) if err != nil { - r.syslog.WithError(err).Errorf("error creating pod %s", msg.podSpec.Name) - r.failures <- resourceCreationFailed{podName: msg.podSpec.Name, err: err} + r.syslog.WithError(err).Errorf("error creating job %s", msg.jobSpec.Name) + r.failures <- resourceCreationFailed{jobName: msg.jobSpec.Name, err: err} return } - r.syslog.Infof("created pod %s", pod.Name) + r.syslog.Infof("created job %s", job.Name) } func (r *requestProcessingWorker) receiveDeleteKubernetesResources( @@ -89,31 +95,51 @@ func (r *requestProcessingWorker) receiveDeleteKubernetesResources( // If resource creation failed, we will still try to delete those resources which // will also result in a failure. + if len(msg.jobName) > 0 { + err = r.jobInterface[msg.namespace].Delete(context.TODO(), msg.jobName, metaV1.DeleteOptions{ + GracePeriodSeconds: &gracePeriod, + PropagationPolicy: ptrs.Ptr(metaV1.DeletePropagationBackground), + }) + switch { + case k8serrors.IsNotFound(err): + r.syslog.Infof("job %s is already deleted", msg.jobName) + case err != nil: + r.syslog.WithError(err).Errorf("failed to delete job %s", msg.jobName) + default: + r.syslog.Infof("deleted job %s", msg.jobName) + } + } + if len(msg.podName) > 0 { - err = r.podInterfaces[msg.namespace].Delete( + err = r.podInterface[msg.namespace].Delete( context.TODO(), msg.podName, metaV1.DeleteOptions{GracePeriodSeconds: &gracePeriod}) - if err != nil { - r.syslog.WithError(err).Errorf("failed to delete pod %s", msg.podName) - } else { - r.syslog.Infof("deleted pod %s", msg.podName) + switch { + case k8serrors.IsNotFound(err): + r.syslog.Infof("pod %s is already deleted", msg.jobName) + case err != nil: + r.syslog.WithError(err).Errorf("failed to delete pod %s", msg.jobName) + default: + r.syslog.Infof("deleted pod %s", msg.jobName) } } if len(msg.configMapName) > 0 { - errDeletingConfigMap := r.configMapInterfaces[msg.namespace].Delete( + err = r.configMapInterfaces[msg.namespace].Delete( context.TODO(), msg.configMapName, metaV1.DeleteOptions{GracePeriodSeconds: &gracePeriod}) - if errDeletingConfigMap != nil { - r.syslog.WithError(err).Errorf("failed to delete configMap %s", msg.configMapName) - err = errDeletingConfigMap - } else { - r.syslog.Infof("deleted configMap %s", msg.configMapName) + switch { + case k8serrors.IsNotFound(err): + r.syslog.Infof("configMap %s is already deleted", msg.jobName) + case err != nil: + r.syslog.WithError(err).Errorf("failed to delete configMap %s", msg.jobName) + default: + r.syslog.Infof("deleted configMap %s", msg.jobName) } } // It is possible that the creator of the message is no longer around. // However this should have no impact on correctness. if err != nil { - r.failures <- resourceDeletionFailed{podName: msg.podName, err: err} + r.failures <- resourceDeletionFailed{jobName: msg.jobName, err: err} } } diff --git a/master/internal/rm/kubernetesrm/resource_pool.go b/master/internal/rm/kubernetesrm/resource_pool.go index eec0841e2ed..880841ddfa5 100644 --- a/master/internal/rm/kubernetesrm/resource_pool.go +++ b/master/internal/rm/kubernetesrm/resource_pool.go @@ -17,7 +17,6 @@ import ( "github.com/determined-ai/determined/master/internal/rm/tasklist" "github.com/determined-ai/determined/master/internal/sproto" "github.com/determined-ai/determined/master/pkg/aproto" - "github.com/determined-ai/determined/master/pkg/cproto" "github.com/determined-ai/determined/master/pkg/device" "github.com/determined-ai/determined/master/pkg/logger" "github.com/determined-ai/determined/master/pkg/model" @@ -29,27 +28,21 @@ import ( const resourcePoolEnvVar = "DET_K8S_RESOURCE_POOL" -// getResourceSummary is a message to request a summary of the resources used by the -// resource pool (agents, slots, cpu containers). -type getResourceSummary struct{} - type kubernetesResourcePool struct { mu sync.Mutex maxSlotsPerPod int poolConfig *config.ResourcePoolConfig - reqList *tasklist.TaskList - groups map[model.JobID]*tasklist.Group - allocationIDToContainerID map[model.AllocationID]cproto.ID - containerIDtoAllocationID map[string]model.AllocationID + reqList *tasklist.TaskList + groups map[model.JobID]*tasklist.Group // TODO(DET-9613): Jobs have many allocs. jobIDToAllocationID map[model.JobID]model.AllocationID allocationIDToJobID map[model.AllocationID]model.JobID slotsUsedPerGroup map[*tasklist.Group]int allocationIDToRunningPods map[model.AllocationID]int - podsService *pods + jobsService *jobsService queuePositions tasklist.JobSortState reschedule bool @@ -62,7 +55,7 @@ type kubernetesResourcePool struct { func newResourcePool( maxSlotsPerPod int, poolConfig *config.ResourcePoolConfig, - podsService *pods, + jobsService *jobsService, db *db.PgDB, ) *kubernetesResourcePool { return &kubernetesResourcePool{ @@ -70,13 +63,11 @@ func newResourcePool( poolConfig: poolConfig, reqList: tasklist.New(), groups: map[model.JobID]*tasklist.Group{}, - allocationIDToContainerID: map[model.AllocationID]cproto.ID{}, - containerIDtoAllocationID: map[string]model.AllocationID{}, jobIDToAllocationID: map[model.JobID]model.AllocationID{}, allocationIDToJobID: map[model.AllocationID]model.JobID{}, slotsUsedPerGroup: map[*tasklist.Group]int{}, allocationIDToRunningPods: map[model.AllocationID]int{}, - podsService: podsService, + jobsService: jobsService, queuePositions: tasklist.InitializeJobSortState(true), db: db, syslog: logrus.WithField("component", "k8s-rp"), @@ -107,22 +98,17 @@ func (k *kubernetesResourcePool) ResourcesReleased(msg sproto.ResourcesReleased) k.resourcesReleased(msg) } -func (k *kubernetesResourcePool) UpdatePodStatus(msg sproto.UpdatePodStatus) { +func (k *kubernetesResourcePool) JobSchedulingStateChanged(msg jobSchedulingStateChanged) { k.mu.Lock() defer k.mu.Unlock() k.reschedule = true - id, ok := k.containerIDtoAllocationID[msg.ContainerID] - if !ok { - return - } - for it := k.reqList.Iterator(); it.Next(); { req := it.Value() - if req.AllocationID == id { + if req.AllocationID == msg.AllocationID { req.State = msg.State if sproto.ScheduledStates[req.State] { - k.allocationIDToRunningPods[id]++ + k.allocationIDToRunningPods[msg.AllocationID] += msg.NumPods } } } @@ -196,10 +182,7 @@ func (k *kubernetesResourcePool) SetGroupPriority(msg sproto.SetGroupPriority) e for it := k.reqList.Iterator(); it.Next(); { if it.Value().JobID == msg.JobID { req := it.Value() - if id, ok := k.allocationIDToContainerID[req.AllocationID]; ok { - k.podsService.ChangePriority(id) - delete(k.allocationIDToContainerID, req.AllocationID) - } + k.jobsService.ChangePriority(req.AllocationID) } } return nil @@ -237,7 +220,7 @@ func (k *kubernetesResourcePool) GetAllocationSummaries() map[model.AllocationID return k.reqList.TaskSummaries(k.groups, kubernetesScheduler) } -func (k *kubernetesResourcePool) getResourceSummary(msg getResourceSummary) (*resourceSummary, error) { +func (k *kubernetesResourcePool) getResourceSummary() (*resourceSummary, error) { k.mu.Lock() defer k.mu.Unlock() k.reschedule = true @@ -252,8 +235,8 @@ func (k *kubernetesResourcePool) getResourceSummary(msg getResourceSummary) (*re } return &resourceSummary{ - numAgents: pods.NumAgents, - numTotalSlots: pods.SlotsAvailable, + numAgents: pods.numAgentsUsed, + numTotalSlots: pods.slotsAvailable, numActiveSlots: slotsUsed, maxNumAuxContainers: 1, numActiveAuxContainers: 0, @@ -282,8 +265,8 @@ func (k *kubernetesResourcePool) Schedule() { k.reschedule = false } -func (k *kubernetesResourcePool) summarizePods() (*PodsInfo, error) { - resp, err := k.podsService.SummarizeResources(SummarizeResources{PoolName: k.poolConfig.PoolName}) +func (k *kubernetesResourcePool) summarizePods() (*computeUsageSummary, error) { + resp, err := k.jobsService.SummarizeResources(k.poolConfig.PoolName) if err != nil { return nil, err } @@ -406,13 +389,8 @@ func (k *kubernetesResourcePool) moveJob( if !ok { return fmt.Errorf("job with ID %s has no valid task address", jobID) } - containerID, ok := k.allocationIDToContainerID[allocationID] - if !ok { - return fmt.Errorf("job with ID %s has no valid containerID", jobID) - } - - k.podsService.ChangePosition(containerID) + k.jobsService.ChangePosition(allocationID) return nil } @@ -456,10 +434,7 @@ func (k *kubernetesResourcePool) assignResources( return } - if req.SlotsNeeded <= k.maxSlotsPerPod { - numPods = 1 - slotsPerPod = req.SlotsNeeded - } else { + if req.SlotsNeeded > k.maxSlotsPerPod { if req.SlotsNeeded%k.maxSlotsPerPod != 0 { k.syslog.WithField("allocation-id", req.AllocationID).Errorf( "task number of slots (%d) is not schedulable on the configured "+ @@ -474,13 +449,12 @@ func (k *kubernetesResourcePool) assignResources( group := k.groups[req.JobID] if group == nil { - k.syslog.WithField("allocation-id", req.AllocationID).Errorf( - "cannot find group for job %s", req.JobID) + k.syslog.WithField("allocation-id", req.AllocationID).Errorf("cannot find group for job %s", req.JobID) return } k.slotsUsedPerGroup[group] += req.SlotsNeeded - var resources []*k8sPodResources + var resources *k8sJobResource if req.Restore { var err error resources, err = k.restoreResources(req, slotsPerPod, numPods) @@ -489,7 +463,7 @@ func (k *kubernetesResourcePool) assignResources( WithField("allocation-id", req.AllocationID). WithError(err).Error("unable to restore allocation") unknownExit := sproto.ExitCode(-1) - rmevents.Publish(req.AllocationID, &sproto.ResourcesRestoreError{ + rmevents.Publish(req.AllocationID, &sproto.ResourcesFailedError{ FailureType: sproto.ResourcesMissing, ErrMsg: errors.Wrap(err, "unable to restore allocation").Error(), ExitCode: &unknownExit, @@ -501,16 +475,13 @@ func (k *kubernetesResourcePool) assignResources( } allocations := sproto.ResourceList{} - for _, rs := range resources { - allocations[rs.Summary().ResourcesID] = rs - k.allocationIDToContainerID[req.AllocationID] = rs.containerID - k.containerIDtoAllocationID[rs.containerID.String()] = req.AllocationID - } + allocations[resources.Summary().ResourcesID] = resources assigned := sproto.ResourcesAllocated{ ID: req.AllocationID, Resources: allocations, JobSubmissionTime: req.JobSubmissionTime, + Recovered: req.Restore, } k.reqList.AddAllocationRaw(req.AllocationID, &assigned) rmevents.Publish(req.AllocationID, assigned.Clone()) @@ -519,19 +490,21 @@ func (k *kubernetesResourcePool) assignResources( k.syslog. WithField("allocation-id", req.AllocationID). WithField("task-handler", req.Name). - Infof("resources restored with %d pods", numPods) + WithField("num-pods", numPods). + Infof("restored kubernetes job") } else { k.syslog. WithField("allocation-id", req.AllocationID). WithField("task-handler", req.Name). - Infof("resources assigned with %d pods", numPods) + WithField("num-pods", numPods). + Infof("admitting kubernetes job") } if req.Restore { // This call must happen after we publish ResourcesAllocated, otherwise the allocation will // receive an update for resources it does not know about, ignore it, then hang if it missed // the termination. - err := k.podsService.RefreshPodStates(refreshPodStates{allocationID: req.AllocationID}) + err := k.jobsService.RefreshStates(req.AllocationID) if err != nil { k.syslog.WithError(err).Error("failed to refresh pod states after reattach") } @@ -540,26 +513,22 @@ func (k *kubernetesResourcePool) assignResources( func (k *kubernetesResourcePool) createResources( req *sproto.AllocateRequest, slotsPerPod, numPods int, -) []*k8sPodResources { - var resources []*k8sPodResources - for pod := 0; pod < numPods; pod++ { - resources = append(resources, &k8sPodResources{ - req: req, - podsService: k.podsService, - containerID: cproto.NewID(), - slots: slotsPerPod, - group: k.groups[req.JobID], - initialPosition: k.queuePositions[k.allocationIDToJobID[req.AllocationID]], - namespace: k.poolConfig.KubernetesNamespace, - }) - } - return resources +) *k8sJobResource { + return &k8sJobResource{ + numPods: numPods, + req: req, + jobsService: k.jobsService, + slots: slotsPerPod, + group: k.groups[req.JobID], + initialPosition: k.queuePositions[k.allocationIDToJobID[req.AllocationID]], + namespace: k.poolConfig.KubernetesNamespace, + } } func (k *kubernetesResourcePool) restoreResources( req *sproto.AllocateRequest, slotsPerPod, numPods int, -) ([]*k8sPodResources, error) { - restoreResponses, err := k.podsService.ReattachAllocationPods(reattachAllocationPods{ +) (*k8sJobResource, error) { + restored, err := k.jobsService.ReattachJob(reattachJobRequest{ req: req, allocationID: req.AllocationID, numPods: numPods, @@ -570,22 +539,16 @@ func (k *kubernetesResourcePool) restoreResources( return nil, err } - var resources []*k8sPodResources - for _, restoreResponse := range restoreResponses { - resources = append(resources, &k8sPodResources{ - req: req, - podsService: k.podsService, - containerID: cproto.ID(restoreResponse.containerID), - slots: slotsPerPod, - group: k.groups[req.JobID], - initialPosition: k.queuePositions[k.allocationIDToJobID[req.AllocationID]], - namespace: k.poolConfig.KubernetesNamespace, - - started: restoreResponse.started, - }) - } + return &k8sJobResource{ + req: req, + jobsService: k.jobsService, + slots: slotsPerPod, + group: k.groups[req.JobID], + initialPosition: k.queuePositions[k.allocationIDToJobID[req.AllocationID]], + namespace: k.poolConfig.KubernetesNamespace, - return resources, nil + started: restored.started, + }, nil } func (k *kubernetesResourcePool) resourcesReleased( @@ -609,15 +572,8 @@ func (k *kubernetesResourcePool) resourcesReleased( } k.reqList.RemoveTaskByID(msg.AllocationID) - delete(k.allocationIDToContainerID, msg.AllocationID) delete(k.allocationIDToRunningPods, msg.AllocationID) - for id, addr := range k.containerIDtoAllocationID { - if addr == msg.AllocationID { - delete(k.containerIDtoAllocationID, id) - break - } - } rmevents.Publish(msg.AllocationID, sproto.ResourcesReleasedEvent{}) } @@ -656,12 +612,12 @@ func (k *kubernetesResourcePool) schedulePendingTasks() { } } -type k8sPodResources struct { +type k8sJobResource struct { req *sproto.AllocateRequest - podsService *pods + jobsService *jobsService group *tasklist.Group - containerID cproto.ID slots int + numPods int initialPosition decimal.Decimal namespace string @@ -669,52 +625,58 @@ type k8sPodResources struct { } // Summary summarizes a container allocation. -func (p k8sPodResources) Summary() sproto.ResourcesSummary { +func (p k8sJobResource) Summary() sproto.ResourcesSummary { return sproto.ResourcesSummary{ AllocationID: p.req.AllocationID, - ResourcesID: sproto.ResourcesID(p.containerID), - ResourcesType: sproto.ResourcesTypeK8sPod, + ResourcesID: sproto.ResourcesID(p.req.AllocationID), + ResourcesType: sproto.ResourcesTypeK8sJob, AgentDevices: map[aproto.ID][]device.Device{ // TODO: Make it more obvious k8s can't be trusted. aproto.ID("pods"): make([]device.Device, p.slots), }, - ContainerID: &p.containerID, - Started: p.started, + Started: p.started, } } // Start notifies the pods actor that it should launch a pod for the provided task spec. -func (p k8sPodResources) Start( +func (p k8sJobResource) Start( logCtx logger.Context, spec tasks.TaskSpec, rri sproto.ResourcesRuntimeInfo, ) error { p.setPosition(&spec) - spec.ContainerID = string(p.containerID) - spec.ResourcesID = string(p.containerID) + spec.ContainerID = string(p.req.AllocationID) + spec.ResourcesID = string(p.req.AllocationID) spec.AllocationID = string(p.req.AllocationID) spec.AllocationSessionToken = rri.Token spec.TaskID = string(p.req.TaskID) spec.UseHostMode = rri.IsMultiAgent spec.ResourcesConfig.SetPriority(p.group.Priority) + if spec.LoggingFields == nil { spec.LoggingFields = map[string]string{} } spec.LoggingFields["allocation_id"] = spec.AllocationID spec.LoggingFields["task_id"] = spec.TaskID - spec.ExtraEnvVars[sproto.ResourcesTypeEnvVar] = string(sproto.ResourcesTypeK8sPod) + + if spec.ExtraEnvVars == nil { + spec.ExtraEnvVars = map[string]string{} + } + spec.ExtraEnvVars[sproto.ResourcesTypeEnvVar] = string(sproto.ResourcesTypeK8sJob) spec.ExtraEnvVars[resourcePoolEnvVar] = p.req.ResourcePool - return p.podsService.StartTaskPod(StartTaskPod{ - Req: p.req, - AllocationID: p.req.AllocationID, - Spec: spec, - Slots: p.slots, - Rank: rri.AgentRank, - Namespace: p.namespace, - LogContext: logCtx, + + return p.jobsService.StartJob(startJob{ + req: p.req, + allocationID: p.req.AllocationID, + spec: spec, + slots: p.slots, + rank: rri.AgentRank, + namespace: p.namespace, + numPods: p.numPods, + logContext: logCtx, }) } -func (p k8sPodResources) setPosition(spec *tasks.TaskSpec) { +func (p k8sJobResource) setPosition(spec *tasks.TaskSpec) { newSpec := spec.Environment.PodSpec() if newSpec == nil { newSpec = &expconf.PodSpec{} @@ -727,11 +689,11 @@ func (p k8sPodResources) setPosition(spec *tasks.TaskSpec) { } // Kill notifies the pods actor that it should stop the pod. -func (p k8sPodResources) Kill(_ logger.Context) { - p.podsService.KillPod(p.containerID) +func (p k8sJobResource) Kill(_ logger.Context) { + p.jobsService.KillJob(p.req.AllocationID) } -func (p k8sPodResources) Persist() error { +func (p k8sJobResource) Persist() error { return nil } diff --git a/master/internal/rm/kubernetesrm/resource_pool_intg_test.go b/master/internal/rm/kubernetesrm/resource_pool_intg_test.go index 9268d5e0ed2..602dc51f77f 100644 --- a/master/internal/rm/kubernetesrm/resource_pool_intg_test.go +++ b/master/internal/rm/kubernetesrm/resource_pool_intg_test.go @@ -2,50 +2,903 @@ package kubernetesrm import ( "context" + "fmt" + "runtime/debug" + "strings" "testing" + "time" "github.com/google/uuid" - "github.com/sirupsen/logrus" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - k8sClient "k8s.io/client-go/kubernetes" + k8sV1 "k8s.io/api/core/v1" + metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" "github.com/determined-ai/determined/master/internal/config" "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/internal/rm/rmerrors" + "github.com/determined-ai/determined/master/internal/rm/rmevents" + "github.com/determined-ai/determined/master/internal/rm/tasklist" "github.com/determined-ai/determined/master/internal/sproto" + "github.com/determined-ai/determined/master/pkg/device" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/ptrs" - "github.com/determined-ai/determined/master/pkg/set" - "github.com/determined-ai/determined/master/pkg/syncx/waitgroupx" + "github.com/determined-ai/determined/master/pkg/schemas" + "github.com/determined-ai/determined/master/pkg/schemas/expconf" + "github.com/determined-ai/determined/master/pkg/tasks" + "github.com/determined-ai/determined/proto/pkg/apiv1" ) -var ( - defaultState = sproto.SchedulingStateQueued - defaultSlots = 3 +const ( + defaultResourcePool = "default" ) -func TestAllocateAndRelease(t *testing.T) { - rp := testResourcePool(t, defaultSlots) +type testLaunchOpts struct { + name string + image string + entrypoint []string + aug model.AgentUserGroup + extraEnvVars map[string]string + slots int + wantFailure *sproto.ResourcesFailedError +} + +func TestJobWorkflows(t *testing.T) { + testCases := []testLaunchOpts{ + { + name: "single successful pod", + entrypoint: []string{"/bin/bash", "-c", "exit 0"}, + slots: 1, + wantFailure: nil, + }, + { + name: "extra env vars", + entrypoint: []string{"/bin/bash", "-c", "exit $DET_EXTRA_VAR"}, + extraEnvVars: map[string]string{"DET_EXTRA_VAR": "15"}, + slots: 1, + wantFailure: &sproto.ResourcesFailedError{ + FailureType: sproto.ResourcesFailed, + ExitCode: (*sproto.ExitCode)(ptrs.Ptr(15)), + }, + }, + { + name: "missing container image", + image: "lieblos/notanimageipushed", + entrypoint: []string{"/bin/bash", "-c", "exit 0"}, + slots: 1, + wantFailure: &sproto.ResourcesFailedError{ + FailureType: sproto.ResourcesFailed, + ErrMsg: "unrecoverable image pull errors", + }, + }, + { + name: "single unsuccessful pod", + entrypoint: []string{"/bin/bash", "-c", "exit 1"}, + slots: 1, + wantFailure: &sproto.ResourcesFailedError{ + FailureType: sproto.ResourcesFailed, + ExitCode: (*sproto.ExitCode)(ptrs.Ptr(1)), + }, + }, + { + name: "multiple successful pods", + entrypoint: []string{"/bin/bash", "-c", "exit 0"}, + slots: 2, + wantFailure: nil, + }, + { + name: "invalid job submission", + entrypoint: []string{"exit 0"}, + aug: model.AgentUserGroup{ + UID: -1, + GID: -1, + }, + wantFailure: &sproto.ResourcesFailedError{ + FailureType: sproto.TaskError, + ErrMsg: "job crashed", + }, + }, + { + name: "non-root users", + entrypoint: []string{"/bin/bash", "-c", "exit $(id -u)"}, + aug: model.AgentUserGroup{ + UID: 123, + GID: 123, + }, + wantFailure: &sproto.ResourcesFailedError{ + FailureType: sproto.ResourcesFailed, + ExitCode: (*sproto.ExitCode)(ptrs.Ptr(123)), + }, + }, + { + name: "long job", // Long enough to see all transitions. + entrypoint: []string{"/bin/bash", "-c", "sleep 10"}, + wantFailure: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testLaunch(t, tc) + }) + } +} + +func testLaunch( + t *testing.T, + opts testLaunchOpts, +) { + if opts.image == "" { + opts.image = "ubuntu:latest" + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + id := uuid.NewString() + jobID, taskID, allocationID := model.JobID(id), model.TaskID(id), model.AllocationID(id) + startTime := time.Now() + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: opts.slots, + ResourcePool: "default", + }) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: &opts.image, + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: opts.entrypoint, + AgentUserGroup: &opts.aug, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + ExtraEnvVars: opts.extraEnvVars, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + // Be careful to allow missing state changes here since the jobs are very short. It's + // all good as long as we don't go backwards and end terminated. + var stop *sproto.ResourcesStopped + for state := sproto.Assigned; state != sproto.Terminated; { + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.True(t, state.BeforeOrEqual(change.ResourcesState)) + state = change.ResourcesState + stop = change.ResourcesStopped + } + + require.NotNil(t, stop) + if opts.wantFailure == nil { + require.Nil(t, stop.Failure) + return + } + require.NotNil(t, stop.Failure) + assert.Equal(t, opts.wantFailure.FailureType, stop.Failure.FailureType) + if opts.wantFailure.ExitCode != nil { + assert.NotNil(t, stop.Failure.ExitCode) + if stop.Failure.ExitCode != nil { + assert.Equal(t, *opts.wantFailure.ExitCode, *stop.Failure.ExitCode) + } + } else { + assert.Nil(t, stop.Failure.ExitCode) + } + assert.Contains(t, stop.Failure.ErrMsg, opts.wantFailure.ErrMsg) +} + +func TestPodLogStreamerReattach(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + user := db.RequireMockUser(t, db.SingleDB()) + task := db.RequireMockTask(t, db.SingleDB(), &user.ID) + alloc := db.RequireMockAllocation(t, db.SingleDB(), task.TaskID) + allocationID, taskID, jobID := alloc.AllocationID, task.TaskID, *task.JobID + startTime := task.StartTime + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + allocateReq := sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + } + rp.AllocateRequest(allocateReq) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + secret := uuid.NewString() + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: ptrs.Ptr("ubuntu:latest"), + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: []string{"/bin/bash", "-c", fmt.Sprintf("sleep 15 && echo %s", secret)}, + AgentUserGroup: &model.AgentUserGroup{}, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Pulling, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Starting, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Running, change.ResourcesState) + require.NotNil(t, change.ResourcesStarted) + + // Remake all component and "reattach" to this new resource pool. This saves + // us from needing to made the k8s code do graceful shutdown, but we should + // do it anyway someday. + rp = newTestResourcePool(newTestJobsService(t)) + + sub = rmevents.Subscribe(allocationID) + allocateReq.Restore = true + rp.AllocateRequest(allocateReq) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + reallocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.True(t, reallocated.Recovered) + require.Len(t, reallocated.Resources, 1) + + seen := 0 // HACK: Because we don't have graceful shutdown, we have two log streamers up and get two events. + for { + log := poll[*sproto.ContainerLog](ctx, t, sub) + if strings.Contains(log.Message(), secret) { + t.Logf("saw one log: %s", log.Message()) + seen++ + } + if seen == 2 { + break + } + } +} + +func TestPodLogStreamer(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + id := uuid.NewString() + jobID, taskID, allocationID := model.JobID(id), model.TaskID(id), model.AllocationID(id) + startTime := time.Now() + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + }) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + require.Len(t, allocated.Resources, 1) + secret := uuid.NewString() + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: ptrs.Ptr("ubuntu:latest"), + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: []string{"/bin/bash", "-c", fmt.Sprintf("sleep 10 && echo %s", secret)}, + AgentUserGroup: &model.AgentUserGroup{}, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + for { + log := poll[*sproto.ContainerLog](ctx, t, sub) + if strings.Contains(log.Message(), secret) { + return + } + } +} + +func TestKill(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + id := uuid.NewString() + jobID, taskID, allocationID := model.JobID(id), model.TaskID(id), model.AllocationID(id) + startTime := time.Now() + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + }) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: ptrs.Ptr("ubuntu:latest"), + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: []string{"sleep", "99999"}, + AgentUserGroup: &model.AgentUserGroup{}, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Pulling, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Starting, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Running, change.ResourcesState) + require.NotNil(t, change.ResourcesStarted) + + for _, res := range allocated.Resources { + res.Kill(nil) + } + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Terminated, change.ResourcesState) + require.NotNil(t, change.ResourcesStopped) + require.NotNil(t, change.ResourcesStopped.Failure) + require.Contains(t, change.ResourcesStopped.Failure.ErrMsg, "kill") +} + +func TestExternalKillWhileQueuedFails(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + id := uuid.NewString() + jobID, taskID, allocationID := model.JobID(id), model.TaskID(id), model.AllocationID(id) + startTime := time.Now() + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + }) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: ptrs.Ptr("ubuntu:latest"), + }, + RawPodSpec: &expconf.PodSpec{ + // Make them unschedulable. + Spec: k8sV1.PodSpec{NodeSelector: map[string]string{"non-existent": uuid.NewString()}}, + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: []string{"sleep", "99999"}, + AgentUserGroup: &model.AgentUserGroup{}, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + ctxWaitForStarting, cancelWaitForStarting := context.WithTimeout(ctx, 5*time.Second) + defer cancelWaitForStarting() + for { + ev, err := sub.GetWithContext(ctxWaitForStarting) + if err != nil && errors.Is(err, context.DeadlineExceeded) { + break + } else if err != nil { + t.Error(err) + t.FailNow() + } + + _, ok := ev.(*sproto.ResourcesStateChanged) + if ok { + t.Error("job should've stayed queued") + t.FailNow() + continue + } + } + + podListOpts := metaV1.ListOptions{ + LabelSelector: fmt.Sprintf("%s=%s", allocationIDLabel, string(allocationID)), + } + pods, err := j.clientSet.CoreV1().Pods("default").List(ctx, podListOpts) + require.NoError(t, err) + + require.Len(t, pods.Items, 1) + pod := pods.Items[0] + err = j.clientSet.CoreV1().Pods("default").Delete(ctx, pod.Name, metaV1.DeleteOptions{}) + require.NoError(t, err) + + var stop *sproto.ResourcesStopped + for state := sproto.Assigned; state != sproto.Terminated; { + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.True(t, state.BeforeOrEqual(change.ResourcesState)) + state = change.ResourcesState + stop = change.ResourcesStopped + } + require.NotNil(t, stop.Failure) + require.Contains(t, stop.Failure.ErrMsg, "deleted pod") +} + +func TestExternalPodDelete(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + id := uuid.NewString() + jobID, taskID, allocationID := model.JobID(id), model.TaskID(id), model.AllocationID(id) + startTime := time.Now() + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + }) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: ptrs.Ptr("ubuntu:latest"), + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: []string{"sleep", "99999"}, + AgentUserGroup: &model.AgentUserGroup{}, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Pulling, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Starting, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Running, change.ResourcesState) + require.NotNil(t, change.ResourcesStarted) + + podListOpts := metaV1.ListOptions{ + LabelSelector: fmt.Sprintf("%s=%s", allocationIDLabel, string(allocationID)), + } + pods, err := j.clientSet.CoreV1().Pods("default").List(ctx, podListOpts) + require.NoError(t, err) + + require.Len(t, pods.Items, 1) + pod := pods.Items[0] + err = j.clientSet.CoreV1().Pods("default").Delete(ctx, pod.Name, metaV1.DeleteOptions{}) + require.NoError(t, err) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Terminated, change.ResourcesState) + require.NotNil(t, change.ResourcesStopped) + require.NotNil(t, change.ResourcesStopped.Failure) +} + +func TestReattach(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + user := db.RequireMockUser(t, db.SingleDB()) + task := db.RequireMockTask(t, db.SingleDB(), &user.ID) + alloc := db.RequireMockAllocation(t, db.SingleDB(), task.TaskID) + allocationID, taskID, jobID := alloc.AllocationID, task.TaskID, *task.JobID + startTime := task.StartTime + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + allocateReq := sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + } + rp.AllocateRequest(allocateReq) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: ptrs.Ptr("ubuntu:latest"), + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: []string{"sleep", "99999"}, + AgentUserGroup: &model.AgentUserGroup{}, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Pulling, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Starting, change.ResourcesState) + + change = poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.Equal(t, sproto.Running, change.ResourcesState) + require.NotNil(t, change.ResourcesStarted) + + // Remake all component and "reattach" to this new resource pool. This saves + // us from needing to made the k8s code do graceful shutdown, but we should + // do it anyway someday. + rp = newTestResourcePool(newTestJobsService(t)) + + sub = rmevents.Subscribe(allocationID) + allocateReq.Restore = true + rp.AllocateRequest(allocateReq) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + reallocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.True(t, reallocated.Recovered) + require.Len(t, reallocated.Resources, 1) + + for _, res := range reallocated.Resources { + res.Kill(nil) + } + + for state := sproto.Assigned; state != sproto.Terminated; { + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.True(t, state.BeforeOrEqual(change.ResourcesState)) + state = change.ResourcesState + } +} + +func TestNodeWorkflows(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + resp := j.getAgents() + require.Equal(t, 1, len(resp.Agents)) + nodeID := resp.Agents[0].Id + + _, err := rp.jobsService.DisableAgent(&apiv1.DisableAgentRequest{AgentId: nodeID}) + defer func() { + // Ensure we re-enable the agent, otherwise failures in this test will break others. + _, err := rp.jobsService.EnableAgent(&apiv1.EnableAgentRequest{AgentId: nodeID}) + require.NoError(t, err) + }() + require.NoError(t, err) + + // Wait because this check relies on our informers (eventual consistency). + require.True(t, waitForCondition(10*time.Second, func() bool { + // Bust the cache. Calls that mutate nodes should probably handle this. + j.mu.Lock() + j.getAgentsCacheTime = j.getAgentsCacheTime.Add(-time.Hour) + j.mu.Unlock() + + resp = j.GetAgents() + require.Equal(t, 1, len(resp.Agents)) + return !resp.Agents[0].Enabled + }), "GetAgents didn't say %s is disabled, but we just disabled it", nodeID) + + id := uuid.NewString() + jobID, taskID, allocationID := model.JobID(id), model.TaskID(id), model.AllocationID(id) + startTime := time.Now() + + err = tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + sub := rmevents.Subscribe(allocationID) + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + }) + + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + allocated := poll[*sproto.ResourcesAllocated](ctx, t, sub) + require.NotNil(t, allocated) + require.Len(t, allocated.Resources, 1) + + for _, res := range allocated.Resources { + conf := expconf.ExperimentConfig{ //nolint:exhaustruct + RawEnvironment: &expconf.EnvironmentConfigV0{ //nolint:exhaustruct + RawImage: &expconf.EnvironmentImageMapV0{ //nolint:exhaustruct + RawCPU: ptrs.Ptr("ubuntu:latest"), + }, + }, + } + conf = schemas.WithDefaults(conf) + + err := res.Start(nil, tasks.TaskSpec{ + Description: fmt.Sprintf("test-job-%s", uuid.NewString()[:8]), + Entrypoint: []string{"/bin/bash", "-c", "exit 0"}, + AgentUserGroup: &model.AgentUserGroup{}, + Environment: conf.Environment(), + ResourcesConfig: conf.Resources(), + DontShipLogs: true, + }, sproto.ResourcesRuntimeInfo{}) + defer res.Kill(nil) + require.NoError(t, err) + } + + shortCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + for { + ev, err := sub.GetWithContext(shortCtx) + if err != nil { + break + } + + res, ok := ev.(*sproto.ResourcesStateChanged) + if !ok { + continue + } + if res.ResourcesState.BeforeOrEqual(sproto.Pulling) { + continue + } + t.Error("state went to RUNNING or beyond when all agents were disabled") + t.FailNow() + } + + _, err = rp.jobsService.EnableAgent(&apiv1.EnableAgentRequest{AgentId: nodeID}) + require.NoError(t, err) + + // Be careful to allow missing state changes here since the jobs are very short. It's + // all good as long as we don't go backwards and end terminated. + var stop *sproto.ResourcesStopped + for state := sproto.Assigned; state != sproto.Terminated; { + change := poll[*sproto.ResourcesStateChanged](ctx, t, sub) + require.True(t, state.BeforeOrEqual(change.ResourcesState)) + state = change.ResourcesState + stop = change.ResourcesStopped + } + require.NotNil(t, stop) + require.Nil(t, stop.Failure) +} + +func TestAllocateAndReleaseBeforeStarted(t *testing.T) { + rp := newTestResourcePool(newTestJobsService(t)) allocID := model.AllocationID(uuid.NewString()) - // AllocateRequest allocReq := sproto.AllocateRequest{ AllocationID: allocID, JobID: model.NewJobID(), Name: uuid.NewString(), BlockedNodes: []string{uuid.NewString(), uuid.NewString()}, } - rp.AllocateRequest(allocReq) req, ok := rp.reqList.TaskByID(allocID) - require.True(t, ok) - require.Equal(t, allocID, req.AllocationID) - require.Equal(t, allocReq.JobID, req.JobID) - require.Equal(t, allocReq.BlockedNodes, req.BlockedNodes) - require.Equal(t, allocReq.Name, req.Name) + require.Equal(t, allocReq, *req) - // ResourcesReleased rp.ResourcesReleased(sproto.ResourcesReleased{ AllocationID: allocID, ResourcePool: rp.poolConfig.PoolName, @@ -55,22 +908,83 @@ func TestAllocateAndRelease(t *testing.T) { require.Nil(t, req) } -func TestPendingPreemption(t *testing.T) { - rp := testResourcePool(t, defaultSlots) - err := rp.PendingPreemption(sproto.PendingPreemption{ - AllocationID: *model.NewAllocationID(ptrs.Ptr(uuid.NewString())), +func TestGroupMaxSlots(t *testing.T) { + j := newTestJobsService(t) + rp := newTestResourcePool(j) + + id := uuid.NewString() + jobID := model.JobID(id) + + err := tasklist.GroupPriorityChangeRegistry.Add(jobID, func(i int) error { return nil }) + require.NoError(t, err) + + t.Log("set group to have a max of one slot") + rp.SetGroupMaxSlots(sproto.SetGroupMaxSlots{ + MaxSlots: ptrs.Ptr(1), + JobID: jobID, }) + + t.Log("first one slot task in the job should get scheduled") + taskID, allocationID := model.TaskID(id), model.AllocationID(id) + startTime := time.Now() + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID, + TaskID: taskID, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + }) + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID)) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID)) + + t.Log("but the second shouldn't") + id2 := uuid.NewString() + taskID2, allocationID2 := model.TaskID(id2), model.AllocationID(id2) + rp.AllocateRequest(sproto.AllocateRequest{ + AllocationID: allocationID2, + TaskID: taskID2, + JobID: jobID, + RequestTime: startTime, + JobSubmissionTime: startTime, + IsUserVisible: true, + Name: "test job", + SlotsNeeded: 1, + ResourcePool: "default", + }) + require.True(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID2)) + rp.Schedule() + require.False(t, rp.reschedule) + require.False(t, rp.reqList.IsScheduled(allocationID2)) + + t.Log("and when the first releases it should get scheduled") + rp.ResourcesReleased(sproto.ResourcesReleased{AllocationID: allocationID}) + rp.Schedule() + require.False(t, rp.reschedule) + require.True(t, rp.reqList.IsScheduled(allocationID2)) +} + +func TestPendingPreemption(t *testing.T) { + var rp kubernetesResourcePool + err := rp.PendingPreemption(sproto.PendingPreemption{}) require.Equal(t, rmerrors.ErrNotSupported, err) } func TestSetGroupWeight(t *testing.T) { - rp := testResourcePool(t, defaultSlots) + var rp kubernetesResourcePool err := rp.SetGroupWeight(sproto.SetGroupWeight{}) require.Equal(t, rmerrors.UnsupportedError("set group weight is unsupported in k8s"), err) } func TestSetGroupPriority(t *testing.T) { - rp := testResourcePool(t, defaultSlots) + rp := newTestResourcePool(newTestJobsService(t)) cases := []struct { name string @@ -110,7 +1024,7 @@ func TestSetGroupPriority(t *testing.T) { } func TestValidateResources(t *testing.T) { - rp := testResourcePool(t, defaultSlots) + rp := newTestResourcePool(newTestJobsService(t)) cases := []struct { name string @@ -133,69 +1047,92 @@ func TestValidateResources(t *testing.T) { func TestSchedule(t *testing.T) { // TODO RM-301 t.Skip("skipping test until flake fixed") - rp := testResourcePool(t, defaultSlots) - _, allocID := testAddAllocation(t, rp, defaultState) + rp := newTestResourcePool(newTestJobsService(t)) + jobID := model.NewJobID() + allocID := model.AllocationID(jobID) + + allocReq := sproto.AllocateRequest{ + AllocationID: allocID, + JobID: jobID, + Preemptible: true, + State: sproto.SchedulingStateQueued, + } + rp.AllocateRequest(allocReq) + + _, ok := rp.reqList.TaskByID(allocReq.AllocationID) + require.True(t, ok) require.True(t, rp.reschedule) require.False(t, rp.reqList.IsScheduled(allocID)) - rp.Schedule() - require.False(t, rp.reschedule) require.True(t, rp.reqList.IsScheduled(allocID)) } -func testResourcePool(t *testing.T, slots int) *kubernetesResourcePool { - return newResourcePool(slots, &config.ResourcePoolConfig{}, testPodsService(t), db.SingleDB()) -} +func poll[T sproto.ResourcesEvent](ctx context.Context, t *testing.T, sub *sproto.ResourcesSubscription) T { + for { + ev, err := sub.GetWithContext(ctx) + if err != nil { + var typed T + t.Errorf("failed to receive %T in time: %s", typed, err) + t.Error(string(debug.Stack())) + t.FailNow() + } -func testAddAllocation( - t *testing.T, rp *kubernetesResourcePool, state sproto.SchedulingState, -) (model.JobID, model.AllocationID) { - jobID := model.NewJobID() - allocID := uuid.NewString() - - allocReq := sproto.AllocateRequest{ - AllocationID: *model.NewAllocationID(&allocID), - JobID: jobID, - Preemptible: true, - State: state, + res, ok := ev.(T) + if !ok { + continue + } + return res } +} - rp.AllocateRequest(allocReq) - - req, ok := rp.reqList.TaskByID(allocReq.AllocationID) - require.True(t, ok) +var tickInterval = 10 * time.Millisecond - return jobID, req.AllocationID +func waitForCondition(timeout time.Duration, condition func() bool) bool { + for i := 0; i < int(timeout/tickInterval); i++ { + if condition() { + return true + } + time.Sleep(tickInterval) + } + return false } -func testPodsService(t *testing.T) *pods { - config, err := readClientConfig("~/.kube/config") - require.NoError(t, err) +var testResourcePoolConfig = config.ResourcePoolConfig{ + PoolName: defaultResourcePool, + Description: "default test pool", + TaskContainerDefaults: &model.TaskContainerDefaultsConfig{}, + AgentReattachEnabled: false, + AgentReconnectWait: 0, + KubernetesNamespace: "default", + MaxCPUContainersPerAgent: 0, +} - clientSet, err := k8sClient.NewForConfig(config) - require.NoError(t, err) +func newTestResourcePool(j *jobsService) *kubernetesResourcePool { + return newResourcePool(1, &testResourcePoolConfig, j, db.SingleDB()) +} - return &pods{ - wg: waitgroupx.WithContext(context.Background()), - namespace: namespace, - masterServiceName: "master", - clientSet: clientSet, - podNameToPodHandler: make(map[string]*pod), - podNameToResourcePool: make(map[string]string), - containerIDToPodName: make(map[string]string), - containerIDToSchedulingState: make(map[string]sproto.SchedulingState), - podNameToContainerID: make(map[string]string), - podHandlerToMetadata: make(map[*pod]podMetadata), - resourceRequestQueue: &requestQueue{ - failures: make(chan resourcesRequestFailure, 16), - workerChan: make(chan interface{}), - queue: make([]*queuedResourceRequest, 0), - creationInProgress: make(set.Set[requestID]), - pendingResourceCreations: make(map[requestID]*queuedResourceRequest), - blockedResourceDeletions: make(map[requestID]*queuedResourceRequest), - syslog: logrus.New().WithField("component", "kubernetesrm-queue"), +func newTestJobsService(t *testing.T) *jobsService { + j, err := newJobsService( + "default", + map[string]string{"default": defaultResourcePool}, + "", + model.TLSClientConfig{}, + "", + device.CPU, + config.PodSlotResourceRequests{ + CPU: 1, }, - } + []config.ResourcePoolConfig{ + testResourcePoolConfig, + }, + &model.TaskContainerDefaultsConfig{}, + "localhost", + 8080, + "~/.kube/config", + nil, + ) + require.NoError(t, err) + return j } diff --git a/master/internal/rm/kubernetesrm/spec.go b/master/internal/rm/kubernetesrm/spec.go index 6fc94664220..8f70418fdab 100644 --- a/master/internal/rm/kubernetesrm/spec.go +++ b/master/internal/rm/kubernetesrm/spec.go @@ -10,10 +10,11 @@ import ( "strconv" "strings" + batchV1 "k8s.io/api/batch/v1" + "github.com/determined-ai/determined/master/internal/config" "github.com/docker/docker/api/types/mount" - petName "github.com/dustinkirkland/golang-petname" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -37,6 +38,10 @@ import ( const ( coscheduler = "coscheduler" + initContainerTarSrcPath = "/run/determined/temp/tar/src" + initContainerTarDstPath = "/run/determined/temp/tar/dst" + initContainerWorkDir = "/run/determined/temp/" + gcTask = "gc" cmdTask = "cmd" labelPrefix = "determined.ai/" @@ -45,13 +50,14 @@ const ( resourcePoolLabel = labelPrefix + "resource_pool" taskTypeLabel = labelPrefix + "task_type" taskIDLabel = labelPrefix + "task_id" + allocationIDLabel = labelPrefix + "allocation_id" containerIDLabel = labelPrefix + "container_id" ) -func (p *pod) configureResourcesRequirements() k8sV1.ResourceRequirements { - switch p.slotType { +func (j *job) configureResourcesRequirements() k8sV1.ResourceRequirements { + switch j.slotType { case device.CPU: - cpuMillisRequested := int64(p.slotResourceRequests.CPU * float32(p.slots) * 1000) + cpuMillisRequested := int64(j.slotResourceRequests.CPU * float32(j.slotsPerPod) * 1000) return k8sV1.ResourceRequirements{ Limits: map[k8sV1.ResourceName]resource.Quantity{ "cpu": *resource.NewMilliQuantity(cpuMillisRequested, resource.DecimalSI), @@ -65,13 +71,13 @@ func (p *pod) configureResourcesRequirements() k8sV1.ResourceRequirements { case device.CUDA: // default to CUDA-backed slots. fallthrough default: - if p.slots > 0 { + if j.slotsPerPod > 0 { return k8sV1.ResourceRequirements{ Limits: map[k8sV1.ResourceName]resource.Quantity{ - ResourceTypeNvidia: *resource.NewQuantity(int64(p.slots), resource.DecimalSI), + resourceTypeNvidia: *resource.NewQuantity(int64(j.slotsPerPod), resource.DecimalSI), }, Requests: map[k8sV1.ResourceName]resource.Quantity{ - ResourceTypeNvidia: *resource.NewQuantity(int64(p.slots), resource.DecimalSI), + resourceTypeNvidia: *resource.NewQuantity(int64(j.slotsPerPod), resource.DecimalSI), }, } } @@ -82,7 +88,7 @@ func (p *pod) configureResourcesRequirements() k8sV1.ResourceRequirements { } } -func (p *pod) configureEnvVars( +func (j *job) configureEnvVars( envVarsMap map[string]string, environment expconf.EnvironmentConfig, deviceType device.Type, @@ -95,23 +101,23 @@ func (p *pod) configureEnvVars( } } - var slotIds []string - for i := 0; i < p.slots; i++ { - slotIds = append(slotIds, strconv.Itoa(i)) + var slotIDs []string + for i := 0; i < j.slotsPerPod; i++ { + slotIDs = append(slotIDs, strconv.Itoa(i)) } masterScheme := "http" - if p.masterTLSConfig.Enabled { + if j.masterTLSConfig.Enabled { masterScheme = "https" } - envVarsMap["DET_CLUSTER_ID"] = p.clusterID - envVarsMap["DET_MASTER"] = fmt.Sprintf("%s://%s:%d", masterScheme, p.masterIP, p.masterPort) - envVarsMap["DET_MASTER_HOST"] = p.masterIP - envVarsMap["DET_MASTER_ADDR"] = p.masterIP - envVarsMap["DET_MASTER_PORT"] = fmt.Sprintf("%d", p.masterPort) - envVarsMap["DET_SLOT_IDS"] = fmt.Sprintf("[%s]", strings.Join(slotIds, ",")) - if p.masterTLSConfig.CertificateName != "" { - envVarsMap["DET_MASTER_CERT_NAME"] = p.masterTLSConfig.CertificateName + envVarsMap["DET_CLUSTER_ID"] = j.clusterID + envVarsMap["DET_MASTER"] = fmt.Sprintf("%s://%s:%d", masterScheme, j.masterIP, j.masterPort) + envVarsMap["DET_MASTER_HOST"] = j.masterIP + envVarsMap["DET_MASTER_ADDR"] = j.masterIP + envVarsMap["DET_MASTER_PORT"] = strconv.Itoa(int(j.masterPort)) + envVarsMap["DET_SLOT_IDS"] = fmt.Sprintf("[%s]", strings.Join(slotIDs, ",")) + if j.masterTLSConfig.CertificateName != "" { + envVarsMap["DET_MASTER_CERT_NAME"] = j.masterTLSConfig.CertificateName } // Without this zero slot tasks will have access to all GPUs. @@ -120,6 +126,8 @@ func (p *pod) configureEnvVars( envVarsMap["NVIDIA_VISIBLE_DEVICES"] = "void" } + envVarsMap["DET_KUBERNETES_JOB_PARALLELISM"] = strconv.Itoa(j.numPods) + envVars := make([]k8sV1.EnvVar, 0, len(envVarsMap)) for envVarKey, envVarValue := range envVarsMap { envVars = append(envVars, k8sV1.EnvVar{Name: envVarKey, Value: envVarValue}) @@ -128,10 +136,19 @@ func (p *pod) configureEnvVars( Name: "DET_AGENT_ID", ValueFrom: &k8sV1.EnvVarSource{FieldRef: &k8sV1.ObjectFieldSelector{FieldPath: "spec.nodeName"}}, }) + envVars = append(envVars, k8sV1.EnvVar{ + Name: "DET_KUBERNETES_POD_IP", + ValueFrom: &k8sV1.EnvVarSource{ + FieldRef: &k8sV1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }) return envVars, nil } -func (p *pod) configureConfigMapSpec( +func (j *job) configureConfigMapSpec( + taskSpec *tasks.TaskSpec, runArchives []cproto.RunArchive, ) (*k8sV1.ConfigMap, error) { configMapData := make(map[string][]byte, len(runArchives)) @@ -152,15 +169,16 @@ func (p *pod) configureConfigMapSpec( // for the init container. return &k8sV1.ConfigMap{ ObjectMeta: metaV1.ObjectMeta{ - Name: p.configMapName, - Namespace: p.namespace, - Labels: map[string]string{determinedLabel: p.submissionInfo.taskSpec.AllocationID}, + Name: j.configMapName, + Namespace: j.namespace, + Labels: map[string]string{determinedLabel: taskSpec.AllocationID}, }, BinaryData: configMapData, }, nil } -func (p *pod) configureVolumes( +func (j *job) configureVolumes( + taskSpec *tasks.TaskSpec, dockerMounts []mount.Mount, runArchives []cproto.RunArchive, ) ([]k8sV1.VolumeMount, []k8sV1.VolumeMount, []k8sV1.Volume) { @@ -171,9 +189,9 @@ func (p *pod) configureVolumes( volumeMounts = append(volumeMounts, hostVolumeMounts...) volumes = append(volumes, hostVolumes...) - shmSize := p.submissionInfo.taskSpec.ShmSize + shmSize := taskSpec.ShmSize if shmSize == 0 { - shmSize = p.submissionInfo.taskSpec.TaskContainerDefaults.ShmSizeBytes + shmSize = taskSpec.TaskContainerDefaults.ShmSizeBytes } shmVolumeMount, shmVolume := configureShmVolume(shmSize) volumeMounts = append(volumeMounts, shmVolumeMount) @@ -181,7 +199,7 @@ func (p *pod) configureVolumes( // //nolint:lll // There isn't a great way to break this line that makes it more readable. initContainerVolumeMounts, mainContainerRunArchiveVolumeMounts, runArchiveVolumes := configureAdditionalFilesVolumes( - p.configMapName, + j.configMapName, runArchives, ) @@ -191,12 +209,16 @@ func (p *pod) configureVolumes( return initContainerVolumeMounts, volumeMounts, volumes } -func (p *pod) modifyPodSpec(newPod *k8sV1.Pod, scheduler string) { - if p.submissionInfo.taskSpec.Description == cmdTask { +func (j *job) modifyPodSpec( + taskSpec *tasks.TaskSpec, + newPod *k8sV1.Pod, + scheduler string, +) { + if taskSpec.Description == cmdTask { return } - if p.submissionInfo.taskSpec.Description == gcTask { + if taskSpec.Description == gcTask { if newPod.Spec.PriorityClassName != "" { log.Warnf( "GC Priority is currently using priority class: %s. "+ @@ -209,15 +231,15 @@ func (p *pod) modifyPodSpec(newPod *k8sV1.Pod, scheduler string) { if newPod.Spec.SchedulerName == "" { newPod.Spec.SchedulerName = scheduler } - p.configureCoscheduler(newPod, scheduler) + j.configureCoscheduler(taskSpec, newPod, scheduler) } if newPod.Spec.PriorityClassName == "" && - p.submissionInfo.taskSpec.ResourcesConfig.Priority() != nil { - priority := int32(*p.submissionInfo.taskSpec.ResourcesConfig.Priority()) - name := fmt.Sprintf("%s-priorityclass", p.submissionInfo.taskSpec.ContainerID) + taskSpec.ResourcesConfig.Priority() != nil { + priority := int32(*taskSpec.ResourcesConfig.Priority()) + name := fmt.Sprintf("%s-priorityclass", taskSpec.ContainerID) - err := p.createPriorityClass(name, priority) + err := j.createPriorityClass(name, priority) if err == nil { newPod.Spec.PriorityClassName = name @@ -300,30 +322,32 @@ func addNodeSelectorRequirement( } } -func (p *pod) configureCoscheduler(newPod *k8sV1.Pod, scheduler string) { +func (j *job) configureCoscheduler( + taskSpec *tasks.TaskSpec, + newPod *k8sV1.Pod, + scheduler string, +) { if newPod.Spec.SchedulerName != scheduler { return } - resources := p.submissionInfo.taskSpec.ResourcesConfig + resources := taskSpec.ResourcesConfig minAvailable := 0 - if p.slotType == device.CUDA && p.slots > 0 { - minAvailable = int(math.Ceil(float64(resources.SlotsPerTrial()) / float64(p.slots))) + if j.slotType == device.CUDA && j.slotsPerPod > 0 { + minAvailable = int(math.Ceil(float64(resources.SlotsPerTrial()) / float64(j.slotsPerPod))) } if newPod.APIVersion == "" { newPod.APIVersion = "v1" } if newPod.Kind == "" { - newPod.Kind = "Pod" + newPod.Kind = "Pod" //nolint:goconst } _, ok := newPod.ObjectMeta.Labels["pod-group.scheduling.sigs.k8s.io/name"] if !ok { - newPod.ObjectMeta.Labels["pod-group.scheduling.sigs.k8s.io/name"] = trialNameFromPod( - p.podName, - ) + newPod.ObjectMeta.Labels["pod-group.scheduling.sigs.k8s.io/name"] = j.jobName } _, ok = newPod.ObjectMeta.Labels["pod-group.scheduling.sigs.k8s.io/min-available"] if !ok { @@ -332,10 +356,12 @@ func (p *pod) configureCoscheduler(newPod *k8sV1.Pod, scheduler string) { } } -func (p *pod) createPriorityClass(name string, priority int32) error { +var defaultTTLSecondsAfterFinished int32 = 15 * 60 // 15 minutes + +func (j *job) createPriorityClass(name string, priority int32) error { preemptionPolicy := k8sV1.PreemptNever - _, err := p.clientSet.SchedulingV1().PriorityClasses().Create(context.TODO(), + _, err := j.clientSet.SchedulingV1().PriorityClasses().Create(context.TODO(), &schedulingV1.PriorityClass{ TypeMeta: metaV1.TypeMeta{}, ObjectMeta: metaV1.ObjectMeta{ @@ -399,59 +425,61 @@ func validatePodLabelValue(value string) (string, error) { return fixedValue, nil } -func (p *pod) configurePodSpec( +func (j *job) configureJobSpec( + taskSpec *tasks.TaskSpec, volumes []k8sV1.Volume, determinedInitContainers k8sV1.Container, determinedContainer k8sV1.Container, sidecarContainers []k8sV1.Container, podSpec *k8sV1.Pod, scheduler string, -) *k8sV1.Pod { +) *batchV1.Job { if podSpec == nil { podSpec = &k8sV1.Pod{} } else { podSpec = podSpec.DeepCopy() } - podSpec.ObjectMeta.Name = p.podName - podSpec.ObjectMeta.Namespace = p.namespace + podSpec.ObjectMeta.Name = j.jobName + podSpec.ObjectMeta.Namespace = j.namespace if podSpec.ObjectMeta.Labels == nil { podSpec.ObjectMeta.Labels = make(map[string]string) } - if p.submissionInfo.taskSpec.Owner != nil { + if taskSpec.Owner != nil { // Owner label will disappear if Owner is somehow nil. - labelValue, err := validatePodLabelValue(p.submissionInfo.taskSpec.Owner.Username) + labelValue, err := validatePodLabelValue(taskSpec.Owner.Username) if err != nil { labelValue = defaultPodLabelValue log.Warnf("unable to reformat username=%s to Kubernetes standards; using %s", - p.submissionInfo.taskSpec.Owner.Username, labelValue) + taskSpec.Owner.Username, labelValue) } podSpec.ObjectMeta.Labels[userLabel] = labelValue } - labelValue, err := validatePodLabelValue(p.submissionInfo.taskSpec.Workspace) + labelValue, err := validatePodLabelValue(taskSpec.Workspace) if err != nil { labelValue = defaultPodLabelValue log.Warnf("unable to reformat workspace=%s to Kubernetes standards; using %s", - p.submissionInfo.taskSpec.Workspace, labelValue) + taskSpec.Workspace, labelValue) } podSpec.ObjectMeta.Labels[workspaceLabel] = labelValue - labelValue, err = validatePodLabelValue(p.req.ResourcePool) + labelValue, err = validatePodLabelValue(j.req.ResourcePool) if err != nil { labelValue = defaultPodLabelValue log.Warnf("unable to reformat resource_pool=%s to Kubernetes standards; using %s", - p.req.ResourcePool, labelValue) + j.req.ResourcePool, labelValue) } podSpec.ObjectMeta.Labels[resourcePoolLabel] = labelValue - podSpec.ObjectMeta.Labels[taskTypeLabel] = string(p.submissionInfo.taskSpec.TaskType) - podSpec.ObjectMeta.Labels[taskIDLabel] = p.submissionInfo.taskSpec.TaskID - podSpec.ObjectMeta.Labels[containerIDLabel] = p.submissionInfo.taskSpec.ContainerID - podSpec.ObjectMeta.Labels[determinedLabel] = p.submissionInfo.taskSpec.AllocationID + podSpec.ObjectMeta.Labels[taskTypeLabel] = string(taskSpec.TaskType) + podSpec.ObjectMeta.Labels[taskIDLabel] = taskSpec.TaskID + podSpec.ObjectMeta.Labels[containerIDLabel] = taskSpec.ContainerID + podSpec.ObjectMeta.Labels[determinedLabel] = taskSpec.AllocationID + podSpec.ObjectMeta.Labels[allocationIDLabel] = taskSpec.AllocationID // If map is not populated, labels will be missing and observability will be impacted. - for k, v := range p.submissionInfo.taskSpec.ExtraPodLabels { + for k, v := range taskSpec.ExtraPodLabels { labelValue, err := validatePodLabelValue(v) if err != nil { labelValue = defaultPodLabelValue @@ -459,10 +487,10 @@ func (p *pod) configurePodSpec( podSpec.ObjectMeta.Labels[labelPrefix+k] = labelValue } - p.modifyPodSpec(podSpec, scheduler) + j.modifyPodSpec(taskSpec, podSpec, scheduler) addNodeDisabledAffinityToPodSpec(podSpec, clusterIDNodeLabel()) - addDisallowedNodesToPodSpec(p.req, podSpec) + addDisallowedNodesToPodSpec(j.req, podSpec) nonDeterminedContainers := make([]k8sV1.Container, 0) for idx, container := range podSpec.Spec.Containers { @@ -497,43 +525,55 @@ func (p *pod) configurePodSpec( podSpec.Spec.Containers = append(podSpec.Spec.Containers, sidecarContainers...) podSpec.Spec.Containers = append(podSpec.Spec.Containers, determinedContainer) podSpec.Spec.Volumes = append(podSpec.Spec.Volumes, volumes...) - podSpec.Spec.HostNetwork = p.submissionInfo.taskSpec.TaskContainerDefaults.NetworkMode.IsHost() + podSpec.Spec.HostNetwork = taskSpec.TaskContainerDefaults.NetworkMode.IsHost() podSpec.Spec.InitContainers = append(podSpec.Spec.InitContainers, determinedInitContainers) podSpec.Spec.RestartPolicy = k8sV1.RestartPolicyNever - return podSpec + return &batchV1.Job{ + ObjectMeta: podSpec.ObjectMeta, + Spec: batchV1.JobSpec{ + Parallelism: ptrs.Ptr(int32(j.numPods)), + Completions: ptrs.Ptr(int32(j.numPods)), + BackoffLimit: ptrs.Ptr(int32(0)), + Template: k8sV1.PodTemplateSpec{ + ObjectMeta: podSpec.ObjectMeta, + Spec: podSpec.Spec, + }, + // TTLSeconds is useful for debugging but also must be set reasonably high so we + // can recover job exit codes in the case where the job exits while the master + // is down. + TTLSecondsAfterFinished: &defaultTTLSecondsAfterFinished, + }, + } } -func (p *pod) createPodSpec(scheduler string) error { - deviceType := p.slotType +func (j *job) createSpec(scheduler string, taskSpec *tasks.TaskSpec) (*batchV1.Job, *k8sV1.ConfigMap, error) { + deviceType := j.slotType // Device type is currently configured globally on KubernetesResourceManagerConfig. // So we special case certain functionality to use device.CPU. - if deviceType == device.ZeroSlot || p.slots == 0 { + if deviceType == device.ZeroSlot || j.slotsPerPod == 0 { deviceType = device.CPU } - spec := p.submissionInfo.taskSpec + runArchives, rootArchives := taskSpec.Archives() - runArchives, rootArchives := spec.Archives() + initContainerVolumeMounts, volumeMounts, volumes := j.configureVolumes(taskSpec, taskSpec.Mounts, runArchives) - initContainerVolumeMounts, volumeMounts, volumes := p.configureVolumes(spec.Mounts, runArchives) - - env := spec.Environment + env := taskSpec.Environment // This array containerPorts is set on the container spec. // This field on the container spec is for "primarily informational" // reasons and to allow us to read these ports in reattaching pods. var containerPorts []k8sV1.ContainerPort for _, port := range env.Ports() { - p.ports = append(p.ports, port) containerPorts = append(containerPorts, k8sV1.ContainerPort{ ContainerPort: int32(port), }) } - envVars, err := p.configureEnvVars(spec.EnvVars(), env, deviceType) + envVars, err := j.configureEnvVars(taskSpec.EnvVars(), env, deviceType) if err != nil { - return err + return nil, nil, err } initContainer := configureInitContainer( @@ -541,61 +581,67 @@ func (p *pod) createPodSpec(scheduler string) error { initContainerVolumeMounts, env.Image().For(deviceType), configureImagePullPolicy(env), - spec.AgentUserGroup, + taskSpec.AgentUserGroup, ) var sidecars []k8sV1.Container container := k8sV1.Container{ Name: model.DeterminedK8ContainerName, - Command: spec.LogShipperWrappedEntrypoint(), + Command: taskSpec.LogShipperWrappedEntrypoint(), Env: envVars, Image: env.Image().For(deviceType), ImagePullPolicy: configureImagePullPolicy(env), SecurityContext: getDetContainerSecurityContext( - spec.AgentUserGroup, + taskSpec.AgentUserGroup, env.PodSpec(), ), - Resources: p.configureResourcesRequirements(), + Resources: j.configureResourcesRequirements(), VolumeMounts: volumeMounts, - WorkingDir: spec.WorkDir, + WorkingDir: taskSpec.WorkDir, Ports: containerPorts, } - p.configMap, err = p.configureConfigMapSpec(runArchives) + configMapSpec, err := j.configureConfigMapSpec(taskSpec, runArchives) if err != nil { - return err + return nil, nil, err } - rootVolumes, rootVolumeMounts, err := handleRootArchiveFiles(rootArchives, p.configMap) + rootVolumes, rootVolumeMounts, err := handleRootArchiveFiles(rootArchives, configMapSpec) if err != nil { - return err + return nil, nil, err } volumes = append(volumes, rootVolumes...) container.VolumeMounts = append(container.VolumeMounts, rootVolumeMounts...) - p.pod = p.configurePodSpec( - volumes, initContainer, container, sidecars, (*k8sV1.Pod)(env.PodSpec()), scheduler) - return nil + return j.configureJobSpec( + taskSpec, + volumes, + initContainer, + container, + sidecars, + (*k8sV1.Pod)(env.PodSpec()), + scheduler, + ), configMapSpec, nil } -func configureUniqueName(t tasks.TaskSpec, rank int) string { - return fmt.Sprintf("%s-%d-%s-%s", - t.Description, rank, t.AllocationID, petName.Generate(2, "-")) -} +func configureUniqueName(t tasks.TaskSpec) string { + name := t.Description -func trialNameFromPod(podName string) string { - // Given a pod name of the form exp-#-trial-#-rank-#..., returns a string exp#trial# - // e.g. input: exp-1-trial-1-rank-0-71af9..., returns: exp1trial1 - - newName := "" - for i, v := range strings.Split(podName, "-") { - if i > 3 { - break - } - newName += v + // Prefix with a cluster ID so multiple Determined installations can coexist within cluster. But + // limit to the first 8 chars of the cluster ID to avoid the 63 character limit (this is ~53). + // Handle short cluster IDs for tests. + var clusterIDPrefix string + if len(t.ClusterID) >= 8 { + clusterIDPrefix = t.ClusterID[:8] + } else { + clusterIDPrefix = t.ClusterID } - return newName + if clusterIDPrefix != "" { + name = fmt.Sprintf("%s-%s", clusterIDPrefix, name) + } + + return name } func configureSecurityContext(agentUserGroup *model.AgentUserGroup) *k8sV1.SecurityContext { diff --git a/master/internal/rm/kubernetesrm/spec_test.go b/master/internal/rm/kubernetesrm/spec_test.go index d9e249ee1a5..4ec620f2660 100644 --- a/master/internal/rm/kubernetesrm/spec_test.go +++ b/master/internal/rm/kubernetesrm/spec_test.go @@ -188,7 +188,7 @@ func TestLaterEnvironmentVariablesGetSet(t *testing.T) { }, } - p := pod{} + p := job{} actual, err := p.configureEnvVars(make(map[string]string), env, device.CPU) require.NoError(t, err) require.NotContains(t, actual, dontBe, "earlier variable set") @@ -213,7 +213,7 @@ func TestAllPrintableCharactersInEnv(t *testing.T) { }, } - p := pod{} + p := job{} actual, err := p.configureEnvVars(make(map[string]string), env, device.CPU) require.NoError(t, err) require.Contains(t, actual, k8sV1.EnvVar{Name: "test", Value: expectedValue}) @@ -253,7 +253,12 @@ func TestValidatePodLabelValues(t *testing.T) { func TestDeterminedLabels(t *testing.T) { // Fill out task spec. taskSpec := tasks.TaskSpec{ - Owner: createUser(), + Owner: &model.User{ + ID: 1, + Username: "determined", + Active: true, + Admin: false, + }, Workspace: "test-workspace", TaskType: model.TaskTypeCommand, TaskID: model.NewTaskID().String(), @@ -264,13 +269,10 @@ func TestDeterminedLabels(t *testing.T) { }, } - p := pod{ + p := job{ req: &sproto.AllocateRequest{ ResourcePool: "test-rp", }, - submissionInfo: &podSubmissionInfo{ - taskSpec: taskSpec, - }, } // Define expectations. @@ -282,13 +284,16 @@ func TestDeterminedLabels(t *testing.T) { taskTypeLabel: string(taskSpec.TaskType), taskIDLabel: taskSpec.TaskID, containerIDLabel: taskSpec.ContainerID, + allocationIDLabel: taskSpec.AllocationID, } for k, v := range taskSpec.ExtraPodLabels { expectedLabels[labelPrefix+k] = v } - spec := p.configurePodSpec(make([]k8sV1.Volume, 1), k8sV1.Container{}, - k8sV1.Container{}, make([]k8sV1.Container, 1), &k8sV1.Pod{}, "scheduler") + spec := p.configureJobSpec( + &taskSpec, make([]k8sV1.Volume, 1), k8sV1.Container{}, + k8sV1.Container{}, make([]k8sV1.Container, 1), &k8sV1.Pod{}, "scheduler", + ) // Confirm pod spec has required labels. require.NotNil(t, spec) diff --git a/master/internal/sproto/resources.go b/master/internal/sproto/resources.go index e8f35f8d338..08509fc0c90 100644 --- a/master/internal/sproto/resources.go +++ b/master/internal/sproto/resources.go @@ -2,6 +2,7 @@ package sproto import ( "fmt" + "slices" "github.com/pkg/errors" @@ -48,6 +49,21 @@ const ( Unknown ResourcesState = "" ) +var resourcesStateOrdering = []ResourcesState{ + Assigned, + Pulling, + Starting, + Running, + Terminated, +} + +// BeforeOrEqual returns if one state is chronologically before or the same as other. +func (s ResourcesState) BeforeOrEqual(other ResourcesState) bool { + selfIndex := slices.Index(resourcesStateOrdering, s) + otherIndex := slices.Index(resourcesStateOrdering, other) + return selfIndex <= otherIndex +} + // FromContainerState converts a cproto.State to ResourcesState. This may shortly become much less // granular (not a one to one mapping). func FromContainerState(state cproto.State) ResourcesState { @@ -108,7 +124,7 @@ func FromContainerStarted(cs *aproto.ContainerStarted) *ResourcesStarted { // ResourcesStopped contains the information needed by tasks from container stopped. type ResourcesStopped struct { - Failure *ResourcesRestoreError + Failure *ResourcesFailedError } // Proto returns the proto representation of ResourcesStopped. @@ -129,7 +145,7 @@ func FromContainerStopped(cs *aproto.ContainerStopped) *ResourcesStopped { rs := &ResourcesStopped{} if f := cs.Failure; f != nil { - rs.Failure = &ResourcesRestoreError{ + rs.Failure = &ResourcesFailedError{ FailureType: FromContainerFailureType(f.FailureType), ErrMsg: f.ErrMsg, ExitCode: FromContainerExitCode(f.ExitCode), @@ -143,14 +159,14 @@ func FromContainerStopped(cs *aproto.ContainerStopped) *ResourcesStopped { func ResourcesError(failureType FailureType, err error) ResourcesStopped { if err == nil { return ResourcesStopped{ - Failure: &ResourcesRestoreError{ + Failure: &ResourcesFailedError{ FailureType: failureType, ErrMsg: errors.WithStack(errors.Errorf("unknown error occurred")).Error(), }, } } return ResourcesStopped{ - Failure: &ResourcesRestoreError{ + Failure: &ResourcesFailedError{ FailureType: failureType, ErrMsg: err.Error(), }, @@ -164,15 +180,15 @@ func (r ResourcesStopped) String() string { return r.Failure.Error() } -// ResourcesRestoreError contains information about restored resources' failure. -type ResourcesRestoreError struct { +// ResourcesFailedError contains information about restored resources' failure. +type ResourcesFailedError struct { FailureType FailureType ErrMsg string ExitCode *ExitCode } // Proto returns the proto representation of ResourcesFailure. -func (f *ResourcesRestoreError) Proto() *taskv1.ResourcesFailure { +func (f *ResourcesFailedError) Proto() *taskv1.ResourcesFailure { if f == nil { return nil } @@ -193,15 +209,15 @@ func (f *ResourcesRestoreError) Proto() *taskv1.ResourcesFailure { // NewResourcesFailure returns a resources failure message wrapping the type, msg and exit code. func NewResourcesFailure( failureType FailureType, msg string, code *ExitCode, -) *ResourcesRestoreError { - return &ResourcesRestoreError{ +) *ResourcesFailedError { + return &ResourcesFailedError{ FailureType: failureType, ErrMsg: msg, ExitCode: code, } } -func (f ResourcesRestoreError) Error() string { +func (f ResourcesFailedError) Error() string { if f.ExitCode == nil { if len(f.ErrMsg) > 0 { return fmt.Sprintf("%s: %s", f.FailureType, f.ErrMsg) @@ -336,7 +352,7 @@ func IsUnrecoverableSystemError(err error) bool { // shouldn't count against `max_restarts`. func IsTransientSystemError(err error) bool { switch err := err.(type) { - case ResourcesRestoreError: + case ResourcesFailedError: switch err.FailureType { case ResourcesFailed, TaskError: return false @@ -368,6 +384,14 @@ type ResourcesStateChanged struct { Container *cproto.Container } +func (r ResourcesStateChanged) String() string { + var reason string + if r.ResourcesStopped != nil { + reason = r.ResourcesStopped.String() + } + return fmt.Sprintf("id=%s state=%s reason=%s", r.ResourcesID, r.ResourcesState, reason) +} + // FromContainerStateChanged converts an aproto.ContainerStateChanged message to // ResourcesStateChanged. func FromContainerStateChanged(sc aproto.ContainerStateChanged) *ResourcesStateChanged { diff --git a/master/internal/sproto/task.go b/master/internal/sproto/task.go index 2ab8560eb2f..9792abaa036 100644 --- a/master/internal/sproto/task.go +++ b/master/internal/sproto/task.go @@ -131,7 +131,7 @@ func (*ReleaseResources) ResourcesEvent() {} func (*ResourcesStateChanged) ResourcesEvent() {} // ResourcesEvent implements ResourcesEvent. -func (*ResourcesRestoreError) ResourcesEvent() {} +func (*ResourcesFailedError) ResourcesEvent() {} // ResourcesEvent implements ResourcesEvent. func (*ContainerLog) ResourcesEvent() {} @@ -289,8 +289,8 @@ const ( // SlurmProxyIfaceEnvVar is the env var for overriding the net iface used to proxy between // the master and agents. SlurmProxyIfaceEnvVar = "DET_SLURM_PROXY_IFACE" - // ResourcesTypeK8sPod indicates the resources are a handle for a k8s pod. - ResourcesTypeK8sPod ResourcesType = "k8s-pod" + // ResourcesTypeK8sJob indicates the resources are a handle for a k8s pod. + ResourcesTypeK8sJob ResourcesType = "k8s-job" // ResourcesTypeDockerContainer indicates the resources are a handle for a docker container. ResourcesTypeDockerContainer ResourcesType = "docker-container" // ResourcesTypeSlurmJob indicates the resources are a handle for a slurm job. diff --git a/master/internal/sproto/task_actor.go b/master/internal/sproto/task_actor.go index 3d05d42b448..3acbe4d6ce6 100644 --- a/master/internal/sproto/task_actor.go +++ b/master/internal/sproto/task_actor.go @@ -31,12 +31,6 @@ type ( AgentID *string } - // UpdatePodStatus notifies the resource manager of job state changes. - UpdatePodStatus struct { - ContainerID string - State SchedulingState - } - // SetGroupMaxSlots sets the maximum number of slots that a group can consume in the cluster. SetGroupMaxSlots struct { MaxSlots *int diff --git a/master/internal/task/allocation.go b/master/internal/task/allocation.go index 3d460568136..5790c80067b 100644 --- a/master/internal/task/allocation.go +++ b/master/internal/task/allocation.go @@ -257,7 +257,7 @@ func (a *allocation) HandleRMEvent(msg sproto.ResourcesEvent) (done bool) { a.releaseResources(msg) case *sproto.ContainerLog: a.sendTaskLog(msg.ToTaskLog()) - case *sproto.ResourcesRestoreError: + case *sproto.ResourcesFailedError: a.restoreResourceFailure(msg) return true case *sproto.InvalidResourcesRequestError: @@ -328,7 +328,7 @@ func (a *allocation) SetProxyAddress(ctx context.Context, address string) error defer a.mu.Unlock() if len(a.req.ProxyPorts) == 0 { - a.syslog.Debug("No ports to proxy. Skipping proxy registration.") + a.syslog.Debug("no ports to proxy, skipping proxy registration.") return nil } a.model.ProxyAddress = &address @@ -481,7 +481,7 @@ func (a *allocation) validateRendezvous() error { } switch a.resources.first().Summary().ResourcesType { - case sproto.ResourcesTypeDockerContainer, sproto.ResourcesTypeK8sPod: + case sproto.ResourcesTypeDockerContainer, sproto.ResourcesTypeK8sJob: break default: return BehaviorUnsupportedError{Behavior: "rendezvous"} @@ -591,7 +591,12 @@ func (a *allocation) finalize( // heavy stuff unless it is necessarily (which also works to spread occurrences of the same work // out). Eventually, Allocations should just be started with their TaskSpec. func (a *allocation) resourcesAllocated(msg *sproto.ResourcesAllocated) error { - a.syslog.WithField("restore", a.req.Restore).Infof("%d resources allocated", len(msg.Resources)) + syslog := a.syslog.WithField("restored", a.req.Restore) + if syslog.Level >= logrus.DebugLevel { + syslog = syslog.WithField("count", len(msg.Resources)) + } + syslog.Infof("resources allocated") + if !a.req.Restore { if a.getModelState() != model.AllocationStatePending { // If we have moved on from the pending state, these must be stale (and we must have @@ -724,7 +729,7 @@ func (a *allocation) resourcesStateChanged(msg *sproto.ResourcesStateChanged) { } a.resources[msg.ResourcesID].Container = msg.Container - a.syslog.Debugf("resources state changed: %+v", msg) + a.syslog.Debugf("resources state changed: %s", msg) switch msg.ResourcesState { case sproto.Pulling: a.setMostProgressedModelState(model.AllocationStatePulling) @@ -847,7 +852,7 @@ func (a *allocation) resourcesStateChanged(msg *sproto.ResourcesStateChanged) { } // restoreResourceFailure handles the restored resource failures. -func (a *allocation) restoreResourceFailure(msg *sproto.ResourcesRestoreError) { +func (a *allocation) restoreResourceFailure(msg *sproto.ResourcesFailedError) { a.syslog.Debugf("allocation resource failure") a.setMostProgressedModelState(model.AllocationStateTerminating) @@ -1028,7 +1033,7 @@ func (a *allocation) exitedWithoutErr() bool { func (a *allocation) SetExitStatus(exitReason string, exitErr error, statusCode *int32) { switch err := exitErr.(type) { - case sproto.ResourcesRestoreError: + case sproto.ResourcesFailedError: a.model.ExitErr = ptrs.Ptr(err.Error()) if err.ExitCode != nil { a.model.StatusCode = ptrs.Ptr(int32(*err.ExitCode)) @@ -1158,7 +1163,7 @@ func (a *allocation) calculateExitStatus(reason string) ( return fmt.Sprintf("allocation stopped early after %s", reason), true, logrus.InfoLevel, nil case a.exitErr != nil: switch err := a.exitErr.(type) { - case sproto.ResourcesRestoreError: + case sproto.ResourcesFailedError: switch err.FailureType { case sproto.ResourcesFailed, sproto.TaskError: if a.killedDaemonsGracefully { diff --git a/master/internal/task/allocation_intg_test.go b/master/internal/task/allocation_intg_test.go index fbb142f9e01..0d3ad3a1964 100644 --- a/master/internal/task/allocation_intg_test.go +++ b/master/internal/task/allocation_intg_test.go @@ -38,7 +38,7 @@ func (m mockTaskSpecifier) ToTaskSpec() (t tasks.TaskSpec) { func TestAllocation(t *testing.T) { cases := []struct { name string - err *sproto.ResourcesRestoreError + err *sproto.ResourcesFailedError acked bool exit *AllocationExited }{ @@ -55,13 +55,13 @@ func TestAllocation(t *testing.T) { { name: "container failed", acked: false, - err: &sproto.ResourcesRestoreError{FailureType: sproto.ResourcesFailed}, - exit: &AllocationExited{Err: sproto.ResourcesRestoreError{FailureType: sproto.ResourcesFailed}}, + err: &sproto.ResourcesFailedError{FailureType: sproto.ResourcesFailed}, + exit: &AllocationExited{Err: sproto.ResourcesFailedError{FailureType: sproto.ResourcesFailed}}, }, { name: "container failed, but acked preemption", acked: true, - err: &sproto.ResourcesRestoreError{FailureType: sproto.ResourcesFailed}, + err: &sproto.ResourcesFailedError{FailureType: sproto.ResourcesFailed}, exit: &AllocationExited{}, }, } diff --git a/master/internal/task/allocation_service_test.go b/master/internal/task/allocation_service_test.go index a94686e07b1..021d37642cd 100644 --- a/master/internal/task/allocation_service_test.go +++ b/master/internal/task/allocation_service_test.go @@ -42,7 +42,7 @@ func TestRestoreFailed(t *testing.T) { defer close() defer requireKilled(t, db, id, q, exitFuture) - q.Put(&sproto.ResourcesRestoreError{ + q.Put(&sproto.ResourcesFailedError{ FailureType: sproto.RestoreError, ErrMsg: "things weren't there", }) @@ -670,7 +670,7 @@ func requireAssignedMany( ResourcesID: rID, ResourcesState: sproto.Terminated, ResourcesStopped: &sproto.ResourcesStopped{ - Failure: &sproto.ResourcesRestoreError{ + Failure: &sproto.ResourcesFailedError{ FailureType: sproto.TaskError, ErrMsg: "exit code 137", ExitCode: ptrs.Ptr(sproto.ExitCode(137)), diff --git a/master/internal/webhooks/postgres_webhook.go b/master/internal/webhooks/postgres_webhook.go index e37646a454c..f955dbe9bf2 100644 --- a/master/internal/webhooks/postgres_webhook.go +++ b/master/internal/webhooks/postgres_webhook.go @@ -459,7 +459,7 @@ func generateSlackPayload( var wID int var w *model.Workspace config := conf.GetMasterConfig() - wName := activeConfig.Workspace() // TODO(!!!) this is incorrect on moves. + wName := activeConfig.Workspace() // TODO(ET-288): This is incorrect on moves. pName := activeConfig.Project() webUIBaseURL := config.Webhooks.BaseURL baseURLIsSet := webUIBaseURL != "" diff --git a/master/pkg/cproto/state.go b/master/pkg/cproto/state.go index e04f70c546a..23338442fca 100644 --- a/master/pkg/cproto/state.go +++ b/master/pkg/cproto/state.go @@ -18,7 +18,7 @@ func (s State) String() string { return string(s) } -// Before returns if our state comes before or is equal to another. Callers have an implicit +// Before returns true if our state comes before or is equal to another. Callers have an implicit // assumption that states always transition in order. func (s State) Before(other State) bool { ordering := []State{ diff --git a/master/pkg/tasks/task.go b/master/pkg/tasks/task.go index 29e9172e940..b70196b8121 100644 --- a/master/pkg/tasks/task.go +++ b/master/pkg/tasks/task.go @@ -115,6 +115,9 @@ type TaskSpec struct { Labels []string // Ports required by trial or commands and their respective base port values. UniqueExposedPortRequests map[string]int + + // For testing only. + DontShipLogs bool } // Clone deep copies a taskSpec. @@ -242,6 +245,10 @@ func (t TaskSpec) EnvVars() map[string]string { // LogShipperWrappedEntrypoint returns the configured Entrypoint wrapped with ship_logs.py. func (t *TaskSpec) LogShipperWrappedEntrypoint() []string { + if t.DontShipLogs { + return t.Entrypoint + } + // Prepend the entrypoint like: `ship-logs.sh ship_logs.py "$@"`. shipLogsShell := filepath.Join(RunDir, taskShipLogsShell) shipLogsPython := filepath.Join(RunDir, taskShipLogsPython) diff --git a/master/pkg/tasks/task_trial.go b/master/pkg/tasks/task_trial.go index 0c1a6fce917..fa7b45bdb9d 100644 --- a/master/pkg/tasks/task_trial.go +++ b/master/pkg/tasks/task_trial.go @@ -103,9 +103,10 @@ func (s TrialSpec) ToTaskSpec() TaskSpec { } res.Description = fmt.Sprintf( - "exp-%d-trial-%d", + "exp-%d-trial-%d-attempt-%d", s.ExperimentID, s.TrialID, + s.TrialRunID, ) res.Entrypoint = []string{"/run/determined/train/entrypoint.sh"} diff --git a/master/test/testutils/fixtures.go b/master/test/testutils/fixtures.go index 459670c18c0..c0550d83824 100644 --- a/master/test/testutils/fixtures.go +++ b/master/test/testutils/fixtures.go @@ -192,13 +192,25 @@ func DefaultMasterConfig() (*config.Config, error) { // DefaultElasticConfig returns the default elastic config. func DefaultElasticConfig() model.LoggingConfig { - port, err := strconv.Atoi(os.Getenv("DET_INTEGRATION_ES_PORT")) - if err != nil { - panic("elastic config had non-numeric port") + host := os.Getenv("DET_INTEGRATION_ES_HOST") + if host == "" { + host = "localhost" } + + var port int + if portStr := os.Getenv("DET_INTEGRATION_ES_PORT"); portStr != "" { + parsed, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Errorf("elastic config had non-numeric port: %s", err)) + } + port = parsed + } else { + port = 9200 + } + return model.LoggingConfig{ ElasticLoggingConfig: &model.ElasticLoggingConfig{ - Host: os.Getenv("DET_INTEGRATION_ES_HOST"), + Host: host, Port: port, }, }