Skip to content

Commit

Permalink
multi-dataset (#538)
Browse files Browse the repository at this point in the history
* added multi-data and pytests

* fix toy real function data layer

* rebase on master and update changelos

Signed-off-by: Yang Zhang <[email protected]>
  • Loading branch information
yzhang123 authored Apr 10, 2020
1 parent 3749e57 commit cbf16b6
Show file tree
Hide file tree
Showing 6 changed files with 375 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ To release a new version, please update the changelog as followed:
## [Unreleased]

### Added
- Added multi-dataset data-layer and dataset.
([PR #538](https://github.com/NVIDIA/NeMo/pull/538)) - @yzhang123

### Changed

Expand Down
1 change: 1 addition & 0 deletions nemo/backends/pytorch/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nemo.backends.pytorch.common.losses import *
from nemo.backends.pytorch.common.multi_data import *
from nemo.backends.pytorch.common.other import *
from nemo.backends.pytorch.common.parts import *
from nemo.backends.pytorch.common.rnn import *
Expand Down
134 changes: 134 additions & 0 deletions nemo/backends/pytorch/common/multi_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

from enum import Enum
from typing import List

import numpy as np
import torch

from nemo import logging
from nemo.backends.pytorch.nm import DataLayerNM
from nemo.core.neural_types import *

__all__ = ['MultiDataLayer', 'DataCombination']


class DataCombination(Enum):
CROSSPRODUCT = 1
ZIP = 2


class MultiDataLayer(DataLayerNM):
def __init__(
self,
data_layers: List[DataLayerNM],
batch_size: int,
shuffle: bool = False,
combination_mode: DataCombination = DataCombination.CROSSPRODUCT,
port_names: List[str] = None,
):
"""
data_layers: (list) of DataLayerNM objects
batch_size: (int) batchsize when the underlying dataset is loaded
combination_mode: (DataCombination) defines how to combine the datasets.
shuffle: (bool) whether underlying multi dataset should be shuffled in each epoch
port_names: List(str) user can override all port names if specified
"""
super().__init__()
self._data_layers = data_layers
self._batch_size = batch_size
self._shuffle = shuffle
self._combination_mode = combination_mode
self._port_names = port_names
self._dataset = MultiDataset(
datasets=[dl.dataset for dl in self._data_layers], combination_mode=combination_mode
)

self._ports = dict()
if self._port_names:
i = 0
for dl in self._data_layers:
for _, port_type in dl.output_ports.items():
self._ports[self._port_names[i]] = port_type
i += 1
else:
for dl_idx, dl in enumerate(self._data_layers):
for port_name, port_type in dl.output_ports.items():
if port_name in self._ports:
logging.warning(f"name collision {port_name}, will rename")
self._ports[f"{port_name}_{dl_idx}"] = port_type
else:
self._ports[port_name] = port_type

@property
def output_ports(self):
"""Return: dict
Returns union of all individual data_layer output ports
In case of name collision, resolve by renaming
"""
return self._ports

def __len__(self):
return len(self._dataset)

@property
def dataset(self):
return self._dataset

@property
def data_iterator(self):
return None


class MultiDataset(torch.utils.data.Dataset):
def __init__(
self,
datasets: List[torch.utils.data.Dataset],
combination_mode: DataCombination = DataCombination.CROSSPRODUCT,
):
"""
Datasets: list of torch.utils.data.Dataset objects.
combination_mode: DataCombination, defines how to combine the datasets, Options are [DataCombination.CROSSPRODUCT, DataCombination.ZIP].
"""
self.datasets = datasets
self.combination_mode = combination_mode
if self.combination_mode == DataCombination.CROSSPRODUCT:
self.len = np.prod([len(d) for d in self.datasets])
elif self.combination_mode == DataCombination.ZIP:
ds_lens = [len(d) for d in self.datasets]
self.len = np.min(ds_lens)
if len(set(ds_lens)) != 1:
raise ValueError("datasets do not have equal lengths.")
else:
raise ValueError("combination_mode unknown")

def __getitem__(self, i):
"""
Returns list [x1, x2, ...xn] where x1 \in D1, x2 \in D2, ..., xn \in Dn
"""

return [x for d in self.datasets for x in d[i % len(d)]]

def __len__(self):
"""
Returns length of this dataset (int).
In case of DataCombination.CROSSPRODUCT this would be prod(len(d) for d in self.datasets).
In case of DataCombination.ZIP this would be min(len(d) for d in self.datasets) given that all datasets have same length.
"""
return self.len
11 changes: 4 additions & 7 deletions nemo/backends/pytorch/tutorials/toys.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,19 @@ def __init__(self, batch_size, f_name="sin", n=1000, x_lo=-4, x_hi=4):

self._n = n
self._batch_size = batch_size
self._device = t.device("cuda" if self.placement == DeviceType.GPU else "cpu")

x_data = t.tensor(np.random.uniform(low=x_lo, high=x_hi, size=self._n)).unsqueeze(-1).to(self._device)
x_data = t.tensor(np.random.uniform(low=x_lo, high=x_hi, size=self._n)).unsqueeze(-1)
y_data = func(x_data)

self._data_iterator = t_utils.DataLoader(
t_utils.TensorDataset(x_data.float(), y_data.float()), batch_size=self._batch_size,
)
self._dataset = t_utils.TensorDataset(x_data.float(), y_data.float())
self._data_iterator = t_utils.DataLoader(self._dataset, batch_size=self._batch_size,)

@property
def data_iterator(self):
return self._data_iterator

@property
def dataset(self):
return None
return self._dataset


class MSELoss(LossNM):
Expand Down
69 changes: 69 additions & 0 deletions tests/integration/test_integration_multidataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# ! /usr/bin/python
# -*- coding: utf-8 -*-

# Copyright 2020 NVIDIA. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================

import os
import shutil
from unittest import TestCase

import pytest
import torch

import nemo
from nemo.backends.pytorch.common import DataCombination
from nemo.core import ChannelType, NeuralType

logging = nemo.logging


@pytest.mark.usefixtures("neural_factory")
class TestMultiDLIntegration(TestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()

@pytest.mark.integration
def test_pipeline(self):
batch_size = 4
dataset_size_0 = 100
dataset_size_1 = 100
shuffle = False
dl_1 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_0)
dl_2 = nemo.backends.pytorch.tutorials.RealFunctionDataLayer(batch_size=batch_size, n=dataset_size_1)

data_layer = nemo.backends.pytorch.common.MultiDataLayer(
data_layers=[dl_1, dl_2], batch_size=batch_size, shuffle=shuffle, combination_mode=DataCombination.ZIP
)
x_0, y_0, x_1, y_1 = data_layer()

trainable_module = nemo.backends.pytorch.tutorials.TaylorNet(dim=4)
loss = nemo.backends.pytorch.tutorials.MSELoss()
combined_loss = nemo.backends.pytorch.common.losses.LossAggregatorNM(num_inputs=2)
pred_0 = trainable_module(x=x_0)
pred_1 = trainable_module(x=x_1)
l_0 = loss(predictions=pred_0, target=y_0)
l_1 = loss(predictions=pred_1, target=y_1)
total_loss = combined_loss(loss_1=l_0, loss_2=l_1)

callback = nemo.core.SimpleLossLoggerCallback(
tensors=[total_loss], print_func=lambda x: logging.info(f'Train Loss: {str(x[0].item())}'),
)
# Instantiate an optimizer to perform `train` action
optimizer = nemo.backends.pytorch.actions.PtActions()
optimizer.train(
tensors_to_optimize=[total_loss], optimizer="sgd", optimization_params={"lr": 0.0003, "max_steps": 2},
)
Loading

0 comments on commit cbf16b6

Please sign in to comment.