Skip to content

Commit

Permalink
GenericImageEncoder ported + CIFAR10 VGG16 classification example
Browse files Browse the repository at this point in the history
Signed-off-by: Tomasz Kornuta <[email protected]>
  • Loading branch information
tkornuta-nvidia committed May 23, 2020
1 parent cf237a4 commit b165288
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import argparse

import nemo.utils.argparse as nm_argparse
from nemo.collections.cv.modules.data_layers.cifar10_datalayer import CIFAR10DataLayer
from nemo.collections.cv.modules.losses.nll_loss import NLLLoss
from nemo.collections.cv.modules.non_trainables.reshape_tensor import ReshapeTensor
from nemo.collections.cv.modules.trainables.generic_image_encoder import GenericImageEncoder
from nemo.collections.cv.modules.trainables.feed_forward_network import FeedForwardNetwork
from nemo.core import (
DeviceType,
NeuralGraph,
NeuralModuleFactory,
OperationMode,
SimpleLossLoggerCallback,
WandbCallback,
)
from nemo.utils import logging

if __name__ == "__main__":
# Create the default parser.
parser = argparse.ArgumentParser(parents=[nm_argparse.NemoArgParser()], conflict_handler='resolve')
# Parse the arguments
args = parser.parse_args()

# 0. Instantiate Neural Factory.
nf = NeuralModuleFactory(local_rank=args.local_rank, placement=DeviceType.CPU)

# Data layers for training and validation - upscale the CIFAR10 images to ImageNet resolution.
dl = CIFAR10DataLayer(height=224, width=224, train=True)
# Model.
image_encoder = GenericImageEncoder(model_type="vgg16", return_feature_maps=True, pretrained=True, name="vgg16")
reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 512], output_sizes=[-1, 25088])
ffn = FeedForwardNetwork(input_size=25088, output_size=10, hidden_sizes=[1000, 1000], dropout_rate=0.1, final_logsoftmax=True)
# Loss.
nll_loss = NLLLoss()

# 2. Create a training graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
img, tgt = dl()
feat_map = image_encoder(inputs=img)
res_img = reshaper(inputs=feat_map)
pred = ffn(inputs=res_img)
loss = nll_loss(predictions=pred, targets=tgt)
# Set output - that output will be used for training.
training_graph.outputs["loss"] = loss

# Freeze the pretrained encoder.
training_graph.freeze()
logging.info(training_graph.summary())

# SimpleLossLoggerCallback will print loss values to console.
callback = SimpleLossLoggerCallback(
tensors=[loss], print_func=lambda x: logging.info(f'Training Loss: {str(x[0].item())}')
)

# Log training metrics to W&B.
wand_callback = WandbCallback(
train_tensors=[loss], wandb_name="simple-mnist-fft", wandb_project="cv-collection-image-classification",
)

# Invoke the "train" action.
nf.train(
training_graph=training_graph,
callbacks=[callback, wand_callback],
optimization_params={"num_epochs": 10, "lr": 0.001},
optimizer="adam",
)
1 change: 1 addition & 0 deletions nemo/collections/cv/modules/trainables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@

from nemo.collections.cv.modules.trainables.convnet_encoder import *
from nemo.collections.cv.modules.trainables.feed_forward_network import *
from nemo.collections.cv.modules.trainables.generic_image_encoder import *
from nemo.collections.cv.modules.trainables.lenet5 import *
7 changes: 3 additions & 4 deletions nemo/collections/cv/modules/trainables/convnet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


import numpy as np
import torch
import torch.nn as nn

from nemo.backends.pytorch.nm import TrainableNM
Expand Down Expand Up @@ -337,19 +336,19 @@ def forward(self, inputs):
out_conv1 = self._conv1(inputs)

# apply max_pooling and relu
out_maxpool1 = torch.nn.functional.relu(self._maxpool1(out_conv1))
out_maxpool1 = nn.functional.relu(self._maxpool1(out_conv1))

# apply Convolutional layer 2
out_conv2 = self._conv2(out_maxpool1)

# apply max_pooling and relu
out_maxpool2 = torch.nn.functional.relu(self._maxpool2(out_conv2))
out_maxpool2 = nn.functional.relu(self._maxpool2(out_conv2))

# apply Convolutional layer 3
out_conv3 = self._conv3(out_maxpool2)

# apply max_pooling and relu
out_maxpool3 = torch.nn.functional.relu(self._maxpool3(out_conv3))
out_maxpool3 = nn.functional.relu(self._maxpool3(out_conv3))

# Return output.
return out_maxpool3
202 changes: 202 additions & 0 deletions nemo/collections/cv/modules/trainables/generic_image_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# -*- coding: utf-8 -*-
# =============================================================================
# Copyright (c) 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
# -*- coding: utf-8 -*-
#
# Copyright (C) IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

__author__ = "Tomasz Kornuta"

"""
This file contains code artifacts adapted from the original implementation:
https://github.com/IBM/pytorchpipe/blob/develop/ptp/components/models/vision/generic_image_encoder.py
"""

import torch
import torchvision.models as models

from nemo.backends.pytorch.nm import TrainableNM
from nemo.core.neural_types import AxisKind, AxisType, NeuralType, VoidType
from nemo.utils import logging
from nemo.utils.configuration_parsing import get_value_from_dictionary
from nemo.utils.decorators import add_port_docs

__all__ = ['GenericImageEncoder']

class GenericImageEncoder(TrainableNM):
"""
Class
"""
def __init__(self, model_type, output_size=None, return_feature_maps=False, pretrained=False, name=None):
"""
Initializes the ``GenericImageEncoder`` model, creates the required backend.
"""
TrainableNM.__init__(self, name=name)

# Get operation modes.
self._return_feature_maps = return_feature_maps

# Get model type.
self._model_type = get_value_from_dictionary(model_type, "vgg16 | densenet121 | resnet152 | resnet50".split(" | "))

# Get output size (optional).
self._output_size = output_size

if(self._model_type == 'vgg16'):
# Get VGG16
self._model = models.vgg16(pretrained=pretrained)

if self._return_feature_maps:
# Use only the "feature encoder".
self._model = self._model.features

# Remember the output feature map dims.
self._feature_map_height = 7
self._feature_map_width = 7
self._feature_map_depth = 512

else:
# Use the whole model, but "reshape"/reinstantiate the last layer ("FC6").
self._model.classifier._modules['6'] = torch.nn.Linear(4096, self._output_size)

elif(self._model_type == 'densenet121'):
# Get densenet121
self._model = models.densenet121(pretrained=pretrained)

if self._return_feature_maps:
raise ConfigurationError("'densenet121' doesn't support 'return_feature_maps' mode (yet)")

# Use the whole model, but "reshape"/reinstantiate the last layer ("FC6").
self._model.classifier = torch.nn.Linear(1024, self._output_size)


elif(self._model_type == 'resnet152'):
# Get resnet152
self._model = models.resnet152(pretrained=pretrained)

if self._return_feature_maps:
# Get all modules exluding last (avgpool) and (fc)
modules=list(self._model.children())[:-2]
self._model=torch.nn.Sequential(*modules)

# Remember the output feature map dims.
self._feature_map_height = 7
self._feature_map_width = 7
self._feature_map_depth = 2048

else:
# Use the whole model, but "reshape"/reinstantiate the last layer ("FC6").
self._model.fc = torch.nn.Linear(2048, self._output_size)

elif(self._model_type == 'resnet50'):
# Get resnet50
self._model = models.resnet50(pretrained=pretrained)

if self._return_feature_maps:
# Get all modules exluding last (avgpool) and (fc)
modules=list(self._model.children())[:-2]
self._model=torch.nn.Sequential(*modules)

# Remember the output feature map dims.
self._feature_map_height = 7
self._feature_map_width = 7
self._feature_map_depth = 2048

else:
# Use the whole model, but "reshape"/reinstantiate the last layer ("FC6").
self._model.fc = torch.nn.Linear(2048, self._output_size)


@property
@add_port_docs()
def input_ports(self):
"""
Returns definitions of module input ports.
"""
return {
"inputs": NeuralType(
axes=(
AxisType(kind=AxisKind.Batch),
AxisType(kind=AxisKind.Channel, size=3),
AxisType(kind=AxisKind.Height, size=224),
AxisType(kind=AxisKind.Width, size=224),
),
elements_type=VoidType(),
)
}

@property
@add_port_docs()
def output_ports(self):
"""
Returns definitions of module output ports.
"""
# Return neural type.
if self._return_feature_maps:
return {
"outputs": NeuralType(
axes=(
AxisType(kind=AxisKind.Batch),
AxisType(kind=AxisKind.Channel, size=self._feature_map_depth),
AxisType(kind=AxisKind.Height, size=self._feature_map_height),
AxisType(kind=AxisKind.Width, size=self._feature_map_width),
),
elements_type=VoidType(),
)
}
else:
return {
"outputs": NeuralType(
axes=(
AxisType(kind=AxisKind.Batch),
AxisType(kind=AxisKind.Any, size=self._output_size),
),
elements_type=VoidType(),
)
}


def forward(self, inputs):
"""
Main forward pass of the model.
Args:
inputs: expected stream containing images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE WIDTH]
Returns:
outpus: added stream containing outputs [BATCH_SIZE x OUTPUT_SIZE]
OR [BATCH_SIZE x OUTPUT_DEPTH x OUTPUT_HEIGHT x OUTPUT_WIDTH]
"""
#print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device))

outputs = self._model(inputs)

# Add outputs to datadict.
return outputs
Loading

0 comments on commit b165288

Please sign in to comment.