diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 7127d37f40..35508cc0c7 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -455,9 +455,9 @@ def tensor_hook( state_dict[fqn] = tensor else: state_dict[fqn] = None - # Convert the state dict to the requested precision - if isinstance(tensor, torch.Tensor): - state_dict[fqn] = tensor.to(dtype=self.dtype) + + if isinstance(state_dict[fqn], torch.Tensor): + state_dict[fqn] = state_dict[fqn].to(dtype=self.dtype) del tensor if dist.get_global_rank() != 0: state_dict = {}