Skip to content

Commit

Permalink
fix TPU tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DuYicong515 committed Mar 21, 2022
1 parent 7a8bf82 commit a26e409
Showing 1 changed file with 4 additions and 30 deletions.
34 changes: 4 additions & 30 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,11 @@ def test_dataloaders_passed_to_fit(tmpdir):
assert trainer.state.finished, f"Training failed with {trainer.state}"


def test_tpu_misconfiguration():
"""Test if trainer.tpu_id is set as expected."""
@RunIf(tpu=True)
@pytest.mark.parametrize("tpu_cores", [1, 8, [1], "9, ", [9], [0], 2, 10])
def test_tpu_misconfiguration(tpu_cores):
with pytest.raises(MisconfigurationException, match="`tpu_cores` can only be"):
Trainer(tpu_cores=[1, 8])
Trainer(tpu_cores=tpu_cores)


@pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU")
Expand Down Expand Up @@ -279,33 +280,6 @@ def test_broadcast(rank):
xmp.spawn(test_broadcast, nprocs=8, start_method="fork")


@pytest.mark.parametrize(
["tpu_cores", "expected_tpu_id", "error_expected"],
[
(1, None, False),
(8, None, False),
([1], 1, False),
([8], 8, False),
("1,", 1, False),
("1", None, False),
("9, ", 9, True),
([9], 9, True),
([0], 0, True),
(2, None, True),
(10, None, True),
],
)
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
if error_expected:
with pytest.raises(MisconfigurationException, match=r".*tpu_cores` can only be 1, 8 or [<1-8>]*"):
Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores)
else:
trainer = Trainer(default_root_dir=tmpdir, tpu_cores=tpu_cores)
assert trainer._accelerator_connector.tpu_id == expected_tpu_id


@pytest.mark.parametrize(
["cli_args", "expected"],
[("--tpu_cores=8", {"tpu_cores": 8}), ("--tpu_cores=1,", {"tpu_cores": "1,"})],
Expand Down

0 comments on commit a26e409

Please sign in to comment.