Skip to content

Commit

Permalink
Fix GetTransposeReordering. (#586)
Browse files Browse the repository at this point in the history
Avoid reusing memory in cub device radix sort.
  • Loading branch information
csukuangfj authored Jan 13, 2021
1 parent fe00848 commit 7f01b9c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
24 changes: 10 additions & 14 deletions k2/csrc/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ RaggedShape ComposeRaggedShapes(const RaggedShape &a, const RaggedShape &b) {
return RaggedShape(axes, validate);
}


RaggedShape ComposeRaggedShapes3(const RaggedShape &a, const RaggedShape &b,
const RaggedShape &c) {
NVTX_RANGE(K2_FUNC);
Expand All @@ -171,8 +170,7 @@ RaggedShape ComposeRaggedShapes3(const RaggedShape &a, const RaggedShape &b,
const auto &a_axes = a.Layers();
const auto &b_axes = b.Layers();
const auto &c_axes = c.Layers();
std::size_t a_size = a_axes.size(),
b_size = b_axes.size(),
std::size_t a_size = a_axes.size(), b_size = b_axes.size(),
c_size = c_axes.size();
std::vector<RaggedShapeLayer> axes;
axes.reserve(a_size + b_size + c_size);
Expand All @@ -183,7 +181,6 @@ RaggedShape ComposeRaggedShapes3(const RaggedShape &a, const RaggedShape &b,
return RaggedShape(axes, validate);
}


RaggedShape RaggedShape3(Array1<int32_t> *row_splits1,
Array1<int32_t> *row_ids1, int32_t cached_tot_size1,
Array1<int32_t> *row_splits2,
Expand Down Expand Up @@ -1188,24 +1185,23 @@ Array1<int32_t> GetTransposeReordering(Ragged<int32_t> &src, int32_t num_cols) {
int32_t num_elements = src.values.Dim();
int32_t log_buckets = static_cast<int32_t>(ceilf(log2f(num_buckets)));

Array1<int32_t> ans = Range(context, num_elements, 0);
Array1<int32_t> order = Range(context, num_elements, 0);
Array1<int32_t> src_tmp_out(context, num_elements);
Array1<int32_t> ans(context, num_elements);

cudaStream_t stream = context->GetCudaStream();

size_t temp_storage_bytes = 0;
K2_CUDA_SAFE_CALL(cub::DeviceRadixSort::SortPairs(
nullptr, temp_storage_bytes, src.values.Data(),
static_cast<int32_t *>(nullptr), ans.Data(), ans.Data(), num_elements, 0,
log_buckets, stream));
nullptr, temp_storage_bytes, src.values.Data(), src_tmp_out.Data(),
order.Data(), ans.Data(), num_elements, 0, log_buckets, stream));

Array1<int8_t> d_temp_storage(
context, temp_storage_bytes + num_elements * sizeof(int32_t));
Array1<int8_t> d_temp_storage(context, temp_storage_bytes);

K2_CUDA_SAFE_CALL(cub::DeviceRadixSort::SortPairs(
d_temp_storage.Data() + sizeof(int32_t) * num_elements,
temp_storage_bytes, src.values.Data(),
reinterpret_cast<int32_t *>(d_temp_storage.Data()), ans.Data(),
ans.Data(), num_elements, 0, log_buckets, stream));
d_temp_storage.Data(), temp_storage_bytes, src.values.Data(),
src_tmp_out.Data(), order.Data(), ans.Data(), num_elements, 0,
log_buckets, stream));

return ans;
#else
Expand Down
29 changes: 29 additions & 0 deletions k2/csrc/ragged_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "k2/csrc/array.h"
#include "k2/csrc/array_ops.h"
#include "k2/csrc/context.h"
#include "k2/csrc/fsa_utils.h"
#include "k2/csrc/math.h"
#include "k2/csrc/ragged.h"
#include "k2/csrc/ragged_ops.h"
Expand Down Expand Up @@ -1494,6 +1495,34 @@ TEST(GetTransposeReordering, WithDuplicatesThreeAxes) {
}
}

TEST(GetTransposeReordering, RandomFsaVecTest) {
for (int32_t iter = 0; iter != 8; ++iter) {
for (auto &context : {GetCpuContext(), GetCudaContext()}) {
int n = RandInt(100, 200);
int32_t min_num_fsas = n;
int32_t max_num_fsas = n * 2;
bool acyclic = false;
int32_t max_symbol = 100;
int32_t min_num_arcs = min_num_fsas * 10;
int32_t max_num_arcs = max_num_fsas * 20;

FsaVec fsas = RandomFsaVec(min_num_fsas, max_num_fsas, acyclic,
max_symbol, min_num_arcs, max_num_arcs);
fsas = fsas.To(context);
Array1<int32_t> dest_states = GetDestStates(fsas, true);
Ragged<int32_t> dest_states_tensor(fsas.shape, dest_states);
int32_t num_states = fsas.TotSize(1);
int32_t num_arcs = fsas.TotSize(2);
Array1<int32_t> order =
GetTransposeReordering(dest_states_tensor, num_states);
Sort(&order);
ASSERT_EQ(order.Dim(), num_arcs);
Array1<int32_t> expected = Range<int32_t>(context, num_arcs, 0);
CheckArrayData(order, expected);
}
}
}

TEST(ChangeSublistSize, TwoAxes) {
for (auto &context : {GetCpuContext(), GetCudaContext()}) {
Array1<int32_t> row_splits1(context, std::vector<int32_t>{0, 2, 5});
Expand Down

0 comments on commit 7f01b9c

Please sign in to comment.