Skip to content

Commit

Permalink
Reverse type checks for better type inheritance (#8313)
Browse files Browse the repository at this point in the history
* Reverse type checks

* remove ignores not needed yet
  • Loading branch information
Illviljan authored Oct 16, 2023
1 parent 4520ce9 commit bac1265
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,20 +901,20 @@ def _copy(
if data is None:
data_old = self._data

if isinstance(data_old, indexing.MemoryCachedArray):
if not isinstance(data_old, indexing.MemoryCachedArray):
ndata = data_old
else:
# don't share caching between copies
ndata = indexing.MemoryCachedArray(data_old.array)
else:
ndata = data_old

if deep:
ndata = copy.deepcopy(ndata, memo)

else:
ndata = as_compatible_data(data)
if self.shape != ndata.shape: # type: ignore[attr-defined]
if self.shape != ndata.shape:
raise ValueError(
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
f"Data shape {ndata.shape} must match shape of object {self.shape}"
)

attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
Expand Down Expand Up @@ -1043,7 +1043,9 @@ def chunk(
if chunkmanager.is_chunked_array(data_old):
data_chunked = chunkmanager.rechunk(data_old, chunks)
else:
if isinstance(data_old, indexing.ExplicitlyIndexed):
if not isinstance(data_old, indexing.ExplicitlyIndexed):
ndata = data_old
else:
# Unambiguously handle array storage backends (like NetCDF4 and h5py)
# that can't handle general array indexing. For example, in netCDF4 you
# can do "outer" indexing along two dimensions independent, which works
Expand All @@ -1055,8 +1057,6 @@ def chunk(
ndata = indexing.ImplicitToExplicitIndexingAdapter(
data_old, indexing.OuterIndexer
)
else:
ndata = data_old

if utils.is_dict_like(chunks):
chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape))
Expand Down Expand Up @@ -1504,7 +1504,9 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable):
new_data = duck_array_ops.reshape(reordered.data, new_shape)
new_dims = reordered.dims[: len(other_dims)] + (new_dim,)

return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True)
return type(self)(
new_dims, new_data, self._attrs, self._encoding, fastpath=True
)

def stack(self, dimensions=None, **dimensions_kwargs):
"""
Expand Down Expand Up @@ -2760,7 +2762,7 @@ def concat(

return cls(first_var.dims, data, attrs)

def copy(self, deep: bool = True, data: ArrayLike | None = None):
def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None):
"""Returns a copy of this object.
`deep` is ignored since data is stored in the form of
Expand All @@ -2785,7 +2787,17 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None):
data copied from original.
"""
if data is None:
ndata = self._data.copy(deep=deep)
data_old = self._data

if not isinstance(data_old, indexing.MemoryCachedArray):
ndata = data_old
else:
# don't share caching between copies
ndata = indexing.MemoryCachedArray(data_old.array)

if deep:
ndata = copy.deepcopy(ndata, None)

else:
ndata = as_compatible_data(data)
if self.shape != ndata.shape:
Expand Down

0 comments on commit bac1265

Please sign in to comment.