diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index 9f7e3e67c5a52..b04a29b691529 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -29,7 +29,7 @@ from tests_fabric.helpers.runif import RunIf -@pytest.fixture +@pytest.fixture() def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 015b2d417a240..9dcbcc802834b 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -86,7 +86,7 @@ def fn(model, device_mesh): return fn -@pytest.fixture +@pytest.fixture() def distributed(): yield if torch.distributed.is_initialized():