Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jul 2, 2020
1 parent cbf87d5 commit 3158efd
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions pl_examples/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
from unittest import mock

import numpy as np
import PIL
import pytest
import torch

Expand Down Expand Up @@ -27,30 +30,25 @@ def test_gpu_template(cli_args):

@pytest.mark.parametrize('cli_args', ['--epochs 1'])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_imagenet(cli_args):
def test_imagenet(tmpdir, cli_args):
"""Test running CLI for an example with default params."""
from pl_examples.domain_templates.imagenet import run_cli

import os
import tempfile
import PIL
import numpy as np
# https://github.com/pytorch/vision/blob/master/test/fakedata_generation.py#L105
def _make_image(file):
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file)

with tempfile.TemporaryDirectory() as tmpdir:
for split in ['train', 'val']:
for class_id in ['a', 'b']:
os.makedirs(os.path.join(tmpdir, split, class_id))
# Generate 5 black images
for image_id in range(5):
_make_image(os.path.join(tmpdir, split, class_id, str(image_id)+'.JPEG'))

cli_args = cli_args.split(' ') if cli_args else []
cli_args.extend(['--data-path', tmpdir])
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
run_cli()
def _make_image(file_path):
PIL.Image.fromarray(np.zeros((32, 32, 3), dtype=np.uint8)).save(file_path)

for split in ['train', 'val']:
for class_id in ['a', 'b']:
os.makedirs(os.path.join(tmpdir, split, class_id))
# Generate 5 black images
for image_id in range(5):
_make_image(os.path.join(tmpdir, split, class_id, str(image_id)+'.JPEG'))

cli_args = cli_args.split(' ') if cli_args else []
cli_args += ['--data-path', tmpdir]
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
run_cli()


# @pytest.mark.parametrize('cli_args', ['--max_epochs 1 --max_steps 3 --num_nodes 1 --gpus 2'])
Expand Down

0 comments on commit 3158efd

Please sign in to comment.