Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input validation for num_nodes argument #18598

Merged
merged 2 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def _check_config_and_set_final_flags(
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
if not isinstance(num_nodes, int) or num_nodes < 1:
raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.")

self._num_nodes_flag = num_nodes
self._devices_flag = devices

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,10 @@ def _check_config_and_set_final_flags(
self._accelerator_flag = "cuda"
self._parallel_devices = self._strategy_flag.parallel_devices

def _check_device_config_and_set_final_flags(
self,
devices: Union[List[int], str, int],
num_nodes: int,
) -> None:
def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str, int], num_nodes: int) -> None:
if not isinstance(num_nodes, int) or num_nodes < 1:
raise ValueError(f"`num_nodes` must be a positive integer, but got {num_nodes}.")

self._num_nodes_flag = num_nodes
self._devices_flag = devices

Expand Down
7 changes: 7 additions & 0 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ def test_strategy_choice_multi_node_gpu(_, strategy, strategy_class, devices):
assert isinstance(connector.strategy, strategy_class)


def test_num_nodes_input_validation():
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_Connector(num_nodes=0)
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_Connector(num_nodes=-1)


@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=0)
def test_cuda_accelerator_can_not_run_on_system(_):
connector = _Connector(accelerator="cpu")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,13 @@ def test_connector_sets_num_nodes(strategy, cuda_count_2):
assert trainer.strategy.num_nodes == 2


def test_connector_num_nodes_input_validation():
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_AcceleratorConnector(num_nodes=0)
with pytest.raises(ValueError, match="`num_nodes` must be a positive integer"):
_AcceleratorConnector(num_nodes=-1)


@pytest.mark.parametrize(
("precision_str", "strategy_str", "expected_precision_cls"),
[
Expand Down