Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test coverage for lion and lion8b checkpoint interop #679

Merged
merged 3 commits into from
Oct 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions tests/test_lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
LocalOptimStateDictConfig = MagicMock()
ShardedOptimStateDictConfig = MagicMock()

from llmfoundry.optim import DecoupledLionW
from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit

warnings.filterwarnings('ignore')
Expand Down Expand Up @@ -406,8 +407,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore
@pytest.mark.parametrize('use_errors', [False, True])
@pytest.mark.parametrize('state_sharding',
[_FULL_STATE, _SHARDED_STATE, _LOCAL_STATE])
@pytest.mark.parametrize('save_as_lion8b, load_as_lion8b', [(False, True),
(True, False),
(True, True)])
def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
state_sharding: fsdp.StateDictType):
state_sharding: fsdp.StateDictType,
save_as_lion8b: bool, load_as_lion8b: bool):
device = 'cuda'
if torch.cuda.device_count() < 2:
pytest.skip(f'This test requires 2+ GPUs.')
Expand All @@ -419,6 +424,10 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
dist.init_process_group(backend='nccl')
assert dist.get_world_size() >= 2, 'Misconfigured test run!'

# nb: this is the line that causes:
# `Warning: Deallocating Tensor that still has live PyObject references.`
# suggesting this warning isn't an issue with our test code. It's also
# going to stdout (probably from cpp) so we can't suppress it with warnings
mod = FSDP(_DummyModule(device=device, dtype=dtype))

# actual forward pass instead of setting p.grad to avoid FSDP issues
Expand All @@ -429,7 +438,10 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool,
p.grad = torch.rand_like(p)

# create optimizer and have it step so that state gets populated
opt = Lion8bit(mod.parameters(), error_correction=use_errors)
if save_as_lion8b:
opt = Lion8bit(mod.parameters(), error_correction=use_errors)
else:
opt = DecoupledLionW(mod.parameters())
opt.step()
opt.zero_grad()

Expand All @@ -449,13 +461,22 @@ def _set_state_dict_type(model: nn.Module):
FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg,
optim_cfg)

def _local_shard(t: torch.Tensor) -> torch.Tensor:
try: # can't operate on ShardedTensors directly
return t.local_tensor() # type: ignore
except AttributeError:
return t

# load FSDP state dict
_set_state_dict_type(mod)
opt_state_dict = FSDP.optim_state_dict(mod, opt)

# make a new model and optimizer
mod_new = FSDP(_DummyModule(device=device, dtype=dtype))
opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors)
if load_as_lion8b:
opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors)
else:
opt_new = DecoupledLionW(mod_new.parameters())
_set_state_dict_type(mod_new)

# load state dict into the new optimizer
Expand All @@ -480,22 +501,26 @@ def _set_state_dict_type(model: nn.Module):
mom_new = d_new['exp_avg']

assert mom_orig.shape == mom_new.shape
assert mom_orig.dtype == mom_new.dtype
if use_errors and (dtype != torch.float32):
errs_orig = d_orig['errors']
errs_new = d_new['errors']
assert errs_orig.shape == errs_new.shape
assert errs_orig.dtype == errs_new.dtype

if state_sharding != _FULL_STATE:
continue # more detailed checks lean on FSDP impl details
both_lion8b = save_as_lion8b and load_as_lion8b
check_errors = both_lion8b and use_errors and (dtype != torch.float32)
if both_lion8b:
assert mom_orig.dtype == mom_new.dtype
if check_errors:
errs_orig = d_orig['errors']
errs_new = d_new['errors']
assert errs_orig.shape == errs_new.shape
assert errs_orig.dtype == errs_new.dtype

# momentums may not be bit-for-bit identical because Optimizer upcasts
# to f32 and we convert back to bf16, possibly with different rounding
torch.testing.assert_close(mom_orig, mom_new)
torch.testing.assert_close(_local_shard(mom_orig).float(),
_local_shard(mom_new).float(),
atol=1e-4,
rtol=1. / 128)
# errors not bit-for-bit identical because scales get upcast too
if use_errors and (dtype != torch.float32):
torch.testing.assert_close(d_orig['errors'], d_new['errors'])
if check_errors:
torch.testing.assert_close(_local_shard(d_orig['errors']),
_local_shard(d_new['errors']))


@pytest.mark.gpu
Expand Down
Loading