Skip to content

Commit

Permalink
[low-bit optim] Fix edge cases for FSDP2 integration (#1269)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Nov 26, 2024
1 parent 615fb0e commit b3493eb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
31 changes: 27 additions & 4 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from packaging.version import Version
from torch import nn
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -381,9 +382,9 @@ def world_size(self) -> int:
return _FSDP_WORLD_SIZE

@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required."
not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required."
)
@skip_if_lt_x_gpu(2)
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
def test_fsdp2(self):
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
if torch.cuda.get_device_capability() >= (8, 9):
Expand All @@ -398,7 +399,6 @@ def _test_fsdp2(self, optim_cls):
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.utils._pytree as pytree
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.tensor import DTensor
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Expand All @@ -412,7 +412,7 @@ def _test_fsdp2(self, optim_cls):
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
dim=512,
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
Expand Down Expand Up @@ -491,6 +491,29 @@ def _test_fsdp2(self, optim_cls):
v2 = v2.dequantize()
self.assertEqual(v1, v2)

@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required."
)
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
def test_uneven_shard(self):
in_dim = 512
out_dim = _FSDP_WORLD_SIZE * 16 + 1

# 1st dim of linear weight will not be divisible by WORLD_SIZE
model = nn.Linear(in_dim, out_dim, device="cuda")
assert model.weight.shape[0] % _FSDP_WORLD_SIZE != 0
fully_shard(model)

# currently all of our low-bit Adam/AdamW share the same implementation.
# thus, we only need to test for 1 optimizer class.
optim = low_bit_optim.AdamW8bit(model.parameters())

for _ in range(2):
inputs = torch.randn(2, in_dim, device="cuda")
model(inputs).sum().backward()
optim.step()
optim.zero_grad()


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
42 changes: 26 additions & 16 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,29 @@ def __setstate__(self, state):
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
raise NotImplementedError

# follow bitsandbytes, only quantize tensors >= 4096 values
# also wrap subclass in DTensor when needed
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=self._subclass_zeros(
p.to_local(), signed, self.block_size
),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
local_p = p.to_local() if isinstance(p, DTensor) else p

# follow bitsandbytes, only quantize tensors >= 4096 values
if local_p.numel() >= 4096 and local_p.numel() % self.block_size == 0:
out = self._subclass_zeros(local_p, signed, self.block_size)
else:
out = torch.zeros_like(p)
out = torch.zeros_like(local_p)

# wrap subclass in DTensor as needed
# NOTE: local tensor may have different shapes across ranks.
# this happens when the 1st dim is not divisible by WORLD_SIZE.
# thus, we must supply shape (and stride) to DTensor.from_local()
if isinstance(p, DTensor):
out = DTensor.from_local(
local_tensor=out,
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
shape=p.shape,
stride=p.stride(),
)

return out

@torch.no_grad()
Expand Down Expand Up @@ -111,8 +117,12 @@ def step(self, closure=None):
"optim.param_groups[0]['lr'].fill_(new_lr)"
)

# without calling p.detach(), torch.compile() will have issues with FSDP2 in some cases
# https://github.com/pytorch/ao/issues/652#issuecomment-2285040894
# thus, by calling p.detach(), DTensor won't have .grad anymore, which is ok since we
# are passing grad separately anyway.
torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
p,
p.detach(),
grad,
state["step"],
state["exp_avg"],
Expand Down

0 comments on commit b3493eb

Please sign in to comment.