Skip to content

Commit

Permalink
CV collection: image classification (#654)
Browse files Browse the repository at this point in the history
* CV collection init: MNIST image classification

Signed-off-by: Tomasz Kornuta <[email protected]>

* Ported FFN and TensorReshaper, MNIST classification working on CPU

Signed-off-by: Tomasz Kornuta <[email protected]>

* reformatted code

Signed-off-by: Tomasz Kornuta <[email protected]>

* convnet encoder ported, example working... but showing that CNNs are not updated

Signed-off-by: Tomasz Kornuta <[email protected]>

* format fix

Signed-off-by: Tomasz Kornuta <[email protected]>

* Trainable NM fix - removing no_grad()

Signed-off-by: Tomasz Kornuta <[email protected]>

* CIFAR10 working

Signed-off-by: Tomasz Kornuta <[email protected]>

* Made the types of FFN and ReshapeTensor more

Signed-off-by: Tomasz Kornuta <[email protected]>

* formatting fix

Signed-off-by: Tomasz Kornuta <[email protected]>

* LGTM unused import fixes

Signed-off-by: Tomasz Kornuta <[email protected]>

* LGTM fixes: unused variable in the loop

Signed-off-by: Tomasz Kornuta <[email protected]>

* GenericImageEncoder ported + CIFAR10 VGG16 classification example

Signed-off-by: Tomasz Kornuta <[email protected]>

* Added NonLinearity component, simplified the FFN, cifar10 - ResNet50 operational

Signed-off-by: Tomasz Kornuta <[email protected]>

* LGTM fixes

Signed-off-by: Tomasz Kornuta <[email protected]>

* Stronger typing in CV modules and examples, introduced several new ElementTypes

Signed-off-by: Tomasz Kornuta <[email protected]>

* formatting

Signed-off-by: Tomasz Kornuta <[email protected]>

* updated requirements, docs, setup, added information about CV collection to readme

Signed-off-by: Tomasz Kornuta <[email protected]>

* updated description in changelog

Signed-off-by: Tomasz Kornuta <[email protected]>

* minor comment polish

Signed-off-by: Tomasz Kornuta <[email protected]>

* rst fix

Signed-off-by: Tomasz Kornuta <[email protected]>

* minor nemo typing fix - imagetype

Signed-off-by: Tomasz Kornuta <[email protected]>

* polished datalayers, added CIFAR100, added Index and Label types, polished types

Signed-off-by: Tomasz Kornuta <[email protected]>

* formatting fix

Signed-off-by: Tomasz Kornuta <[email protected]>

* GenericImageEncoder -> ImageEncoder, updated readme file

Signed-off-by: Tomasz Kornuta <[email protected]>

* changed assert to get_value_from_dict

Signed-off-by: Tomasz Kornuta <[email protected]>

* formatting fix

Signed-off-by: Tomasz Kornuta <[email protected]>

* added python 3 typing to all inits, fixed LGTM issue, formatted

Signed-off-by: Tomasz Kornuta <[email protected]>

* Updated docstrings

Signed-off-by: Tomasz Kornuta <[email protected]>

* raise ConfigurationError

Signed-off-by: Tomasz Kornuta <[email protected]>

* reshape tensor docstring update

Signed-off-by: Tomasz Kornuta <[email protected]>

* Label -> StringLabel, description of ImageEncoder

Signed-off-by: Tomasz Kornuta <[email protected]>
  • Loading branch information
tkornuta-nvidia authored Jun 3, 2020
1 parent 78b6bef commit 74bf41d
Show file tree
Hide file tree
Showing 33 changed files with 2,530 additions and 12 deletions.
9 changes: 2 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ To release a new version, please update the changelog as followed:
- ContextNet Encoder + Decoder Initial Support ([PR #630](https://github.com/NVIDIA/NeMo/pull/630)) - @titu1994
- Added finetuning with Megatron-LM ([PR #601](https://github.com/NVIDIA/NeMo/pull/601)) - @ekmb
- Added documentation for 8 kHz model ([PR #632](https://github.com/NVIDIA/NeMo/pull/632)) - @jbalam-nv
- The Neural Graph is a high-level abstract concept empowering the users to build graphs consisting of many, interconnected Neural Modules. A user in his/her application can build any number of graphs, potentially spanning over the same modules. The import/export options combined with the lightweight API make Neural Graphs a perfect tool for rapid prototyping and experimentation. ([PR #413](https://github.com/NVIDIA/NeMo/pull/413)) - @tkornuta-nvidia
- Created the NeMo CV collection, added MNIST and CIFAR10 thin datalayers, implemented/ported several general usage trainable and non-trainable modules, added several new ElementTypes ([PR #654](https://github.com/NVIDIA/NeMo/pull/654)) - @tkornuta-nvidia


### Changed
Expand All @@ -97,13 +99,6 @@ To release a new version, please update the changelog as followed:

### Security

### Contributors

## [0.10.2] - 2020-05-05

### Added
- The Neural Graph is a high-level abstract concept empowering the users to build graphs consisting of many, interconnected Neural Modules. A user in his/her application can build any number of graphs, potentially spanning over the same modules. The import/export options combined with the lightweight API make Neural Graphs a perfect tool for rapid prototyping and experimentation. ([PR #413](https://github.com/NVIDIA/NeMo/pull/413)) - @tkornuta

## [0.10.0] - 2020-04-03

### Added
Expand Down
1 change: 1 addition & 0 deletions docs/sources/source/collections/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ NeMo Collections API

core
nemo_asr
nemo_cv
nemo_tts
nemo_nlp
34 changes: 34 additions & 0 deletions docs/sources/source/collections/nemo_cv.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
NeMo CV collection
==================

DataLayers
----------
.. automodule:: nemo.collections.cv.modules.data_layers
:members:
:undoc-members:
:show-inheritance:
:exclude-members: forward

Trainable Modules
-----------------
.. automodule:: nemo.collections.cv.modules.trainables
:members:
:undoc-members:
:show-inheritance:
:exclude-members: forward

NonTrainable Modules
--------------------
.. automodule:: nemo.collections.cv.modules.non_trainables
:members:
:undoc-members:
:show-inheritance:
:exclude-members: forward

Losses
------
.. automodule:: nemo.collections.cv.modules.losses
:members:
:undoc-members:
:show-inheritance:
:exclude-members: forward
5 changes: 2 additions & 3 deletions nemo/backends/pytorch/nm.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,7 @@ def __init__(self, name=None):
def __call__(self, force_pt=False, *input, **kwargs):
pt_call = len(input) > 0 or force_pt
if pt_call:
with t.no_grad():
return self.forward(*input, **kwargs)
return self.forward(*input, **kwargs)
else:
return NeuralModule.__call__(self, **kwargs)

Expand Down Expand Up @@ -320,13 +319,13 @@ def dataset(self):
pass

@property
@abstractmethod
def data_iterator(self):
""""Iterator over the dataset. It is a good idea to return
torch.utils.data.DataLoader here. Should implement either this or
`dataset`.
If this is implemented, `dataset` property should return None.
"""
return None

@property
def batch_size(self):
Expand Down
21 changes: 21 additions & 0 deletions nemo/collections/cv/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
NeMo CV Collection: Neural Modules for Computer Vision
====================================================================

The NeMo CV collection offers modules useful for the following computer vision applications.

For now the collection focuses only on Image Classification.

1. MNIST classification:
* a thin DL wrapper around torchvision MNIST dataset
* classification with the classic LeNet-5
* classification with a graph: ReshapeTensor -> FeedForwardNetwork -> LogProbs
* classification with a graph: ConvNet -> ReshapeTensor -> FeedForwardNetwork -> LogProbs

2. CIFAR10 classification:
* a thin DL wrapper around torchvision CIFAR10 dataset
* classification with a graph: ConvNet -> ReshapeTensor -> FeedForwardNetwork -> LogProbs
* classification with a graph: ImageEncoder (ResNet-50 feature maps) -> FeedForwardNetwork -> LogProbs

3. CIFAR100 classification:
* a thin DL wrapper around torchvision CIFAR100 dataset
* classification with a graph: ImageEncoder (VGG-16 with FC6 reshaped) -> LogProbs
20 changes: 20 additions & 0 deletions nemo/collections/cv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from nemo.collections.cv.modules import *

# __version__ = "0.1"
# __name__ = "nemo.collections.cv"
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import argparse

import nemo.utils.argparse as nm_argparse
from nemo.collections.cv.modules.data_layers import CIFAR100DataLayer
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor
from nemo.collections.cv.modules.trainables import FeedForwardNetwork, ImageEncoder
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
# Create the default parser.
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve')
# Parse the arguments
args = parser.parse_args()

# Instantiate Neural Factory.
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU)

# Data layer - upscale the CIFAR100 images to ImageNet resolution.
cifar100_dl = CIFAR100DataLayer(height=224, width=224, train=True)
# The "model".
image_encoder = ImageEncoder(model_type="vgg16", return_feature_maps=True, pretrained=True, name="vgg16")
reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 512], output_sizes=[-1, 25088])
ffn = FeedForwardNetwork(input_size=25088, output_size=100, hidden_sizes=[1000, 1000], dropout_rate=0.1)
nl = NonLinearity(type="logsoftmax", sizes=[-1, 100])
# Loss.
nll_loss = NLLLoss()

# Create a training graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
_, img, _, _, fine_target, _ = cifar100_dl()
feat_map = image_encoder(inputs=img)
res_img = reshaper(inputs=feat_map)
logits = ffn(inputs=res_img)
pred = nl(inputs=logits)
loss = nll_loss(predictions=pred, targets=fine_target)
# Set output - that output will be used for training.
training_graph.outputs["loss"] = loss

# Freeze the pretrained encoder.
training_graph.freeze(["vgg16"])
logging.info(training_graph.summary())

# SimpleLossLoggerCallback will print loss values to console.
callback = SimpleLossLoggerCallback(
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}')
)

# Invoke the "train" action.
nf.train(
training_graph=training_graph,
callbacks=[callback],
optimization_params={"num_epochs": 10, "lr": 0.001},
optimizer="adam",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import argparse

import nemo.utils.argparse as nm_argparse
from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor
from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
# Create the default parser.
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve')
# Parse the arguments
args = parser.parse_args()

# Instantiate Neural Factory.
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU)

# Data layer for training.
cifar10_dl = CIFAR10DataLayer(train=True)
# The "model".
cnn = ConvNetEncoder(input_depth=3, input_height=32, input_width=32)
reshaper = ReshapeTensor(input_sizes=[-1, 16, 2, 2], output_sizes=[-1, 64])
ffn = FeedForwardNetwork(input_size=64, output_size=10, dropout_rate=0.1)
nl = NonLinearity(type="logsoftmax", sizes=[-1, 10])
# Loss.
nll_loss = NLLLoss()

# Create a training graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
_, img, tgt = cifar10_dl()
feat_map = cnn(inputs=img)
res_img = reshaper(inputs=feat_map)
logits = ffn(inputs=res_img)
pred = nl(inputs=logits)
loss = nll_loss(predictions=pred, targets=tgt)
# Set output - that output will be used for training.
training_graph.outputs["loss"] = loss

# Display the graph summmary.
logging.info(training_graph.summary())

# SimpleLossLoggerCallback will print loss values to console.
callback = SimpleLossLoggerCallback(
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}')
)

# Invoke the "train" action.
nf.train(
training_graph=training_graph,
callbacks=[callback],
optimization_params={"num_epochs": 10, "lr": 0.001},
optimizer="adam",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import argparse

import nemo.utils.argparse as nm_argparse
from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity
from nemo.collections.cv.modules.trainables import ImageEncoder
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
# Create the default parser.
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve')
# Parse the arguments
args = parser.parse_args()

# Instantiate Neural Factory.
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU)

# Data layer - upscale the CIFAR10 images to ImageNet resolution.
cifar10_dl = CIFAR10DataLayer(height=224, width=224, train=True)
# The "model".
image_classifier = ImageEncoder(model_type="resnet50", output_size=10, pretrained=True, name="resnet50")
nl = NonLinearity(type="logsoftmax", sizes=[-1, 10])
# Loss.
nll_loss = NLLLoss()

# Create a training graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
_, img, tgt = cifar10_dl()
logits = image_classifier(inputs=img)
pred = nl(inputs=logits)
loss = nll_loss(predictions=pred, targets=tgt)
# Set output - that output will be used for training.
training_graph.outputs["loss"] = loss

# Display the graph summmary.
logging.info(training_graph.summary())

# SimpleLossLoggerCallback will print loss values to console.
callback = SimpleLossLoggerCallback(
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}')
)

# Invoke the "train" action.
nf.train(
training_graph=training_graph,
callbacks=[callback],
optimization_params={"num_epochs": 10, "lr": 0.001},
optimizer="adam",
)
Loading

0 comments on commit 74bf41d

Please sign in to comment.