Skip to content

Commit

Permalink
Fix CPU error always being raised (#1175)
Browse files Browse the repository at this point in the history
* Save state

* Revert to old behavior

* Fix failing test/update

* Remove duplicate test
  • Loading branch information
muellerzr authored and sgugger committed Mar 13, 2023
1 parent 1a63f7d commit 66065a5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
17 changes: 7 additions & 10 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,6 @@ class PartialState:

def __init__(self, cpu: bool = False, **kwargs):
self.__dict__ = self._shared_state
# Raise an error if the user tries to reinitialize on a different device setup in the same launch
if self.initialized and (self._cpu != cpu):
raise AssertionError(
"The current device and desired device are not the same. If the `PartialState` was generated "
"before the `Accelerator` has been instantiated, ensure the `cpu` flag is the same for both. In this case, "
f"the `PartialState` has {self._cpu} and the desired device is {cpu}. Please use `cpu={self._cpu}`."
)
if not self.initialized:
self._cpu = cpu
self.backend = None
Expand Down Expand Up @@ -540,10 +533,12 @@ def __init__(
**kwargs,
):
self.__dict__ = self._shared_state
if PartialState._shared_state == {} or (cpu != PartialState._shared_state.get("_cpu", False)):
if parse_flag_from_env("ACCELERATE_USE_CPU"):
cpu = True
if PartialState._shared_state == {}:
PartialState(cpu, **kwargs)
self.__dict__.update(PartialState._shared_state)
self._check_initialized(mixed_precision)
self._check_initialized(mixed_precision, cpu)
if not self.initialized:
self.deepspeed_plugin = None
mixed_precision = (
Expand Down Expand Up @@ -599,10 +594,12 @@ def __repr__(self):
repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n"
return repr

def _check_initialized(self, mixed_precision=None):
def _check_initialized(self, mixed_precision=None, cpu=None):
"Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized"
if self.initialized:
err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`."
if cpu and self.device.type != "cpu":
raise ValueError(err.format(flag="cpu=True"))
if (
mixed_precision is not None
and mixed_precision != self._mixed_precision
Expand Down
9 changes: 8 additions & 1 deletion tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_accelerator_can_be_reinstantiated(self):
_ = Accelerator()
assert PartialState._shared_state["_cpu"] is False
assert PartialState._shared_state["device"].type == "cuda"
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
_ = Accelerator(cpu=True)

def test_prepared_objects_are_referenced(self):
Expand Down Expand Up @@ -226,3 +226,10 @@ def test_accelerator_bnb_multi_gpu(self):
# This should not work and get value error
with self.assertRaises(ValueError):
_ = accelerator.prepare(model)

@require_cuda
def test_accelerator_cpu_flag_prepare(self):
model = torch.nn.Linear(10, 10)
sgd = torch.optim.SGD(model.parameters(), lr=0.01)
accelerator = Accelerator(cpu=True)
_ = accelerator.prepare(sgd)

0 comments on commit 66065a5

Please sign in to comment.