Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use array_api compliant dtype #8933

Draft
wants to merge 52 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
8e23089
Update _typing.py
Illviljan Apr 12, 2024
7f50c5a
Update _typing.py
Illviljan Apr 12, 2024
6636854
Update core.py
Illviljan Apr 12, 2024
bff3437
Use strict dtype, lets' how much breaks
Illviljan Apr 12, 2024
7028a53
Update _typing.py
Illviljan Apr 12, 2024
8a4f62c
Update _typing.py
Illviljan Apr 12, 2024
adda1ba
Update _typing.py
Illviljan Apr 12, 2024
cd095f2
Update _typing.py
Illviljan Apr 12, 2024
519702c
Update _typing.py
Illviljan Apr 12, 2024
6c64a4a
Update pycompat.py
Illviljan Apr 12, 2024
65f1c78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2024
64ef10b
Type hint run on parallelcompat
Illviljan Apr 14, 2024
937723f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
e83da1f
more
Illviljan Apr 14, 2024
5c4ecab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
c2fb035
Update daskmanager.py
Illviljan Apr 14, 2024
6b525ae
Update parallelcompat.py
Illviljan Apr 14, 2024
587a73a
Update daskmanager.py
Illviljan Apr 14, 2024
f25581a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2024
f551679
Merge branch 'main' into namedarray_dtype_type
Illviljan Apr 30, 2024
5492336
test
Illviljan May 15, 2024
4771672
Merge branch 'main' into namedarray_dtype_type
Illviljan May 15, 2024
fb7ccfa
Update pycompat.py
Illviljan May 15, 2024
edab4ba
Merge branch 'main' into namedarray_dtype_type
Illviljan Jun 5, 2024
9140f36
Update daskmanager.py
Illviljan Jun 7, 2024
d1ed614
Merge branch 'main' into namedarray_dtype_type
Illviljan Jul 9, 2024
bbfbca3
Update pycompat.py
Illviljan Jul 9, 2024
e57d85d
Update pycompat.py
Illviljan Jul 9, 2024
d601942
Update pycompat.py
Illviljan Jul 9, 2024
b8da922
Update pycompat.py
Illviljan Jul 9, 2024
7222ae4
Let's try undoing
Illviljan Jul 9, 2024
dc784eb
Update _typing.py
Illviljan Jul 9, 2024
ce879ca
Update _typing.py
Illviljan Jul 9, 2024
6cc75cc
Update _typing.py
Illviljan Jul 9, 2024
c2ee93c
chunkmanager doesn't use generic anymore
Illviljan Jul 9, 2024
21f9049
Update core.py
Illviljan Jul 9, 2024
d2e360d
fix
Illviljan Jul 9, 2024
46b0218
Update times.py
Illviljan Jul 9, 2024
c17adf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2024
17d49e5
use duck_array_ops.ravel
Illviljan Jul 9, 2024
7a659f2
fix
Illviljan Jul 9, 2024
7ddef25
Update dataset.py
Illviljan Jul 9, 2024
121774d
Update test_parallelcompat.py
Illviljan Jul 9, 2024
9c619a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2024
36831ce
Update test_parallelcompat.py
Illviljan Jul 9, 2024
0ec9be7
Update times.py
Illviljan Jul 9, 2024
1e35da3
Update _typing.py
Illviljan Jul 9, 2024
6930c22
Update test_coding_times.py
Illviljan Jul 9, 2024
96dde9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2024
229c183
Merge branch 'main' into namedarray_dtype_type
Illviljan Jul 10, 2024
6f8a29c
use duckarray assertions instead
Illviljan Jul 10, 2024
f2903f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 64 additions & 30 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Hashable
from datetime import datetime, timedelta
from functools import partial
from typing import TYPE_CHECKING, Callable, Union
from typing import TYPE_CHECKING, Callable, Union, overload

import numpy as np
import pandas as pd
Expand All @@ -22,13 +22,19 @@
)
from xarray.core import indexing
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
from xarray.core.duck_array_ops import asarray
from xarray.core.duck_array_ops import asarray, ravel
from xarray.core.formatting import first_n_items, format_timestamp, last_item
from xarray.core.pdcompat import nanosecond_precision_timestamp
from xarray.core.utils import emit_user_level_warning
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import T_ChunkedArray, get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray._typing import (
_chunkedarrayfunction_or_api,
chunkedduckarray,
duckarray,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type

# from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray.utils import is_duck_dask_array

try:
Expand All @@ -37,7 +43,7 @@
cftime = None

if TYPE_CHECKING:
from xarray.core.types import CFCalendar, T_DuckArray
from xarray.core.types import CFCalendar

T_Name = Union[Hashable, None]

Expand Down Expand Up @@ -315,7 +321,7 @@ def decode_cf_datetime(
cftime.num2date
"""
num_dates = np.asarray(num_dates)
flat_num_dates = num_dates.ravel()
flat_num_dates = ravel(num_dates)
if calendar is None:
calendar = "standard"

Expand Down Expand Up @@ -369,7 +375,7 @@ def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray:
"""
num_timedeltas = np.asarray(num_timedeltas)
units = _netcdf_to_numpy_timeunit(units)
result = to_timedelta_unboxed(num_timedeltas.ravel(), unit=units)
result = to_timedelta_unboxed(ravel(num_timedeltas), unit=units)
return result.reshape(num_timedeltas.shape)


Expand Down Expand Up @@ -428,7 +434,7 @@ def infer_datetime_units(dates) -> str:
'hours', 'minutes' or 'seconds' (the first one that can evenly divide all
unique time deltas in `dates`)
"""
dates = np.asarray(dates).ravel()
dates = ravel(np.asarray(dates))
if np.asarray(dates).dtype == "datetime64[ns]":
dates = to_datetime_unboxed(dates)
dates = dates[pd.notnull(dates)]
Expand Down Expand Up @@ -456,7 +462,7 @@ def infer_timedelta_units(deltas) -> str:
{'days', 'hours', 'minutes' 'seconds'} (the first one that can evenly
divide all unique time deltas in `deltas`)
"""
deltas = to_timedelta_unboxed(np.asarray(deltas).ravel())
deltas = to_timedelta_unboxed(ravel(np.asarray(deltas)))
unique_timedeltas = np.unique(deltas[pd.notnull(deltas)])
return _infer_time_units_from_diff(unique_timedeltas)

Expand Down Expand Up @@ -643,7 +649,7 @@ def encode_datetime(d):
except TypeError:
return np.nan if d is None else cftime.date2num(d, units, calendar)

return np.array([encode_datetime(d) for d in dates.ravel()]).reshape(dates.shape)
return np.array([encode_datetime(d) for d in ravel(dates)]).reshape(dates.shape)


def cast_to_int_if_safe(num) -> np.ndarray:
Expand Down Expand Up @@ -700,12 +706,26 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
return cast_num


@overload
def encode_cf_datetime(
dates: chunkedduckarray,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[chunkedduckarray, str, str]: ...
@overload
def encode_cf_datetime(
dates: duckarray,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[duckarray, str, str]: ...
def encode_cf_datetime(
dates: T_DuckArray, # type: ignore
dates: duckarray | chunkedduckarray,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[T_DuckArray, str, str]:
) -> tuple[duckarray | chunkedduckarray, str, str]:
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.

Expand All @@ -716,19 +736,19 @@ def encode_cf_datetime(
cftime.date2num
"""
dates = asarray(dates)
if is_chunked_array(dates):
if isinstance(dates, _chunkedarrayfunction_or_api):
return _lazily_encode_cf_datetime(dates, units, calendar, dtype)
else:
return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)


def _eagerly_encode_cf_datetime(
dates: T_DuckArray, # type: ignore
dates: duckarray,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
allow_units_modification: bool = True,
) -> tuple[T_DuckArray, str, str]:
) -> tuple[duckarray, str, str]:
dates = asarray(dates)

data_units = infer_datetime_units(dates)
Expand All @@ -753,7 +773,7 @@ def _eagerly_encode_cf_datetime(
# Wrap the dates in a DatetimeIndex to do the subtraction to ensure
# an OverflowError is raised if the ref_date is too far away from
# dates to be encoded (GH 2272).
dates_as_index = pd.DatetimeIndex(dates.ravel())
dates_as_index = pd.DatetimeIndex(ravel(dates))
time_deltas = dates_as_index - ref_date

# retrieve needed units to faithfully encode to int64
Expand Down Expand Up @@ -806,23 +826,23 @@ def _eagerly_encode_cf_datetime(


def _encode_cf_datetime_within_map_blocks(
dates: T_DuckArray, # type: ignore
dates: duckarray,
units: str,
calendar: str,
dtype: np.dtype,
) -> T_DuckArray:
) -> duckarray:
num, *_ = _eagerly_encode_cf_datetime(
dates, units, calendar, dtype, allow_units_modification=False
)
return num


def _lazily_encode_cf_datetime(
dates: T_ChunkedArray,
dates: chunkedduckarray,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[T_ChunkedArray, str, str]:
) -> tuple[chunkedduckarray, str, str]:
if calendar is None:
# This will only trigger minor compute if dates is an object dtype array.
calendar = infer_calendar_name(dates)
Expand Down Expand Up @@ -855,31 +875,43 @@ def _lazily_encode_cf_datetime(
return num, units, calendar


@overload
def encode_cf_timedelta(
timedeltas: T_DuckArray, # type: ignore
timedeltas: chunkedduckarray,
units: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[T_DuckArray, str]:
) -> tuple[chunkedduckarray, str]: ...
@overload
def encode_cf_timedelta(
timedeltas: duckarray,
units: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[duckarray, str]: ...
def encode_cf_timedelta(
timedeltas: chunkedduckarray | duckarray,
units: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[chunkedduckarray | duckarray, str]:
timedeltas = asarray(timedeltas)
if is_chunked_array(timedeltas):
if isinstance(timedeltas, _chunkedarrayfunction_or_api):
return _lazily_encode_cf_timedelta(timedeltas, units, dtype)
else:
return _eagerly_encode_cf_timedelta(timedeltas, units, dtype)


def _eagerly_encode_cf_timedelta(
timedeltas: T_DuckArray, # type: ignore
timedeltas: duckarray,
units: str | None = None,
dtype: np.dtype | None = None,
allow_units_modification: bool = True,
) -> tuple[T_DuckArray, str]:
) -> tuple[duckarray, str]:
data_units = infer_timedelta_units(timedeltas)

if units is None:
units = data_units

time_delta = _time_units_to_timedelta64(units)
time_deltas = pd.TimedeltaIndex(timedeltas.ravel())
time_deltas = pd.TimedeltaIndex(ravel(timedeltas))

# retrieve needed units to faithfully encode to int64
needed_units = data_units
Expand Down Expand Up @@ -920,19 +952,21 @@ def _eagerly_encode_cf_timedelta(


def _encode_cf_timedelta_within_map_blocks(
timedeltas: T_DuckArray, # type:ignore
timedeltas: duckarray,
units: str,
dtype: np.dtype,
) -> T_DuckArray:
) -> duckarray:
num, _ = _eagerly_encode_cf_timedelta(
timedeltas, units, dtype, allow_units_modification=False
)
return num


def _lazily_encode_cf_timedelta(
timedeltas: T_ChunkedArray, units: str | None = None, dtype: np.dtype | None = None
) -> tuple[T_ChunkedArray, str]:
timedeltas: chunkedduckarray,
units: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[chunkedduckarray, str]:
if units is None and dtype is None:
units = "nanoseconds"
dtype = np.dtype("int64")
Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike):
if is_chunked_array(array):
chunkmanager = get_chunked_array_type(array)

return chunkmanager.map_blocks(func, array, dtype=dtype) # type: ignore[arg-type]
return chunkmanager.map_blocks(func, array, dtype=dtype)
else:
return _ElementwiseFunctionArray(array, func, dtype)

Expand Down
2 changes: 2 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from xarray.core.types import Dims, T_DataArray
from xarray.core.utils import is_dict_like, is_duck_dask_array, is_scalar, parse_dims
from xarray.core.variable import Variable
from xarray.namedarray._typing import chunkedduckarray
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
from xarray.util.deprecation_helpers import deprecate_dims
Expand Down Expand Up @@ -795,6 +796,7 @@ def apply_variable_ufunc(
)

def func(*arrays):
res: chunkedduckarray | tuple[chunkedduckarray, ...]
res = chunkmanager.apply_gufunc(
numpy_func,
signature.to_gufunc_string(exclude_dims),
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@
T_Xarray,
)
from xarray.core.weighted import DatasetWeighted
from xarray.namedarray._typing import duckarray
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint


Expand Down Expand Up @@ -860,7 +861,7 @@ def load(self, **kwargs) -> Self:
chunkmanager = get_chunked_array_type(*lazy_data.values())

# evaluate all the chunked arrays simultaneously
evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
evaluated_data: tuple[duckarray[Any, Any], ...] = chunkmanager.compute(
*lazy_data.values(), **kwargs
)

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ def _arrayize_vectorized_indexer(


def _chunked_array_with_chunks_hint(
array, chunks, chunkmanager: ChunkManagerEntrypoint[Any]
array, chunks, chunkmanager: ChunkManagerEntrypoint
):
"""Create a chunked array using the chunks hint for dimensions of size > 1."""

Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2526,7 +2526,7 @@ def chunk( # type: ignore[override]
name: str | None = None,
lock: bool | None = None,
inline_array: bool | None = None,
chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None,
chunked_array_type: str | ChunkManagerEntrypoint | None = None,
from_array_kwargs: Any = None,
**chunks_kwargs: Any,
) -> Self:
Expand Down
Loading
Loading