Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split and where translations #6

Merged
merged 2 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,33 @@ def numpy_like(array, pad_width, mode="constant", constant_values=0):
return numpy_like


def tensorflow_where_wrap(fn):
@functools.wraps(fn)
def numpy_like(condition, x=None, y=None, **kwargs):
return tuple(transpose(fn(condition, x, y, **kwargs)))

return numpy_like


def tensorflow_split_wrap(fn):
@functools.wraps(fn)
def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
if isinstance(indices_or_sections, int):
return fn(ary, indices_or_sections, axis=axis, **kwargs)
else:
diff = do(
"diff",
indices_or_sections,
prepend=0,
append=ary.shape[axis],
like="numpy",
)
diff = list(diff)
return fn(ary, diff, axis=axis)

return numpy_like


_FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy

_SUBMODULE_ALIASES["tensorflow", "log"] = "tensorflow.math"
Expand Down Expand Up @@ -839,6 +866,8 @@ def numpy_like(array, pad_width, mode="constant", constant_values=0):
_CUSTOM_WRAPPERS["tensorflow", "tril"] = tril_to_band_part
_CUSTOM_WRAPPERS["tensorflow", "triu"] = triu_to_band_part
_CUSTOM_WRAPPERS["tensorflow", "pad"] = tensorflow_pad_wrap
_CUSTOM_WRAPPERS["tensorflow", "where"] = tensorflow_where_wrap
_CUSTOM_WRAPPERS["tensorflow", "split"] = tensorflow_split_wrap
_CUSTOM_WRAPPERS["tensorflow", "random.uniform"] = make_translator(
[
("low", ("minval", 0.0)),
Expand Down Expand Up @@ -951,6 +980,27 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0):
)


def torch_split_wrap(fn):
# for torch >=1.8 we can use tensor_split instead, but in current stable
# release this function has not been added
@functools.wraps(fn)
def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
if isinstance(indices_or_sections, int):
return fn(ary, indices_or_sections, dim=axis, **kwargs)
else:
diff = do(
"diff",
indices_or_sections,
prepend=0,
append=ary.shape[axis],
like="numpy",
)
diff = list(diff)
return fn(ary, diff, dim=axis)

return numpy_like


_FUNCS["torch", "pad"] = torch_pad
_FUNCS["torch", "real"] = torch_real
_FUNCS["torch", "imag"] = torch_imag
Expand All @@ -975,6 +1025,7 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0):

_SUBMODULE_ALIASES["torch", "linalg.qr"] = "torch"
_SUBMODULE_ALIASES["torch", "linalg.svd"] = "torch"
_SUBMODULE_ALIASES["torch", "linalg.norm"] = "torch"
_SUBMODULE_ALIASES["torch", "linalg.expm"] = "torch"
_SUBMODULE_ALIASES["torch", "random.normal"] = "torch"
_SUBMODULE_ALIASES["torch", "random.uniform"] = "torch"
Expand All @@ -983,6 +1034,7 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0):
_CUSTOM_WRAPPERS["torch", "linalg.qr"] = qr_allow_fat
_CUSTOM_WRAPPERS["torch", "random.normal"] = scale_random_normal_manually
_CUSTOM_WRAPPERS["torch", "random.uniform"] = scale_random_uniform_manually
_CUSTOM_WRAPPERS["torch", "split"] = torch_split_wrap
_CUSTOM_WRAPPERS["torch", "stack"] = make_translator(
[
("arrays", ("tensors",)),
Expand Down
28 changes: 28 additions & 0 deletions tests/test_autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,31 @@ def test_einsum(backend):
== ar.infer_backend(C4)
== backend
)


@pytest.mark.parametrize("backend", BACKENDS)
def test_split(backend):
if backend == "sparse":
pytest.xfail("sparse doesn't support split yet")
if backend == "dask":
pytest.xfail("dask doesn't support split yet")
A = ar.do("ones", (10, 20, 10), like=backend)
sections = [2, 4, 14]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to parametrize sections with a int case as well? then we'll have full test coverage for the translations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good thing you noticed, implementation for int case for torch was wrong (torch takes split size as opposed to number of splits), should be fixed now.

splits = ar.do("split", A, sections, axis=1)
assert len(splits) == 4
assert splits[3].shape == (10, 6, 10)


@pytest.mark.parametrize("backend", BACKENDS)
def test_where(backend):
if backend == "sparse":
pytest.xfail("sparse doesn't support where yet")
A = ar.do("arange", 10, like=backend)
B = ar.do("arange", 10, like=backend) + 1
C = ar.do("stack", [A, B])
D = ar.do("where", C < 5)
if backend == "dask":
for x in D:
x.compute_chunk_sizes()
for x in D:
assert ar.to_numpy(x).shape == (9,)