Skip to content

Commit

Permalink
Merge pull request #254 from tofunlp/feature/return-slice-iterabledat…
Browse files Browse the repository at this point in the history
…aset

Subset (slice) of iterabledataset should be iterabledataset, not list.
  • Loading branch information
yasufumy authored Oct 31, 2021
2 parents 5db77b3 + bfe6f55 commit 6da749d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lineflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ def get_example(self, i: int) -> Any:
def __len__(self) -> int:
return super(IterableDataset, self).__len__()

def __getitem__(self, index: Union[int, slice]) -> Union[Any, List[Any]]:
out = super(IterableDataset, self).__getitem__(index)
if isinstance(out, list):
out = Dataset(out)
return out


class ConcatDataset(Dataset):
def __init__(self, *datasets: List[DatasetMixin]) -> None:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ def test_dunder_len(self):
self.assertEqual(self.data._length, len(self.base))
self.assertTrue(self.data._computed)

def test_dunder_slice(self):
# Subset of IterableDataset should be Dataset.
subset = self.data[:10]
self.assertIsInstance(subset, Dataset)

single_sample = self.data[0]
self.assertIsInstance(single_sample, int)


class DatasetTestCase(TestCase):

Expand Down

0 comments on commit 6da749d

Please sign in to comment.