Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct loading of models with shared tensors when using accelerator.load_state() #2875

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np
import torch
from safetensors.torch import load_file
from safetensors.torch import load_model
from torch.cuda.amp import GradScaler

from .utils import (
Expand Down Expand Up @@ -196,12 +196,12 @@ def load_accelerator_state(
ending = f"_{i}" if i > 0 else ""
input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors")
if input_model_file.exists():
state_dict = load_file(input_model_file, device=str(map_location))
load_model(model, input_model_file, device=str(map_location), **load_model_func_kwargs)
else:
# Load with torch
input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin")
state_dict = torch.load(input_model_file, map_location=map_location)
models[i].load_state_dict(state_dict, **load_model_func_kwargs)
model.load_state_dict(state_dict, **load_model_func_kwargs)
logger.info("All model weights loaded successfully")
Comment on lines -204 to 205
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for this change? I'd expect only the prior to be modified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the if statement, he's loading the safetensors model directly whereas before, we were only getting the state dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_model does both: it loads the file and uses it to populate the state_dict. Previously, each branch of the if-condition only loaded the file and after the if-condition, the model would load the state dict. Since load_model does both, I indented the statement on line 204 to become part of the else-clause. This becomes clearer when you have a look at the complete surroundings of the changes instead of only the affected lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other change in this line (aside from the indent) namely using model instead of models[i] is mostly cosmetic. My linter was complaining that the enumerate call defines model but it's never used.


# Optimizer states
Expand Down
30 changes: 23 additions & 7 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,20 @@
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model


def create_components():
model = torch.nn.Linear(2, 4)
class ModelWithTiedWeights(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 4)
self.linear2 = torch.nn.Linear(4, 2)
self.linear2.weight = self.linear1.weight
self.linear2.bias = self.linear1.bias

def forward(self, x):
return self.linear2(self.linear1(x))


def create_components(tied_weights=False):
model = ModelWithTiedWeights() if tied_weights else torch.nn.Linear(2, 4)
optimizer = torch.optim.AdamW(model.parameters(), lr=1.0)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)
train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3])))
Expand All @@ -54,18 +66,22 @@ def forward(self, x):


def get_signature(model):
return (model.weight.abs().sum() + model.bias.abs().sum()).item()
return sum(param.abs().sum().item() for param in model.parameters())


def load_random_weights(model):
state = torch.nn.Linear(*tuple(model.weight.T.shape)).state_dict()
if isinstance(model, torch.nn.Linear):
state = torch.nn.Linear(*tuple(model.weight.T.shape)).state_dict()
elif isinstance(model, ModelWithTiedWeights):
state = ModelWithTiedWeights().state_dict()
model.load_state_dict(state)


def parameterized_custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch"
param_based_name += "_tied_weights" if (len(param.args) == 2 and param.args[1] is True) else ""
return f"{func.__name__}_{param_based_name}"


Expand Down Expand Up @@ -230,10 +246,10 @@ def noop(*args, **kwargs):
accelerator = Accelerator()
assert str(accelerator.state.device) == "cuda:64"

@parameterized.expand((True, False), name_func=parameterized_custom_name_func)
def test_save_load_model(self, use_safetensors):
@parameterized.expand([(True, True), (True, False), (False, False)], name_func=parameterized_custom_name_func)
def test_save_load_model(self, use_safetensors, tied_weights):
accelerator = Accelerator()
model, optimizer, scheduler, train_dl, valid_dl = create_components()
model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)
accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)

model_signature = get_signature(model)
Expand Down
Loading