Skip to content

Commit

Permalink
Merge pull request #268 from NVIDIA/flow-naming
Browse files Browse the repository at this point in the history
Named modules output
  • Loading branch information
okuchaiev authored Jan 15, 2020
2 parents e86f8d6 + 8bfb196 commit 5467aba
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
15 changes: 14 additions & 1 deletion nemo/nemo/core/neural_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
from typing import Optional, Dict, Set, Tuple, List
import uuid
import collections

from nemo.core import NeuralModuleFactory

Expand Down Expand Up @@ -213,7 +214,19 @@ def __call__(self, **kwargs):
ntype=out_type,
)
)
return tuple(result)

# Creating ad-hoc class for returning from module's forward pass.
output_class_name = f'{self.__class__.__name__}Output'
field_names = list(output_port_defs)
result_type = collections.namedtuple(
typename=output_class_name,
field_names=field_names,
)

# Tie tuple of output tensors with corresponding names.
result = result_type(*result)

return result

def __str__(self):
return self.__class__.__name__
Expand Down
25 changes: 25 additions & 0 deletions tests/test_pytorch_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,31 @@ def test_simple_train(self):
optimization_params={"lr": 0.0003, "num_epochs": 1}
)

def test_simple_train_named_output(self):
print('Simplest train test with using named output.')
data_source = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(
n=10000,
batch_size=128,
)
trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4)
loss = nemo.backends.pytorch.tutorials.MSELoss()

data = data_source()
self.assertEqual(
first=type(data).__name__,
second='RealFunctionDataLayerOutput',
msg='Check output class naming coherence.',
)
y_pred = trainable_module(x=data.x)
loss_tensor = loss(predictions=y_pred, target=data.y)

optimizer = nemo.backends.pytorch.actions.PtActions()
optimizer.train(
tensors_to_optimize=[loss_tensor],
optimizer="sgd",
optimization_params={"lr": 0.0003, "num_epochs": 1}
)

def test_simple_chained_train(self):
print("Chained train test")
data_source = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(
Expand Down

0 comments on commit 5467aba

Please sign in to comment.