Skip to content

Commit

Permalink
Input validation for num_nodes argument (#18598)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Sep 20, 2023
1 parent 3bfd7b2 commit 66f15cf
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def _check_config_and_set_final_flags(
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
if not isinstance(num_nodes, int) or num_nodes < 1:
raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.")

self._num_nodes_flag = num_nodes
self._devices_flag = devices

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,10 @@ def _check_config_and_set_final_flags(
self._accelerator_flag = "cuda"
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(
self,
devices: Union[List[int], str, int],
num_nodes: int,
) -> None:
def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
if not isinstance(num_nodes, int) or num_nodes < 1:
raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.")

self._num_nodes_flag = num_nodes
self._devices_flag = devices

Expand Down
7 changes: 7 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ def test_strategy_choice_multi_node_gpu(_, strategy, strategy_class, devices):
assert isinstance(connector.strategy, strategy_class)


def test_num_nodes_input_validation():
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_Connector(num_nodes=0)
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_Connector(num_nodes=-1)


@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0)
def test_cuda_accelerator_can_not_run_on_system(_):
connector = _Connector(accelerator="cpu")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,13 @@ def test_connector_sets_num_nodes(strategy, cuda_count_2):
assert trainer.strategy.num_nodes == 2


def test_connector_num_nodes_input_validation():
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_AcceleratorConnector(num_nodes=0)
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_AcceleratorConnector(num_nodes=-1)


@pytest.mark.parametrize(
("precision_str", "strategy_str", "expected_precision_cls"),
[
Expand Down

0 comments on commit 66f15cf

Please sign in to comment.