From b165288e017f4b88f40b46ec83ac5aeb1493a2aa Mon Sep 17 00:00:00 2001 From: Tomasz Kornuta Date: Fri, 22 May 2020 19:16:45 -0700 Subject: [PATCH] GenericImageEncoder ported + CIFAR10 VGG16 classification example Signed-off-by: Tomasz Kornuta --- .../cifar10_vgg16_ffn_image_classification.py | 83 +++++++ .../cv/modules/trainables/__init__.py | 1 + .../cv/modules/trainables/convnet_encoder.py | 7 +- .../trainables/generic_image_encoder.py | 202 ++++++++++++++++++ nemo/utils/configuration_parsing.py | 99 +++++++++ 5 files changed, 388 insertions(+), 4 deletions(-) create mode 100644 nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py create mode 100644 nemo/collections/cv/modules/trainables/generic_image_encoder.py create mode 100644 nemo/utils/configuration_parsing.py diff --git a/nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py b/nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py new file mode 100644 index 000000000000..130c9947952e --- /dev/null +++ b/nemo/collections/cv/examples/cifar10_vgg16_ffn_image_classification.py @@ -0,0 +1,83 @@ +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import argparse + +import nemo.utils.argparse as nm_argparse +from nemo.collections.cv.modules.data_layers.cifar10_datalayer import CIFAR10DataLayer +from nemo.collections.cv.modules.losses.nll_loss import NLLLoss +from nemo.collections.cv.modules.non_trainables.reshape_tensor import ReshapeTensor +from nemo.collections.cv.modules.trainables.generic_image_encoder import GenericImageEncoder +from nemo.collections.cv.modules.trainables.feed_forward_network import FeedForwardNetwork +from nemo.core import ( + DeviceType, + NeuralGraph, + NeuralModuleFactory, + OperationMode, + SimpleLossLoggerCallback, + WandbCallback, +) +from nemo.utils import logging + +if __name__ == "__main__": + # Create the default parser. + parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve') + # Parse the arguments + args = parser.parse_args() + + # 0. Instantiate Neural Factory. + nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU) + + # Data layers for training and validation - upscale the CIFAR10 images to ImageNet resolution. + dl = CIFAR10DataLayer(height=224, width=224, train=True) + # Model. + image_encoder = GenericImageEncoder(model_type="vgg16", return_feature_maps=True, pretrained=True, name="vgg16") + reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 512], output_sizes=[-1, 25088]) + ffn = FeedForwardNetwork(input_size=25088, output_size=10, hidden_sizes=[1000, 1000], dropout_rate=0.1, final_logsoftmax=True) + # Loss. + nll_loss = NLLLoss() + + # 2. Create a training graph. + with NeuralGraph(operation_mode=OperationMode.training) as training_graph: + img, tgt = dl() + feat_map = image_encoder(inputs=img) + res_img = reshaper(inputs=feat_map) + pred = ffn(inputs=res_img) + loss = nll_loss(predictions=pred, targets=tgt) + # Set output - that output will be used for training. + training_graph.outputs["loss"] = loss + + # Freeze the pretrained encoder. + training_graph.freeze() + logging.info(training_graph.summary()) + + # SimpleLossLoggerCallback will print loss values to console. + callback = SimpleLossLoggerCallback( + tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}') + ) + + # Log training metrics to W&B. + wand_callback = WandbCallback( + train_tensors=[loss], wandb_name="simple-mnist-fft", wandb_project="cv-collection-image-classification", + ) + + # Invoke the "train" action. + nf.train( + training_graph=training_graph, + callbacks=[callback, wand_callback], + optimization_params={"num_epochs": 10, "lr": 0.001}, + optimizer="adam", + ) diff --git a/nemo/collections/cv/modules/trainables/__init__.py b/nemo/collections/cv/modules/trainables/__init__.py index 55d73e48a2e9..dbc585131cb0 100644 --- a/nemo/collections/cv/modules/trainables/__init__.py +++ b/nemo/collections/cv/modules/trainables/__init__.py @@ -16,4 +16,5 @@ from nemo.collections.cv.modules.trainables.convnet_encoder import * from nemo.collections.cv.modules.trainables.feed_forward_network import * +from nemo.collections.cv.modules.trainables.generic_image_encoder import * from nemo.collections.cv.modules.trainables.lenet5 import * diff --git a/nemo/collections/cv/modules/trainables/convnet_encoder.py b/nemo/collections/cv/modules/trainables/convnet_encoder.py index f0bfc4b53b37..74128fcea7a3 100644 --- a/nemo/collections/cv/modules/trainables/convnet_encoder.py +++ b/nemo/collections/cv/modules/trainables/convnet_encoder.py @@ -38,7 +38,6 @@ import numpy as np -import torch import torch.nn as nn from nemo.backends.pytorch.nm import TrainableNM @@ -337,19 +336,19 @@ def forward(self, inputs): out_conv1 = self._conv1(inputs) # apply max_pooling and relu - out_maxpool1 = torch.nn.functional.relu(self._maxpool1(out_conv1)) + out_maxpool1 = nn.functional.relu(self._maxpool1(out_conv1)) # apply Convolutional layer 2 out_conv2 = self._conv2(out_maxpool1) # apply max_pooling and relu - out_maxpool2 = torch.nn.functional.relu(self._maxpool2(out_conv2)) + out_maxpool2 = nn.functional.relu(self._maxpool2(out_conv2)) # apply Convolutional layer 3 out_conv3 = self._conv3(out_maxpool2) # apply max_pooling and relu - out_maxpool3 = torch.nn.functional.relu(self._maxpool3(out_conv3)) + out_maxpool3 = nn.functional.relu(self._maxpool3(out_conv3)) # Return output. return out_maxpool3 diff --git a/nemo/collections/cv/modules/trainables/generic_image_encoder.py b/nemo/collections/cv/modules/trainables/generic_image_encoder.py new file mode 100644 index 000000000000..1adf1185c61d --- /dev/null +++ b/nemo/collections/cv/modules/trainables/generic_image_encoder.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# -*- coding: utf-8 -*- +# +# Copyright (C) IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +__author__ = "Tomasz Kornuta" + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/models/vision/generic_image_encoder.py +""" + +import torch +import torchvision.models as models + +from nemo.backends.pytorch.nm import TrainableNM +from nemo.core.neural_types import AxisKind, AxisType, NeuralType, VoidType +from nemo.utils import logging +from nemo.utils.configuration_parsing import get_value_from_dictionary +from nemo.utils.decorators import add_port_docs + +__all__ = ['GenericImageEncoder'] + +class GenericImageEncoder(TrainableNM): + """ + Class + """ + def __init__(self, model_type, output_size=None, return_feature_maps=False, pretrained=False, name=None): + """ + Initializes the ``GenericImageEncoder`` model, creates the required backend. + + """ + TrainableNM.__init__(self, name=name) + + # Get operation modes. + self._return_feature_maps = return_feature_maps + + # Get model type. + self._model_type = get_value_from_dictionary(model_type, "vgg16 | densenet121 | resnet152 | resnet50".split(" | ")) + + # Get output size (optional). + self._output_size = output_size + + if(self._model_type == 'vgg16'): + # Get VGG16 + self._model = models.vgg16(pretrained=pretrained) + + if self._return_feature_maps: + # Use only the "feature encoder". + self._model = self._model.features + + # Remember the output feature map dims. + self._feature_map_height = 7 + self._feature_map_width = 7 + self._feature_map_depth = 512 + + else: + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.classifier._modules['6'] = torch.nn.Linear(4096, self._output_size) + + elif(self._model_type == 'densenet121'): + # Get densenet121 + self._model = models.densenet121(pretrained=pretrained) + + if self._return_feature_maps: + raise ConfigurationError("'densenet121' doesn't support 'return_feature_maps' mode (yet)") + + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.classifier = torch.nn.Linear(1024, self._output_size) + + + elif(self._model_type == 'resnet152'): + # Get resnet152 + self._model = models.resnet152(pretrained=pretrained) + + if self._return_feature_maps: + # Get all modules exluding last (avgpool) and (fc) + modules=list(self._model.children())[:-2] + self._model=torch.nn.Sequential(*modules) + + # Remember the output feature map dims. + self._feature_map_height = 7 + self._feature_map_width = 7 + self._feature_map_depth = 2048 + + else: + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.fc = torch.nn.Linear(2048, self._output_size) + + elif(self._model_type == 'resnet50'): + # Get resnet50 + self._model = models.resnet50(pretrained=pretrained) + + if self._return_feature_maps: + # Get all modules exluding last (avgpool) and (fc) + modules=list(self._model.children())[:-2] + self._model=torch.nn.Sequential(*modules) + + # Remember the output feature map dims. + self._feature_map_height = 7 + self._feature_map_width = 7 + self._feature_map_depth = 2048 + + else: + # Use the whole model, but "reshape"/reinstantiate the last layer ("FC6"). + self._model.fc = torch.nn.Linear(2048, self._output_size) + + + @property + @add_port_docs() + def input_ports(self): + """ + Returns definitions of module input ports. + """ + return { + "inputs": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=3), + AxisType(kind=AxisKind.Height, size=224), + AxisType(kind=AxisKind.Width, size=224), + ), + elements_type=VoidType(), + ) + } + + @property + @add_port_docs() + def output_ports(self): + """ + Returns definitions of module output ports. + """ + # Return neural type. + if self._return_feature_maps: + return { + "outputs": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Channel, size=self._feature_map_depth), + AxisType(kind=AxisKind.Height, size=self._feature_map_height), + AxisType(kind=AxisKind.Width, size=self._feature_map_width), + ), + elements_type=VoidType(), + ) + } + else: + return { + "outputs": NeuralType( + axes=( + AxisType(kind=AxisKind.Batch), + AxisType(kind=AxisKind.Any, size=self._output_size), + ), + elements_type=VoidType(), + ) + } + + + def forward(self, inputs): + """ + Main forward pass of the model. + + Args: + inputs: expected stream containing images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE WIDTH] + + Returns: + outpus: added stream containing outputs [BATCH_SIZE x OUTPUT_SIZE] + OR [BATCH_SIZE x OUTPUT_DEPTH x OUTPUT_HEIGHT x OUTPUT_WIDTH] + """ + #print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device)) + + outputs = self._model(inputs) + + # Add outputs to datadict. + return outputs \ No newline at end of file diff --git a/nemo/utils/configuration_parsing.py b/nemo/utils/configuration_parsing.py new file mode 100644 index 000000000000..94aabd061b05 --- /dev/null +++ b/nemo/utils/configuration_parsing.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# ============================================================================= +# Copyright (c) 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +# Copyright (C) IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +""" +This file contains code artifacts adapted from the original implementation: +https://github.com/IBM/pytorchpipe/blob/develop/ptp/configuration/config_parsing.py +""" + +from nemo.utils.configuration_error import ConfigurationError + + +def get_value_list_from_dictionary(parameter: str, accepted_values = []): + """ + Parses parameter values retrieved from a given parameter dictionary using key. + Optionally, checks is all values are accepted. + + Args: + parameter: Value to be checked. + accepted_values: List of accepted values (DEFAULT: []) + + Returns: + List of parsed values + """ + # Preprocess parameter value. + if (type(parameter) == str): + if parameter == '': + # Return empty list. + return [] + else: + # Process and split. + values = parameter.replace(" ","").split(",") + else: + values = parameter # list + if type(values) != list: + ConfigurationError("'parameter' must be a list") + + # Test values one by one. + if len(accepted_values) > 0: + for value in values: + if value not in accepted_values: + raise ConfigurationError("One of the values in '{}' is invalid (current: '{}', accepted: {})".format(key, value, accepted_values)) + + # Return list. + return values + + +def get_value_from_dictionary(parameter: str, accepted_values = []): + """ + Parses value of the parameter retrieved from a given parameter dictionary using key. + Optionally, checks is the values is one of the accepted values. + + Args: + parameter: Value to be checked. + accepted_values: List of accepted values (DEFAULT: []) + + Returns: + List of parsed values + """ + if type(parameter) != str: + ConfigurationError("'parameter' must be a string") + # Preprocess parameter value. + if parameter == '': + return None + + # Test values one by one. + if len(accepted_values) > 0: + if parameter not in accepted_values: + raise ConfigurationError("One of the values in '{}' is invalid (current: '{}', accepted: {})".format(key, value, accepted_values)) + + # Return value. + return parameter \ No newline at end of file