Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pandas type checking #9213

Merged
merged 28 commits into from
Jul 17, 2024
Merged
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
678536c
remove pandas from ignore missing imports
headtr1ck Jul 1, 2024
20066c4
add Any to dim arg of concat as placeholder
headtr1ck Jul 1, 2024
9053b60
allow sequence of np.ndarrays as coords in dataArray constructor
headtr1ck Jul 1, 2024
27bd0ab
fix several typing issues in tests
headtr1ck Jul 1, 2024
0dcb2a1
fix more types
headtr1ck Jul 1, 2024
a2b5b4c
more fixes
headtr1ck Jul 2, 2024
6a75746
more typing...
headtr1ck Jul 3, 2024
ce657e4
we are getting there?
headtr1ck Jul 3, 2024
5589156
who might have guessed it... more typing
headtr1ck Jul 4, 2024
e45d28b
continue fixing typing issues
headtr1ck Jul 7, 2024
faa8bcc
fix some typed_ops
headtr1ck Jul 7, 2024
2b678b1
fix last non-typed-ops errors
headtr1ck Jul 7, 2024
7b08a94
Merge branch 'pydata:main' into pandas-stubs
headtr1ck Jul 8, 2024
69b93dd
Merge branch 'main' into pandas-stubs
max-sixty Jul 12, 2024
ab4dbf5
update typed ops
headtr1ck Jul 12, 2024
27db395
remove useless DaskArray type in scalar or array type
headtr1ck Jul 12, 2024
5977107
fix missing import in type_checking
headtr1ck Jul 12, 2024
64765c8
fix import
headtr1ck Jul 12, 2024
5b1d98c
improve cftime offsets typing
headtr1ck Jul 12, 2024
16bee67
fix classvars
headtr1ck Jul 12, 2024
602c3cf
fix some checks
headtr1ck Jul 12, 2024
07a52ae
fix a broken test
headtr1ck Jul 13, 2024
a93bba6
Merge branch 'main' into pandas-stubs
headtr1ck Jul 13, 2024
f2f5dfb
Merge branch 'main' into pandas-stubs
headtr1ck Jul 15, 2024
a5eaa82
improve typing of test_concat
headtr1ck Jul 15, 2024
2f12f88
fix broken concat
headtr1ck Jul 15, 2024
4bfec5e
add whats-new
headtr1ck Jul 15, 2024
2f95dbd
Merge branch 'main' into pandas-stubs
headtr1ck Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more typing...
  • Loading branch information
headtr1ck committed Jul 3, 2024
commit 6a75746ec0684ef669f4fdcfc8aef2b5db81e14e
22 changes: 12 additions & 10 deletions xarray/coding/cftime_offsets.py
Original file line number Diff line number Diff line change
@@ -759,16 +759,16 @@ def _emit_freq_deprecation_warning(deprecated_freq):
emit_user_level_warning(message, FutureWarning)


def to_offset(freq, warn=True):
def to_offset(freq: BaseCFTimeOffset | str, warn: bool = True) -> BaseCFTimeOffset:
"""Convert a frequency string to the appropriate subclass of
BaseCFTimeOffset."""
if isinstance(freq, BaseCFTimeOffset):
return freq
else:
try:
freq_data = re.match(_PATTERN, freq).groupdict()
except AttributeError:
raise ValueError("Invalid frequency string provided")

match = re.match(_PATTERN, freq)
if match is None:
raise ValueError("Invalid frequency string provided")
freq_data = match.groupdict()

freq = freq_data["freq"]
if warn and freq in _DEPRECATED_FREQUENICES:
@@ -909,17 +909,19 @@ def _translate_closed_to_inclusive(closed):
return inclusive


def _infer_inclusive(closed, inclusive):
def _infer_inclusive(
closed: NoDefault | SideOptions, inclusive: InclusiveOptions | None
) -> InclusiveOptions:
"""Follows code added in pandas #43504."""
if closed is not no_default and inclusive is not None:
raise ValueError(
"Following pandas, deprecated argument `closed` cannot be "
"passed if argument `inclusive` is not None."
)
if closed is not no_default:
inclusive = _translate_closed_to_inclusive(closed)
elif inclusive is None:
inclusive = "both"
return _translate_closed_to_inclusive(closed)
if inclusive is None:
return "both"
return inclusive


61 changes: 38 additions & 23 deletions xarray/coding/cftimeindex.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@
import re
import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
@@ -64,6 +65,10 @@
except ImportError:
cftime = None

if TYPE_CHECKING:
from xarray.coding.cftime_offsets import BaseCFTimeOffset
from xarray.core.types import Self


# constants for cftimeindex.repr
CFTIME_REPR_LENGTH = 19
@@ -495,34 +500,38 @@ def get_value(self, series, key):
else:
return series.iloc[self.get_loc(key)]

def __contains__(self, key):
def __contains__(self, key: Any) -> bool:
"""Adapted from
pandas.tseries.base.DatetimeIndexOpsMixin.__contains__"""
try:
result = self.get_loc(key)
return (
is_scalar(result)
or type(result) == slice
or (isinstance(result, np.ndarray) and result.size)
or (isinstance(result, np.ndarray) and result.size > 0)
)
except (KeyError, TypeError, ValueError):
return False

def contains(self, key):
def contains(self, key: Any) -> bool:
"""Needed for .loc based partial-string indexing"""
return self.__contains__(key)

def shift(self, n: int | float, freq: str | timedelta):
def shift( # type: ignore[override] # freq is typed Any, we are more precise
self,
periods: int | float,
freq: str | timedelta | BaseCFTimeOffset | None = None,
) -> Self:
"""Shift the CFTimeIndex a multiple of the given frequency.

See the documentation for :py:func:`~xarray.cftime_range` for a
complete listing of valid frequency strings.

Parameters
----------
n : int, float if freq of days or below
periods : int, float if freq of days or below
Periods to shift by
freq : str or datetime.timedelta
freq : str, datetime.timedelta or BaseCFTimeOffset
A frequency string or datetime.timedelta object to shift by

Returns
@@ -546,42 +555,48 @@ def shift(self, n: int | float, freq: str | timedelta):
CFTimeIndex([2000-02-01 12:00:00],
dtype='object', length=1, calendar='standard', freq=None)
"""
if freq is None:
# None type is required to be compatible with base pd.Index class
raise TypeError(
f"`freq` argument cannot be None for {type(self).__name__}.shift"
)

if isinstance(freq, timedelta):
return self + n * freq
elif isinstance(freq, str):
return self + periods * freq

if isinstance(freq, (str, BaseCFTimeOffset)):
from xarray.coding.cftime_offsets import to_offset

return self + n * to_offset(freq)
else:
raise TypeError(
f"'freq' must be of type str or datetime.timedelta, got {freq}."
)
return self + periods * to_offset(freq)

raise TypeError(
f"'freq' must be of type str or datetime.timedelta, got {freq}."
)

def __add__(self, other):
def __add__(self, other) -> Self:
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return CFTimeIndex(np.array(self) + other)
return type(self)(np.array(self) + other)

def __radd__(self, other):
def __radd__(self, other) -> Self:
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return CFTimeIndex(other + np.array(self))
return type(self)(other + np.array(self))

def __sub__(self, other):
if _contains_datetime_timedeltas(other):
return CFTimeIndex(np.array(self) - other)
elif isinstance(other, pd.TimedeltaIndex):
return CFTimeIndex(np.array(self) - other.to_pytimedelta())
elif _contains_cftime_datetimes(np.array(other)):
return type(self)(np.array(self) - other)
if isinstance(other, pd.TimedeltaIndex):
return type(self)(np.array(self) - other.to_pytimedelta())
if _contains_cftime_datetimes(np.array(other)):
try:
return pd.TimedeltaIndex(np.array(self) - np.array(other))
except OUT_OF_BOUNDS_TIMEDELTA_ERRORS:
raise ValueError(
"The time difference exceeds the range of values "
"that can be expressed at the nanosecond resolution."
)
else:
return NotImplemented
return NotImplemented

def __rsub__(self, other):
try:
34 changes: 19 additions & 15 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
from collections.abc import Hashable
from datetime import datetime, timedelta
from functools import partial
from typing import TYPE_CHECKING, Callable, Union
from typing import TYPE_CHECKING, Callable, Literal, Union, cast

import numpy as np
import pandas as pd
@@ -37,7 +37,7 @@
cftime = None

if TYPE_CHECKING:
from xarray.core.types import CFCalendar, T_DuckArray
from xarray.core.types import CFCalendar, NPDatetimeUnitOptions, T_DuckArray

T_Name = Union[Hashable, None]

@@ -111,22 +111,25 @@ def _is_numpy_compatible_time_range(times):
return True


def _netcdf_to_numpy_timeunit(units: str) -> str:
def _netcdf_to_numpy_timeunit(units: str) -> NPDatetimeUnitOptions:
units = units.lower()
if not units.endswith("s"):
units = f"{units}s"
return {
"nanoseconds": "ns",
"microseconds": "us",
"milliseconds": "ms",
"seconds": "s",
"minutes": "m",
"hours": "h",
"days": "D",
}[units]
return cast(
NPDatetimeUnitOptions,
{
"nanoseconds": "ns",
"microseconds": "us",
"milliseconds": "ms",
"seconds": "s",
"minutes": "m",
"hours": "h",
"days": "D",
}[units],
)


def _numpy_to_netcdf_timeunit(units: str) -> str:
def _numpy_to_netcdf_timeunit(units: NPDatetimeUnitOptions) -> str:
return {
"ns": "nanoseconds",
"us": "microseconds",
@@ -252,12 +255,12 @@ def _decode_datetime_with_pandas(
"pandas."
)

time_units, ref_date = _unpack_netcdf_time_units(units)
time_units, ref_date_str = _unpack_netcdf_time_units(units)
time_units = _netcdf_to_numpy_timeunit(time_units)
try:
# TODO: the strict enforcement of nanosecond precision Timestamps can be
# relaxed when addressing GitHub issue #7493.
ref_date = nanosecond_precision_timestamp(ref_date)
ref_date = nanosecond_precision_timestamp(ref_date_str)
except ValueError:
# ValueError is raised by pd.Timestamp for non-ISO timestamp
# strings, in which case we fall back to using cftime
@@ -471,6 +474,7 @@ def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray:
# TODO: the strict enforcement of nanosecond precision datetime values can
# be relaxed when addressing GitHub issue #7493.
new = np.empty(times.shape, dtype="M8[ns]")
dt: pd.Timestamp | Literal["NaT"]
for i, t in np.ndenumerate(times):
try:
# Use pandas.Timestamp in place of datetime.datetime, because
42 changes: 23 additions & 19 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -304,7 +304,7 @@ def _calc_concat_dim_index(

dim: Hashable | None

if isinstance(dim_or_data, str):
if utils.hashable(dim_or_data):
dim = dim_or_data
index = None
else:
@@ -475,7 +475,7 @@ def _parse_datasets(


def _dataset_concat(
datasets: list[T_Dataset],
datasets: Iterable[T_Dataset],
dim: str | T_Variable | T_DataArray | pd.Index,
data_vars: T_DataVars,
coords: str | list[str],
@@ -506,12 +506,14 @@ def _dataset_concat(
else:
dim_var = None

dim, index = _calc_concat_dim_index(dim)
dim_name, index = _calc_concat_dim_index(dim)

# Make sure we're working on a copy (we'll be loading variables)
datasets = [ds.copy() for ds in datasets]
datasets = list(
align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value)
align(
*datasets, join=join, copy=False, exclude=[dim_name], fill_value=fill_value
)
)

dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets(
@@ -525,19 +527,21 @@ def _dataset_concat(
f"{both_data_and_coords!r} is a coordinate in some datasets but not others."
)
# we don't want the concat dimension in the result dataset yet
dim_coords.pop(dim, None)
dims_sizes.pop(dim, None)
dim_coords.pop(dim_name, None)
dims_sizes.pop(dim_name, None)

# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
if (
dim_name in coord_names or dim_name in data_names
) and dim_name not in dim_names:
datasets = [
ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim)
ds.expand_dims(dim_name, create_index_for_new_dim=create_index_for_new_dim)
for ds in datasets
]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
datasets, dim, dim_names, data_vars, coords, compat
datasets, dim_name, dim_names, data_vars, coords, compat
)

# determine which variables to merge, and then merge them according to compat
@@ -577,8 +581,8 @@ def ensure_common_dims(vars, concat_dim_lengths):
# dimensions and the same shape for all of them except along the
# concat dimension
common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims))
if dim not in common_dims:
common_dims = (dim,) + common_dims
if dim_name not in common_dims:
common_dims = (dim_name,) + common_dims
for var, dim_len in zip(vars, concat_dim_lengths):
if var.dims != common_dims:
common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims)
@@ -597,9 +601,9 @@ def get_indexes(name):
elif name == dim:
var = ds._variables[name]
if not var.dims:
data = var.set_dims(dim).values
data = var.set_dims(dim_name).values
if create_index_for_new_dim:
yield PandasIndex(data, dim, coord_dtype=var.dtype)
yield PandasIndex(data, dim_name, coord_dtype=var.dtype)

# create concatenation index, needed for later reindexing
file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths))
@@ -645,7 +649,7 @@ def get_indexes(name):
f"{name!r} must have either an index or no index in all datasets, "
f"found {len(indexes)}/{len(datasets)} datasets with an index."
)
combined_idx = indexes[0].concat(indexes, dim, positions)
combined_idx = indexes[0].concat(indexes, dim_name, positions)
if name in datasets[0]._indexes:
idx_vars = datasets[0].xindexes.get_all_coords(name)
else:
@@ -661,14 +665,14 @@ def get_indexes(name):
result_vars[k] = v
else:
combined_var = concat_vars(
vars, dim, positions, combine_attrs=combine_attrs
vars, dim_name, positions, combine_attrs=combine_attrs
)
# reindex if variable is not present in all datasets
if len(variable_index) < concat_index_size:
combined_var = reindex_variables(
variables={name: combined_var},
dim_pos_indexers={
dim: pd.Index(variable_index).get_indexer(concat_index)
dim_name: pd.Index(variable_index).get_indexer(concat_index)
},
fill_value=fill_value,
)[name]
@@ -694,12 +698,12 @@ def get_indexes(name):

if index is not None:
if dim_var is not None:
index_vars = index.create_variables({dim: dim_var})
index_vars = index.create_variables({dim_name: dim_var})
else:
index_vars = index.create_variables()

coord_vars[dim] = index_vars[dim]
result_indexes[dim] = index
coord_vars[dim_name] = index_vars[dim_name]
result_indexes[dim_name] = index

coords_obj = Coordinates(coord_vars, indexes=result_indexes)

1 change: 1 addition & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
@@ -219,6 +219,7 @@ def copy(
DatetimeUnitOptions = Literal[
"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as", None
]
NPDatetimeUnitOptions = Literal["D", "h", "m", "s", "ms", "us", "ns"]

QueryEngineOptions = Literal["python", "numexpr", None]
QueryParserOptions = Literal["pandas", "python"]
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
@@ -1504,7 +1504,7 @@ def _unstack_once(
# Potentially we could replace `len(other_dims)` with just `-1`
other_dims = [d for d in self.dims if d != dim]
new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes)
new_dims = reordered.dims[: len(other_dims)] + new_dim_names
new_dims = reordered.dims[: len(other_dims)] + tuple(new_dim_names)

create_template: Callable
if fill_value is dtypes.NA: