Skip to content

Commit

Permalink
Handle edge case in Fabric.setup() when model has no parameters (#1…
Browse files Browse the repository at this point in the history
…7441)

(cherry picked from commit 0631fa0)
  • Loading branch information
awaelchli authored and lantiga committed Apr 24, 2023
1 parent 8c9cf00 commit 6e344b8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `LightningModule.*_step` methods bypassing the DDP/FSDP wrapper ([#17424](https://github.com/Lightning-AI/lightning/pull/17424))


- Fixed device handling in `Fabric.setup()` when the model has no parameters ([#17441](https://github.com/Lightning-AI/lightning/pull/17441))


## [2.0.1] - 2023-03-30

### Changed
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def setup(
module = _FabricModule(module, self._precision, original_module=original_module)

# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters()).device)
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)

optimizers = [_FabricOptimizer(optimizer=optimizer, strategy=self._strategy) for optimizer in optimizers]

Expand Down Expand Up @@ -248,7 +248,7 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True) -> _Fabri

if not isinstance(self._strategy, FSDPStrategy):
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters()).device)
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)

if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self # type: ignore[assignment]
Expand Down Expand Up @@ -741,7 +741,7 @@ def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) ->
return run_function(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
initial_device = next(model.parameters()).device
initial_device = next(model.parameters(), torch.tensor(0)).device
if any(param.device != initial_device for param in model.parameters()):
rank_zero_warn(
"The model passed to `Fabric.setup()` has parameters on different devices. Since `move_to_device=True`,"
Expand Down
5 changes: 5 additions & 0 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def test_setup_module_move_to_device(setup_method, move_to_device, accelerator,
assert fabric_model.device == expected_device
assert fabric.device == target_device

# edge case: model has no parameters
model = nn.Sequential()
fabric_model = setup_method(model, move_to_device=move_to_device)
assert fabric_model.device == target_device if move_to_device else torch.device("cpu")


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize("move_to_device", [True, False])
Expand Down

0 comments on commit 6e344b8

Please sign in to comment.