Skip to content

Commit

Permalink
Declare Dataset, DataArray, Variable, GroupBy unhashable (#8392)
Browse files Browse the repository at this point in the history
* Add unhashable to generate_ops

* Regenerate _typed_ops after adding "unhashable"

* Fix variable redefinition

The previous commit revealed the following mypy error:

xarray/core/dataset.py: note: In member "swap_dims" of class "Dataset":
xarray/core/dataset.py:4415: error: Incompatible types in assignment (expression has type "Variable", variable has type "Hashable")  [assignment]
xarray/core/dataset.py:4415: note: Following member(s) of "Variable" have conflicts:
xarray/core/dataset.py:4415: note:     __hash__: expected "Callable[[], int]", got "None"
xarray/core/dataset.py:4416: error: "Hashable" has no attribute "dims"  [attr-defined]
xarray/core/dataset.py:4419: error: "Hashable" has no attribute "to_index_variable"  [attr-defined]
xarray/core/dataset.py:4430: error: "Hashable" has no attribute "to_base_variable"  [attr-defined]
  • Loading branch information
maresb authored Nov 9, 2023
1 parent feba698 commit 15328b6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 15 deletions.
20 changes: 20 additions & 0 deletions xarray/core/_typed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override]
def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)

Expand Down Expand Up @@ -291,6 +295,10 @@ def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override]
def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)

Expand Down Expand Up @@ -643,6 +651,10 @@ def __ne__(self, other: VarCompatible) -> Self:
def __ne__(self, other: VarCompatible) -> Self | T_DataArray:
return self._binary_op(other, nputils.array_ne)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)

Expand Down Expand Up @@ -851,6 +863,10 @@ def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override]
def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: GroupByCompatible) -> Dataset:
return self._binary_op(other, operator.add, reflexive=True)

Expand Down Expand Up @@ -973,6 +989,10 @@ def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override]
def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)

# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]

def __radd__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.add, reflexive=True)

Expand Down
32 changes: 17 additions & 15 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4410,16 +4410,18 @@ def swap_dims(
# rename_dims() method that only renames dimensions.

dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
for k, v in dims_dict.items():
if k not in self.dims:
for current_name, new_name in dims_dict.items():
if current_name not in self.dims:
raise ValueError(
f"cannot swap from dimension {k!r} because it is "
f"cannot swap from dimension {current_name!r} because it is "
f"not one of the dimensions of this dataset {tuple(self.dims)}"
)
if v in self.variables and self.variables[v].dims != (k,):
if new_name in self.variables and self.variables[new_name].dims != (
current_name,
):
raise ValueError(
f"replacement dimension {v!r} is not a 1D "
f"variable along the old dimension {k!r}"
f"replacement dimension {new_name!r} is not a 1D "
f"variable along the old dimension {current_name!r}"
)

result_dims = {dims_dict.get(dim, dim) for dim in self.dims}
Expand All @@ -4429,24 +4431,24 @@ def swap_dims(

variables: dict[Hashable, Variable] = {}
indexes: dict[Hashable, Index] = {}
for k, v in self.variables.items():
dims = tuple(dims_dict.get(dim, dim) for dim in v.dims)
for current_name, current_variable in self.variables.items():
dims = tuple(dims_dict.get(dim, dim) for dim in current_variable.dims)
var: Variable
if k in result_dims:
var = v.to_index_variable()
if current_name in result_dims:
var = current_variable.to_index_variable()
var.dims = dims
if k in self._indexes:
indexes[k] = self._indexes[k]
variables[k] = var
if current_name in self._indexes:
indexes[current_name] = self._indexes[current_name]
variables[current_name] = var
else:
index, index_vars = create_default_index_implicit(var)
indexes.update({name: index for name in index_vars})
variables.update(index_vars)
coord_names.update(index_vars)
else:
var = v.to_base_variable()
var = current_variable.to_base_variable()
var.dims = dims
variables[k] = var
variables[current_name] = var

return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

Expand Down
6 changes: 6 additions & 0 deletions xarray/util/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def {method}(self) -> Self:
template_other_unary = """
def {method}(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op({func}, *args, **kwargs)"""
unhashable = """
# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]"""

# For some methods we override return type `bool` defined by base class `object`.
# We need to add "# type: ignore[override]"
Expand Down Expand Up @@ -152,6 +156,7 @@ def binops(
template_binop,
extras | {"type_ignore": _type_ignore(type_ignore_eq)},
),
([(None, None)], unhashable, extras),
(BINOPS_REFLEXIVE, template_reflexive, extras),
]

Expand Down Expand Up @@ -185,6 +190,7 @@ def binops_overload(
"overload_type_ignore": _type_ignore(type_ignore_eq),
},
),
([(None, None)], unhashable, extras),
(BINOPS_REFLEXIVE, template_reflexive, extras),
]

Expand Down

0 comments on commit 15328b6

Please sign in to comment.