Skip to content

Commit

Permalink
Protect combine() from mixing 1/N-dimensional arguments.
Browse files Browse the repository at this point in the history
  • Loading branch information
LTLA committed Nov 8, 2023
1 parent 3186a18 commit 804458c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/biocutils/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@ def combine(*x: Any):
Returns:
A combined object, typically the same type as the first element in ``x``.
"""
if hasattr(x[0], "shape") and len(x[0].shape) > 1:
has_1d = False
has_nd = False
for y in x:
if hasattr(y, "shape") and len(y.shape) > 1:
has_nd = True
else:
has_1d = True

if has_nd and has_1d:
raise ValueError("cannot mix 1-dimensional and higher-dimensional objects in `combine`")
if has_nd:
return combine_rows(*x)
else:
return combine_sequences(*x)
4 changes: 4 additions & 0 deletions tests/test_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ def test_basic_mixed_dense_array():
y = np.array([4, 5, 6, 7]).reshape((2,2))
zcomb = combine(x, y)
assert zcomb.shape == (4, 2)

with pytest.raises(ValueError) as ex:
combine(x, [1,2,3,4])
assert str(ex.value).find("cannot mix") >= 0

0 comments on commit 804458c

Please sign in to comment.