Skip to content

Commit

Permalink
Add an implementation of np.vdot to PytatoPyOpenCLArrayContext (induc…
Browse files Browse the repository at this point in the history
…er#299)

* Add an implementation of vdot to the PytatoPyOpenCLArrayContext np namespace.

* Remove the tests that are just skipped for scalars.

* Respond to comments.

* Ruff version needed to be updated locally.
  • Loading branch information
nkoskelo authored Jan 9, 2025
1 parent a88f08d commit bc7139f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
3 changes: 3 additions & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,7 @@ def amin(self, a, axis=None):
def absolute(self, a):
return self.abs(a)

def vdot(self, a: Array, b: Array):

return rec_multimap_array_container(pt.vdot, a, b)
# }}}
9 changes: 4 additions & 5 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,10 @@ def evaluate(np_, *args_):

assert_close_to_numpy_in_containers(actx, evaluate, args)

if sym_name in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]:
pytest.skip(f"'{sym_name}' not supported on scalars")

args = [randn(0, dtype)[()] for i in range(n_args)]
assert_close_to_numpy(actx, evaluate, args)
if sym_name not in ["where", "min", "max", "any", "all", "conj", "vdot", "sum"]:
# Scalar arguments are supported.
args = [randn(0, dtype)[()] for i in range(n_args)]
assert_close_to_numpy(actx, evaluate, args)


@pytest.mark.parametrize(("sym_name", "n_args", "dtype"), [
Expand Down
4 changes: 2 additions & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
THE SOFTWARE.
"""
import logging
from typing import Optional, cast
from typing import cast

import numpy as np
import pytest
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_dataclass_array_container() -> None:
class ArrayContainerWithOptional:
x: np.ndarray
# Deliberately left as Optional to test compatibility.
y: Optional[np.ndarray] # noqa: UP007
y: np.ndarray | None

with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
# NOTE: cannot have wrapped annotations (here by `Optional`)
Expand Down

0 comments on commit bc7139f

Please sign in to comment.