Skip to content

Commit

Permalink
Add output class name constructor with test.
Browse files Browse the repository at this point in the history
Signed-off-by: Stanislav Beliaev <[email protected]>
  • Loading branch information
stasbel committed Jan 15, 2020
1 parent 2b98903 commit 8bfb196
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
6 changes: 5 additions & 1 deletion nemo/nemo/core/neural_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,12 @@ def __call__(self, **kwargs):
)

# Creating ad-hoc class for returning from module's forward pass.
output_class_name = f'{self.__class__.__name__}Output'
field_names = list(output_port_defs)
result_type = collections.namedtuple('NmOutput', field_names)
result_type = collections.namedtuple(
typename=output_class_name,
field_names=field_names,
)

# Tie tuple of output tensors with corresponding names.
result = result_type(*result)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_pytorch_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def test_simple_train_named_output(self):
loss = nemo.backends.pytorch.tutorials.MSELoss()

data = data_source()
self.assertEqual(
first=type(data).__name__,
second='RealFunctionDataLayerOutput',
msg='Check output class naming coherence.',
)
y_pred = trainable_module(x=data.x)
loss_tensor = loss(predictions=y_pred, target=data.y)

Expand Down

0 comments on commit 8bfb196

Please sign in to comment.