Skip to content

Commit

Permalink
fix miscellaneous numpy=2.0 errors (#8117)
Browse files Browse the repository at this point in the history
* replace `np.unicode_` with `np.str_`

* replace `np.NaN` with `np.nan`

* replace more instances of `np.unicode_`

note that with more modern versions of `numpy` the `.astype(np.str_)`
don't actually change the dtype, so maybe we can remove those.

* more instances of renamed / removed dtypes

* more dtype replacements

* use `str.encode(encoding)` instead of `bytes(str, encoding)`

* explicitly import `RankWarning`

* left-over `np.RankWarning`

* use `float` instead of the removed `np.float_`

* ignore missing stubs for `numpy.exceptions`

---------

Co-authored-by: Kai Mühlbauer <[email protected]>
Co-authored-by: Mathias Hauser <[email protected]>
Co-authored-by: Maximilian Roos <[email protected]>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
5 people authored Sep 11, 2023
1 parent 0b3b20a commit 2951ce0
Show file tree
Hide file tree
Showing 18 changed files with 158 additions and 134 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ module = [
"sparse.*",
"toolz.*",
"zarr.*",
"numpy.exceptions.*", # remove once support for `numpy<2.0` has been dropped
]

[[tool.mypy.overrides]]
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _ensure_fill_value_valid(data, attributes):
# work around for netCDF4/scipy issue where _FillValue has the wrong type:
# https://github.com/Unidata/netcdf4-python/issues/271
if data.dtype.kind == "S" and "_FillValue" in attributes:
attributes["_FillValue"] = np.string_(attributes["_FillValue"])
attributes["_FillValue"] = np.bytes_(attributes["_FillValue"])


def _force_native_endianness(var):
Expand Down
6 changes: 3 additions & 3 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def ensure_fixed_length_bytes(var):
dims, data, attrs, encoding = unpack_for_encoding(var)
if check_vlen_dtype(data.dtype) == bytes:
# TODO: figure out how to handle this with dask
data = np.asarray(data, dtype=np.string_)
data = np.asarray(data, dtype=np.bytes_)
return Variable(dims, data, attrs, encoding)


Expand Down Expand Up @@ -151,7 +151,7 @@ def bytes_to_char(arr):
def _numpy_bytes_to_char(arr):
"""Like netCDF4.stringtochar, but faster and more flexible."""
# ensure the array is contiguous
arr = np.array(arr, copy=False, order="C", dtype=np.string_)
arr = np.array(arr, copy=False, order="C", dtype=np.bytes_)
return arr.reshape(arr.shape + (1,)).view("S1")


Expand All @@ -168,7 +168,7 @@ def char_to_bytes(arr):

if not size:
# can't make an S0 dtype
return np.zeros(arr.shape[:-1], dtype=np.string_)
return np.zeros(arr.shape[:-1], dtype=np.bytes_)

if is_chunked_array(arr):
chunkmanager = get_chunked_array_type(arr)
Expand Down
4 changes: 2 additions & 2 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray
Useful to convert between calendars in numpy and cftime or between cftime calendars.
If raise_on_valid is True (default), invalid dates trigger a ValueError.
Otherwise, the invalid element is replaced by np.NaN for cftime types and np.NaT for np.datetime64.
Otherwise, the invalid element is replaced by np.nan for cftime types and np.NaT for np.datetime64.
"""
if date_type in (pd.Timestamp, np.datetime64) and not is_np_datetime_like(
times.dtype
Expand All @@ -489,7 +489,7 @@ def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray
f"{date_type(2000, 1, 1).calendar} calendar. Reason: {e}."
)
else:
dt = np.NaN
dt = np.nan

new[i] = dt
return new
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/accessor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def cat(self, *others, sep: str | bytes | Any = "") -> T_DataArray:
... )
>>> values_2 = np.array(3.4)
>>> values_3 = ""
>>> values_4 = np.array("test", dtype=np.unicode_)
>>> values_4 = np.array("test", dtype=np.str_)
Determine the separator to use
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5772,8 +5772,8 @@ def idxmin(
>>> array = xr.DataArray(
... [
... [2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN],
... [-4.0, np.nan, 2.0, np.nan, -2.0],
... [np.nan, np.nan, 1.0, np.nan, np.nan],
... ],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2},
Expand Down Expand Up @@ -5868,8 +5868,8 @@ def idxmax(
>>> array = xr.DataArray(
... [
... [2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN],
... [-4.0, np.nan, 2.0, np.nan, -2.0],
... [np.nan, np.nan, 1.0, np.nan, np.nan],
... ],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2},
Expand Down
19 changes: 13 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload

import numpy as np

# remove once numpy 2.0 is the oldest supported version
try:
from numpy.exceptions import RankWarning
except ImportError:
from numpy import RankWarning

import pandas as pd

from xarray.coding.calendar_ops import convert_calendar, interp_calendar
Expand Down Expand Up @@ -8785,9 +8792,9 @@ def polyfit(

with warnings.catch_warnings():
if full: # Copy np.polyfit behavior
warnings.simplefilter("ignore", np.RankWarning)
warnings.simplefilter("ignore", RankWarning)
else: # Raise only once per variable
warnings.simplefilter("once", np.RankWarning)
warnings.simplefilter("once", RankWarning)

coeffs, residuals = duck_array_ops.least_squares(
lhs, rhs.data, rcond=rcond, skipna=skipna_da
Expand Down Expand Up @@ -9077,8 +9084,8 @@ def idxmin(
>>> array2 = xr.DataArray(
... [
... [2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN],
... [-4.0, np.nan, 2.0, np.nan, -2.0],
... [np.nan, np.nan, 1.0, np.nan, np.nan],
... ],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]},
Expand Down Expand Up @@ -9174,8 +9181,8 @@ def idxmax(
>>> array2 = xr.DataArray(
... [
... [2.0, 1.0, 2.0, 0.0, -2.0],
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN],
... [-4.0, np.nan, 2.0, np.nan, -2.0],
... [np.nan, np.nan, 1.0, np.nan, np.nan],
... ],
... dims=["y", "x"],
... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]},
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __eq__(self, other):
PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
(np.number, np.character), # numpy promotes to character
(np.bool_, np.character), # numpy promotes to character
(np.bytes_, np.unicode_), # numpy promotes to unicode
(np.bytes_, np.str_), # numpy promotes to unicode
)


Expand Down
2 changes: 1 addition & 1 deletion xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
# scipy.interpolate.interp1d always forces to float.
# Use the same check for blockwise as well:
if not issubclass(var.dtype.type, np.inexact):
dtype = np.float_
dtype = float
else:
dtype = var.dtype

Expand Down
8 changes: 7 additions & 1 deletion xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import pandas as pd
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]

# remove once numpy 2.0 is the oldest supported version
try:
from numpy.exceptions import RankWarning
except ImportError:
from numpy import RankWarning

from xarray.core.options import OPTIONS
from xarray.core.pycompat import is_duck_array

Expand Down Expand Up @@ -194,7 +200,7 @@ def _nanpolyfit_1d(arr, x, rcond=None):

def warn_on_deficient_rank(rank, order):
if rank != order:
warnings.warn("Polyfit may be poorly conditioned", np.RankWarning, stacklevel=2)
warnings.warn("Polyfit may be poorly conditioned", RankWarning, stacklevel=2)


def least_squares(lhs, rhs, rcond=None, skipna=False):
Expand Down
Loading

0 comments on commit 2951ce0

Please sign in to comment.