From dfdffafa451c9581b79fc81b8cd1d980437d41b4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 4 Feb 2023 05:49:06 +0000 Subject: [PATCH] first milestone --- cmake/onnxruntime_rocm_hipify.cmake | 13 --- .../cpu/transformers/subgraph_base.cc | 4 +- .../cuda/transformers/dump_cuda_tensor.cc | 17 +-- .../cuda/transformers/generation_cuda_impl.cu | 108 +++++++++--------- .../transformers/generation_device_helper.cc | 36 +++--- .../cuda/transformers/sampling_cuda_helper.h | 3 + .../contrib_ops/rocm/rocm_contrib_kernels.cc | 6 +- onnxruntime/core/framework/session_state.cc | 2 +- .../core/providers/rocm/math/softmax_impl.cu | 3 + .../contrib_ops/embed_layer_norm_op_test.cc | 7 +- onnxruntime/test/contrib_ops/sampling_test.cc | 45 ++++---- tools/ci_build/amd_hipify.py | 6 + 12 files changed, 137 insertions(+), 113 deletions(-) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 92a3260714a36..6eb315c59bc80 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -11,10 +11,6 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/multihead_attention.cc" "bert/multihead_attention.h" - "bert/embed_layer_norm.cc" - "bert/embed_layer_norm.h" - "bert/embed_layer_norm_impl.cu" - "bert/embed_layer_norm_impl.h" "bert/fast_gelu_impl.cu" "bert/fast_gelu_impl.h" "bert/fast_gelu.cc" @@ -85,17 +81,8 @@ set(contrib_ops_excluded_files "tensor/image_scaler_impl.h" "transformers/beam_search.cc" "transformers/beam_search.h" - "transformers/generation_device_helper.cc" - "transformers/generation_device_helper.h" - "transformers/generation_cuda_impl.cu" - "transformers/generation_cuda_impl.h" "transformers/greedy_search.cc" "transformers/greedy_search.h" - "transformers/sampling.cc" - "transformers/sampling.h" - "transformers/sampling_cuda_helper.h" - "transformers/dump_cuda_tensor.cc" - "transformers/dump_cuda_tensor.h" "conv_transpose_with_dynamic_pads.cc" "conv_transpose_with_dynamic_pads.h" "cuda_contrib_kernels.cc" diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index c7a2b8f0c0fc1..c8be36a41e944 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -116,7 +116,9 @@ const IExecutionProvider* Subgraph::GetProvider() const { const ExecutionProviders& providers = session_state_->GetExecutionProviders(); const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider); const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider); - const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider; + const IExecutionProvider* rocm_provider = providers.Get(onnxruntime::kRocmExecutionProvider); + const IExecutionProvider* gpu_provider = cuda_provider ? cuda_provider : rocm_provider; + const IExecutionProvider* provider = gpu_provider ? gpu_provider : cpu_provider; return provider; } diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index 741f9ac259da1..3046a58040635 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -17,12 +17,12 @@ class PinnedHostBuffer { public: PinnedHostBuffer(size_t length) : buffer_(nullptr) { - cudaHostAlloc(&buffer_, length * sizeof(T), cudaHostAllocDefault); + CUDA_CALL_THROW(cudaHostAlloc((void**)&buffer_, length * sizeof(T), cudaHostAllocDefault)); } virtual ~PinnedHostBuffer() { if (buffer_) { - cudaFreeHost(buffer_); + CUDA_CALL_THROW(cudaFreeHost(buffer_)); } } @@ -46,8 +46,9 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool i // In that case, we copy tensor data as well. It is not needed, but it keeps code simple. int num_items = dim0 * dim1; auto data = std::make_shared>(num_items); - cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); + CUDA_CALL_THROW(cudaDeviceSynchronize()); + CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost)); + if (nullptr != name) { std::cout << std::string(name) << std::endl; @@ -64,8 +65,8 @@ template void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, bool is_gpu_tensor) { int num_items = dim0 * dim1 * dim2; auto data = std::make_shared>(num_items); - cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); + CUDA_CALL_THROW(cudaDeviceSynchronize()); + CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost)); if (nullptr != name) { std::cout << std::string(name) << std::endl; @@ -82,8 +83,8 @@ template void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, int dim3, bool is_gpu_tensor) { int num_items = dim0 * dim1 * dim2 * dim3; auto data = std::make_shared>(num_items); - cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); + CUDA_CALL_THROW(cudaDeviceSynchronize()); + CUDA_CALL_THROW(cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost)); if (nullptr != name) { std::cout << std::string(name) << std::endl; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 523603a550be9..90c91228204b6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -320,33 +320,33 @@ void GetTempStorageSize(const T* d_keys_in, bool is_descending, size_t& temp_storage_bytes) { if (is_descending) { - cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, - temp_storage_bytes, - d_keys_in, - (T*)nullptr, - d_values_in, - (int*)nullptr, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } else { - cub::DeviceSegmentedRadixSort::SortPairs(nullptr, - temp_storage_bytes, - d_keys_in, - (T*)nullptr, - d_values_in, - (int*)nullptr, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } } @@ -412,33 +412,33 @@ void LaunchSortPairs(void* d_temp_storage, cudaStream_t stream, bool is_descending) { if (is_descending) { - cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_keys_out, - d_values_in, - d_values_out, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } else { - cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, - temp_storage_bytes, - d_keys_in, - d_keys_out, - d_values_in, - d_values_out, - num_items, - num_segments, - d_offsets, - d_offsets + 1, - 0, - sizeof(T) * 8, - stream); + CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream)); } } @@ -721,9 +721,9 @@ void TorchMultinomialKernelLauncher(float* d_input, cudaStream_t stream) { // Store the props in class variables int device; - cudaGetDevice(&device); + CUDA_CALL_THROW(cudaGetDevice(&device)); cudaDeviceProp props; - cudaGetDeviceProperties(&props, device); + CUDA_CALL_THROW(cudaGetDeviceProperties(&props, device)); int numSM = props.multiProcessorCount; int maxThreads = props.maxThreadsPerBlock; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 1a5a9ac5d97b2..c5983e0260650 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -17,14 +17,23 @@ #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" #include "contrib_ops/cuda/transformers/greedy_search_top_one.h" + +// the includes would be dummy for ROCm, we will ignore them for now +#ifdef ENABLE_NVTX_PROFILE #include "core/providers/cuda/nvtx_profile.h" #include "core/providers/cuda/nvtx_profile_context.h" +#endif + #include "sampling_cuda_helper.h" #ifdef DEBUG_GENERATION #include #endif +using onnxruntime::cuda::ToCudaType; +using onnxruntime::cuda::TArray; +using onnxruntime::cuda::TopKImpl; + namespace onnxruntime { namespace concurrency { class ThreadPool; @@ -203,12 +212,13 @@ void InitBeamState(transformers::IBeamSearchState* beam_state, // TODO(tianleiwu): we can use another stream to avoid blocking subgraph execution. cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream); - cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->next_scores.data(), 0, beam_state->next_scores.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(beam_state->topk_buffer.data(), 0, beam_state->topk_buffer.size_bytes(), cuda_stream)); + // Initialize score of first beam of each group with 0 and the rest with -1e9. cuda::LaunchInitKernel(beam_state->beam_scores.data(), batch_size, num_beams, cuda_stream); @@ -216,8 +226,8 @@ void InitBeamState(transformers::IBeamSearchState* beam_state, // copy sequence lengths to GPU // since next_positions is only needed to update feeds after subgraph execution, so it is fine to use Async here. if (!beam_state->next_positions.empty()) { // next_positions is empty for T5 - cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), - cudaMemcpyHostToDevice, cuda_stream); + CUDA_CALL_THROW(cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), + cudaMemcpyHostToDevice, cuda_stream)); } #ifdef ENABLE_NVTX_PROFILE @@ -234,12 +244,12 @@ void InitGreedyState(transformers::IGreedySearchState* greedy_state, initStateRange.Begin(); #endif - cudaStream_t cuda_stream = ort_stream ? reinterpret_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream); - cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream); + cudaStream_t cuda_stream = ort_stream ? reinterpret_cast(ort_stream->GetHandle()) : nullptr; + CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_token_scores.data(), 0, greedy_state->next_token_scores.size_bytes(), cuda_stream)); + CUDA_CALL_THROW(cudaMemsetAsync(greedy_state->next_positions.data(), 0, greedy_state->next_positions.size_bytes(), cuda_stream)); - cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), - cudaMemcpyHostToDevice, cuda_stream); + CUDA_CALL_THROW(cudaMemcpyAsync(greedy_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), + cudaMemcpyHostToDevice, cuda_stream)); #ifdef ENABLE_NVTX_PROFILE initStateRange.End(); diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h index d82648890f94f..67b64fee21895 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -11,6 +11,9 @@ #include #endif +using onnxruntime::cuda::ToCudaType; +using onnxruntime::cuda::dispatch_blockwise_softmax_forward; + namespace onnxruntime { namespace contrib { namespace SamplingCudaHelper { diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index b92efc3a6109a..e056c8cbfb64d 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -69,6 +69,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); @@ -166,8 +167,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { 1, MLFloat16, DecoderAttention)>, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, @@ -178,6 +179,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index facce93cde798..d9c02702d733e 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1004,7 +1004,7 @@ Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider) { + if (!ep.empty() && ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && ep != kRocmExecutionProvider) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/providers/rocm/math/softmax_impl.cu b/onnxruntime/core/providers/rocm/math/softmax_impl.cu index f5a26ef045881..9431e4327d770 100644 --- a/onnxruntime/core/providers/rocm/math/softmax_impl.cu +++ b/onnxruntime/core/providers/rocm/math/softmax_impl.cu @@ -129,6 +129,9 @@ SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, half, float) SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(double, double, double) SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(BFloat16, BFloat16, float) +#ifndef DISABLE_CONTRIB_OPS +SPECIALIZED_BLOCKWISE_SOFTMAX_IMPL(half, float, float) // used by BeamSearch op +#endif } } diff --git a/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc index be384a20a3190..884f4422d5d8b 100644 --- a/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc @@ -17,10 +17,11 @@ static void RunTest(const embedlayernorm::OpData& data, int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = DefaultRocmExecutionProvider().get() != nullptr; bool enable_dml = DefaultDmlExecutionProvider().get() != nullptr; bool enable_cpu = !use_float16; - if (enable_cpu || enable_cuda || enable_dml) { + if (enable_cpu || enable_cuda || enable_dml || enable_rocm) { // Input and output shapes // Input 0 - input_ids : (batch_size, sequence_size) // Input 1 - segment_ids : (batch_size, sequence_size) @@ -143,6 +144,10 @@ static void RunTest(const embedlayernorm::OpData& data, std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else if (enable_rocm) { + std::vector> execution_providers; + execution_providers.push_back(DefaultRocmExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (enable_dml) { std::vector> execution_providers; execution_providers.push_back(DefaultDmlExecutionProvider()); diff --git a/onnxruntime/test/contrib_ops/sampling_test.cc b/onnxruntime/test/contrib_ops/sampling_test.cc index 48992f24a3234..e0ad20415ca5f 100644 --- a/onnxruntime/test/contrib_ops/sampling_test.cc +++ b/onnxruntime/test/contrib_ops/sampling_test.cc @@ -14,8 +14,8 @@ namespace onnxruntime { namespace test { #if defined(__linux__) && !defined(__ANDROID__) -#ifdef USE_CUDA -TEST(SamplingTest, Gpt2Sampling_CUDA) { +#if defined(USE_CUDA) || defined(USE_ROCM) +TEST(SamplingTest, Gpt2Sampling_GPU) { std::vector input_ids{ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, @@ -25,7 +25,6 @@ TEST(SamplingTest, Gpt2Sampling_CUDA) { std::vector min_length{1}; std::vector repetition_penalty{1.0f}; - std::vector expected_output{ 0, 0, 0, 0, 0, 52, 195, 731, 321, 301, 734, 620, 125, 543, 668, 41, 554, 74, 622, 206, 222, 75, 223, 221, 198, 224, 572, 776, 213, 697, @@ -35,9 +34,7 @@ TEST(SamplingTest, Gpt2Sampling_CUDA) { const int64_t sequence_length = 12; std::vector input_ids_shape{batch_size, sequence_length}; - std::vector parameter_shape{1}; - std::vector expected_output_shape{input_ids_shape[0], max_length[0]}; Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); @@ -62,28 +59,36 @@ TEST(SamplingTest, Gpt2Sampling_CUDA) { const char* const output_names[] = {"sequences"}; Ort::SessionOptions session_options; +#ifdef USE_CUDA constexpr int min_cuda_architecture = 530; - if (HasCudaEnvironment(min_cuda_architecture)) { - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support current architecture"; + return; + } + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); +#else // USE_ROCM + OrtROCMProviderOptions rocm_options; + // TODO - verify the default settings + session_options.AppendExecutionProvider_ROCM(rocm_options); +#endif - Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options); + Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_sampling.onnx"), session_options); - auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), - output_names, 1); + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names, ort_inputs.data(), ort_inputs.size(), + output_names, 1); - ASSERT_EQ(ort_outputs.size(), 1U); - const auto& sequences = ort_outputs[0]; - ASSERT_TRUE(sequences.IsTensor()); + ASSERT_EQ(ort_outputs.size(), 1U); + const auto& sequences = ort_outputs[0]; + ASSERT_TRUE(sequences.IsTensor()); - auto result_ts = sequences.GetTensorTypeAndShapeInfo(); - ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType()); + auto result_ts = sequences.GetTensorTypeAndShapeInfo(); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, result_ts.GetElementType()); - ASSERT_EQ(expected_output_shape, result_ts.GetShape()); - const auto* result_vals = sequences.GetTensorData(); - auto result_span = gsl::make_span(result_vals, expected_output.size()); + ASSERT_EQ(expected_output_shape, result_ts.GetShape()); + const auto* result_vals = sequences.GetTensorData(); + auto result_span = gsl::make_span(result_vals, expected_output.size()); - ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end())); - } + ASSERT_TRUE(std::equal(expected_output.cbegin(), expected_output.cend(), result_span.begin(), result_span.end())); } #endif diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index e6a5f8f8cc38a..08b258406ffd3 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -59,6 +59,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("GPU_WARP_SIZE = 32", "GPU_WARP_SIZE = 64") s = s.replace("std::exp", "expf") s = s.replace("std::log", "logf") + s = s.replace("WaitCudaNotificationOnDevice", "WaitRocmNotificationOnDevice") + s = s.replace("hipHostAlloc", "hipHostMalloc") s = s.replace( "#include ", "#include \n#include ", @@ -67,6 +69,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): '#include "cub/device/device_radix_sort.cuh"', "#include \n#include ", ) + s = s.replace( + "#include ", + "#include ", + ) s = s.replace( "#include ", "#include " )