Skip to content

Commit

Permalink
reduce cuda library binary size (microsoft#14555)
Browse files Browse the repository at this point in the history
### Description
Reduce the cuda library size by:
1. refactoring beam_search_top_k to reduce template instantiation. It
saves ~56MB
2. opt out TopK for type uint*, int8_t and int16_t. It saves ~50MB.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
yufenglee authored and preetha-intel committed Feb 15, 2023
1 parent 0d58dce commit 7866fca
Show file tree
Hide file tree
Showing 12 changed files with 61 additions and 129 deletions.
1 change: 1 addition & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
endif()

endif()
if (onnxruntime_USE_VITISAI)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
Expand Down
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -768,9 +768,9 @@ Do not modify directly.*
|||1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Tile|*in* input:**T**<br> *in* repeats:**T1**<br> *out* output:**T**<br><br>or<br><br>*in* input:**T**<br> *in* tiles:**T**<br> *in* axis:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)<br/> **T1** = tensor(int64)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 9]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|TopK|*in* X:**T**<br> *in* K:**tensor(int64)**<br> *out* Values:**T**<br> *out* Indices:**I**<br><br>or<br><br>*in* X:**T**<br> *out* Values:**T**<br> *out* Indices:**I**|11+|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||10|**I** = tensor(int64)<br/> **T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|||[1, 9]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)|
|Transpose|*in* data:**T**<br> *out* transposed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Trilu|*in* input:**T**<br> *in* k:**tensor(int64)**<br> *out* output:**T**|14+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
53 changes: 2 additions & 51 deletions onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ void LaunchBatchTopKKernel(const T* topk_scores,
int32_t num_beams,
int32_t k,
cudaStream_t stream) {
ORT_ENFORCE(k <= 256, "LaunchBatchTopKKernel doesn't support k >= 256");
ORT_ENFORCE(k <= 64, "LaunchBatchTopKKernel doesn't support k >= 64");

#define BatchTopKKernelLauncher(K) \
BatchTopKKernel<T, I, K, 32><<<batch_size, 32, 0, stream>>>(topk_scores, \
Expand All @@ -311,12 +311,8 @@ void LaunchBatchTopKKernel(const T* topk_scores,
BatchTopKKernelLauncher(16);
} else if (k <= 32) {
BatchTopKKernelLauncher(32);
} else if (k <= 64) {
BatchTopKKernelLauncher(64);
} else if (k <= 128) {
BatchTopKKernelLauncher(128);
} else {
BatchTopKKernelLauncher(256);
BatchTopKKernelLauncher(64);
}
}

Expand All @@ -330,36 +326,6 @@ template void LaunchBatchTopKKernel(const float* topk_scores,
int32_t k,
cudaStream_t stream);

template void LaunchBatchTopKKernel(const float* topk_scores,
const int64_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
float* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template void LaunchBatchTopKKernel(const half* topk_scores,
const int32_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
half* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template void LaunchBatchTopKKernel(const half* topk_scores,
const int64_t* topk_tokens,
int32_t* next_indices,
int32_t* next_tokens,
half* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template <typename T>
void BeamSearchTopK(
const T* input,
Expand Down Expand Up @@ -426,21 +392,6 @@ template void BeamSearchTopK(
int32_t* output_indices,
cudaStream_t stream);

template void BeamSearchTopK(
const half* input,
int32_t batch_size,
int32_t num_beams,
int32_t vocab_size,
int32_t k,
half* tmp_values_1st_stage,
int32_t* tmp_indices_1st_stage,
half* tmp_values_2st_stage,
int32_t* tmp_indices_2st_stage,
half* output_values,
int32_t* output_tokens,
int32_t* output_indices,
cudaStream_t stream);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
12 changes: 0 additions & 12 deletions onnxruntime/contrib_ops/cuda/transformers/beam_search_topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

template <typename T, typename I>
void LaunchBatchTopKKernel(
const T* topk_scores,
const I* topk_indices,
int32_t* next_indices,
int32_t* next_tokens,
T* next_scores,
int32_t batch_size,
int32_t num_beams,
int32_t k,
cudaStream_t stream);

template <typename T>
void BeamSearchTopK(
const T* input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,12 +440,16 @@ Status ProcessLogits(const OrtValue& logits, //
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
#endif
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
beam_state->next_scores.data(),
beam_state->next_scores.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
} else {
// Apply top-k selection like the following:
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
// next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
// int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size};
int64_t next_token_scores_dims[] = {batch_size * num_beams, vocab_size};
int64_t next_token_scores_dims[] = {batch_size, num_beams * vocab_size};

TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType<float>();
Expand All @@ -460,31 +464,36 @@ Status ProcessLogits(const OrtValue& logits, //
constexpr bool sorted = true; // results returned in sorted order.

std::unique_ptr<Tensor> topk_scores = Tensor::CreateDefault();
std::unique_ptr<Tensor> topk_tokens = Tensor::CreateDefault();
std::unique_ptr<Tensor> topk_indices = Tensor::CreateDefault();
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, ort_stream, thread_pool,
*topk_scores, *topk_tokens));
*topk_scores, *topk_indices));

#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", *(topk_scores.get()));
dumper->Print("topk_tokens", *(topk_tokens.get()));
dumper->Print("topk_indices", *(topk_indices.get()));
#endif

// Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following:
// next_indices = (next_tokens / vocab_size).long()
// next_tokens = next_tokens % vocab_size
const int64_t* next_token_indices = topk_indices->Data<int64_t>();
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
batch_size, top_k, vocab_size, cuda_stream);

const float* data = topk_scores->Data<float>();
#ifdef DEBUG_GENERATION
dumper->Print("next_scores before scorer", data, batch_size, top_k);
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k);
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif

cuda::LaunchBatchTopKKernel(topk_scores->Data<float>(),
topk_tokens->Data<int64_t>(),
beam_state->next_indices.data(),
beam_state->next_tokens.data(),
beam_state->next_scores.data(),
batch_size,
num_beams,
2 * num_beams,
cuda_stream);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
data,
topk_scores->SizeInBytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
}

CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
beam_state->next_scores.data(),
beam_state->next_scores.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(),
beam_state->next_tokens.data(),
beam_state->next_tokens.size_bytes(),
Expand Down
41 changes: 27 additions & 14 deletions onnxruntime/core/providers/cuda/math/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,42 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kOnnxDomain,
1, 9,
kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
TopK<false>);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
TopK,
kOnnxDomain,
10, 10,
kCudaExecutionProvider,
(*KernelDefBuilder::Create()).InputMemoryType(OrtMemTypeCPUInput, 1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()})
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
TopK<true>);

ONNX_OPERATOR_KERNEL_EX(
TopK,
kOnnxDomain,
11,
kCudaExecutionProvider,
(*KernelDefBuilder::Create()).InputMemoryType(OrtMemTypeCPUInput, 1).TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()).TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<MLFloat16>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()})
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()),
TopK<true>);

template <bool inputk>
Expand All @@ -42,11 +61,11 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
}

#define IS_PRIM_TYPE(T) utils::IsPrimitiveDataType<T>(prim_type)
#define TOPKIMPL(T) TopKImpl<T>(this, ctx->GetComputeStream(), tensor_X->Data<T>(), \
static_cast<T*>(tensor_V->MutableDataRaw()), \
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
elem_nums_cuda, \
elem_nums.size(), \
#define TOPKIMPL(T) TopKImpl<T>(this, ctx->GetComputeStream(), tensor_X->Data<T>(), \
static_cast<T*>(tensor_V->MutableDataRaw()), \
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
elem_nums_cuda, \
elem_nums.size(), \
axis, K_, largest_, sorted_, N, dimension)

template <bool inputk>
Expand Down Expand Up @@ -87,12 +106,6 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for TopK operator");
}

if (IS_PRIM_TYPE(uint8_t)) return TOPKIMPL(uint8_t);
if (IS_PRIM_TYPE(uint16_t)) return TOPKIMPL(uint16_t);
if (IS_PRIM_TYPE(uint32_t)) return TOPKIMPL(uint32_t);
if (IS_PRIM_TYPE(uint64_t)) return TOPKIMPL(uint64_t);
if (IS_PRIM_TYPE(int8_t)) return TOPKIMPL(int8_t);
if (IS_PRIM_TYPE(int16_t)) return TOPKIMPL(int16_t);
if (IS_PRIM_TYPE(int32_t)) return TOPKIMPL(int32_t);
if (IS_PRIM_TYPE(int64_t)) return TOPKIMPL(int64_t);
if (IS_PRIM_TYPE(MLFloat16)) return TOPKIMPL(MLFloat16);
Expand Down
5 changes: 0 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk_impl_i16.cu

This file was deleted.

5 changes: 0 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk_impl_i8.cu

This file was deleted.

5 changes: 0 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk_impl_u16.cu

This file was deleted.

5 changes: 0 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk_impl_u32.cu

This file was deleted.

5 changes: 0 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk_impl_u64.cu

This file was deleted.

5 changes: 0 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk_impl_u8.cu

This file was deleted.

0 comments on commit 7866fca

Please sign in to comment.