From 7ec0e0a2130bcf66f62622007e23a7f9b7975e03 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 1 Apr 2020 11:04:11 -0700 Subject: [PATCH 1/8] added pytest for multi-dataset Signed-off-by: Yang Zhang --- nemo/backends/pytorch/common/__init__.py | 1 + nemo/backends/pytorch/common/multi_data.py | 128 ++++++++++++++++ nemo/backends/pytorch/tutorials/toys.py | 8 +- .../test_unclassified_multidataset.py | 143 ++++++++++++++++++ 4 files changed, 275 insertions(+), 5 deletions(-) create mode 100644 nemo/backends/pytorch/common/multi_data.py create mode 100644 tests/unclassified/test_unclassified_multidataset.py diff --git a/nemo/backends/pytorch/common/__init__.py b/nemo/backends/pytorch/common/__init__.py index adf89ab704d3..c80017b33e10 100644 --- a/nemo/backends/pytorch/common/__init__.py +++ b/nemo/backends/pytorch/common/__init__.py @@ -1,4 +1,5 @@ from nemo.backends.pytorch.common.losses import * +from nemo.backends.pytorch.common.multi_data import * from nemo.backends.pytorch.common.other import * from nemo.backends.pytorch.common.parts import * from nemo.backends.pytorch.common.rnn import * diff --git a/nemo/backends/pytorch/common/multi_data.py b/nemo/backends/pytorch/common/multi_data.py new file mode 100644 index 000000000000..ab454b67a306 --- /dev/null +++ b/nemo/backends/pytorch/common/multi_data.py @@ -0,0 +1,128 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from typing import List + +import numpy as np +import torch + +from nemo import logging +from nemo.backends.pytorch.nm import DataLayerNM +from nemo.core.neural_types import * + +__all__ = ['MultiDataLayer'] + + +class MultiDataLayer(DataLayerNM): + def __init__( + self, + data_layers: List[DataLayerNM], + batch_size: int, + shuffle: bool = False, + combination_mode: str = "cross_product", + port_names: List[str] = None, + ): + """ + data_layers: (list) of DataLayerNM objects + batch_size: (int) batchsize when the underlying dataset is loaded + combination_mode: (str) defines how to combine the datasets, Options are ["cross_product", "zip"]. + shuffle: (bool) whether underlying multi dataset should be shuffled in each epoch + port_names: List(str) user can override all port names if specified + """ + super().__init__() + self._data_layers = data_layers + self._batch_size = batch_size + self._shuffle = shuffle + self._combination_mode = combination_mode + self._port_names = port_names + self._dataset = MultiDataset( + datasets=[dl.dataset for dl in self._data_layers], combination_mode=combination_mode + ) + + total_num_port = sum([len(dl.output_ports) for dl in self._data_layers]) + self._ports = dict() + if self._port_names: + assert (len(self._port_names) == total_num_port, "Number of ports is does not match.") + i = 0 + for dl in self._data_layers: + for _, port_type in dl.output_ports.items(): + self._ports[self._port_names[i]] = port_type + i += 1 + else: + for dl_idx, dl in enumerate(self._data_layers): + for port_name, port_type in dl.output_ports.items(): + if port_name in self._ports: + logging.warning(f"name collision {port_name}, will rename") + self._ports[f"{port_name}_{dl_idx}"] = port_type + else: + self._ports[port_name] = port_type + + @property + def output_ports(self): + """Return: dict + Returns union of all individual data_layer output ports + In case of name collision, resolve by renaming + """ + return self._ports + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return self._dataset + + @property + def data_iterator(self): + return None + + +class MultiDataset(torch.utils.data.Dataset): + def __init__(self, datasets: List[torch.utils.data.Dataset], combination_mode: str = "cross_product"): + """ + Datasets: list of torch.utils.data.Dataset objects. + combination_mode: str, defines how to combine the datasets, Options are ["cross_product", "zip"]. + """ + self.datasets = datasets + self.combination_mode = combination_mode + self.len = None + + def __getitem__(self, i): + """ + Returns tuple (x1, x2, ...xn) where x1 \in D1, x2 \in D2, ...xn\ Dn + """ + + return [x for d in self.datasets for x in d[i % len(d)]] + + def __len__(self): + """ + Returns length of this dataset (int). + In case of combination_mode="cross_product" this would be prod(len(d) for d in self.datasets). + In case of combination_mode="zip" this would be min(len(d) for d in self.datasets) given that all datasets have same length. + """ + if not self.len: + if self.combination_mode == "cross_product": + self.len = np.prod([len(d) for d in self.datasets]) + elif self.combination_mode == "zip": + ds_lens = [len(d) for d in self.datasets] + self.len = np.min(ds_lens) + if not np.all(ds_lens): + logging.warning("datasets do not have equal lengths and will be pruned to the shortest length.") + else: + raise ValueError("combination_mode unknown") + return self.len diff --git a/nemo/backends/pytorch/tutorials/toys.py b/nemo/backends/pytorch/tutorials/toys.py index 25fa8bd7c277..817658cd4518 100644 --- a/nemo/backends/pytorch/tutorials/toys.py +++ b/nemo/backends/pytorch/tutorials/toys.py @@ -158,10 +158,8 @@ def __init__(self, batch_size, f_name="sin", n=1000, x_lo=-4, x_hi=4): x_data = t.tensor(np.random.uniform(low=x_lo, high=x_hi, size=self._n)).unsqueeze(-1).to(self._device) y_data = func(x_data) - - self._data_iterator = t_utils.DataLoader( - t_utils.TensorDataset(x_data.float(), y_data.float()), batch_size=self._batch_size, - ) + self._dataset = t_utils.TensorDataset(x_data.float(), y_data.float()) + self._data_iterator = t_utils.DataLoader(self._dataset, batch_size=self._batch_size,) @property def data_iterator(self): @@ -169,7 +167,7 @@ def data_iterator(self): @property def dataset(self): - return None + return self._dataset class MSELoss(LossNM): diff --git a/tests/unclassified/test_unclassified_multidataset.py b/tests/unclassified/test_unclassified_multidataset.py new file mode 100644 index 000000000000..21fe9aa81bd4 --- /dev/null +++ b/tests/unclassified/test_unclassified_multidataset.py @@ -0,0 +1,143 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import os +import shutil +from unittest import TestCase + +import pytest +import torch + +import nemo +from nemo.core import ChannelType, LabelsType, MaskType, NeuralType + +logging = nemo.logging + + +@pytest.mark.usefixtures("neural_factory") +class TestMultiDL(TestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + @pytest.mark.unclassified + def test_port_name_collision_handling(self): + batch_size = 4 + dataset_size = 4 + shuffle = False + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())}, + ) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="cross_product" + ) + self.assertEqual([*data_layer.output_ports], ["a", "b", "a_1", "c"]) + self.assertEqual(len(data_layer), dataset_size * dataset_size) + + @pytest.mark.unclassified + def test_port_renaming(self): + batch_size = 4 + dataset_size = 4 + shuffle = False + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], + batch_size=batch_size, + shuffle=shuffle, + combination_mode="cross_product", + port_names=["1", "2", "3", "4"], + ) + self.assertEqual([*data_layer.output_ports], ["1", "2", "3", "4"]) + + @pytest.mark.unclassified + def test_multi_dl_zip(self): + batch_size = 4 + dataset_size_0 = 4 + dataset_size_1 = 5 + shuffle = False + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_0, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_1, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())}, + ) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + ) + self.assertEqual(len(data_layer), dataset_size_0) + + @pytest.mark.unclassified + def test_pipeline(self): + batch_size = 4 + dataset_size_0 = 100 + dataset_size_1 = 100 + shuffle = False + dl_1 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_0) + dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + ) + x_0, y_0, x_1, y_1 = data_layer() + + trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + loss = nemo.backends.pytorch.tutorials.MSELoss() + combined_loss = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2) + pred_0 = trainable_module(x=x_0) + pred_1 = trainable_module(x=x_1) + l_0 = loss(predictions=pred_0, target=y_0) + l_1 = loss(predictions=pred_1, target=y_1) + total_loss = combined_loss(loss_1=l_0, loss_2=l_1) + + callback = nemo.core.SimpleLossLoggerCallback( + tensors=[total_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), + ) + # Instantiate an optimizer to perform `train` action + optimizer = nemo.backends.pytorch.actions.PtActions() + optimizer.train( + tensors_to_optimize=[total_loss], optimizer="sgd", optimization_params={"lr": 0.0003, "num_epochs": 1}, + ) From 328cc7a26da32a4fd9049008cff4caf6670e251b Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 1 Apr 2020 12:20:15 -0700 Subject: [PATCH 2/8] fix toy real function data layer Signed-off-by: Yang Zhang --- nemo/backends/pytorch/tutorials/toys.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/backends/pytorch/tutorials/toys.py b/nemo/backends/pytorch/tutorials/toys.py index 817658cd4518..c3d9027b180e 100644 --- a/nemo/backends/pytorch/tutorials/toys.py +++ b/nemo/backends/pytorch/tutorials/toys.py @@ -154,9 +154,8 @@ def __init__(self, batch_size, f_name="sin", n=1000, x_lo=-4, x_hi=4): self._n = n self._batch_size = batch_size - self._device = t.device("cuda" if self.placement == DeviceType.GPU else "cpu") - x_data = t.tensor(np.random.uniform(low=x_lo, high=x_hi, size=self._n)).unsqueeze(-1).to(self._device) + x_data = t.tensor(np.random.uniform(low=x_lo, high=x_hi, size=self._n)).unsqueeze(-1) y_data = func(x_data) self._dataset = t_utils.TensorDataset(x_data.float(), y_data.float()) self._data_iterator = t_utils.DataLoader(self._dataset, batch_size=self._batch_size,) From c267b8b1f75dfe1cdc42e6ebe6dd607ec31de5d8 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 1 Apr 2020 14:16:51 -0700 Subject: [PATCH 3/8] refactored tests and make sure datasets have same lengths when zipped Signed-off-by: Yang Zhang --- nemo/backends/pytorch/common/multi_data.py | 20 ++--- .../test_integration_multidataset.py | 68 ++++++++++++++ .../test_unit_multidataset.py} | 88 +++++++++++-------- 3 files changed, 130 insertions(+), 46 deletions(-) create mode 100644 tests/integration/test_integration_multidataset.py rename tests/{unclassified/test_unclassified_multidataset.py => unit/test_unit_multidataset.py} (63%) diff --git a/nemo/backends/pytorch/common/multi_data.py b/nemo/backends/pytorch/common/multi_data.py index ab454b67a306..ae0ec26a7585 100644 --- a/nemo/backends/pytorch/common/multi_data.py +++ b/nemo/backends/pytorch/common/multi_data.py @@ -100,7 +100,15 @@ def __init__(self, datasets: List[torch.utils.data.Dataset], combination_mode: s """ self.datasets = datasets self.combination_mode = combination_mode - self.len = None + if self.combination_mode == "cross_product": + self.len = np.prod([len(d) for d in self.datasets]) + elif self.combination_mode == "zip": + ds_lens = [len(d) for d in self.datasets] + self.len = np.min(ds_lens) + if len(set(ds_lens)) != 1: + raise ValueError("datasets do not have equal lengths.") + else: + raise ValueError("combination_mode unknown") def __getitem__(self, i): """ @@ -115,14 +123,4 @@ def __len__(self): In case of combination_mode="cross_product" this would be prod(len(d) for d in self.datasets). In case of combination_mode="zip" this would be min(len(d) for d in self.datasets) given that all datasets have same length. """ - if not self.len: - if self.combination_mode == "cross_product": - self.len = np.prod([len(d) for d in self.datasets]) - elif self.combination_mode == "zip": - ds_lens = [len(d) for d in self.datasets] - self.len = np.min(ds_lens) - if not np.all(ds_lens): - logging.warning("datasets do not have equal lengths and will be pruned to the shortest length.") - else: - raise ValueError("combination_mode unknown") return self.len diff --git a/tests/integration/test_integration_multidataset.py b/tests/integration/test_integration_multidataset.py new file mode 100644 index 000000000000..35491f620aa6 --- /dev/null +++ b/tests/integration/test_integration_multidataset.py @@ -0,0 +1,68 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import os +import shutil +from unittest import TestCase + +import pytest +import torch + +import nemo +from nemo.core import ChannelType, NeuralType + +logging = nemo.logging + + +@pytest.mark.usefixtures("neural_factory") +class TestMultiDLIntegration(TestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + @pytest.mark.integration + def test_pipeline(self): + batch_size = 4 + dataset_size_0 = 100 + dataset_size_1 = 100 + shuffle = False + dl_1 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_0) + dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + ) + x_0, y_0, x_1, y_1 = data_layer() + + trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + loss = nemo.backends.pytorch.tutorials.MSELoss() + combined_loss = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2) + pred_0 = trainable_module(x=x_0) + pred_1 = trainable_module(x=x_1) + l_0 = loss(predictions=pred_0, target=y_0) + l_1 = loss(predictions=pred_1, target=y_1) + total_loss = combined_loss(loss_1=l_0, loss_2=l_1) + + callback = nemo.core.SimpleLossLoggerCallback( + tensors=[total_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), + ) + # Instantiate an optimizer to perform `train` action + optimizer = nemo.backends.pytorch.actions.PtActions() + optimizer.train( + tensors_to_optimize=[total_loss], optimizer="sgd", optimization_params={"lr": 0.0003, "max_steps": 2}, + ) diff --git a/tests/unclassified/test_unclassified_multidataset.py b/tests/unit/test_unit_multidataset.py similarity index 63% rename from tests/unclassified/test_unclassified_multidataset.py rename to tests/unit/test_unit_multidataset.py index 21fe9aa81bd4..f356c61422db 100644 --- a/tests/unclassified/test_unclassified_multidataset.py +++ b/tests/unit/test_unit_multidataset.py @@ -24,18 +24,18 @@ import torch import nemo -from nemo.core import ChannelType, LabelsType, MaskType, NeuralType +from nemo.core import ChannelType, NeuralType logging = nemo.logging @pytest.mark.usefixtures("neural_factory") -class TestMultiDL(TestCase): +class TestMultiDLUnit(TestCase): @classmethod def setUpClass(cls) -> None: super().setUpClass() - @pytest.mark.unclassified + @pytest.mark.unit def test_port_name_collision_handling(self): batch_size = 4 dataset_size = 4 @@ -59,7 +59,7 @@ def test_port_name_collision_handling(self): self.assertEqual([*data_layer.output_ports], ["a", "b", "a_1", "c"]) self.assertEqual(len(data_layer), dataset_size * dataset_size) - @pytest.mark.unclassified + @pytest.mark.unit def test_port_renaming(self): batch_size = 4 dataset_size = 4 @@ -86,11 +86,12 @@ def test_port_renaming(self): ) self.assertEqual([*data_layer.output_ports], ["1", "2", "3", "4"]) - @pytest.mark.unclassified - def test_multi_dl_zip(self): + @pytest.mark.unit + def test_multi_dl_zip_working(self): + dataset_size_0 = 2 + dataset_size_1 = 2 + final_dataset_size = 2 batch_size = 4 - dataset_size_0 = 4 - dataset_size_1 = 5 shuffle = False dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( size=dataset_size_0, @@ -108,36 +109,53 @@ def test_multi_dl_zip(self): data_layer = nemo.backends.pytorch.common.MultiDataLayer( data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" ) - self.assertEqual(len(data_layer), dataset_size_0) + self.assertEqual(len(data_layer), final_dataset_size) - @pytest.mark.unclassified - def test_pipeline(self): + @pytest.mark.unit + def test_multi_dl_zip_failing(self): + dataset_size_0 = 4 + dataset_size_1 = 2 batch_size = 4 - dataset_size_0 = 100 - dataset_size_1 = 100 shuffle = False - dl_1 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_0) - dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1) - - data_layer = nemo.backends.pytorch.common.MultiDataLayer( - data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_0, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, ) - x_0, y_0, x_1, y_1 = data_layer() - - trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) - loss = nemo.backends.pytorch.tutorials.MSELoss() - combined_loss = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2) - pred_0 = trainable_module(x=x_0) - pred_1 = trainable_module(x=x_1) - l_0 = loss(predictions=pred_0, target=y_0) - l_1 = loss(predictions=pred_1, target=y_1) - total_loss = combined_loss(loss_1=l_0, loss_2=l_1) - - callback = nemo.core.SimpleLossLoggerCallback( - tensors=[total_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_1, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())}, ) - # Instantiate an optimizer to perform `train` action - optimizer = nemo.backends.pytorch.actions.PtActions() - optimizer.train( - tensors_to_optimize=[total_loss], optimizer="sgd", optimization_params={"lr": 0.0003, "num_epochs": 1}, + + with pytest.raises(ValueError): + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + ) + + @pytest.mark.unit + def test_multi_dl_wrong_combination(self): + dataset_size_0 = 2 + dataset_size_1 = 2 + unknown_combination = "cross" + batch_size = 4 + shuffle = False + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_0, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, ) + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_1, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())}, + ) + + with pytest.raises(ValueError): + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode=unknown_combination + ) From 2187c00d29ad4ca34a5a6747936b3b121e50bade Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 1 Apr 2020 14:33:44 -0700 Subject: [PATCH 4/8] changed types from string to enum Signed-off-by: Yang Zhang --- nemo/backends/pytorch/common/multi_data.py | 22 ++++++++++++++----- .../test_integration_multidataset.py | 3 ++- tests/unit/test_unit_multidataset.py | 12 ++++++---- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/nemo/backends/pytorch/common/multi_data.py b/nemo/backends/pytorch/common/multi_data.py index ae0ec26a7585..d7ae5cf04c8d 100644 --- a/nemo/backends/pytorch/common/multi_data.py +++ b/nemo/backends/pytorch/common/multi_data.py @@ -16,6 +16,7 @@ # limitations under the License. # ============================================================================= +from enum import Enum from typing import List import numpy as np @@ -25,7 +26,12 @@ from nemo.backends.pytorch.nm import DataLayerNM from nemo.core.neural_types import * -__all__ = ['MultiDataLayer'] +__all__ = ['MultiDataLayer', 'DataCombination'] + + +class DataCombination(Enum): + CROSSPRODUCT = 1 + ZIP = 2 class MultiDataLayer(DataLayerNM): @@ -34,13 +40,13 @@ def __init__( data_layers: List[DataLayerNM], batch_size: int, shuffle: bool = False, - combination_mode: str = "cross_product", + combination_mode: DataCombination = DataCombination.CROSSPRODUCT, port_names: List[str] = None, ): """ data_layers: (list) of DataLayerNM objects batch_size: (int) batchsize when the underlying dataset is loaded - combination_mode: (str) defines how to combine the datasets, Options are ["cross_product", "zip"]. + combination_mode: (DataCombination) defines how to combine the datasets. shuffle: (bool) whether underlying multi dataset should be shuffled in each epoch port_names: List(str) user can override all port names if specified """ @@ -93,16 +99,20 @@ def data_iterator(self): class MultiDataset(torch.utils.data.Dataset): - def __init__(self, datasets: List[torch.utils.data.Dataset], combination_mode: str = "cross_product"): + def __init__( + self, + datasets: List[torch.utils.data.Dataset], + combination_mode: DataCombination = DataCombination.CROSSPRODUCT, + ): """ Datasets: list of torch.utils.data.Dataset objects. combination_mode: str, defines how to combine the datasets, Options are ["cross_product", "zip"]. """ self.datasets = datasets self.combination_mode = combination_mode - if self.combination_mode == "cross_product": + if self.combination_mode == DataCombination.CROSSPRODUCT: self.len = np.prod([len(d) for d in self.datasets]) - elif self.combination_mode == "zip": + elif self.combination_mode == DataCombination.ZIP: ds_lens = [len(d) for d in self.datasets] self.len = np.min(ds_lens) if len(set(ds_lens)) != 1: diff --git a/tests/integration/test_integration_multidataset.py b/tests/integration/test_integration_multidataset.py index 35491f620aa6..892d3e08bcb4 100644 --- a/tests/integration/test_integration_multidataset.py +++ b/tests/integration/test_integration_multidataset.py @@ -24,6 +24,7 @@ import torch import nemo +from nemo.backends.pytorch.common import DataCombination from nemo.core import ChannelType, NeuralType logging = nemo.logging @@ -45,7 +46,7 @@ def test_pipeline(self): dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1) data_layer = nemo.backends.pytorch.common.MultiDataLayer( - data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode=DataCombination.ZIP ) x_0, y_0, x_1, y_1 = data_layer() diff --git a/tests/unit/test_unit_multidataset.py b/tests/unit/test_unit_multidataset.py index f356c61422db..9d8384df8ac4 100644 --- a/tests/unit/test_unit_multidataset.py +++ b/tests/unit/test_unit_multidataset.py @@ -24,6 +24,7 @@ import torch import nemo +from nemo.backends.pytorch.common import DataCombination from nemo.core import ChannelType, NeuralType logging = nemo.logging @@ -54,7 +55,10 @@ def test_port_name_collision_handling(self): ) data_layer = nemo.backends.pytorch.common.MultiDataLayer( - data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="cross_product" + data_layers=[dl_1, dl_2], + batch_size=batch_size, + shuffle=shuffle, + combination_mode=DataCombination.CROSSPRODUCT, ) self.assertEqual([*data_layer.output_ports], ["a", "b", "a_1", "c"]) self.assertEqual(len(data_layer), dataset_size * dataset_size) @@ -81,7 +85,7 @@ def test_port_renaming(self): data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, - combination_mode="cross_product", + combination_mode=DataCombination.CROSSPRODUCT, port_names=["1", "2", "3", "4"], ) self.assertEqual([*data_layer.output_ports], ["1", "2", "3", "4"]) @@ -107,7 +111,7 @@ def test_multi_dl_zip_working(self): ) data_layer = nemo.backends.pytorch.common.MultiDataLayer( - data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode=DataCombination.ZIP ) self.assertEqual(len(data_layer), final_dataset_size) @@ -132,7 +136,7 @@ def test_multi_dl_zip_failing(self): with pytest.raises(ValueError): data_layer = nemo.backends.pytorch.common.MultiDataLayer( - data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode=DataCombination.ZIP ) @pytest.mark.unit From f7148b505cfa7ad22045af959cb7432f65b96d67 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Wed, 1 Apr 2020 14:36:05 -0700 Subject: [PATCH 5/8] fix lgtm Signed-off-by: Yang Zhang --- nemo/backends/pytorch/common/multi_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nemo/backends/pytorch/common/multi_data.py b/nemo/backends/pytorch/common/multi_data.py index d7ae5cf04c8d..fc0ce3417836 100644 --- a/nemo/backends/pytorch/common/multi_data.py +++ b/nemo/backends/pytorch/common/multi_data.py @@ -60,10 +60,8 @@ def __init__( datasets=[dl.dataset for dl in self._data_layers], combination_mode=combination_mode ) - total_num_port = sum([len(dl.output_ports) for dl in self._data_layers]) self._ports = dict() if self._port_names: - assert (len(self._port_names) == total_num_port, "Number of ports is does not match.") i = 0 for dl in self._data_layers: for _, port_type in dl.output_ports.items(): From d9a618561ec5543281bbdecea89def2cf54726ae Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Mon, 6 Apr 2020 15:26:39 -0700 Subject: [PATCH 6/8] fixing some docstring Signed-off-by: Yang Zhang --- nemo/backends/pytorch/common/multi_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/backends/pytorch/common/multi_data.py b/nemo/backends/pytorch/common/multi_data.py index fc0ce3417836..23962f10e285 100644 --- a/nemo/backends/pytorch/common/multi_data.py +++ b/nemo/backends/pytorch/common/multi_data.py @@ -104,7 +104,7 @@ def __init__( ): """ Datasets: list of torch.utils.data.Dataset objects. - combination_mode: str, defines how to combine the datasets, Options are ["cross_product", "zip"]. + combination_mode: DataCombination, defines how to combine the datasets, Options are [DataCombination.CROSSPRODUCT, DataCombination.ZIP]. """ self.datasets = datasets self.combination_mode = combination_mode @@ -120,7 +120,7 @@ def __init__( def __getitem__(self, i): """ - Returns tuple (x1, x2, ...xn) where x1 \in D1, x2 \in D2, ...xn\ Dn + Returns list [x1, x2, ...xn] where x1 \in D1, x2 \in D2, ...xn\ Dn """ return [x for d in self.datasets for x in d[i % len(d)]] @@ -128,7 +128,7 @@ def __getitem__(self, i): def __len__(self): """ Returns length of this dataset (int). - In case of combination_mode="cross_product" this would be prod(len(d) for d in self.datasets). - In case of combination_mode="zip" this would be min(len(d) for d in self.datasets) given that all datasets have same length. + In case of DataCombination.CROSSPRODUCT this would be prod(len(d) for d in self.datasets). + In case of DataCombination.ZIP this would be min(len(d) for d in self.datasets) given that all datasets have same length. """ return self.len From 38eee8b90692c5baad1a5371d65b3164cd51bc96 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Thu, 9 Apr 2020 16:04:15 -0700 Subject: [PATCH 7/8] typo fix Signed-off-by: Yang Zhang --- nemo/backends/pytorch/common/multi_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/backends/pytorch/common/multi_data.py b/nemo/backends/pytorch/common/multi_data.py index 23962f10e285..b51b23587aa1 100644 --- a/nemo/backends/pytorch/common/multi_data.py +++ b/nemo/backends/pytorch/common/multi_data.py @@ -120,7 +120,7 @@ def __init__( def __getitem__(self, i): """ - Returns list [x1, x2, ...xn] where x1 \in D1, x2 \in D2, ...xn\ Dn + Returns list [x1, x2, ...xn] where x1 \in D1, x2 \in D2, ..., xn \in Dn """ return [x for d in self.datasets for x in d[i % len(d)]] From 28a1ac2eb61954793ff3de52447f6efb88a49389 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Thu, 9 Apr 2020 16:25:26 -0700 Subject: [PATCH 8/8] rebase on master and update changelos Signed-off-by: Yang Zhang --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c307f47afbe7..ec08fef092c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,8 @@ To release a new version, please update the changelog as followed: ## [Unreleased] ### Added +- Added multi-dataset data-layer and dataset. +([PR #538](https://github.com/NVIDIA/NeMo/pull/538)) - @yzhang123 ### Changed