diff --git a/Makefile b/Makefile index 90f20601e9..4d2f37f715 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,7 @@ radical_local_test: .PHONY: config_local_test config_local_test: $(CCTOOLS_INSTALL) - pip3 install ".[monitoring,visualization,proxystore]" + pip3 install ".[monitoring,visualization,proxystore,kubernetes]" PYTHONPATH=/tmp/cctools/lib/python3.8/site-packages pytest parsl/tests/ -k "not cleannet" --config local --random-order --durations 10 .PHONY: site_test diff --git a/parsl/providers/kubernetes/kube.py b/parsl/providers/kubernetes/kube.py index 40b5b430a5..8c3b081ad0 100644 --- a/parsl/providers/kubernetes/kube.py +++ b/parsl/providers/kubernetes/kube.py @@ -1,10 +1,5 @@ import logging -import time - -from parsl.providers.kubernetes.template import template_string - -logger = logging.getLogger(__name__) - +import uuid from typing import Any, Dict, List, Optional, Tuple import typeguard @@ -12,7 +7,8 @@ from parsl.errors import OptionalModuleMissing from parsl.jobs.states import JobState, JobStatus from parsl.providers.base import ExecutionProvider -from parsl.utils import RepresentationMixin +from parsl.providers.kubernetes.template import template_string +from parsl.utils import RepresentationMixin, sanitize_dns_subdomain_rfc1123 try: from kubernetes import client, config @@ -20,6 +16,8 @@ except (ImportError, NameError, FileNotFoundError): _kubernetes_enabled = False +logger = logging.getLogger(__name__) + translate_table = { 'Running': JobState.RUNNING, 'Pending': JobState.PENDING, @@ -161,7 +159,7 @@ def __init__(self, self.resources: Dict[object, Dict[str, Any]] self.resources = {} - def submit(self, cmd_string, tasks_per_node, job_name="parsl"): + def submit(self, cmd_string: str, tasks_per_node: int, job_name: str = "parsl.kube"): """ Submit a job Args: - cmd_string :(String) - Name of the container to initiate @@ -173,15 +171,19 @@ def submit(self, cmd_string, tasks_per_node, job_name="parsl"): Returns: - job_id: (string) Identifier for the job """ + job_id = uuid.uuid4().hex[:8] - cur_timestamp = str(time.time() * 1000).split(".")[0] - job_name = "{0}-{1}".format(job_name, cur_timestamp) - - if not self.pod_name: - pod_name = '{}'.format(job_name) - else: - pod_name = '{}-{}'.format(self.pod_name, - cur_timestamp) + pod_name = self.pod_name or job_name + try: + pod_name = sanitize_dns_subdomain_rfc1123(pod_name) + except ValueError: + logger.warning( + f"Invalid pod name '{pod_name}' for job '{job_id}', falling back to 'parsl.kube'" + ) + pod_name = "parsl.kube" + pod_name = pod_name[:253 - 1 - len(job_id)] # Leave room for the job ID + pod_name = pod_name.rstrip(".-") # Remove trailing dot or hyphen after trim + pod_name = f"{pod_name}.{job_id}" formatted_cmd = template_string.format(command=cmd_string, worker_init=self.worker_init) @@ -189,14 +191,14 @@ def submit(self, cmd_string, tasks_per_node, job_name="parsl"): logger.debug("Pod name: %s", pod_name) self._create_pod(image=self.image, pod_name=pod_name, - job_name=job_name, + job_id=job_id, cmd_string=formatted_cmd, volumes=self.persistent_volumes, service_account_name=self.service_account_name, annotations=self.annotations) - self.resources[pod_name] = {'status': JobStatus(JobState.RUNNING)} + self.resources[job_id] = {'status': JobStatus(JobState.RUNNING), 'pod_name': pod_name} - return pod_name + return job_id def status(self, job_ids): """ Get the status of a list of jobs identified by the job identifiers @@ -212,6 +214,9 @@ def status(self, job_ids): self._status() return [self.resources[jid]['status'] for jid in job_ids] + def _get_pod_name(self, job_id: str) -> str: + return self.resources[job_id]['pod_name'] + def cancel(self, job_ids): """ Cancels the jobs specified by a list of job ids Args: @@ -221,7 +226,8 @@ def cancel(self, job_ids): """ for job in job_ids: logger.debug("Terminating job/pod: {0}".format(job)) - self._delete_pod(job) + pod_name = self._get_pod_name(job) + self._delete_pod(pod_name) self.resources[job]['status'] = JobStatus(JobState.CANCELLED) rets = [True for i in job_ids] @@ -242,7 +248,8 @@ def _status(self): for jid in to_poll_job_ids: phase = None try: - pod = self.kube_client.read_namespaced_pod(name=jid, namespace=self.namespace) + pod_name = self._get_pod_name(jid) + pod = self.kube_client.read_namespaced_pod(name=pod_name, namespace=self.namespace) except Exception: logger.exception("Failed to poll pod {} status, most likely because pod was terminated".format(jid)) if self.resources[jid]['status'] is JobStatus(JobState.RUNNING): @@ -257,10 +264,10 @@ def _status(self): self.resources[jid]['status'] = JobStatus(status) def _create_pod(self, - image, - pod_name, - job_name, - port=80, + image: str, + pod_name: str, + job_id: str, + port: int = 80, cmd_string=None, volumes=[], service_account_name=None, @@ -269,7 +276,7 @@ def _create_pod(self, Args: - image (string) : Docker image to launch - pod_name (string) : Name of the pod - - job_name (string) : App label + - job_id (string) : Job ID KWargs: - port (integer) : Container port Returns: @@ -299,7 +306,7 @@ def _create_pod(self, ) # Configure Pod template container container = client.V1Container( - name=pod_name, + name=job_id, image=image, resources=resources, ports=[client.V1ContainerPort(container_port=port)], @@ -322,7 +329,7 @@ def _create_pod(self, claim_name=volume[0]))) metadata = client.V1ObjectMeta(name=pod_name, - labels={"app": job_name}, + labels={"parsl-job-id": job_id}, annotations=annotations) spec = client.V1PodSpec(containers=[container], image_pull_secrets=[secret], diff --git a/parsl/tests/test_providers/test_kubernetes_provider.py b/parsl/tests/test_providers/test_kubernetes_provider.py new file mode 100644 index 0000000000..453dc57422 --- /dev/null +++ b/parsl/tests/test_providers/test_kubernetes_provider.py @@ -0,0 +1,102 @@ +import re +from unittest import mock + +import pytest + +from parsl.providers.kubernetes.kube import KubernetesProvider +from parsl.tests.test_utils.test_sanitize_dns import DNS_SUBDOMAIN_REGEX + +_MOCK_BASE = "parsl.providers.kubernetes.kube" + + +@pytest.fixture(autouse=True) +def mock_kube_config(): + with mock.patch(f"{_MOCK_BASE}.config") as mock_config: + mock_config.load_kube_config.return_value = None + yield mock_config + + +@pytest.fixture +def mock_kube_client(): + mock_client = mock.MagicMock() + with mock.patch(f"{_MOCK_BASE}.client.CoreV1Api") as mock_api: + mock_api.return_value = mock_client + yield mock_client + + +@pytest.mark.local +def test_submit_happy_path(mock_kube_client: mock.MagicMock): + image = "test-image" + namespace = "test-namespace" + cmd_string = "test-command" + volumes = [("test-volume", "test-mount-path")] + service_account_name = "test-service-account" + annotations = {"test-annotation": "test-value"} + max_cpu = 2 + max_mem = "2Gi" + init_cpu = 1 + init_mem = "1Gi" + provider = KubernetesProvider( + image=image, + persistent_volumes=volumes, + namespace=namespace, + service_account_name=service_account_name, + annotations=annotations, + max_cpu=max_cpu, + max_mem=max_mem, + init_cpu=init_cpu, + init_mem=init_mem, + ) + + job_name = "test.job.name" + job_id = provider.submit(cmd_string=cmd_string, tasks_per_node=1, job_name=job_name) + + assert job_id in provider.resources + assert mock_kube_client.create_namespaced_pod.call_count == 1 + + call_args = mock_kube_client.create_namespaced_pod.call_args[1] + pod = call_args["body"] + container = pod.spec.containers[0] + volume = container.volume_mounts[0] + + assert image == container.image + assert namespace == call_args["namespace"] + assert any(cmd_string in arg for arg in container.args) + assert volumes[0] == (volume.name, volume.mount_path) + assert service_account_name == pod.spec.service_account_name + assert annotations == pod.metadata.annotations + assert str(max_cpu) == container.resources.limits["cpu"] + assert max_mem == container.resources.limits["memory"] + assert str(init_cpu) == container.resources.requests["cpu"] + assert init_mem == container.resources.requests["memory"] + assert job_id == pod.metadata.labels["parsl-job-id"] + assert job_id == container.name + assert f"{job_name}.{job_id}" == pod.metadata.name + + +@pytest.mark.local +@mock.patch(f"{_MOCK_BASE}.KubernetesProvider._create_pod") +@pytest.mark.parametrize("char", (".", "-")) +def test_submit_pod_name_includes_job_id(mock_create_pod: mock.MagicMock, char: str): + provider = KubernetesProvider(image="test-image") + + job_name = "a." * 121 + f"a{char}" + "a" * 9 + assert len(job_name) == 253 # Max length for pod name + job_id = provider.submit(cmd_string="test-command", tasks_per_node=1, job_name=job_name) + + expected_pod_name = job_name[:253 - len(job_id) - 2] + f".{job_id}" + actual_pod_name = mock_create_pod.call_args[1]["pod_name"] + assert re.match(DNS_SUBDOMAIN_REGEX, actual_pod_name) + assert expected_pod_name == actual_pod_name + + +@pytest.mark.local +@mock.patch(f"{_MOCK_BASE}.KubernetesProvider._create_pod") +@mock.patch(f"{_MOCK_BASE}.logger") +@pytest.mark.parametrize("job_name", ("", ".", "-", "a.-.a", "$$$")) +def test_submit_invalid_job_name(mock_logger: mock.MagicMock, mock_create_pod: mock.MagicMock, job_name: str): + provider = KubernetesProvider(image="test-image") + job_id = provider.submit(cmd_string="test-command", tasks_per_node=1, job_name=job_name) + assert mock_logger.warning.call_count == 1 + assert f"Invalid pod name '{job_name}' for job '{job_id}'" in mock_logger.warning.call_args[0][0] + assert f"parsl.kube.{job_id}" == mock_create_pod.call_args[1]["pod_name"] diff --git a/parsl/tests/test_utils/test_sanitize_dns.py b/parsl/tests/test_utils/test_sanitize_dns.py new file mode 100644 index 0000000000..17b801339c --- /dev/null +++ b/parsl/tests/test_utils/test_sanitize_dns.py @@ -0,0 +1,76 @@ +import random +import re + +import pytest + +from parsl.utils import sanitize_dns_label_rfc1123, sanitize_dns_subdomain_rfc1123 + +# Ref: https://datatracker.ietf.org/doc/html/rfc1123 +DNS_LABEL_REGEX = r'^[a-z0-9]([-a-z0-9]{0,61}[a-z0-9])?$' +DNS_SUBDOMAIN_REGEX = r'^[a-z0-9]([-a-z0-9]{0,61}[a-z0-9])?(\.[a-z0-9]([-a-z0-9]{0,61}[a-z0-9])?)*$' + +test_labels = [ + "example-label-123", # Valid label + "EXAMPLE", # Case sensitivity + "!@#example*", # Remove invalid characters + "--leading-and-trailing--", # Leading and trailing hyphens + "..leading.and.trailing..", # Leading and tailing dots + "multiple..dots", # Consecutive dots + "valid--label", # Consecutive hyphens + "a" * random.randint(64, 70), # Longer than 63 characters + f"{'a' * 62}-a", # Trailing hyphen at max length +] + + +def _generate_test_subdomains(num_subdomains: int): + subdomains = [] + for _ in range(num_subdomains): + num_labels = random.randint(1, 5) + labels = [test_labels[random.randint(0, num_labels - 1)] for _ in range(num_labels)] + subdomain = ".".join(labels) + subdomains.append(subdomain) + return subdomains + + +@pytest.mark.local +@pytest.mark.parametrize("raw_string", test_labels) +def test_sanitize_dns_label_rfc1123(raw_string: str): + print(sanitize_dns_label_rfc1123(raw_string)) + assert re.match(DNS_LABEL_REGEX, sanitize_dns_label_rfc1123(raw_string)) + + +@pytest.mark.local +@pytest.mark.parametrize("raw_string", ("", "-", "@", "$$$")) +def test_sanitize_dns_label_rfc1123_empty(raw_string: str): + with pytest.raises(ValueError) as e_info: + sanitize_dns_label_rfc1123(raw_string) + assert str(e_info.value) == f"Sanitized DNS label is empty for input '{raw_string}'" + + +@pytest.mark.local +@pytest.mark.parametrize("raw_string", _generate_test_subdomains(10)) +def test_sanitize_dns_subdomain_rfc1123(raw_string: str): + assert re.match(DNS_SUBDOMAIN_REGEX, sanitize_dns_subdomain_rfc1123(raw_string)) + + +@pytest.mark.local +@pytest.mark.parametrize("char", ("-", ".")) +def test_sanitize_dns_subdomain_rfc1123_trailing_non_alphanumeric_at_max_length(char: str): + raw_string = (f"{'a' * 61}." * 4) + f".aaaa{char}a" + assert re.match(DNS_SUBDOMAIN_REGEX, sanitize_dns_subdomain_rfc1123(raw_string)) + + +@pytest.mark.local +@pytest.mark.parametrize("raw_string", ("", ".", "...")) +def test_sanitize_dns_subdomain_rfc1123_empty(raw_string: str): + with pytest.raises(ValueError) as e_info: + sanitize_dns_subdomain_rfc1123(raw_string) + assert str(e_info.value) == f"Sanitized DNS subdomain is empty for input '{raw_string}'" + + +@pytest.mark.local +@pytest.mark.parametrize( + "raw_string", ("a" * 253, "a" * random.randint(254, 300)), ids=("254 chars", ">253 chars") +) +def test_sanitize_dns_subdomain_rfc1123_max_length(raw_string: str): + assert len(sanitize_dns_subdomain_rfc1123(raw_string)) <= 253 diff --git a/parsl/utils.py b/parsl/utils.py index 6f36d4506a..0ea5d7d9eb 100644 --- a/parsl/utils.py +++ b/parsl/utils.py @@ -1,6 +1,7 @@ import inspect import logging import os +import re import shlex import subprocess import threading @@ -380,3 +381,80 @@ def __exit__( exc_tb: Optional[TracebackType] ) -> None: self.cancel() + + +def sanitize_dns_label_rfc1123(raw_string: str) -> str: + """Convert input string to a valid RFC 1123 DNS label. + + Parameters + ---------- + raw_string : str + String to sanitize. + + Returns + ------- + str + Sanitized string. + + Raises + ------ + ValueError + If the string is empty after sanitization. + """ + # Convert to lowercase and replace non-alphanumeric characters with hyphen + sanitized = re.sub(r'[^a-z0-9]', '-', raw_string.lower()) + + # Remove consecutive hyphens + sanitized = re.sub(r'-+', '-', sanitized) + + # DNS label cannot exceed 63 characters + sanitized = sanitized[:63] + + # Strip after trimming to avoid trailing hyphens + sanitized = sanitized.strip("-") + + if not sanitized: + raise ValueError(f"Sanitized DNS label is empty for input '{raw_string}'") + + return sanitized + + +def sanitize_dns_subdomain_rfc1123(raw_string: str) -> str: + """Convert input string to a valid RFC 1123 DNS subdomain. + + Parameters + ---------- + raw_string : str + String to sanitize. + + Returns + ------- + str + Sanitized string. + + Raises + ------ + ValueError + If the string is empty after sanitization. + """ + segments = raw_string.split('.') + + sanitized_segments = [] + for segment in segments: + if not segment: + continue + sanitized_segment = sanitize_dns_label_rfc1123(segment) + sanitized_segments.append(sanitized_segment) + + sanitized = '.'.join(sanitized_segments) + + # DNS subdomain cannot exceed 253 characters + sanitized = sanitized[:253] + + # Strip after trimming to avoid trailing dots or hyphens + sanitized = sanitized.strip(".-") + + if not sanitized: + raise ValueError(f"Sanitized DNS subdomain is empty for input '{raw_string}'") + + return sanitized