diff --git a/.github/workflows/run_python_tests.yml b/.github/workflows/run_python_tests.yml index 5f985ee..e8239cc 100644 --- a/.github/workflows/run_python_tests.yml +++ b/.github/workflows/run_python_tests.yml @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 diff --git a/lazy_dataset/__init__.py b/lazy_dataset/__init__.py index d7fecab..a0ca21b 100644 --- a/lazy_dataset/__init__.py +++ b/lazy_dataset/__init__.py @@ -1,3 +1,4 @@ +from . import core from .core import ( new, concatenate, @@ -7,6 +8,7 @@ from_dict, from_list, from_dataset, + from_file, FilterException, ) -from.core import _zip as zip +from .core import _zip as zip diff --git a/lazy_dataset/core.py b/lazy_dataset/core.py index b9ab2d2..549159a 100644 --- a/lazy_dataset/core.py +++ b/lazy_dataset/core.py @@ -201,6 +201,20 @@ def from_dataset( >>> ds = from_dataset(new({'a': 1, 'b': 2, 'c': 3, 'd': 4}).filter(lambda x: x%2)) >>> dict(ds) {'a': 1, 'c': 3} + + # Works with concatenated datasets and duplicated keys + >>> ds = new({'a': 1, 'b': 2}) + >>> ds = concatenate(ds, ds) + >>> ds + DictDataset(len=2) + MapDataset(_pickle.loads) + DictDataset(len=2) + MapDataset(_pickle.loads) + ConcatenateDataset() + >>> from_dataset(ds) + ListDataset(len=4) + MapDataset(_pickle.loads) + """ try: items = list(examples.items()) @@ -208,8 +222,14 @@ def from_dataset( return from_list(list(examples), immutable_warranty=immutable_warranty, name=name) else: - return from_dict(dict(items), - immutable_warranty=immutable_warranty, name=name) + new = dict(items) + if len(new) == len(items): + return from_dict(new, + immutable_warranty=immutable_warranty, name=name) + else: + # Duplicates in keys + return from_list(list(map(operator.itemgetter(1), items)), + immutable_warranty=immutable_warranty, name=name) def concatenate(*datasets): @@ -417,7 +437,10 @@ def copy(self, freeze: bool = False) -> 'Dataset': Returns: A copy of this dataset """ - raise NotImplementedError + raise NotImplementedError( + f'copy is not implemented for {self.__class__}.\n' + f'self: \n{repr(self)}' + ) def __iter__(self, with_key=False): if with_key: @@ -2973,6 +2996,7 @@ def __init__(self, *input_datasets): ] raise AssertionError( f'Expect that all input_datasets have the same keys. ' + f'Missing: {lengths} of {len(keys)}\n' f'Missing keys: ' f'{missing_keys}\n{self.input_datasets}' ) @@ -3067,8 +3091,8 @@ class ItemsDataset(Dataset): >>> ds_nokeys_rng = ds_plain.shuffle(True, rng=np.random.RandomState(0)) # No keys >>> list(ds_nokeys.map(lambda x: x + 10).items()) [('a', 11), ('b', 12), ('c', 13)] - >>> list(ds_nokeys.concatenate(ds_plain).items()) - [('a', 1), ('b', 2), ('c', 3), ('a', 1), ('b', 2), ('c', 3)] + >>> list(ds_nokeys.map(lambda x: x + 10).concatenate(ds_plain).filter(lambda x: x in [1, 12, 13]).items()) + [('b', 12), ('c', 13), ('a', 1)] >>> list(ds_nokeys_rng.intersperse(ds_nokeys_rng).items()) [('c', 3), ('a', 1), ('c', 3), ('c', 3), ('b', 2), ('b', 2)] >>> list(ds_plain.key_zip(ds_plain).items())