Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Enable Sampling Op UT on AMD #14581

Merged
merged 1 commit into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ set(contrib_ops_excluded_files
"bert/attention_softmax.h"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
"bert/embed_layer_norm.cc"
"bert/embed_layer_norm.h"
"bert/embed_layer_norm_impl.cu"
"bert/embed_layer_norm_impl.h"
"bert/fast_gelu_impl.cu"
"bert/fast_gelu_impl.h"
"bert/fast_gelu.cc"
Expand Down Expand Up @@ -85,17 +81,8 @@ set(contrib_ops_excluded_files
"tensor/image_scaler_impl.h"
"transformers/beam_search.cc"
"transformers/beam_search.h"
"transformers/generation_device_helper.cc"
"transformers/generation_device_helper.h"
"transformers/generation_cuda_impl.cu"
"transformers/generation_cuda_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"transformers/sampling.cc"
"transformers/sampling.h"
"transformers/sampling_cuda_helper.h"
"transformers/dump_cuda_tensor.cc"
"transformers/dump_cuda_tensor.h"
"conv_transpose_with_dynamic_pads.cc"
"conv_transpose_with_dynamic_pads.h"
"cuda_contrib_kernels.cc"
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ const IExecutionProvider* Subgraph::GetProvider() const {
const ExecutionProviders& providers = session_state_->GetExecutionProviders();
const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider);
const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider);
const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider;
const IExecutionProvider* rocm_provider = providers.Get(onnxruntime::kRocmExecutionProvider);
const IExecutionProvider* gpu_provider = cuda_provider ? cuda_provider : rocm_provider;
const IExecutionProvider* provider = gpu_provider ? gpu_provider : cpu_provider;
return provider;
}

Expand Down
17 changes: 9 additions & 8 deletions onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ class PinnedHostBuffer {
public:
PinnedHostBuffer(size_t length)
: buffer_(nullptr) {
cudaHostAlloc(&buffer_, length * sizeof(T), cudaHostAllocDefault);
CUDA_CALL_THROW(cudaHostAlloc((void**)&buffer_, length * sizeof(T), cudaHostAllocDefault));
}

virtual ~PinnedHostBuffer() {
if (buffer_) {
cudaFreeHost(buffer_);
CUDA_CALL_THROW(cudaFreeHost(buffer_));
}
}

Expand All @@ -46,8 +46,9 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool i
// In that case, we copy tensor data as well. It is not needed, but it keeps code simple.
int num_items = dim0 * dim1;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
CUDA_CALL_THROW(cudaDeviceSynchronize());
CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost));


if (nullptr != name) {
std::cout << std::string(name) << std::endl;
Expand All @@ -64,8 +65,8 @@ template <typename T>
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, bool is_gpu_tensor) {
int num_items = dim0 * dim1 * dim2;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
CUDA_CALL_THROW(cudaDeviceSynchronize());
CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost));

if (nullptr != name) {
std::cout << std::string(name) << std::endl;
Expand All @@ -82,8 +83,8 @@ template <typename T>
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, int dim3, bool is_gpu_tensor) {
int num_items = dim0 * dim1 * dim2 * dim3;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
CUDA_CALL_THROW(cudaDeviceSynchronize());
CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost));

if (nullptr != name) {
std::cout << std::string(name) << std::endl;
Expand Down
108 changes: 54 additions & 54 deletions onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -320,33 +320,33 @@ void GetTempStorageSize(const T* d_keys_in,
bool is_descending,
size_t& temp_storage_bytes) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
} else {
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
}
}

Expand Down Expand Up @@ -412,33 +412,33 @@ void LaunchSortPairs(void* d_temp_storage,
cudaStream_t stream,
bool is_descending) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
} else {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
}
}

Expand Down Expand Up @@ -721,9 +721,9 @@ void TorchMultinomialKernelLauncher(float* d_input,
cudaStream_t stream) {
// Store the props in class variables
int device;
cudaGetDevice(&device);
CUDA_CALL_THROW(cudaGetDevice(&device));
cudaDeviceProp props;
cudaGetDeviceProperties(&props, device);
CUDA_CALL_THROW(cudaGetDeviceProperties(&props, device));

int numSM = props.multiProcessorCount;
int maxThreads = props.maxThreadsPerBlock;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,23 @@
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
#include "contrib_ops/cuda/transformers/beam_search_topk.h"
#include "contrib_ops/cuda/transformers/greedy_search_top_one.h"

// the includes would be dummy for ROCm, we will ignore them for now
#ifdef ENABLE_NVTX_PROFILE
#include "core/providers/cuda/nvtx_profile.h"
#include "core/providers/cuda/nvtx_profile_context.h"
#endif

#include "sampling_cuda_helper.h"

#ifdef DEBUG_GENERATION
#include <iostream>
#endif

using onnxruntime::cuda::ToCudaType;
using onnxruntime::cuda::TArray;
using onnxruntime::cuda::TopKImpl;

namespace onnxruntime {
namespace concurrency {
class ThreadPool;
Expand Down Expand Up @@ -203,21 +212,22 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,

// TODO(tianleiwu): we can use another stream to avoid blocking subgraph execution.
cudaStream_t cuda_stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream);
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream));


// Initialize score of first beam of each group with 0 and the rest with -1e9.
cuda::LaunchInitKernel(beam_state->beam_scores.data(), batch_size, num_beams, cuda_stream);

// copy sequence lengths to GPU
// since next_positions is only needed to update feeds after subgraph execution, so it is fine to use Async here.
if (!beam_state->next_positions.empty()) { // next_positions is empty for T5
cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream);
CUDA_CALL_THROW(cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));
}

#ifdef ENABLE_NVTX_PROFILE
Expand All @@ -234,12 +244,12 @@ void InitGreedyState(transformers::IGreedySearchState<T>* greedy_state,
initStateRange.Begin();
#endif

cudaStream_t cuda_stream = ort_stream ? reinterpret_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream);
cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream);
cudaStream_t cuda_stream = ort_stream ? reinterpret_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream));
CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream));

cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream);
CUDA_CALL_THROW(cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));

#ifdef ENABLE_NVTX_PROFILE
initStateRange.End();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include <iostream>
#endif

using onnxruntime::cuda::ToCudaType;
using onnxruntime::cuda::dispatch_blockwise_softmax_forward;

namespace onnxruntime {
namespace contrib {
namespace SamplingCudaHelper {
Expand Down
Loading