Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type check datatree tests #9632

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
type hints for datatree ops tests
TomNicholas committed Oct 16, 2024
commit ae81de3ba97fede5b384088293e1f73c67672f81
20 changes: 10 additions & 10 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
@@ -1769,7 +1769,7 @@ def test_subtree(self):


class TestOps:
def test_unary_op(self):
def test_unary_op(self) -> None:
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})
dt = DataTree.from_dict({"/": ds1, "/subnode": ds2})
@@ -1779,7 +1779,7 @@ def test_unary_op(self):
result = -dt
assert_equal(result, expected)

def test_unary_op_inherited_coords(self):
def test_unary_op_inherited_coords(self) -> None:
tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]}))
tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])}))
actual = -tree
@@ -1792,7 +1792,7 @@ def test_unary_op_inherited_coords(self):
expected["/foo/bar"].data = np.array([-4, -5, -6])
assert_identical(actual, expected)

def test_binary_op_on_int(self):
def test_binary_op_on_int(self) -> None:
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})
dt = DataTree.from_dict({"/": ds1, "/subnode": ds2})
@@ -1802,7 +1802,7 @@ def test_binary_op_on_int(self):
result = dt * 5
assert_equal(result, expected)

def test_binary_op_on_dataarray(self):
def test_binary_op_on_dataarray(self) -> None:
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})
dt = DataTree.from_dict(
@@ -1824,7 +1824,7 @@ def test_binary_op_on_dataarray(self):
result = dt * other_da
assert_equal(result, expected)

def test_binary_op_on_dataset(self):
def test_binary_op_on_dataset(self) -> None:
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})
dt = DataTree.from_dict(
@@ -1846,7 +1846,7 @@ def test_binary_op_on_dataset(self):
result = dt * other_ds
assert_equal(result, expected)

def test_binary_op_on_datatree(self):
def test_binary_op_on_datatree(self) -> None:
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})

@@ -1857,7 +1857,7 @@ def test_binary_op_on_datatree(self):
result = dt * dt
assert_equal(result, expected)

def test_arithmetic_inherited_coords(self):
def test_arithmetic_inherited_coords(self) -> None:
tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]}))
tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])}))
actual = 2 * tree
@@ -1869,7 +1869,7 @@ def test_arithmetic_inherited_coords(self):
expected["/foo/bar"].data = np.array([8, 10, 12])
assert_identical(actual, expected)

def test_binary_op_commutativity_with_dataset(self):
def test_binary_op_commutativity_with_dataset(self) -> None:
# regression test for #9365

ds1 = xr.Dataset({"a": [5], "b": [3]})
@@ -1893,7 +1893,7 @@ def test_binary_op_commutativity_with_dataset(self):
result = other_ds * dt
assert_equal(result, expected)

def test_inplace_binary_op(self):
def test_inplace_binary_op(self) -> None:
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})
dt = DataTree.from_dict({"/": ds1, "/subnode": ds2})
@@ -1903,7 +1903,7 @@ def test_inplace_binary_op(self):
dt += 1
assert_equal(dt, expected)

def test_dont_broadcast_single_node_tree(self):
def test_dont_broadcast_single_node_tree(self) -> None:
# regression test for https://github.com/pydata/xarray/issues/9365#issuecomment-2291622577
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})