From e88265859de335f96b034adc9abc195c3b9719d5 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Wed, 24 Jul 2024 19:27:48 -0700 Subject: [PATCH] Dtensor oom (#1395) --- llmfoundry/callbacks/hf_checkpointer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 7127d37f40..35508cc0c7 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -455,9 +455,9 @@ def tensor_hook( state_dict[fqn] = tensor else: state_dict[fqn] = None - # Convert the state dict to the requested precision - if isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) del tensor if dist.get_global_rank() != 0: state_dict = {}