diff --git a/test/unit/test_pytorch_xla.py b/test/unit/test_pytorch_xla.py index 9303410e..8766538c 100644 --- a/test/unit/test_pytorch_xla.py +++ b/test/unit/test_pytorch_xla.py @@ -21,28 +21,28 @@ from sagemaker_training.pytorch_xla import PyTorchXLARunner -@pytest.fixture(autouse=True) -def cluster(cluster_size): +@pytest.fixture(autouse=True, name='cluster') +def _cluster(cluster_size): return [f"algo-{i+1}" for i in range(cluster_size)] -@pytest.fixture(autouse=True) -def master(cluster): +@pytest.fixture(autouse=True, name='master') +def _master(cluster): return cluster[0] -@pytest.fixture(autouse=True) -def cluster_size(): +@pytest.fixture(autouse=True, name='cluster_size') +def _cluster_size(): return 2 -@pytest.fixture(autouse=True) -def instance_type(): +@pytest.fixture(autouse=True, name='instance_type') +def _instance_type(): return "ml.p3.16xlarge" -@pytest.fixture(autouse=True) -def num_gpus(instance_type): +@pytest.fixture(autouse=True, name='num_gpus') +def _num_gpus(instance_type): if instance_type in [ "ml.p3.16xlarge", ]: