Skip to content

Commit

Permalink
fix: Include dtype information in the Partitioning configuration.
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Zilio committed Nov 20, 2024
1 parent 7faa803 commit f2cc109
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 24 deletions.
13 changes: 7 additions & 6 deletions zcollection/merging/tests/test_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def test_update_fs(
"""Test the _update_fs function."""
generator = data.create_test_dataset(delayed=False)
zds = next(generator)
zds_sc = dask_client.scatter(zds)

partition_folder = local_fs.root.joinpath('variable=1')

zattrs = str(partition_folder.joinpath('.zattrs'))
future = dask_client.submit(_update_fs, str(partition_folder),
dask_client.scatter(zds), local_fs.fs)
future = dask_client.submit(_update_fs, str(partition_folder), zds_sc,
local_fs.fs)
dask_client.gather(future)
assert local_fs.exists(zattrs)

Expand All @@ -60,7 +61,7 @@ def test_update_fs(
try:
future = dask_client.submit(_update_fs,
str(partition_folder),
dask_client.scatter(zds),
zds_sc,
local_fs.fs,
synchronizer=ThrowError())
dask_client.gather(future)
Expand All @@ -83,13 +84,13 @@ def test_perform(
zds = next(generator)

path = str(local_fs.root.joinpath('variable=1'))
zds_sc = dask_client.scatter(zds)

future = dask_client.submit(_update_fs, path, dask_client.scatter(zds),
local_fs.fs)
future = dask_client.submit(_update_fs, path, zds_sc, local_fs.fs)
dask_client.gather(future)

future = dask_client.submit(perform,
dask_client.scatter(zds),
zds_sc,
path,
'time',
local_fs.fs,
Expand Down
4 changes: 2 additions & 2 deletions zcollection/partitioning/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def unique_and_check_monotony(arr: ArrayLike) -> tuple[NDArray, NDArray]:
Args:
arr: Array of elements.
is_delayed: If True, the array is delayed.
Returns:
Tuple of unique elements and their indices.
"""
Expand Down Expand Up @@ -331,12 +330,13 @@ def get_config(self) -> dict[str, Any]:
Returns:
The configuration of the partitioning scheme.
"""
config: dict[str, str | None] = {'id': self.ID}
config: dict[str, str | tuple[str, ...] | None] = {'id': self.ID}
slots: Generator[tuple[str, ...]] = (getattr(
_class, '__slots__',
()) for _class in reversed(self.__class__.__mro__))
config.update((attr, getattr(self, attr)) for _class in slots
for attr in _class if not attr.startswith('_'))
config['dtype'] = self._dtype
return config

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions zcollection/partitioning/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,10 @@ def decode(
py_datetime: datetime.datetime = datetime64.astype('M8[s]').item()
return tuple((UNITS[ix], getattr(py_datetime, self._attrs[ix]))
for ix in self._index)

def get_config(self) -> dict[str, Any]:
config = super().get_config()

# dtype are automatically computed by this partitioning
config.pop('dtype')
return config
38 changes: 29 additions & 9 deletions zcollection/partitioning/tests/test_date.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,43 @@ def test_construction() -> None:
Date(('dates', ), 'W')


def test_config():
RESOLUTION_DTYPE_TEST_SET = [
('Y', (('year', 'uint16'), )),
('M', (('year', 'uint16'), ('month', 'uint8'))),
('D', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'))),
('h', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'),
('hour', 'uint8'))),
('m', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'),
('hour', 'uint8'), ('minute', 'uint8'))),
('s', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'),
('hour', 'uint8'), ('minute', 'uint8'), ('second', 'uint8')))
]


@pytest.mark.parametrize('resolution, dtype', RESOLUTION_DTYPE_TEST_SET)
def test_config(resolution, dtype):
"""Test the configuration of the Date class."""
partitioning = Date(('dates', ), 'D')
assert partitioning.dtype() == (('year', 'uint16'), ('month', 'uint8'),
('day', 'uint8'))
partitioning = Date(variables=('dates', ), resolution=resolution)
assert partitioning.dtype() == dtype

config = partitioning.get_config()
partitioning = get_codecs(config)
assert isinstance(partitioning, Date)
other = get_codecs(config)

assert isinstance(other, Date)
assert other.variables == ('dates', )
assert other.dtype() == dtype


def test_pickle():
@pytest.mark.parametrize('resolution, dtype', RESOLUTION_DTYPE_TEST_SET)
def test_pickle(resolution, dtype):
"""Test the pickling of the Date class."""
partitioning = Date(('dates', ), 'D')
partitioning = Date(('dates', ), resolution=resolution)
other = pickle.loads(pickle.dumps(partitioning))

assert isinstance(other, Date)
assert other.resolution == 'D'
assert other.resolution == resolution
assert other.variables == ('dates', )
assert other.dtype() == dtype


@pytest.mark.parametrize('delayed', [False, True])
Expand Down
26 changes: 19 additions & 7 deletions zcollection/partitioning/tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,32 @@ def test_split_dataset(
list(partitioning.split_dataset(zds, 'num_lines'))


def test_config() -> None:
VARIABLES_DTYPE_TEST_SET = [(('a', ), None), (('a', ), ('uint8', )),
(('a', 'b'), None),
(('a', 'b'), ('int8', 'int16'))]


@pytest.mark.parametrize('variables, dtype', VARIABLES_DTYPE_TEST_SET)
def test_config(variables, dtype) -> None:
"""Test the configuration of the Sequence class."""
partitioning = Sequence(('cycle_number', 'pass_number'))
partitioning = Sequence(variables=variables, dtype=dtype)

config = partitioning.get_config()
partitioning = get_codecs(config) # type: ignore[assignment]
assert isinstance(partitioning, Sequence)
other = get_codecs(config) # type: ignore[assignment]

assert isinstance(other, Sequence)
assert other.dtype() == partitioning.dtype()

def test_pickle() -> None:

@pytest.mark.parametrize('variables, dtype', VARIABLES_DTYPE_TEST_SET)
def test_pickle(variables, dtype) -> None:
"""Test the pickling of the Date class."""
partitioning = Sequence(('cycle_number', 'pass_number'))
partitioning = Sequence(variables=variables, dtype=dtype)

other = pickle.loads(pickle.dumps(partitioning))

assert isinstance(other, Sequence)
assert other.variables == ('cycle_number', 'pass_number')
assert other.dtype() == partitioning.dtype()


# pylint: disable=protected-access
Expand Down

0 comments on commit f2cc109

Please sign in to comment.