diff --git a/autoray/autoray.py b/autoray/autoray.py index 886c158..9c7855d 100644 --- a/autoray/autoray.py +++ b/autoray/autoray.py @@ -805,6 +805,33 @@ def numpy_like(array, pad_width, mode="constant", constant_values=0): return numpy_like +def tensorflow_where_wrap(fn): + @functools.wraps(fn) + def numpy_like(condition, x=None, y=None, **kwargs): + return tuple(transpose(fn(condition, x, y, **kwargs))) + + return numpy_like + + +def tensorflow_split_wrap(fn): + @functools.wraps(fn) + def numpy_like(ary, indices_or_sections, axis=0, **kwargs): + if isinstance(indices_or_sections, int): + return fn(ary, indices_or_sections, axis=axis, **kwargs) + else: + diff = do( + "diff", + indices_or_sections, + prepend=0, + append=ary.shape[axis], + like="numpy", + ) + diff = list(diff) + return fn(ary, diff, axis=axis) + + return numpy_like + + _FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy _SUBMODULE_ALIASES["tensorflow", "log"] = "tensorflow.math" @@ -839,6 +866,8 @@ def numpy_like(array, pad_width, mode="constant", constant_values=0): _CUSTOM_WRAPPERS["tensorflow", "tril"] = tril_to_band_part _CUSTOM_WRAPPERS["tensorflow", "triu"] = triu_to_band_part _CUSTOM_WRAPPERS["tensorflow", "pad"] = tensorflow_pad_wrap +_CUSTOM_WRAPPERS["tensorflow", "where"] = tensorflow_where_wrap +_CUSTOM_WRAPPERS["tensorflow", "split"] = tensorflow_split_wrap _CUSTOM_WRAPPERS["tensorflow", "random.uniform"] = make_translator( [ ("low", ("minval", 0.0)), @@ -951,6 +980,28 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0): ) +def torch_split_wrap(fn): + # for torch >=1.8 we can use tensor_split instead, but in current stable + # release this function has not been added + @functools.wraps(fn) + def numpy_like(ary, indices_or_sections, axis=0, **kwargs): + if isinstance(indices_or_sections, int): + split_size = ary.shape[axis] // indices_or_sections + return fn(ary, split_size, dim=axis, **kwargs) + else: + diff = do( + "diff", + indices_or_sections, + prepend=0, + append=ary.shape[axis], + like="numpy", + ) + diff = list(diff) + return fn(ary, diff, dim=axis) + + return numpy_like + + _FUNCS["torch", "pad"] = torch_pad _FUNCS["torch", "real"] = torch_real _FUNCS["torch", "imag"] = torch_imag @@ -975,6 +1026,7 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0): _SUBMODULE_ALIASES["torch", "linalg.qr"] = "torch" _SUBMODULE_ALIASES["torch", "linalg.svd"] = "torch" +_SUBMODULE_ALIASES["torch", "linalg.norm"] = "torch" _SUBMODULE_ALIASES["torch", "linalg.expm"] = "torch" _SUBMODULE_ALIASES["torch", "random.normal"] = "torch" _SUBMODULE_ALIASES["torch", "random.uniform"] = "torch" @@ -983,6 +1035,7 @@ def torch_pad(array, pad_width, mode="constant", constant_values=0): _CUSTOM_WRAPPERS["torch", "linalg.qr"] = qr_allow_fat _CUSTOM_WRAPPERS["torch", "random.normal"] = scale_random_normal_manually _CUSTOM_WRAPPERS["torch", "random.uniform"] = scale_random_uniform_manually +_CUSTOM_WRAPPERS["torch", "split"] = torch_split_wrap _CUSTOM_WRAPPERS["torch", "stack"] = make_translator( [ ("arrays", ("tensors",)), diff --git a/tests/test_autoray.py b/tests/test_autoray.py index aea9cd6..a5475d0 100644 --- a/tests/test_autoray.py +++ b/tests/test_autoray.py @@ -545,3 +545,37 @@ def test_einsum(backend): == ar.infer_backend(C4) == backend ) + + +@pytest.mark.parametrize("backend", BACKENDS) +@pytest.mark.parametrize("int_or_section", ["int", "section"]) +def test_split(backend, int_or_section): + if backend == "sparse": + pytest.xfail("sparse doesn't support split yet") + if backend == "dask": + pytest.xfail("dask doesn't support split yet") + A = ar.do("ones", (10, 20, 10), like=backend) + if int_or_section == 'section': + sections = [2, 4, 14] + splits = ar.do("split", A, sections, axis=1) + assert len(splits) == 4 + assert splits[3].shape == (10, 6, 10) + else: + splits = ar.do("split", A, 5, axis=2) + assert len(splits) == 5 + assert splits[2].shape == (10, 20, 2) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_where(backend): + if backend == "sparse": + pytest.xfail("sparse doesn't support where yet") + A = ar.do("arange", 10, like=backend) + B = ar.do("arange", 10, like=backend) + 1 + C = ar.do("stack", [A, B]) + D = ar.do("where", C < 5) + if backend == "dask": + for x in D: + x.compute_chunk_sizes() + for x in D: + assert ar.to_numpy(x).shape == (9,)