From 5f5ba833da19c464d85c6917a960d41c4ba8002f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 20 Sep 2023 15:07:22 +0200 Subject: [PATCH 1/2] input validation for num_nodes --- src/lightning/fabric/connector.py | 3 +++ .../pytorch/trainer/connectors/accelerator_connector.py | 9 ++++----- tests/tests_fabric/test_connector.py | 7 +++++++ .../trainer/connectors/test_accelerator_connector.py | 7 +++++++ 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 1723b341be4f1..f71149354513e 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -293,6 +293,9 @@ def _check_config_and_set_final_flags( self._parallel_devices = self._strategy_flag.parallel_devices def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: + if not isinstance(num_nodes, int) or num_nodes < 1: + raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") + self._num_nodes_flag = num_nodes self._devices_flag = devices diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 300fcd3c5589b..fb44ab8d9040d 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -317,11 +317,10 @@ def _check_config_and_set_final_flags( self._accelerator_flag = "cuda" self._parallel_devices = self._strategy_flag.parallel_devices - def _check_device_config_and_set_final_flags( - self, - devices: Union[List[int], str, int], - num_nodes: int, - ) -> None: + def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None: + if not isinstance(num_nodes, int) or num_nodes < 1: + raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.") + self._num_nodes_flag = num_nodes self._devices_flag = devices diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 73b78ca86e28a..56fac505b4654 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -314,6 +314,13 @@ def test_strategy_choice_multi_node_gpu(_, strategy, strategy_class, devices): assert isinstance(connector.strategy, strategy_class) +def test_num_nodes_input_validation(): + with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"): + _Connector(num_nodes=0) + with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"): + _Connector(num_nodes="string") + + @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0) def test_cuda_accelerator_can_not_run_on_system(_): connector = _Connector(accelerator="cpu") diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index ea047040c6bcd..a7a6e928722ef 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -1013,6 +1013,13 @@ def test_connector_sets_num_nodes(strategy, cuda_count_2): assert trainer.strategy.num_nodes == 2 +def test_connector_num_nodes_input_validation(): + with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"): + _AcceleratorConnector(num_nodes=0) + with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"): + _AcceleratorConnector(num_nodes=-1) + + @pytest.mark.parametrize( ("precision_str", "strategy_str", "expected_precision_cls"), [ From 7fb24493bfcec9428719dc9f138eac92819f64d6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 20 Sep 2023 15:13:39 +0200 Subject: [PATCH 2/2] update --- tests/tests_fabric/test_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 56fac505b4654..c2f35362deabf 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -318,7 +318,7 @@ def test_num_nodes_input_validation(): with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"): _Connector(num_nodes=0) with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"): - _Connector(num_nodes="string") + _Connector(num_nodes=-1) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0)