Skip to content

Commit

Permalink
Deduplicate tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Aug 28, 2021
1 parent 8d1c423 commit 0a475af
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 47 deletions.
2 changes: 2 additions & 0 deletions .azure-pipelines/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ jobs:
bash pl_examples/run_examples.sh --trainer.gpus=1
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp --trainer.precision=16
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp --trainer.precision=16
env:
PL_USE_MOCKED_MNIST: "1"
displayName: 'Testing: examples'
Expand Down
51 changes: 4 additions & 47 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,70 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import platform
from unittest import mock

import pytest
import torch

from pl_examples import _DALI_AVAILABLE
from tests.helpers.runif import RunIf

ARGS_DEFAULT = (
"--trainer.default_root_dir %(tmpdir)s "
"--trainer.max_epochs 1 "
"--trainer.limit_train_batches 2 "
"--trainer.limit_val_batches 2 "
"--trainer.limit_test_batches 2 "
"--trainer.limit_predict_batch4es 2 "
"--data.batch_size 32 "
)
ARGS_GPU = ARGS_DEFAULT + "--trainer.gpus 1 "
ARGS_DP = ARGS_DEFAULT + "--trainer.gpus 2 --trainer.accelerator dp "
ARGS_AMP = "--trainer.precision 16 "


@pytest.mark.parametrize(
"import_cli",
[
"pl_examples.basic_examples.simple_image_classifier",
"pl_examples.basic_examples.backbone_image_classifier",
"pl_examples.basic_examples.autoencoder",
],
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize("cli_args", [ARGS_DP, ARGS_DP + ARGS_AMP])
def test_examples_dp(tmpdir, import_cli, cli_args):

module = importlib.import_module(import_cli)
# update the temp dir
cli_args = cli_args % {"tmpdir": tmpdir}

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
module.cli_main()


@pytest.mark.parametrize(
"import_cli",
[
"pl_examples.basic_examples.simple_image_classifier",
"pl_examples.basic_examples.backbone_image_classifier",
"pl_examples.basic_examples.autoencoder",
],
)
@pytest.mark.parametrize("cli_args", [ARGS_DEFAULT])
def test_examples_cpu(tmpdir, import_cli, cli_args):

module = importlib.import_module(import_cli)
# update the temp dir
cli_args = cli_args % {"tmpdir": tmpdir}

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
module.cli_main()


@pytest.mark.skipif(not _DALI_AVAILABLE, reason="Nvidia DALI required")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(platform.system() != "Linux", reason="Only applies to Linux platform.")
@RunIf(min_gpus=1, skip_windows=True)
@pytest.mark.parametrize("cli_args", [ARGS_GPU])
def test_examples_mnist_dali(tmpdir, cli_args):
from pl_examples.basic_examples.dali_image_classifier import cli_main
Expand Down

0 comments on commit 0a475af

Please sign in to comment.