Skip to content

Commit

Permalink
Fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 22, 2024
1 parent 2363eb8 commit 25bf152
Show file tree
Hide file tree
Showing 11 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No
for k2, v2 in attrs.items():
encoded_attrs[k2] = self.encode_attribute(v2)

if coding.strings.check_vlen_dtype(dtype) == str:
if coding.strings.check_vlen_dtype(dtype) is str:
dtype = str

if self._write_empty is not None:
Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/cftimeindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def __contains__(self, key):
result = self.get_loc(key)
return (
is_scalar(result)
or type(result) == slice
or isinstance(result, slice)
or (isinstance(result, np.ndarray) and result.size)
)
except (KeyError, TypeError, ValueError):
Expand Down
6 changes: 3 additions & 3 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def check_vlen_dtype(dtype):


def is_unicode_dtype(dtype):
return dtype.kind == "U" or check_vlen_dtype(dtype) == str
return dtype.kind == "U" or check_vlen_dtype(dtype) is str


def is_bytes_dtype(dtype):
return dtype.kind == "S" or check_vlen_dtype(dtype) == bytes
return dtype.kind == "S" or check_vlen_dtype(dtype) is bytes


class EncodedStringCoder(VariableCoder):
Expand Down Expand Up @@ -104,7 +104,7 @@ def encode_string_array(string_array, encoding="utf-8"):

def ensure_fixed_length_bytes(var: Variable) -> Variable:
"""Ensure that a variable with vlen bytes is converted to fixed width."""
if check_vlen_dtype(var.dtype) == bytes:
if check_vlen_dtype(var.dtype) is bytes:
dims, data, attrs, encoding = unpack_for_encoding(var)
# TODO: figure out how to handle this with dask
data = np.asarray(data, dtype=np.bytes_)
Expand Down
2 changes: 1 addition & 1 deletion xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def encode(self):
raise NotImplementedError

def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
if variable.dtype.kind == "O" and variable.encoding.get("dtype", False) is str:
variable = variable.astype(variable.encoding["dtype"])
return variable
else:
Expand Down
2 changes: 1 addition & 1 deletion xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def decode_cf_variable(
var = strings.CharacterArrayCoder().decode(var, name=name)
var = strings.EncodedStringCoder().decode(var)

if original_dtype == object:
if original_dtype.kind == "O":
var = variables.ObjectVLenStringCoder().decode(var)
original_dtype = var.dtype

Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,14 +863,14 @@ def test_roundtrip_empty_vlen_string_array(self) -> None:
# checks preserving vlen dtype for empty arrays GH7862
dtype = create_vlen_dtype(str)
original = Dataset({"a": np.array([], dtype=dtype)})
assert check_vlen_dtype(original["a"].dtype) == str
assert check_vlen_dtype(original["a"].dtype) is str
with self.roundtrip(original) as actual:
assert_identical(original, actual)
if np.issubdtype(actual["a"].dtype, object):
# only check metadata for capable backends
# eg. NETCDF3 based backends do not roundtrip metadata
if actual["a"].dtype.metadata is not None:
assert check_vlen_dtype(actual["a"].dtype) == str
assert check_vlen_dtype(actual["a"].dtype) is str
else:
assert actual["a"].dtype == np.dtype("<U1")

Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_coding_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

def test_vlen_dtype() -> None:
dtype = strings.create_vlen_dtype(str)
assert dtype.metadata["element_type"] == str
assert dtype.metadata["element_type"] is str
assert strings.is_unicode_dtype(dtype)
assert not strings.is_bytes_dtype(dtype)
assert strings.check_vlen_dtype(dtype) is str

dtype = strings.create_vlen_dtype(bytes)
assert dtype.metadata["element_type"] == bytes
assert dtype.metadata["element_type"] is bytes
assert not strings.is_unicode_dtype(dtype)
assert strings.is_bytes_dtype(dtype)
assert strings.check_vlen_dtype(dtype) is bytes
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,10 @@ def test_encode_cf_variable_with_vlen_dtype() -> None:
)
encoded_v = conventions.encode_cf_variable(v)
assert encoded_v.data.dtype.kind == "O"
assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str
assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) is str

# empty array
v = Variable(["x"], np.array([], dtype=coding.strings.create_vlen_dtype(str)))
encoded_v = conventions.encode_cf_variable(v)
assert encoded_v.data.dtype.kind == "O"
assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) == str
assert coding.strings.check_vlen_dtype(encoded_v.data.dtype) is str
2 changes: 1 addition & 1 deletion xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def test_methods(self):
ds = create_test_data()
dt: DataTree = DataTree(data=ds)
assert ds.mean().identical(dt.ds.mean())
assert type(dt.ds.mean()) == xr.Dataset
assert isinstance(dt.ds.mean(), xr.Dataset)

def test_arithmetic(self, create_test_datatree):
dt = create_test_datatree()
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def construct_dataarray(dim_num, dtype, contains_nan, dask):
array = rng.randint(0, 10, size=shapes).astype(dtype)
elif np.issubdtype(dtype, np.bool_):
array = rng.randint(0, 1, size=shapes).astype(dtype)
elif dtype == str:
elif dtype is str:
array = rng.choice(["a", "b", "c", "d"], size=shapes)
else:
raise ValueError
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def _assertIndexedLikeNDArray(self, variable, expected_value0, expected_dtype=No
# check type or dtype is consistent for both ndarray and Variable
if expected_dtype is None:
# check output type instead of array dtype
assert type(variable.values[0]) == type(expected_value0)
assert type(variable[0].values) == type(expected_value0)
assert type(variable.values[0]) is type(expected_value0)
assert type(variable[0].values) is type(expected_value0)
elif expected_dtype is not False:
assert variable.values[0].dtype == expected_dtype
assert variable[0].values.dtype == expected_dtype
Expand Down

0 comments on commit 25bf152

Please sign in to comment.