From f0331a94466db6e38946d9cb5e00a24bc8c93e38 Mon Sep 17 00:00:00 2001 From: Tomasz Kornuta Date: Fri, 22 May 2020 20:19:43 -0700 Subject: [PATCH] Added NonLinearity component, simplified the FFN, cifar10 - ResNet50 operational Signed-off-by: Tomasz Kornuta --- ...ifar10_convnet_ffn_image_classification.py | 27 +++--- .../cifar10_resnet50_image_classification.py | 79 +++++++++++++++ .../cifar10_vgg16_ffn_image_classification.py | 29 +++--- .../mnist_convnet_ffn_image_classification.py | 19 ++-- .../mnist_ffn_image_classification.py | 30 +++--- .../mnist_lenet5_image_classification.py | 16 ++-- .../cv/modules/non_trainables/__init__.py | 1 + .../modules/non_trainables/non_linearity.py | 96 +++++++++++++++++++ .../modules/non_trainables/reshape_tensor.py | 4 +- .../trainables/feed_forward_network.py | 20 +--- .../trainables/generic_image_encoder.py | 36 ++++--- nemo/utils/configuration_parsing.py | 24 +++-- 12 files changed, 275 insertions(+), 106 deletions(-) create mode 100644 nemo/collections/cv/examples/cifar10_resnet50_image_classification.py create mode 100644 nemo/collections/cv/modules/non_trainables/non_linearity.py diff --git a/nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py b/nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py index e327989c0a31..a7708d8ccc59 100644 --- a/nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py +++ b/nemo/collections/cv/examples/cifar10_convnet_ffn_image_classification.py @@ -17,11 +17,10 @@ import argparse import nemo.utils.argparse as nm_argparse -from nemo.collections.cv.modules.data_layers.cifar10_datalayer import CIFAR10DataLayer -from nemo.collections.cv.modules.losses.nll_loss import NLLLoss -from nemo.collections.cv.modules.non_trainables.reshape_tensor import ReshapeTensor -from nemo.collections.cv.modules.trainables.convnet_encoder import ConvNetEncoder -from nemo.collections.cv.modules.trainables.feed_forward_network import FeedForwardNetwork +from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork from nemo.core import ( DeviceType, NeuralGraph, @@ -38,24 +37,26 @@ # Parse the arguments args = parser.parse_args() - # 0. Instantiate Neural Factory. + # Instantiate Neural Factory. nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) - # Data layers for training and validation. - dl = CIFAR10DataLayer(train=True) - # Model. + # Data layer for training. + cifar10_dl = CIFAR10DataLayer(train=True) + # The "model". cnn = ConvNetEncoder(input_depth=3, input_height=32, input_width=32) reshaper = ReshapeTensor(input_sizes=[-1, 16, 2, 2], output_sizes=[-1, 64]) - ffn = FeedForwardNetwork(input_size=64, output_size=10, dropout_rate=0.1, final_logsoftmax=True) + ffn = FeedForwardNetwork(input_size=64, output_size=10, dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) # Loss. nll_loss = NLLLoss() - # 2. Create a training graph. + # Create a training graph. with NeuralGraph(operation_mode=OperationMode.training) as training_graph: - img, tgt = dl() + img, tgt = cifar10_dl() feat_map = cnn(inputs=img) res_img = reshaper(inputs=feat_map) - pred = ffn(inputs=res_img) + logits = ffn(inputs=res_img) + pred = nl(inputs=logits) loss = nll_loss(predictions=pred, targets=tgt) # Set output - that output will be used for training. training_graph.outputs["loss"] = loss diff --git a/nemo/collections/cv/examples/cifar10_resnet50_image_classification.py b/nemo/collections/cv/examples/cifar10_resnet50_image_classification.py new file mode 100644 index 000000000000..5b1472b293ab --- /dev/null +++ b/nemo/collections/cv/examples/cifar10_resnet50_image_classification.py @@ -0,0 +1,79 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import FeedForwardNetwork, GenericImageEncoder +from nemo.core import ( + DeviceType, + NeuralGraph, + NeuralModuleFactory, + OperationMode, + SimpleLossLoggerCallback, + WandbCallback, +) +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) + + # Data layer - upscale the CIFAR10 images to ImageNet resolution. + cifar10_dl = CIFAR10DataLayer(height=224, width=224, train=True) + # The "model". + image_classifier = GenericImageEncoder(model_type="resnet50", output_size=10, pretrained=True, name="resnet50") + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) + # Loss. + nll_loss = NLLLoss() + + # Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + img, tgt = cifar10_dl() + logits = image_classifier(inputs=img) + pred = nl(inputs=logits) + loss = nll_loss(predictions=pred, targets=tgt) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Show info. + logging.info(training_graph.summary()) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Log training metrics to W&B. + wand_callback = WandbCallback( + train_tensors=[loss], wandb_name="simple-mnist-fft", wandb_project="cv-collection-image-classification", + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback, wand_callback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py b/nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py index 130c9947952e..72ab14bacbe5 100644 --- a/nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py +++ b/nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py @@ -17,11 +17,10 @@ import argparse import nemo.utils.argparse as nm_argparse -from nemo.collections.cv.modules.data_layers.cifar10_datalayer import CIFAR10DataLayer -from nemo.collections.cv.modules.losses.nll_loss import NLLLoss -from nemo.collections.cv.modules.non_trainables.reshape_tensor import ReshapeTensor -from nemo.collections.cv.modules.trainables.generic_image_encoder import GenericImageEncoder -from nemo.collections.cv.modules.trainables.feed_forward_network import FeedForwardNetwork +from nemo.collections.cv.modules.data_layers import CIFAR10DataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import FeedForwardNetwork, GenericImageEncoder from nemo.core import ( DeviceType, NeuralGraph, @@ -38,30 +37,32 @@ # Parse the arguments args = parser.parse_args() - # 0. Instantiate Neural Factory. + # Instantiate Neural Factory. nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) - # Data layers for training and validation - upscale the CIFAR10 images to ImageNet resolution. - dl = CIFAR10DataLayer(height=224, width=224, train=True) - # Model. + # Data layer - upscale the CIFAR10 images to ImageNet resolution. + cifar10_dl = CIFAR10DataLayer(height=224, width=224, train=True) + # The "model". image_encoder = GenericImageEncoder(model_type="vgg16", return_feature_maps=True, pretrained=True, name="vgg16") reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 512], output_sizes=[-1, 25088]) - ffn = FeedForwardNetwork(input_size=25088, output_size=10, hidden_sizes=[1000, 1000], dropout_rate=0.1, final_logsoftmax=True) + ffn = FeedForwardNetwork(input_size=25088, output_size=10, hidden_sizes=[1000, 1000], dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) # Loss. nll_loss = NLLLoss() - # 2. Create a training graph. + # Create a training graph. with NeuralGraph(operation_mode=OperationMode.training) as training_graph: - img, tgt = dl() + img, tgt = cifar10_dl() feat_map = image_encoder(inputs=img) res_img = reshaper(inputs=feat_map) - pred = ffn(inputs=res_img) + logits = ffn(inputs=res_img) + pred = nl(inputs=logits) loss = nll_loss(predictions=pred, targets=tgt) # Set output - that output will be used for training. training_graph.outputs["loss"] = loss # Freeze the pretrained encoder. - training_graph.freeze() + training_graph.freeze(["vgg16"]) logging.info(training_graph.summary()) # SimpleLossLoggerCallback will print loss values to console. diff --git a/nemo/collections/cv/examples/mnist_convnet_ffn_image_classification.py b/nemo/collections/cv/examples/mnist_convnet_ffn_image_classification.py index 8b88bdd688b6..988990ab63df 100644 --- a/nemo/collections/cv/examples/mnist_convnet_ffn_image_classification.py +++ b/nemo/collections/cv/examples/mnist_convnet_ffn_image_classification.py @@ -17,11 +17,10 @@ import argparse import nemo.utils.argparse as nm_argparse -from nemo.collections.cv.modules.data_layers.mnist_datalayer import MNISTDataLayer -from nemo.collections.cv.modules.losses.nll_loss import NLLLoss -from nemo.collections.cv.modules.non_trainables.reshape_tensor import ReshapeTensor -from nemo.collections.cv.modules.trainables.convnet_encoder import ConvNetEncoder -from nemo.collections.cv.modules.trainables.feed_forward_network import FeedForwardNetwork +from nemo.collections.cv.modules.data_layers import MNISTDataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork from nemo.core import ( DeviceType, NeuralGraph, @@ -43,19 +42,21 @@ # Data layers for training and validation. dl = MNISTDataLayer(height=28, width=28, train=True) - # Model. + # The "model". cnn = ConvNetEncoder(input_depth=1, input_height=28, input_width=28) reshaper = ReshapeTensor(input_sizes=[-1, 16, 1, 1], output_sizes=[-1, 16]) - ffn = FeedForwardNetwork(input_size=16, output_size=10, dropout_rate=0.1, final_logsoftmax=True) + ffn = FeedForwardNetwork(input_size=16, output_size=10, dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) # Loss. nll_loss = NLLLoss() - # 2. Create a training graph. + # Create a training graph. with NeuralGraph(operation_mode=OperationMode.training) as training_graph: img, tgt = dl() feat_map = cnn(inputs=img) res_img = reshaper(inputs=feat_map) - pred = ffn(inputs=res_img) + logits = ffn(inputs=res_img) + pred = nl(inputs=logits) loss = nll_loss(predictions=pred, targets=tgt) # Set output - that output will be used for training. training_graph.outputs["loss"] = loss diff --git a/nemo/collections/cv/examples/mnist_ffn_image_classification.py b/nemo/collections/cv/examples/mnist_ffn_image_classification.py index b1c961177d80..cdaf6139e27c 100644 --- a/nemo/collections/cv/examples/mnist_ffn_image_classification.py +++ b/nemo/collections/cv/examples/mnist_ffn_image_classification.py @@ -16,13 +16,11 @@ import argparse -from torch import max, mean, stack, tensor - import nemo.utils.argparse as nm_argparse -from nemo.collections.cv.modules.data_layers.mnist_datalayer import MNISTDataLayer -from nemo.collections.cv.modules.losses.nll_loss import NLLLoss -from nemo.collections.cv.modules.non_trainables.reshape_tensor import ReshapeTensor -from nemo.collections.cv.modules.trainables.feed_forward_network import FeedForwardNetwork +from nemo.collections.cv.modules.data_layers import MNISTDataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor +from nemo.collections.cv.modules.trainables import FeedForwardNetwork from nemo.core import ( DeviceType, NeuralGraph, @@ -39,25 +37,25 @@ # Parse the arguments args = parser.parse_args() - # 0. Instantiate Neural Factory. + # Instantiate Neural Factory. nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) # Data layers for training and validation. dl = MNISTDataLayer(height=28, width=28, train=True) - # Model. + # The "model". reshaper = ReshapeTensor(input_sizes=[-1, 1, 32, 32], output_sizes=[-1, 784]) - ffn = FeedForwardNetwork( - input_size=784, output_size=10, hidden_sizes=[100, 100], dropout_rate=0.1, final_logsoftmax=True - ) + ffn = FeedForwardNetwork(input_size=784, output_size=10, hidden_sizes=[100, 100], dropout_rate=0.1) + nl = NonLinearity(type="logsoftmax", sizes=[-1, 10]) # Loss. nll_loss = NLLLoss() - # 2. Create a training graph. + # Create a training graph. with NeuralGraph(operation_mode=OperationMode.training) as training_graph: - img, tgt = dl() - res_img = reshaper(inputs=img) - pred = ffn(inputs=res_img) - loss = nll_loss(predictions=pred, targets=tgt) + imgs, tgts = dl() + res_imgs = reshaper(inputs=imgs) + logits = ffn(inputs=res_imgs) + preds = nl(inputs=logits) + loss = nll_loss(predictions=preds, targets=tgts) # Set output - that output will be used for training. training_graph.outputs["loss"] = loss diff --git a/nemo/collections/cv/examples/mnist_lenet5_image_classification.py b/nemo/collections/cv/examples/mnist_lenet5_image_classification.py index d30cdacb3a55..c736ef5723c6 100644 --- a/nemo/collections/cv/examples/mnist_lenet5_image_classification.py +++ b/nemo/collections/cv/examples/mnist_lenet5_image_classification.py @@ -19,9 +19,9 @@ from torch import max, mean, stack, tensor import nemo.utils.argparse as nm_argparse -from nemo.collections.cv.modules.data_layers.mnist_datalayer import MNISTDataLayer -from nemo.collections.cv.modules.losses.nll_loss import NLLLoss -from nemo.collections.cv.modules.trainables.lenet5 import LeNet5 +from nemo.collections.cv.modules.data_layers import MNISTDataLayer +from nemo.collections.cv.modules.losses import NLLLoss +from nemo.collections.cv.modules.trainables import LeNet5 from nemo.core import ( DeviceType, EvaluatorCallback, @@ -38,18 +38,18 @@ # Parse the arguments args = parser.parse_args() - # 0. Instantiate Neural Factory. + # Instantiate Neural Factory. nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.GPU) # Data layers for training and validation. dl = MNISTDataLayer(height=32, width=32, train=True) dl_e = MNISTDataLayer(height=32, width=32, train=False) - # Model. + # The "model". lenet5 = LeNet5() # Loss. nll_loss = NLLLoss() - # 2. Create a training graph. + # Create a training graph. with NeuralGraph(operation_mode=OperationMode.training) as training_graph: x, y = dl() p = lenet5(images=x) @@ -57,13 +57,13 @@ # Set output - that output will be used for training. training_graph.outputs["loss"] = loss - # 3. Create a validation graph, starting from the second data layer. + # Create a validation graph, starting from the second data layer. with NeuralGraph(operation_mode=OperationMode.evaluation) as evaluation_graph: x, y = dl_e() p = lenet5(images=x) loss_e = nll_loss(predictions=p, targets=y) - # 4. Create the callbacks. + # Create the callbacks. def eval_loss_per_batch_callback(tensors, global_vars): if "eval_loss" not in global_vars.keys(): global_vars["eval_loss"] = [] diff --git a/nemo/collections/cv/modules/non_trainables/__init__.py b/nemo/collections/cv/modules/non_trainables/__init__.py index 5e5de464c162..8ee079e36de8 100644 --- a/nemo/collections/cv/modules/non_trainables/__init__.py +++ b/nemo/collections/cv/modules/non_trainables/__init__.py @@ -14,4 +14,5 @@ # limitations under the License. # ============================================================================= +from nemo.collections.cv.modules.non_trainables.non_linearity import * from nemo.collections.cv.modules.non_trainables.reshape_tensor import * diff --git a/nemo/collections/cv/modules/non_trainables/non_linearity.py b/nemo/collections/cv/modules/non_trainables/non_linearity.py new file mode 100644 index 000000000000..de2e9272203d --- /dev/null +++ b/nemo/collections/cv/modules/non_trainables/non_linearity.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + + +import torch + +from nemo.backends.pytorch.nm import NonTrainableNM +from nemo.core.neural_types import AxisKind, AxisType, LogprobsType, NeuralType, VoidType +from nemo.utils import logging +from nemo.utils.decorators import add_port_docs + +__all__ = ['NonLinearity'] + + +class NonLinearity(NonTrainableNM): + """ + Class responsible for applying additional non-linearity along the last axis of the input tensor. + + """ + + def __init__(self, type="logsoftmax", sizes=[-1], name=None): + """ + Initializes the object. + + """ + # Call constructor of parent classes. + NonTrainableNM.__init__(self, name=name) + + # Store params. + self._type = type + self._sizes = sizes + + # Apply the non-linearity along the last dimension. + # TODO: if self._type != "logsoftmax" + assert type == "logsoftmax" + dim = len(sizes) - 1 + self._non_linearity = torch.nn.LogSoftmax(dim=dim) + + @property + @add_port_docs() + def input_ports(self): + """ + Returns definitions of module input ports. + Batch of inputs, each represented as index [BATCH_SIZE x ... x INPUT_SIZE] + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + for size in self._sizes[1:]: + axes.append(AxisType(kind=AxisKind.Any, size=size)) + # Return neural type. + return {"inputs": NeuralType(axes, VoidType())} + + @property + @add_port_docs() + def output_ports(self): + """ + Returns definitions of module output ports. + """ + # Prepare list of axes. + axes = [AxisType(kind=AxisKind.Batch)] + for size in self._sizes[1:]: + axes.append(AxisType(kind=AxisKind.Any, size=size)) + # Return neural type. + # TODO: if self._type != "logsoftmax" + return {"outputs": NeuralType(axes, LogprobsType())} + + def forward(self, inputs): + """ + Encodes "inputs" in the format of a single tensor. + Stores reshaped tensor in "outputs" field of in data_streams. + + Args: + inputs: a tensor [BATCH_SIZE x ...] + + Returns: + Outputs a tensor [BATCH_SIZE x ...] + """ + # print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + + # Reshape. + # TODO: if self._type != "logsoftmax" + return self._non_linearity(inputs) diff --git a/nemo/collections/cv/modules/non_trainables/reshape_tensor.py b/nemo/collections/cv/modules/non_trainables/reshape_tensor.py index eb852ab69f10..a30cdaf53ddd 100644 --- a/nemo/collections/cv/modules/non_trainables/reshape_tensor.py +++ b/nemo/collections/cv/modules/non_trainables/reshape_tensor.py @@ -36,8 +36,6 @@ https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/transforms/reshape_tensor.py """ -import torch - from nemo.backends.pytorch.nm import NonTrainableNM from nemo.core.neural_types import AxisKind, AxisType, NeuralType, VoidType from nemo.utils import logging @@ -55,7 +53,7 @@ class ReshapeTensor(NonTrainableNM): def __init__(self, input_sizes, output_sizes, name=None): """ - Initializes object. + Initializes the object. """ # Call constructor of parent classes. diff --git a/nemo/collections/cv/modules/trainables/feed_forward_network.py b/nemo/collections/cv/modules/trainables/feed_forward_network.py index 13d913c0d7a7..c6ba332eba79 100644 --- a/nemo/collections/cv/modules/trainables/feed_forward_network.py +++ b/nemo/collections/cv/modules/trainables/feed_forward_network.py @@ -39,7 +39,7 @@ import torch from nemo.backends.pytorch.nm import TrainableNM -from nemo.core.neural_types import AxisKind, AxisType, LogprobsType, NeuralType, VoidType +from nemo.core.neural_types import AxisKind, AxisType, NeuralType, VoidType from nemo.utils import logging from nemo.utils.configuration_error import ConfigurationError from nemo.utils.decorators import add_port_docs @@ -57,9 +57,7 @@ class FeedForwardNetwork(TrainableNM): Additionally, the module applies log softmax non-linearity on the output of the last layer (logits). """ - def __init__( - self, input_size, output_size, hidden_sizes=[], dimensions=2, dropout_rate=0, final_logsoftmax=False, name=None - ): + def __init__(self, input_size, output_size, hidden_sizes=[], dimensions=2, dropout_rate=0, name=None): """ Initializes the classifier. @@ -127,11 +125,6 @@ def __init__( ) ) - # Create the final non-linearity. - self._final_logsoftmax = final_logsoftmax - if self._final_logsoftmax: - modules.append(torch.nn.LogSoftmax(dim=1)) - # Finally create the sequential model out of those modules. self.layers = torch.nn.Sequential(*modules) @@ -165,13 +158,8 @@ def output_ports(self): axes.append(AxisType(kind=AxisKind.Any)) # Add the last axis: input_size axes.append(AxisType(kind=AxisKind.Any, size=self._output_size)) - # Return neural type. - if self._final_logsoftmax: - # Batch of predictions, each represented as probability distribution over classes. - return {"outputs": NeuralType(axes, LogprobsType())} - else: - # Batch of "logits" of "any type". - return {"outputs": NeuralType(axes, VoidType())} + # Return neural type: batch of "logits" of "any type". + return {"outputs": NeuralType(axes, VoidType())} def forward(self, inputs): """ diff --git a/nemo/collections/cv/modules/trainables/generic_image_encoder.py b/nemo/collections/cv/modules/trainables/generic_image_encoder.py index 1adf1185c61d..d195289d4181 100644 --- a/nemo/collections/cv/modules/trainables/generic_image_encoder.py +++ b/nemo/collections/cv/modules/trainables/generic_image_encoder.py @@ -49,10 +49,12 @@ __all__ = ['GenericImageEncoder'] + class GenericImageEncoder(TrainableNM): """ Class """ + def __init__(self, model_type, output_size=None, return_feature_maps=False, pretrained=False, name=None): """ Initializes the ``GenericImageEncoder`` model, creates the required backend. @@ -64,12 +66,14 @@ def __init__(self, model_type, output_size=None, return_feature_maps=False, pret self._return_feature_maps = return_feature_maps # Get model type. - self._model_type = get_value_from_dictionary(model_type, "vgg16 | densenet121 | resnet152 | resnet50".split(" | ")) + self._model_type = get_value_from_dictionary( + model_type, "vgg16 | densenet121 | resnet152 | resnet50".split(" | ") + ) - # Get output size (optional). + # Get output size (optional - not in feature_maps). self._output_size = output_size - if(self._model_type == 'vgg16'): + if self._model_type == 'vgg16': # Get VGG16 self._model = models.vgg16(pretrained=pretrained) @@ -86,7 +90,7 @@ def __init__(self, model_type, output_size=None, return_feature_maps=False, pret # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). self._model.classifier._modules['6'] = torch.nn.Linear(4096, self._output_size) - elif(self._model_type == 'densenet121'): + elif self._model_type == 'densenet121': # Get densenet121 self._model = models.densenet121(pretrained=pretrained) @@ -96,15 +100,14 @@ def __init__(self, model_type, output_size=None, return_feature_maps=False, pret # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). self._model.classifier = torch.nn.Linear(1024, self._output_size) - - elif(self._model_type == 'resnet152'): + elif self._model_type == 'resnet152': # Get resnet152 self._model = models.resnet152(pretrained=pretrained) if self._return_feature_maps: # Get all modules exluding last (avgpool) and (fc) - modules=list(self._model.children())[:-2] - self._model=torch.nn.Sequential(*modules) + modules = list(self._model.children())[:-2] + self._model = torch.nn.Sequential(*modules) # Remember the output feature map dims. self._feature_map_height = 7 @@ -115,14 +118,14 @@ def __init__(self, model_type, output_size=None, return_feature_maps=False, pret # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). self._model.fc = torch.nn.Linear(2048, self._output_size) - elif(self._model_type == 'resnet50'): + elif self._model_type == 'resnet50': # Get resnet50 self._model = models.resnet50(pretrained=pretrained) if self._return_feature_maps: # Get all modules exluding last (avgpool) and (fc) - modules=list(self._model.children())[:-2] - self._model=torch.nn.Sequential(*modules) + modules = list(self._model.children())[:-2] + self._model = torch.nn.Sequential(*modules) # Remember the output feature map dims. self._feature_map_height = 7 @@ -133,7 +136,6 @@ def __init__(self, model_type, output_size=None, return_feature_maps=False, pret # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). self._model.fc = torch.nn.Linear(2048, self._output_size) - @property @add_port_docs() def input_ports(self): @@ -174,15 +176,11 @@ def output_ports(self): else: return { "outputs": NeuralType( - axes=( - AxisType(kind=AxisKind.Batch), - AxisType(kind=AxisKind.Any, size=self._output_size), - ), + axes=(AxisType(kind=AxisKind.Batch), AxisType(kind=AxisKind.Any, size=self._output_size),), elements_type=VoidType(), ) } - def forward(self, inputs): """ Main forward pass of the model. @@ -194,9 +192,9 @@ def forward(self, inputs): outpus: added stream containing outputs [BATCH_SIZE x OUTPUT_SIZE] OR [BATCH_SIZE x OUTPUT_DEPTH x OUTPUT_HEIGHT x OUTPUT_WIDTH] """ - #print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + # print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) outputs = self._model(inputs) # Add outputs to datadict. - return outputs \ No newline at end of file + return outputs diff --git a/nemo/utils/configuration_parsing.py b/nemo/utils/configuration_parsing.py index 94aabd061b05..34931e388935 100644 --- a/nemo/utils/configuration_parsing.py +++ b/nemo/utils/configuration_parsing.py @@ -37,7 +37,7 @@ from nemo.utils.configuration_error import ConfigurationError -def get_value_list_from_dictionary(parameter: str, accepted_values = []): +def get_value_list_from_dictionary(parameter: str, accepted_values=[]): """ Parses parameter values retrieved from a given parameter dictionary using key. Optionally, checks is all values are accepted. @@ -50,15 +50,15 @@ def get_value_list_from_dictionary(parameter: str, accepted_values = []): List of parsed values """ # Preprocess parameter value. - if (type(parameter) == str): + if type(parameter) == str: if parameter == '': # Return empty list. return [] else: # Process and split. - values = parameter.replace(" ","").split(",") + values = parameter.replace(" ", "").split(",") else: - values = parameter # list + values = parameter # list if type(values) != list: ConfigurationError("'parameter' must be a list") @@ -66,13 +66,17 @@ def get_value_list_from_dictionary(parameter: str, accepted_values = []): if len(accepted_values) > 0: for value in values: if value not in accepted_values: - raise ConfigurationError("One of the values in '{}' is invalid (current: '{}', accepted: {})".format(key, value, accepted_values)) + raise ConfigurationError( + "One of the values in '{}' is invalid (current: '{}', accepted: {})".format( + key, value, accepted_values + ) + ) # Return list. return values -def get_value_from_dictionary(parameter: str, accepted_values = []): +def get_value_from_dictionary(parameter: str, accepted_values=[]): """ Parses value of the parameter retrieved from a given parameter dictionary using key. Optionally, checks is the values is one of the accepted values. @@ -93,7 +97,11 @@ def get_value_from_dictionary(parameter: str, accepted_values = []): # Test values one by one. if len(accepted_values) > 0: if parameter not in accepted_values: - raise ConfigurationError("One of the values in '{}' is invalid (current: '{}', accepted: {})".format(key, value, accepted_values)) + raise ConfigurationError( + "One of the values in '{}' is invalid (current: '{}', accepted: {})".format( + key, value, accepted_values + ) + ) # Return value. - return parameter \ No newline at end of file + return parameter