Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #3870: bug in reshape for bigint type #3907

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module AryUtil
use List;
use CommAggregation;
use CommPrimitives;
use BigInteger;


param bitsPerDigit = RSLSD_bitsPerDigit;
Expand Down Expand Up @@ -905,7 +906,8 @@ module AryUtil
/*
unflatten a 1D array into a multi-dimensional array of the given shape
*/
proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws {
proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws
where t!=bigint {
var unflat = makeDistArray((...shape), t);

if N == 1 {
Expand Down Expand Up @@ -952,7 +954,6 @@ module AryUtil
// flat region is spread across multiple locales, do a get for each source locale
for locInID in locInStart..locInStop {
const flatSubSlice = flatSlice[flatLocRanges[locInID]];

get(
c_ptrTo(unflat[dufc.orderToIndex(flatSubSlice.low)]),
getAddr(a[flatSubSlice.low]),
Expand All @@ -967,11 +968,30 @@ module AryUtil
return unflat;
}

proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws
where t==bigint {
var unflat = makeDistArray((...shape), t);

if N == 1 {
unflat = a;
return unflat;
}

coforall loc in Locales with (ref unflat) do on loc {
forall idx in a.localSubdomain() with (var agg = newDstAggregator(t)) {
agg.copy(unflat[unflat.domain.orderToIndex(idx)], a[idx]);
}
}

return unflat;
}

/*
flatten a multi-dimensional array into a 1D array
*/
@arkouda.registerCommand
proc flatten(const ref a: [?d] ?t): [] t throws {
@arkouda.registerCommand(ignoreWhereClause=true)
proc flatten(const ref a: [?d] ?t): [] t throws
where t!=bigint {
if a.rank == 1 then return a;

var flat = makeDistArray(d.size, t);
Expand Down Expand Up @@ -1030,6 +1050,22 @@ module AryUtil
return flat;
}


proc flatten(const ref a: [?d] ?t): [] t throws
where t==bigint {
if a.rank == 1 then return a;

var flat = makeDistArray(d.size, t);

coforall loc in Locales with (ref flat) do on loc {
forall idx in flat.localSubdomain() with (var agg = newSrcAggregator(t)) {
agg.copy(flat[idx], a[a.domain.orderToIndex(idx)]);
}
}

return flat;
}

// helper for computing an array element's index from its order
record orderer {
param rank: int;
Expand All @@ -1044,10 +1080,10 @@ module AryUtil
// index -> order for the input array's indices
// e.g., order = k + (nz * j) + (nz * ny * i)
inline proc indexToOrder(idx: rank*?t): t
where (t==int) || (t==uint(64)) {
var order : t = 0;
for param i in 0..<rank do order += idx[i] * accumRankSizes[rank - i - 1];
return order;
}
where (t==int) || (t==uint(64)) {
jaketrookman marked this conversation as resolved.
Show resolved Hide resolved
var order : t = 0;
for param i in 0..<rank do order += idx[i] * accumRankSizes[rank - i - 1];
return order;
}
}
}
39 changes: 26 additions & 13 deletions tests/pdarrayclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,47 @@
class TestPdarrayClass:

@pytest.mark.skip_if_max_rank_less_than(2)
def test_reshape(self):
a = ak.arange(4)
@pytest.mark.parametrize("dtype", DTYPES)
def test_reshape(self, dtype):
a = ak.arange(4, dtype=dtype)
r = a.reshape((2, 2))
assert r.shape == (2, 2)
assert isinstance(r, ak.pdarray)

def test_shape(self):
a = ak.arange(4)
@pytest.mark.skip_if_max_rank_less_than(3)
def test_reshape_and_flatten_bug_reproducer(self):
dtype = "bigint"
size = 10
x = ak.arange(size, dtype=dtype).reshape((1, size, 1))
ak_assert_equal(x.flatten(), ak.arange(size, dtype=dtype))

@pytest.mark.parametrize("dtype", DTYPES)
def test_shape(self,dtype):
a = ak.arange(4,dtype=dtype)
np_a = np.arange(4)
assert isinstance(a.shape, tuple)
assert a.shape == np_a.shape

@pytest.mark.skip_if_max_rank_less_than(2)
def test_shape_multidim(self):
a = ak.arange(4).reshape((2, 2))
np_a = np.arange(4).reshape((2, 2))
@pytest.mark.parametrize("dtype", list(set(DTYPES) - set(["bool"])))
def test_shape_multidim(self,dtype):
a = ak.arange(4,dtype=dtype).reshape((2, 2))
np_a = np.arange(4,dtype=dtype).reshape((2, 2))
assert isinstance(a.shape, tuple)
assert a.shape == np_a.shape

@pytest.mark.parametrize("size", pytest.prob_size)
def test_flatten(self, size):
a = ak.arange(size)
@pytest.mark.parametrize("dtype", DTYPES)
def test_flatten(self, size,dtype):
a = ak.arange(size,dtype=dtype)
ak_assert_equal(a.flatten(), a)

@pytest.mark.skip_if_max_rank_less_than(3)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("size", pytest.prob_size)
def test_flatten(self, size):
def test_flatten(self, size,dtype):
size = size - (size % 4)
a = ak.arange(size)
a = ak.arange(size,dtype=dtype)
b = a.reshape((2, 2, size / 4))
ak_assert_equal(b.flatten(), a)

Expand Down Expand Up @@ -104,11 +116,12 @@ def test_is_locally_sorted(self, size, dtype, axis):
@pytest.mark.skip_if_nl_greater_than(2)
@pytest.mark.skip_if_nl_less_than(2)
@pytest.mark.parametrize("size", pytest.prob_size)
def test_is_locally_sorted_multi_locale(self, size):
@pytest.mark.parametrize("dtype", DTYPES)
def test_is_locally_sorted_multi_locale(self, size,dtype):
from arkouda.pdarrayclass import is_locally_sorted, is_sorted

size = size // 2
a = ak.concatenate([ak.arange(size), ak.arange(size)])
a = ak.concatenate([ak.arange(size,dtype=dtype), ak.arange(size,dtype=dtype)])
assert is_locally_sorted(a)
assert not is_sorted(a)

Expand Down
Loading