Skip to content

Commit

Permalink
formatting fix
Browse files Browse the repository at this point in the history
Signed-off-by: Tomasz Kornuta <[email protected]>
  • Loading branch information
tkornuta-nvidia committed Jun 2, 2020
1 parent ee6ec4a commit 3730537
Show file tree
Hide file tree
Showing 10 changed files with 16 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor
from nemo.collections.cv.modules.trainables import FeedForwardNetwork, GenericImageEncoder
from nemo.core import (
DeviceType,
NeuralGraph,
NeuralModuleFactory,
OperationMode,
SimpleLossLoggerCallback,
)
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
Expand All @@ -51,7 +45,7 @@

# Create a training graph.
with NeuralGraph(operation_mode=OperationMode.training) as training_graph:
_, img, _, _, fine_target, _ = cifar100_dl()
_, img, _, _, fine_target, _ = cifar100_dl()
feat_map = image_encoder(inputs=img)
res_img = reshaper(inputs=feat_map)
logits = ffn(inputs=res_img)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor
from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork
from nemo.core import (
DeviceType,
NeuralGraph,
NeuralModuleFactory,
OperationMode,
SimpleLossLoggerCallback,
)
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity
from nemo.collections.cv.modules.trainables import GenericImageEncoder
from nemo.core import (
DeviceType,
NeuralGraph,
NeuralModuleFactory,
OperationMode,
SimpleLossLoggerCallback,
)
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor
from nemo.collections.cv.modules.trainables import ConvNetEncoder, FeedForwardNetwork
from nemo.core import (
DeviceType,
NeuralGraph,
NeuralModuleFactory,
OperationMode,
SimpleLossLoggerCallback,
)
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,7 @@
from nemo.collections.cv.modules.losses import NLLLoss
from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor
from nemo.collections.cv.modules.trainables import FeedForwardNetwork
from nemo.core import (
DeviceType,
NeuralGraph,
NeuralModuleFactory,
OperationMode,
SimpleLossLoggerCallback,
)
from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode, SimpleLossLoggerCallback
from nemo.utils import logging

if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

# Create a validation graph, starting from the second data layer.
with NeuralGraph(operation_mode=OperationMode.evaluation) as evaluation_graph:
_, x, y, _ = dl_e()
_, x, y, _ = dl_e()
p = lenet5(images=x)
loss_e = nll_loss(predictions=p, targets=y)

Expand Down
12 changes: 2 additions & 10 deletions nemo/collections/cv/modules/data_layers/cifar100_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,7 @@
from torchvision.transforms import Compose, Resize, ToTensor

from nemo.backends.pytorch.nm import DataLayerNM
from nemo.core.neural_types import (
AxisKind,
AxisType,
ClassificationTarget,
ImageValue,
Index,
Label,
NeuralType,
)
from nemo.core.neural_types import AxisKind, AxisType, ClassificationTarget, ImageValue, Index, Label, NeuralType
from nemo.utils.decorators import add_port_docs

__all__ = ['CIFAR100DataLayer']
Expand Down Expand Up @@ -204,7 +196,7 @@ def __getitem__(self, index: int):
img, fine_target = self._dataset.__getitem__(index)
# Get coarse target.
coarse_target = self._fine_to_coarse_id_mapping[fine_target]

# Labels.
fine_label = self._fine_ix_to_word[fine_target]
coarse_label = self._coarse_ix_to_word[self._fine_to_coarse_id_mapping[fine_target]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __getitem__(self, index: int):
"""
# Get image and target.
img, target = self._dataset.__getitem__(index)

# Return sample.
return index, img, target

Expand Down
5 changes: 2 additions & 3 deletions nemo/collections/cv/modules/data_layers/mnist_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
AxisType,
ClassificationTarget,
Index,
Label,
NeuralType,
NormalizedImageValue,
Label,
)
from nemo.utils.decorators import add_port_docs

Expand Down Expand Up @@ -129,7 +129,7 @@ def __getitem__(self, index: int):
"""
# Get image and target.
img, target = self._dataset.__getitem__(index)

# Return sample.
return index, img, target, self._ix_to_word[target]

Expand All @@ -149,4 +149,3 @@ def dataset(self):
Self - just to be "compatible" with the current NeMo train action.
"""
return self # ! Important - as we want to use this __getitem__ method!

4 changes: 4 additions & 0 deletions nemo/core/neural_types/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,19 @@ class Target(ElementType):
Type representing an element being a target value.
"""


class ClassificationTarget(Target):
"""
Type representing an element being target value in the classification task, i.e. identifier of a desired class.
"""


class Label(ElementType):
"""
Type representing an element being a target value.
"""


class ImageValue(ElementType):
"""
Type representing an element/value of a single image channel,
Expand All @@ -235,5 +238,6 @@ class NormalizedImageValue(ImageValue):
e.g. a single element (R) of normalized RGB image.
"""


class ImageFeatureValue(ImageValue):
"""Type representing an element (single value) of a (image) feature maps."""

0 comments on commit 3730537

Please sign in to comment.