Skip to content

Commit

Permalink
Improve matmul/_contract coverage (nv-legate#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
yimoj authored May 2, 2023
1 parent dd27be6 commit e9e7fe0
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions tests/integration/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,23 @@ def test_out_invalid_shape_DIVERGENCE(self):
out = num.zeros(shape)
num.matmul(A, B, out=out)

def test_out_invalid_dtype(self):
@pytest.mark.parametrize(
("dtype", "out_dtype", "casting"),
((None, np.int64, "same_kind"), (float, str, "safe")),
ids=("direct", "intermediate"),
)
def test_out_invalid_dtype(self, dtype, out_dtype, casting):
expected_exc = TypeError
A_np = num.ones((3, 2, 4))
B_np = num.ones((3, 4, 3))
A_np = np.ones((3, 2, 4))
B_np = np.ones((3, 4, 3))
A_num = num.ones((3, 2, 4))
B_num = num.ones((3, 4, 3))
dtype = np.int64
out_np = np.zeros((3, 2, 3), dtype=dtype)
out_num = num.zeros((3, 2, 3), dtype=dtype)
out_np = np.zeros((3, 2, 3), dtype=out_dtype)
out_num = num.zeros((3, 2, 3), dtype=out_dtype)
with pytest.raises(expected_exc):
np.matmul(A_np, B_np, out=out_np)
np.matmul(A_np, B_np, dtype=dtype, out=out_np, casting=casting)
with pytest.raises(expected_exc):
num.matmul(A_num, B_num, out=out_num)
num.matmul(A_num, B_num, dtype=dtype, out=out_num, casting=casting)

@pytest.mark.parametrize(
"casting_dtype",
Expand All @@ -183,18 +187,20 @@ def test_invalid_casting_dtype(self, casting_dtype):
with pytest.raises(expected_exc):
num.matmul(A_num, B_num, casting=casting, dtype=dtype)

@pytest.mark.xfail
def test_invalid_casting(self):
# In Numpy, raise ValueError
# In cuNumeric, pass
@pytest.mark.parametrize(
"dtype", (str, pytest.param(float, marks=pytest.mark.xfail)), ids=str
)
def test_invalid_casting(self, dtype):
expected_exc = ValueError
casting = "unknown"
A_np = np.ones((2, 4))
B_np = np.ones((4, 3))
B_np = np.ones((4, 3), dtype=dtype)
A_num = num.ones((2, 4))
B_num = num.ones((4, 3))
B_num = num.ones((4, 3), dtype=dtype)
# In Numpy, raise ValueError
with pytest.raises(expected_exc):
np.matmul(A_np, B_np, casting=casting)
# cuNumeric does not check casting when A and B are of the same dtype
with pytest.raises(expected_exc):
num.matmul(A_num, B_num, casting=casting)

Expand Down

0 comments on commit e9e7fe0

Please sign in to comment.