Skip to content

Commit

Permalink
CIFAR10 working
Browse files Browse the repository at this point in the history
Signed-off-by: Tomasz Kornuta <[email protected]>
  • Loading branch information
tkornuta-nvidia committed May 22, 2020
1 parent f244a30 commit 4fdaa03
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import argparse
from copy import deepcopy

import numpy as np
from torch import max, mean, stack, tensor

import nemo.utils.argparse as nm_argparse
from nemo.backends import get_state_dict
from nemo.collections.cv.modules.data_layers.cifar10_datalayer import CIFAR10DataLayer
from nemo.collections.cv.modules.losses.nll_loss import NLLLoss
from nemo.collections.cv.modules.non_trainables.reshape_tensor import ReshapeTensor
from nemo.collections.cv.modules.trainables.convnet_encoder import ConvNetEncoder
from nemo.collections.cv.modules.trainables.feed_forward_network import FeedForwardNetwork
from nemo.core import (
DeviceType,
NeuralGraph,
NeuralModuleFactory,
OperationMode,
SimpleLossLoggerCallback,
WandbCallback,
)
from nemo.utils import logging

if __name__ == "__main__":
# Create the default parser.
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve')
# Parse the arguments
args = parser.parse_args()

# 0. Instantiate Neural Factory.
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU)

# Data layers for training and validation.
dl = CIFAR10DataLayer(train=True)
# Model.
cnn = ConvNetEncoder(input_depth=3, input_height=32, input_width=32)
reshaper = ReshapeTensor(input_dims=[-1, 16, 2, 2], output_dims=[-1, 64])
ffn = FeedForwardNetwork(input_size=64, output_size=10, dropout_rate=0.1, final_logsoftmax=True)
# Loss.
nll_loss = NLLLoss()

# 2. Create a training graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
img, tgt = dl()
feat_map = cnn(inputs=img)
res_img = reshaper(inputs=feat_map)
pred = ffn(inputs=res_img)
loss = nll_loss(predictions=pred, targets=tgt)
# Set output - that output will be used for training.
training_graph.outputs["loss"] = loss

# SimpleLossLoggerCallback will print loss values to console.
callback = SimpleLossLoggerCallback(
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}')
)

# Log training metrics to W&B.
wand_callback = WandbCallback(
train_tensors=[loss],
wandb_name="simple-mnist-fft",
wandb_project="cv-collection-image-classification",
)

# Invoke the "train" action.
nf.train(
training_graph=training_graph,
callbacks=[callback, wand_callback],
optimization_params={"num_epochs": 10, "lr": 0.001},
optimizer="adam",
)
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,16 @@
)

# Log training metrics to W&B.
# wand_callback = WandbCallback(
# train_tensors=[loss],
# wandb_name="simple-mnist-fft",
# wandb_project="cv-collection-image-classification",
# )

# Get CNN weights before training.
weights = deepcopy(get_state_dict(cnn)["_conv3.bias"]).numpy()
wand_callback = WandbCallback(
train_tensors=[loss],
wandb_name="simple-mnist-fft",
wandb_project="cv-collection-image-classification",
)

# Invoke the "train" action.
nf.train(
training_graph=training_graph,
callbacks=[callback],
optimization_params={"max_steps": 100, "lr": 0.001},
callbacks=[callback, wand_callback],
optimization_params={"num_epochs": 10, "lr": 0.001},
optimizer="adam",
)

# Get CNN weights after training.
weights2 = deepcopy(get_state_dict(cnn)["_conv3.bias"]).numpy()

logging.info("Before training:\n{}".format(weights))
logging.info("After training:\n{}".format(weights2))

if np.array_equal(weights, weights2):
logging.error("Module weights not updated during training")
1 change: 1 addition & 0 deletions nemo/collections/cv/modules/data_layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
# =============================================================================

from nemo.collections.cv.modules.data_layers.mnist_datalayer import *
from nemo.collections.cv.modules.data_layers.cifar10_datalayer import *
104 changes: 104 additions & 0 deletions nemo/collections/cv/modules/data_layers/cifar10_datalayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from os.path import expanduser

import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Resize, ToTensor

from nemo.backends.pytorch.nm import DataLayerNM
from nemo.core.neural_types import AxisKind, AxisType, LabelsType, NeuralType, NormalizedValueType
from nemo.utils.decorators import add_port_docs

__all__ = ['CIFAR10DataLayer']


class CIFAR10DataLayer(DataLayerNM, CIFAR10):
"""
A "thin DataLayer" - wrapper around the torchvision's CIFAR10 dataset.
"""

def __init__(
self, name=None, height=32, width=32, data_folder="~/data/cifar10", train=True, batch_size=64, shuffle=True
):
"""
Initializes the CIFAR10 datalayer.
Args:
name: Name of the module (DEFAULT: None)
height: image height (DEFAULT: 32)
width: image width (DEFAULT: 32)
data_folder: path to the folder with data, can be relative to user (DEFAULT: "~/data/cifar10")
train: use train or test splits (DEFAULT: True)
batch_size: size of batch (DEFAULT: 64) [PARAMETER OF DATALOADER]
shuffle: shuffle data (DEFAULT: True) [PARAMETER OF DATALOADER]
"""
# Call the base class constructor of DataLayer.
DataLayerNM.__init__(self, name=name)

# Store height and width.
self._height = height
self._width = width

# Create transformations: up-scale and transform to tensors.
mnist_transforms = Compose([Resize((self._height, self._width)), ToTensor()])

# Get absolute path.
abs_data_folder = expanduser(data_folder)

# Call the base class constructor of MNIST dataset.
CIFAR10.__init__(self, root=abs_data_folder, train=train, download=True, transform=mnist_transforms)

# Remember the params passed to DataLoader. :]
self._batch_size = batch_size
self._shuffle = shuffle

@property
@add_port_docs()
def output_ports(self):
"""
Creates definitions of output ports.
By default, it sets image width and height to 32.
"""
return {
"images": NeuralType(
axes=(
AxisType(kind=AxisKind.Batch),
AxisType(kind=AxisKind.Channel, size=3),
AxisType(kind=AxisKind.Height, size=self._height),
AxisType(kind=AxisKind.Width, size=self._width),
),
elements_type=NormalizedValueType(),
),
"targets": NeuralType(tuple('B'), elements_type=LabelsType()),
}

def __len__(self):
"""
Returns:
len(Data) - to overwrite the abstract method (which is already overwritten by the other dependency)
"""
return len(self.data)

@property
def dataset(self):
"""
Returns:
Self - just to be "compatible" with the current NeMo train action.
"""
return self
4 changes: 4 additions & 0 deletions nemo/collections/cv/modules/data_layers/mnist_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def output_ports(self):
}

def __len__(self):
"""
Returns:
len(Data) - to overwrite the abstract method (which is already overwritten by the other dependency)
"""
return len(self.data)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
Initializes the classifier.
"""
# Call constructor of parent classes.
# Call constructor of the parent class.
TrainableNM.__init__(self, name=name)

# Get input size.
Expand Down

0 comments on commit 4fdaa03

Please sign in to comment.