Skip to content

Commit

Permalink
Fix integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga committed Nov 26, 2024
1 parent 4eaac17 commit a8310cf
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
1 change: 1 addition & 0 deletions tests/tests_fabric/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def restore_env_variables():
"OMP_NUM_THREADS", # set by our launchers
# set by torchdynamo
"TRITON_CACHE_DIR",
"TORCHINDUCTOR_CACHE_DIR",
}
leaked_vars.difference_update(allowlist)
assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}"
Expand Down
32 changes: 20 additions & 12 deletions tests/tests_fabric/strategies/test_model_parallel_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
from tests_fabric.helpers.runif import RunIf


@pytest.fixture
def distributed():
yield
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


class FeedForward(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -81,7 +88,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_setup_device_mesh():
def test_setup_device_mesh(distributed):
from torch.distributed.device_mesh import DeviceMesh

for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)):
Expand Down Expand Up @@ -129,7 +136,7 @@ def fn(model, device_mesh):
"compile",
[True, False],
)
def test_tensor_parallel(compile):
def test_tensor_parallel(distributed, compile):
from torch.distributed._tensor import DTensor

parallelize = _parallelize_feed_forward_tp
Expand Down Expand Up @@ -182,7 +189,7 @@ def test_tensor_parallel(compile):
"compile",
[True, False],
)
def test_fsdp2_tensor_parallel(compile):
def test_fsdp2_tensor_parallel(distributed, compile):
from torch.distributed._tensor import DTensor

parallelize = _parallelize_feed_forward_fsdp2_tp
Expand Down Expand Up @@ -264,14 +271,15 @@ def _train(fabric, model=None, optimizer=None):


@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
@pytest.mark.filterwarnings("ignore::UserWarning")
@pytest.mark.parametrize(
"precision",
[
pytest.param("32-true"),
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
def test_train_save_load(precision, tmp_path):
def test_train_save_load(distributed, precision, tmp_path):
"""Test 2D-parallel training, saving and loading precision settings."""
strategy = ModelParallelStrategy(
_parallelize_feed_forward_fsdp2_tp,
Expand Down Expand Up @@ -329,7 +337,7 @@ def test_train_save_load(precision, tmp_path):

@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_full_state_dict(tmp_path):
def test_save_full_state_dict(distributed, tmp_path):
"""Test that ModelParallelStrategy saves the full state into a single file with
`save_distributed_checkpoint=False`."""
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
Expand Down Expand Up @@ -430,7 +438,7 @@ def test_save_full_state_dict(tmp_path):

@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_load_full_state_dict_into_sharded_model(tmp_path):
def test_load_full_state_dict_into_sharded_model(distributed, tmp_path):
"""Test that the strategy can load a full-state checkpoint into a distributed model."""
fabric = Fabric(accelerator="cuda", devices=1)
fabric.seed_everything(0)
Expand Down Expand Up @@ -476,7 +484,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path):
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("move_to_device", [True, False])
@mock.patch("lightning.fabric.wrappers._FabricModule")
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
def test_setup_module_move_to_device(fabric_module_mock, move_to_device, distributed):
"""Test that `move_to_device` does nothing, ModelParallel decides which device parameters get moved to which device
(sharding)."""
from torch.distributed._tensor import DTensor
Expand Down Expand Up @@ -508,7 +516,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
],
)
def test_module_init_context(precision, expected_dtype):
def test_module_init_context(distributed, precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
strategy = ModelParallelStrategy(parallelize_fn=_parallelize_feed_forward_fsdp2)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy, precision=precision)
Expand All @@ -531,7 +539,7 @@ def _run_setup_assertions(empty_init, expected_device):


@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_filter(tmp_path):
def test_save_filter(distributed, tmp_path):
strategy = ModelParallelStrategy(
parallelize_fn=_parallelize_feed_forward_fsdp2,
save_distributed_checkpoint=False,
Expand Down Expand Up @@ -584,7 +592,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
"val",
],
)
def test_clip_gradients(clip_type, precision):
def test_clip_gradients(distributed, clip_type, precision):
strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2)
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
fabric.launch()
Expand Down Expand Up @@ -626,7 +634,7 @@ def test_clip_gradients(clip_type, precision):


@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
def test_save_sharded_and_consolidate_and_load(tmp_path):
def test_save_sharded_and_consolidate_and_load(distributed, tmp_path):
"""Test the consolidation of a distributed (DTensor) checkpoint into a single file."""
strategy = ModelParallelStrategy(
_parallelize_feed_forward_fsdp2_tp,
Expand Down Expand Up @@ -683,7 +691,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):


@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_load_raw_module_state():
def test_load_raw_module_state(distributed):
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module

Expand Down
25 changes: 16 additions & 9 deletions tests/tests_pytorch/strategies/test_model_parallel_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def fn(model, device_mesh):
return fn


@pytest.fixture
def distributed():
yield
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


class TemplateModel(LightningModule):
def __init__(self, compile=False):
super().__init__()
Expand Down Expand Up @@ -130,7 +137,7 @@ def configure_model(self):


@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_setup_device_mesh():
def test_setup_device_mesh(distributed):
from torch.distributed.device_mesh import DeviceMesh

for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)):
Expand Down Expand Up @@ -191,7 +198,7 @@ def configure_model(self):
"compile",
[True, False],
)
def test_tensor_parallel(compile):
def test_tensor_parallel(distributed, compile):
from torch.distributed._tensor import DTensor

class Model(TensorParallelModel):
Expand Down Expand Up @@ -236,7 +243,7 @@ def training_step(self, batch):
"compile",
[True, False],
)
def test_fsdp2_tensor_parallel(compile):
def test_fsdp2_tensor_parallel(distributed, compile):
from torch.distributed._tensor import DTensor

class Model(FSDP2TensorParallelModel):
Expand Down Expand Up @@ -293,7 +300,7 @@ def training_step(self, batch):


@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_modules_without_parameters(tmp_path):
def test_modules_without_parameters(distributed, tmp_path):
"""Test that TorchMetrics get moved to the device despite not having any parameters."""

class MetricsModel(TensorParallelModel):
Expand Down Expand Up @@ -336,7 +343,7 @@ def training_step(self, batch):
"compile",
[True, False],
)
def test_module_init_context(compile, precision, expected_dtype, tmp_path):
def test_module_init_context(distributed, compile, precision, expected_dtype, tmp_path):
"""Test that the module under the init-context gets moved to the right device and dtype."""

class Model(FSDP2Model):
Expand Down Expand Up @@ -375,7 +382,7 @@ def _run_setup_assertions(empty_init, expected_device):

@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("save_distributed_checkpoint", [True, False])
def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
def test_strategy_state_dict(distributed, tmp_path, save_distributed_checkpoint):
"""Test that the strategy returns the correct state dict of the LightningModule."""
model = FSDP2Model()
correct_state_dict = model.state_dict() # State dict before wrapping
Expand Down Expand Up @@ -408,7 +415,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):


@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_load_full_state_checkpoint_into_regular_model(tmp_path):
def test_load_full_state_checkpoint_into_regular_model(distributed, tmp_path):
"""Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model."""

# Save a regular full-state checkpoint from a distributed model
Expand Down Expand Up @@ -450,7 +457,7 @@ def test_load_full_state_checkpoint_into_regular_model(tmp_path):

@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_load_standard_checkpoint_into_distributed_model(tmp_path):
def test_load_standard_checkpoint_into_distributed_model(distributed, tmp_path):
"""Test that a regular checkpoint (weights and optimizer states) can be loaded into a distributed model."""

# Save a regular DDP checkpoint
Expand Down Expand Up @@ -491,7 +498,7 @@ def test_load_standard_checkpoint_into_distributed_model(tmp_path):

@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_load_sharded_state_dict(tmp_path):
def test_save_load_sharded_state_dict(distributed, tmp_path):
"""Test saving and loading with the distributed state dict format."""

class CheckpointModel(FSDP2Model):
Expand Down

0 comments on commit a8310cf

Please sign in to comment.