Skip to content

Commit

Permalink
Implement Unstack (#920)
Browse files Browse the repository at this point in the history
* Implement unstack

* Remove code does not relate to this PR

* Remove for loop on output dim; add Unstack ragged

* Add more docs

* Fix comments

* Fix docs & unit tests
  • Loading branch information
pkufool authored Feb 20, 2022
1 parent f4fefe4 commit 56edc82
Show file tree
Hide file tree
Showing 4 changed files with 767 additions and 37 deletions.
328 changes: 320 additions & 8 deletions k2/csrc/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,18 @@ RaggedShape RaggedShapeFromTotSizes(ContextPtr c, int32_t num_axes,
NVTX_RANGE(K2_FUNC);
K2_CHECK_GE(num_axes, 2);
std::vector<RaggedShapeLayer> axes(num_axes - 1);
// In future we might choose to allocate everything in one big array, to avoid
// multiple allocations, but for now just do it the simple way.
int32_t tot_size = 0;
for (int32_t axis = 1; axis < num_axes; ++axis) {
axes[axis - 1].row_splits = Array1<int32_t>(c, tot_sizes[axis - 1] + 1);
axes[axis - 1].row_ids = Array1<int32_t>(c, tot_sizes[axis]);
tot_size += tot_sizes[axis - 1] + 1 + tot_sizes[axis];
}
Array1<int32_t> buf(c, tot_size);
int32_t start = 0;
for (int32_t axis = 1; axis < num_axes; ++axis) {
axes[axis - 1].row_splits = buf.Arange(start,
start + tot_sizes[axis - 1] + 1);
start += tot_sizes[axis - 1] + 1;
axes[axis - 1].row_ids = buf.Arange(start, start + tot_sizes[axis]);
start += tot_sizes[axis];
axes[axis - 1].cached_tot_size = tot_sizes[axis];
}
// Not check here as we did not set the values of row_splits and row_ids
Expand Down Expand Up @@ -418,7 +425,6 @@ static RaggedShape IndexAxis0(RaggedShape &src, const Array1<int32_t> &new2old,
Array1<int32_t> *elem_indexes /*=nullptr*/) {
NVTX_RANGE(K2_FUNC);
ContextPtr &c = src.Context();
bool is_cpu = (c->GetDeviceType() == kCpu);
K2_CHECK(IsCompatible(src, new2old));
int32_t num_axes = src.NumAxes(), src_dim0 = src.Dim0(),
ans_dim0 = new2old.Dim();
Expand Down Expand Up @@ -470,7 +476,6 @@ static RaggedShape IndexAxis0(RaggedShape &src, const Array1<int32_t> &new2old,
tot_sizes.data[i]);
}


int32_t *elem_indexes_data = (elem_indexes != nullptr ?
elem_indexes->Data() : nullptr);

Expand Down Expand Up @@ -575,6 +580,7 @@ RaggedShape Index(RaggedShape &src, int32_t axis,
Array1<int32_t> last_row_splits(last_row_ids.Context(),
src.TotSize(num_axes - 2) + 1);
RowIdsToRowSplits(last_row_ids, &last_row_splits);

if (elem_indexes)
*elem_indexes = indexes;

Expand Down Expand Up @@ -689,7 +695,6 @@ static RaggedShape StackAxis0(int32_t num_srcs, RaggedShape **src,
int32_t num_axes_in = src[0]->NumAxes(),
num_axes_out = num_axes_in + 1;
ContextPtr c = src[0]->Context();
bool is_cpu = (c->GetDeviceType() == kCpu);

// Check if they have same num-axes and compatible context
for (int32_t i = 1; i < num_srcs; ++i) {
Expand All @@ -702,7 +707,6 @@ static RaggedShape StackAxis0(int32_t num_srcs, RaggedShape **src,
Array2<int32_t> offsets = GetOffsets(num_srcs, src);
auto offsets_acc = offsets.Accessor();


SmallVec<int32_t, 6> tot_sizes_out;
K2_CHECK(num_axes_out <= 6);
int32_t max_tot_size = 0;
Expand Down Expand Up @@ -1100,6 +1104,314 @@ RaggedShape Stack(int32_t axis, int32_t num_srcs, RaggedShape **src,
return RaggedShape(ans_layers);
}

/*
Select ragged tensor's shape on axis 0 with a two axes ragged index.
@param [in] src Source RaggedShape to select.
@param [in] indexes A **TWO** axes ragged tensor containing the indexes
into the axis 0 of src. we also support -1 as an index,
which will result in the empty list (as if it were the
index into a position in `src` that had an empty list)
i.e. with `-1 <= indexes[i] < src.TotSize(0)`.
@param [out] out The container where the output RaggedShape will write to,
MUST NOT be a nullptr. Will be reallocated and the final
size of `out` would equal to `indexes.TotSize(0)`.
Note, The `NumAxes()` of output RaggedShape is the same
as the `NumAxes()` of src.
@param [out] split_map If not nullptr will store the element-index within
src telling where the elements of split RaggedShape
come from. Will be reallocated and the final size of
`split_map` would equal to `indexes.TotSize(0)`.
Suppose indexes is `[ [ 0 3 5 ] [ 1 2 4] [ 6 -1 ] ]`, it means that we will
select elements 0,3,5 of src's axis 0 to construct the first output
RaggedShape, 1,2,4 to construct the second output RaggedShape, 6 and a empty
list to construct the third output RaggedShape.
*/
static void SelectAxis0(RaggedShape &src, const Ragged<int32_t> &indexes,
std::vector<RaggedShape> *out, std::vector<Array1<int32_t>> *split_map) {
NVTX_RANGE(K2_FUNC);
ContextPtr &c = src.Context();
K2_CHECK(IsCompatible(src, indexes));
K2_CHECK_EQ(indexes.NumAxes(), 2);
K2_CHECK(out != nullptr);
int32_t num_axes = src.NumAxes(),
out_size = indexes.Dim0(),
tot_elems = indexes.NumElements();
if (out_size == 0) {
*out = std::vector<RaggedShape>();
if (split_map) {
*split_map = std::vector<Array1<int32_t>>();
}
return;
}

Array2<int32_t> old_offsets, // num_axes by tot_elems
new_offsets; // num_axes by (tot_elems + 1).
GetOldAndNewOffsets(src, indexes.values, &old_offsets, &new_offsets);

const int32_t *indexes_row_split1_data = indexes.RowSplits(1).Data(),
*indexes_row_ids1_data = indexes.RowIds(1).Data();

// Contains the `TotSize` of each axes of each output RaggedShape
Array2<int32_t> tot_sizes(c, out_size, num_axes);
Array2Accessor<int32_t> tot_sizes_acc = tot_sizes.Accessor();
Array2Accessor<int32_t> new_offsets_acc = new_offsets.Accessor();

K2_EVAL2(c, out_size, num_axes, lambda_set_tot_sizes,
(int32_t i, int32_t j) -> void {
int32_t idx0 = indexes_row_split1_data[i],
idx0_next = indexes_row_split1_data[i + 1];
tot_sizes_acc(i, j) =
new_offsets_acc(j, idx0_next) - new_offsets_acc(j, idx0);
});

auto tot_sizes_cpu = tot_sizes.To(GetCpuContext());
auto tot_sizes_cpu_acc = tot_sizes_cpu.Accessor();
out->resize(out_size);
if (split_map != nullptr) split_map->resize(out_size);
// We can not avoid this for loop on dim0, as we want to allocate memory
// seperately, may consider using a ThreadPool later.
for (int32_t i = 0; i < out_size; ++i) {
out->at(i) = RaggedShapeFromTotSizes(c,
num_axes, tot_sizes_cpu.Row(i).Data());
if (split_map != nullptr) {
split_map->at(i) =
Array1<int32_t>(c, tot_sizes_cpu_acc(i, num_axes - 1));
};
}

// Caution: e.g. old_row_splits_acc(i) == src.RowSplits(i+1).
RowSplitsAccessor<5> old_row_splits_acc(src);
RowIdsAccessor<5> old_row_ids_acc(src);
auto old_offsets_acc = old_offsets.Accessor();

// axes_elems contains the elements number of each axes before splitting into
// different RaggedShape, it should equal to the Col sum of `tot_sizes` above.
Array1<int32_t> axes_elems =
Array1<int32_t>(new_offsets.Col(tot_elems)).To(GetCpuContext());

for (int32_t axis = 0; axis < num_axes; axis++) {
// Contains the RowSplits & RowIds pointer for current layer,
// has a dimension of dim0 * 2, the layout is splits_pointer0, ids_pointer0,
// splits_pointer1, ids_pointer1, ...
Array1<int32_t *> splits_ids_ptr(GetCpuContext(), out_size * 2);
int32_t **splits_ids_ptr_data = splits_ids_ptr.Data();

// Contains the pointers for split_map
Array1<int32_t *> split_map_ptr;
int32_t **split_map_ptr_data;

if (axis == num_axes - 1 && split_map != nullptr) {
split_map_ptr = Array1<int32_t *>(GetCpuContext(), out_size);
split_map_ptr_data = split_map_ptr.Data();
}

for (int32_t i = 0; i < out_size; ++i) {
splits_ids_ptr_data[2 * i] = axis == num_axes - 1 ? nullptr :
out->at(i).RowSplits(axis + 1).Data();

splits_ids_ptr_data[2 * i + 1] =
axis == 0 ? nullptr : out->at(i).RowIds(axis).Data();

if (axis == num_axes - 1 && split_map != nullptr) {
split_map_ptr_data[i] = split_map->at(i).Data();
}
}
// transfer to GPU if we're using a GPU
splits_ids_ptr = splits_ids_ptr.To(c);
splits_ids_ptr_data = splits_ids_ptr.Data();

// set row split1
if (axis == 0) {
K2_EVAL(c, tot_elems, lambda_set_row_split1, (int32_t idx01) {
int32_t index_idx0 = indexes_row_ids1_data[idx01],
idx0x = indexes_row_split1_data[index_idx0];
splits_ids_ptr_data[2 * index_idx0][idx01 - idx0x]
= new_offsets_acc(axis + 1, idx01) -
new_offsets_acc(axis + 1, idx0x);

// Set the last elements of row_splits1 of each output shape
if (idx01 == tot_elems - 1 ||
index_idx0 != indexes_row_ids1_data[idx01 + 1]) {
splits_ids_ptr_data[2 * index_idx0][idx01 - idx0x + 1]
= new_offsets_acc(axis + 1, idx01 + 1) -
new_offsets_acc(axis + 1, idx0x);
}
});
continue;
}

// set last element of each row_splits
// TODO: Integrate this kernel into the kernel below.
if (axis < num_axes - 1) {
K2_EVAL(c, out_size, lambda_set_last_row_splits, (int32_t idx0) {
int32_t idx0x = indexes_row_split1_data[idx0],
idx0x_next = indexes_row_split1_data[idx0 + 1],
value = new_offsets_acc(axis + 1, idx0x_next) -
new_offsets_acc(axis + 1, idx0x),
pos = tot_sizes_acc(idx0, axis);
splits_ids_ptr_data[2 * idx0][pos] = value;
});
}

if (axis == num_axes - 1 && split_map != nullptr) {
split_map_ptr = split_map_ptr.To(c);
split_map_ptr_data = split_map_ptr.Data();
}

int32_t num_elems = axes_elems[axis];

// composed_row_ids maps current idx to idx01 of indexes
Array1<int32_t> composed_row_ids(c, num_elems);
RowSplitsToRowIds(new_offsets.Row(axis), &composed_row_ids);

const int32_t *composed_row_ids_data = composed_row_ids.Data();

K2_EVAL(c, num_elems, lambda_set_row_splits_and_ids, (int32_t i) {
// tot_elems = indexes.NumElements(), so tot_idx0 can be interpreted as
// index_idx01
int32_t tot_idx0 = composed_row_ids_data[i],
index_idx0 = indexes_row_ids1_data[tot_idx0],
index_idx0x = indexes_row_split1_data[index_idx0],

begin_base = new_offsets_acc(axis, index_idx0x),
begin = new_offsets_acc(axis, tot_idx0),
this_idx0 = i - begin,
this_idx01 = i - begin_base;

K2_CHECK_GE(this_idx0, 0);
K2_CHECK_GE(this_idx01, 0);

// "prev" means for axis - 1
int32_t new_prev_offset = new_offsets_acc(axis - 1, tot_idx0),
old_prev_offset = old_offsets_acc(axis - 1, tot_idx0),
old_offset = old_offsets_acc(axis, tot_idx0),
old_idx = old_offset + this_idx0;

if (split_map != nullptr && axis == num_axes - 1)
split_map_ptr_data[index_idx0][this_idx01] = old_idx;

// set row ids
const int32_t *this_old_row_ids = old_row_ids_acc(axis - 1);
int32_t old_row_id = this_old_row_ids[old_idx],
new_row_id = old_row_id + new_prev_offset - old_prev_offset,
new_pre_offset_idx0x = new_offsets_acc(axis - 1, index_idx0x);

splits_ids_ptr_data[2 * index_idx0 + 1][this_idx01] =
new_row_id - new_pre_offset_idx0x;

// set row splits
if (axis + 1 < num_axes) {
int32_t new_next_offset = new_offsets_acc(axis + 1, tot_idx0),
old_next_offset = old_offsets_acc(axis + 1, tot_idx0),
next_offset_diff = new_next_offset - old_next_offset;
const int32_t *old_row_splits_data = old_row_splits_acc(axis);
int32_t row_split_value =
next_offset_diff + old_row_splits_data[old_idx],
new_next_offset_idx0x = new_offsets_acc(axis + 1, index_idx0x);
splits_ids_ptr_data[2 * index_idx0][this_idx01]
= row_split_value - new_next_offset_idx0x;
}
});
}
}

void Unstack(RaggedShape &src, int32_t axis, std::vector<RaggedShape> *out,
std::vector<Array1<int32_t>> *split_map) {
ContextPtr &c = src.Context();
if (axis == 0) {
if (src.NumAxes() == 2) {
auto new_src = ComposeRaggedShapes(
TrivialShape(c, src.TotSize(0)), src);
return Unstack(new_src, 1, out, split_map);
}
auto indexes = Ragged<int32_t>(RegularRaggedShape(c, src.Dim0(), 1),
Arange(c, 0, src.Dim0()));

SelectAxis0(src, indexes, out, split_map);
for (size_t i = 0; i < out->size(); ++i) {
out->at(i) = RemoveAxis(out->at(i), 0);
}
} else {
int32_t tot_size_axis_minus1 = src.TotSize(axis - 1),
tot_size_axis = src.TotSize(axis);
const int32_t *row_splits_axis = src.RowSplits(axis).Data(),
*row_ids_axis = src.RowIds(axis).Data();

// Get the number of elements of current axis on each sublist
Array1<int32_t> sublists_size(c, tot_size_axis_minus1);
int32_t *sublists_size_data = sublists_size.Data();
K2_EVAL(c, tot_size_axis_minus1, lambda_get_sublists_size, (int32_t i) {
sublists_size_data[i] = row_splits_axis[i + 1] - row_splits_axis[i];
});

// Each sublist contains the elements of axis `axis`, unstack operation will
// split all these elements in a sublist to different RaggedShapes, so the
// number of output RaggedShapes is the size of the sublist with max
// elements.
int32_t num_out = MaxValue(sublists_size);

out->resize(num_out);
if (split_map != nullptr) split_map->resize(num_out);

// We will select the elements of axis `axis` on each sublist, the number
// of sublits equals to `src.TotSize(axis - 1)`.
// Initialize with -1 here, because not all the sublists have the same size,
// -1s here mean that we don't select anything on those positions
Array1<int32_t> indexes(c, num_out * tot_size_axis_minus1, -1);
int32_t *indexes_data = indexes.Data();

// Decide the elements of axis `axis` will go to which output RaggedShape
K2_EVAL(c, tot_size_axis, lambda_set_indexes, (int32_t idx01) {
int32_t idx0 = row_ids_axis[idx01],
idx0x = row_splits_axis[idx0],
idx1 = idx01 - idx0x;
indexes_data[idx1 * tot_size_axis_minus1 + idx0] = idx01;
});

// To make `DecomposeRaggedShape` work, we add a RegularRaggedShape
// layer after axis `axis` if axis equals to `src.NumAxes() - 1`.
// Of course, we have to remove the added layer finally.
bool remove_last_axis = false;
if (axis == src.NumAxes() - 1) {
src = ComposeRaggedShapes(src,
RegularRaggedShape(c, src.NumElements(), 1));
remove_last_axis = true;
}

RaggedShape top, bottom;
DecomposeRaggedShape(src, axis, &top, &bottom);

// Unstack will remove current axis (the last axis of top after decomposing
// on axis), to make `RemoveAxis` work, we add a TrivialShape layer before
// axix 0, finally we will remove the added layer.
bool remove_axis0 = false;
if (top.NumAxes() == 2) {
top = ComposeRaggedShapes(
TrivialShape(c, top.TotSize(0)), top);
remove_axis0 = true;
}
top = RemoveAxis(top, top.NumAxes() - 1);

auto ragged_indexes = Ragged<int32_t>(RegularRaggedShape(c,
num_out, tot_size_axis_minus1), indexes);

// Select elements according to indexes into corresponding RaggedShape
SelectAxis0(bottom, ragged_indexes, out, split_map);

for (int32_t i = 0; i < num_out; ++i) {
out->at(i) = ComposeRaggedShapes(top, out->at(i));
if (remove_axis0 && !remove_last_axis)
out->at(i) = RemoveAxis(out->at(i), 0);
if (remove_last_axis) {
out->at(i) = RemoveEmptyLists(out->at(i), out->at(i).NumAxes() - 2);
out->at(i) = RemoveAxis(out->at(i), out->at(i).NumAxes() - 1);
}
}
}
}

RaggedShape Merge(int32_t num_srcs, RaggedShape **src,
const Array1<uint32_t> &merge_map,
Array1<uint32_t> *merge_map_out) {
Expand Down
Loading

0 comments on commit 56edc82

Please sign in to comment.