Skip to content

Commit

Permalink
[ROCm] Enable Sampling Op UT on AMD (#14581)
Browse files Browse the repository at this point in the history
Making basic porting effort to run Sampling UT on ROCm ep, based on the
commits:

#13426
#14218

1. enabling EmbedLayerNorm op
2. enabling Sampling op
3. enabling helpers to copy data from CPU->GPU for subgraph

This task is the first checkpoint. There could be other missing ops when
testing a real model.
We will migrate more code onto ROCm as needed.

Co-authored-by: Ubuntu <ettao@ettao-amd-dev1.zvflicr54joexhdgnhvmxrxygg.phxx.internal.cloudapp.net>
  • Loading branch information
2 people authored and rui-ren committed Feb 7, 2023
1 parent 7d9c53b commit c0c8e11
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 113 deletions.
13 changes: 0 additions & 13 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
"bert/embed_layer_norm.cc"
"bert/embed_layer_norm.h"
"bert/embed_layer_norm_impl.cu"
"bert/embed_layer_norm_impl.h"
"bert/fast_gelu_impl.cu"
"bert/fast_gelu_impl.h"
"bert/fast_gelu.cc"
Expand Down Expand Up @@ -85,17 +81,8 @@ set(contrib_ops_excluded_files
"tensor/image_scaler_impl.h"
"transformers/beam_search.cc"
"transformers/beam_search.h"
"transformers/generation_device_helper.cc"
"transformers/generation_device_helper.h"
"transformers/generation_cuda_impl.cu"
"transformers/generation_cuda_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"transformers/sampling.cc"
"transformers/sampling.h"
"transformers/sampling_cuda_helper.h"
"transformers/dump_cuda_tensor.cc"
"transformers/dump_cuda_tensor.h"
"conv_transpose_with_dynamic_pads.cc"
"conv_transpose_with_dynamic_pads.h"
"cuda_contrib_kernels.cc"
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ const IExecutionProvider* Subgraph::GetProvider() const {
const ExecutionProviders& providers = session_state_->GetExecutionProviders();
const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider);
const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider);
const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider;
const IExecutionProvider* rocm_provider = providers.Get(onnxruntime::kRocmExecutionProvider);
const IExecutionProvider* gpu_provider = cuda_provider ? cuda_provider : rocm_provider;
const IExecutionProvider* provider = gpu_provider ? gpu_provider : cpu_provider;
return provider;
}

Expand Down
17 changes: 9 additions & 8 deletions onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ class PinnedHostBuffer {
public:
PinnedHostBuffer(size_t length)
: buffer_(nullptr) {
cudaHostAlloc(&buffer_, length * sizeof(T), cudaHostAllocDefault);
CUDA_CALL_THROW(cudaHostAlloc((void**)&buffer_, length * sizeof(T), cudaHostAllocDefault));
}

virtual ~PinnedHostBuffer() {
if (buffer_) {
cudaFreeHost(buffer_);
CUDA_CALL_THROW(cudaFreeHost(buffer_));
}
}

Expand All @@ -46,8 +46,9 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool i
// In that case, we copy tensor data as well. It is not needed, but it keeps code simple.
int num_items = dim0 * dim1;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
CUDA_CALL_THROW(cudaDeviceSynchronize());
CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost));


if (nullptr != name) {
std::cout << std::string(name) << std::endl;
Expand All @@ -64,8 +65,8 @@ template <typename T>
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, bool is_gpu_tensor) {
int num_items = dim0 * dim1 * dim2;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
CUDA_CALL_THROW(cudaDeviceSynchronize());
CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost));

if (nullptr != name) {
std::cout << std::string(name) << std::endl;
Expand All @@ -82,8 +83,8 @@ template <typename T>
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, int dim3, bool is_gpu_tensor) {
int num_items = dim0 * dim1 * dim2 * dim3;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
CUDA_CALL_THROW(cudaDeviceSynchronize());
CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost));

if (nullptr != name) {
std::cout << std::string(name) << std::endl;
Expand Down
108 changes: 54 additions & 54 deletions onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -320,33 +320,33 @@ void GetTempStorageSize(const T* d_keys_in,
bool is_descending,
size_t& temp_storage_bytes) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
} else {
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
}
}

Expand Down Expand Up @@ -412,33 +412,33 @@ void LaunchSortPairs(void* d_temp_storage,
cudaStream_t stream,
bool is_descending) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
} else {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
}
}

Expand Down Expand Up @@ -721,9 +721,9 @@ void TorchMultinomialKernelLauncher(float* d_input,
cudaStream_t stream) {
// Store the props in class variables
int device;
cudaGetDevice(&device);
CUDA_CALL_THROW(cudaGetDevice(&device));
cudaDeviceProp props;
cudaGetDeviceProperties(&props, device);
CUDA_CALL_THROW(cudaGetDeviceProperties(&props, device));

int numSM = props.multiProcessorCount;
int maxThreads = props.maxThreadsPerBlock;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,23 @@
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
#include "contrib_ops/cuda/transformers/beam_search_topk.h"
#include "contrib_ops/cuda/transformers/greedy_search_top_one.h"

// the includes would be dummy for ROCm, we will ignore them for now
#ifdef ENABLE_NVTX_PROFILE
#include "core/providers/cuda/nvtx_profile.h"
#include "core/providers/cuda/nvtx_profile_context.h"
#endif

#include "sampling_cuda_helper.h"

#ifdef DEBUG_GENERATION
#include <iostream>
#endif

using onnxruntime::cuda::ToCudaType;
using onnxruntime::cuda::TArray;
using onnxruntime::cuda::TopKImpl;

namespace onnxruntime {
namespace concurrency {
class ThreadPool;
Expand Down Expand Up @@ -203,21 +212,22 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,

// TODO(tianleiwu): we can use another stream to avoid blocking subgraph execution.
cudaStream_t cuda_stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream);
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream));


// Initialize score of first beam of each group with 0 and the rest with -1e9.
cuda::LaunchInitKernel(beam_state->beam_scores.data(), batch_size, num_beams, cuda_stream);

// copy sequence lengths to GPU
// since next_positions is only needed to update feeds after subgraph execution, so it is fine to use Async here.
if (!beam_state->next_positions.empty()) { // next_positions is empty for T5
cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream);
CUDA_CALL_THROW(cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));
}

#ifdef ENABLE_NVTX_PROFILE
Expand All @@ -234,12 +244,12 @@ void InitGreedyState(transformers::IGreedySearchState<T>* greedy_state,
initStateRange.Begin();
#endif

cudaStream_t cuda_stream = ort_stream ? reinterpret_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream);
cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream);
cudaStream_t cuda_stream = ort_stream ? reinterpret_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream));

cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream);
CUDA_CALL_THROW(cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));

#ifdef ENABLE_NVTX_PROFILE
initStateRange.End();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include <iostream>
#endif

using onnxruntime::cuda::ToCudaType;
using onnxruntime::cuda::dispatch_blockwise_softmax_forward;

namespace onnxruntime {
namespace contrib {
namespace SamplingCudaHelper {
Expand Down
Loading

0 comments on commit c0c8e11

Please sign in to comment.