Skip to content

Commit

Permalink
Merge pull request #40 from tylerjereddy/treddy_to_device_cpu_cupy
Browse files Browse the repository at this point in the history
ENH: support CuPy to_device "cpu"
  • Loading branch information
asmeurer authored Apr 28, 2023
2 parents 9ef7f72 + b9ceea9 commit 2ec609d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
5 changes: 5 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def _cupy_to_device(x, device, /, stream=None):

if device == x.device:
return x
elif device == "cpu":
# allowing us to use `to_device(x, "cpu")`
# is useful for portable test swapping between
# host and device backends
return x.get()
elif not isinstance(device, _Device):
raise ValueError(f"Unsupported device {device!r}")
else:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from ._helpers import import_
from array_api_compat import to_device, device

import pytest
import numpy as np
from numpy.testing import assert_allclose

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
def test_to_device_host(library):
# different libraries have different semantics
# for DtoH transfers; ensure that we support a portable
# shim for common array libs
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
xp = import_('array_api_compat.' + library)
expected = np.array([1, 2, 3])
x = xp.asarray([1, 2, 3])
x = to_device(x, "cpu")
# torch will return a genuine Device object, but
# the other libs will do something different with
# a `device(x)` query; however, what's really important
# here is that we can test portably after calling
# to_device(x, "cpu") to return to host
assert_allclose(x, expected)

0 comments on commit 2ec609d

Please sign in to comment.