diff --git a/nemo/backends/pytorch/common/__init__.py b/nemo/backends/pytorch/common/__init__.py index adf89ab704d3..c80017b33e10 100644 --- a/nemo/backends/pytorch/common/__init__.py +++ b/nemo/backends/pytorch/common/__init__.py @@ -1,4 +1,5 @@ from nemo.backends.pytorch.common.losses import * +from nemo.backends.pytorch.common.multi_data import * from nemo.backends.pytorch.common.other import * from nemo.backends.pytorch.common.parts import * from nemo.backends.pytorch.common.rnn import * diff --git a/nemo/backends/pytorch/common/multi_data.py b/nemo/backends/pytorch/common/multi_data.py new file mode 100644 index 000000000000..ab454b67a306 --- /dev/null +++ b/nemo/backends/pytorch/common/multi_data.py @@ -0,0 +1,128 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from typing import List + +import numpy as np +import torch + +from nemo import logging +from nemo.backends.pytorch.nm import DataLayerNM +from nemo.core.neural_types import * + +__all__ = ['MultiDataLayer'] + + +class MultiDataLayer(DataLayerNM): + def __init__( + self, + data_layers: List[DataLayerNM], + batch_size: int, + shuffle: bool = False, + combination_mode: str = "cross_product", + port_names: List[str] = None, + ): + """ + data_layers: (list) of DataLayerNM objects + batch_size: (int) batchsize when the underlying dataset is loaded + combination_mode: (str) defines how to combine the datasets, Options are ["cross_product", "zip"]. + shuffle: (bool) whether underlying multi dataset should be shuffled in each epoch + port_names: List(str) user can override all port names if specified + """ + super().__init__() + self._data_layers = data_layers + self._batch_size = batch_size + self._shuffle = shuffle + self._combination_mode = combination_mode + self._port_names = port_names + self._dataset = MultiDataset( + datasets=[dl.dataset for dl in self._data_layers], combination_mode=combination_mode + ) + + total_num_port = sum([len(dl.output_ports) for dl in self._data_layers]) + self._ports = dict() + if self._port_names: + assert (len(self._port_names) == total_num_port, "Number of ports is does not match.") + i = 0 + for dl in self._data_layers: + for _, port_type in dl.output_ports.items(): + self._ports[self._port_names[i]] = port_type + i += 1 + else: + for dl_idx, dl in enumerate(self._data_layers): + for port_name, port_type in dl.output_ports.items(): + if port_name in self._ports: + logging.warning(f"name collision {port_name}, will rename") + self._ports[f"{port_name}_{dl_idx}"] = port_type + else: + self._ports[port_name] = port_type + + @property + def output_ports(self): + """Return: dict + Returns union of all individual data_layer output ports + In case of name collision, resolve by renaming + """ + return self._ports + + def __len__(self): + return len(self._dataset) + + @property + def dataset(self): + return self._dataset + + @property + def data_iterator(self): + return None + + +class MultiDataset(torch.utils.data.Dataset): + def __init__(self, datasets: List[torch.utils.data.Dataset], combination_mode: str = "cross_product"): + """ + Datasets: list of torch.utils.data.Dataset objects. + combination_mode: str, defines how to combine the datasets, Options are ["cross_product", "zip"]. + """ + self.datasets = datasets + self.combination_mode = combination_mode + self.len = None + + def __getitem__(self, i): + """ + Returns tuple (x1, x2, ...xn) where x1 \in D1, x2 \in D2, ...xn\ Dn + """ + + return [x for d in self.datasets for x in d[i % len(d)]] + + def __len__(self): + """ + Returns length of this dataset (int). + In case of combination_mode="cross_product" this would be prod(len(d) for d in self.datasets). + In case of combination_mode="zip" this would be min(len(d) for d in self.datasets) given that all datasets have same length. + """ + if not self.len: + if self.combination_mode == "cross_product": + self.len = np.prod([len(d) for d in self.datasets]) + elif self.combination_mode == "zip": + ds_lens = [len(d) for d in self.datasets] + self.len = np.min(ds_lens) + if not np.all(ds_lens): + logging.warning("datasets do not have equal lengths and will be pruned to the shortest length.") + else: + raise ValueError("combination_mode unknown") + return self.len diff --git a/nemo/backends/pytorch/tutorials/toys.py b/nemo/backends/pytorch/tutorials/toys.py index 25fa8bd7c277..817658cd4518 100644 --- a/nemo/backends/pytorch/tutorials/toys.py +++ b/nemo/backends/pytorch/tutorials/toys.py @@ -158,10 +158,8 @@ def __init__(self, batch_size, f_name="sin", n=1000, x_lo=-4, x_hi=4): x_data = t.tensor(np.random.uniform(low=x_lo, high=x_hi, size=self._n)).unsqueeze(-1).to(self._device) y_data = func(x_data) - - self._data_iterator = t_utils.DataLoader( - t_utils.TensorDataset(x_data.float(), y_data.float()), batch_size=self._batch_size, - ) + self._dataset = t_utils.TensorDataset(x_data.float(), y_data.float()) + self._data_iterator = t_utils.DataLoader(self._dataset, batch_size=self._batch_size,) @property def data_iterator(self): @@ -169,7 +167,7 @@ def data_iterator(self): @property def dataset(self): - return None + return self._dataset class MSELoss(LossNM): diff --git a/tests/unclassified/test_unclassified_multidataset.py b/tests/unclassified/test_unclassified_multidataset.py new file mode 100644 index 000000000000..21fe9aa81bd4 --- /dev/null +++ b/tests/unclassified/test_unclassified_multidataset.py @@ -0,0 +1,143 @@ +# ! /usr/bin/python +# -*- coding: utf-8 -*- + +# Copyright 2020 NVIDIA. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +import os +import shutil +from unittest import TestCase + +import pytest +import torch + +import nemo +from nemo.core import ChannelType, LabelsType, MaskType, NeuralType + +logging = nemo.logging + + +@pytest.mark.usefixtures("neural_factory") +class TestMultiDL(TestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + @pytest.mark.unclassified + def test_port_name_collision_handling(self): + batch_size = 4 + dataset_size = 4 + shuffle = False + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())}, + ) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="cross_product" + ) + self.assertEqual([*data_layer.output_ports], ["a", "b", "a_1", "c"]) + self.assertEqual(len(data_layer), dataset_size * dataset_size) + + @pytest.mark.unclassified + def test_port_renaming(self): + batch_size = 4 + dataset_size = 4 + shuffle = False + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], + batch_size=batch_size, + shuffle=shuffle, + combination_mode="cross_product", + port_names=["1", "2", "3", "4"], + ) + self.assertEqual([*data_layer.output_ports], ["1", "2", "3", "4"]) + + @pytest.mark.unclassified + def test_multi_dl_zip(self): + batch_size = 4 + dataset_size_0 = 4 + dataset_size_1 = 5 + shuffle = False + dl_1 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_0, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "b": NeuralType(('B', 'T'), ChannelType())}, + ) + dl_2 = nemo.backends.pytorch.common.ZerosDataLayer( + size=dataset_size_1, + dtype=torch.FloatTensor, + batch_size=batch_size, + output_ports={"a": NeuralType(('B', 'T'), ChannelType()), "c": NeuralType(('B', 'T'), ChannelType())}, + ) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + ) + self.assertEqual(len(data_layer), dataset_size_0) + + @pytest.mark.unclassified + def test_pipeline(self): + batch_size = 4 + dataset_size_0 = 100 + dataset_size_1 = 100 + shuffle = False + dl_1 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_0) + dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1) + + data_layer = nemo.backends.pytorch.common.MultiDataLayer( + data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode="zip" + ) + x_0, y_0, x_1, y_1 = data_layer() + + trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4) + loss = nemo.backends.pytorch.tutorials.MSELoss() + combined_loss = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2) + pred_0 = trainable_module(x=x_0) + pred_1 = trainable_module(x=x_1) + l_0 = loss(predictions=pred_0, target=y_0) + l_1 = loss(predictions=pred_1, target=y_1) + total_loss = combined_loss(loss_1=l_0, loss_2=l_1) + + callback = nemo.core.SimpleLossLoggerCallback( + tensors=[total_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'), + ) + # Instantiate an optimizer to perform `train` action + optimizer = nemo.backends.pytorch.actions.PtActions() + optimizer.train( + tensors_to_optimize=[total_loss], optimizer="sgd", optimization_params={"lr": 0.0003, "num_epochs": 1}, + )