Skip to content

Commit

Permalink
Merge pull request #66 from boeddeker/master
Browse files Browse the repository at this point in the history
Fix items bug of dataset that doesn't support items
  • Loading branch information
boeddeker authored Dec 3, 2024
2 parents 357b4c6 + 7d7e8d8 commit 0c32ff7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 3 additions & 1 deletion lazy_dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import core
from .core import (
new,
concatenate,
Expand All @@ -7,6 +8,7 @@
from_dict,
from_list,
from_dataset,
from_file,
FilterException,
)
from.core import _zip as zip
from .core import _zip as zip
34 changes: 29 additions & 5 deletions lazy_dataset/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,35 @@ def from_dataset(
>>> ds = from_dataset(new({'a': 1, 'b': 2, 'c': 3, 'd': 4}).filter(lambda x: x%2))
>>> dict(ds)
{'a': 1, 'c': 3}
# Works with concatenated datasets and duplicated keys
>>> ds = new({'a': 1, 'b': 2})
>>> ds = concatenate(ds, ds)
>>> ds
DictDataset(len=2)
MapDataset(_pickle.loads)
DictDataset(len=2)
MapDataset(_pickle.loads)
ConcatenateDataset()
>>> from_dataset(ds)
ListDataset(len=4)
MapDataset(_pickle.loads)
"""
try:
items = list(examples.items())
except ItemsNotDefined:
return from_list(list(examples),
immutable_warranty=immutable_warranty, name=name)
else:
return from_dict(dict(items),
immutable_warranty=immutable_warranty, name=name)
new = dict(items)
if len(new) == len(items):
return from_dict(new,
immutable_warranty=immutable_warranty, name=name)
else:
# Duplicates in keys
return from_list(list(map(operator.itemgetter(1), items)),
immutable_warranty=immutable_warranty, name=name)


def concatenate(*datasets):
Expand Down Expand Up @@ -417,7 +437,10 @@ def copy(self, freeze: bool = False) -> 'Dataset':
Returns:
A copy of this dataset
"""
raise NotImplementedError
raise NotImplementedError(
f'copy is not implemented for {self.__class__}.\n'
f'self: \n{repr(self)}'
)

def __iter__(self, with_key=False):
if with_key:
Expand Down Expand Up @@ -2973,6 +2996,7 @@ def __init__(self, *input_datasets):
]
raise AssertionError(
f'Expect that all input_datasets have the same keys. '
f'Missing: {lengths} of {len(keys)}\n'
f'Missing keys: '
f'{missing_keys}\n{self.input_datasets}'
)
Expand Down Expand Up @@ -3067,8 +3091,8 @@ class ItemsDataset(Dataset):
>>> ds_nokeys_rng = ds_plain.shuffle(True, rng=np.random.RandomState(0)) # No keys
>>> list(ds_nokeys.map(lambda x: x + 10).items())
[('a', 11), ('b', 12), ('c', 13)]
>>> list(ds_nokeys.concatenate(ds_plain).items())
[('a', 1), ('b', 2), ('c', 3), ('a', 1), ('b', 2), ('c', 3)]
>>> list(ds_nokeys.map(lambda x: x + 10).concatenate(ds_plain).filter(lambda x: x in [1, 12, 13]).items())
[('b', 12), ('c', 13), ('a', 1)]
>>> list(ds_nokeys_rng.intersperse(ds_nokeys_rng).items())
[('c', 3), ('a', 1), ('c', 3), ('c', 3), ('b', 2), ('b', 2)]
>>> list(ds_plain.key_zip(ds_plain).items())
Expand Down

0 comments on commit 0c32ff7

Please sign in to comment.