-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CV collection: image classification (#654)
* CV collection init: MNIST image classification Signed-off-by: Tomasz Kornuta <[email protected]> * Ported FFN and TensorReshaper, MNIST classification working on CPU Signed-off-by: Tomasz Kornuta <[email protected]> * reformatted code Signed-off-by: Tomasz Kornuta <[email protected]> * convnet encoder ported, example working... but showing that CNNs are not updated Signed-off-by: Tomasz Kornuta <[email protected]> * format fix Signed-off-by: Tomasz Kornuta <[email protected]> * Trainable NM fix - removing no_grad() Signed-off-by: Tomasz Kornuta <[email protected]> * CIFAR10 working Signed-off-by: Tomasz Kornuta <[email protected]> * Made the types of FFN and ReshapeTensor more Signed-off-by: Tomasz Kornuta <[email protected]> * formatting fix Signed-off-by: Tomasz Kornuta <[email protected]> * LGTM unused import fixes Signed-off-by: Tomasz Kornuta <[email protected]> * LGTM fixes: unused variable in the loop Signed-off-by: Tomasz Kornuta <[email protected]> * GenericImageEncoder ported + CIFAR10 VGG16 classification example Signed-off-by: Tomasz Kornuta <[email protected]> * Added NonLinearity component, simplified the FFN, cifar10 - ResNet50 operational Signed-off-by: Tomasz Kornuta <[email protected]> * LGTM fixes Signed-off-by: Tomasz Kornuta <[email protected]> * Stronger typing in CV modules and examples, introduced several new ElementTypes Signed-off-by: Tomasz Kornuta <[email protected]> * formatting Signed-off-by: Tomasz Kornuta <[email protected]> * updated requirements, docs, setup, added information about CV collection to readme Signed-off-by: Tomasz Kornuta <[email protected]> * updated description in changelog Signed-off-by: Tomasz Kornuta <[email protected]> * minor comment polish Signed-off-by: Tomasz Kornuta <[email protected]> * rst fix Signed-off-by: Tomasz Kornuta <[email protected]> * minor nemo typing fix - imagetype Signed-off-by: Tomasz Kornuta <[email protected]> * polished datalayers, added CIFAR100, added Index and Label types, polished types Signed-off-by: Tomasz Kornuta <[email protected]> * formatting fix Signed-off-by: Tomasz Kornuta <[email protected]> * GenericImageEncoder -> ImageEncoder, updated readme file Signed-off-by: Tomasz Kornuta <[email protected]> * changed assert to get_value_from_dict Signed-off-by: Tomasz Kornuta <[email protected]> * formatting fix Signed-off-by: Tomasz Kornuta <[email protected]> * added python 3 typing to all inits, fixed LGTM issue, formatted Signed-off-by: Tomasz Kornuta <[email protected]> * Updated docstrings Signed-off-by: Tomasz Kornuta <[email protected]> * raise ConfigurationError Signed-off-by: Tomasz Kornuta <[email protected]> * reshape tensor docstring update Signed-off-by: Tomasz Kornuta <[email protected]> * Label -> StringLabel, description of ImageEncoder Signed-off-by: Tomasz Kornuta <[email protected]>
- Loading branch information
1 parent
78b6bef
commit 74bf41d
Showing
33 changed files
with
2,530 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,5 +8,6 @@ NeMo Collections API | |
|
||
core | ||
nemo_asr | ||
nemo_cv | ||
nemo_tts | ||
nemo_nlp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
NeMo CV collection | ||
================== | ||
|
||
DataLayers | ||
---------- | ||
.. automodule:: nemo.collections.cv.modules.data_layers | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
:exclude-members: forward | ||
|
||
Trainable Modules | ||
----------------- | ||
.. automodule:: nemo.collections.cv.modules.trainables | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
:exclude-members: forward | ||
|
||
NonTrainable Modules | ||
-------------------- | ||
.. automodule:: nemo.collections.cv.modules.non_trainables | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
:exclude-members: forward | ||
|
||
Losses | ||
------ | ||
.. automodule:: nemo.collections.cv.modules.losses | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
:exclude-members: forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
NeMo CV Collection: Neural Modules for Computer Vision | ||
==================================================================== | ||
|
||
The NeMo CV collection offers modules useful for the following computer vision applications. | ||
|
||
For now the collection focuses only on Image Classification. | ||
|
||
1. MNIST classification: | ||
* a thin DL wrapper around torchvision MNIST dataset | ||
* classification with the classic LeNet-5 | ||
* classification with a graph: ReshapeTensor -> FeedForwardNetwork -> LogProbs | ||
* classification with a graph: ConvNet -> ReshapeTensor -> FeedForwardNetwork -> LogProbs | ||
|
||
2. CIFAR10 classification: | ||
* a thin DL wrapper around torchvision CIFAR10 dataset | ||
* classification with a graph: ConvNet -> ReshapeTensor -> FeedForwardNetwork -> LogProbs | ||
* classification with a graph: ImageEncoder (ResNet-50 feature maps) -> FeedForwardNetwork -> LogProbs | ||
|
||
3. CIFAR100 classification: | ||
* a thin DL wrapper around torchvision CIFAR100 dataset | ||
* classification with a graph: ImageEncoder (VGG-16 with FC6 reshaped) -> LogProbs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2020 NVIDIA. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
|
||
from nemo.collections.cv.modules import * | ||
|
||
# __version__ = "0.1" | ||
# __name__ = "nemo.collections.cv" |
72 changes: 72 additions & 0 deletions
72
nemo/collections/cv/examples/cifar100_vgg16_ffn_image_classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2020 NVIDIA. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
|
||
import argparse | ||
|
||
import nemo.utils.argparse as nm_argparse | ||
from nemo.collections.cv.modules.data_layers import CIFAR100DataLayer | ||
from nemo.collections.cv.modules.losses import NLLLoss | ||
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor | ||
from nemo.collections.cv.modules.trainables import FeedForwardNetwork, ImageEncoder | ||
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback | ||
from nemo.utils import logging | ||
|
||
if __name__ == "__main__": | ||
# Create the default parser. | ||
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') | ||
# Parse the arguments | ||
args = parser.parse_args() | ||
|
||
# Instantiate Neural Factory. | ||
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) | ||
|
||
# Data layer - upscale the CIFAR100 images to ImageNet resolution. | ||
cifar100_dl = CIFAR100DataLayer(height=224, width=224, train=True) | ||
# The "model". | ||
image_encoder = ImageEncoder(model_type="vgg16", return_feature_maps=True, pretrained=True, name="vgg16") | ||
reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 512], output_sizes=[-1, 25088]) | ||
ffn = FeedForwardNetwork(input_size=25088, output_size=100, hidden_sizes=[1000, 1000], dropout_rate=0.1) | ||
nl = NonLinearity(type="logsoftmax", sizes=[-1, 100]) | ||
# Loss. | ||
nll_loss = NLLLoss() | ||
|
||
# Create a training graph. | ||
with NeuralGraph(operation_mode=OperationMode.training) as training_graph: | ||
_, img, _, _, fine_target, _ = cifar100_dl() | ||
feat_map = image_encoder(inputs=img) | ||
res_img = reshaper(inputs=feat_map) | ||
logits = ffn(inputs=res_img) | ||
pred = nl(inputs=logits) | ||
loss = nll_loss(predictions=pred, targets=fine_target) | ||
# Set output - that output will be used for training. | ||
training_graph.outputs["loss"] = loss | ||
|
||
# Freeze the pretrained encoder. | ||
training_graph.freeze(["vgg16"]) | ||
logging.info(training_graph.summary()) | ||
|
||
# SimpleLossLoggerCallback will print loss values to console. | ||
callback = SimpleLossLoggerCallback( | ||
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') | ||
) | ||
|
||
# Invoke the "train" action. | ||
nf.train( | ||
training_graph=training_graph, | ||
callbacks=[callback], | ||
optimization_params={"num_epochs": 10, "lr": 0.001}, | ||
optimizer="adam", | ||
) |
71 changes: 71 additions & 0 deletions
71
nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2020 NVIDIA. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
|
||
import argparse | ||
|
||
import nemo.utils.argparse as nm_argparse | ||
from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer | ||
from nemo.collections.cv.modules.losses import NLLLoss | ||
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor | ||
from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork | ||
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback | ||
from nemo.utils import logging | ||
|
||
if __name__ == "__main__": | ||
# Create the default parser. | ||
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') | ||
# Parse the arguments | ||
args = parser.parse_args() | ||
|
||
# Instantiate Neural Factory. | ||
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) | ||
|
||
# Data layer for training. | ||
cifar10_dl = CIFAR10DataLayer(train=True) | ||
# The "model". | ||
cnn = ConvNetEncoder(input_depth=3, input_height=32, input_width=32) | ||
reshaper = ReshapeTensor(input_sizes=[-1, 16, 2, 2], output_sizes=[-1, 64]) | ||
ffn = FeedForwardNetwork(input_size=64, output_size=10, dropout_rate=0.1) | ||
nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) | ||
# Loss. | ||
nll_loss = NLLLoss() | ||
|
||
# Create a training graph. | ||
with NeuralGraph(operation_mode=OperationMode.training) as training_graph: | ||
_, img, tgt = cifar10_dl() | ||
feat_map = cnn(inputs=img) | ||
res_img = reshaper(inputs=feat_map) | ||
logits = ffn(inputs=res_img) | ||
pred = nl(inputs=logits) | ||
loss = nll_loss(predictions=pred, targets=tgt) | ||
# Set output - that output will be used for training. | ||
training_graph.outputs["loss"] = loss | ||
|
||
# Display the graph summmary. | ||
logging.info(training_graph.summary()) | ||
|
||
# SimpleLossLoggerCallback will print loss values to console. | ||
callback = SimpleLossLoggerCallback( | ||
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') | ||
) | ||
|
||
# Invoke the "train" action. | ||
nf.train( | ||
training_graph=training_graph, | ||
callbacks=[callback], | ||
optimization_params={"num_epochs": 10, "lr": 0.001}, | ||
optimizer="adam", | ||
) |
67 changes: 67 additions & 0 deletions
67
nemo/collections/cv/examples/cifar10_resnet50_image_classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2020 NVIDIA. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
|
||
import argparse | ||
|
||
import nemo.utils.argparse as nm_argparse | ||
from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer | ||
from nemo.collections.cv.modules.losses import NLLLoss | ||
from nemo.collections.cv.modules.non_trainables import NonLinearity | ||
from nemo.collections.cv.modules.trainables import ImageEncoder | ||
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback | ||
from nemo.utils import logging | ||
|
||
if __name__ == "__main__": | ||
# Create the default parser. | ||
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') | ||
# Parse the arguments | ||
args = parser.parse_args() | ||
|
||
# Instantiate Neural Factory. | ||
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) | ||
|
||
# Data layer - upscale the CIFAR10 images to ImageNet resolution. | ||
cifar10_dl = CIFAR10DataLayer(height=224, width=224, train=True) | ||
# The "model". | ||
image_classifier = ImageEncoder(model_type="resnet50", output_size=10, pretrained=True, name="resnet50") | ||
nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) | ||
# Loss. | ||
nll_loss = NLLLoss() | ||
|
||
# Create a training graph. | ||
with NeuralGraph(operation_mode=OperationMode.training) as training_graph: | ||
_, img, tgt = cifar10_dl() | ||
logits = image_classifier(inputs=img) | ||
pred = nl(inputs=logits) | ||
loss = nll_loss(predictions=pred, targets=tgt) | ||
# Set output - that output will be used for training. | ||
training_graph.outputs["loss"] = loss | ||
|
||
# Display the graph summmary. | ||
logging.info(training_graph.summary()) | ||
|
||
# SimpleLossLoggerCallback will print loss values to console. | ||
callback = SimpleLossLoggerCallback( | ||
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') | ||
) | ||
|
||
# Invoke the "train" action. | ||
nf.train( | ||
training_graph=training_graph, | ||
callbacks=[callback], | ||
optimization_params={"num_epochs": 10, "lr": 0.001}, | ||
optimizer="adam", | ||
) |
Oops, something went wrong.