From 109a56c66e41f81b1a4afd05d24d8a1abb6cfb7e Mon Sep 17 00:00:00 2001 From: DuYicong515 Date: Sun, 20 Mar 2022 22:06:03 -0700 Subject: [PATCH] fix TPU tests --- tests/models/test_tpu.py | 34 ++++------------------------------ 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index e39ae9c84d921b..e4ddc52839f389 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -242,10 +242,11 @@ def test_dataloaders_passed_to_fit(tmpdir): assert trainer.state.finished, f"Training failed with {trainer.state}" -def test_tpu_misconfiguration(): - """Test if trainer.tpu_id is set as expected.""" +@RunIf(tpu=True) +@pytest.mark.parametrize("tpu_cores", [[1, 8], [1], "9, ", [9], [0], 2, 10]) +def test_tpu_misconfiguration(tpu_cores): with pytest.raises(MisconfigurationException, match="`tpu_cores` can only be"): - Trainer(tpu_cores=[1, 8]) + Trainer(tpu_cores=tpu_cores) @pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU") @@ -279,33 +280,6 @@ def test_broadcast(rank): xmp.spawn(test_broadcast, nprocs=8, start_method="fork") -@pytest.mark.parametrize( - ["tpu_cores", "expected_tpu_id", "error_expected"], - [ - (1, None, False), - (8, None, False), - ([1], 1, False), - ([8], 8, False), - ("1,", 1, False), - ("1", None, False), - ("9, ", 9, True), - ([9], 9, True), - ([0], 0, True), - (2, None, True), - (10, None, True), - ], -) -@RunIf(tpu=True) -@pl_multi_process_test -def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): - if error_expected: - with pytest.raises(MisconfigurationException, match=r".*tpu_cores` can only be 1, 8 or [<1-8>]*"): - Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) - else: - trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores) - assert trainer._accelerator_connector.tpu_id == expected_tpu_id - - @pytest.mark.parametrize( ["cli_args", "expected"], [("--tpu_cores=8", {"tpu_cores": 8}), ("--tpu_cores=1,", {"tpu_cores": "1,"})],