diff --git a/CHANGELOG.md b/CHANGELOG.md index d5d85e0f472b..d152d8b761dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,8 @@ To release a new version, please update the changelog as followed: - ContextNet Encoder + Decoder Initial Support ([PR #630](https://github.com/NVIDIA/NeMo/pull/630)) - @titu1994 - Added finetuning with Megatron-LM ([PR #601](https://github.com/NVIDIA/NeMo/pull/601)) - @ekmb - Added documentation for 8 kHz model ([PR #632](https://github.com/NVIDIA/NeMo/pull/632)) - @jbalam-nv +- The Neural Graph is a high-level abstract concept empowering the users to build graphs consisting of many, interconnected Neural Modules. A user in his/her application can build any number of graphs, potentially spanning over the same modules. The import/export options combined with the lightweight API make Neural Graphs a perfect tool for rapid prototyping and experimentation. ([PR #413](https://github.com/NVIDIA/NeMo/pull/413)) - @tkornuta-nvidia +- Created the NeMo CV collection, added MNIST and CIFAR10 thin datalayers, implemented/ported several general usage trainable and non-trainable modules, added several new ElementTypes ([PR #654](https://github.com/NVIDIA/NeMo/pull/654)) - @tkornuta-nvidia ### Changed @@ -97,13 +99,6 @@ To release a new version, please update the changelog as followed: ### Security -### Contributors - -## [0.10.2] - 2020-05-05 - -### Added -- The Neural Graph is a high-level abstract concept empowering the users to build graphs consisting of many, interconnected Neural Modules. A user in his/her application can build any number of graphs, potentially spanning over the same modules. The import/export options combined with the lightweight API make Neural Graphs a perfect tool for rapid prototyping and experimentation. ([PR #413](https://github.com/NVIDIA/NeMo/pull/413)) - @tkornuta - ## [0.10.0] - 2020-04-03 ### Added diff --git a/docs/sources/source/collections/modules.rst b/docs/sources/source/collections/modules.rst index 3a26a3eafc1d..a4b769001317 100644 --- a/docs/sources/source/collections/modules.rst +++ b/docs/sources/source/collections/modules.rst @@ -8,5 +8,6 @@ NeMo Collections API core nemo_asr + nemo_cv nemo_tts nemo_nlp diff --git a/docs/sources/source/collections/nemo_cv.rst b/docs/sources/source/collections/nemo_cv.rst new file mode 100644 index 000000000000..137446f20042 --- /dev/null +++ b/docs/sources/source/collections/nemo_cv.rst @@ -0,0 +1,34 @@ +NeMo CV collection +================== + +DataLayers +---------- +.. automodule:: nemo.collections.cv.modules.data_layers + :members: + :undoc-members: + :show-inheritance: + :exclude-members: forward + +Trainable Modules +----------------- +.. automodule:: nemo.collections.cv.modules.trainables + :members: + :undoc-members: + :show-inheritance: + :exclude-members: forward + +NonTrainable Modules +-------------------- +.. automodule:: nemo.collections.cv.modules.non_trainables + :members: + :undoc-members: + :show-inheritance: + :exclude-members: forward + +Losses +------ +.. automodule:: nemo.collections.cv.modules.losses + :members: + :undoc-members: + :show-inheritance: + :exclude-members: forward diff --git a/nemo/backends/pytorch/nm.py b/nemo/backends/pytorch/nm.py index 0ed8e4ee66de..fe882a818a84 100644 --- a/nemo/backends/pytorch/nm.py +++ b/nemo/backends/pytorch/nm.py @@ -171,8 +171,7 @@ def __init__(self, name=None): def __call__(self, force_pt=False, *input, **kwargs): pt_call = len(input) > 0 or force_pt if pt_call: - with t.no_grad(): - return self.forward(*input, **kwargs) + return self.forward(*input, **kwargs) else: return NeuralModule.__call__(self, **kwargs) @@ -320,13 +319,13 @@ def dataset(self): pass @property - @abstractmethod def data_iterator(self): """"Iterator over the dataset. It is a good idea to return torch.utils.data.DataLoader here. Should implement either this or `dataset`. If this is implemented, `dataset` property should return None. """ + return None @property def batch_size(self): diff --git a/nemo/collections/cv/README.md b/nemo/collections/cv/README.md new file mode 100644 index 000000000000..1861a474970f --- /dev/null +++ b/nemo/collections/cv/README.md @@ -0,0 +1,21 @@ +NeMo CV Collection: Neural Modules for Computer Vision +==================================================================== + +The NeMo CV collection offers modules useful for the following computer vision applications. + +For now the collection focuses only on Image Classification. + +1. MNIST classification: + * a thin DL wrapper around torchvision MNIST dataset + * classification with the classic LeNet-5 + * classification with a graph: ReshapeTensor -> FeedForwardNetwork -> LogProbs + * classification with a graph: ConvNet -> ReshapeTensor -> FeedForwardNetwork -> LogProbs + +2. CIFAR10 classification: + * a thin DL wrapper around torchvision CIFAR10 dataset + * classification with a graph: ConvNet -> ReshapeTensor -> FeedForwardNetwork -> LogProbs + * classification with a graph: ImageEncoder (ResNet-50 feature maps) -> FeedForwardNetwork -> LogProbs + +3. CIFAR100 classification: + * a thin DL wrapper around torchvision CIFAR100 dataset + * classification with a graph: ImageEncoder (VGG-16 with FC6 reshaped) -> LogProbs \ No newline at end of file diff --git a/nemo/collections/cv/__init__.py b/nemo/collections/cv/__init__.py new file mode 100644 index 000000000000..80cb6e13d70e --- /dev/null +++ b/nemo/collections/cv/__init__.py @@ -0,0 +1,20 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.cv.modules import * + +# __version__ = "0.1" +# __name__ = "nemo.collections.cv" diff --git a/nemo/collections/cv/examples/cifar100_vgg16_ffn_image_classification.py b/nemo/collections/cv/examples/cifar100_vgg16_ffn_image_classification.py new file mode 100644 index 000000000000..8dc819aa6753 --- /dev/null +++ b/nemo/collections/cv/examples/cifar100_vgg16_ffn_image_classification.py @@ -0,0 +1,72 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers import CIFAR100DataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import FeedForwardNetwork, ImageEncoder +from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) + + # Data layer - upscale the CIFAR100 images to ImageNet resolution. + cifar100_dl = CIFAR100DataLayer(height=224, width=224, train=True) + # The "model". + image_encoder = ImageEncoder(model_type="vgg16", return_feature_maps=True, pretrained=True, name="vgg16") + reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 512], output_sizes=[-1, 25088]) + ffn = FeedForwardNetwork(input_size=25088, output_size=100, hidden_sizes=[1000, 1000], dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 100]) + # Loss. + nll_loss = NLLLoss() + + # Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + _, img, _, _, fine_target, _ = cifar100_dl() + feat_map = image_encoder(inputs=img) + res_img = reshaper(inputs=feat_map) + logits = ffn(inputs=res_img) + pred = nl(inputs=logits) + loss = nll_loss(predictions=pred, targets=fine_target) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Freeze the pretrained encoder. + training_graph.freeze(["vgg16"]) + logging.info(training_graph.summary()) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py b/nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py new file mode 100644 index 000000000000..d5c9d088827b --- /dev/null +++ b/nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py @@ -0,0 +1,71 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork +from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) + + # Data layer for training. + cifar10_dl = CIFAR10DataLayer(train=True) + # The "model". + cnn = ConvNetEncoder(input_depth=3, input_height=32, input_width=32) + reshaper = ReshapeTensor(input_sizes=[-1, 16, 2, 2], output_sizes=[-1, 64]) + ffn = FeedForwardNetwork(input_size=64, output_size=10, dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) + # Loss. + nll_loss = NLLLoss() + + # Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + _, img, tgt = cifar10_dl() + feat_map = cnn(inputs=img) + res_img = reshaper(inputs=feat_map) + logits = ffn(inputs=res_img) + pred = nl(inputs=logits) + loss = nll_loss(predictions=pred, targets=tgt) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Display the graph summmary. + logging.info(training_graph.summary()) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/examples/cifar10_resnet50_image_classification.py b/nemo/collections/cv/examples/cifar10_resnet50_image_classification.py new file mode 100644 index 000000000000..9e2b1b42c43e --- /dev/null +++ b/nemo/collections/cv/examples/cifar10_resnet50_image_classification.py @@ -0,0 +1,67 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity +from nemo.collections.cv.modules.trainables import ImageEncoder +from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) + + # Data layer - upscale the CIFAR10 images to ImageNet resolution. + cifar10_dl = CIFAR10DataLayer(height=224, width=224, train=True) + # The "model". + image_classifier = ImageEncoder(model_type="resnet50", output_size=10, pretrained=True, name="resnet50") + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) + # Loss. + nll_loss = NLLLoss() + + # Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + _, img, tgt = cifar10_dl() + logits = image_classifier(inputs=img) + pred = nl(inputs=logits) + loss = nll_loss(predictions=pred, targets=tgt) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Display the graph summmary. + logging.info(training_graph.summary()) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/examples/mnist_convnet_ffn_image_classification.py b/nemo/collections/cv/examples/mnist_convnet_ffn_image_classification.py new file mode 100644 index 000000000000..ee97ea614e2d --- /dev/null +++ b/nemo/collections/cv/examples/mnist_convnet_ffn_image_classification.py @@ -0,0 +1,71 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers import MNISTDataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork +from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # 0. Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) + + # Data layers for training and validation. + dl = MNISTDataLayer(height=28, width=28, train=True) + # The "model". + cnn = ConvNetEncoder(input_depth=1, input_height=28, input_width=28) + reshaper = ReshapeTensor(input_sizes=[-1, 16, 1, 1], output_sizes=[-1, 16]) + ffn = FeedForwardNetwork(input_size=16, output_size=10, dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) + # Loss. + nll_loss = NLLLoss() + + # Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + _, img, tgt, _ = dl() + feat_map = cnn(inputs=img) + res_img = reshaper(inputs=feat_map) + logits = ffn(inputs=res_img) + pred = nl(inputs=logits) + loss = nll_loss(predictions=pred, targets=tgt) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Display the graph summmary. + logging.info(training_graph.summary()) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/examples/mnist_ffn_image_classification.py b/nemo/collections/cv/examples/mnist_ffn_image_classification.py new file mode 100644 index 000000000000..20091b87cd97 --- /dev/null +++ b/nemo/collections/cv/examples/mnist_ffn_image_classification.py @@ -0,0 +1,69 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers import MNISTDataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import FeedForwardNetwork +from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) + + # Data layers for training and validation. + dl = MNISTDataLayer(height=28, width=28, train=True) + # The "model". + reshaper = ReshapeTensor(input_sizes=[-1, 1, 32, 32], output_sizes=[-1, 784]) + ffn = FeedForwardNetwork(input_size=784, output_size=10, hidden_sizes=[100, 100], dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) + # Loss. + nll_loss = NLLLoss() + + # Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + _, imgs, tgts, _ = dl() + res_imgs = reshaper(inputs=imgs) + logits = ffn(inputs=res_imgs) + preds = nl(inputs=logits) + loss = nll_loss(predictions=preds, targets=tgts) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Display the graph summmary. + logging.info(training_graph.summary()) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/examples/mnist_lenet5_image_classification.py b/nemo/collections/cv/examples/mnist_lenet5_image_classification.py new file mode 100644 index 000000000000..0bcb6be159d7 --- /dev/null +++ b/nemo/collections/cv/examples/mnist_lenet5_image_classification.py @@ -0,0 +1,103 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +from torch import max, mean, stack, tensor + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers import MNISTDataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.trainables import LeNet5 +from nemo.core import ( + DeviceType, + EvaluatorCallback, + NeuralGraph, + NeuralModuleFactory, + OperationMode, + SimpleLossLoggerCallback, +) +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.GPU) + + # Data layers for training and validation. + dl = MNISTDataLayer(height=32, width=32, train=True) + dl_e = MNISTDataLayer(height=32, width=32, train=False) + # The "model". + lenet5 = LeNet5() + # Loss. + nll_loss = NLLLoss() + + # Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + _, x, y, _ = dl() + p = lenet5(images=x) + loss = nll_loss(predictions=p, targets=y) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Display the graph summmary. + logging.info(training_graph.summary()) + + # Create a validation graph, starting from the second data layer. + with NeuralGraph(operation_mode=OperationMode.evaluation) as evaluation_graph: + _, x, y, _ = dl_e() + p = lenet5(images=x) + loss_e = nll_loss(predictions=p, targets=y) + + # Display the graph summmary. + logging.info(evaluation_graph.summary()) + + # Create the callbacks. + def eval_loss_per_batch_callback(tensors, global_vars): + if "eval_loss" not in global_vars.keys(): + global_vars["eval_loss"] = [] + for key, value in tensors.items(): + if key.startswith("loss"): + global_vars["eval_loss"].append(mean(stack(value))) + + def eval_loss_epoch_finished_callback(global_vars): + eloss = max(tensor(global_vars["eval_loss"])) + logging.info("Evaluation Loss: {0}".format(eloss)) + return dict({"Evaluation Loss": eloss}) + + ecallback = EvaluatorCallback( + eval_tensors=[loss_e], + user_iter_callback=eval_loss_per_batch_callback, + user_epochs_done_callback=eval_loss_epoch_finished_callback, + eval_step=100, + ) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback, ecallback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/modules/__init__.py b/nemo/collections/cv/modules/__init__.py new file mode 100644 index 000000000000..01216a63f0ef --- /dev/null +++ b/nemo/collections/cv/modules/__init__.py @@ -0,0 +1,20 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import nemo.collections.cv.modules.data_layers +import nemo.collections.cv.modules.losses +import nemo.collections.cv.modules.non_trainables +import nemo.collections.cv.modules.trainables diff --git a/nemo/collections/cv/modules/data_layers/__init__.py b/nemo/collections/cv/modules/data_layers/__init__.py new file mode 100644 index 000000000000..1eb1a0a59893 --- /dev/null +++ b/nemo/collections/cv/modules/data_layers/__init__.py @@ -0,0 +1,19 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.cv.modules.data_layers.cifar10_datalayer import * +from nemo.collections.cv.modules.data_layers.cifar100_datalayer import * +from nemo.collections.cv.modules.data_layers.mnist_datalayer import * diff --git a/nemo/collections/cv/modules/data_layers/cifar100_datalayer.py b/nemo/collections/cv/modules/data_layers/cifar100_datalayer.py new file mode 100644 index 000000000000..4ddb2e17a13f --- /dev/null +++ b/nemo/collections/cv/modules/data_layers/cifar100_datalayer.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# Copyright (C) IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +__author__ = "Tomasz Kornuta" + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/tasks/image_to_class/cifar_100.py +""" + +from os.path import expanduser +from typing import Optional + +from torch.utils.data import Dataset +from torchvision.datasets import CIFAR100 +from torchvision.transforms import Compose, Resize, ToTensor + +from nemo.backends.pytorch.nm import DataLayerNM +from nemo.core.neural_types import AxisKind, AxisType, ClassificationTarget, ImageValue, Index, NeuralType, StringLabel +from nemo.utils.decorators import add_port_docs + +__all__ = ['CIFAR100DataLayer'] + + +class CIFAR100DataLayer(DataLayerNM, Dataset): + """ + A "thin DataLayer" - wrapper around the torchvision's CIFAR100 dataset. + + Reference page: http://www.cs.toronto.edu/~kriz/cifar.html + """ + + def __init__( + self, + height: int = 32, + width: int = 32, + data_folder: str = "~/data/cifar100", + train: bool = True, + name: Optional[str] = None, + batch_size: int = 64, + shuffle: bool = True, + ): + """ + Initializes the CIFAR100 datalayer. + + Args: + height: image height (DEFAULT: 32) + width: image width (DEFAULT: 32) + data_folder: path to the folder with data, can be relative to user (DEFAULT: "~/data/cifar10") + train: use train or test splits (DEFAULT: True) + name: Name of the module (DEFAULT: None) + batch_size: size of batch (DEFAULT: 64) [PARAMETER OF DATALOADER] + shuffle: shuffle data (DEFAULT: True) [PARAMETER OF DATALOADER] + """ + # Call the base class constructor of DataLayer. + DataLayerNM.__init__(self, name=name) + + # Store height and width. + self._height = height + self._width = width + + # Create transformations: up-scale and transform to tensors. + mnist_transforms = Compose([Resize((self._height, self._width)), ToTensor()]) + + # Get absolute path. + abs_data_folder = expanduser(data_folder) + + # Create the CIFAR10 dataset object. + self._dataset = CIFAR100(root=abs_data_folder, train=train, download=True, transform=mnist_transforms) + + # Remember the params passed to DataLoader. :] + self._batch_size = batch_size + self._shuffle = shuffle + + # Process labels. + all_labels = { + "aquatic_mammals": "beaver, dolphin, otter, seal, whale".split(", "), + "fish": "aquarium_fish, flatfish, ray, shark, trout".split(", "), + "flowers": "orchid, poppy, rose, sunflower, tulip".split(", "), + "food_containers": "bottle, bowl, can, cup, plate".split(", "), + "fruit_and_vegetables": "apple, mushroom, orange, pear, sweet_pepper".split(", "), + "household_electrical_devices": "clock, keyboard, lamp, telephone, television".split(", "), + "household_furniture": "bed, chair, couch, table, wardrobe".split(", "), + "insects": "bee, beetle, butterfly, caterpillar, cockroach".split(", "), + "large_carnivores": "bear, leopard, lion, tiger, wolf".split(", "), + "large_man-made_outdoor_things": "bridge, castle, house, road, skyscraper".split(", "), + "large_natural_outdoor_scenes": "cloud, forest, mountain, plain, sea".split(", "), + "large_omnivores_and_herbivores": "camel, cattle, chimpanzee, elephant, kangaroo".split(", "), + "medium-sized_mammals": "fox, porcupine, possum, raccoon, skunk".split(", "), + "non-insect_invertebrates": "crab, lobster, snail, spider, worm".split(", "), + "people": "baby, boy, girl, man, woman".split(", "), + "reptiles": "crocodile, dinosaur, lizard, snake, turtle".split(", "), + "small_mammals": "hamster, mouse, rabbit, shrew, squirrel".split(", "), + "trees": "maple_tree, oak_tree, palm_tree, pine_tree, willow_tree".split(", "), + "vehicles_1": "bicycle, bus, motorcycle, pickup_truck, train".split(", "), + "vehicles_2": "lawn_mower, rocket, streetcar, tank, tractor".split(", "), + } + + coarse_word_to_ix = {} + fine_to_coarse_mapping = {} + fine_labels = [] + for coarse_id, (key, values) in enumerate(all_labels.items()): + # Add mapping from coarse category name to coarse id. + coarse_word_to_ix[key] = coarse_id + # Add mappings from fine category names to coarse id. + for value in values: + fine_to_coarse_mapping[value] = coarse_id + # Add values to list of fine labels. + fine_labels.extend(values) + + # Sort fine labels. + fine_labels = sorted(fine_labels) + + # Generate fine word mappings. + fine_word_to_ix = {fine_labels[i]: i for i in range(len(fine_labels))} + + # Reverse mapping - for labels. + self._fine_ix_to_word = {value: key for (key, value) in fine_word_to_ix.items()} + + # Reverse mapping - for labels. + self._coarse_ix_to_word = {value: key for (key, value) in coarse_word_to_ix.items()} + + # Create fine to coarse id mapping. + self._fine_to_coarse_id_mapping = {} + for fine_label, fine_id in fine_word_to_ix.items(): + self._fine_to_coarse_id_mapping[fine_id] = fine_to_coarse_mapping[fine_label] + # print(" {} ({}) : {} ".format(fine_label, fine_id, self.coarse_ix_to_word[fine_to_coarse_mapping[fine_label]])) + + @property + @add_port_docs() + def output_ports(self): + """ + Creates definitions of output ports. + By default, it sets image width and height to 32. + """ + return { + "indices": NeuralType(tuple('B'), elements_type=Index()), + "images": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=3), + AxisType(kind=AxisKind.Height, size=self._height), + AxisType(kind=AxisKind.Width, size=self._width), + ), + elements_type=ImageValue(), # uint8, <0-255> + ), + "coarse_targets": NeuralType(tuple('B'), elements_type=ClassificationTarget()), + "coarse_labels": NeuralType(tuple('B'), elements_type=StringLabel()), # Labels is string! + "fine_targets": NeuralType(tuple('B'), elements_type=ClassificationTarget()), + "fine_labels": NeuralType(tuple('B'), elements_type=StringLabel()), # Labels is string! + } + + def __len__(self): + """ + Returns: + Length of the dataset. + """ + return len(self._dataset) + + def __getitem__(self, index: int): + """ + Returns a single sample. + + Args: + index: index of the sample to return. + + """ + # Get image and target. + img, fine_target = self._dataset.__getitem__(index) + # Get coarse target. + coarse_target = self._fine_to_coarse_id_mapping[fine_target] + + # Labels. + fine_label = self._fine_ix_to_word[fine_target] + coarse_label = self._coarse_ix_to_word[self._fine_to_coarse_id_mapping[fine_target]] + + # Return sample. + return index, img, coarse_target, coarse_label, fine_target, fine_label + + @property + def dataset(self): + """ + Returns: + Self - just to be "compatible" with the current NeMo train action. + """ + return self # ! Important - as we want to use this __getitem__ method! diff --git a/nemo/collections/cv/modules/data_layers/cifar10_datalayer.py b/nemo/collections/cv/modules/data_layers/cifar10_datalayer.py new file mode 100644 index 000000000000..8faa234ea51f --- /dev/null +++ b/nemo/collections/cv/modules/data_layers/cifar10_datalayer.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from os.path import expanduser +from typing import Optional + +from torch.utils.data import Dataset +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Resize, ToTensor + +from nemo.backends.pytorch.nm import DataLayerNM +from nemo.core.neural_types import AxisKind, AxisType, ClassificationTarget, ImageValue, Index, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['CIFAR10DataLayer'] + + +class CIFAR10DataLayer(DataLayerNM, Dataset): + """ + A "thin DataLayer" - wrapper around the torchvision's CIFAR10 dataset. + + Reference page: http://www.cs.toronto.edu/~kriz/cifar.html + """ + + def __init__( + self, + height: int = 32, + width: int = 32, + data_folder: str = "~/data/cifar10", + train: bool = True, + name: Optional[str] = None, + batch_size: int = 64, + shuffle: bool = True, + ): + """ + Initializes the CIFAR10 datalayer. + + Args: + height: image height (DEFAULT: 32) + width: image width (DEFAULT: 32) + data_folder: path to the folder with data, can be relative to user (DEFAULT: "~/data/cifar10") + train: use train or test splits (DEFAULT: True) + name: Name of the module (DEFAULT: None) + batch_size: size of batch (DEFAULT: 64) [PARAMETER OF DATALOADER] + shuffle: shuffle data (DEFAULT: True) [PARAMETER OF DATALOADER] + """ + # Call the base class constructor of DataLayer. + DataLayerNM.__init__(self, name=name) + + # Store height and width. + self._height = height + self._width = width + + # Create transformations: up-scale and transform to tensors. + mnist_transforms = Compose([Resize((self._height, self._width)), ToTensor()]) + + # Get absolute path. + abs_data_folder = expanduser(data_folder) + + # Create the CIFAR10 dataset object. + self._dataset = CIFAR10(root=abs_data_folder, train=train, download=True, transform=mnist_transforms) + + # Remember the params passed to DataLoader. :] + self._batch_size = batch_size + self._shuffle = shuffle + + @property + @add_port_docs() + def output_ports(self): + """ + Creates definitions of output ports. + By default, it sets image width and height to 32. + """ + return { + "indices": NeuralType(tuple('B'), elements_type=Index()), + "images": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=3), + AxisType(kind=AxisKind.Height, size=self._height), + AxisType(kind=AxisKind.Width, size=self._width), + ), + elements_type=ImageValue(), # uint8, <0-255> + ), + "targets": NeuralType(tuple('B'), elements_type=ClassificationTarget()), + } + + def __len__(self): + """ + Returns: + Length of the dataset. + """ + return len(self._dataset) + + def __getitem__(self, index: int): + """ + Returns a single sample. + + Args: + index: index of the sample to return. + """ + # Get image and target. + img, target = self._dataset.__getitem__(index) + + # Return sample. + return index, img, target + + @property + def dataset(self): + """ + Returns: + Self - just to be "compatible" with the current NeMo train action. + """ + return self # ! Important - as we want to use this __getitem__ method! diff --git a/nemo/collections/cv/modules/data_layers/mnist_datalayer.py b/nemo/collections/cv/modules/data_layers/mnist_datalayer.py new file mode 100644 index 000000000000..a2ff835ffdc9 --- /dev/null +++ b/nemo/collections/cv/modules/data_layers/mnist_datalayer.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from os.path import expanduser +from typing import Optional + +from torch.utils.data import Dataset +from torchvision.datasets import MNIST +from torchvision.transforms import Compose, Resize, ToTensor + +from nemo.backends.pytorch.nm import DataLayerNM +from nemo.core.neural_types import ( + AxisKind, + AxisType, + ClassificationTarget, + Index, + NeuralType, + NormalizedImageValue, + StringLabel, +) +from nemo.utils.decorators import add_port_docs + +__all__ = ['MNISTDataLayer'] + + +class MNISTDataLayer(DataLayerNM, Dataset): + """ + A "thin DataLayer" - wrapper around the torchvision's MNIST dataset. + """ + + def __init__( + self, + height: int = 28, + width: int = 28, + data_folder: str = "~/data/mnist", + train: bool = True, + name: Optional[str] = None, + batch_size: int = 64, + shuffle: bool = True, + ): + """ + Initializes the MNIST datalayer. + + Args: + height: image height (DEFAULT: 28) + width: image width (DEFAULT: 28) + data_folder: path to the folder with data, can be relative to user (DEFAULT: "~/data/mnist") + train: use train or test splits (DEFAULT: True) + name: Name of the module (DEFAULT: None) + batch_size: size of batch (DEFAULT: 64) [PARAMETER OF DATALOADER] + shuffle: shuffle data (DEFAULT: True) [PARAMETER OF DATALOADER] + """ + # Call the base class constructor of DataLayer. + DataLayerNM.__init__(self, name=name) + + # Store height and width. + self._height = height + self._width = width + + # Create transformations: up-scale and transform to tensors. + mnist_transforms = Compose([Resize((self._height, self._width)), ToTensor()]) + + # Get absolute path. + abs_data_folder = expanduser(data_folder) + + # Create the MNIST dataset object. + self._dataset = MNIST(root=abs_data_folder, train=train, download=True, transform=mnist_transforms) + + # Remember the params passed to DataLoader. :] + self._batch_size = batch_size + self._shuffle = shuffle + + # Class names. + labels = 'Zero One Two Three Four Five Six Seven Eight Nine'.split(' ') + word_to_ix = {labels[i]: i for i in range(10)} + + # Reverse mapping. + self._ix_to_word = {value: key for (key, value) in word_to_ix.items()} + + @property + @add_port_docs() + def output_ports(self): + """ + Creates definitions of output ports. + By default, it sets image width and height to 32. + """ + return { + "indices": NeuralType(tuple('B'), elements_type=Index()), + "images": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=1), + AxisType(kind=AxisKind.Height, size=self._height), + AxisType(kind=AxisKind.Width, size=self._width), + ), + elements_type=NormalizedImageValue(), # float, <0-1> + ), + "targets": NeuralType(tuple('B'), elements_type=ClassificationTarget()), # Target are ints! + "labels": NeuralType(tuple('B'), elements_type=StringLabel()), # Labels is string! + } + + def __len__(self): + """ + Returns: + Length of the dataset. + """ + return len(self._dataset) + + def __getitem__(self, index: int): + """ + Returns a single sample. + + Args: + index: index of the sample to return. + """ + # Get image and target. + img, target = self._dataset.__getitem__(index) + + # Return sample. + return index, img, target, self._ix_to_word[target] + + @property + def ix_to_word(self): + """ + Returns: + Dictionary with mapping of target indices (int) to labels (class names as strings) + that can we used by other modules. + """ + return self._ix_to_word + + @property + def dataset(self): + """ + Returns: + Self - just to be "compatible" with the current NeMo train action. + """ + return self # ! Important - as we want to use this __getitem__ method! diff --git a/nemo/collections/cv/modules/losses/__init__.py b/nemo/collections/cv/modules/losses/__init__.py new file mode 100644 index 000000000000..e866e912d4ee --- /dev/null +++ b/nemo/collections/cv/modules/losses/__init__.py @@ -0,0 +1,17 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.cv.modules.losses.nll_loss import * diff --git a/nemo/collections/cv/modules/losses/nll_loss.py b/nemo/collections/cv/modules/losses/nll_loss.py new file mode 100644 index 000000000000..1e447453400b --- /dev/null +++ b/nemo/collections/cv/modules/losses/nll_loss.py @@ -0,0 +1,60 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from typing import Optional + +import torch + +from nemo.backends.pytorch.nm import LossNM +from nemo.core.neural_types import ClassificationTarget, LogprobsType, LossType, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['NLLLoss'] + + +class NLLLoss(LossNM): + """ Class representing a simple NLL loss. """ + + def __init__(self, name: Optional[str] = None): + """ + Constructor. + + Args: + name: Name of the module (DEFAULT: None) + """ + # Call the base class constructor. + LossNM.__init__(self, name=name) + # Set criterion. + self._criterion = torch.nn.NLLLoss() + + @property + @add_port_docs() + def input_ports(self): + """ Returns definitions of module input ports. """ + return { + "predictions": NeuralType(axes=('B', 'ANY'), elements_type=LogprobsType()), + "targets": NeuralType(axes=('B'), elements_type=ClassificationTarget()), + } + + @property + @add_port_docs() + def output_ports(self): + """ Returns definitions of module output ports. """ + return {"loss": NeuralType(elements_type=LossType())} + + # You need to implement this function + def _loss_function(self, **kwargs): + return self._criterion(*(kwargs.values())) diff --git a/nemo/collections/cv/modules/non_trainables/__init__.py b/nemo/collections/cv/modules/non_trainables/__init__.py new file mode 100644 index 000000000000..8ee079e36de8 --- /dev/null +++ b/nemo/collections/cv/modules/non_trainables/__init__.py @@ -0,0 +1,18 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.cv.modules.non_trainables.non_linearity import * +from nemo.collections.cv.modules.non_trainables.reshape_tensor import * diff --git a/nemo/collections/cv/modules/non_trainables/non_linearity.py b/nemo/collections/cv/modules/non_trainables/non_linearity.py new file mode 100644 index 000000000000..5aaa5d96dfbe --- /dev/null +++ b/nemo/collections/cv/modules/non_trainables/non_linearity.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +from typing import List, Optional + +import torch + +from nemo.backends.pytorch.nm import NonTrainableNM +from nemo.core.neural_types import AxisKind, AxisType, LogprobsType, NeuralType, VoidType +from nemo.utils import logging +from nemo.utils.configuration_parsing import get_value_from_dictionary +from nemo.utils.decorators import add_port_docs + +__all__ = ['NonLinearity'] + + +class NonLinearity(NonTrainableNM): + """ + Class responsible for applying additional non-linearity along the last axis of the input tensor. + + """ + + def __init__(self, type: str = "logsoftmax", sizes: List[int] = [-1], name: Optional[str] = None): + """ + Constructor initializing the non-linearity. + + Args: + type: Type of non-linearity (currently only logsoftmax is supported) + sizes: Sizes of dimensions of the input/output tensors (DEFAULT: [-1] - variable size batch) + name: Name of the module (DEFAULT: None) + """ + # Call constructor of parent classes. + NonTrainableNM.__init__(self, name=name) + + # Store params. + self._type = type + self._sizes = sizes + + # Get type - only one option accepted (for now). + self._non_linearity_type = get_value_from_dictionary(type, ["logsoftmax"]) + + # Apply the non-linearity along the last dimension. + dim = len(sizes) - 1 + self._non_linearity = torch.nn.LogSoftmax(dim=dim) + + @property + @add_port_docs() + def input_ports(self): + """ + Returns definitions of module input ports. + Batch of inputs, each represented as index [BATCH_SIZE x ... x INPUT_SIZE] + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + for size in self._sizes[1:]: + axes.append(AxisType(kind=AxisKind.Any, size=size)) + # Return neural type. + return {"inputs": NeuralType(axes, VoidType())} + + @property + @add_port_docs() + def output_ports(self): + """ + Returns definitions of module output ports. + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + for size in self._sizes[1:]: + axes.append(AxisType(kind=AxisKind.Any, size=size)) + # Return neural type. + # TODO: if self._type != "logsoftmax" + return {"outputs": NeuralType(axes, LogprobsType())} + + def forward(self, inputs): + """ + Encodes "inputs" in the format of a single tensor. + Stores reshaped tensor in "outputs" field of in data_streams. + + Args: + inputs: a tensor [BATCH_SIZE x ...] + + Returns: + Outputs a tensor [BATCH_SIZE x ...] + """ + # print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + + # Reshape. + # TODO: if self._type != "logsoftmax" + return self._non_linearity(inputs) diff --git a/nemo/collections/cv/modules/non_trainables/reshape_tensor.py b/nemo/collections/cv/modules/non_trainables/reshape_tensor.py new file mode 100644 index 000000000000..ac5dcb2cf9b1 --- /dev/null +++ b/nemo/collections/cv/modules/non_trainables/reshape_tensor.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +__author__ = "Tomasz Kornuta" + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/transforms/reshape_tensor.py +""" + +from typing import List, Optional + +from nemo.backends.pytorch.nm import NonTrainableNM +from nemo.core.neural_types import AxisKind, AxisType, NeuralType, VoidType +from nemo.utils import logging +from nemo.utils.configuration_error import ConfigurationError +from nemo.utils.decorators import add_port_docs + +__all__ = ['ReshapeTensor'] + + +class ReshapeTensor(NonTrainableNM): + """ + Class responsible for reshaping the input tensor. + + Reshapes tensor from e.g. [64, 16, 2, 2] to [64, 64]. + + For more details please refer to: https://pytorch.org/docs/master/generated/torch.reshape.html + """ + + def __init__(self, input_sizes: List[int], output_sizes: List[int], name: Optional[str] = None): + """ + Initializes the object. + + Args: + input_sizes: Sizes of dimensions of the input tensor. + output_sizes: Sizes of dimensions of the output. + name: Name of the module (DEFAULT: None) + """ + # Call constructor of parent classes. + NonTrainableNM.__init__(self, name=name) + + # Validate params. + if type(input_sizes) != list or len(input_sizes) < 2: + raise ConfigurationError( + "'input_sizes' must be at least a list with two values (received {})".format(self.input_sizes) + ) + if type(output_sizes) != list or len(output_sizes) < 2: + raise ConfigurationError( + "'output_sizes' must be at least a list with two values (received {})".format(self.output_sizes) + ) + + # Get input and output shapes from configuration. + self._input_sizes = input_sizes + self._output_sizes = output_sizes + + @property + @add_port_docs() + def input_ports(self): + """ + Returns definitions of module input ports. + Batch of inputs, each represented as index [BATCH_SIZE x ... x INPUT_SIZE] + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + for size in self._input_sizes[1:]: + axes.append(AxisType(kind=AxisKind.Any, size=size)) + # Return neural type. + return {"inputs": NeuralType(axes, VoidType())} + + @property + @add_port_docs() + def output_ports(self): + """ + Returns definitions of module output ports. + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + for size in self._output_sizes[1:]: + axes.append(AxisType(kind=AxisKind.Any, size=size)) + # Return neural type. + return {"outputs": NeuralType(axes, VoidType())} + + def forward(self, inputs): + """ + Encodes "inputs" in the format of a single tensor. + Stores reshaped tensor in "outputs" field of in data_streams. + + Args: + inputs: a tensor [BATCH_SIZE x ...] + + Returns: + Outputs a tensor [BATCH_SIZE x ...] + """ + # print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + + # Reshape. + return inputs.view(self._output_sizes) diff --git a/nemo/collections/cv/modules/trainables/__init__.py b/nemo/collections/cv/modules/trainables/__init__.py new file mode 100644 index 000000000000..6eb1a4a61b73 --- /dev/null +++ b/nemo/collections/cv/modules/trainables/__init__.py @@ -0,0 +1,20 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from nemo.collections.cv.modules.trainables.convnet_encoder import * +from nemo.collections.cv.modules.trainables.feed_forward_network import * +from nemo.collections.cv.modules.trainables.image_encoder import * +from nemo.collections.cv.modules.trainables.lenet5 import * diff --git a/nemo/collections/cv/modules/trainables/convnet_encoder.py b/nemo/collections/cv/modules/trainables/convnet_encoder.py new file mode 100644 index 000000000000..d13796d280b2 --- /dev/null +++ b/nemo/collections/cv/modules/trainables/convnet_encoder.py @@ -0,0 +1,364 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# Copyright (C) IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +__author__ = "Younes Bouhadjar, Vincent Marois, Tomasz Kornuta" + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/models/vision/convnet_encoder.py +""" + + +from typing import Optional + +import numpy as np +import torch.nn as nn + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core.neural_types import AxisKind, AxisType, ImageFeatureValue, ImageValue, NeuralType +from nemo.utils import logging +from nemo.utils.decorators import add_port_docs + +__all__ = ['ConvNetEncoder'] + + +class ConvNetEncoder(TrainableNM): + """ + A simple image encoder consisting of 3 consecutive convolutional layers. + The parameters of input image (height, width and depth) are not hardcoded + so the encoder can be adjusted for a given application (image dimensions). + """ + + def __init__( + self, + input_depth: int, + input_height: int, + input_width: int, + conv1_out_channels: int = 64, + conv1_kernel_size: int = 3, + conv1_stride: int = 1, + conv1_padding: int = 0, + maxpool1_kernel_size: int = 2, + conv2_out_channels: int = 32, + conv2_kernel_size: int = 3, + conv2_stride: int = 1, + conv2_padding: int = 0, + maxpool2_kernel_size: int = 2, + conv3_out_channels: int = 16, + conv3_kernel_size: int = 3, + conv3_stride: int = 1, + conv3_padding: int = 0, + maxpool3_kernel_size: int = 2, + name: Optional[str] = None, + ): + """ + Constructor of the a simple CNN. + + The overall structure of this CNN is as follows: + + (Conv1 -> MaxPool1 -> ReLu) -> (Conv2 -> MaxPool2 -> ReLu) -> (Conv3 -> MaxPool3 -> ReLu) + + The parameters that the user can change are: + + - For Conv1, Conv2 & Conv3: number of output channels, kernel size, stride and padding. + - For MaxPool1, MaxPool2 & MaxPool3: Kernel size + + + .. note:: + + We are using the default values of ``dilatation``, ``groups`` & ``bias`` for ``nn.Conv2D``. + + Similarly for the ``stride``, ``padding``, ``dilatation``, ``return_indices`` & ``ceil_mode`` of \ + ``nn.MaxPool2D``. + + Args: + input_depth: Depth of the input image + input_height: Height of the input image + input_width: Width of the input image + convX_out_channels: Number of output channels of layer X (X=1,2,3) + convX_kernel_size: Kernel size of layer X (X=1,2,3) + convX_stride: Stride of layer X (X=1,2,3) + convX_padding: Padding of layer X (X=1,2,3) + name: Name of the module (DEFAULT: None) + """ + # Call base constructor. + TrainableNM.__init__(self, name=name) + + # Get input image information from the global parameters. + self._input_depth = input_depth + self._input_height = input_height + self._input_width = input_width + + # Retrieve the Conv1 parameters. + self._conv1_out_channels = conv1_out_channels + self._conv1_kernel_size = conv1_kernel_size + self._conv1_stride = conv1_stride + self._conv1_padding = conv1_padding + + # Retrieve the MaxPool1 parameter. + self._maxpool1_kernel_size = maxpool1_kernel_size + + # Retrieve the Conv2 parameters. + self._conv2_out_channels = conv2_out_channels + self._conv2_kernel_size = conv2_kernel_size + self._conv2_stride = conv2_stride + self._conv2_padding = conv2_padding + + # Retrieve the MaxPool2 parameter. + self._maxpool2_kernel_size = maxpool2_kernel_size + + # Retrieve the Conv3 parameters. + self._conv3_out_channels = conv3_out_channels + self._conv3_kernel_size = conv3_kernel_size + self._conv3_stride = conv3_stride + self._conv3_padding = conv3_padding + + # Retrieve the MaxPool3 parameter. + self._maxpool3_kernel_size = maxpool3_kernel_size + + # We can compute the spatial size of the output volume as a function of the input volume size (W), + # the receptive field size of the Conv Layer neurons (F), the stride with which they are applied (S), + # and the amount of zero padding used (P) on the border. + # The corresponding equation is conv_size = ((W−F+2P)/S)+1. + + # doc for nn.Conv2D: https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d + # doc for nn.MaxPool2D: https://pytorch.org/docs/stable/nn.html#torch.nn.MaxPool2d + + # ---------------------------------------------------- + # Conv1 + self._conv1 = nn.Conv2d( + in_channels=self._input_depth, + out_channels=self._conv1_out_channels, + kernel_size=self._conv1_kernel_size, + stride=self._conv1_stride, + padding=self._conv1_padding, + dilation=1, + groups=1, + bias=True, + ) + + width_features_conv1 = np.floor( + ((self._input_width - self._conv1_kernel_size + 2 * self._conv1_padding) / self._conv1_stride) + 1 + ) + height_features_conv1 = np.floor( + ((self._input_height - self._conv1_kernel_size + 2 * self._conv1_padding) / self._conv1_stride) + 1 + ) + + # ---------------------------------------------------- + # MaxPool1 + self._maxpool1 = nn.MaxPool2d(kernel_size=self._maxpool1_kernel_size) + + width_features_maxpool1 = np.floor( + ((width_features_conv1 - self._maxpool1_kernel_size + 2 * self._maxpool1.padding) / self._maxpool1.stride) + + 1 + ) + + height_features_maxpool1 = np.floor( + ((height_features_conv1 - self._maxpool1_kernel_size + 2 * self._maxpool1.padding) / self._maxpool1.stride) + + 1 + ) + + # ---------------------------------------------------- + # Conv2 + self._conv2 = nn.Conv2d( + in_channels=self._conv1_out_channels, + out_channels=self._conv2_out_channels, + kernel_size=self._conv2_kernel_size, + stride=self._conv2_stride, + padding=self._conv2_padding, + dilation=1, + groups=1, + bias=True, + ) + + width_features_conv2 = np.floor( + ((width_features_maxpool1 - self._conv2_kernel_size + 2 * self._conv2_padding) / self._conv2_stride) + 1 + ) + height_features_conv2 = np.floor( + ((height_features_maxpool1 - self._conv2_kernel_size + 2 * self._conv2_padding) / self._conv2_stride) + 1 + ) + + # ---------------------------------------------------- + # MaxPool2 + self._maxpool2 = nn.MaxPool2d(kernel_size=self._maxpool2_kernel_size) + + width_features_maxpool2 = np.floor( + ((width_features_conv2 - self._maxpool2_kernel_size + 2 * self._maxpool2.padding) / self._maxpool2.stride) + + 1 + ) + height_features_maxpool2 = np.floor( + ((height_features_conv2 - self._maxpool2_kernel_size + 2 * self._maxpool2.padding) / self._maxpool2.stride) + + 1 + ) + + # ---------------------------------------------------- + # Conv3 + self._conv3 = nn.Conv2d( + in_channels=self._conv2_out_channels, + out_channels=self._conv3_out_channels, + kernel_size=self._conv3_kernel_size, + stride=self._conv3_stride, + padding=self._conv3_padding, + dilation=1, + groups=1, + bias=True, + ) + + width_features_conv3 = np.floor( + ((width_features_maxpool2 - self._conv3_kernel_size + 2 * self._conv3_padding) / self._conv3_stride) + 1 + ) + height_features_conv3 = np.floor( + ((height_features_maxpool2 - self._conv3_kernel_size + 2 * self._conv3_padding) / self._conv3_stride) + 1 + ) + + # ---------------------------------------------------- + # MaxPool3 + self._maxpool3 = nn.MaxPool2d(kernel_size=self._maxpool3_kernel_size) + + width_features_maxpool3 = np.floor( + ((width_features_conv3 - self._maxpool3_kernel_size + 2 * self._maxpool3.padding) / self._maxpool3.stride) + + 1 + ) + + height_features_maxpool3 = np.floor( + ((height_features_conv3 - self._maxpool1_kernel_size + 2 * self._maxpool3.padding) / self._maxpool3.stride) + + 1 + ) + + # Rememvber the output dims. + self._feature_map_height = height_features_maxpool3 + self._feature_map_width = width_features_maxpool3 + self._feature_map_depth = self._conv3_out_channels + + # Log info about dimensions. + logging.info('Input shape: [-1, {}, {}, {}]'.format(self._input_depth, self._input_height, self._input_width)) + logging.debug('Computed output shape of each layer:') + logging.debug( + ' * Conv1: [-1, {}, {}, {}]'.format(self._conv1_out_channels, height_features_conv1, width_features_conv1) + ) + logging.debug( + ' * MaxPool1: [-1, {}, {}, {}]'.format( + self._conv1_out_channels, height_features_maxpool1, width_features_maxpool1 + ) + ) + logging.debug( + ' * Conv2: [-1, {}, {}, {}]'.format(self._conv2_out_channels, height_features_conv2, width_features_conv2) + ) + logging.debug( + ' * MaxPool2: [-1, {}, {}, {}]'.format( + self._conv2_out_channels, height_features_maxpool2, width_features_maxpool2 + ) + ) + logging.debug( + ' * Conv3: [-1, {}, {}, {}]'.format( + self._conv3_out_channels, height_features_conv3, width_features_conv3, + ) + ) + logging.debug( + ' * MaxPool3: [-1, {}, {}, {}]'.format( + self._conv3_out_channels, width_features_maxpool3, height_features_maxpool3 + ) + ) + logging.info( + 'Output shape: [-1, {}, {}, {}]'.format( + self._feature_map_depth, self._feature_map_height, self._feature_map_width + ) + ) + + @property + @add_port_docs() + def input_ports(self): + """ + Returns definitions of module input ports. + """ + return { + "inputs": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=self._input_depth), + AxisType(kind=AxisKind.Height, size=self._input_height), + AxisType(kind=AxisKind.Width, size=self._input_width), + ), + elements_type=ImageValue(), + ) + } + + @property + @add_port_docs() + def output_ports(self): + """ + Returns definitions of module output ports. + """ + return { + "outputs": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=self._feature_map_depth), + AxisType(kind=AxisKind.Height, size=self._feature_map_height), + AxisType(kind=AxisKind.Width, size=self._feature_map_width), + ), + elements_type=ImageFeatureValue(), + ) + } + + def forward(self, inputs): + """ + Forward pass of the convnet module. + + :param data_streams: DataStreams({'inputs','outputs'}), where: + + Args: + inputs: Batch of inputs to be processed [BATCH_SIZE x INPUT_DEPTH x INPUT_HEIGHT x INPUT_WIDTH] + + Returns: + Batch of outputs [BATCH_SIZE x OUTPUT_DEPTH x OUTPUT_HEIGHT x OUTPUT_WIDTH] + + """ + # apply Convolutional layer 1 + out_conv1 = self._conv1(inputs) + + # apply max_pooling and relu + out_maxpool1 = nn.functional.relu(self._maxpool1(out_conv1)) + + # apply Convolutional layer 2 + out_conv2 = self._conv2(out_maxpool1) + + # apply max_pooling and relu + out_maxpool2 = nn.functional.relu(self._maxpool2(out_conv2)) + + # apply Convolutional layer 3 + out_conv3 = self._conv3(out_maxpool2) + + # apply max_pooling and relu + out_maxpool3 = nn.functional.relu(self._maxpool3(out_conv3)) + + # Return output. + return out_maxpool3 diff --git a/nemo/collections/cv/modules/trainables/feed_forward_network.py b/nemo/collections/cv/modules/trainables/feed_forward_network.py new file mode 100644 index 000000000000..29cfdc13e2e6 --- /dev/null +++ b/nemo/collections/cv/modules/trainables/feed_forward_network.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +__author__ = "Tomasz Kornuta" + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/models/general_usage/feed_forward_network.py +""" + +from typing import List, Optional + +import torch + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core.neural_types import AxisKind, AxisType, NeuralType, VoidType +from nemo.utils import logging +from nemo.utils.configuration_error import ConfigurationError +from nemo.utils.decorators import add_port_docs + +__all__ = ['FeedForwardNetwork'] + + +class FeedForwardNetwork(TrainableNM): + """ + A simple trainable module consisting of several stacked fully connected layers + with ReLU non-linearities and dropout between them. + + # TODO: parametrize with other non-linearities. + + """ + + def __init__( + self, + input_size: int, + output_size: int, + hidden_sizes: List[int] = [], + dimensions: int = 2, + dropout_rate: float = 0, + name: Optional[str] = None, + ): + """ + Initializes the feed-forwad network. + + Args: + input_size: Size of input (1D) + output_sizes: Size of the output (1D) + hidden_sizes: Sizes of the consecutive hidden layers (DEFAULT: [] = no hidden) + dimensions: Number of dimensions of input/output tensors (DEFAULT: 2 = BATCH X INPUT_SIZE) + dropout_rate: Dropout rage (Default: 0) + name: Name of the module (DEFAULT: None) + """ + # Call constructor of the parent class. + TrainableNM.__init__(self, name=name) + + # Get input size. + self._input_size = input_size + if type(self._input_size) == list: + if len(self._input_size) == 1: + self._input_size = self._input_size[0] + else: + raise ConfigurationError("'input_size' must be a single value (received {})".format(self._input_size)) + + # Get input/output dimensions, i.e. number of axes of the input [BATCH_SIZE x ... x INPUT_SIZE]. + # The module will "broadcast" over those dimensions. + self._dimensions = dimensions + if self._dimensions < 2: + raise ConfigurationError("'dimensions' must be bigger than two (received {})".format(self._dimensions)) + + # Get output (prediction/logits) size. + self._output_size = output_size + if type(self._output_size) == list: + if len(self._output_size) == 1: + self._output_size = self._output_size[0] + else: + raise ConfigurationError( + "'output_size' must be a single value (received {})".format(self._output_size) + ) + + logging.info( + "Initializing network with input size = {} and output size = {}".format( + self._input_size, self._output_size + ) + ) + + # Create the module list. + modules = [] + + # Retrieve number of hidden layers, along with their sizes (numbers of hidden neurons from configuration). + if type(hidden_sizes) == list: + # Stack linear layers. + input_dim = self._input_size + for hidden_dim in hidden_sizes: + # Add linear layer. + modules.append(torch.nn.Linear(input_dim, hidden_dim)) + # Add activation. + modules.append(torch.nn.ReLU()) + # Add dropout. + if dropout_rate > 0: + modules.append(torch.nn.Dropout(dropout_rate)) + # Remember size. + input_dim = hidden_dim + + # Add the last output" (or in a special case: the only) layer. + modules.append(torch.nn.Linear(input_dim, self._output_size)) + + logging.info("Created {} hidden layers with sizes {}".format(len(hidden_sizes), hidden_sizes)) + + else: + raise ConfigurationError( + "'hidden_sizes' must contain a list with numbers of neurons in consecutive hidden layers (received {})".format( + hidden_sizes + ) + ) + + # Finally create the sequential model out of those modules. + self.layers = torch.nn.Sequential(*modules) + + @property + @add_port_docs() + def input_ports(self): + """ + Returns definitions of module input ports. + Batch of inputs, each represented as index [BATCH_SIZE x ... x INPUT_SIZE] + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + # Add the "additional dimensions". + for _ in range(self._dimensions)[1:-1]: + axes.append(AxisType(kind=AxisKind.Any)) + # Add the last axis: input_size + axes.append(AxisType(kind=AxisKind.Any, size=self._input_size)) + # Return neural type. + return {"inputs": NeuralType(axes, VoidType())} + + @property + @add_port_docs() + def output_ports(self): + """ + Returns definitions of module output ports. + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + # Add the "additional dimensions". + for _ in range(self._dimensions)[1:-1]: + axes.append(AxisType(kind=AxisKind.Any)) + # Add the last axis: input_size + axes.append(AxisType(kind=AxisKind.Any, size=self._output_size)) + # Return neural type: batch of "logits" of "any type". + return {"outputs": NeuralType(axes, VoidType())} + + def forward(self, inputs): + """ + Performs the forward step of the module. + + Args: + inputs: Batch of inputs to be processed [BATCH_SIZE x ... x INPUT_SIZE] + + Returns: + Batch of outputs/predictions (log_probs) [BATCH_SIZE x ... x NUM_CLASSES] + """ + + # print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + + # Check that the input has the number of dimensions that we expect + if len(inputs.shape) != self._dimensions: + raise ConfigurationError( + "Expected `{}` dimensions for input, but received `{}` instead. " + F"Check fix the dimensions in your script.".format(self._dimensions, len(inputs.shape)) + ) + + # Reshape such that we do a broadcast over the last dimension + origin_shape = inputs.shape + inputs = inputs.contiguous().view(-1, origin_shape[-1]) + + # Propagate inputs through all layers and activations. + outputs = self.layers(inputs) + + # Restore the input dimensions but the last one (as it's been resized by the FFN) + outputs = outputs.view(*origin_shape[0 : self._dimensions - 1], -1) + + # Return the result. + return outputs diff --git a/nemo/collections/cv/modules/trainables/image_encoder.py b/nemo/collections/cv/modules/trainables/image_encoder.py new file mode 100644 index 000000000000..49e3d0b24c9d --- /dev/null +++ b/nemo/collections/cv/modules/trainables/image_encoder.py @@ -0,0 +1,219 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# -*- coding: utf-8 -*- +# +# Copyright (C) IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +__author__ = "Tomasz Kornuta" + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/models/vision/image_encoder.py +""" + +from typing import Optional + +import torch +import torchvision.models as models + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core.neural_types import AxisKind, AxisType, ImageFeatureValue, ImageValue, LogitsType, NeuralType +from nemo.utils import logging +from nemo.utils.configuration_parsing import get_value_from_dictionary +from nemo.utils.decorators import add_port_docs + +__all__ = ['ImageEncoder'] + + +class ImageEncoder(TrainableNM): + """ + Neural Module implementing a general-usage image encoderds. + It encapsulates several models from TorchVision (VGG16, ResNet152 and DensNet121, naming a few). + Offers two operation modes and can return: image embeddings vs feature maps. + """ + + def __init__( + self, + model_type: str, + output_size: Optional[int] = None, + return_feature_maps: bool = False, + pretrained: bool = False, + name: Optional[str] = None, + ): + """ + Initializes the ``ImageEncoder`` model, creates the required "backbone". + + Args: + model_type: Type of backbone (Options: VGG16 | DenseNet121 | ResNet152 | ResNet50) + output_size: Size of the output layer (Optional, Default: None) + return_feature_maps: Return mode: image embeddings vs feature maps (Default: False) + pretrained: Loads pretrained model (Default: False) + name: Name of the module (DEFAULT: None) + """ + TrainableNM.__init__(self, name=name) + + # Get operation modes. + self._return_feature_maps = return_feature_maps + + # Get model type. + self._model_type = get_value_from_dictionary( + model_type, "vgg16 | densenet121 | resnet152 | resnet50".split(" | ") + ) + + # Get output size (optional - not in feature_maps). + self._output_size = output_size + + if self._model_type == 'vgg16': + # Get VGG16 + self._model = models.vgg16(pretrained=pretrained) + + if self._return_feature_maps: + # Use only the "feature encoder". + self._model = self._model.features + + # Remember the output feature map dims. + self._feature_map_height = 7 + self._feature_map_width = 7 + self._feature_map_depth = 512 + + else: + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.classifier._modules['6'] = torch.nn.Linear(4096, self._output_size) + + elif self._model_type == 'densenet121': + # Get densenet121 + self._model = models.densenet121(pretrained=pretrained) + + if self._return_feature_maps: + raise ConfigurationError("'densenet121' doesn't support 'return_feature_maps' mode (yet)") + + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.classifier = torch.nn.Linear(1024, self._output_size) + + elif self._model_type == 'resnet152': + # Get resnet152 + self._model = models.resnet152(pretrained=pretrained) + + if self._return_feature_maps: + # Get all modules exluding last (avgpool) and (fc) + modules = list(self._model.children())[:-2] + self._model = torch.nn.Sequential(*modules) + + # Remember the output feature map dims. + self._feature_map_height = 7 + self._feature_map_width = 7 + self._feature_map_depth = 2048 + + else: + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.fc = torch.nn.Linear(2048, self._output_size) + + elif self._model_type == 'resnet50': + # Get resnet50 + self._model = models.resnet50(pretrained=pretrained) + + if self._return_feature_maps: + # Get all modules exluding last (avgpool) and (fc) + modules = list(self._model.children())[:-2] + self._model = torch.nn.Sequential(*modules) + + # Remember the output feature map dims. + self._feature_map_height = 7 + self._feature_map_width = 7 + self._feature_map_depth = 2048 + + else: + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.fc = torch.nn.Linear(2048, self._output_size) + + @property + @add_port_docs() + def input_ports(self): + """ + Returns definitions of module input ports. + """ + return { + "inputs": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=3), + AxisType(kind=AxisKind.Height, size=224), + AxisType(kind=AxisKind.Width, size=224), + ), + elements_type=ImageValue(), + # TODO: actually encoders pretrained on ImageNet require special image normalization. + # Probably this should be a new image type. + ) + } + + @property + @add_port_docs() + def output_ports(self): + """ + Returns definitions of module output ports. + """ + # Return neural type. + if self._return_feature_maps: + return { + "outputs": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=self._feature_map_depth), + AxisType(kind=AxisKind.Height, size=self._feature_map_height), + AxisType(kind=AxisKind.Width, size=self._feature_map_width), + ), + elements_type=ImageFeatureValue(), + ) + } + else: + return { + "outputs": NeuralType( + axes=(AxisType(kind=AxisKind.Batch), AxisType(kind=AxisKind.Any, size=self._output_size),), + elements_type=LogitsType(), + ) + } + + def forward(self, inputs): + """ + Main forward pass of the model. + + Args: + inputs: expected stream containing images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE WIDTH] + + Returns: + outpus: added stream containing outputs [BATCH_SIZE x OUTPUT_SIZE] + OR [BATCH_SIZE x OUTPUT_DEPTH x OUTPUT_HEIGHT x OUTPUT_WIDTH] + """ + # print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + + outputs = self._model(inputs) + + # Add outputs to datadict. + return outputs diff --git a/nemo/collections/cv/modules/trainables/lenet5.py b/nemo/collections/cv/modules/trainables/lenet5.py new file mode 100644 index 000000000000..c5478c479079 --- /dev/null +++ b/nemo/collections/cv/modules/trainables/lenet5.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from typing import Optional + +import torch + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core.neural_types import AxisKind, AxisType, ImageValue, LogprobsType, NeuralType +from nemo.utils.decorators import add_port_docs + +__all__ = ['LeNet5'] + + +class LeNet5(TrainableNM): + """ + Classical LeNet-5 model for MNIST image classification. + """ + + def __init__(self, name: Optional[str] = None): + """ + Creates the LeNet-5 model. + + Args: + name: Name of the module (DEFAULT: None) + """ + # Call the base class constructor. + super().__init__(name=name) + + # Create the LeNet-5 model. + self.model = torch.nn.Sequential( + torch.nn.Conv2d(1, 6, kernel_size=(5, 5)), + torch.nn.ReLU(), + torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), + torch.nn.Conv2d(6, 16, kernel_size=(5, 5)), + torch.nn.ReLU(), + torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2), + torch.nn.Conv2d(16, 120, kernel_size=(5, 5)), + torch.nn.ReLU(), + # reshape to [-1, 120] + torch.nn.Flatten(), + torch.nn.Linear(120, 84), + torch.nn.ReLU(), + torch.nn.Linear(84, 10), + torch.nn.LogSoftmax(dim=1), + ) + self.to(self._device) + + @property + @add_port_docs() + def input_ports(self): + """ Returns definitions of module input ports. """ + return { + "images": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=1), + AxisType(kind=AxisKind.Height, size=32), + AxisType(kind=AxisKind.Width, size=32), + ), + elements_type=ImageValue(), + ) + } + + @property + @add_port_docs() + def output_ports(self): + """ Returns definitions of module output ports. """ + return { + "predictions": NeuralType( + axes=(AxisType(kind=AxisKind.Batch), AxisType(kind=AxisKind.Dimension)), elements_type=LogprobsType() + ) + } + + def forward(self, images): + """ + Performs the forward step of the LeNet-5 model. + + Args: + images: Batch of images to be classified. + + Returns: + Batch of predictions. + """ + + predictions = self.model(images) + return predictions diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index 945506065a34..42786e860aba 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -35,7 +35,15 @@ 'EmbeddedTextType', 'EncodedRepresentation', 'MaskType', + 'Target', + 'ClassificationTarget', + 'ImageFeatureValue', + 'Index', + 'ImageValue', + 'NormalizedImageValue', + 'StringLabel', ] + import abc from abc import ABC, abstractmethod from typing import Dict, Optional, Tuple @@ -192,4 +200,44 @@ class CategoricalValuesType(PredictionsType): class MaskType(PredictionsType): - """Element type to represent boolean mask""" + """Element type to represent a boolean mask""" + + +class Index(ElementType): + """Type representing an element being an index of the sample.""" + + +class Target(ElementType): + """ + Type representing an element being a target value. + """ + + +class ClassificationTarget(Target): + """ + Type representing an element being target value in the classification task, i.e. identifier of a desired class. + """ + + +class StringLabel(ElementType): + """ + Type representing an label being a string with class name (e.g. the "hamster" class in CIFAR100). + """ + + +class ImageValue(ElementType): + """ + Type representing an element/value of a single image channel, + e.g. a single element (R) of RGB image. + """ + + +class NormalizedImageValue(ImageValue): + """ + Type representing an element/value of a single image channel normalized to <0-1> range, + e.g. a single element (R) of normalized RGB image. + """ + + +class ImageFeatureValue(ImageValue): + """Type representing an element (single value) of a (image) feature maps.""" diff --git a/nemo/core/neural_types/neural_type.py b/nemo/core/neural_types/neural_type.py index 699d87c99662..29634b2d2595 100644 --- a/nemo/core/neural_types/neural_type.py +++ b/nemo/core/neural_types/neural_type.py @@ -120,7 +120,9 @@ def compare_and_raise_error(self, parent_type_name, port_name, second_object): type_comatibility != NeuralTypeComparisonResult.SAME and type_comatibility != NeuralTypeComparisonResult.GREATER ): - raise NeuralPortNmTensorMismatchError(parent_type_name, port_name, self, second_object, type_comatibility) + raise NeuralPortNmTensorMismatchError( + parent_type_name, port_name, str(self), str(second_object), type_comatibility + ) @staticmethod def __check_sanity(axes): diff --git a/nemo/utils/configuration_error.py b/nemo/utils/configuration_error.py new file mode 100644 index 000000000000..05e08fb37184 --- /dev/null +++ b/nemo/utils/configuration_error.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# Copyright (C) IBM tkornuta, Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +__author__ = "Tomasz Kornuta" + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/configuration/configuration_error.py +""" + + +class ConfigurationError(Exception): + """ Error thrown when encountered a configuration issue. """ + + def __init__(self, msg): + """ Stores message """ + self.msg = msg + + def __str__(self): + """ Prints the message """ + return repr(self.msg) diff --git a/nemo/utils/configuration_parsing.py b/nemo/utils/configuration_parsing.py new file mode 100644 index 000000000000..81d33233fa69 --- /dev/null +++ b/nemo/utils/configuration_parsing.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# Copyright (C) IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/configuration/config_parsing.py +""" + +from nemo.utils.configuration_error import ConfigurationError + + +def get_value_list_from_dictionary(parameter: str, accepted_values=[]): + """ + Parses parameter values retrieved from a given parameter dictionary using key. + Optionally, checks is all values are accepted. + + Args: + parameter: Value to be checked. + accepted_values: List of accepted values (DEFAULT: []) + + Returns: + List of parsed values + """ + # Preprocess parameter value. + if type(parameter) == str: + if parameter == '': + # Return empty list. + return [] + else: + # Process and split. + values = parameter.replace(" ", "").split(",") + else: + values = parameter # list + if type(values) != list: + raise ConfigurationError("'parameter' must be a list") + + # Test values one by one. + if len(accepted_values) > 0: + for value in values: + if value not in accepted_values: + raise ConfigurationError( + "One of the values in '{}' is invalid (current: '{}', accepted: {})".format( + key, value, accepted_values + ) + ) + + # Return list. + return values + + +def get_value_from_dictionary(parameter: str, accepted_values=[]): + """ + Parses value of the parameter retrieved from a given parameter dictionary using key. + Optionally, checks is the values is one of the accepted values. + + Args: + parameter: Value to be checked. + accepted_values: List of accepted values (DEFAULT: []) + + Returns: + List of parsed values + """ + if type(parameter) != str: + raise ConfigurationError("'parameter' must be a string") + # Preprocess parameter value. + if parameter == '': + return None + + # Test values one by one. + if len(accepted_values) > 0: + if parameter not in accepted_values: + raise ConfigurationError( + "One of the values in '{}' is invalid (current: '{}', accepted: {})".format( + key, value, accepted_values + ) + ) + + # Return value. + return parameter diff --git a/requirements/requirements_cv.txt b/requirements/requirements_cv.txt new file mode 100644 index 000000000000..f0eabe59f453 --- /dev/null +++ b/requirements/requirements_cv.txt @@ -0,0 +1,2 @@ +pillow +torchvision diff --git a/setup.py b/setup.py index fae6a943613d..618fb980f7e4 100644 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ def req_file(filename, folder="requirements"): 'test': req_file("requirements_test.txt"), # Collections Packages 'asr': req_file("requirements_asr.txt"), + 'cv': req_file("requirements_cv.txt"), 'nlp': req_file("requirements_nlp.txt"), 'simple_gan': req_file("requirements_simple_gan.txt"), 'tts': req_file("requirements_tts.txt"),