Skip to content

Commit

Permalink
Merge pull request #6 from RikVoorhaar/master
Browse files Browse the repository at this point in the history
split and where translations
  • Loading branch information
jcmgray authored Feb 18, 2021
2 parents 5fdcbdf + adf2221 commit 6f14268
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
53 changes: 53 additions & 0 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,33 @@ def numpy_like(array, pad_width, mode="constant", constant_values=0):
return numpy_like


def tensorflow_where_wrap(fn):
@functools.wraps(fn)
def numpy_like(condition, x=None, y=None, **kwargs):
return tuple(transpose(fn(condition, x, y, **kwargs)))

return numpy_like


def tensorflow_split_wrap(fn):
@functools.wraps(fn)
def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
if isinstance(indices_or_sections, int):
return fn(ary, indices_or_sections, axis=axis, **kwargs)
else:
diff = do(
"diff",
indices_or_sections,
prepend=0,
append=ary.shape[axis],
like="numpy",
)
diff = list(diff)
return fn(ary, diff, axis=axis)

return numpy_like


_FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy

_SUBMODULE_ALIASES["tensorflow", "log"] = "tensorflow.math"
Expand Down Expand Up @@ -839,6 +866,8 @@ def numpy_like(array, pad_width, mode="constant", constant_values=0):
_CUSTOM_WRAPPERS["tensorflow", "tril"] = tril_to_band_part
_CUSTOM_WRAPPERS["tensorflow", "triu"] = triu_to_band_part
_CUSTOM_WRAPPERS["tensorflow", "pad"] = tensorflow_pad_wrap
_CUSTOM_WRAPPERS["tensorflow", "where"] = tensorflow_where_wrap
_CUSTOM_WRAPPERS["tensorflow", "split"] = tensorflow_split_wrap
_CUSTOM_WRAPPERS["tensorflow", "random.uniform"] = make_translator(
[
("low", ("minval", 0.0)),
Expand Down Expand Up @@ -951,6 +980,28 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0):
)


def torch_split_wrap(fn):
# for torch >=1.8 we can use tensor_split instead, but in current stable
# release this function has not been added
@functools.wraps(fn)
def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
if isinstance(indices_or_sections, int):
split_size = ary.shape[axis] // indices_or_sections
return fn(ary, split_size, dim=axis, **kwargs)
else:
diff = do(
"diff",
indices_or_sections,
prepend=0,
append=ary.shape[axis],
like="numpy",
)
diff = list(diff)
return fn(ary, diff, dim=axis)

return numpy_like


_FUNCS["torch", "pad"] = torch_pad
_FUNCS["torch", "real"] = torch_real
_FUNCS["torch", "imag"] = torch_imag
Expand All @@ -975,6 +1026,7 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0):

_SUBMODULE_ALIASES["torch", "linalg.qr"] = "torch"
_SUBMODULE_ALIASES["torch", "linalg.svd"] = "torch"
_SUBMODULE_ALIASES["torch", "linalg.norm"] = "torch"
_SUBMODULE_ALIASES["torch", "linalg.expm"] = "torch"
_SUBMODULE_ALIASES["torch", "random.normal"] = "torch"
_SUBMODULE_ALIASES["torch", "random.uniform"] = "torch"
Expand All @@ -983,6 +1035,7 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0):
_CUSTOM_WRAPPERS["torch", "linalg.qr"] = qr_allow_fat
_CUSTOM_WRAPPERS["torch", "random.normal"] = scale_random_normal_manually
_CUSTOM_WRAPPERS["torch", "random.uniform"] = scale_random_uniform_manually
_CUSTOM_WRAPPERS["torch", "split"] = torch_split_wrap
_CUSTOM_WRAPPERS["torch", "stack"] = make_translator(
[
("arrays", ("tensors",)),
Expand Down
34 changes: 34 additions & 0 deletions tests/test_autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,37 @@ def test_einsum(backend):
== ar.infer_backend(C4)
== backend
)


@pytest.mark.parametrize("backend", BACKENDS)
@pytest.mark.parametrize("int_or_section", ["int", "section"])
def test_split(backend, int_or_section):
if backend == "sparse":
pytest.xfail("sparse doesn't support split yet")
if backend == "dask":
pytest.xfail("dask doesn't support split yet")
A = ar.do("ones", (10, 20, 10), like=backend)
if int_or_section == 'section':
sections = [2, 4, 14]
splits = ar.do("split", A, sections, axis=1)
assert len(splits) == 4
assert splits[3].shape == (10, 6, 10)
else:
splits = ar.do("split", A, 5, axis=2)
assert len(splits) == 5
assert splits[2].shape == (10, 20, 2)


@pytest.mark.parametrize("backend", BACKENDS)
def test_where(backend):
if backend == "sparse":
pytest.xfail("sparse doesn't support where yet")
A = ar.do("arange", 10, like=backend)
B = ar.do("arange", 10, like=backend) + 1
C = ar.do("stack", [A, B])
D = ar.do("where", C < 5)
if backend == "dask":
for x in D:
x.compute_chunk_sizes()
for x in D:
assert ar.to_numpy(x).shape == (9,)

0 comments on commit 6f14268

Please sign in to comment.