Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MNIST dataset & drop torchvision dep. from tests #986

Merged
merged 52 commits into from
Mar 30, 2020
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
d9a6c63
added custom mnist without torchvision dep
Feb 26, 2020
72ea8a4
move files so it does not conflict with mnist gitignore
Feb 26, 2020
6dedc60
mock torchvision for tests
Feb 29, 2020
2a5d0f8
fix line too long
Feb 29, 2020
5ef53a7
fix line too long
Feb 29, 2020
6cbc638
fix "module level import not at top of file" warning
Feb 29, 2020
2dbbfcf
move mock imports to __init__.py
Feb 29, 2020
59000f0
simplify MNIST a lot and download directly the .pt files
Feb 29, 2020
bffacaf
further simplify and clean up mnist
Feb 29, 2020
5cf4c27
revert import overrides
Feb 29, 2020
423d76c
make as before
Feb 29, 2020
dbe211c
drop PIL requirement
Feb 29, 2020
59c1f16
move mnist.py to datasets subfolder
Mar 1, 2020
a1b78f0
use logging instead of print
Mar 1, 2020
999375e
choose same name as in torchvision
Mar 1, 2020
d03270a
remove torchvision and pillow also from yml file
Mar 1, 2020
22a57bd
refactor if train
Mar 4, 2020
f1c7e52
capitalized class attr
Mar 4, 2020
a6d5fcf
moved mnist to models
Mar 4, 2020
b8e8046
re-added datsets ignore
Mar 4, 2020
4465e39
better name for file variable
Mar 4, 2020
fe708a8
Update mnist.py
Borda Mar 4, 2020
8ca713d
move dataset classes to datasets.py
Mar 4, 2020
b01ee46
new line
Mar 4, 2020
19dd316
Merge branch 'master' into mnistcopy
Mar 7, 2020
65aefdc
update
Mar 7, 2020
4ebcdf7
update
Mar 7, 2020
63f4f70
Merge branch 'master' into mnistcopy
Mar 11, 2020
0eb72e4
Merge branch 'master' into mnistcopy
Mar 13, 2020
962e0bd
Merge branch 'master' into mnistcopy
Mar 20, 2020
3abe8ca
Merge branch 'master' into mnistcopy
Mar 25, 2020
e634bfc
fix automerge
Mar 25, 2020
d1c4e61
move to base folder
Mar 25, 2020
82e9a90
adapt testingmnist to new mnist base class
Mar 25, 2020
def3804
remove temporal fix
Mar 25, 2020
a8c9d33
fix datatype
Mar 25, 2020
a0b12b9
remove old testingmnist
Mar 25, 2020
5bf5603
readable
Mar 25, 2020
0ee7d30
fix import
Mar 25, 2020
075aca0
fix whitespace
Mar 25, 2020
5d04fd4
docstring
Mar 25, 2020
e987455
Update tests/base/datasets.py
Mar 25, 2020
577b471
changelog
Mar 25, 2020
be425e3
added types
Mar 25, 2020
0ae8ac4
Update CHANGELOG.md
Mar 26, 2020
3204684
exist->isfile
Mar 26, 2020
ec6e6b2
index -> idx
Mar 26, 2020
5d1b2c6
Merge branch 'master' into mnistcopy
Mar 26, 2020
85c9a4a
temporary fix for trains error
Mar 26, 2020
b360d1a
Merge remote-tracking branch 'PyTorchLightning/master' into mnistcopy
Mar 27, 2020
dfa5f6e
Merge branch 'master' into mnistcopy
Mar 27, 2020
aca2cb9
better changelog message
Mar 27, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ docs/source/pl_examples*.rst
docs/source/pytorch_lightning*.rst
docs/source/tests*.rst
docs/source/*.md
tests/tests/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
2 changes: 0 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ dependencies:
- future>=0.17.1

# For dev and testing
- torchvision>=0.4.0
- tox
- coverage
- codecov
Expand All @@ -26,7 +25,6 @@ dependencies:
- autopep8
- check-manifest
- twine==1.13.0
- pillow<7.0.0

- pip:
- test-tube>=0.7.5
Expand Down
104 changes: 104 additions & 0 deletions tests/base/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import logging
import urllib.request

import torch
from torch.utils.data import Dataset


class MNIST(Dataset):
"""
Customized `MNIST <http://yann.lecun.com/exdb/mnist/>`_ dataset for testing Pytorch Lightning
without the torchvision dependency.

Part of the code was copied from
https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py

Args:
root (string): Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train (bool, optional): If True, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

RESOURCES = (
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/training.pt",
"https://pl-public-data.s3.amazonaws.com/MNIST/processed/test.pt",
)

TRAIN_FILE_NAME = 'training.pt'
TEST_FILE_NAME = 'test.pt'

def __init__(self, root, train=True, normalize=(0.5, 1.0), download=False):
super(MNIST, self).__init__()
self.root = root
self.train = train # training set or test set
self.normalize = normalize

if download:
self.download()

if not self._check_exists():
raise RuntimeError('Dataset not found.')

data_file = self.TRAIN_FILE_NAME if self.train else self.TEST_FILE_NAME
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))

def __getitem__(self, index):
img = self.data[index].float().unsqueeze(0)
target = int(self.targets[index])

if self.normalize is not None:
img = normalize_tensor(img, mean=self.normalize[0], std=self.normalize[1])

return img, target

def __len__(self):
return len(self.data)

@property
def processed_folder(self):
return os.path.join(self.root, 'MNIST', 'processed')

def _check_exists(self):
train_file = os.path.join(self.processed_folder, self.TRAIN_FILE_NAME)
test_file = os.path.join(self.processed_folder, self.TEST_FILE_NAME)
return os.path.exists(train_file) and os.path.exists(test_file)

def download(self):
"""Download the MNIST data if it doesn't exist in processed_folder already."""

if self._check_exists():
return

os.makedirs(self.processed_folder, exist_ok=True)

for url in self.RESOURCES:
logging.info(f'Downloading {url}')
fpath = os.path.join(self.processed_folder, os.path.basename(url))
urllib.request.urlretrieve(url, fpath)


def normalize_tensor(tensor, mean=0.0, std=1.0):
tensor = tensor.clone()
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
tensor.sub_(mean).div_(std)
return tensor


class TestingMNIST(MNIST):

def __init__(self, root, train=True, normalize=(0.5, 1.0), download=False, num_samples=8000):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next step would be also select just a subset of digits because we do not have use all 10 in testing...
so with limiting to 5 digits and keeping 2k examples, we may increase back the test's accuracy to 0.5

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea. new pr or this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is accuracy important in tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, acc is used to verify the model is learning something
sure a new PR would be better since it would need to update some more numbers

super().__init__(
root,
train=train,
normalize=normalize,
download=download
)
# take just a subset of MNIST dataset
self.data = self.data[:num_samples]
self.targets = self.targets[:num_samples]
2 changes: 1 addition & 1 deletion tests/base/debug.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import pytorch_lightning as pl
from tests.base.datasets import MNIST


# from test_models import assert_ok_test_acc, load_model, \
Expand Down
38 changes: 5 additions & 33 deletions tests/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

from tests.base.datasets import TestingMNIST

try:
from test_tube import HyperOptArgumentParser
Expand All @@ -18,29 +18,6 @@

from pytorch_lightning.core.lightning import LightningModule

# TODO: remove after getting own MNIST
# TEMPORAL FIX, https://github.com/pytorch/vision/issues/1938
import urllib.request
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)


class TestingMNIST(MNIST):

def __init__(self, root, train=True, transform=None, target_transform=None,
download=False, num_samples=8000):
super().__init__(
root,
train=train,
transform=transform,
target_transform=target_transform,
download=download
)
# take just a subset of MNIST dataset
self.data = self.data[:num_samples]
self.targets = self.targets[:num_samples]


class DictHparamsModel(LightningModule):

Expand All @@ -61,8 +38,7 @@ def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)

def train_dataloader(self):
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True,
transform=transforms.ToTensor()), batch_size=32)
return DataLoader(TestingMNIST(os.getcwd(), train=True, download=True), batch_size=32)


class TestModelBase(LightningModule):
Expand Down Expand Up @@ -178,17 +154,13 @@ def configure_optimizers(self):
return [optimizer], [scheduler]

def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
_ = TestingMNIST(root=self.hparams.data_root, train=True,
transform=transform, download=True, num_samples=2000)
download=True, num_samples=2000)

def _dataloader(self, train):
# init data generators
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = TestingMNIST(root=self.hparams.data_root, train=train,
transform=transform, download=False, num_samples=2000)
download=False, num_samples=2000)

# when using multi-node we need to add the datasampler
batch_size = self.hparams.batch_size
Expand Down
Empty file added tests/datasets/__init__.py
Empty file.
4 changes: 1 addition & 3 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
-r ../requirements-extra.txt

# extended list of dependencies dor development and run lint and tests
torchvision>=0.4.0, < 0.5 # the 0.5. has some issues with torch JIT
tox
coverage
codecov
Expand All @@ -11,5 +10,4 @@ pytest-cov
pytest-flake8
flake8
check-manifest
twine==1.13.0
pillow<7.0.0
twine==1.13.0