Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing CudaAsyncBuffer with TArray to improve perf #3303

Merged
merged 6 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions onnxruntime/contrib_ops/cuda/activation/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@ namespace cuda {
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
UnaryElementwise::Prepare(context, &p); \
CudaAsyncBuffer<Ctx##x> func_ctx(this, MakeFuncCtx(), 1); \
if (!std::is_same<CtxNull, Ctx##x>::value) ORT_RETURN_IF_ERROR(func_ctx.CopyToGpu()); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToCudaType<T>::MappedType>( \
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(p.input_tensor->template Data<T>()), \
reinterpret_cast<typename ToCudaType<T>::MappedType*>(p.output_tensor->template MutableData<T>()), \
func_ctx.GpuPtr(), p.output_tensor->Shape().Size()); \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/cuda/activation/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ namespace cuda {
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
UnaryElementwise::Prepare(context, &p); \
CudaAsyncBuffer<Ctx##x> func_ctx(this, MakeFuncCtx(), 1); \
if (!std::is_same<CtxNull, Ctx##x>::value) ORT_RETURN_IF_ERROR(func_ctx.CopyToGpu()); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToCudaType<T>::MappedType>( \
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(p.input_tensor->template Data<T>()), \
reinterpret_cast<typename ToCudaType<T>::MappedType*>(p.output_tensor->template MutableData<T>()), \
func_ctx.GpuPtr(), p.output_tensor->Shape().Size()); \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ template <typename InT, typename OutT, typename FuncT, int NumThreadsPerBlock, i
__global__ void _UnaryElementWise(
const InT* input_data,
OutT* output_data,
const FuncT& functor,
const FuncT functor,
CUDA_LONG N) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
InT value[NumElementsPerThread];
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
#define TOPKIMPL(T) TopKImpl<T>(this, tensor_X->Data<T>(), \
static_cast<T*>(tensor_V->MutableDataRaw()), \
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
elem_nums_cuda.GpuPtr(), \
elem_nums_cuda, \
elem_nums.size(), \
axis, K_, largest_, sorted_, N, dimension)

template <bool inputk>
Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
auto tensor_X = ctx->Input<Tensor>(0);
ORT_ENFORCE(nullptr != tensor_X);
auto rank = static_cast<int64_t>(tensor_X->Shape().NumDimensions());
auto axis = axis_ < 0 ? rank + axis_ : axis_;
int32_t rank = static_cast<int32_t>(tensor_X->Shape().NumDimensions());
int32_t axis = static_cast<int32_t>(axis_ < 0 ? rank + axis_ : axis_);
ORT_ENFORCE(axis > -1 && axis < rank);

if (inputk) {
Expand All @@ -80,8 +80,7 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
}

auto N = elem_nums[0] / dimension;
CudaAsyncBuffer<int64_t> elem_nums_cuda(this, elem_nums);
ORT_RETURN_IF_ERROR(elem_nums_cuda.CopyToGpu());
TArray<int64_t> elem_nums_cuda(elem_nums);

auto prim_type = tensor_X->DataType()->AsPrimitiveDataType();
if (prim_type == nullptr) {
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cuda/math/topk_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct KV {
#define LESS(n, m) ((n) <= (m) ? (n) : (m))

template <typename T>
__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) {
__global__ void BitonicTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t aligned_K, int64_t largest, int64_t sorted, int64_t dimension, int64_t aligned_dimension, T type_min, T type_max) {
auto tid = threadIdx.x;
auto bid = blockIdx.x;
extern __shared__ char shared_mem[];
Expand Down Expand Up @@ -192,7 +192,7 @@ __device__ void SetByte(double* d, int64_t byte) {
}

template<typename T, int64_t THREADS, int64_t KPT>
__global__ void RadixTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) {
__global__ void RadixTopK(const T* X, T* V, int64_t* I, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t dimension, int64_t XPT, T type_min, T type_max) {
auto tid = threadIdx.x;
auto bid = blockIdx.x;
extern __shared__ char shared_mem[];
Expand Down Expand Up @@ -342,7 +342,7 @@ __global__ void RadixTopK(const T* X, T* V, int64_t* I, const int64_t* elem_nums
}

template <typename T>
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
__global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t offset, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, dimension);
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis];
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
Expand All @@ -352,7 +352,7 @@ __global__ void FillInput(const T* input_x, T* output_v, int64_t* output_i, cons
}

template <typename T>
__global__ void FillOutput(const T* input_v, const int64_t* input_i, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t offset, int64_t dimension) {
__global__ void FillOutput(const T* input_v, const int64_t* input_i, T* output_v, int64_t* output_i, const TArray<int64_t> elem_nums, size_t size, int32_t axis, int64_t K, int64_t offset, int64_t dimension) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, K);
auto left = offset / (axis == size - 1 ? 1 : elem_nums[axis + 1]) * elem_nums[axis] * K / dimension;
auto right = axis == size - 1 ? 0 : offset % elem_nums[axis + 1];
Expand All @@ -369,7 +369,7 @@ __global__ void ExcludeOutput(int64_t* output_i, int64_t K, int64_t dimension) {
}

template <typename T>
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension) {
auto aligned_K = ALIGN(K);
auto aligned_dimension = ALIGN(dimension);
if (aligned_dimension <= GridDim::maxThreadsPerBlock) {
Expand Down Expand Up @@ -419,9 +419,9 @@ Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t
const T* input_x, \
T* output_v, \
int64_t* output_i, \
const int64_t* elem_nums, \
const TArray<int64_t>& elem_nums, \
size_t size, \
int64_t axis, \
int32_t axis, \
int64_t K, \
int64_t largest, \
int64_t sorted, \
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/math/topk_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace onnxruntime {
namespace cuda {

template <typename T>
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const int64_t* elem_nums, size_t size, int64_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);
Status TopKImpl(const CudaKernel* kernel, const T* input_x, T* output_v, int64_t* output_i, const TArray<int64_t>& elem_nums, size_t size, int32_t axis, int64_t K, int64_t largest, int64_t sorted, int64_t N, int64_t dimension);

} // namespace cuda
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -94,40 +94,35 @@ Status NonMaxSuppression::ComputeInternal(OpKernelContext* ctx) const {
Tensor* output = ctx->Output(0, {static_cast<int64_t>(total_num_saved_outputs), last_dim});
ORT_ENFORCE(output != nullptr);
int64_t* dst = output->MutableData<int64_t>();
size_t count = all_selected_indices.size();
int32_t count = static_cast<int32_t>(all_selected_indices.size());
weixingzhang marked this conversation as resolved.
Show resolved Hide resolved

CudaAsyncBuffer<const void*> input_ptr(this, count);
CudaAsyncBuffer<int64_t> concat_sizes_gpu(this, count);
CudaAsyncBuffer<int64_t> concat_sizes_range_gpu(this, count);
CudaAsyncBuffer<int64_t> axis_dimension_input_output_mapping_gpu(this, total_num_saved_outputs);
TArray<const void*> input_ptr(count);
TArray<int64_t> concat_sizes_gpu(count);
TArray<int64_t> concat_sizes_range_gpu(count);
TArray<int64_t> axis_dimension_input_output_mapping_gpu(total_num_saved_outputs);

int index = 0;
for (size_t i = 0; i < count; i++) {
for (int32_t i = 0; i < count; i++) {
auto& it = all_selected_indices[i];
auto src = std::get<0>(it).get();
auto size = std::get<1>(it);

input_ptr.CpuPtr()[i] = src;
concat_sizes_gpu.CpuPtr()[i] = size;
concat_sizes_range_gpu.CpuPtr()[i] = (i == 0) ? size : size + concat_sizes_range_gpu.CpuPtr()[i - 1];
input_ptr[i] = src;
concat_sizes_gpu[i] = size;
concat_sizes_range_gpu[i] = (i == 0) ? size : size + concat_sizes_range_gpu[i - 1];
for (int j = 0; j < size; j++) {
axis_dimension_input_output_mapping_gpu.CpuPtr()[index++] = i;
axis_dimension_input_output_mapping_gpu[index++] = i;
}
}

concat_sizes_gpu.CopyToGpu();
axis_dimension_input_output_mapping_gpu.CopyToGpu();
concat_sizes_range_gpu.CopyToGpu();
input_ptr.CopyToGpu();

ORT_RETURN_IF_ERROR(ConcatImpl(sizeof(int64_t),
num_elements,
last_dim,
concat_sizes_gpu.GpuPtr(),
concat_sizes_range_gpu.GpuPtr(),
axis_dimension_input_output_mapping_gpu.GpuPtr(),
concat_sizes_gpu,
concat_sizes_range_gpu,
axis_dimension_input_output_mapping_gpu,
dst,
input_ptr.GpuPtr(),
input_ptr,
static_cast<size_t>(num_elements)));
}

Expand Down
13 changes: 5 additions & 8 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,13 @@ Status CudnnRnnBase<T>::ComputeInternal(OpKernelContext* ctx) const {
}

if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) {
CudaAsyncBuffer<int32_t> sequence_lens_buffer(this, batch_size);
memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t));
ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu());
TArray<int32_t> sequence_lens_buffer(batch_size);
weixingzhang marked this conversation as resolved.
Show resolved Hide resolved
memcpy(sequence_lens_buffer.data_, sequence_lens_data, batch_size * sizeof(int32_t));
RnnMaskImpl(gsl::narrow_cast<int32_t>(num_directions_),
gsl::narrow_cast<int32_t>(seq_length),
gsl::narrow_cast<int32_t>(batch_size),
gsl::narrow_cast<int32_t>(hidden_size_),
sequence_lens_buffer.GpuPtr(),
sequence_lens_buffer,
reinterpret_cast<CudaT*>(y_data),
reinterpret_cast<CudaT*>(y_h_data),
output_size);
Expand All @@ -371,14 +370,12 @@ void CudnnRnnBase<T>::SetZeroSequences(const int64_t zero_seq_index_cache_size,
T* y_h_data,
T* y_c_data) const {
typedef typename ToCudaType<T>::MappedType CudaT;
CudaAsyncBuffer<int32_t> zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size);
memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t));
ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu());
TArray<int32_t> zero_seq_index_cache_async_buffer(zero_seq_index_cache);
MaskZeroSequences(gsl::narrow_cast<int32_t>(hidden_size_),
reinterpret_cast<CudaT*>(y_data),
reinterpret_cast<CudaT*>(y_h_data),
reinterpret_cast<CudaT*>(y_c_data),
zero_seq_index_cache_async_buffer.GpuPtr(),
zero_seq_index_cache_async_buffer,
static_cast<int64_t>(zero_seq_index_cache_size));
}

Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/providers/cuda/rnn/rnn_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template <typename T>
__global__ void _RnnMaskKernel(const int32_t seq_length,
const int32_t batch_size,
const int32_t hidden_size,
const int32_t* sequence_lens,
const TArray<int32_t> sequence_lens,
const fast_divmod div_seq_block,
const fast_divmod div_dir_block,
const fast_divmod div_batch_block,
Expand Down Expand Up @@ -120,7 +120,7 @@ void RnnMaskImpl(const int32_t num_directions,
const int32_t seq_length,
const int32_t batch_size,
const int32_t hidden_size,
const int32_t* sequence_lens,
const TArray<int32_t>& sequence_lens,
T* y_output_data,
T* y_h_output_data,
const size_t N) {
Expand All @@ -138,7 +138,7 @@ __global__ void _MaskZeroSequences(const int32_t hidden_size,
T* y_output_data,
T* y_h_output_data,
T* y_c_output_data,
const int32_t* zeor_seq_index_cache,
const TArray<int32_t> zeor_seq_index_cache,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);

Expand Down Expand Up @@ -168,7 +168,7 @@ void MaskZeroSequences(const int32_t hidden_size,
T* y_output_data,
T* y_h_output_data,
T* y_c_output_data,
const int32_t* zeor_seq_index_cache,
const TArray<int32_t>& zeor_seq_index_cache,
const size_t N) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
_MaskZeroSequences<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
Expand All @@ -180,7 +180,7 @@ void MaskZeroSequences(const int32_t hidden_size,
const int32_t seq_length, \
const int32_t batch_size, \
const int32_t hidden_size, \
const int32_t* sequence_lens, \
const TArray<int32_t>& sequence_lens, \
T* y_output_data, \
T* y_h_output_data, \
const size_t N); \
Expand All @@ -200,7 +200,7 @@ template void MaskZeroSequences<T>(const int32_t hidden_size,
T* y_output_data, \
T* y_h_output_data, \
T* y_c_output_data, \
const int32_t* zeor_seq_index_cache, \
const TArray<int32_t>& zeor_seq_index_cache, \
const size_t N);

SPECIALIZED_RNN_IMPL(half)
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/rnn/rnn_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void RnnMaskImpl(const int32_t num_directions,
const int32_t seq_length,
const int32_t batch_size,
const int32_t hidden_size,
const int32_t* sequence_lens,
const TArray<int32_t>& sequence_lens,
T* y_output_data,
T* y_h_output_data,
const size_t N);
Expand All @@ -39,7 +39,7 @@ void MaskZeroSequences(const int32_t hidden_size,
T* y_output_data,
T* y_h_output_data,
T* y_c_output_data,
const int32_t* zeor_seq_index_cache_async_buffer,
const TArray<int32_t>& zeor_seq_index_cache_async_buffer,
const size_t N);
} // namespace cuda
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct TArray {
memcpy(data_, vec.data(), vec.size() * sizeof(T));
}

T& operator[](int32_t index) {
__host__ __device__ __forceinline__ T& operator[](int32_t index) {
return data_[index];
}

Expand Down
Loading