Skip to content

Commit

Permalink
refactored tests and make sure datasets have same lengths when zipped
Browse files Browse the repository at this point in the history
Signed-off-by: Yang Zhang <[email protected]>
  • Loading branch information
yzhang123 committed Apr 1, 2020
1 parent e90da39 commit efb4b32
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 46 deletions.
20 changes: 9 additions & 11 deletions nemo/backends/pytorch/common/multi_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ def __init__(self, datasets: List[torch.utils.data.Dataset], combination_mode: s
"""
self.datasets = datasets
self.combination_mode = combination_mode
self.len = None
if self.combination_mode == "cross_product":
self.len = np.prod([len(d) for d in self.datasets])
elif self.combination_mode == "zip":
ds_lens = [len(d) for d in self.datasets]
self.len = np.min(ds_lens)
if len(set(ds_lens)) != 1:
raise ValueError("datasets do not have equal lengths.")
else:
raise ValueError("combination_mode unknown")

def __getitem__(self, i):
"""
Expand All @@ -115,14 +123,4 @@ def __len__(self):
In case of combination_mode="cross_product" this would be prod(len(d) for d in self.datasets).
In case of combination_mode="zip" this would be min(len(d) for d in self.datasets) given that all datasets have same length.
"""
if not self.len:
if self.combination_mode == "cross_product":
self.len = np.prod([len(d) for d in self.datasets])
elif self.combination_mode == "zip":
ds_lens = [len(d) for d in self.datasets]
self.len = np.min(ds_lens)
if not np.all(ds_lens):
logging.warning("datasets do not have equal lengths and will be pruned to the shortest length.")
else:
raise ValueError("combination_mode unknown")
return self.len
68 changes: 68 additions & 0 deletions tests/integration/test_integration_multidataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import os
import shutil
from unittest import TestCase

import pytest
import torch

import nemo
from nemo.core import ChannelType, NeuralType

logging = nemo.logging


@pytest.mark.usefixtures("neural_factory")
class TestMultiDLIntegration(TestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()

@pytest.mark.integration
def test_pipeline(self):
batch_size = 4
dataset_size_0 = 100
dataset_size_1 = 100
shuffle = False
dl_1 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_0)
dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1)

data_layer = nemo.backends.pytorch.common.MultiDataLayer(
data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip"
)
x_0, y_0, x_1, y_1 = data_layer()

trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4)
loss = nemo.backends.pytorch.tutorials.MSELoss()
combined_loss = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2)
pred_0 = trainable_module(x=x_0)
pred_1 = trainable_module(x=x_1)
l_0 = loss(predictions=pred_0, target=y_0)
l_1 = loss(predictions=pred_1, target=y_1)
total_loss = combined_loss(loss_1=l_0, loss_2=l_1)

callback = nemo.core.SimpleLossLoggerCallback(
tensors=[total_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'),
)
# Instantiate an optimizer to perform `train` action
optimizer = nemo.backends.pytorch.actions.PtActions()
optimizer.train(
tensors_to_optimize=[total_loss], optimizer="sgd", optimization_params={"lr": 0.0003, "max_steps": 2},
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@
import torch

import nemo
from nemo.core import ChannelType, LabelsType, MaskType, NeuralType
from nemo.core import ChannelType, NeuralType

logging = nemo.logging


@pytest.mark.usefixtures("neural_factory")
class TestMultiDL(TestCase):
class TestMultiDLUnit(TestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()

@pytest.mark.unclassified
@pytest.mark.unit
def test_port_name_collision_handling(self):
batch_size = 4
dataset_size = 4
Expand All @@ -59,7 +59,7 @@ def test_port_name_collision_handling(self):
self.assertEqual([*data_layer.output_ports], ["a", "b", "a_1", "c"])
self.assertEqual(len(data_layer), dataset_size * dataset_size)

@pytest.mark.unclassified
@pytest.mark.unit
def test_port_renaming(self):
batch_size = 4
dataset_size = 4
Expand All @@ -86,11 +86,12 @@ def test_port_renaming(self):
)
self.assertEqual([*data_layer.output_ports], ["1", "2", "3", "4"])

@pytest.mark.unclassified
def test_multi_dl_zip(self):
@pytest.mark.unit
def test_multi_dl_zip_working(self):
dataset_size_0 = 2
dataset_size_1 = 2
final_dataset_size = 2
batch_size = 4
dataset_size_0 = 4
dataset_size_1 = 5
shuffle = False
dl_1 = nemo.backends.pytorch.common.ZerosDataLayer(
size=dataset_size_0,
Expand All @@ -108,36 +109,53 @@ def test_multi_dl_zip(self):
data_layer = nemo.backends.pytorch.common.MultiDataLayer(
data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip"
)
self.assertEqual(len(data_layer), dataset_size_0)
self.assertEqual(len(data_layer), final_dataset_size)

@pytest.mark.unclassified
def test_pipeline(self):
@pytest.mark.unit
def test_multi_dl_zip_failing(self):
dataset_size_0 = 4
dataset_size_1 = 2
batch_size = 4
dataset_size_0 = 100
dataset_size_1 = 100
shuffle = False
dl_1 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_0)
dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1)

data_layer = nemo.backends.pytorch.common.MultiDataLayer(
data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip"
dl_1 = nemo.backends.pytorch.common.ZerosDataLayer(
size=dataset_size_0,
dtype=torch.FloatTensor,
batch_size=batch_size,
output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())},
)
x_0, y_0, x_1, y_1 = data_layer()

trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4)
loss = nemo.backends.pytorch.tutorials.MSELoss()
combined_loss = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2)
pred_0 = trainable_module(x=x_0)
pred_1 = trainable_module(x=x_1)
l_0 = loss(predictions=pred_0, target=y_0)
l_1 = loss(predictions=pred_1, target=y_1)
total_loss = combined_loss(loss_1=l_0, loss_2=l_1)

callback = nemo.core.SimpleLossLoggerCallback(
tensors=[total_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'),
dl_2 = nemo.backends.pytorch.common.ZerosDataLayer(
size=dataset_size_1,
dtype=torch.FloatTensor,
batch_size=batch_size,
output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())},
)
# Instantiate an optimizer to perform `train` action
optimizer = nemo.backends.pytorch.actions.PtActions()
optimizer.train(
tensors_to_optimize=[total_loss], optimizer="sgd", optimization_params={"lr": 0.0003, "num_epochs": 1},

with pytest.raises(ValueError):
data_layer = nemo.backends.pytorch.common.MultiDataLayer(
data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip"
)

@pytest.mark.unit
def test_multi_dl_wrong_combination(self):
dataset_size_0 = 2
dataset_size_1 = 2
unknown_combination = "cross"
batch_size = 4
shuffle = False
dl_1 = nemo.backends.pytorch.common.ZerosDataLayer(
size=dataset_size_0,
dtype=torch.FloatTensor,
batch_size=batch_size,
output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())},
)
dl_2 = nemo.backends.pytorch.common.ZerosDataLayer(
size=dataset_size_1,
dtype=torch.FloatTensor,
batch_size=batch_size,
output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())},
)

with pytest.raises(ValueError):
data_layer = nemo.backends.pytorch.common.MultiDataLayer(
data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode=unknown_combination
)

0 comments on commit efb4b32

Please sign in to comment.