diff --git a/src/sagemaker_training/pytorch_xla.py b/src/sagemaker_training/pytorch_xla.py index bc98070e..c37d86b0 100644 --- a/src/sagemaker_training/pytorch_xla.py +++ b/src/sagemaker_training/pytorch_xla.py @@ -102,7 +102,7 @@ def _create_command(self): "Please use a python script as the entry-point" ) - def __pytorch_xla_command(self): + def _pytorch_xla_command(self): return [ self._python_command(), "-m", @@ -111,7 +111,7 @@ def __pytorch_xla_command(self): str(self._num_gpus), ] - def __check_compatibility(self): + def _check_compatibility(self): try: import torch_xla # pylint: disable=unused-import except ModuleNotFoundError as exception: diff --git a/test/unit/test_pytorch_xla.py b/test/unit/test_pytorch_xla.py index 849dee2e..9303410e 100644 --- a/test/unit/test_pytorch_xla.py +++ b/test/unit/test_pytorch_xla.py @@ -57,7 +57,7 @@ def num_gpus(instance_type): @pytest.mark.parametrize("instance_type", ["ml.p3.16xlarge", "ml.p3.2xlarge"]) @pytest.mark.parametrize("cluster_size", [1, 4]) class TestPyTorchXLARunner: - @patch.object(PyTorchXLARunner, "__check_compatibility") + @patch.object(PyTorchXLARunner, "_check_compatibility") def test_setup(self, *patches): for current_host in cluster: rank = cluster.index(current_host)