diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 87ccd6b105d..487868d1d9e 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -340,14 +340,15 @@ def get_accelerator_from_label_value(cls, value: str) -> str: """ canonical_gpu_names = [ 'A100-80GB', 'A100', 'A10G', 'H100', 'K80', 'M60', 'T4g', 'T4', - 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L4' + 'V100', 'A10', 'P4000', 'P100', 'P40', 'P4', 'L40', 'L4' ] for canonical_name in canonical_gpu_names: # A100-80G accelerator is A100-SXM-80GB or A100-PCIE-80GB if canonical_name == 'A100-80GB' and re.search( r'A100.*-80GB', value): return canonical_name - elif canonical_name in value: + # Use word boundary matching to prevent substring matches + elif re.search(rf'\b{re.escape(canonical_name)}\b', value): return canonical_name # If we didn't find a canonical name: diff --git a/tests/unit_tests/kubernetes/test_gpu_label_formatters.py b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py new file mode 100644 index 00000000000..cd7337dc7a1 --- /dev/null +++ b/tests/unit_tests/kubernetes/test_gpu_label_formatters.py @@ -0,0 +1,22 @@ +"""Tests for GPU label formatting in Kubernetes integration. + +Tests verify correct GPU detection from Kubernetes labels. +""" +import pytest + +from sky.provision.kubernetes.utils import GFDLabelFormatter + + +def test_gfd_label_formatter(): + """Test word boundary regex matching in GFDLabelFormatter.""" + # Test various GPU name patterns + test_cases = [ + ('NVIDIA-L4-24GB', 'L4'), + ('NVIDIA-L40-48GB', 'L40'), + ('NVIDIA-L400', 'L400'), # Should not match L4 or L40 + ('NVIDIA-L4', 'L4'), + ('L40-GPU', 'L40'), + ] + for input_value, expected in test_cases: + result = GFDLabelFormatter.get_accelerator_from_label_value(input_value) + assert result == expected, f'Failed for {input_value}'