Skip to content

Commit

Permalink
test for int case and torch_split_wrap fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RikVoorhaar committed Feb 18, 2021
1 parent 062df87 commit adf2221
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,8 @@ def torch_split_wrap(fn):
@functools.wraps(fn)
def numpy_like(ary, indices_or_sections, axis=0, **kwargs):
if isinstance(indices_or_sections, int):
return fn(ary, indices_or_sections, dim=axis, **kwargs)
split_size = ary.shape[axis] // indices_or_sections
return fn(ary, split_size, dim=axis, **kwargs)
else:
diff = do(
"diff",
Expand Down
16 changes: 11 additions & 5 deletions tests/test_autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,22 @@ def test_einsum(backend):


@pytest.mark.parametrize("backend", BACKENDS)
def test_split(backend):
@pytest.mark.parametrize("int_or_section", ["int", "section"])
def test_split(backend, int_or_section):
if backend == "sparse":
pytest.xfail("sparse doesn't support split yet")
if backend == "dask":
pytest.xfail("dask doesn't support split yet")
A = ar.do("ones", (10, 20, 10), like=backend)
sections = [2, 4, 14]
splits = ar.do("split", A, sections, axis=1)
assert len(splits) == 4
assert splits[3].shape == (10, 6, 10)
if int_or_section == 'section':
sections = [2, 4, 14]
splits = ar.do("split", A, sections, axis=1)
assert len(splits) == 4
assert splits[3].shape == (10, 6, 10)
else:
splits = ar.do("split", A, 5, axis=2)
assert len(splits) == 5
assert splits[2].shape == (10, 20, 2)


@pytest.mark.parametrize("backend", BACKENDS)
Expand Down

0 comments on commit adf2221

Please sign in to comment.