From 0e71e72479f05cfc5635b5bd9765d2f5472e930c Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Wed, 11 Dec 2024 08:29:49 -0800 Subject: [PATCH] [hip] Update the hip runtime so that multiple physical devices can back a logical device. (#18790) This gives us an interface for creating a logical device from a set of physical hip devices. Each physical device in the new logical device is represented as a queue in the logical device and therefore can be accessed by specifying a queue affinity for operations and buffers. This is a pretty substantial change to how hip devices are handled. In general: 1. Synchronization is updated so that we no longer have to do a CPU round trip for GPU<->GPU synchronization. 2. The deferred work queue is no longer in use in HIP at all. 3. We have 2 much simpler helper threads per device. They handle keeping high-cost operations off the main thread. Specifically, async memory operations, and queue executions which are emulated in HIP, and have a relatively high cost. Signed-off-by: Andrew Woloszyn --- .../iree/hal/cts/semaphore_submission_test.h | 23 +- .../src/iree/hal/drivers/cuda/cuda_device.c | 4 +- .../src/iree/hal/drivers/hip/CMakeLists.txt | 16 +- .../src/iree/hal/drivers/hip/cleanup_thread.c | 170 ++ .../src/iree/hal/drivers/hip/cleanup_thread.h | 40 + .../src/iree/hal/drivers/hip/context_util.h | 4 +- .../iree/hal/drivers/hip/cts/CMakeLists.txt | 161 ++ .../hip/cts/multi_queue_device_creation.h | 72 + .../iree/hal/drivers/hip/dispatch_thread.c | 175 ++ .../iree/hal/drivers/hip/dispatch_thread.h | 58 + .../hal/drivers/hip/dynamic_symbol_tables.h | 5 +- runtime/src/iree/hal/drivers/hip/event_pool.c | 65 +- runtime/src/iree/hal/drivers/hip/event_pool.h | 12 +- .../iree/hal/drivers/hip/event_semaphore.c | 1133 ++++++---- .../iree/hal/drivers/hip/event_semaphore.h | 73 +- .../hal/drivers/hip/graph_command_buffer.c | 19 +- .../hal/drivers/hip/graph_command_buffer.h | 11 +- .../src/iree/hal/drivers/hip/hip_allocator.c | 136 +- .../src/iree/hal/drivers/hip/hip_allocator.h | 13 +- runtime/src/iree/hal/drivers/hip/hip_buffer.c | 1 + runtime/src/iree/hal/drivers/hip/hip_buffer.h | 8 - runtime/src/iree/hal/drivers/hip/hip_device.c | 1836 ++++++++++++----- runtime/src/iree/hal/drivers/hip/hip_device.h | 29 +- runtime/src/iree/hal/drivers/hip/hip_driver.c | 249 ++- .../hip/hip_multi_queue_command_buffer.c | 356 ++++ .../hip/hip_multi_queue_command_buffer.h | 48 + .../src/iree/hal/drivers/hip/memory_pools.c | 44 +- .../src/iree/hal/drivers/hip/memory_pools.h | 10 - .../iree/hal/drivers/hip/native_executable.c | 327 +-- .../iree/hal/drivers/hip/native_executable.h | 15 +- .../hal/drivers/hip/nop_executable_cache.c | 15 +- .../hal/drivers/hip/nop_executable_cache.h | 13 +- .../hal/drivers/hip/per_device_information.h | 37 + .../src/iree/hal/drivers/hip/rccl_channel.h | 8 - .../iree/hal/drivers/hip/rccl_status_util.h | 8 - .../src/iree/hal/drivers/hip/status_util.h | 20 +- .../hal/drivers/hip/stream_command_buffer.c | 62 +- .../hal/drivers/hip/stream_command_buffer.h | 20 +- .../src/iree/hal/drivers/hip/timepoint_pool.c | 352 ---- .../src/iree/hal/drivers/hip/timepoint_pool.h | 119 -- .../iree/hal/drivers/hip/util/CMakeLists.txt | 41 + runtime/src/iree/hal/drivers/hip/util/queue.c | 88 + runtime/src/iree/hal/drivers/hip/util/queue.h | 100 + .../iree/hal/drivers/hip/util/queue_test.cc | 111 + runtime/src/iree/hal/drivers/hip/util/tree.c | 578 ++++++ runtime/src/iree/hal/drivers/hip/util/tree.h | 151 ++ .../iree/hal/drivers/hip/util/tree_test.cc | 180 ++ .../iree/hal/drivers/local_sync/sync_device.c | 4 +- .../iree/hal/drivers/local_task/task_device.c | 4 +- .../src/iree/hal/drivers/metal/metal_device.m | 4 +- .../iree/hal/drivers/vulkan/vulkan_device.cc | 2 +- runtime/src/iree/hal/queue.h | 1 + .../iree/hal/utils/deferred_command_buffer.c | 6 +- .../iree/hal/utils/deferred_command_buffer.h | 4 +- runtime/src/iree/hal/utils/stream_tracing.c | 116 +- runtime/src/iree/hal/utils/stream_tracing.h | 13 +- third_party/llvm-project | 2 +- 57 files changed, 5146 insertions(+), 2026 deletions(-) create mode 100644 runtime/src/iree/hal/drivers/hip/cleanup_thread.c create mode 100644 runtime/src/iree/hal/drivers/hip/cleanup_thread.h create mode 100644 runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h create mode 100644 runtime/src/iree/hal/drivers/hip/dispatch_thread.c create mode 100644 runtime/src/iree/hal/drivers/hip/dispatch_thread.h create mode 100644 runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.c create mode 100644 runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.h create mode 100644 runtime/src/iree/hal/drivers/hip/per_device_information.h delete mode 100644 runtime/src/iree/hal/drivers/hip/timepoint_pool.c delete mode 100644 runtime/src/iree/hal/drivers/hip/timepoint_pool.h create mode 100644 runtime/src/iree/hal/drivers/hip/util/CMakeLists.txt create mode 100644 runtime/src/iree/hal/drivers/hip/util/queue.c create mode 100644 runtime/src/iree/hal/drivers/hip/util/queue.h create mode 100644 runtime/src/iree/hal/drivers/hip/util/queue_test.cc create mode 100644 runtime/src/iree/hal/drivers/hip/util/tree.c create mode 100644 runtime/src/iree/hal/drivers/hip/util/tree.h create mode 100644 runtime/src/iree/hal/drivers/hip/util/tree_test.cc diff --git a/runtime/src/iree/hal/cts/semaphore_submission_test.h b/runtime/src/iree/hal/cts/semaphore_submission_test.h index 094368190aac..62c7ae442af2 100644 --- a/runtime/src/iree/hal/cts/semaphore_submission_test.h +++ b/runtime/src/iree/hal/cts/semaphore_submission_test.h @@ -31,10 +31,9 @@ TEST_F(SemaphoreSubmissionTest, SubmitWithNoCommandBuffers) { signal_payload_values, }; - IREE_ASSERT_OK(iree_hal_device_queue_barrier(device_, - /*queue_affinity=*/0, - iree_hal_semaphore_list_empty(), - signal_semaphores)); + IREE_ASSERT_OK(iree_hal_device_queue_barrier( + device_, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(), + signal_semaphores)); IREE_ASSERT_OK( iree_hal_semaphore_wait(signal_semaphore, 1, iree_infinite_timeout())); @@ -54,9 +53,9 @@ TEST_F(SemaphoreSubmissionTest, SubmitAndSignal) { }; IREE_ASSERT_OK(iree_hal_device_queue_execute( - device_, - /*queue_affinity=*/0, iree_hal_semaphore_list_empty(), signal_semaphores, - command_buffer, iree_hal_buffer_binding_table_empty())); + device_, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(), + signal_semaphores, command_buffer, + iree_hal_buffer_binding_table_empty())); IREE_ASSERT_OK( iree_hal_semaphore_wait(signal_semaphore, 1, iree_infinite_timeout())); @@ -87,9 +86,8 @@ TEST_F(SemaphoreSubmissionTest, SubmitWithWait) { }; IREE_ASSERT_OK(iree_hal_device_queue_execute( - device_, - /*queue_affinity=*/0, wait_semaphores, signal_semaphores, command_buffer, - iree_hal_buffer_binding_table_empty())); + device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, signal_semaphores, + command_buffer, iree_hal_buffer_binding_table_empty())); // Work shouldn't start until the wait semaphore reaches its payload value. CheckSemaphoreValue(signal_semaphore, 100); @@ -130,9 +128,8 @@ TEST_F(SemaphoreSubmissionTest, SubmitWithMultipleSemaphores) { }; IREE_ASSERT_OK(iree_hal_device_queue_execute( - device_, - /*queue_affinity=*/0, wait_semaphores, signal_semaphores, command_buffer, - iree_hal_buffer_binding_table_empty())); + device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, signal_semaphores, + command_buffer, iree_hal_buffer_binding_table_empty())); // Work shouldn't start until all wait semaphores reach their payload values. CheckSemaphoreValue(signal_semaphore_1, 0); diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 9c9c7c61e02c..5ff3cdea0cce 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -854,7 +854,7 @@ static iree_status_t iree_hal_cuda_device_create_command_buffer( if (binding_capacity > 0) { return iree_hal_deferred_command_buffer_create( iree_hal_device_allocator(base_device), mode, command_categories, - binding_capacity, &device->block_pool, + queue_affinity, binding_capacity, &device->block_pool, iree_hal_device_host_allocator(base_device), out_command_buffer); } else { return iree_hal_cuda_graph_command_buffer_create( @@ -867,7 +867,7 @@ static iree_status_t iree_hal_cuda_device_create_command_buffer( case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM: { return iree_hal_deferred_command_buffer_create( iree_hal_device_allocator(base_device), mode, command_categories, - binding_capacity, &device->block_pool, + queue_affinity, binding_capacity, &device->block_pool, iree_hal_device_host_allocator(base_device), out_command_buffer); } default: { diff --git a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt index 7fd5abf07294..15e375c1fd2c 100644 --- a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt @@ -21,7 +21,10 @@ iree_cc_library( "api.h" SRCS "api.h" - "context_util.h" + "cleanup_thread.c" + "cleanup_thread.h" + "dispatch_thread.c" + "dispatch_thread.h" "event_pool.c" "event_pool.h" "event_semaphore.c" @@ -32,21 +35,22 @@ iree_cc_library( "hip_allocator.h" "hip_buffer.c" "hip_buffer.h" - "hip_device.c" - "hip_device.h" "hip_driver.c" + "hip_device.h" + "hip_device.c" + "hip_multi_queue_command_buffer.h" + "hip_multi_queue_command_buffer.c" "memory_pools.c" "memory_pools.h" "native_executable.c" "native_executable.h" "nop_executable_cache.c" "nop_executable_cache.h" + "per_device_information.h" "rccl_channel.c" "rccl_channel.h" "stream_command_buffer.c" "stream_command_buffer.h" - "timepoint_pool.c" - "timepoint_pool.h" INCLUDES "${HIP_API_HEADERS_ROOT}" DEPS @@ -65,12 +69,12 @@ iree_cc_library( iree::hal::utils::collective_batch iree::hal::utils::executable_debug_info iree::hal::utils::deferred_command_buffer - iree::hal::utils::deferred_work_queue iree::hal::utils::file_transfer iree::hal::utils::memory_file iree::hal::utils::resource_set iree::hal::utils::semaphore_base iree::hal::utils::stream_tracing + iree::hal::drivers::hip::util::hip_util iree::schemas::executable_debug_info_c_fbs iree::schemas::hip_executable_def_c_fbs PUBLIC diff --git a/runtime/src/iree/hal/drivers/hip/cleanup_thread.c b/runtime/src/iree/hal/drivers/hip/cleanup_thread.c new file mode 100644 index 000000000000..95c7e206b7b9 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/cleanup_thread.c @@ -0,0 +1,170 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/hip/cleanup_thread.h" + +#include "iree/base/internal/synchronization.h" +#include "iree/base/internal/threading.h" +#include "iree/hal/drivers/hip/event_pool.h" +#include "iree/hal/drivers/hip/status_util.h" +#include "iree/hal/drivers/hip/util/queue.h" + +#define iree_hal_hip_cleanup_thread_default_queue_size 64 + +typedef struct iree_hal_hip_cleanup_thread_callback_t { + iree_hal_hip_cleanup_callback_t callback; + void* user_data; + iree_hal_hip_event_t* event; +} iree_hal_hip_cleanup_thread_callback_t; + +IREE_HAL_HIP_UTIL_TYPED_QUEUE_WRAPPER( + iree_hal_hip_callback_queue, iree_hal_hip_cleanup_thread_callback_t, + iree_hal_hip_cleanup_thread_default_queue_size); + +typedef struct iree_hal_hip_cleanup_thread_t { + iree_thread_t* thread; + iree_allocator_t host_allocator; + const iree_hal_hip_dynamic_symbols_t* symbols; + iree_slim_mutex_t mutex; + + iree_hal_hip_callback_queue_t queue; + iree_status_t failure_status; + iree_notification_t notification; + bool do_exit; +} iree_hal_hip_cleanup_thread_t; + +static bool iree_hal_hip_cleanup_thread_has_request(void* user_data) { + iree_hal_hip_cleanup_thread_t* thread = + (iree_hal_hip_cleanup_thread_t*)user_data; + iree_slim_mutex_lock(&thread->mutex); + bool has_request = !iree_hal_hip_callback_queue_empty(&thread->queue); + has_request |= thread->do_exit; + iree_slim_mutex_unlock(&thread->mutex); + return has_request; +} + +static int iree_hal_hip_cleanup_thread_main(void* param) { + iree_hal_hip_cleanup_thread_t* thread = (iree_hal_hip_cleanup_thread_t*)param; + bool exit = false; + while (true) { + iree_notification_await(&thread->notification, + &iree_hal_hip_cleanup_thread_has_request, thread, + iree_infinite_timeout()); + + iree_slim_mutex_lock(&thread->mutex); + exit |= thread->do_exit; + iree_status_t status = thread->failure_status; + while (!iree_hal_hip_callback_queue_empty(&thread->queue)) { + iree_hal_hip_cleanup_thread_callback_t callback = + iree_hal_hip_callback_queue_at(&thread->queue, 0); + iree_hal_hip_callback_queue_pop_front(&thread->queue, 1); + iree_slim_mutex_unlock(&thread->mutex); + + if (iree_status_is_ok(status)) { + status = IREE_HIP_CALL_TO_STATUS( + thread->symbols, + hipEventSynchronize(iree_hal_hip_event_handle(callback.event))); + } + + status = iree_status_join( + status, + callback.callback(callback.user_data, callback.event, status)); + iree_slim_mutex_lock(&thread->mutex); + if (!iree_status_is_ok(status)) { + thread->failure_status = status; + } + if (!iree_status_is_ok(thread->failure_status)) { + status = iree_status_clone(thread->failure_status); + } + } + iree_slim_mutex_unlock(&thread->mutex); + + if (!iree_status_is_ok(status) || exit) { + break; + } + } + return 0; +} + +iree_status_t iree_hal_hip_cleanup_thread_initialize( + const iree_hal_hip_dynamic_symbols_t* symbols, + iree_allocator_t host_allocator, + iree_hal_hip_cleanup_thread_t** out_thread) { + IREE_TRACE_ZONE_BEGIN(z0); + *out_thread = NULL; + iree_hal_hip_cleanup_thread_t* thread = NULL; + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, sizeof(*thread), (void**)&thread)); + + thread->symbols = symbols; + thread->do_exit = false; + iree_slim_mutex_initialize(&thread->mutex); + iree_hal_hip_callback_queue_initialize(host_allocator, &thread->queue); + thread->failure_status = iree_ok_status(); + thread->host_allocator = host_allocator; + iree_notification_initialize(&thread->notification); + + iree_thread_create_params_t params; + memset(¶ms, 0x00, sizeof(params)); + params.name = iree_make_cstring_view("iree-hal-hip-cleanup"); + iree_status_t status = + iree_thread_create((iree_thread_entry_t)iree_hal_hip_cleanup_thread_main, + thread, params, host_allocator, &thread->thread); + if (iree_status_is_ok(status)) { + *out_thread = thread; + } else { + iree_hal_hip_callback_queue_deinitialize(&thread->queue); + iree_slim_mutex_deinitialize(&thread->mutex); + iree_allocator_free(host_allocator, thread); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hip_cleanup_thread_deinitialize( + iree_hal_hip_cleanup_thread_t* thread) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_slim_mutex_lock(&thread->mutex); + thread->do_exit = true; + iree_slim_mutex_unlock(&thread->mutex); + + iree_notification_post(&thread->notification, IREE_ALL_WAITERS); + // There is only one owner for the thread, so this also joins the thread. + iree_thread_release(thread->thread); + + iree_hal_hip_callback_queue_deinitialize(&thread->queue); + iree_slim_mutex_deinitialize(&thread->mutex); + iree_allocator_free(thread->host_allocator, thread); + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_hip_cleanup_thread_add_cleanup( + iree_hal_hip_cleanup_thread_t* thread, iree_hal_hip_event_t* event, + iree_hal_hip_cleanup_callback_t callback, void* user_data) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_slim_mutex_lock(&thread->mutex); + if (!iree_status_is_ok(thread->failure_status)) { + IREE_TRACE_ZONE_END(z0); + iree_slim_mutex_unlock(&thread->mutex); + return thread->failure_status; + } + + iree_hal_hip_cleanup_thread_callback_t callback_data = { + .callback = callback, + .user_data = user_data, + .event = event, + }; + iree_hal_hip_callback_queue_push_back(&thread->queue, callback_data); + iree_slim_mutex_unlock(&thread->mutex); + iree_notification_post(&thread->notification, IREE_ALL_WAITERS); + + IREE_TRACE_ZONE_END(z0); + + return iree_ok_status(); +} diff --git a/runtime/src/iree/hal/drivers/hip/cleanup_thread.h b/runtime/src/iree/hal/drivers/hip/cleanup_thread.h new file mode 100644 index 000000000000..321d12953ac1 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/cleanup_thread.h @@ -0,0 +1,40 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_CLEANUP_THREAD_H_ +#define IREE_HAL_DRIVERS_HIP_CLEANUP_THREAD_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/drivers/hip/dynamic_symbols.h" + +typedef struct iree_hal_hip_cleanup_thread_t iree_hal_hip_cleanup_thread_t; +typedef struct iree_hal_hip_event_t iree_hal_hip_event_t; + +typedef iree_status_t (*iree_hal_hip_cleanup_callback_t)( + void* user_data, iree_hal_hip_event_t* event, iree_status_t status); + +// Initializes the cleanup thread for HIP driver. +iree_status_t iree_hal_hip_cleanup_thread_initialize( + const iree_hal_hip_dynamic_symbols_t* symbols, + iree_allocator_t host_allocator, + iree_hal_hip_cleanup_thread_t** out_thread); + +// Deinitializes the cleanup thread for HIP driver. +void iree_hal_hip_cleanup_thread_deinitialize( + iree_hal_hip_cleanup_thread_t* thread); + +// Adds a pending cleanup to the thread. +// +// The thread will wait on the event and fire the callback, +// once the event has completed. +// |user_data| must remain valid until the callback is called, +// and it is up to the callee to clean up user_data if required. +iree_status_t iree_hal_hip_cleanup_thread_add_cleanup( + iree_hal_hip_cleanup_thread_t* thread, iree_hal_hip_event_t* event, + iree_hal_hip_cleanup_callback_t callback, void* user_data); + +#endif // IREE_HAL_DRIVERS_HIP_CLEANUP_THREAD_H_ diff --git a/runtime/src/iree/hal/drivers/hip/context_util.h b/runtime/src/iree/hal/drivers/hip/context_util.h index 1aa1d79b4c28..e7bd3ed0d438 100644 --- a/runtime/src/iree/hal/drivers/hip/context_util.h +++ b/runtime/src/iree/hal/drivers/hip/context_util.h @@ -23,12 +23,12 @@ static inline iree_status_t iree_hal_hip_set_context( if (current_context != hip_context) { IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_set_context_switch"); iree_status_t status = - IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); + IREE_HIP_CALL_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); IREE_TRACE_ZONE_END(z0); return status; } }); - return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); + return IREE_HIP_CALL_TO_STATUS(syms, hipCtxSetCurrent(hip_context)); } #endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/hip/cts/CMakeLists.txt index ea1c7a3726ac..ed4dbdd7461e 100644 --- a/runtime/src/iree/hal/drivers/hip/cts/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/hip/cts/CMakeLists.txt @@ -35,6 +35,13 @@ iree_hal_cts_test_suite( EXCLUDED_TESTS # HAL event is unimplemented for now. "event" + # These tests fail with: + # UNAVAILABLE; missing hipDrvGraphAddMemcpyNode symbol; + # cannot use graph-based command buffer + "command_buffer_copy_buffer" + "command_buffer_dispatch" + "command_buffer_update_buffer" + "file" LABELS driver=hip requires-gpu-amd @@ -60,6 +67,8 @@ iree_hal_cts_test_suite( DEPS iree::hal::drivers::hip::registration EXCLUDED_TESTS + # HAL event is unimplemented for now. + "event" # These tests fail with: # UNAVAILABLE; missing hipDrvGraphAddMemcpyNode symbol; # cannot use graph-based command buffer @@ -67,8 +76,160 @@ iree_hal_cts_test_suite( "command_buffer_dispatch" "command_buffer_update_buffer" "file" + LABELS + driver=hip +) + +iree_hal_cts_test_suite( + DRIVER_NAME + hip + VARIANT_SUFFIX + multi_queue_stream + DRIVER_REGISTRATION_HDR + "runtime/src/iree/hal/drivers/hip/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_hip_driver_module_register" + DEVICE_CREATION_HDR + "runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h" + DEFAULT_DEVICE_CREATION_FN + "iree_hal_drivers_hip_cts_default_multi_queue_create" + COMPILER_TARGET_BACKEND + "rocm" + EXECUTABLE_FORMAT + "\"HSACO\"" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + ARGS + "--hip_use_streams=true" + DEPS + iree::hal::drivers::hip::registration + EXCLUDED_TESTS # HAL event is unimplemented for now. "event" + # These tests fail with: + # UNAVAILABLE; missing hipDrvGraphAddMemcpyNode symbol; + # cannot use graph-based command buffer + "command_buffer_copy_buffer" + "command_buffer_dispatch" + "command_buffer_update_buffer" + "file" + LABELS + driver=hip + requires-gpu-amd +) + +iree_hal_cts_test_suite( + DRIVER_NAME + hip + VARIANT_SUFFIX + multi_queue_graph + DRIVER_REGISTRATION_HDR + "runtime/src/iree/hal/drivers/hip/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_hip_driver_module_register" + DEVICE_CREATION_HDR + "runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h" + DEFAULT_DEVICE_CREATION_FN + "iree_hal_drivers_hip_cts_default_multi_queue_create" + COMPILER_TARGET_BACKEND + "rocm" + EXECUTABLE_FORMAT + "\"HSACO\"" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + ARGS + "--hip_use_streams=false" + DEPS + iree::hal::drivers::hip::registration + EXCLUDED_TESTS + # HAL event is unimplemented for now. + "event" + # These tests fail with: + # UNAVAILABLE; missing hipDrvGraphAddMemcpyNode symbol; + # cannot use graph-based command buffer + "command_buffer_copy_buffer" + "command_buffer_dispatch" + "command_buffer_update_buffer" + "file" + LABELS + driver=hip +) + +iree_hal_cts_test_suite( + DRIVER_NAME + hip + VARIANT_SUFFIX + multi_queue_stream_queue_1 + DRIVER_REGISTRATION_HDR + "runtime/src/iree/hal/drivers/hip/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_hip_driver_module_register" + DEVICE_CREATION_HDR + "runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h" + DEFAULT_DEVICE_CREATION_FN + "iree_hal_drivers_hip_cts_default_multi_queue_create" + COMPILER_TARGET_BACKEND + "rocm" + EXECUTABLE_FORMAT + "\"HSACO\"" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + DEFAULT_SUBMIT_QUEUE_AFFINITY + "(1 << 1)" + ARGS + "--hip_use_streams=true" + DEPS + iree::hal::drivers::hip::registration + EXCLUDED_TESTS + # HAL event is unimplemented for now. + "event" + # These tests fail with: + # UNAVAILABLE; missing hipDrvGraphAddMemcpyNode symbol; + # cannot use graph-based command buffer + "command_buffer_copy_buffer" + "command_buffer_dispatch" + "command_buffer_update_buffer" + "file" + LABELS + driver=hip + requires-gpu-amd +) + +iree_hal_cts_test_suite( + DRIVER_NAME + hip + VARIANT_SUFFIX + multi_queue_graph_queue_1 + DRIVER_REGISTRATION_HDR + "runtime/src/iree/hal/drivers/hip/registration/driver_module.h" + DRIVER_REGISTRATION_FN + "iree_hal_hip_driver_module_register" + DEVICE_CREATION_HDR + "runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h" + DEFAULT_DEVICE_CREATION_FN + "iree_hal_drivers_hip_cts_default_multi_queue_create" + DEFAULT_SUBMIT_QUEUE_AFFINITY + "(1 << 1)" + COMPILER_TARGET_BACKEND + "rocm" + EXECUTABLE_FORMAT + "\"HSACO\"" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + ARGS + "--hip_use_streams=false" + DEPS + iree::hal::drivers::hip::registration + EXCLUDED_TESTS + # HAL event is unimplemented for now. + "event" + # These tests fail with: + # UNAVAILABLE; missing hipDrvGraphAddMemcpyNode symbol; + # cannot use graph-based command buffer + "command_buffer_copy_buffer" + "command_buffer_dispatch" + "command_buffer_update_buffer" + "file" LABELS driver=hip ) diff --git a/runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h b/runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h new file mode 100644 index 000000000000..c24335afff69 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/cts/multi_queue_device_creation.h @@ -0,0 +1,72 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_REGISTRATION_MULTI_QUEUE_H_ +#define IREE_HAL_DRIVERS_HIP_REGISTRATION_MULTI_QUEUE_H_ + +#include + +#include "iree/hal/driver.h" +#include "iree/testing/status_matchers.h" + +inline iree_status_t iree_hal_drivers_hip_cts_default_multi_queue_create( + iree_hal_driver_t* driver, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + *out_device = NULL; + std::multimap grouped_devices; + + iree_host_size_t device_info_count = 0; + iree_hal_device_info_t* device_infos = NULL; + IREE_RETURN_IF_ERROR(iree_hal_driver_query_available_devices( + driver, iree_allocator_system(), &device_info_count, &device_infos)); + + for (iree_host_size_t i = 0; i < device_info_count; ++i) { + const char* nm = device_infos[i].name.data; + iree_host_size_t size = device_infos[i].name.size; + + std::string name(nm, size); + grouped_devices.insert(std::make_pair(name, i)); + } + + std::string path; + iree_host_size_t max_valid_devices = 0; + for (auto it = grouped_devices.begin(); it != grouped_devices.end(); + /*empty on purpose*/) { + iree_host_size_t device_count = grouped_devices.count(it->first); + if (device_count == 1) { + ++it; + continue; + } + if (device_count <= max_valid_devices) { + for (iree_host_size_t j = 0; j < device_count; ++j) { + // No += for multimap iterator. + ++it; + } + continue; + } + path = ""; + for (iree_host_size_t i = 0; i < device_count; ++i) { + if (i > 0) { + path += ","; + } + path += std::to_string(it->second); + ++it; + } + max_valid_devices = device_count; + } + + iree_allocator_free(iree_allocator_system(), device_infos); + + if (!max_valid_devices) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "No device group found on the system"); + } + return iree_hal_driver_create_device_by_path( + driver, IREE_SV("hip"), IREE_SV(path.c_str()), /*param_count=*/0, + /*params=*/NULL, iree_allocator_system(), out_device); +} + +#endif // IREE_HAL_DRIVERS_HIP_REGISTRATION_MULTI_QUEUE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/dispatch_thread.c b/runtime/src/iree/hal/drivers/hip/dispatch_thread.c new file mode 100644 index 000000000000..23313a68836f --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/dispatch_thread.c @@ -0,0 +1,175 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/hip/dispatch_thread.h" + +#include "iree/base/internal/synchronization.h" +#include "iree/base/internal/threading.h" +#include "iree/hal/drivers/hip/event_pool.h" +#include "iree/hal/drivers/hip/status_util.h" +#include "iree/hal/drivers/hip/util/queue.h" + +#define iree_hal_hip_dispatch_thread_default_queue_size 64 + +typedef struct iree_hal_hip_dispatch_thread_dispatch_t { + iree_hal_hip_dispatch_callback_t dispatch; + void* user_data; +} iree_hal_hip_dispatch_thread_dispatch_t; + +IREE_HAL_HIP_UTIL_TYPED_QUEUE_WRAPPER( + iree_hal_hip_dispatch_queue, iree_hal_hip_dispatch_thread_dispatch_t, + iree_hal_hip_dispatch_thread_default_queue_size); + +typedef struct iree_hal_hip_dispatch_thread_t { + iree_thread_t* thread; + iree_allocator_t host_allocator; + iree_slim_mutex_t mutex; + + iree_hal_hip_dispatch_queue_t queue; + iree_status_t failure_status; + iree_notification_t notification; + bool do_exit; +} iree_hal_hip_dispatch_thread_t; + +static bool iree_hal_hip_dispatch_thread_has_request(void* user_data) { + iree_hal_hip_dispatch_thread_t* thread = + (iree_hal_hip_dispatch_thread_t*)user_data; + iree_slim_mutex_lock(&thread->mutex); + bool has_request = !iree_hal_hip_dispatch_queue_empty(&thread->queue); + has_request |= thread->do_exit; + iree_slim_mutex_unlock(&thread->mutex); + return has_request; +} + +static int iree_hal_hip_dispatch_thread_main(void* param) { + iree_hal_hip_dispatch_thread_t* thread = + (iree_hal_hip_dispatch_thread_t*)param; + bool exit = false; + while (true) { + iree_notification_await(&thread->notification, + &iree_hal_hip_dispatch_thread_has_request, thread, + iree_infinite_timeout()); + + iree_slim_mutex_lock(&thread->mutex); + exit |= thread->do_exit; + iree_status_t status = iree_status_clone(thread->failure_status); + while (!iree_hal_hip_dispatch_queue_empty(&thread->queue)) { + iree_hal_hip_dispatch_thread_dispatch_t dispatch = + iree_hal_hip_dispatch_queue_at(&thread->queue, 0); + iree_hal_hip_dispatch_queue_pop_front(&thread->queue, 1); + iree_slim_mutex_unlock(&thread->mutex); + + status = iree_status_join(status, + dispatch.dispatch(dispatch.user_data, status)); + iree_slim_mutex_lock(&thread->mutex); + if (!iree_status_is_ok(status)) { + // We don't join here as the failure status was already + // included here. + iree_status_ignore(thread->failure_status); + thread->failure_status = iree_status_clone(status); + } + } + iree_slim_mutex_unlock(&thread->mutex); + + if (!iree_status_is_ok(status) || exit) { + // Drop the status as it was cloned into thread->failure_status + // if needed. + iree_status_ignore(status); + break; + } + } + return 0; +} + +iree_status_t iree_hal_hip_dispatch_thread_initialize( + iree_allocator_t host_allocator, + iree_hal_hip_dispatch_thread_t** out_thread) { + IREE_TRACE_ZONE_BEGIN(z0); + *out_thread = NULL; + iree_hal_hip_dispatch_thread_t* thread = NULL; + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_allocator_malloc(host_allocator, sizeof(*thread), (void**)&thread)); + + thread->do_exit = false; + iree_slim_mutex_initialize(&thread->mutex); + iree_hal_hip_dispatch_queue_initialize(host_allocator, &thread->queue); + thread->failure_status = iree_ok_status(); + thread->host_allocator = host_allocator; + iree_notification_initialize(&thread->notification); + + iree_thread_create_params_t params; + memset(¶ms, 0x00, sizeof(params)); + params.name = iree_make_cstring_view("iree-hal-hip-dispatch"); + iree_status_t status = + iree_thread_create((iree_thread_entry_t)iree_hal_hip_dispatch_thread_main, + thread, params, host_allocator, &thread->thread); + + if (iree_status_is_ok(status)) { + *out_thread = thread; + } else { + iree_hal_hip_dispatch_queue_deinitialize(&thread->queue); + iree_slim_mutex_deinitialize(&thread->mutex); + iree_allocator_free(host_allocator, thread); + } + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hip_dispatch_thread_deinitialize( + iree_hal_hip_dispatch_thread_t* thread) { + if (!thread) { + return; + } + IREE_TRACE_ZONE_BEGIN(z0); + + iree_slim_mutex_lock(&thread->mutex); + thread->do_exit = true; + iree_slim_mutex_unlock(&thread->mutex); + + iree_notification_post(&thread->notification, IREE_ALL_WAITERS); + // There is only one owner for the thread, so this also joins the thread. + iree_thread_release(thread->thread); + + iree_status_ignore(thread->failure_status); + iree_hal_hip_dispatch_queue_deinitialize(&thread->queue); + iree_slim_mutex_deinitialize(&thread->mutex); + iree_allocator_free(thread->host_allocator, thread); + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_hip_dispatch_thread_add_dispatch( + iree_hal_hip_dispatch_thread_t* thread, + iree_hal_hip_dispatch_callback_t dispatch, void* user_data) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_slim_mutex_lock(&thread->mutex); + iree_status_t status = iree_status_clone(thread->failure_status); + + iree_hal_hip_dispatch_thread_dispatch_t dispatch_data = { + .dispatch = dispatch, + .user_data = user_data, + }; + + if (iree_status_is_ok(status)) { + status = + iree_hal_hip_dispatch_queue_push_back(&thread->queue, dispatch_data); + } + if (!iree_status_is_ok(status)) { + iree_status_ignore(thread->failure_status); + thread->failure_status = iree_status_clone(status); + } + iree_slim_mutex_unlock(&thread->mutex); + iree_notification_post(&thread->notification, IREE_ALL_WAITERS); + + if (!iree_status_is_ok(status)) { + iree_status_ignore(dispatch(user_data, iree_status_clone(status))); + } + IREE_TRACE_ZONE_END(z0); + + // If this was a failure then it was put into thread->failure_status. + return status; +} diff --git a/runtime/src/iree/hal/drivers/hip/dispatch_thread.h b/runtime/src/iree/hal/drivers/hip/dispatch_thread.h new file mode 100644 index 000000000000..1f715b094871 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/dispatch_thread.h @@ -0,0 +1,58 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_DISPATCH_THREAD_H_ +#define IREE_HAL_DRIVERS_HIP_DISPATCH_THREAD_H_ + +#include "iree/base/api.h" +#include "iree/hal/api.h" +#include "iree/hal/drivers/hip/dynamic_symbols.h" + +// iree_hal_hip_dispatch_thread is used to get work off of the main thread. +// This is important to do for a single reason. There are 2 types of +// command buffer that we use in hip. One is a pre-recorded command buffer +// iree_hal_deferred_command_buffer_t, which when executed +// calls all of the associated hipStream based commands. +// The other is iree_hal_hip_graph_command_buffer_t which when executed +// executes hipGraphLaunch. In practice what hipGraphLaunch does +// under the hood is call the associated stream API for each node in the +// graph. Either way these block the main thread for +// quite a lot of time. If a host program wants to execute +// command buffers on multiple GPUs from the same thread +// blocking that thread will cause stalls. +// So instead this thread exists to simply move that +// work off of the main thread. There are a couple of +// caveats, as now we have to move async allocations and deallocations +// to that thread as well, as they need to remain in-order. +typedef struct iree_hal_hip_dispatch_thread_t iree_hal_hip_dispatch_thread_t; + +typedef struct iree_hal_hip_event_t iree_hal_hip_event_t; + +typedef iree_status_t (*iree_hal_hip_dispatch_callback_t)(void* user_data, + iree_status_t status); + +// Initializes the dispatch thread for HIP driver. +iree_status_t iree_hal_hip_dispatch_thread_initialize( + iree_allocator_t host_allocator, + iree_hal_hip_dispatch_thread_t** out_thread); + +// Deinitializes the dispatch thread for HIP driver. +void iree_hal_hip_dispatch_thread_deinitialize( + iree_hal_hip_dispatch_thread_t* thread); + +// Adds a dispatch to the thread, which will be executed +// in order. +// +// |user_data| must remain valid until the callback is called, +// and it is up to the callee to clean up user_data if required. +// The callback will always be called regardless of whether +// or not this function returns an error. An error indicates there +// was an asynchronous failure on the thread, or a semaphore. +iree_status_t iree_hal_hip_dispatch_thread_add_dispatch( + iree_hal_hip_dispatch_thread_t* thread, + iree_hal_hip_dispatch_callback_t callback, void* user_data); + +#endif // IREE_HAL_DRIVERS_HIP_DISPATCH_THREAD_H_ diff --git a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h index 85d33740c83d..28b16542f266 100644 --- a/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h +++ b/runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h @@ -8,8 +8,11 @@ // HIP symbols //===----------------------------------------------------------------------===// -IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxSetCurrent, hipCtx_t) +IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxGetCurrent, hipCtx_t *) +IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxPushCurrent, hipCtx_t) +IREE_HAL_HIP_REQUIRED_PFN_DECL(hipCtxPopCurrent, hipCtx_t *) +IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceEnablePeerAccess, int, unsigned int) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGet, hipDevice_t *, int) IREE_HAL_HIP_REQUIRED_PFN_DECL(hipDeviceGetAttribute, int *, hipDeviceAttribute_t, int) diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.c b/runtime/src/iree/hal/drivers/hip/event_pool.c index 010cfd67fd2e..95232902b803 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.c +++ b/runtime/src/iree/hal/drivers/hip/event_pool.c @@ -37,10 +37,6 @@ struct iree_hal_hip_event_t { // The event pool that owns this event. This cannot be NULL. We retain it to // make sure the event outlive the pool. iree_hal_hip_event_pool_t* pool; - - // The context to use to free this event, it must be the same - // context as was used when allocating the event. - hipCtx_t hip_context; // The underlying hipEvent_t object. hipEvent_t hip_event; }; @@ -53,8 +49,6 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { iree_allocator_t host_allocator = event->host_allocator; const iree_hal_hip_dynamic_symbols_t* symbols = event->symbols; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR( - iree_hal_hip_set_context(event->symbols, event->hip_context)); IREE_ASSERT_REF_COUNT_ZERO(&event->ref_count); IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(event->hip_event)); @@ -65,8 +59,8 @@ static inline void iree_hal_hip_event_destroy(iree_hal_hip_event_t* event) { static inline iree_status_t iree_hal_hip_event_create( const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_event_pool_t* pool, hipCtx_t context, - iree_allocator_t host_allocator, iree_hal_hip_event_t** out_event) { + iree_hal_hip_event_pool_t* pool, iree_allocator_t host_allocator, + iree_hal_hip_event_t** out_event) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(pool); IREE_ASSERT_ARGUMENT(out_event); @@ -82,9 +76,8 @@ static inline iree_status_t iree_hal_hip_event_create( event->symbols = symbols; event->pool = pool; event->hip_event = NULL; - event->hip_context = context; - iree_status_t status = IREE_HIP_RESULT_TO_STATUS( + iree_status_t status = IREE_HIP_CALL_TO_STATUS( symbols, hipEventCreateWithFlags(&event->hip_event, hipEventDisableTiming), "hipEventCreateWithFlags"); @@ -108,6 +101,9 @@ static void iree_hal_hip_event_pool_release_event( iree_hal_hip_event_t** events); void iree_hal_hip_event_release(iree_hal_hip_event_t* event) { + if (!event) { + return; + } if (iree_atomic_ref_count_dec(&event->ref_count) == 1) { iree_hal_hip_event_pool_t* pool = event->pool; // Release back to the pool if the reference count becomes 0. @@ -130,9 +126,7 @@ struct iree_hal_hip_event_pool_t { // The symbols used to create and destroy hipEvent_t objects. const iree_hal_hip_dynamic_symbols_t* symbols; - // The context for this event pool to use to allocate - // events. - hipCtx_t hip_context; + hipCtx_t device_context; // Guards event related fields in the pool. We don't expect a performant // program to frequently allocate events for synchronization purposes; the @@ -154,9 +148,9 @@ struct iree_hal_hip_event_pool_t { static void iree_hal_hip_event_pool_free(iree_hal_hip_event_pool_t* event_pool); iree_status_t iree_hal_hip_event_pool_allocate( - const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context, + const iree_hal_hip_dynamic_symbols_t* symbols, iree_host_size_t available_capacity, iree_allocator_t host_allocator, - iree_hal_hip_event_pool_t** out_event_pool) { + hipCtx_t device_context, iree_hal_hip_event_pool_t** out_event_pool) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(out_event_pool); *out_event_pool = NULL; @@ -175,14 +169,16 @@ iree_status_t iree_hal_hip_event_pool_allocate( iree_slim_mutex_initialize(&event_pool->event_mutex); event_pool->available_capacity = available_capacity; event_pool->available_count = 0; - event_pool->hip_context = hip_context; - - iree_status_t status = iree_ok_status(); - for (iree_host_size_t i = 0; i < available_capacity; ++i) { - status = iree_hal_hip_event_create( - symbols, event_pool, hip_context, host_allocator, - &event_pool->available_list[event_pool->available_count++]); - if (!iree_status_is_ok(status)) break; + event_pool->device_context = device_context; + + iree_status_t status = iree_hal_hip_set_context(symbols, device_context); + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < available_capacity; ++i) { + status = iree_hal_hip_event_create( + symbols, event_pool, host_allocator, + &event_pool->available_list[event_pool->available_count++]); + if (!iree_status_is_ok(status)) break; + } } if (iree_status_is_ok(status)) { @@ -217,6 +213,9 @@ void iree_hal_hip_event_pool_retain(iree_hal_hip_event_pool_t* event_pool) { } void iree_hal_hip_event_pool_release(iree_hal_hip_event_pool_t* event_pool) { + if (!event_pool) { + return; + } if (iree_atomic_ref_count_dec(&event_pool->ref_count) == 1) { iree_hal_hip_event_pool_free(event_pool); } @@ -250,22 +249,24 @@ iree_status_t iree_hal_hip_event_pool_acquire( // Allocate the rest of the events. if (remaining_count > 0) { - IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-acquire"); - iree_status_t status = iree_ok_status(); + IREE_TRACE_ZONE_APPEND_TEXT(z0, "unpooled acquire"); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)remaining_count); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(event_pool->symbols, + event_pool->device_context)); for (iree_host_size_t i = 0; i < remaining_count; ++i) { - status = iree_hal_hip_event_create( - event_pool->symbols, event_pool, event_pool->hip_context, - event_pool->host_allocator, &out_events[from_pool_count + i]); + iree_status_t status = iree_hal_hip_event_create( + event_pool->symbols, event_pool, event_pool->host_allocator, + &out_events[from_pool_count + i]); if (!iree_status_is_ok(status)) { // Must release all events we've acquired so far. iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i, out_events); - IREE_TRACE_ZONE_END(z1); IREE_TRACE_ZONE_END(z0); return status; } } - IREE_TRACE_ZONE_END(z1); } // Retain a reference to a pool when we pass event to the caller. When the @@ -311,11 +312,11 @@ static void iree_hal_hip_event_pool_release_event( // Deallocate the rest of the events. We don't bother resetting them as we are // getting rid of them. if (remaining_count > 0) { - IREE_TRACE_ZONE_BEGIN_NAMED(z1, "event-pool-unpooled-release"); + IREE_TRACE_ZONE_APPEND_TEXT(z0, "unpooled release"); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)remaining_count); for (iree_host_size_t i = 0; i < remaining_count; ++i) { iree_hal_hip_event_destroy(events[to_pool_count + i]); } - IREE_TRACE_ZONE_END(z1); } IREE_TRACE_ZONE_END(z0); } diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.h b/runtime/src/iree/hal/drivers/hip/event_pool.h index 0683714d9724..d3d4c497bb87 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.h +++ b/runtime/src/iree/hal/drivers/hip/event_pool.h @@ -10,10 +10,6 @@ #include "iree/base/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - //===----------------------------------------------------------------------===// // iree_hal_hip_event_t //===----------------------------------------------------------------------===// @@ -52,9 +48,9 @@ typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t; // Extra events requested beyond the capability are directly created and // destroyed without pooling. iree_status_t iree_hal_hip_event_pool_allocate( - const iree_hal_hip_dynamic_symbols_t* symbols, hipCtx_t hip_context, + const iree_hal_hip_dynamic_symbols_t* symbols, iree_host_size_t available_capacity, iree_allocator_t host_allocator, - iree_hal_hip_event_pool_t** out_event_pool); + hipCtx_t device_context, iree_hal_hip_event_pool_t** out_event_pool); // Retains the given |event_pool| by increasing its reference count. void iree_hal_hip_event_pool_retain(iree_hal_hip_event_pool_t* event_pool); @@ -73,8 +69,4 @@ iree_status_t iree_hal_hip_event_pool_acquire( iree_hal_hip_event_pool_t* event_pool, iree_host_size_t event_count, iree_hal_hip_event_t** out_events); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_EVENT_POOL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 686e23015b8f..276290c5f006 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -9,42 +9,91 @@ #include "iree/base/internal/synchronization.h" #include "iree/base/internal/wait_handle.h" #include "iree/base/status.h" -#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/event_pool.h" #include "iree/hal/drivers/hip/status_util.h" -#include "iree/hal/drivers/hip/timepoint_pool.h" +#include "iree/hal/drivers/hip/util/tree.h" #include "iree/hal/utils/semaphore_base.h" +typedef struct iree_hal_hip_cpu_event_t { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + iree_event_t event; +} iree_hal_hip_cpu_event_t; + +static void iree_hal_hip_cpu_event_destroy(iree_hal_resource_t* resource) { + iree_hal_hip_cpu_event_t* event = (iree_hal_hip_cpu_event_t*)(resource); + iree_event_deinitialize(&event->event); + iree_allocator_free(event->host_allocator, event); +} + +static const iree_hal_resource_vtable_t iree_hal_hip_cpu_event_vtable = { + .destroy = &iree_hal_hip_cpu_event_destroy, +}; + +typedef struct iree_hal_hip_cpu_event_vtable_t { + void(IREE_API_PTR* destroy)(iree_hal_resource_t* resource); +} iree_hal_hip_cpu_event_vtable_t; + +typedef struct iree_hal_hip_semaphore_work_item_t { + iree_hal_hip_event_semaphore_scheduled_callback_t scheduled_callback; + void* user_data; + struct iree_hal_hip_semaphore_work_item_t* next; +} iree_hal_hip_semaphore_work_item_t; + +// Work associated with a particular point in the semaphore timeline. +// +// The |work_item| is a set of callbacks to be made when the semaphore +// is guaranteed to make forward progress the associated key value. They +// will also be cleaned up at this time. If the semaphore is failed, +// the callbacks will be called with the status code of the failure. +// If the semaphore is destroyed while callbacks are active, +// they will be called with the CANCELLED erorr. +// The |cpu_event| is a value for the CPU to wait on when +// we may not have to wait infinitely. For example with a multi +// wait or a non-infinite timeout. +// The |event| is a hip event that is used for GPU waits or +// infinite CPU waits. +typedef struct iree_hal_hip_semaphore_queue_item_t { + iree_hal_hip_event_t* event; + iree_hal_hip_cpu_event_t* cpu_event; + iree_hal_hip_semaphore_work_item_t* work_item; +} iree_hal_hip_semaphore_queue_item_t; + typedef struct iree_hal_hip_semaphore_t { // Abstract resource used for injecting reference counting and vtable; // must be at offset 0. - iree_hal_semaphore_t base; + iree_hal_resource_t base; // The allocator used to create this semaphore. iree_allocator_t host_allocator; // The symbols used to issue HIP API calls. const iree_hal_hip_dynamic_symbols_t* symbols; - // The timepoint pool to acquire timepoint objects. - iree_hal_hip_timepoint_pool_t* timepoint_pool; - - // The list of actions that this semaphore may need to advance on - // new signaled values. - iree_hal_deferred_work_queue_t* work_queue; + // This queue represents the values in the timeline. + // The keys in the queue are the timeline values that + // are being signaled/waited on in the semaphore + // The values are |iree_hal_hip_semaphore_queue_item_t| values. + struct { + iree_hal_hip_util_tree_t tree; + // Inline storage for this tree. We expect the normal number of + // nodes in use for a single semaphore to be relatively small. + uint8_t inline_storage[sizeof(iree_hal_hip_util_tree_node_t) * 16]; + } event_queue; + + // Notify any potential CPU waiters that this semaphore + // has changed state. + iree_notification_t state_notification; - hipCtx_t hip_context; - - // Guards value and status. We expect low contention on semaphores and since - // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler - // than trying to make the entire structure lock-free. - // If we need to hold mutex and base.timepoint_mutex locked together, the - // locking order must be (mutex, base.timepoint_mutex). iree_slim_mutex_t mutex; + // The maximum value that this semaphore has been signaled to. + // This means this semaphore is guaranteed to make forward progress + // until that semaphore is hit, as all signaling operations have + // been made available. + uint64_t max_value_to_be_signaled IREE_GUARDED_BY(mutex); - // Current signaled value. May be IREE_HAL_SEMAPHORE_FAILURE_VALUE to - // indicate that the semaphore has been signaled for failure and - // |failure_status| contains the error. - uint64_t current_value IREE_GUARDED_BY(mutex); + // The largest value that has been observed by the host. + uint64_t current_visible_value IREE_GUARDED_BY(mutex); // OK or the status passed to iree_hal_semaphore_fail. Owned by the semaphore. iree_status_t failure_status IREE_GUARDED_BY(mutex); @@ -60,34 +109,34 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast( iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, - hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, - iree_hal_semaphore_t** out_semaphore) { + iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { IREE_ASSERT_ARGUMENT(symbols); - IREE_ASSERT_ARGUMENT(timepoint_pool); - IREE_ASSERT_ARGUMENT(work_queue); IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); + *out_semaphore = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(symbols, hip_context)); iree_hal_hip_semaphore_t* semaphore = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, sizeof(*semaphore), (void**)&semaphore)); iree_hal_semaphore_initialize(&iree_hal_hip_semaphore_vtable, - &semaphore->base); + (iree_hal_semaphore_t*)semaphore); semaphore->host_allocator = host_allocator; semaphore->symbols = symbols; - semaphore->timepoint_pool = timepoint_pool; - semaphore->work_queue = work_queue; + iree_hal_hip_util_tree_initialize( + host_allocator, sizeof(iree_hal_hip_semaphore_queue_item_t), + semaphore->event_queue.inline_storage, + sizeof(semaphore->event_queue.inline_storage), + &semaphore->event_queue.tree); + iree_notification_initialize(&semaphore->state_notification); + iree_slim_mutex_initialize(&semaphore->mutex); - semaphore->current_value = initial_value; + semaphore->current_visible_value = initial_value; + semaphore->max_value_to_be_signaled = initial_value; semaphore->failure_status = iree_ok_status(); - semaphore->hip_context = hip_context; - *out_semaphore = &semaphore->base; + *out_semaphore = (iree_hal_semaphore_t*)semaphore; IREE_TRACE_ZONE_END(z0); return iree_ok_status(); @@ -99,484 +148,790 @@ static void iree_hal_hip_semaphore_destroy( iree_hal_hip_semaphore_cast(base_semaphore); iree_allocator_t host_allocator = semaphore->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR( - iree_hal_hip_set_context(semaphore->symbols, semaphore->hip_context)); iree_status_ignore(semaphore->failure_status); iree_slim_mutex_deinitialize(&semaphore->mutex); - iree_hal_semaphore_deinitialize(&semaphore->base); + iree_notification_deinitialize(&semaphore->state_notification); + for (iree_hal_hip_util_tree_node_t* i = + iree_hal_hip_util_tree_first(&semaphore->event_queue.tree); + i != NULL; i = iree_hal_hip_util_tree_node_next(i)) { + iree_hal_hip_semaphore_queue_item_t* queue_item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(i); + iree_hal_hip_event_release(queue_item->event); + iree_hal_resource_release(queue_item->cpu_event); + iree_hal_hip_semaphore_work_item_t* work_item = queue_item->work_item; + while (work_item) { + work_item->scheduled_callback( + work_item->user_data, base_semaphore, + iree_make_status( + IREE_STATUS_CANCELLED, + "semaphore was destroyed while callback is in flight")); + iree_hal_hip_semaphore_work_item_t* next = work_item->next; + iree_allocator_free(host_allocator, work_item); + work_item = next; + } + } + iree_hal_hip_util_tree_deinitialize(&semaphore->event_queue.tree); iree_allocator_free(host_allocator, semaphore); IREE_TRACE_ZONE_END(z0); } -static iree_status_t iree_hal_hip_semaphore_query( - iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { +static iree_status_t iree_hal_hip_semaphore_get_cpu_event( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_hal_hip_cpu_event_t** out_event) { + IREE_ASSERT_ARGUMENT(out_event); + *out_event = NULL; iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); - iree_slim_mutex_lock(&semaphore->mutex); - - *out_value = semaphore->current_value; - + if (value <= semaphore->current_visible_value) { + iree_slim_mutex_unlock(&semaphore->mutex); + return iree_ok_status(); + } iree_status_t status = iree_ok_status(); - if (*out_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { - status = iree_status_clone(semaphore->failure_status); + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_get(&semaphore->event_queue.tree, value); + if (!node) { + status = iree_hal_hip_util_tree_insert(&semaphore->event_queue.tree, value, + &node); } - iree_slim_mutex_unlock(&semaphore->mutex); + iree_hal_hip_semaphore_queue_item_t* item = NULL; + if (iree_status_is_ok(status)) { + item = (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + + if (!item->cpu_event) { + status = iree_allocator_malloc(semaphore->host_allocator, + sizeof(*item->cpu_event), + (void**)&item->cpu_event); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_hip_cpu_event_vtable, + (iree_hal_resource_t*)item->cpu_event); + item->cpu_event->host_allocator = semaphore->host_allocator; + + status = iree_event_initialize(false, &item->cpu_event->event); + if (!iree_status_is_ok(status)) { + // Clear out the cpu_event here, so that we dont have to + // special case cleanup later. + iree_allocator_free(semaphore->host_allocator, item->cpu_event); + item->cpu_event = NULL; + } + } + } - IREE_TRACE_ZONE_END(z0); + if (iree_status_is_ok(status)) { + iree_hal_resource_retain(&item->cpu_event->resource); + *out_event = item->cpu_event; + } + } + iree_slim_mutex_unlock(&semaphore->mutex); + if (!iree_status_is_ok(status)) { + if (item && item->cpu_event) { + iree_event_deinitialize(&item->cpu_event->event); + iree_allocator_free(semaphore->host_allocator, item->cpu_event); + } + } return status; } -static iree_status_t iree_hal_hip_semaphore_signal( - iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { +static bool iree_hal_hip_semaphore_is_aborted( + iree_hal_semaphore_t* base_semaphore) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); + bool aborted = + semaphore->current_visible_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE; + iree_slim_mutex_unlock(&semaphore->mutex); + return aborted; +} - if (new_value <= semaphore->current_value) { - uint64_t current_value IREE_ATTRIBUTE_UNUSED = semaphore->current_value; - iree_slim_mutex_unlock(&semaphore->mutex); +iree_status_t iree_hal_hip_semaphore_multi_wait( + const iree_hal_semaphore_list_t semaphore_list, + iree_hal_wait_mode_t wait_mode, iree_timeout_t timeout, + iree_allocator_t host_allocator) { + if (semaphore_list.count == 0) return iree_ok_status(); + IREE_TRACE_ZONE_BEGIN(z0); + + const iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout); + iree_status_t status = iree_ok_status(); + + // If we have to wait on "all" semaphores then we can + // fast-path this to just a normal wait. + if (semaphore_list.count == 1 || wait_mode == IREE_HAL_WAIT_MODE_ALL) { + // Fast-path if we don't have to wait on only a subset of the semaphores. + for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) { + iree_timeout_t t = iree_make_deadline(deadline_ns); + status = iree_status_join( + status, iree_hal_semaphore_wait(semaphore_list.semaphores[0], + semaphore_list.payload_values[0], t)); + if (!iree_status_is_ok(status)) { + break; + } + } IREE_TRACE_ZONE_END(z0); - return iree_make_status(IREE_STATUS_OUT_OF_RANGE, - "semaphore values must be monotonically " - "increasing; current_value=%" PRIu64 - ", new_value=%" PRIu64, - current_value, new_value); + return status; + } + + iree_hal_hip_cpu_event_t** cpu_events = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, + semaphore_list.count * sizeof(*cpu_events), + (void**)&cpu_events)); + bool semaphore_hit = false; + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) { + status = iree_hal_hip_semaphore_get_cpu_event( + semaphore_list.semaphores[i], semaphore_list.payload_values[i], + &cpu_events[i]); + if (!iree_status_is_ok(status)) { + break; + } + // If we can not get a CPU event for a given value BUT it returns success + // it is because the event has already been signaled to that value. + if (!cpu_events[i]) { + semaphore_hit = true; + break; + } + } } - semaphore->current_value = new_value; + iree_wait_set_t* wait_set = NULL; + if (iree_status_is_ok(status) && !semaphore_hit) { + status = + iree_wait_set_allocate(semaphore_list.count, host_allocator, &wait_set); + } - iree_slim_mutex_unlock(&semaphore->mutex); + if (iree_status_is_ok(status) && !semaphore_hit) { + for (iree_host_size_t i = 0; + i < semaphore_list.count && iree_status_is_ok(status); ++i) { + status = iree_wait_set_insert(wait_set, cpu_events[i]->event); + } + } - // Notify timepoints - note that this must happen outside the lock. - iree_hal_semaphore_notify(&semaphore->base, new_value, IREE_STATUS_OK); + if (iree_status_is_ok(status) && !semaphore_hit) { + status = iree_wait_any(wait_set, deadline_ns, NULL); + iree_wait_set_free(wait_set); + if (iree_status_is_ok(status)) { + // Now we have to walk all of the semaphores to propagate + // any errors that we find. + for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) { + if (iree_hal_hip_semaphore_is_aborted(semaphore_list.semaphores[i])) { + status = iree_make_status(IREE_STATUS_ABORTED, + "the semaphore was aborted"); + break; + } + } + } + } - // Advance the deferred work queue if possible. This also must happen - // outside the lock to avoid nesting. - iree_status_t status = - iree_hal_deferred_work_queue_issue(semaphore->work_queue); + for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) { + iree_hal_resource_release(&cpu_events[i]->resource); + } + iree_allocator_free(host_allocator, cpu_events); IREE_TRACE_ZONE_END(z0); return status; } -static void iree_hal_hip_semaphore_fail(iree_hal_semaphore_t* base_semaphore, - iree_status_t status) { +static iree_status_t iree_hal_hip_event_semaphore_run_scheduled_callbacks( + iree_hal_semaphore_t* base_semaphore) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); IREE_TRACE_ZONE_BEGIN(z0); - const iree_status_code_t status_code = iree_status_code(status); - - iree_slim_mutex_lock(&semaphore->mutex); + iree_hal_hip_semaphore_work_item_t* work_item = NULL; + iree_hal_hip_semaphore_work_item_t* last_work_item = NULL; - // Try to set our local status - we only preserve the first failure so only - // do this if we are going from a valid semaphore to a failed one. - if (!iree_status_is_ok(semaphore->failure_status)) { - // Previous sta-tus was not OK; drop our new status. - IREE_IGNORE_ERROR(status); + // Take out all of the values from the queue that are less than the + // current visible value, and make sure we advance any work needed + // on them. + do { + iree_slim_mutex_lock(&semaphore->mutex); + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_first(&semaphore->event_queue.tree); + if (node == NULL) { + iree_slim_mutex_unlock(&semaphore->mutex); + break; + } + if (iree_hal_hip_util_tree_node_get_key(node) > + semaphore->current_visible_value) { + iree_slim_mutex_unlock(&semaphore->mutex); + break; + } + iree_hal_hip_semaphore_queue_item_t copy = + *(iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + iree_hal_hip_util_tree_erase(&semaphore->event_queue.tree, node); iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); - return; - } - - // Signal to our failure sentinel value. - semaphore->current_value = IREE_HAL_SEMAPHORE_FAILURE_VALUE; - semaphore->failure_status = status; - - iree_slim_mutex_unlock(&semaphore->mutex); + iree_hal_hip_event_release(copy.event); + if (copy.cpu_event) { + iree_event_set(©.cpu_event->event); + iree_hal_resource_release(©.cpu_event->resource); + } - // Notify timepoints - note that this must happen outside the lock. - iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE, - status_code); + iree_hal_hip_semaphore_work_item_t* next_work_item = copy.work_item; + while (next_work_item) { + if (!work_item) { + work_item = next_work_item; + } + if (last_work_item && !last_work_item->next) { + last_work_item->next = next_work_item; + } + last_work_item = next_work_item; + next_work_item = next_work_item->next; + } + } while (true); - // Advance the deferred work queue if possible. This also must happen - // outside the lock to avoid nesting. - status = iree_hal_deferred_work_queue_issue(semaphore->work_queue); - iree_status_ignore(status); + iree_slim_mutex_lock(&semaphore->mutex); + semaphore->max_value_to_be_signaled = iree_max( + semaphore->max_value_to_be_signaled, semaphore->current_visible_value); + iree_status_t status = iree_status_clone(semaphore->failure_status); - IREE_TRACE_ZONE_END(z0); -} + iree_slim_mutex_unlock(&semaphore->mutex); + // Now that we have accumulated all of the work items, and we have + // unlocked the semaphore, start running through the work items. + while (work_item) { + iree_hal_hip_semaphore_work_item_t* next_work_item = work_item->next; + iree_status_ignore(work_item->scheduled_callback( + work_item->user_data, base_semaphore, iree_status_clone(status))); + iree_allocator_free(semaphore->host_allocator, work_item); + work_item = next_work_item; + } -// Handles host wait timepoints on the host when the |semaphore| timeline -// advances past the given |value|. -// -// Note that this callback is invoked by the a host thread. -static iree_status_t iree_hal_hip_semaphore_timepoint_host_wait_callback( - void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value, - iree_status_code_t status_code) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_timepoint_t* timepoint = (iree_hal_hip_timepoint_t*)user_data; - iree_event_set(&timepoint->timepoint.host_wait); + iree_notification_post(&semaphore->state_notification, IREE_ALL_WAITERS); IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } -// Acquires a timepoint to wait the timeline to reach at least the given -// |min_value| from the host. -static iree_status_t iree_hal_hip_semaphore_acquire_timepoint_host_wait( - iree_hal_hip_semaphore_t* semaphore, uint64_t min_value, - iree_timeout_t timeout, iree_hal_hip_timepoint_t** out_timepoint) { - IREE_TRACE_ZONE_BEGIN(z0); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_timepoint_pool_acquire_host_wait( - semaphore->timepoint_pool, 1, out_timepoint)); - // Initialize the timepoint with the value and callback, and connect it to - // this semaphore. - iree_hal_semaphore_acquire_timepoint( - &semaphore->base, min_value, timeout, - (iree_hal_semaphore_callback_t){ - .fn = iree_hal_hip_semaphore_timepoint_host_wait_callback, - .user_data = *out_timepoint, - }, - &(*out_timepoint)->base); +iree_status_t iree_hal_hip_semaphore_notify_work( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_hal_hip_event_pool_t* event_pool, + iree_hal_hip_event_semaphore_scheduled_callback_t callback, + void* user_data) { + iree_hal_hip_semaphore_t* semaphore = + iree_hal_hip_semaphore_cast(base_semaphore); + iree_slim_mutex_lock(&semaphore->mutex); + iree_status_t status = iree_status_clone(semaphore->failure_status); + + if (iree_status_is_ok(status) && + value > semaphore->max_value_to_be_signaled) { + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_get(&semaphore->event_queue.tree, value); + if (node == NULL) { + status = iree_hal_hip_util_tree_insert(&semaphore->event_queue.tree, + value, &node); + if (iree_status_is_ok(status)) { + iree_hal_hip_semaphore_queue_item_t* item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + item->event = NULL; + item->cpu_event = NULL; + item->work_item = NULL; + } + } + if (iree_status_is_ok(status)) { + iree_hal_hip_semaphore_queue_item_t* item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + iree_hal_hip_semaphore_work_item_t* work_item = NULL; + status = iree_allocator_malloc(semaphore->host_allocator, + sizeof(*work_item), (void**)&work_item); + if (iree_status_is_ok(status)) { + work_item->scheduled_callback = callback; + work_item->user_data = user_data; + work_item->next = item->work_item; + item->work_item = work_item; + callback = NULL; + } + } + } + iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + // If this semaphore requirement has already been satisfied, + // of if this semaphore has failed then we can just run the callback right + // now. + if (callback) { + status = callback(user_data, base_semaphore, status); + } + return status; } -bool iree_hal_hip_semaphore_acquire_event_host_wait( - iree_hal_semaphore_t* base_semaphore, uint64_t min_value, - iree_hal_hip_event_t** out_event) { - *out_event = NULL; - IREE_TRACE_ZONE_BEGIN(z0); +iree_status_t iree_hal_hip_semaphore_notify_forward_progress_to( + iree_hal_semaphore_t* base_semaphore, uint64_t value) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - - // Scan through the timepoint list and try to find a device event signal to - // wait on. We need to lock with the timepoint list mutex here. - iree_slim_mutex_lock(&semaphore->base.timepoint_mutex); - for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head; - tp != NULL; tp = tp->next) { - iree_hal_hip_timepoint_t* signal_timepoint = (iree_hal_hip_timepoint_t*)tp; - if (signal_timepoint->kind == IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_SIGNAL && - signal_timepoint->base.minimum_value >= min_value) { - *out_event = signal_timepoint->timepoint.device_signal; - iree_hal_hip_event_retain(*out_event); - break; + iree_slim_mutex_lock(&semaphore->mutex); + iree_status_t status = iree_status_clone(semaphore->failure_status); + if (!iree_status_is_ok(status)) { + iree_slim_mutex_unlock(&semaphore->mutex); + return status; + } + iree_hal_hip_semaphore_work_item_t* work_item = NULL; + iree_hal_hip_semaphore_work_item_t* last_work_item = NULL; + if (value > semaphore->max_value_to_be_signaled) { + iree_hal_hip_util_tree_node_t* node = iree_hal_hip_util_tree_upper_bound( + &semaphore->event_queue.tree, semaphore->max_value_to_be_signaled); + // Collect all of the things to schedule now that we know we can safely make + // it to a given value. + while (node && iree_hal_hip_util_tree_node_get_key(node) <= value) { + iree_hal_hip_semaphore_queue_item_t* queue_item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + iree_hal_hip_semaphore_work_item_t* next_work_item = + queue_item->work_item; + while (next_work_item) { + if (!work_item) { + work_item = next_work_item; + } + if (last_work_item && !last_work_item->next) { + last_work_item->next = next_work_item; + } + last_work_item = next_work_item; + next_work_item = next_work_item->next; + } + queue_item->work_item = NULL; + iree_hal_hip_util_tree_node_t* last_node = node; + node = iree_hal_hip_util_tree_node_next(node); + if (!queue_item->event) { + iree_hal_hip_util_tree_erase(&semaphore->event_queue.tree, last_node); + } } } - iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex); - IREE_TRACE_ZONE_END(z0); - return *out_event != NULL; + semaphore->max_value_to_be_signaled = + iree_max(semaphore->max_value_to_be_signaled, value); + iree_slim_mutex_unlock(&semaphore->mutex); + + // Now that we have accumulated all of the work items, and we have + // unlocked the semaphore, start running through the work items. + while (work_item) { + iree_hal_hip_semaphore_work_item_t* next_work_item = work_item->next; + work_item->scheduled_callback(work_item->user_data, base_semaphore, status); + iree_allocator_free(semaphore->host_allocator, work_item); + work_item = next_work_item; + } + return status; } -// Checks if the semaphore has to wait to reach `value`. -// If it has to wait, then acquires a wait timepoint and returns it. -// If we don't need to wait, then *out_timepoint is set to NULL. -static iree_status_t iree_hal_hip_semaphore_try_wait_or_acquire_wait_timepoint( +iree_status_t iree_hal_hip_semaphore_get_hip_event( iree_hal_semaphore_t* base_semaphore, uint64_t value, - iree_timeout_t timeout, iree_hal_hip_timepoint_t** out_timepoint) { - *out_timepoint = NULL; + iree_hal_hip_event_pool_t* event_pool, + iree_hal_hip_event_t** out_hip_event) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); - + *out_hip_event = NULL; iree_slim_mutex_lock(&semaphore->mutex); - if (!iree_status_is_ok(semaphore->failure_status)) { - // Fastest path: failed; return an error to tell callers to query for it. + if (value <= semaphore->current_visible_value) { iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); - return iree_status_from_code(IREE_STATUS_ABORTED); - } - if (semaphore->current_value >= value) { - // Fast path: already satisfied. - iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } - if (iree_timeout_is_immediate(timeout)) { - // Not satisfied but a poll, so can avoid the expensive wait handle work. - iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); - return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); - } - - // Slow path: acquire a timepoint. This should happen inside of the lock too. - // If not locked the semaphore may be signal before acquiring a timepoint. - // Then we would miss the signal. - iree_status_t status = iree_hal_hip_semaphore_acquire_timepoint_host_wait( - semaphore, value, timeout, out_timepoint); + iree_status_t status = iree_status_clone(semaphore->failure_status); + if (iree_status_is_ok(status)) { + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_get(&semaphore->event_queue.tree, value); + + if (node == NULL) { + status = iree_hal_hip_util_tree_insert(&semaphore->event_queue.tree, + value, &node); + if (iree_status_is_ok(status)) { + iree_hal_hip_semaphore_queue_item_t* item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + item->cpu_event = NULL; + item->work_item = NULL; + } + } + if (iree_status_is_ok(status)) { + iree_hal_hip_event_t* event = + ((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event; + if (!event) { + do { + node = iree_hal_hip_util_tree_node_next(node); + if (!node) { + status = iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "there was no event that could be valid"); + break; + } + event = ((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event; + } while (!event); + } + if (event) { + iree_hal_hip_event_retain(event); + } + *out_hip_event = event; + } + } iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return status; } -static iree_status_t iree_hal_hip_semaphore_wait( +iree_status_t iree_hal_hip_semaphore_create_event_and_record_if_necessary( iree_hal_semaphore_t* base_semaphore, uint64_t value, - iree_timeout_t timeout) { + hipStream_t dispatch_stream, iree_hal_hip_event_pool_t* event_pool) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_hip_timepoint_t* timepoint; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_semaphore_try_wait_or_acquire_wait_timepoint( - base_semaphore, value, timeout, &timepoint)); - if (!timepoint) { - // We don't need to wait on a timepoint. - // The wait condition is satisfied. - IREE_TRACE_ZONE_END(z0); + iree_slim_mutex_lock(&semaphore->mutex); + if (value <= semaphore->current_visible_value) { + iree_slim_mutex_unlock(&semaphore->mutex); return iree_ok_status(); } + iree_status_t status = iree_status_clone(semaphore->failure_status); + if (iree_status_is_ok(status)) { + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_get(&semaphore->event_queue.tree, value); + + if (node == NULL) { + status = iree_hal_hip_util_tree_insert(&semaphore->event_queue.tree, + value, &node); + if (iree_status_is_ok(status)) { + iree_hal_hip_semaphore_queue_item_t* item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + item->cpu_event = NULL; + item->work_item = NULL; + } + } - iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { - iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); - return iree_make_status(IREE_STATUS_ABORTED); + if (iree_status_is_ok(status)) { + iree_hal_hip_event_t* event = + ((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event; + if (!event) { + status = iree_hal_hip_event_pool_acquire( + event_pool, 1, + &((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event); + if (iree_status_is_ok(status)) { + event = ((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event; + status = IREE_HIP_CALL_TO_STATUS( + semaphore->symbols, + hipEventRecord(iree_hal_hip_event_handle(event), + dispatch_stream)); + } + } + } } iree_slim_mutex_unlock(&semaphore->mutex); - // Wait until the timepoint resolves. - // If satisfied the timepoint is automatically cleaned up and we are done. If - // the deadline is reached before satisfied then we have to clean it up. - iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout); - iree_status_t status = - iree_wait_one(&timepoint->timepoint.host_wait, deadline_ns); - if (!iree_status_is_ok(status)) { - iree_hal_semaphore_cancel_timepoint(&semaphore->base, &timepoint->base); - } - iree_hal_hip_timepoint_pool_release(semaphore->timepoint_pool, 1, &timepoint); - if (!iree_status_is_ok(status)) { - IREE_TRACE_ZONE_END(z0); - return status; + return status; +} + +static iree_status_t iree_hal_hip_semaphore_query_locked( + iree_hal_hip_semaphore_t* semaphore, uint64_t* out_value) { + iree_status_t status = iree_ok_status(); + *out_value = semaphore->current_visible_value; + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_first(&semaphore->event_queue.tree); + while (node) { + if (!((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event) { + node = iree_hal_hip_util_tree_node_next(node); + continue; + } + + hipError_t err = + semaphore->symbols->hipEventQuery(iree_hal_hip_event_handle( + ((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event)); + if (err == hipErrorNotReady) { + break; + } + if (err != hipSuccess) { + status = IREE_HIP_RESULT_TO_STATUS(semaphore->symbols, err); + break; + } + + *out_value = iree_hal_hip_util_tree_node_get_key(node); + node = iree_hal_hip_util_tree_node_next(node); } - iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { - status = iree_make_status(IREE_STATUS_ABORTED); + if (iree_status_is_ok(status)) { + if (semaphore->current_visible_value < *out_value) { + semaphore->current_visible_value = *out_value; + iree_notification_post(&semaphore->state_notification, IREE_ALL_WAITERS); + } + + if (*out_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + status = + iree_make_status(IREE_STATUS_ABORTED, "the semaphore was aborted"); + } } - iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return status; } -iree_status_t iree_hal_hip_semaphore_multi_wait( - const iree_hal_semaphore_list_t semaphore_list, - iree_hal_wait_mode_t wait_mode, iree_timeout_t timeout, - iree_arena_block_pool_t* block_pool) { - if (semaphore_list.count == 0) return iree_ok_status(); +static iree_status_t iree_hal_hip_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + iree_hal_hip_semaphore_t* semaphore = + iree_hal_hip_semaphore_cast(base_semaphore); + + iree_slim_mutex_lock(&semaphore->mutex); + *out_value = semaphore->current_visible_value; - if (semaphore_list.count == 1) { - // Fast-path for a single semaphore. - return iree_hal_semaphore_wait(semaphore_list.semaphores[0], - semaphore_list.payload_values[0], timeout); + iree_status_t status = + iree_hal_hip_semaphore_query_locked(semaphore, out_value); + iree_slim_mutex_unlock(&semaphore->mutex); + // If the status is aborted, we will pick up the real status from + // semaphore_advance. + if (iree_status_is_aborted(status)) { + iree_status_ignore(status); + status = iree_ok_status(); } - IREE_TRACE_ZONE_BEGIN(z0); + return iree_status_join( + status, + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); +} - iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout); +iree_status_t iree_hal_hip_event_semaphore_advance( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_hip_semaphore_t* semaphore = + iree_hal_hip_semaphore_cast(base_semaphore); - // Avoid heap allocations by using the device block pool for the wait set. - iree_arena_allocator_t arena; - iree_arena_initialize(block_pool, &arena); - iree_wait_set_t* wait_set = NULL; - iree_status_t status = iree_wait_set_allocate( - semaphore_list.count, iree_arena_allocator(&arena), &wait_set); - - // Acquire a host wait handle for each semaphore timepoint we are to wait on. - iree_host_size_t timepoint_count = 0; - iree_hal_hip_timepoint_t** timepoints = NULL; - iree_host_size_t total_timepoint_size = - semaphore_list.count * sizeof(timepoints[0]); - bool needs_wait = true; - status = - iree_arena_allocate(&arena, total_timepoint_size, (void**)&timepoints); - if (iree_status_is_ok(status)) { - memset(timepoints, 0, total_timepoint_size); - for (iree_host_size_t i = 0; i < semaphore_list.count && needs_wait; ++i) { - iree_hal_hip_timepoint_t* timepoint; - status = iree_hal_hip_semaphore_try_wait_or_acquire_wait_timepoint( - semaphore_list.semaphores[i], semaphore_list.payload_values[i], - timeout, &timepoint); - if (!iree_status_is_ok(status)) break; - if (!timepoint) { - // We don't need to wait on a timepoint. - // The wait condition is satisfied. - if (wait_mode == IREE_HAL_WAIT_MODE_ANY) { - needs_wait = false; - break; - } - continue; - } + iree_slim_mutex_lock(&semaphore->mutex); - timepoints[timepoint_count++] = timepoint; - status = iree_wait_set_insert(wait_set, timepoint->timepoint.host_wait); - if (!iree_status_is_ok(status)) break; + iree_status_t status = iree_ok_status(); + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_first(&semaphore->event_queue.tree); + + iree_host_size_t highest_value = 0; + while (node) { + if (!((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event) { + node = iree_hal_hip_util_tree_node_next(node); + continue; } - } - // Perform the wait. - if (iree_status_is_ok(status) && needs_wait) { - if (wait_mode == IREE_HAL_WAIT_MODE_ANY) { - status = iree_wait_any(wait_set, deadline_ns, /*out_wake_handle=*/NULL); - } else { - status = iree_wait_all(wait_set, deadline_ns); + hipError_t err = + semaphore->symbols->hipEventQuery(iree_hal_hip_event_handle( + ((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event)); + if (err == hipErrorNotReady) { + break; } + if (err != hipSuccess) { + status = IREE_HIP_RESULT_TO_STATUS(semaphore->symbols, err); + break; + } + + highest_value = iree_hal_hip_util_tree_node_get_key(node); + node = iree_hal_hip_util_tree_node_next(node); } - for (iree_host_size_t i = 0; i < timepoint_count; ++i) { - iree_hal_hip_timepoint_t* timepoint = timepoints[i]; - iree_hal_semaphore_t* semaphore = timepoint->base.semaphore; - // Cancel if this is still an unresolved host wait. - if (semaphore) { - iree_hal_semaphore_cancel_timepoint(semaphore, &timepoint->base); + if (iree_status_is_ok(status)) { + if (semaphore->current_visible_value < highest_value) { + semaphore->current_visible_value = highest_value; + iree_notification_post(&semaphore->state_notification, IREE_ALL_WAITERS); + } + + if (highest_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + status = + iree_make_status(IREE_STATUS_ABORTED, "the semaphore was aborted"); } - iree_hal_hip_timepoint_pool_release(timepoint->pool, 1, &timepoint); } - iree_wait_set_free(wait_set); - iree_arena_deinitialize(&arena); - if (!iree_status_is_ok(status)) { - IREE_TRACE_ZONE_END(z0); - return status; + iree_slim_mutex_unlock(&semaphore->mutex); + // If the status is aborted, we will pick up the real status from + // iree_hal_hip_event_semaphore_run_scheduled_callbacks. + if (iree_status_is_aborted(status)) { + iree_status_ignore(status); + status = iree_ok_status(); } + status = iree_status_join( + status, + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); + return status; +} - for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) { - iree_hal_hip_semaphore_t* semaphore = - iree_hal_hip_semaphore_cast(semaphore_list.semaphores[i]); - iree_slim_mutex_lock(&semaphore->mutex); - if (semaphore->current_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { - iree_slim_mutex_unlock(&semaphore->mutex); - status = iree_make_status(IREE_STATUS_ABORTED); - break; - } +static iree_status_t iree_hal_hip_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + iree_hal_hip_semaphore_t* semaphore = + iree_hal_hip_semaphore_cast(base_semaphore); + iree_slim_mutex_lock(&semaphore->mutex); + + iree_status_t status = iree_ok_status(); + if (new_value <= semaphore->current_visible_value) { + uint64_t current_value IREE_ATTRIBUTE_UNUSED = + semaphore->current_visible_value; iree_slim_mutex_unlock(&semaphore->mutex); + status = iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "semaphore values must be monotonically " + "increasing; current_value=%" PRIu64 + ", new_value=%" PRIu64, + current_value, new_value); } - IREE_TRACE_ZONE_END(z0); - return status; -} + if (iree_status_is_ok(status)) { + semaphore->current_visible_value = new_value; + } -// Handles device signal timepoints on the host when the |semaphore| timeline -// advances past the given |value|. -// -// Note that this callback is invoked by the a host thread after the HIP host -// function callback function is triggered in the HIP driver. -static iree_status_t iree_hal_hip_semaphore_timepoint_device_signal_callback( - void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value, - iree_status_code_t status_code) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_timepoint_t* timepoint = (iree_hal_hip_timepoint_t*)user_data; - // Just release the timepoint back to the pool. This will decrease the - // reference count of the underlying HIP event internally. - iree_hal_hip_timepoint_pool_release(timepoint->pool, 1, &timepoint); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + iree_slim_mutex_unlock(&semaphore->mutex); + + if (iree_status_is_ok(status)) { + status = + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore); + } + return status; } -// Acquires a timepoint to signal the timeline to the given |to_value| from the -// device. -iree_status_t iree_hal_hip_event_semaphore_acquire_timepoint_device_signal( - iree_hal_semaphore_t* base_semaphore, uint64_t to_value, - hipEvent_t* out_event) { +static void iree_hal_hip_semaphore_fail(iree_hal_semaphore_t* base_semaphore, + iree_status_t status) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - iree_hal_hip_timepoint_t* signal_timepoint = NULL; - IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_timepoint_pool_acquire_device_signal( - semaphore->timepoint_pool, 1, &signal_timepoint)); - - // Initialize the timepoint with the value and callback, and connect it to - // this semaphore. - iree_hal_semaphore_acquire_timepoint( - &semaphore->base, to_value, iree_infinite_timeout(), - (iree_hal_semaphore_callback_t){ - .fn = iree_hal_hip_semaphore_timepoint_device_signal_callback, - .user_data = signal_timepoint, - }, - &signal_timepoint->base); - iree_hal_hip_event_t* event = signal_timepoint->timepoint.device_signal; - - // Scan through the timepoint list and update device wait timepoints to wait - // for this device signal when possible. We need to lock with the timepoint - // list mutex here. - iree_slim_mutex_lock(&semaphore->base.timepoint_mutex); - for (iree_hal_semaphore_timepoint_t* tp = semaphore->base.timepoint_list.head; - tp != NULL; tp = tp->next) { - iree_hal_hip_timepoint_t* wait_timepoint = (iree_hal_hip_timepoint_t*)tp; - if (wait_timepoint->kind == IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_WAIT && - wait_timepoint->timepoint.device_wait == NULL && - wait_timepoint->base.minimum_value <= to_value) { - iree_hal_hip_event_retain(event); - wait_timepoint->timepoint.device_wait = event; - } + iree_slim_mutex_lock(&semaphore->mutex); + + // Try to set our local status - we only preserve the first failure so only + // do this if we are going from a valid semaphore to a failed one. + if (!iree_status_is_ok(semaphore->failure_status)) { + // Previous sta-tus was not OK; drop our new status. + iree_slim_mutex_unlock(&semaphore->mutex); + return; } - iree_slim_mutex_unlock(&semaphore->base.timepoint_mutex); - *out_event = iree_hal_hip_event_handle(event); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} + // Signal to our failure sentinel value. + semaphore->current_visible_value = IREE_HAL_SEMAPHORE_FAILURE_VALUE; + semaphore->failure_status = status; -// Handles device wait timepoints on the host when the |semaphore| timeline -// advances past the given |value|. -// -// Note that this callback is invoked by the a host thread. -static iree_status_t iree_hal_hip_semaphore_timepoint_device_wait_callback( - void* user_data, iree_hal_semaphore_t* semaphore, uint64_t value, - iree_status_code_t status_code) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_timepoint_t* timepoint = (iree_hal_hip_timepoint_t*)user_data; - // Just release the timepoint back to the pool. This will decrease the - // reference count of the underlying HIP event internally. - iree_hal_hip_timepoint_pool_release(timepoint->pool, 1, &timepoint); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + iree_slim_mutex_unlock(&semaphore->mutex); + iree_status_ignore( + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); } -// Acquires a timepoint to wait the timeline to reach at least the given -// |min_value| on the device. -iree_status_t iree_hal_hip_event_semaphore_acquire_timepoint_device_wait( - iree_hal_semaphore_t* base_semaphore, uint64_t min_value, - hipEvent_t* out_event) { +static iree_status_t iree_hal_hip_semaphore_wait( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_timeout_t timeout) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - iree_hal_hip_timepoint_t* wait_timepoint = NULL; IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_timepoint_pool_acquire_device_wait( - semaphore->timepoint_pool, 1, &wait_timepoint)); - - // Initialize the timepoint with the value and callback, and connect it to - // this semaphore. - iree_hal_semaphore_acquire_timepoint( - &semaphore->base, min_value, iree_infinite_timeout(), - (iree_hal_semaphore_callback_t){ - .fn = iree_hal_hip_semaphore_timepoint_device_wait_callback, - .user_data = wait_timepoint, - }, - &wait_timepoint->base); - - iree_hal_hip_event_t* wait_event = NULL; - if (iree_hal_hip_semaphore_acquire_event_host_wait(&semaphore->base, - min_value, &wait_event)) { - // We've found an existing signal timepoint to wait on; we don't need a - // standalone wait timepoint anymore. Decrease its refcount before - // overwriting it to return it back to the pool and retain the existing one. - iree_hal_hip_event_release(wait_timepoint->timepoint.device_wait); - wait_timepoint->timepoint.device_wait = wait_event; - } - - *out_event = iree_hal_hip_event_handle(wait_timepoint->timepoint.device_wait); + const iree_time_t deadline_ns = iree_timeout_as_deadline_ns(timeout); + iree_slim_mutex_lock(&semaphore->mutex); + uint64_t current_value = 0; + + // query_locked to make sure our count is up to date. + iree_status_t status = + iree_hal_hip_semaphore_query_locked(semaphore, ¤t_value); + + if (iree_status_is_ok(status)) { + while (semaphore->max_value_to_be_signaled < value) { + if (iree_time_now() > deadline_ns) { + status = iree_make_status(IREE_STATUS_DEADLINE_EXCEEDED); + break; + } + iree_wait_token_t wait = + iree_notification_prepare_wait(&semaphore->state_notification); + iree_slim_mutex_unlock(&semaphore->mutex); + + // We are going to pick up the correct status from query_locked below. + iree_status_ignore( + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); + + // We have to wait for the semaphore to catch up. + bool committed = + iree_notification_commit_wait(&semaphore->state_notification, wait, + IREE_DURATION_ZERO, deadline_ns); + + iree_slim_mutex_lock(&semaphore->mutex); + if (!committed) { + status = iree_make_status(IREE_STATUS_DEADLINE_EXCEEDED); + break; + } + + // query_locked to make sure our count is up to date. + status = iree_hal_hip_semaphore_query_locked(semaphore, ¤t_value); + if (!iree_status_is_ok(status)) { + break; + } + } + } + + if (iree_status_is_ok(status)) { + // The current value stored in the semaphore is greater than the current + // value, so we can return. + if (semaphore->current_visible_value >= value) { + iree_slim_mutex_unlock(&semaphore->mutex); + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore); + iree_slim_mutex_lock(&semaphore->mutex); + } else if (iree_timeout_is_infinite(timeout)) { + // This is the fast-path. Since we have an infinite timeout, we can + // wait directly on the hip event. + + // The current value is not enough, but we have at least submitted + // the work that will increment the semaphore to the value we need. + // Use iree_hal_hip_util_tree_lower_bound to find the first element in the + // tree that would signal our semaphore to at least the given value. + iree_hal_hip_util_tree_node_t* node = iree_hal_hip_util_tree_lower_bound( + &semaphore->event_queue.tree, value); + IREE_ASSERT( + node, + "We really should either have an event in the queue that will satisfy" + "this semaphore (we checked max_value_to_be_signaled above) or we" + "should already have signaled (current_visible_value above)"); + iree_hal_hip_semaphore_queue_item_t* item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node); + + iree_hal_hip_event_t* event = item->event; + + // Retain the event, as the event may be removed from the tree + // while we sleep on the event. + iree_hal_hip_event_retain(event); + iree_slim_mutex_unlock(&semaphore->mutex); + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore); + status = IREE_HIP_CALL_TO_STATUS( + semaphore->symbols, + hipEventSynchronize(iree_hal_hip_event_handle(event))); + iree_hal_hip_event_release(event); + iree_slim_mutex_lock(&semaphore->mutex); + } else { + // If we have a non-infinite timeout, this is the slow-path. + // because we will end up having to wait for either the + // cleanup thread, or someone else to advance the + // semaphore. + iree_slim_mutex_unlock(&semaphore->mutex); + iree_hal_hip_cpu_event_t* cpu_event = NULL; + status = iree_hal_hip_semaphore_get_cpu_event(base_semaphore, value, + &cpu_event); + if (iree_status_is_ok(status)) { + // If there is no cpu event the semaphore has hit the value already. + if (cpu_event) { + status = iree_wait_one(&cpu_event->event, deadline_ns); + iree_hal_resource_release(&cpu_event->resource); + } + } + iree_slim_mutex_lock(&semaphore->mutex); + } + } + + if (iree_status_is_ok(status)) { + if (semaphore->current_visible_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + status = + iree_make_status(IREE_STATUS_ABORTED, "the semaphore was aborted"); + } + } + iree_slim_mutex_unlock(&semaphore->mutex); IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } static const iree_hal_semaphore_vtable_t iree_hal_hip_semaphore_vtable = { diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.h b/runtime/src/iree/hal/drivers/hip/event_semaphore.h index a5d8ff95b369..ebbf23992dcb 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.h +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.h @@ -12,12 +12,12 @@ #include "iree/base/api.h" #include "iree/hal/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" -#include "iree/hal/drivers/hip/timepoint_pool.h" -#include "iree/hal/utils/deferred_work_queue.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus +typedef struct iree_hal_hip_event_t iree_hal_hip_event_t; +typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t; +typedef iree_status_t (*iree_hal_hip_event_semaphore_scheduled_callback_t)( + void* user_data, iree_hal_semaphore_t* semaphore, + iree_status_t semaphore_status); // Creates an IREE HAL semaphore with the given |initial_value|. // @@ -25,41 +25,52 @@ extern "C" { // different timepoints along the timeline under the hood. Those timepoints will // be allocated from the |timepoint_pool|. // -// This semaphore is meant to be used together with a pending queue actions; it -// may advance the given |work_queue| if new values are signaled. -// // Thread-safe; multiple threads may signal/wait values on the same semaphore. iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, - hipCtx_t hip_context, iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, - iree_hal_semaphore_t** out_semaphore); - -// Acquires a timepoint to signal the timeline to the given |to_value| from the -// device. The underlying HIP event is written into |out_event| for interacting -// with HIP APIs. -iree_status_t iree_hal_hip_event_semaphore_acquire_timepoint_device_signal( - iree_hal_semaphore_t* base_semaphore, uint64_t to_value, - hipEvent_t* out_event); - -// Acquires an iree_hal_hip_event_t object to wait on the host for the -// timeline to reach at least the given |min_value| on the device. -// Returns true and writes to |out_event| if we can find such an event; -// returns false otherwise. -// The caller should release the |out_event| once done. -bool iree_hal_hip_semaphore_acquire_event_host_wait( - iree_hal_semaphore_t* base_semaphore, uint64_t min_value, - iree_hal_hip_event_t** out_event); + iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore); // Performs a multi-wait on one or more semaphores. Returns // IREE_STATUS_DEADLINE_EXCEEDED if the wait does not complete before |timeout|. iree_status_t iree_hal_hip_semaphore_multi_wait( const iree_hal_semaphore_list_t semaphore_list, iree_hal_wait_mode_t wait_mode, iree_timeout_t timeout, - iree_arena_block_pool_t* block_pool); + iree_allocator_t host_allocator); + +// Adds a work item to be executed once we have a forward progress +// guarantee on this semaphore to reach a particular value. +// The event pool must be an event pool specifically +// for the queue that will be doing the work. +iree_status_t iree_hal_hip_semaphore_notify_work( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_hal_hip_event_pool_t* event_pool, + iree_hal_hip_event_semaphore_scheduled_callback_t callback, + void* user_data); + +// Notifies this semaphore that we have guaranteed +// forward progress until the particular value is reached. +iree_status_t iree_hal_hip_semaphore_notify_forward_progress_to( + iree_hal_semaphore_t* base_semaphore, uint64_t value); + +// Returns the hip event that needs to be signaled in order +// for the semaphore to reach a given value. +// This event *must* have been previously notified for +// forward progress by iree_hal_hip_semaphore_notify_forward_progress_to. +// If the return status is iree_ok_status(), and the out_hip_event is NULL, +// it is because the event has already been signaled, and the result +// is visible on the host. +// The refcount for the event is incremented here, and the caller +// must decrement when no longer needed. +iree_status_t iree_hal_hip_semaphore_get_hip_event( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_hal_hip_event_pool_t* event_pool, + iree_hal_hip_event_t** out_hip_event); + +iree_status_t iree_hal_hip_semaphore_create_event_and_record_if_necessary( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + hipStream_t dispatch_stream, iree_hal_hip_event_pool_t* event_pool); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +iree_status_t iree_hal_hip_event_semaphore_advance( + iree_hal_semaphore_t* semaphore); #endif // IREE_HAL_DRIVERS_HIP_EVENT_SEMAPHORE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c index 67ab8774432a..ae68f8829b18 100644 --- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c @@ -371,7 +371,7 @@ static iree_status_t iree_hal_hip_graph_command_buffer_end( // Compile the graph. hipGraphNode_t error_node = NULL; - iree_status_t status = IREE_HIP_RESULT_TO_STATUS( + iree_status_t status = IREE_HIP_CALL_TO_STATUS( command_buffer->symbols, hipGraphInstantiate(&command_buffer->hip_exec, command_buffer->hip_graph, &error_node, @@ -402,7 +402,6 @@ static iree_status_t iree_hal_hip_graph_command_buffer_begin_debug_group( location ? location->file.data : NULL, location ? location->file.size : 0, location ? location->line : 0, /*func_name=*/NULL, 0, label.data, label.size); - return iree_ok_status(); } @@ -490,9 +489,8 @@ static iree_status_t iree_hal_hip_graph_command_buffer_wait_events( } static iree_status_t iree_hal_hip_graph_command_buffer_advise_buffer( - iree_hal_command_buffer_t* base_command_buffer, - iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags, - uint64_t arg0, uint64_t arg1) { + iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t buffer_ref, + iree_hal_memory_advise_flags_t flags, uint64_t arg0, uint64_t arg1) { // We could mark the memory as invalidated so that if this is a managed buffer // HIP does not try to copy it back to the host. return iree_ok_status(); @@ -746,7 +744,8 @@ static iree_status_t iree_hal_hip_graph_command_buffer_dispatch( const iree_hal_hip_kernel_params_t* kernel_params = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_native_executable_lookup_kernel_params( - executable, entry_point, &kernel_params)); + executable, entry_point, command_buffer->base.queue_affinity, + &kernel_params)); IREE_HIP_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer, IREE_HAL_STREAM_TRACING_VERBOSITY_FINE, @@ -849,6 +848,14 @@ static iree_status_t iree_hal_hip_graph_command_buffer_dispatch_indirect( "indirect dispatch not yet implemented"); } +iree_hal_stream_tracing_context_event_list_t +iree_hal_hip_graph_command_buffer_tracing_events( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_hip_graph_command_buffer_t* command_buffer = + iree_hal_hip_graph_command_buffer_cast(base_command_buffer); + return command_buffer->tracing_event_list; +} + static const iree_hal_command_buffer_vtable_t iree_hal_hip_graph_command_buffer_vtable = { .destroy = iree_hal_hip_graph_command_buffer_destroy, diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.h b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.h index 424d780521c8..c02eebe24b47 100644 --- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.h +++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.h @@ -11,10 +11,7 @@ #include "iree/hal/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_headers.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus +#include "iree/hal/utils/stream_tracing.h" // NOTE: hipGraph API used in this module is marked as beta in the HIP // documentation, meaning, while this is feature complete it is still open to @@ -51,8 +48,8 @@ hipGraphExec_t iree_hal_hip_graph_command_buffer_handle( void iree_hal_hip_graph_tracing_notify_submitted_commands( iree_hal_command_buffer_t* command_buffer); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus +iree_hal_stream_tracing_context_event_list_t +iree_hal_hip_graph_command_buffer_tracing_events( + iree_hal_command_buffer_t* base_command_buffer); #endif // IREE_HAL_DRIVERS_HIP_GRAPH_COMMAND_BUFFER_H_ diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.c b/runtime/src/iree/hal/drivers/hip/hip_allocator.c index c8bc93f52aa4..eb618abe68fb 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.c +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.c @@ -9,10 +9,11 @@ #include #include "iree/base/api.h" +#include "iree/base/internal/math.h" #include "iree/base/tracing.h" -#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" +#include "iree/hal/drivers/hip/per_device_information.h" #include "iree/hal/drivers/hip/status_util.h" #if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_ALLOCATION_TRACKING @@ -27,16 +28,9 @@ typedef struct iree_hal_hip_allocator_t { // Parent device that this allocator is associated with. Unowned. iree_hal_device_t* parent_device; - // The device that this allocator allocates memory from. - hipDevice_t device; + iree_hal_hip_device_topology_t topology; - // The HIP stream that allocations should be used in. - hipStream_t stream; - - hipCtx_t hip_context; - - // NOTE: optional depending on device support. - iree_hal_hip_memory_pools_t* pools; + bool supports_memory_pools; const iree_hal_hip_dynamic_symbols_t* symbols; @@ -60,16 +54,19 @@ static iree_hal_hip_allocator_t* iree_hal_hip_allocator_cast( iree_status_t iree_hal_hip_allocator_create( iree_hal_device_t* parent_device, - const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device, - hipCtx_t hip_context, hipStream_t stream, - iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator, - iree_hal_allocator_t** out_allocator) { + const iree_hal_hip_dynamic_symbols_t* hip_symbols, + iree_hal_hip_device_topology_t topology, bool supports_memory_pools, + iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator) { IREE_ASSERT_ARGUMENT(parent_device); IREE_ASSERT_ARGUMENT(hip_symbols); IREE_ASSERT_ARGUMENT(out_allocator); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(hip_symbols, hip_context)); + *out_allocator = NULL; + if (topology.count < 1) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "at least one device must be specified"); + } // To support device-local + host-visible memory we need concurrent managed // access indicating that the host and devices can concurrently access the @@ -79,11 +76,11 @@ iree_status_t iree_hal_hip_allocator_create( // buffers except for readback staging buffers. int supports_concurrent_managed_access = 0; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, IREE_HIP_RESULT_TO_STATUS( + z0, IREE_HIP_CALL_TO_STATUS( hip_symbols, hipDeviceGetAttribute(&supports_concurrent_managed_access, hipDeviceAttributeConcurrentManagedAccess, - device), + topology.devices[0].hip_device), "hipDeviceGetAttribute")); IREE_TRACE_ZONE_APPEND_TEXT( @@ -99,14 +96,12 @@ iree_status_t iree_hal_hip_allocator_create( iree_hal_resource_initialize(&iree_hal_hip_allocator_vtable, &allocator->resource); allocator->parent_device = parent_device; - allocator->device = device; - allocator->stream = stream; - allocator->pools = pools; + allocator->supports_memory_pools = supports_memory_pools; allocator->symbols = hip_symbols; allocator->host_allocator = host_allocator; allocator->supports_concurrent_managed_access = supports_concurrent_managed_access != 0; - allocator->hip_context = hip_context; + allocator->topology = topology; *out_allocator = (iree_hal_allocator_t*)allocator; IREE_TRACE_ZONE_END(z0); @@ -147,9 +142,12 @@ static void iree_hal_hip_allocator_query_statistics( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); memcpy(out_statistics, &allocator->statistics, sizeof(*out_statistics)); - if (allocator->pools) { - iree_hal_hip_memory_pools_merge_statistics(allocator->pools, - out_statistics); + + if (allocator->supports_memory_pools) { + for (iree_host_size_t i = 0; i < allocator->topology.count; ++i) { + iree_hal_hip_memory_pools_merge_statistics( + &allocator->topology.devices[i].memory_pools, out_statistics); + } } }); } @@ -329,6 +327,8 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( const iree_hal_buffer_params_t* IREE_RESTRICT params, iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + IREE_ASSERT_ARGUMENT(out_buffer); + *out_buffer = NULL; iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); @@ -365,10 +365,19 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( void* host_ptr = NULL; hipDeviceptr_t device_ptr = NULL; IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_buffer_allocate"); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size); + + int device_ordinal = 0; + if (params->queue_affinity) { + device_ordinal = iree_math_count_trailing_zeros_u64(params->queue_affinity); + } + IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); + z0, IREE_HIP_CALL_TO_STATUS( + allocator->symbols, + hipCtxPushCurrent( + allocator->topology.devices[device_ordinal].hip_context))); - IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, allocation_size); if (iree_all_bits_set(compat_params.type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { // Device local case. @@ -376,23 +385,26 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { // Device local and host visible. buffer_type = IREE_HAL_HIP_BUFFER_TYPE_DEVICE; - status = IREE_HIP_RESULT_TO_STATUS( + status = IREE_HIP_CALL_TO_STATUS( allocator->symbols, hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal)); if (iree_status_is_ok(status) && allocator->supports_concurrent_managed_access) { // Prefetch the buffer on the GPU device. - status = IREE_HIP_RESULT_TO_STATUS( + status = IREE_HIP_CALL_TO_STATUS( allocator->symbols, - hipMemPrefetchAsync(device_ptr, allocation_size, allocator->device, - allocator->stream)); + hipMemPrefetchAsync( + device_ptr, allocation_size, + allocator->topology.devices[device_ordinal].hip_device, + allocator->topology.devices[device_ordinal] + .hip_dispatch_stream)); } host_ptr = (void*)device_ptr; } else { // Device only. buffer_type = IREE_HAL_HIP_BUFFER_TYPE_DEVICE; - status = IREE_HIP_RESULT_TO_STATUS( - allocator->symbols, hipMalloc(&device_ptr, allocation_size)); + status = IREE_HIP_CALL_TO_STATUS(allocator->symbols, + hipMalloc(&device_ptr, allocation_size)); } } else { // Host local case. @@ -402,10 +414,10 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( IREE_HAL_MEMORY_TYPE_HOST_CACHED)) { flags |= hipHostMallocWriteCombined; } - status = IREE_HIP_RESULT_TO_STATUS( + status = IREE_HIP_CALL_TO_STATUS( allocator->symbols, hipHostMalloc(&host_ptr, allocation_size, flags)); if (iree_status_is_ok(status)) { - status = IREE_HIP_RESULT_TO_STATUS( + status = IREE_HIP_CALL_TO_STATUS( allocator->symbols, hipHostGetDevicePointer(&device_ptr, host_ptr, /*flags=*/0)); } @@ -444,7 +456,10 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( iree_hal_buffer_release(buffer); } } - return status; + + return iree_status_join( + status, + IREE_HIP_CALL_TO_STATUS(allocator->symbols, hipCtxPopCurrent(NULL))); } static void iree_hal_hip_allocator_deallocate_buffer( @@ -453,9 +468,6 @@ static void iree_hal_hip_allocator_deallocate_buffer( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); - IREE_IGNORE_ERROR( - iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); - const iree_hal_hip_buffer_type_t buffer_type = iree_hal_hip_buffer_type(base_buffer); @@ -488,12 +500,11 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( iree_hal_external_buffer_t* IREE_RESTRICT external_buffer, iree_hal_buffer_release_callback_t release_callback, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + IREE_ASSERT_ARGUMENT(out_buffer); + *out_buffer = NULL; iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); - // Coerce options into those required by the current device. iree_hal_buffer_params_t compat_params = *params; iree_device_size_t allocation_size = external_buffer->size; @@ -523,6 +534,16 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( #endif // IREE_STATUS_MODE } + int device_ordinal = 0; + if (params->queue_affinity) { + device_ordinal = iree_math_count_trailing_zeros_u64(params->queue_affinity); + } + + IREE_RETURN_IF_ERROR(IREE_HIP_CALL_TO_STATUS( + allocator->symbols, + hipCtxPushCurrent( + allocator->topology.devices[device_ordinal].hip_context))); + iree_status_t status = iree_ok_status(); iree_hal_hip_buffer_type_t buffer_type = IREE_HAL_HIP_BUFFER_TYPE_DEVICE; void* host_ptr = NULL; @@ -539,12 +560,12 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( buffer_type = IREE_HAL_HIP_BUFFER_TYPE_HOST_REGISTERED; host_ptr = external_buffer->handle.host_allocation.ptr; uint32_t register_flags = hipHostRegisterMapped; - status = IREE_HIP_RESULT_TO_STATUS( + status = IREE_HIP_CALL_TO_STATUS( allocator->symbols, hipHostRegister(host_ptr, external_buffer->size, register_flags), "hipHostRegister"); if (iree_status_is_ok(status)) { - status = IREE_HIP_RESULT_TO_STATUS( + status = IREE_HIP_CALL_TO_STATUS( allocator->symbols, hipHostGetDevicePointer(&device_ptr, host_ptr, 0), "hipHostGetDevicePointer"); @@ -593,7 +614,10 @@ static iree_status_t iree_hal_hip_allocator_import_buffer( iree_hal_buffer_release(buffer); } } - return status; + + return iree_status_join( + status, + IREE_HIP_CALL_TO_STATUS(allocator->symbols, hipCtxPopCurrent(NULL))); } static iree_status_t iree_hal_hip_allocator_export_buffer( @@ -634,15 +658,16 @@ iree_status_t iree_hal_hip_allocator_alloc_async( iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); - + // In an ideal world we would use hipMallocAsync/hipFreeAsync, + // however the caching inside can cause lots of slack + // to the point of unusability depending on the memory allocation + // patterns of the host program, so instead we simply hipMalloc/hipFree. hipDeviceptr_t ptr = NULL; - iree_status_t status = IREE_HIP_RESULT_TO_STATUS( + iree_status_t status = IREE_HIP_CALL_TO_STATUS( allocator->symbols, - hipMallocAsync(&ptr, (size_t)iree_hal_buffer_allocation_size(buffer), - stream), - "hipMallocAsync"); + hipMalloc(&ptr, (size_t)iree_hal_buffer_allocation_size(buffer)), + "hipMalloc"); + if (iree_status_is_ok(status)) { iree_hal_hip_buffer_set_device_pointer(buffer, ptr); IREE_TRACE_ALLOC_NAMED(IREE_HAL_HIP_ALLOCATOR_ID, (void*)ptr, @@ -662,16 +687,13 @@ iree_status_t iree_hal_hip_allocator_free_async( iree_hal_buffer_t* buffer) { iree_hal_hip_allocator_t* allocator = iree_hal_hip_allocator_cast(base_allocator); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(allocator->symbols, allocator->hip_context)); - hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (!device_ptr) { return iree_ok_status(); } - IREE_RETURN_IF_ERROR(IREE_HIP_RESULT_TO_STATUS( - allocator->symbols, hipFreeAsync(device_ptr, stream), "hipFreeAsync")); + IREE_RETURN_IF_ERROR(IREE_HIP_CALL_TO_STATUS(allocator->symbols, + hipFree(device_ptr), "hipFree")); iree_hal_hip_buffer_set_allocation_empty(buffer); IREE_TRACE_FREE_NAMED(IREE_HAL_HIP_ALLOCATOR_ID, (void*)device_ptr); diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.h b/runtime/src/iree/hal/drivers/hip/hip_allocator.h index 5c19a7a957df..4b85259fcf0a 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.h +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.h @@ -12,20 +12,17 @@ #include "iree/hal/drivers/hip/memory_pools.h" #include "iree/hal/drivers/hip/status_util.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus +typedef struct iree_hal_hip_device_topology_t iree_hal_hip_device_topology_t; // Creates a HIP memory allocator. -// |device| |hip_context| and |stream| will be used for management operations. +// |device| and |stream| will be used for management operations. // |pools| provides memory pools that may be shared across multiple allocators // and the pointer must remain valid for the lifetime of the allocator. iree_status_t iree_hal_hip_allocator_create( iree_hal_device_t* parent_device, - const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t device, - hipCtx_t hip_context, hipStream_t stream, - iree_hal_hip_memory_pools_t* pools, iree_allocator_t host_allocator, - iree_hal_allocator_t** out_allocator); + const iree_hal_hip_dynamic_symbols_t* hip_symbols, + iree_hal_hip_device_topology_t topology, bool supports_memory_pools, + iree_allocator_t host_allocator, iree_hal_allocator_t** out_allocator); bool iree_hal_hip_allocator_isa(iree_hal_allocator_t* base_value); diff --git a/runtime/src/iree/hal/drivers/hip/hip_buffer.c b/runtime/src/iree/hal/drivers/hip/hip_buffer.c index a0efa9a60d8f..16e1cc210df9 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/hip_buffer.c @@ -49,6 +49,7 @@ iree_status_t iree_hal_hip_buffer_wrap( void* host_ptr, iree_hal_buffer_release_callback_t release_callback, iree_allocator_t host_allocator, iree_hal_buffer_t** out_buffer) { IREE_ASSERT_ARGUMENT(out_buffer); + *out_buffer = NULL; if (!host_ptr && iree_any_bit_set(allowed_usage, IREE_HAL_BUFFER_USAGE_MAPPING_PERSISTENT | IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED)) { diff --git a/runtime/src/iree/hal/drivers/hip/hip_buffer.h b/runtime/src/iree/hal/drivers/hip/hip_buffer.h index 3a18956effde..4ea806955f16 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_buffer.h +++ b/runtime/src/iree/hal/drivers/hip/hip_buffer.h @@ -11,10 +11,6 @@ #include "iree/hal/api.h" #include "iree/hal/drivers/hip/hip_headers.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - typedef enum iree_hal_hip_buffer_type_e { // Device local buffer; allocated with hipMalloc/hipMallocManaged, freed // with hipFree. @@ -69,8 +65,4 @@ void* iree_hal_hip_buffer_host_pointer(const iree_hal_buffer_t* buffer); // this call returns and the caller has released its reference. void iree_hal_hip_buffer_drop_release_callback(iree_hal_buffer_t* buffer); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_BUFFER_H_ diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index c6fe92967cd6..bfb58f5fad41 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -14,22 +14,23 @@ #include "iree/base/internal/event_pool.h" #include "iree/base/internal/math.h" #include "iree/base/tracing.h" -#include "iree/hal/drivers/hip/context_util.h" +#include "iree/hal/drivers/hip/cleanup_thread.h" +#include "iree/hal/drivers/hip/dispatch_thread.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/event_pool.h" #include "iree/hal/drivers/hip/event_semaphore.h" #include "iree/hal/drivers/hip/graph_command_buffer.h" #include "iree/hal/drivers/hip/hip_allocator.h" #include "iree/hal/drivers/hip/hip_buffer.h" +#include "iree/hal/drivers/hip/hip_multi_queue_command_buffer.h" #include "iree/hal/drivers/hip/memory_pools.h" #include "iree/hal/drivers/hip/nop_executable_cache.h" +#include "iree/hal/drivers/hip/per_device_information.h" #include "iree/hal/drivers/hip/rccl_channel.h" #include "iree/hal/drivers/hip/rccl_dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" #include "iree/hal/drivers/hip/stream_command_buffer.h" -#include "iree/hal/drivers/hip/timepoint_pool.h" #include "iree/hal/utils/deferred_command_buffer.h" -#include "iree/hal/utils/deferred_work_queue.h" #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" #include "iree/hal/utils/stream_tracing.h" @@ -38,6 +39,11 @@ // iree_hal_hip_device_t //===----------------------------------------------------------------------===// +typedef enum iree_hip_device_commandbuffer_type_t { + IREE_HAL_HIP_DEVICE_COMMAND_BUFFER_TYPE_STREAM, + IREE_HAL_HIP_DEVICE_COMMAND_BUFFER_TYPE_GRAPH, +} iree_hip_device_commandbuffer_type_t; + typedef struct iree_hal_hip_device_t { // Abstract resource used for injecting reference counting and vtable; // must be at offset 0. @@ -58,261 +64,50 @@ typedef struct iree_hal_hip_device_t { // Parameters used to control device behavior. iree_hal_hip_device_params_t params; - hipCtx_t hip_context; - hipDevice_t hip_device; - // TODO: Support multiple device streams. - // The hipStream_t used to issue device kernels and allocations. - hipStream_t hip_dispatch_stream; - - iree_hal_stream_tracing_context_t* tracing_context; - iree_allocator_t host_allocator; // Host/device event pools, used for backing semaphore timepoints. iree_event_pool_t* host_event_pool; - iree_hal_hip_event_pool_t* device_event_pool; - // Timepoint pools, shared by various semaphores. - iree_hal_hip_timepoint_pool_t* timepoint_pool; - - // A queue to order device workloads and relase to the GPU when constraints - // are met. It buffers submissions and allocations internally before they - // are ready. This queue couples with HAL semaphores backed by iree_event_t - // and hipEvent_t objects. - iree_hal_deferred_work_queue_t* work_queue; // Device memory pools and allocators. bool supports_memory_pools; - iree_hal_hip_memory_pools_t memory_pools; - iree_hal_allocator_t* device_allocator; // Optional provider used for creating/configuring collective channels. iree_hal_channel_provider_t* channel_provider; -} iree_hal_hip_device_t; - -static iree_hal_hip_device_t* iree_hal_hip_device_cast( - iree_hal_device_t* base_value); - -static const iree_hal_device_vtable_t iree_hal_hip_device_vtable; -static const iree_hal_deferred_work_queue_device_interface_vtable_t - iree_hal_hip_deferred_work_queue_device_interface_vtable; - -// We put a hipEvent_t into a iree_hal_deferred_work_queue_native_event_t. -static_assert(sizeof(hipEvent_t) <= - sizeof(iree_hal_deferred_work_queue_native_event_t), - "Unexpected event size"); -typedef struct iree_hal_hip_deferred_work_queue_device_interface_t { - iree_hal_deferred_work_queue_device_interface_t base; - iree_hal_device_t* device; - hipDevice_t hip_device; - hipCtx_t hip_context; - hipStream_t dispatch_hip_stream; - iree_allocator_t host_allocator; - const iree_hal_hip_dynamic_symbols_t* hip_symbols; -} iree_hal_hip_deferred_work_queue_device_interface_t; - -static void iree_hal_hip_deferred_work_queue_device_interface_destroy( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - iree_allocator_free(device_interface->host_allocator, device_interface); -} - -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_bind_to_thread( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return IREE_HIP_RESULT_TO_STATUS( - device_interface->hip_symbols, - hipCtxSetCurrent(device_interface->hip_context), "hipCtxSetCurrent"); -} - -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_wait_native_event( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_deferred_work_queue_native_event_t event) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return IREE_HIP_RESULT_TO_STATUS( - device_interface->hip_symbols, - hipStreamWaitEvent(device_interface->dispatch_hip_stream, - (hipEvent_t)event, 0), - "hipStreamWaitEvent"); -} - -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_create_native_event( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_deferred_work_queue_native_event_t* out_event) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, - hipEventCreate((hipEvent_t*)out_event), - "hipEventCreate"); -} -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_record_native_event( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_deferred_work_queue_native_event_t event) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return IREE_HIP_RESULT_TO_STATUS( - device_interface->hip_symbols, - hipEventRecord((hipEvent_t)event, device_interface->dispatch_hip_stream), - "hipEventRecord"); -} - -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_synchronize_native_event( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_deferred_work_queue_native_event_t event) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, - hipEventSynchronize((hipEvent_t)event)); -} - -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_destroy_native_event( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_deferred_work_queue_native_event_t event) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, - hipEventDestroy((hipEvent_t)event)); -} - -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event( - iree_hal_deferred_work_queue_device_interface_t* device_interface, - struct iree_hal_semaphore_t* semaphore, uint64_t value, - iree_hal_deferred_work_queue_native_event_t* out_event) { - return iree_hal_hip_event_semaphore_acquire_timepoint_device_signal( - semaphore, value, (hipEvent_t*)out_event); -} -static bool -iree_hal_hip_deferred_work_queue_device_interface_acquire_host_wait_event( - iree_hal_deferred_work_queue_device_interface_t* device_interface, - struct iree_hal_semaphore_t* semaphore, uint64_t value, - iree_hal_deferred_work_queue_host_device_event_t* out_event) { - return iree_hal_hip_semaphore_acquire_event_host_wait( - semaphore, value, (iree_hal_hip_event_t**)out_event); -} + iree_hal_allocator_t* device_allocator; -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_device_wait_on_host_event( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_deferred_work_queue_host_device_event_t wait_event) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return IREE_HIP_RESULT_TO_STATUS( - device_interface->hip_symbols, - hipStreamWaitEvent( - device_interface->dispatch_hip_stream, - iree_hal_hip_event_handle((iree_hal_hip_event_t*)wait_event), 0), - "hipStreamWaitEvent"); -} + iree_hal_hip_cleanup_thread_t* cleanup_thread; -static void -iree_hal_hip_deferred_work_queue_device_interface_release_wait_event( - iree_hal_deferred_work_queue_device_interface_t* device_interface, - iree_hal_deferred_work_queue_host_device_event_t wait_event) { - iree_hal_hip_event_release(wait_event); -} + iree_hal_hip_cleanup_thread_t* buffer_free_thread; -static iree_hal_deferred_work_queue_native_event_t -iree_hal_hip_deferred_work_queue_device_interface_native_event_from_wait_event( - iree_hal_deferred_work_queue_device_interface_t* device_interface, - iree_hal_deferred_work_queue_host_device_event_t event) { - iree_hal_hip_event_t* wait_event = (iree_hal_hip_event_t*)event; - return iree_hal_hip_event_handle(wait_event); -} + iree_host_size_t device_count; -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_create_stream_command_buffer( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t categories, - iree_hal_command_buffer_t** out) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); - return iree_hal_hip_device_create_stream_command_buffer( - device_interface->device, mode, categories, 0, out); -} - -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_submit_command_buffer( - iree_hal_deferred_work_queue_device_interface_t* device_interface, - iree_hal_command_buffer_t* command_buffer) { - iree_hal_hip_deferred_work_queue_device_interface_t* table = - (iree_hal_hip_deferred_work_queue_device_interface_t*)(device_interface); - iree_status_t status = iree_ok_status(); - if (iree_hal_hip_stream_command_buffer_isa(command_buffer)) { - // Stream command buffer so nothing to do but notify it was submitted. - iree_hal_hip_stream_notify_submitted_commands(command_buffer); - } else { - hipGraphExec_t exec = - iree_hal_hip_graph_command_buffer_handle(command_buffer); - status = IREE_HIP_RESULT_TO_STATUS( - table->hip_symbols, hipGraphLaunch(exec, table->dispatch_hip_stream)); - if (IREE_LIKELY(iree_status_is_ok(status))) { - iree_hal_hip_graph_tracing_notify_submitted_commands(command_buffer); - } - } - return status; -} + iree_hal_hip_per_device_info_t devices[]; +} iree_hal_hip_device_t; -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_async_alloc( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_buffer_t* buffer) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*) - base_device_interface; - iree_hal_hip_device_t* device = - iree_hal_hip_device_cast(device_interface->device); - if (device->supports_memory_pools) { - return iree_hal_hip_memory_pools_allocate_pointer( - &device->memory_pools, buffer, device->hip_dispatch_stream, - iree_hal_buffer_allocation_size(buffer)); - } +static iree_hal_hip_device_t* iree_hal_hip_device_cast( + iree_hal_device_t* base_value); - return iree_hal_hip_allocator_alloc_async( - iree_hal_device_allocator(device_interface->device), - device->hip_dispatch_stream, buffer); -} +static const iree_hal_device_vtable_t iree_hal_hip_device_vtable; -// Asynchronously frees a buffer. -static iree_status_t -iree_hal_hip_deferred_work_queue_device_interface_async_dealloc( - iree_hal_deferred_work_queue_device_interface_t* base_device_interface, - iree_hal_buffer_t* buffer) { - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = - (iree_hal_hip_deferred_work_queue_device_interface_t*) - base_device_interface; - iree_hal_hip_device_t* device = - iree_hal_hip_device_cast(device_interface->device); - if (device->supports_memory_pools) { - return iree_hal_hip_memory_pools_deallocate( - &device->memory_pools, device->hip_dispatch_stream, buffer); - } - return iree_hal_hip_allocator_free_async( - iree_hal_device_allocator(device_interface->device), - device->hip_dispatch_stream, buffer); -} +static iree_status_t iree_hal_hip_device_create_command_buffer_internal( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_hip_device_commandbuffer_type_t type, + iree_hal_command_buffer_t** out_command_buffer); typedef struct iree_hal_hip_tracing_device_interface_t { iree_hal_stream_tracing_device_interface_t base; - hipDevice_t device; - hipCtx_t context; - hipStream_t dispatch_stream; + iree_hal_hip_per_device_info_t* device_context; iree_allocator_t host_allocator; const iree_hal_hip_dynamic_symbols_t* hip_symbols; } iree_hal_hip_tracing_device_interface_t; static const iree_hal_stream_tracing_device_interface_vtable_t iree_hal_hip_tracing_device_interface_vtable_t; -void iree_hal_hip_tracing_device_interface_destroy( +static void iree_hal_hip_tracing_device_interface_destroy( iree_hal_stream_tracing_device_interface_t* base_device_interface) { iree_hal_hip_tracing_device_interface_t* device_interface = (iree_hal_hip_tracing_device_interface_t*)base_device_interface; @@ -320,38 +115,39 @@ void iree_hal_hip_tracing_device_interface_destroy( iree_allocator_free(device_interface->host_allocator, device_interface); } -iree_status_t iree_hal_hip_tracing_device_interface_synchronize_native_event( +static iree_status_t +iree_hal_hip_tracing_device_interface_synchronize_native_event( iree_hal_stream_tracing_device_interface_t* base_device_interface, iree_hal_stream_tracing_native_event_t base_event) { iree_hal_hip_tracing_device_interface_t* device_interface = (iree_hal_hip_tracing_device_interface_t*)base_device_interface; - return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, - hipEventSynchronize((hipEvent_t)base_event)); + return IREE_HIP_CALL_TO_STATUS(device_interface->hip_symbols, + hipEventSynchronize((hipEvent_t)base_event)); } -iree_status_t iree_hal_hip_tracing_device_interface_create_native_event( +static iree_status_t iree_hal_hip_tracing_device_interface_create_native_event( iree_hal_stream_tracing_device_interface_t* base_device_interface, iree_hal_stream_tracing_native_event_t* base_event) { iree_hal_hip_tracing_device_interface_t* device_interface = (iree_hal_hip_tracing_device_interface_t*)base_device_interface; - return IREE_HIP_RESULT_TO_STATUS( + return IREE_HIP_CALL_TO_STATUS( device_interface->hip_symbols, hipEventCreateWithFlags((hipEvent_t*)base_event, hipEventDefault)); } -iree_status_t iree_hal_hip_tracing_device_interface_query_native_event( +static iree_status_t iree_hal_hip_tracing_device_interface_query_native_event( iree_hal_stream_tracing_device_interface_t* base_device_interface, iree_hal_stream_tracing_native_event_t base_event) { iree_hal_hip_tracing_device_interface_t* device_interface = (iree_hal_hip_tracing_device_interface_t*)base_device_interface; - return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, - hipEventQuery((hipEvent_t)base_event)); + return IREE_HIP_CALL_TO_STATUS(device_interface->hip_symbols, + hipEventQuery((hipEvent_t)base_event)); } -void iree_hal_hip_tracing_device_interface_event_elapsed_time( +static void iree_hal_hip_tracing_device_interface_event_elapsed_time( iree_hal_stream_tracing_device_interface_t* base_device_interface, float* relative_millis, iree_hal_stream_tracing_native_event_t start_event, iree_hal_stream_tracing_native_event_t end_event) { @@ -364,7 +160,7 @@ void iree_hal_hip_tracing_device_interface_event_elapsed_time( (hipEvent_t)end_event)); } -void iree_hal_hip_tracing_device_interface_destroy_native_event( +static void iree_hal_hip_tracing_device_interface_destroy_native_event( iree_hal_stream_tracing_device_interface_t* base_device_interface, iree_hal_stream_tracing_native_event_t base_event) { iree_hal_hip_tracing_device_interface_t* device_interface = @@ -374,19 +170,21 @@ void iree_hal_hip_tracing_device_interface_destroy_native_event( hipEventDestroy((hipEvent_t)base_event)); } -iree_status_t iree_hal_hip_tracing_device_interface_record_native_event( +static iree_status_t iree_hal_hip_tracing_device_interface_record_native_event( iree_hal_stream_tracing_device_interface_t* base_device_interface, iree_hal_stream_tracing_native_event_t base_event) { iree_hal_hip_tracing_device_interface_t* device_interface = (iree_hal_hip_tracing_device_interface_t*)base_device_interface; - return IREE_HIP_RESULT_TO_STATUS( + return IREE_HIP_CALL_TO_STATUS( device_interface->hip_symbols, - hipEventRecord((hipEvent_t)base_event, - (hipStream_t)device_interface->dispatch_stream)); + hipEventRecord( + (hipEvent_t)base_event, + (hipStream_t)device_interface->device_context->hip_dispatch_stream)); } -iree_status_t iree_hal_hip_tracing_device_interface_add_graph_event_record_node( +static iree_status_t +iree_hal_hip_tracing_device_interface_add_graph_event_record_node( iree_hal_stream_tracing_device_interface_t* base_device_interface, iree_hal_stream_tracing_native_graph_node_t* out_node, iree_hal_stream_tracing_native_graph_t graph, @@ -396,7 +194,7 @@ iree_status_t iree_hal_hip_tracing_device_interface_add_graph_event_record_node( iree_hal_hip_tracing_device_interface_t* device_interface = (iree_hal_hip_tracing_device_interface_t*)base_device_interface; - return IREE_HIP_RESULT_TO_STATUS( + return IREE_HIP_CALL_TO_STATUS( device_interface->hip_symbols, hipGraphAddEventRecordNode((hipGraphNode_t*)out_node, (hipGraph_t)graph, (hipGraphNode_t*)dependency_nodes, @@ -439,123 +237,141 @@ static iree_status_t iree_hal_hip_device_check_params( return iree_ok_status(); } -static iree_status_t iree_hal_hip_device_create_internal( +static iree_hal_hip_device_topology_t iree_hal_hip_device_make_topology( + iree_hal_hip_device_t* device) { + iree_hal_hip_device_topology_t topology = {.count = device->device_count, + .devices = device->devices}; + return topology; +} + +static iree_status_t iree_hal_hip_device_initialize_internal( iree_hal_driver_t* driver, iree_string_view_t identifier, - const iree_hal_hip_device_params_t* params, hipDevice_t hip_device, - hipStream_t dispatch_stream, hipCtx_t context, + const iree_hal_hip_device_params_t* params, iree_hal_hip_device_t* device, const iree_hal_hip_dynamic_symbols_t* symbols, const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, - iree_allocator_t host_allocator, iree_hal_device_t** out_device) { - iree_hal_hip_device_t* device = NULL; - iree_host_size_t total_size = iree_sizeof_struct(*device) + identifier.size; - IREE_RETURN_IF_ERROR( - iree_allocator_malloc(host_allocator, total_size, (void**)&device)); - - iree_hal_resource_initialize(&iree_hal_hip_device_vtable, &device->resource); - iree_string_view_append_to_buffer( - identifier, &device->identifier, - (char*)device + iree_sizeof_struct(*device)); - iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, - &device->block_pool); - device->driver = driver; - iree_hal_driver_retain(device->driver); - device->hip_symbols = symbols; - device->nccl_symbols = nccl_symbols; - device->params = *params; - device->hip_context = context; - device->hip_device = hip_device; - device->hip_dispatch_stream = dispatch_stream; - device->host_allocator = host_allocator; + iree_allocator_t host_allocator) { + IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_deferred_work_queue_device_interface_t* device_interface; - iree_status_t status = iree_allocator_malloc( - host_allocator, - sizeof(iree_hal_hip_deferred_work_queue_device_interface_t), - (void**)&device_interface); - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_device_release((iree_hal_device_t*)device); - return status; - } - device_interface->base.vtable = - &iree_hal_hip_deferred_work_queue_device_interface_vtable; - device_interface->hip_context = context; - device_interface->hip_symbols = symbols; - device_interface->device = (iree_hal_device_t*)device; - device_interface->hip_device = hip_device; - device_interface->dispatch_hip_stream = dispatch_stream; - device_interface->host_allocator = host_allocator; - status = iree_hal_deferred_work_queue_create( - (iree_hal_deferred_work_queue_device_interface_t*)device_interface, - &device->block_pool, host_allocator, &device->work_queue); - - // Enable tracing for the (currently only) stream - no-op if disabled. - if (iree_status_is_ok(status) && device->params.stream_tracing) { + if (device->params.stream_tracing) { if (device->params.stream_tracing >= IREE_HAL_STREAM_TRACING_VERBOSITY_MAX || device->params.stream_tracing < IREE_HAL_STREAM_TRACING_VERBOSITY_OFF) { + IREE_TRACE_ZONE_END(z0); return iree_make_status( IREE_STATUS_INVALID_ARGUMENT, "invalid stream_tracing argument: expected to be between %d and %d", IREE_HAL_STREAM_TRACING_VERBOSITY_OFF, IREE_HAL_STREAM_TRACING_VERBOSITY_MAX); } + } - iree_hal_hip_tracing_device_interface_t* tracing_device_interface = NULL; - status = iree_allocator_malloc( - host_allocator, sizeof(iree_hal_hip_tracing_device_interface_t), - (void**)&tracing_device_interface); + const iree_host_size_t identifier_offset = + sizeof(*device) + + sizeof(iree_hal_hip_per_device_info_t) * device->device_count; - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_device_release((iree_hal_device_t*)device); - return status; - } + iree_hal_resource_initialize(&iree_hal_hip_device_vtable, &device->resource); + iree_string_view_append_to_buffer(identifier, &device->identifier, + (char*)device + identifier_offset); + iree_arena_block_pool_initialize(params->arena_block_size, host_allocator, + &device->block_pool); + device->driver = driver; + iree_hal_driver_retain(device->driver); + device->hip_symbols = symbols; + device->nccl_symbols = nccl_symbols; + device->params = *params; + + device->host_allocator = host_allocator; + iree_status_t status = iree_ok_status(); + // Enable tracing for each of the streams - no-op if disabled. + if (device->params.stream_tracing) { + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + iree_hal_hip_tracing_device_interface_t* tracing_device_interface = NULL; + status = iree_allocator_malloc(host_allocator, + sizeof(*tracing_device_interface), + (void**)&tracing_device_interface); + + if (!iree_status_is_ok(status)) { + break; + } - tracing_device_interface->base.vtable = - &iree_hal_hip_tracing_device_interface_vtable_t; - tracing_device_interface->context = context; - tracing_device_interface->device = hip_device; - tracing_device_interface->dispatch_stream = dispatch_stream; - tracing_device_interface->host_allocator = host_allocator; - tracing_device_interface->hip_symbols = symbols; + tracing_device_interface->base.vtable = + &iree_hal_hip_tracing_device_interface_vtable_t; + tracing_device_interface->device_context = &device->devices[i]; + tracing_device_interface->host_allocator = host_allocator; + tracing_device_interface->hip_symbols = symbols; - status = iree_hal_stream_tracing_context_allocate( - (iree_hal_stream_tracing_device_interface_t*)tracing_device_interface, - device->identifier, device->params.stream_tracing, &device->block_pool, - host_allocator, &device->tracing_context); + status = IREE_HIP_CALL_TO_STATUS( + symbols, hipCtxPushCurrent(device->devices[i].hip_context)); + if (!iree_status_is_ok(status)) { + break; + } + status = iree_hal_stream_tracing_context_allocate( + (iree_hal_stream_tracing_device_interface_t*)tracing_device_interface, + device->identifier, device->params.stream_tracing, + &device->block_pool, host_allocator, + &device->devices[i].tracing_context); + status = IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL)); + if (!iree_status_is_ok(status)) { + break; + } + } } // Memory pool support is conditional. if (iree_status_is_ok(status) && params->async_allocations) { - int supports_memory_pools = 0; - status = IREE_HIP_RESULT_TO_STATUS( - symbols, - hipDeviceGetAttribute(&supports_memory_pools, - hipDeviceAttributeMemoryPoolsSupported, - hip_device), - "hipDeviceGetAttribute"); - device->supports_memory_pools = supports_memory_pools != 0; + device->supports_memory_pools = true; + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + int supports_memory_pools = 0; + status = IREE_HIP_CALL_TO_STATUS( + symbols, + hipDeviceGetAttribute(&supports_memory_pools, + hipDeviceAttributeMemoryPoolsSupported, + device->devices[i].hip_device), + "hipDeviceGetAttribute"); + device->supports_memory_pools &= (supports_memory_pools != 0); + } } // Create memory pools first so that we can share them with the allocator. if (iree_status_is_ok(status) && device->supports_memory_pools) { - status = iree_hal_hip_memory_pools_initialize( - (iree_hal_device_t*)device, symbols, hip_device, context, - ¶ms->memory_pools, host_allocator, &device->memory_pools); + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + device->supports_memory_pools = false; + status = iree_hal_hip_memory_pools_initialize( + (iree_hal_device_t*)device, symbols, device->devices[i].hip_device, + ¶ms->memory_pools, host_allocator, + &device->devices[i].memory_pools); + } } + status = iree_hal_hip_allocator_create( + (iree_hal_device_t*)device, symbols, + iree_hal_hip_device_make_topology(device), device->supports_memory_pools, + host_allocator, &device->device_allocator); + if (iree_status_is_ok(status)) { - status = iree_hal_hip_allocator_create( - (iree_hal_device_t*)device, symbols, hip_device, context, - dispatch_stream, - device->supports_memory_pools ? &device->memory_pools : NULL, - host_allocator, &device->device_allocator); + status = iree_hal_hip_cleanup_thread_initialize(symbols, host_allocator, + &device->cleanup_thread); } if (iree_status_is_ok(status)) { - *out_device = (iree_hal_device_t*)device; - } else { + status = iree_hal_hip_cleanup_thread_initialize( + symbols, host_allocator, &device->buffer_free_thread); + } + + if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + status = iree_hal_hip_dispatch_thread_initialize( + host_allocator, &device->devices[i].dispatch_thread); + if (!iree_status_is_ok(status)) { + break; + } + } + } + + if (!iree_status_is_ok(status)) { iree_hal_device_release((iree_hal_device_t*)device); } + IREE_TRACE_ZONE_END(z0); return status; } @@ -563,7 +379,8 @@ iree_status_t iree_hal_hip_device_create( iree_hal_driver_t* driver, iree_string_view_t identifier, const iree_hal_hip_device_params_t* params, const iree_hal_hip_dynamic_symbols_t* symbols, - const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, hipDevice_t device, + const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, + iree_host_size_t device_count, hipDevice_t* devices, iree_allocator_t host_allocator, iree_hal_device_t** out_device) { IREE_ASSERT_ARGUMENT(driver); IREE_ASSERT_ARGUMENT(params); @@ -571,35 +388,53 @@ iree_status_t iree_hal_hip_device_create( IREE_ASSERT_ARGUMENT(out_device); IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_t* device = NULL; + const iree_host_size_t total_device_size = + sizeof(*device) + sizeof(device->devices[0]) * device_count + + identifier.size; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, total_device_size, + (void**)&device)); + device->device_count = device_count; + iree_status_t status = iree_hal_hip_device_check_params(params); // Get the main context for the device. - hipCtx_t context = NULL; - if (iree_status_is_ok(status)) { - status = IREE_HIP_RESULT_TO_STATUS( - symbols, hipDevicePrimaryCtxRetain(&context, device)); - } - if (iree_status_is_ok(status)) { - status = IREE_HIP_RESULT_TO_STATUS(symbols, hipCtxSetCurrent(context)); - } - - // Create the default dispatch stream for the device. - hipStream_t dispatch_stream = NULL; - if (iree_status_is_ok(status)) { - status = IREE_HIP_RESULT_TO_STATUS( + for (iree_host_size_t i = 0; i < device_count && iree_status_is_ok(status); + ++i) { + device->devices[i].hip_device = devices[i]; + status = IREE_HIP_CALL_TO_STATUS( symbols, - hipStreamCreateWithFlags(&dispatch_stream, hipStreamNonBlocking)); + hipDevicePrimaryCtxRetain(&device->devices[i].hip_context, devices[i])); + if (iree_status_is_ok(status)) { + status = IREE_HIP_CALL_TO_STATUS( + symbols, hipCtxSetCurrent(device->devices[i].hip_context)); + } + + // Create the default dispatch stream for the device. + if (iree_status_is_ok(status)) { + status = IREE_HIP_CALL_TO_STATUS( + symbols, + hipStreamCreateWithFlags(&device->devices[i].hip_dispatch_stream, + hipStreamNonBlocking)); + } + + if (iree_status_is_ok(status)) { + for (iree_host_size_t j = 0; + j < device_count && iree_status_is_ok(status); ++j) { + if (i == j) { + continue; + } + status = IREE_HIP_CALL_TO_STATUS( + symbols, hipDeviceEnablePeerAccess(devices[j], 0)); + } + } } if (iree_status_is_ok(status)) { - status = iree_hal_hip_device_create_internal( - driver, identifier, params, device, dispatch_stream, context, symbols, - nccl_symbols, host_allocator, out_device); - } else { - if (dispatch_stream) symbols->hipStreamDestroy(dispatch_stream); - // NOTE: This function return hipSuccess though doesn't release the - // primaryCtx by design on HIP/HCC path. - if (context) symbols->hipDevicePrimaryCtxRelease(device); + status = iree_hal_hip_device_initialize_internal( + driver, identifier, params, device, symbols, nccl_symbols, + host_allocator); } iree_event_pool_t* host_event_pool = NULL; @@ -608,43 +443,26 @@ iree_status_t iree_hal_hip_device_create( host_allocator, &host_event_pool); } - iree_hal_hip_event_pool_t* device_event_pool = NULL; - if (iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < device_count && iree_status_is_ok(status); + ++i) { status = iree_hal_hip_event_pool_allocate( - symbols, context, params->event_pool_capacity, host_allocator, - &device_event_pool); + symbols, params->event_pool_capacity, host_allocator, + device->devices[i].hip_context, &device->devices[i].device_event_pool); } - iree_hal_hip_timepoint_pool_t* timepoint_pool = NULL; if (iree_status_is_ok(status)) { - status = iree_hal_hip_timepoint_pool_allocate( - host_event_pool, device_event_pool, params->event_pool_capacity, - host_allocator, &timepoint_pool); - } - - if (iree_status_is_ok(status)) { - iree_hal_hip_device_t* hip_device = iree_hal_hip_device_cast(*out_device); - hip_device->host_event_pool = host_event_pool; - hip_device->device_event_pool = device_event_pool; - hip_device->timepoint_pool = timepoint_pool; + device->host_event_pool = host_event_pool; + *out_device = (iree_hal_device_t*)device; } else { - // Release resources we have accquired after HAL device creation. - if (timepoint_pool) iree_hal_hip_timepoint_pool_free(timepoint_pool); - if (device_event_pool) iree_hal_hip_event_pool_release(device_event_pool); - if (host_event_pool) iree_event_pool_free(host_event_pool); // Release other resources via the HAL device. - iree_hal_device_release(*out_device); + iree_hal_device_release((iree_hal_device_t*)device); + device = NULL; } IREE_TRACE_ZONE_END(z0); return status; } -hipCtx_t iree_hal_hip_device_context(iree_hal_device_t* base_device) { - iree_hal_hip_device_t* device = iree_hal_hip_device_cast_unsafe(base_device); - return device->hip_context; -} - const iree_hal_hip_dynamic_symbols_t* iree_hal_hip_device_dynamic_symbols( iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast_unsafe(base_device); @@ -654,11 +472,12 @@ const iree_hal_hip_dynamic_symbols_t* iree_hal_hip_device_dynamic_symbols( static void iree_hal_hip_device_destroy(iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); - const iree_hal_hip_dynamic_symbols_t* symbols = device->hip_symbols; IREE_TRACE_ZONE_BEGIN(z0); - // Destroy the pending workload queue. - iree_hal_deferred_work_queue_destroy(device->work_queue); + const iree_hal_hip_dynamic_symbols_t* symbols = device->hip_symbols; + + iree_hal_hip_cleanup_thread_deinitialize(device->cleanup_thread); + iree_hal_hip_cleanup_thread_deinitialize(device->buffer_free_thread); // There should be no more buffers live that use the allocator. iree_hal_allocator_release(device->device_allocator); @@ -666,26 +485,29 @@ static void iree_hal_hip_device_destroy(iree_hal_device_t* base_device) { // Buffers may have been retaining collective resources. iree_hal_channel_provider_release(device->channel_provider); - // Destroy memory pools that hold on to reserved memory. - iree_hal_hip_memory_pools_deinitialize(&device->memory_pools); - - iree_hal_stream_tracing_context_free(device->tracing_context); - - // Destroy various pools for synchronization. - if (device->timepoint_pool) { - iree_hal_hip_timepoint_pool_free(device->timepoint_pool); + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + iree_hal_hip_memory_pools_deinitialize(&device->devices[i].memory_pools); + iree_hal_stream_tracing_context_free(device->devices[i].tracing_context); } - if (device->device_event_pool) { - iree_hal_hip_event_pool_release(device->device_event_pool); + + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + iree_hal_hip_event_pool_release(device->devices[i].device_event_pool); } if (device->host_event_pool) iree_event_pool_free(device->host_event_pool); - IREE_HIP_IGNORE_ERROR(symbols, hipStreamDestroy(device->hip_dispatch_stream)); + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + IREE_HIP_IGNORE_ERROR( + symbols, hipStreamDestroy(device->devices[i].hip_dispatch_stream)); + // NOTE: This function return hipSuccess though doesn't release the + // primaryCtx by design on HIP/HCC path. + IREE_HIP_IGNORE_ERROR( + symbols, hipDevicePrimaryCtxRelease(device->devices[i].hip_device)); + } - // NOTE: This function return hipSuccess though doesn't release the - // primaryCtx by design on HIP/HCC path. - IREE_HIP_IGNORE_ERROR(symbols, - hipDevicePrimaryCtxRelease(device->hip_device)); + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + iree_hal_hip_dispatch_thread_deinitialize( + device->devices[i].dispatch_thread); + } iree_arena_block_pool_deinitialize(&device->block_pool); @@ -733,13 +555,13 @@ static void iree_hal_hip_replace_channel_provider( static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); iree_arena_block_pool_trim(&device->block_pool); IREE_RETURN_IF_ERROR(iree_hal_allocator_trim(device->device_allocator)); if (device->supports_memory_pools) { - IREE_RETURN_IF_ERROR(iree_hal_hip_memory_pools_trim( - &device->memory_pools, &device->params.memory_pools)); + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + IREE_RETURN_IF_ERROR(iree_hal_hip_memory_pools_trim( + &device->devices[i].memory_pools, &device->params.memory_pools)); + } } return iree_ok_status(); } @@ -747,12 +569,13 @@ static iree_status_t iree_hal_hip_device_trim(iree_hal_device_t* base_device) { static iree_status_t iree_hal_hip_device_query_attribute( iree_hal_hip_device_t* device, hipDeviceAttribute_t attribute, int64_t* out_value) { - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + IREE_ASSERT_ARGUMENT(out_value); + + *out_value = 0; int value = 0; IREE_HIP_RETURN_IF_ERROR( device->hip_symbols, - hipDeviceGetAttribute(&value, attribute, device->hip_device), + hipDeviceGetAttribute(&value, attribute, device->devices[0].hip_device), "hipDeviceGetAttribute"); *out_value = value; return iree_ok_status(); @@ -762,8 +585,6 @@ static iree_status_t iree_hal_hip_device_query_i64( iree_hal_device_t* base_device, iree_string_view_t category, iree_string_view_t key, int64_t* out_value) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); *out_value = 0; if (iree_string_view_equal(category, IREE_SV("hal.device.id"))) { @@ -777,6 +598,13 @@ static iree_status_t iree_hal_hip_device_query_i64( return iree_ok_status(); } + if (iree_string_view_equal(category, IREE_SV("hal.device"))) { + if (iree_string_view_equal(key, IREE_SV("concurrency"))) { + *out_value = device->device_count; + return iree_ok_status(); + } + } + return iree_make_status( IREE_STATUS_NOT_FOUND, "unknown device configuration key value '%.*s :: %.*s'", @@ -787,9 +615,6 @@ static iree_status_t iree_hal_hip_device_create_channel( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); - if (!device->nccl_symbols || !device->nccl_symbols->dylib) { return iree_make_status( IREE_STATUS_UNAVAILABLE, @@ -865,25 +690,107 @@ static iree_status_t iree_hal_hip_device_create_channel( // TODO: when we support multiple logical devices we'll want to pass in the // context of the device mapped to the queue_affinity. For now since this // implementation only supports one device we pass in the only one we have. - return iree_hal_hip_nccl_channel_create( + iree_status_t status = iree_hal_hip_nccl_channel_create( device->hip_symbols, device->nccl_symbols, &id, params.rank, params.count, device->host_allocator, out_channel); + return status; } -iree_status_t iree_hal_hip_device_create_stream_command_buffer( +static iree_status_t iree_hal_hip_device_create_command_buffer_internal( iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, - iree_host_size_t binding_capacity, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_hip_device_commandbuffer_type_t type, iree_hal_command_buffer_t** out_command_buffer) { + IREE_TRACE_ZONE_BEGIN(z0); + + *out_command_buffer = NULL; + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); - - return iree_hal_hip_stream_command_buffer_create( - iree_hal_device_allocator(base_device), device->hip_symbols, - device->nccl_symbols, device->hip_context, device->tracing_context, mode, - command_categories, binding_capacity, device->hip_dispatch_stream, - &device->block_pool, device->host_allocator, out_command_buffer); + + iree_hal_command_buffer_t* buffers[IREE_HAL_MAX_QUEUES]; + memset(buffers, 0x00, sizeof(buffers[0]) * IREE_HAL_MAX_QUEUES); + if (queue_affinity == 0) { + queue_affinity = IREE_HAL_QUEUE_AFFINITY_ANY; + } + queue_affinity = + queue_affinity & ~(IREE_HAL_QUEUE_AFFINITY_ANY << device->device_count); + + iree_status_t status = iree_ok_status(); + iree_host_size_t device_ordinal = 0; + iree_host_size_t command_buffer_ordinal = 0; + iree_hal_queue_affinity_t current_affinity = queue_affinity; + while (current_affinity) { + int next_device_ordinal_offset = + iree_math_count_trailing_zeros_u64(current_affinity); + device_ordinal += next_device_ordinal_offset; + current_affinity >>= next_device_ordinal_offset + 1; + status = IREE_HIP_CALL_TO_STATUS( + device->hip_symbols, + hipCtxPushCurrent(device->devices[device_ordinal].hip_context)); + if (!iree_status_is_ok(status)) { + break; + } + switch (type) { + case IREE_HAL_HIP_DEVICE_COMMAND_BUFFER_TYPE_STREAM: + status = iree_hal_hip_stream_command_buffer_create( + iree_hal_device_allocator(base_device), device->hip_symbols, + device->nccl_symbols, + device->devices[device_ordinal].tracing_context, mode, + command_categories, (iree_hal_queue_affinity_t)1 << device_ordinal, + binding_capacity, + device->devices[device_ordinal].hip_dispatch_stream, + &device->block_pool, device->host_allocator, + &buffers[command_buffer_ordinal]); + break; + case IREE_HAL_HIP_DEVICE_COMMAND_BUFFER_TYPE_GRAPH: + status = iree_hal_hip_graph_command_buffer_create( + iree_hal_device_allocator(base_device), device->hip_symbols, + device->devices[device_ordinal].tracing_context, + device->devices[device_ordinal].hip_context, mode, + command_categories, (iree_hal_queue_affinity_t)1 << device_ordinal, + binding_capacity, &device->block_pool, device->host_allocator, + &buffers[command_buffer_ordinal]); + break; + } + + status = iree_status_join( + status, + IREE_HIP_CALL_TO_STATUS(device->hip_symbols, hipCtxPopCurrent(NULL))); + ++device_ordinal; + ++command_buffer_ordinal; + if (!iree_status_is_ok(status)) { + break; + } + } + + if (iree_status_is_ok(status)) { + status = iree_hal_hip_multi_queue_command_buffer_create( + command_buffer_ordinal, &buffers[0], device->device_allocator, mode, + command_categories, queue_affinity, device->hip_symbols, + iree_hal_hip_device_make_topology(device), binding_capacity, + device->host_allocator, out_command_buffer); + } + + // If |iree_hal_hip_multi_queue_command_buffer_create| was successful, it will + // have retained the command buffers, if not, then it will have not. + // So we release here either way. + for (iree_host_size_t i = 0; i < IREE_HAL_MAX_QUEUES; ++i) { + iree_hal_resource_release(buffers[i]); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hip_device_create_stream_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_hal_command_buffer_t** out_command_buffer) { + return iree_hal_hip_device_create_command_buffer_internal( + base_device, mode, command_categories, queue_affinity, binding_capacity, + IREE_HAL_HIP_DEVICE_COMMAND_BUFFER_TYPE_STREAM, out_command_buffer); } static iree_status_t iree_hal_hip_device_create_command_buffer( @@ -891,10 +798,8 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( iree_hal_command_category_t command_categories, iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, iree_hal_command_buffer_t** out_command_buffer) { + *out_command_buffer = NULL; iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); - if (device->params.allow_inline_execution && iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION)) { @@ -902,11 +807,9 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( // recorded, implying that the command buffer cannot be reused and doesn't // need to be persisted. This lets us lower the execution delay as we can // directly route commands to a HIP stream and let it eagerly flush. - return iree_hal_hip_stream_command_buffer_create( - iree_hal_device_allocator(base_device), device->hip_symbols, - device->nccl_symbols, device->hip_context, device->tracing_context, - mode, command_categories, binding_capacity, device->hip_dispatch_stream, - &device->block_pool, device->host_allocator, out_command_buffer); + return iree_hal_hip_device_create_command_buffer_internal( + base_device, mode, command_categories, queue_affinity, binding_capacity, + IREE_HAL_HIP_DEVICE_COMMAND_BUFFER_TYPE_STREAM, out_command_buffer); } switch (device->params.command_buffer_mode) { case IREE_HAL_HIP_COMMAND_BUFFER_MODE_GRAPH: @@ -916,19 +819,18 @@ static iree_status_t iree_hal_hip_device_create_command_buffer( if (binding_capacity > 0) { return iree_hal_deferred_command_buffer_create( iree_hal_device_allocator(base_device), mode, command_categories, - binding_capacity, &device->block_pool, + queue_affinity, binding_capacity, &device->block_pool, iree_hal_device_host_allocator(base_device), out_command_buffer); } else { - return iree_hal_hip_graph_command_buffer_create( - iree_hal_device_allocator(base_device), device->hip_symbols, - device->tracing_context, device->hip_context, mode, - command_categories, queue_affinity, binding_capacity, - &device->block_pool, device->host_allocator, out_command_buffer); + return iree_hal_hip_device_create_command_buffer_internal( + base_device, mode, command_categories, queue_affinity, + binding_capacity, IREE_HAL_HIP_DEVICE_COMMAND_BUFFER_TYPE_GRAPH, + out_command_buffer); } case IREE_HAL_HIP_COMMAND_BUFFER_MODE_STREAM: return iree_hal_deferred_command_buffer_create( iree_hal_device_allocator(base_device), mode, command_categories, - binding_capacity, &device->block_pool, + queue_affinity, binding_capacity, &device->block_pool, iree_hal_device_host_allocator(base_device), out_command_buffer); default: return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, @@ -947,10 +849,7 @@ static iree_status_t iree_hal_hip_device_import_file( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_memory_access_t access, iree_io_file_handle_t* handle, iree_hal_external_file_flags_t flags, iree_hal_file_t** out_file) { - iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); - + *out_file = NULL; if (iree_io_file_handle_type(handle) != IREE_IO_FILE_HANDLE_TYPE_HOST_ALLOCATION) { return iree_make_status( @@ -966,24 +865,25 @@ static iree_status_t iree_hal_hip_device_create_executable_cache( iree_hal_device_t* base_device, iree_string_view_t identifier, iree_loop_t loop, iree_hal_executable_cache_t** out_executable_cache) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + hipDevice_t devices[IREE_HAL_MAX_QUEUES]; + hipCtx_t contexts[IREE_HAL_MAX_QUEUES]; + for (iree_host_size_t i = 0; i < device->device_count; ++i) { + devices[i] = device->devices[i].hip_device; + contexts[i] = device->devices[i].hip_context; + } return iree_hal_hip_nop_executable_cache_create( - identifier, device->hip_symbols, device->hip_device, device->hip_context, - device->host_allocator, out_executable_cache); + identifier, device->hip_symbols, + iree_hal_hip_device_make_topology(device), device->host_allocator, + out_executable_cache); } static iree_status_t iree_hal_hip_device_create_semaphore( iree_hal_device_t* base_device, uint64_t initial_value, iree_hal_semaphore_flags_t flags, iree_hal_semaphore_t** out_semaphore) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); - - return iree_hal_hip_event_semaphore_create( - initial_value, device->hip_symbols, device->hip_context, - device->timepoint_pool, device->work_queue, device->host_allocator, - out_semaphore); + return iree_hal_hip_event_semaphore_create(initial_value, device->hip_symbols, + device->host_allocator, + out_semaphore); } static iree_hal_semaphore_compatibility_t @@ -993,21 +893,22 @@ iree_hal_hip_device_query_semaphore_compatibility( return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY; } -static iree_status_t iree_hal_hip_device_pepare_async_alloc( +static iree_status_t iree_hal_hip_device_prepare_async_alloc( iree_hal_hip_device_t* device, iree_hal_buffer_params_t params, iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { IREE_TRACE_ZONE_BEGIN(z0); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)allocation_size); + *out_buffer = NULL; iree_hal_buffer_params_canonicalize(¶ms); - const iree_hal_buffer_placement_t placement = { .device = (iree_hal_device_t*)device, .queue_affinity = params.queue_affinity ? params.queue_affinity : IREE_HAL_QUEUE_AFFINITY_ANY, .flags = IREE_HAL_BUFFER_PLACEMENT_FLAG_ASYNCHRONOUS, }; + iree_hal_buffer_t* buffer = NULL; iree_status_t status = iree_hal_hip_buffer_wrap( placement, params.type, params.access, params.usage, allocation_size, @@ -1026,6 +927,422 @@ static iree_status_t iree_hal_hip_device_pepare_async_alloc( IREE_TRACE_ZONE_END(z0); return status; } +typedef enum iree_hal_hip_device_semaphore_buffer_operation_type_e { + IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_ALLOC, + IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_DEALLOC, + IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_MAX = + IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_DEALLOC, +} iree_hal_hip_device_semaphore_buffer_operation_type_t; + +typedef struct iree_hal_hip_device_semaphore_buffer_operation_callback_data_t { + iree_allocator_t host_allocator; + iree_atomic_int64_t wait_semaphore_count; + iree_hal_hip_device_t* device; + iree_hal_queue_affinity_t queue_affinity; + iree_hal_semaphore_list_t wait_semaphore_list; + iree_hal_semaphore_list_t signal_semaphore_list; + iree_hal_buffer_t* buffer; + iree_hal_hip_device_semaphore_buffer_operation_type_t type; + iree_slim_mutex_t status_mutex; + iree_status_t status; +} iree_hal_hip_device_semaphore_buffer_operation_callback_data_t; + +static iree_status_t iree_hal_hip_device_make_buffer_callback_data( + iree_hal_hip_device_t* device, iree_allocator_t host_allocator, + iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_buffer_t* buffer, + iree_hal_hip_device_semaphore_buffer_operation_type_t type, + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t** out_data) { + *out_data = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + // Embed captured tables in the action allocation. + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* + callback_data = NULL; + + const iree_host_size_t wait_semaphore_list_size = + wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores) + + wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values); + const iree_host_size_t signal_semaphore_list_size = + signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores) + + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.payload_values); + + const iree_host_size_t total_callback_size = sizeof(*callback_data) + + wait_semaphore_list_size + + signal_semaphore_list_size; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, total_callback_size, + (void**)&callback_data)); + uint8_t* callback_ptr = (uint8_t*)callback_data + sizeof(*callback_data); + + iree_atomic_store(&callback_data->wait_semaphore_count, + wait_semaphore_list.count, iree_memory_order_relaxed); + + callback_data->host_allocator = host_allocator; + callback_data->device = device; + callback_data->queue_affinity = queue_affinity; + + // Copy wait list for later access. + callback_data->wait_semaphore_list.count = wait_semaphore_list.count; + callback_data->wait_semaphore_list.semaphores = + (iree_hal_semaphore_t**)callback_ptr; + memcpy(callback_data->wait_semaphore_list.semaphores, + wait_semaphore_list.semaphores, + wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores)); + callback_data->wait_semaphore_list.payload_values = + (uint64_t*)(callback_ptr + wait_semaphore_list.count * + sizeof(*wait_semaphore_list.semaphores)); + memcpy( + callback_data->wait_semaphore_list.payload_values, + wait_semaphore_list.payload_values, + wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { + iree_hal_resource_retain(wait_semaphore_list.semaphores[i]); + } + callback_ptr += wait_semaphore_list_size; + + // Copy signal list for later access. + callback_data->signal_semaphore_list.count = signal_semaphore_list.count; + callback_data->signal_semaphore_list.semaphores = + (iree_hal_semaphore_t**)callback_ptr; + memcpy( + callback_data->signal_semaphore_list.semaphores, + signal_semaphore_list.semaphores, + signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores)); + callback_data->signal_semaphore_list.payload_values = + (uint64_t*)(callback_ptr + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.semaphores)); + memcpy(callback_data->signal_semaphore_list.payload_values, + signal_semaphore_list.payload_values, + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + iree_hal_resource_retain(signal_semaphore_list.semaphores[i]); + } + callback_ptr += signal_semaphore_list_size; + + callback_data->buffer = buffer; + iree_hal_buffer_retain(buffer); + callback_data->type = type; + + iree_slim_mutex_initialize(&callback_data->status_mutex); + callback_data->status = iree_ok_status(); + *out_data = callback_data; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +void iree_hal_hip_device_destroy_buffer_callback_data( + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* data) { + if (!data) { + return; + } + iree_slim_mutex_deinitialize(&data->status_mutex); + for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { + iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); + } + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_hal_resource_release(data->signal_semaphore_list.semaphores[i]); + } + iree_hal_buffer_release(data->buffer); + + iree_allocator_free(data->host_allocator, data); +} + +static iree_status_t +iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( + iree_hal_hip_device_t* device, iree_hal_hip_cleanup_thread_t* thread, + iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t device_ordinal, iree_hal_hip_cleanup_callback_t callback, + void* user_data) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + status = iree_hal_hip_semaphore_create_event_and_record_if_necessary( + signal_semaphore_list.semaphores[i], + signal_semaphore_list.payload_values[i], + device->devices[device_ordinal].hip_dispatch_stream, + device->devices[device_ordinal].device_event_pool); + if (!iree_status_is_ok(status)) { + break; + } + } + + for (iree_host_size_t i = 0; + i < signal_semaphore_list.count && iree_status_is_ok(status); ++i) { + status = iree_hal_hip_semaphore_notify_forward_progress_to( + signal_semaphore_list.semaphores[i], + signal_semaphore_list.payload_values[i]); + } + + iree_hal_hip_event_t* event = NULL; + if (iree_status_is_ok(status)) { + status = iree_hal_hip_event_pool_acquire( + device->devices[device_ordinal].device_event_pool, 1, &event); + } + + if (iree_status_is_ok(status)) { + status = IREE_HIP_CALL_TO_STATUS( + device->hip_symbols, + hipEventRecord(iree_hal_hip_event_handle(event), + device->devices[device_ordinal].hip_dispatch_stream)); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_hip_cleanup_thread_add_cleanup(thread, event, callback, + user_data); + } else { + iree_hal_hip_event_release(event); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +typedef struct iree_hal_hip_device_buffer_free_callback_data_t { + iree_hal_hip_device_t* device; + iree_allocator_t host_allocator; + iree_hal_queue_affinity_t queue_affinity; + iree_hal_buffer_t* buffer; +} iree_hal_hip_device_buffer_free_callback_data_t; + +static iree_status_t iree_hal_hip_device_make_buffer_free_callback_data( + iree_hal_hip_device_t* device, iree_hal_queue_affinity_t queue_affinity, + iree_hal_buffer_t* buffer, iree_allocator_t host_allocator, + iree_hal_hip_device_buffer_free_callback_data_t** out_data) { + IREE_TRACE_ZONE_BEGIN(z0); + + *out_data = NULL; + + iree_hal_hip_device_buffer_free_callback_data_t* callback_data = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*callback_data), + (void**)&callback_data)); + + callback_data->buffer = buffer; + callback_data->device = device; + callback_data->queue_affinity = queue_affinity; + callback_data->host_allocator = host_allocator; + *out_data = callback_data; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hip_async_free_buffer(void* user_data, + iree_hal_hip_event_t* event, + iree_status_t status) { + iree_hal_hip_device_buffer_free_callback_data_t* data = + (iree_hal_hip_device_buffer_free_callback_data_t*)(user_data); + + iree_hal_hip_device_t* device = data->device; + int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); + + if (device->supports_memory_pools) { + status = iree_status_join( + status, + iree_hal_hip_memory_pools_deallocate( + &device->devices[device_ordinal].memory_pools, + device->devices[device_ordinal].hip_dispatch_stream, data->buffer)); + } else { + status = iree_status_join( + status, + iree_hal_hip_allocator_free_async( + iree_hal_device_allocator((iree_hal_device_t*)data->device), + device->devices[device_ordinal].hip_dispatch_stream, data->buffer)); + } + + iree_hal_hip_event_release(event); + iree_hal_buffer_release(data->buffer); + iree_allocator_free(device->host_allocator, data); + + return status; +} + +static iree_status_t iree_hal_hip_device_complete_buffer_operation( + void* user_data, iree_hal_hip_event_t* event, iree_status_t status) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* data = + (iree_hal_hip_device_semaphore_buffer_operation_callback_data_t*) + user_data; + + // Free the event we specifically created. + iree_hal_hip_event_release(event); + + // Notify all of the signal semaphores that they have been incremented. + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_status_ignore(iree_hal_hip_event_semaphore_advance( + data->signal_semaphore_list.semaphores[i])); + } + + if (data->buffer && + data->type == IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_DEALLOC) { + int device_ordinal = + iree_math_count_trailing_zeros_u64(data->queue_affinity); + if (data->device->supports_memory_pools) { + status = iree_status_join( + status, iree_hal_hip_memory_pools_deallocate( + &data->device->devices[device_ordinal].memory_pools, + data->device->devices[device_ordinal].hip_dispatch_stream, + data->buffer)); + } else { + status = iree_status_join( + status, + iree_hal_hip_allocator_free_async( + iree_hal_device_allocator((iree_hal_device_t*)data->device), + data->device->devices[device_ordinal].hip_dispatch_stream, + data->buffer)); + } + } + + iree_hal_hip_device_destroy_buffer_callback_data(data); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hip_device_stream_wait_for_semaphores( + iree_hal_hip_device_t* device, + iree_hal_semaphore_list_t wait_semaphore_list, + iree_host_size_t device_ordinal) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_t status = iree_ok_status(); + // TODO(awoloszyn): Because of how hip works, if we only have a single + // physical device in the hip_device we could avoid waiting on any of these + // semaphores, we are guaranteed to have waits, but if we want this + // to work across multiple device/streams, we need these waits. + for (iree_host_size_t i = 0; + i < wait_semaphore_list.count && iree_status_is_ok(status); ++i) { + iree_hal_hip_event_t* event = NULL; + status = iree_hal_hip_semaphore_get_hip_event( + wait_semaphore_list.semaphores[i], + wait_semaphore_list.payload_values[i], + device->devices[device_ordinal].device_event_pool, &event); + if (!iree_status_is_ok(status)) { + break; + } + // If we don't have an event, then we don't have to wait for it since it + // has already been signaled on the host. + if (!event) { + continue; + } + + status = IREE_HIP_CALL_TO_STATUS( + device->hip_symbols, + hipStreamWaitEvent(device->devices[device_ordinal].hip_dispatch_stream, + iree_hal_hip_event_handle(event), 0)); + iree_hal_hip_event_release(event); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hip_device_perform_buffer_operation_now( + void* user_data, iree_status_t status) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* data = + (iree_hal_hip_device_semaphore_buffer_operation_callback_data_t*) + user_data; + IREE_ASSERT_LE(data->type, IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_MAX); + + iree_hal_hip_device_t* device = data->device; + + // If we had a semaphore failure then we should propagate it + // but not run anything. + if (!iree_status_is_ok(data->status)) { + status = iree_status_join(data->status, status); + } + + int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); + + if (iree_status_is_ok(status)) { + status = IREE_HIP_CALL_TO_STATUS( + data->device->hip_symbols, + hipCtxPushCurrent(data->device->devices[device_ordinal].hip_context)); + } + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, device_ordinal); + + if (iree_status_is_ok(status)) { + status = iree_hal_hip_device_stream_wait_for_semaphores( + data->device, data->wait_semaphore_list, device_ordinal); + } + + // We have satisfied all of the waits. + IREE_TRACE_ZONE_BEGIN_NAMED( + z3, "iree_hal_hip_device_perform_buffer_operation_now_launch_operation"); + if (iree_status_is_ok(status)) { + switch (data->type) { + case IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_ALLOC: + if (device->supports_memory_pools) { + status = iree_hal_hip_memory_pools_allocate_pointer( + &device->devices[device_ordinal].memory_pools, data->buffer, + device->devices[device_ordinal].hip_dispatch_stream, + iree_hal_buffer_allocation_size(data->buffer)); + break; + } + status = iree_hal_hip_allocator_alloc_async( + iree_hal_device_allocator((iree_hal_device_t*)data->device), + device->devices[device_ordinal].hip_dispatch_stream, data->buffer); + break; + case IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_DEALLOC: { + // Because of a quirk of HIP here, we don't actually want to use + // free_async which can cause large amounts of memory fragmentation, + // so instead we will put the actual free on the cleanup thread. + } break; + } + } + IREE_TRACE_ZONE_END(z3); + + const iree_hal_hip_dynamic_symbols_t* symbols = data->device->hip_symbols; + if (iree_status_is_ok(status)) { + // Data may get deleted any time after adding it to the cleanup, + // so retain the symbols here. + status = iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( + data->device, data->device->cleanup_thread, data->signal_semaphore_list, + device_ordinal, &iree_hal_hip_device_complete_buffer_operation, data); + } else { + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_hal_semaphore_fail(data->signal_semaphore_list.semaphores[i], + iree_status_clone(data->status)); + } + iree_hal_hip_device_destroy_buffer_callback_data(data); + } + + IREE_TRACE_ZONE_END(z0); + return iree_status_join( + status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); +} + +static iree_status_t iree_hal_hip_device_semaphore_buffer_operation_callback( + void* user_context, iree_hal_semaphore_t* semaphore, iree_status_t status) { + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* data = + (iree_hal_hip_device_semaphore_buffer_operation_callback_data_t*) + user_context; + if (!iree_status_is_ok(status)) { + iree_slim_mutex_lock(&data->status_mutex); + data->status = iree_status_join(data->status, status); + iree_slim_mutex_unlock(&data->status_mutex); + } + if (iree_atomic_fetch_sub(&data->wait_semaphore_count, 1, + iree_memory_order_acq_rel) != 1) { + return iree_ok_status(); + } + + int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); + // Now the actual buffer_operation happens, as all semaphore have been + // satisfied (by satisfied here, we specifically mean that the semaphore has + // been scheduled, not necessarily completed). + return iree_hal_hip_dispatch_thread_add_dispatch( + data->device->devices[device_ordinal].dispatch_thread, + &iree_hal_hip_device_perform_buffer_operation_now, data); +} // TODO: implement multiple streams; today we only have one and queue_affinity // is ignored. @@ -1038,64 +1355,77 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { + IREE_TRACE_ZONE_BEGIN(z0); + + *out_buffer = NULL; + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + uint64_t queue_affinity_mask = + ((iree_hal_queue_affinity_t)1 << device->device_count); + queue_affinity_mask = queue_affinity_mask | (queue_affinity_mask - 1); + queue_affinity &= queue_affinity_mask; + + int device_ordinal = iree_math_count_trailing_zeros_u64(queue_affinity); + queue_affinity = (uint64_t)1 << device_ordinal; - if (device->supports_memory_pools && - !iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { + iree_status_t status = iree_ok_status(); + if (!iree_all_bits_set(params.type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && + (device->supports_memory_pools || + iree_hal_hip_allocator_isa(iree_hal_device_allocator(base_device)))) { iree_hal_buffer_t* buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_hip_memory_pools_prepare_buffer( - &device->memory_pools, device->hip_dispatch_stream, pool, params, - allocation_size, &buffer)); + status = iree_hal_hip_device_prepare_async_alloc(device, params, + allocation_size, &buffer); - iree_status_t status = iree_hal_deferred_work_queue_enqueue_alloc( - device->work_queue, wait_semaphore_list, signal_semaphore_list, buffer); + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* + callback_data = NULL; if (iree_status_is_ok(status)) { - status = iree_hal_deferred_work_queue_issue(device->work_queue); + status = iree_hal_hip_device_make_buffer_callback_data( + device, device->host_allocator, queue_affinity, wait_semaphore_list, + signal_semaphore_list, buffer, + IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_ALLOC, &callback_data); } - if (iree_status_is_ok(status)) { - *out_buffer = buffer; + if (iree_status_is_ok(status) && wait_semaphore_list.count == 0) { + status = iree_hal_hip_dispatch_thread_add_dispatch( + device->devices[device_ordinal].dispatch_thread, + &iree_hal_hip_device_perform_buffer_operation_now, callback_data); + } else if (iree_status_is_ok(status) && wait_semaphore_list.count != 0) { + for (iree_host_size_t i = 0; + i < wait_semaphore_list.count && iree_status_is_ok(status); ++i) { + status = iree_status_join( + status, + iree_hal_hip_semaphore_notify_work( + wait_semaphore_list.semaphores[i], + wait_semaphore_list.payload_values[i], + device->devices[device_ordinal].device_event_pool, + &iree_hal_hip_device_semaphore_buffer_operation_callback, + callback_data)); + } } else { - iree_hal_hip_buffer_set_allocation_empty(buffer); - iree_hal_resource_release(&buffer->resource); + iree_hal_hip_device_destroy_buffer_callback_data(callback_data); } - return status; - } else if (!iree_all_bits_set(params.type, - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) && - iree_hal_hip_allocator_isa( - iree_hal_device_allocator(base_device))) { - iree_hal_buffer_t* buffer = NULL; - - IREE_RETURN_IF_ERROR(iree_hal_hip_device_pepare_async_alloc( - device, params, allocation_size, &buffer)); - iree_status_t status = iree_hal_deferred_work_queue_enqueue_alloc( - device->work_queue, wait_semaphore_list, signal_semaphore_list, buffer); - if (iree_status_is_ok(status)) { - status = iree_hal_deferred_work_queue_issue(device->work_queue); - } if (iree_status_is_ok(status)) { *out_buffer = buffer; } else { - iree_hal_hip_buffer_set_allocation_empty(buffer); - iree_hal_resource_release(&buffer->resource); + if (buffer) { + iree_hal_hip_buffer_set_allocation_empty(buffer); + iree_hal_resource_release(&buffer->resource); + } } + + IREE_TRACE_ZONE_END(z0); return status; } // NOTE: block on the semaphores here; we could avoid this by properly // sequencing device work with semaphores. The HIP HAL is not currently // asynchronous. - IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, - iree_infinite_timeout())); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_semaphore_list_wait(wait_semaphore_list, + iree_infinite_timeout())); - // Allocate from the pool; likely to fail in cases of virtual memory - // exhaustion but the error may be deferred until a later synchronization. - // If pools are not supported we allocate a buffer as normal from whatever - // allocator is set on the device. - iree_status_t status = + status = iree_hal_allocator_allocate_buffer(iree_hal_device_allocator(base_device), params, allocation_size, out_buffer); @@ -1105,6 +1435,7 @@ static iree_status_t iree_hal_hip_device_queue_alloca( if (iree_status_is_ok(status)) { status = iree_hal_semaphore_list_signal(signal_semaphore_list); } + IREE_TRACE_ZONE_END(z0); return status; } @@ -1117,31 +1448,71 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_buffer_t* buffer) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + uint64_t queue_affinity_mask = + ((iree_hal_queue_affinity_t)1 << device->device_count); + queue_affinity_mask = queue_affinity_mask | (queue_affinity_mask - 1); + queue_affinity &= queue_affinity_mask; + + int device_ordinal = iree_math_count_trailing_zeros_u64(queue_affinity); + + if (device_ordinal > device->device_count) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "device affinity out of range, maximum device is %" PRIhsz, + device->device_count); + } + + queue_affinity = (uint64_t)1 << device_ordinal; + iree_status_t status = iree_ok_status(); if (iree_hal_hip_allocator_isa(iree_hal_device_allocator(base_device))) { - iree_status_t status = iree_hal_deferred_work_queue_enqueue_dealloc( - device->work_queue, wait_semaphore_list, signal_semaphore_list, buffer); - if (iree_status_is_ok(status)) { - status = iree_hal_deferred_work_queue_issue(device->work_queue); + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* + callback_data; + status = iree_hal_hip_device_make_buffer_callback_data( + device, device->host_allocator, queue_affinity, wait_semaphore_list, + signal_semaphore_list, buffer, + IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_DEALLOC, &callback_data); + + if (iree_status_is_ok(status) && wait_semaphore_list.count == 0) { + status = iree_hal_hip_dispatch_thread_add_dispatch( + device->devices[device_ordinal].dispatch_thread, + &iree_hal_hip_device_perform_buffer_operation_now, callback_data); + } else if (iree_status_is_ok(status) && wait_semaphore_list.count != 0) { + for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { + status = iree_status_join( + status, + iree_hal_hip_semaphore_notify_work( + wait_semaphore_list.semaphores[i], + wait_semaphore_list.payload_values[i], + device->devices[device_ordinal].device_event_pool, + &iree_hal_hip_device_semaphore_buffer_operation_callback, + callback_data)); + } + } else { + iree_hal_hip_device_destroy_buffer_callback_data(callback_data); } + + IREE_TRACE_ZONE_END(z0); return status; } // NOTE: block on the semaphores here; we could avoid this by properly // sequencing device work with semaphores. The HIP HAL is not currently // asynchronous. - IREE_RETURN_IF_ERROR(iree_hal_semaphore_list_wait(wait_semaphore_list, - iree_infinite_timeout())); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_semaphore_list_wait(wait_semaphore_list, + iree_infinite_timeout())); // Schedule the buffer deallocation if we got it from a pool and otherwise // drop it on the floor and let it be freed when the buffer is released. - iree_status_t status = iree_ok_status(); if (device->supports_memory_pools) { status = iree_hal_hip_memory_pools_deallocate( - &device->memory_pools, device->hip_dispatch_stream, buffer); + &device->devices[device_ordinal].memory_pools, + device->devices[device_ordinal].hip_dispatch_stream, buffer); } // Only signal if not returning a synchronous error - synchronous failure @@ -1150,6 +1521,8 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( if (iree_status_is_ok(status)) { status = iree_hal_semaphore_list_signal(signal_semaphore_list); } + + IREE_TRACE_ZONE_END(z0); return status; } @@ -1160,6 +1533,8 @@ static iree_status_t iree_hal_hip_device_queue_read( iree_hal_file_t* source_file, uint64_t source_offset, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length, iree_hal_read_flags_t flags) { + IREE_TRACE_ZONE_BEGIN(z0); + // TODO: expose streaming chunk count/size options. iree_status_t loop_status = iree_ok_status(); iree_hal_file_transfer_options_t options = { @@ -1167,10 +1542,13 @@ static iree_status_t iree_hal_hip_device_queue_read( .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, }; - IREE_RETURN_IF_ERROR(iree_hal_device_queue_read_streaming( - base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, - source_file, source_offset, target_buffer, target_offset, length, flags, - options)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_device_queue_read_streaming( + base_device, queue_affinity, wait_semaphore_list, + signal_semaphore_list, source_file, source_offset, target_buffer, + target_offset, length, flags, options)); + + IREE_TRACE_ZONE_END(z0); return loop_status; } @@ -1181,6 +1559,8 @@ static iree_status_t iree_hal_hip_device_queue_write( iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, iree_hal_file_t* target_file, uint64_t target_offset, iree_device_size_t length, iree_hal_write_flags_t flags) { + IREE_TRACE_ZONE_BEGIN(z0); + // TODO: expose streaming chunk count/size options. iree_status_t loop_status = iree_ok_status(); iree_hal_file_transfer_options_t options = { @@ -1188,16 +1568,354 @@ static iree_status_t iree_hal_hip_device_queue_write( .chunk_count = IREE_HAL_FILE_TRANSFER_CHUNK_COUNT_DEFAULT, .chunk_size = IREE_HAL_FILE_TRANSFER_CHUNK_SIZE_DEFAULT, }; - IREE_RETURN_IF_ERROR(iree_hal_device_queue_write_streaming( - base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, - source_buffer, source_offset, target_file, target_offset, length, flags, - options)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_device_queue_write_streaming( + base_device, queue_affinity, wait_semaphore_list, + signal_semaphore_list, source_buffer, source_offset, target_file, + target_offset, length, flags, options)); + + IREE_TRACE_ZONE_END(z0); return loop_status; } -static void iree_hal_hip_device_collect_tracing_context(void* user_data) { - iree_hal_stream_tracing_context_collect( - (iree_hal_stream_tracing_context_t*)user_data); +typedef struct iree_hal_hip_device_semaphore_submit_callback_data_t { + iree_allocator_t host_allocator; + iree_atomic_int64_t wait_semaphore_count; + iree_hal_hip_device_t* device; + iree_hal_queue_affinity_t queue_affinity; + iree_hal_command_buffer_t* command_buffer; + iree_hal_buffer_binding_table_t binding_table; + iree_hal_semaphore_list_t wait_semaphore_list; + iree_hal_semaphore_list_t signal_semaphore_list; + iree_hal_resource_set_t* resource_set; + iree_slim_mutex_t status_mutex; + iree_status_t status; +} iree_hal_hip_device_semaphore_submit_callback_data_t; + +static iree_status_t iree_hal_hip_device_make_callback_data( + iree_hal_hip_device_t* device, iree_allocator_t host_allocator, + iree_arena_block_pool_t* block_pool, + iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_command_buffer_t* command_buffer, + iree_hal_buffer_binding_table_t binding_table, + iree_hal_hip_device_semaphore_submit_callback_data_t** out_data) { + IREE_TRACE_ZONE_BEGIN(z0); + + *out_data = NULL; + + // Embed captured tables in the action allocation. + iree_hal_hip_device_semaphore_submit_callback_data_t* callback_data = NULL; + + const iree_host_size_t wait_semaphore_list_size = + wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores) + + wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values); + const iree_host_size_t signal_semaphore_list_size = + signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores) + + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.payload_values); + + const iree_host_size_t payload_size = + binding_table.count * sizeof(*binding_table.bindings); + + const iree_host_size_t total_callback_size = + sizeof(*callback_data) + wait_semaphore_list_size + + signal_semaphore_list_size + payload_size; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, total_callback_size, + (void**)&callback_data)); + uint8_t* callback_ptr = (uint8_t*)callback_data + sizeof(*callback_data); + + callback_data->host_allocator = host_allocator; + callback_data->device = device; + + iree_atomic_store(&callback_data->wait_semaphore_count, + wait_semaphore_list.count, iree_memory_order_relaxed); + // Copy wait list for later access. + callback_data->wait_semaphore_list.count = wait_semaphore_list.count; + callback_data->wait_semaphore_list.semaphores = + (iree_hal_semaphore_t**)callback_ptr; + memcpy(callback_data->wait_semaphore_list.semaphores, + wait_semaphore_list.semaphores, + wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores)); + callback_data->wait_semaphore_list.payload_values = + (uint64_t*)(callback_ptr + wait_semaphore_list.count * + sizeof(*wait_semaphore_list.semaphores)); + memcpy( + callback_data->wait_semaphore_list.payload_values, + wait_semaphore_list.payload_values, + wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { + iree_hal_resource_retain(wait_semaphore_list.semaphores[i]); + } + callback_ptr += wait_semaphore_list_size; + + // Copy signal list for later access. + callback_data->signal_semaphore_list.count = signal_semaphore_list.count; + callback_data->signal_semaphore_list.semaphores = + (iree_hal_semaphore_t**)callback_ptr; + memcpy( + callback_data->signal_semaphore_list.semaphores, + signal_semaphore_list.semaphores, + signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores)); + callback_data->signal_semaphore_list.payload_values = + (uint64_t*)(callback_ptr + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.semaphores)); + memcpy(callback_data->signal_semaphore_list.payload_values, + signal_semaphore_list.payload_values, + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + iree_hal_resource_retain(signal_semaphore_list.semaphores[i]); + } + callback_ptr += signal_semaphore_list_size; + + // Copy the execution resources for later access. + callback_data->queue_affinity = queue_affinity; + callback_data->command_buffer = command_buffer; + + // Retain all command buffers and semaphores. + iree_status_t status = + iree_hal_resource_set_allocate(block_pool, &callback_data->resource_set); + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(callback_data->resource_set, + wait_semaphore_list.count, + wait_semaphore_list.semaphores); + } + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(callback_data->resource_set, + signal_semaphore_list.count, + signal_semaphore_list.semaphores); + } + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(callback_data->resource_set, 1, + &command_buffer); + } + + callback_data->binding_table = binding_table; + iree_hal_buffer_binding_t* binding_element_ptr = + (iree_hal_buffer_binding_t*)callback_ptr; + callback_data->binding_table.bindings = binding_element_ptr; + memcpy(binding_element_ptr, binding_table.bindings, + sizeof(*binding_element_ptr) * binding_table.count); + status = iree_hal_resource_set_insert_strided( + callback_data->resource_set, binding_table.count, + callback_data->binding_table.bindings, + offsetof(iree_hal_buffer_binding_t, buffer), + sizeof(iree_hal_buffer_binding_t)); + + callback_data->status = iree_ok_status(); + iree_slim_mutex_initialize(&callback_data->status_mutex); + *out_data = callback_data; + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hip_device_destroy_callback_data( + iree_hal_hip_device_semaphore_submit_callback_data_t* data) { + if (!data) { + return; + } + iree_slim_mutex_deinitialize(&data->status_mutex); + iree_hal_resource_set_free(data->resource_set); + for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { + iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); + } + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_hal_resource_release(data->signal_semaphore_list.semaphores[i]); + } + iree_allocator_free(data->host_allocator, data); +} + +static iree_status_t iree_hal_hip_device_complete_submission( + void* user_data, iree_hal_hip_event_t* event, iree_status_t status) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hip_device_semaphore_submit_callback_data_t* data = + (iree_hal_hip_device_semaphore_submit_callback_data_t*)user_data; + iree_hal_hip_device_t* device = data->device; + + // Get the device_context from the queue_affinity. + int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); + + // Read any tracing events that were submitted. + + if (iree_status_is_ok(status)) { + iree_hal_command_buffer_t* command_buffer = data->command_buffer; + if (iree_hal_hip_multi_queue_command_buffer_isa(command_buffer)) { + status = iree_hal_hip_multi_queue_command_buffer_get( + command_buffer, data->queue_affinity, &command_buffer); + } + + if (iree_status_is_ok(status)) { + if (iree_hal_hip_stream_command_buffer_isa(command_buffer)) { + status = iree_hal_stream_tracing_context_collect_list( + // Get the tracing context from the device/stream/queue affinity. + device->devices[device_ordinal].tracing_context, + // Get the tracing event list from the command buffer. + iree_hal_hip_stream_command_buffer_tracing_events(command_buffer) + .head); + } else if (iree_hal_hip_graph_command_buffer_isa(command_buffer)) { + status = iree_hal_stream_tracing_context_collect_list( + // Get the tracing context from the device/stream/queue affinity. + device->devices[device_ordinal].tracing_context, + // Get the tracing event list from the command buffer. + iree_hal_hip_graph_command_buffer_tracing_events(command_buffer) + .head); + } + } + } + + // Free the event we specifically created. + iree_hal_hip_event_release(event); + + // Notify all of the signal semaphores that they have been incremented. + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_status_ignore(iree_hal_hip_event_semaphore_advance( + data->signal_semaphore_list.semaphores[i])); + } + iree_hal_hip_device_destroy_callback_data(data); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hip_device_execute_now(void* user_data, + iree_status_t status) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hip_device_semaphore_submit_callback_data_t* data = + (iree_hal_hip_device_semaphore_submit_callback_data_t*)user_data; + IREE_ASSERT_EQ(iree_math_count_ones_u64(data->queue_affinity), 1, + "Cannot execute a command buffer on more than one queue"); + + iree_hal_hip_device_t* device = data->device; + + // If we had a semaphore failure then we should propagate it + // but not run anything. + status = iree_status_join(status, data->status); + + int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); + IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, device_ordinal); + + if (iree_status_is_ok(status)) { + status = IREE_HIP_CALL_TO_STATUS( + data->device->hip_symbols, + hipCtxPushCurrent(data->device->devices[device_ordinal].hip_context)); + } + + if (iree_status_is_ok(status)) { + status = iree_hal_hip_device_stream_wait_for_semaphores( + data->device, data->wait_semaphore_list, device_ordinal); + } + + // We have satisfied all of the waits. + IREE_TRACE_ZONE_BEGIN_NAMED(z1, "iree_hal_hip_device_execute_now_launch"); + iree_hal_command_buffer_t* command_buffer = data->command_buffer; + if (iree_status_is_ok(status)) { + if (iree_hal_hip_multi_queue_command_buffer_isa(command_buffer)) { + status = iree_hal_hip_multi_queue_command_buffer_get( + command_buffer, data->queue_affinity, &command_buffer); + } + } + if (iree_status_is_ok(status)) { + iree_hal_buffer_binding_table_t binding_table = data->binding_table; + if (iree_hal_deferred_command_buffer_isa(command_buffer)) { + iree_hal_command_buffer_t* stream_command_buffer = NULL; + iree_hal_command_buffer_mode_t mode = + iree_hal_command_buffer_mode(command_buffer) | + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | + // NOTE: we need to validate if a binding table is provided as the + // bindings were not known when it was originally recorded. + (iree_hal_buffer_binding_table_is_empty(binding_table) + ? IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED + : 0); + status = iree_hal_hip_device_create_stream_command_buffer( + (iree_hal_device_t*)data->device, mode, + command_buffer->allowed_categories, data->queue_affinity, 0, + &stream_command_buffer); + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(data->resource_set, 1, + &stream_command_buffer); + } + if (iree_status_is_ok(status)) { + status = iree_hal_deferred_command_buffer_apply( + command_buffer, stream_command_buffer, binding_table); + } + data->command_buffer = stream_command_buffer; + iree_hal_resource_release(stream_command_buffer); + } else if (iree_hal_hip_stream_command_buffer_isa(command_buffer)) { + status = + iree_hal_resource_set_insert(data->resource_set, 1, &command_buffer); + } else if (iree_hal_hip_graph_command_buffer_isa(command_buffer)) { + status = + iree_hal_resource_set_insert(data->resource_set, 1, &command_buffer); + if (iree_status_is_ok(status)) { + IREE_TRACE_ZONE_BEGIN_NAMED( + z2, "iree_hal_hip_device_execute_now_hip_graph_launch"); + hipGraphExec_t exec = + iree_hal_hip_graph_command_buffer_handle(command_buffer); + status = IREE_HIP_CALL_TO_STATUS( + data->device->hip_symbols, + hipGraphLaunch( + exec, device->devices[device_ordinal].hip_dispatch_stream)); + IREE_TRACE_ZONE_END(z2); + } + } else if (command_buffer) { + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unsupported command buffer type"); + } + } + + IREE_TRACE_ZONE_END(z1); + + // Store symbols, because the cleanup may trigger off-thread + // before it returns. + const iree_hal_hip_dynamic_symbols_t* symbols = data->device->hip_symbols; + + if (iree_status_is_ok(status)) { + status = iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( + data->device, data->device->cleanup_thread, data->signal_semaphore_list, + device_ordinal, iree_hal_hip_device_complete_submission, data); + } + + if (!iree_status_is_ok(status)) { + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_hal_semaphore_fail(data->signal_semaphore_list.semaphores[i], + iree_status_clone(data->status)); + } + iree_hal_hip_device_destroy_callback_data(data); + } + + IREE_TRACE_ZONE_END(z0); + return iree_status_join( + status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); +} + +static iree_status_t iree_hal_hip_device_semaphore_submit_callback( + void* user_context, iree_hal_semaphore_t* semaphore, iree_status_t status) { + iree_hal_hip_device_semaphore_submit_callback_data_t* data = + (iree_hal_hip_device_semaphore_submit_callback_data_t*)user_context; + + if (!iree_status_is_ok(status)) { + iree_slim_mutex_lock(&data->status_mutex); + data->status = iree_status_join(data->status, status); + iree_slim_mutex_unlock(&data->status_mutex); + } + if (iree_atomic_fetch_sub(&data->wait_semaphore_count, 1, + iree_memory_order_acq_rel) != 1) { + return iree_ok_status(); + } + + int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); + + // Now the actual submit happens, as all semaphore have been satisfied + // (by satisfied here, we specifically mean that the semaphore has been + // scheduled, not necessarily completed) + return iree_hal_hip_dispatch_thread_add_dispatch( + data->device->devices[device_ordinal].dispatch_thread, + &iree_hal_hip_device_execute_now, data); } static iree_status_t iree_hal_hip_device_queue_execute( @@ -1206,20 +1924,51 @@ static iree_status_t iree_hal_hip_device_queue_execute( const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_binding_table_t binding_table) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); + if (queue_affinity == IREE_HAL_QUEUE_AFFINITY_ANY) { + queue_affinity = 0x1; + } + + uint64_t queue_affinity_mask = + ((iree_hal_queue_affinity_t)1 << device->device_count); + queue_affinity_mask = queue_affinity_mask | (queue_affinity_mask - 1); + queue_affinity &= queue_affinity_mask; + + int device_ordinal = iree_math_count_trailing_zeros_u64(queue_affinity); + queue_affinity = (uint64_t)1 << device_ordinal; + + iree_hal_hip_device_semaphore_submit_callback_data_t* callback_data = NULL; + iree_status_t status = iree_ok_status(); + status = iree_hal_hip_device_make_callback_data( + device, device->host_allocator, &device->block_pool, queue_affinity, + wait_semaphore_list, signal_semaphore_list, command_buffer, binding_table, + &callback_data); + + if (iree_status_is_ok(status)) { + if (wait_semaphore_list.count == 0) { + status = iree_hal_hip_dispatch_thread_add_dispatch( + device->devices[device_ordinal].dispatch_thread, + &iree_hal_hip_device_execute_now, callback_data); + IREE_TRACE_ZONE_END(z0); + return status; + } + } - iree_status_t status = iree_hal_deferred_work_queue_enqueue( - device->work_queue, iree_hal_hip_device_collect_tracing_context, - device->tracing_context, wait_semaphore_list, signal_semaphore_list, - command_buffer ? 1 : 0, command_buffer ? &command_buffer : NULL, - &binding_table); if (iree_status_is_ok(status)) { - // Try to advance the deferred work queue. - status = iree_hal_deferred_work_queue_issue(device->work_queue); + for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { + status = iree_status_join( + status, + iree_hal_hip_semaphore_notify_work( + wait_semaphore_list.semaphores[i], + wait_semaphore_list.payload_values[i], + device->devices[device_ordinal].device_event_pool, + &iree_hal_hip_device_semaphore_submit_callback, callback_data)); + } + } else { + iree_hal_hip_device_destroy_callback_data(callback_data); } IREE_TRACE_ZONE_END(z0); @@ -1228,23 +1977,15 @@ static iree_status_t iree_hal_hip_device_queue_execute( static iree_status_t iree_hal_hip_device_queue_flush( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { - iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_TRACE_ZONE_BEGIN(z0); - // Try to advance the deferred work queue. - iree_status_t status = iree_hal_deferred_work_queue_issue(device->work_queue); - IREE_TRACE_ZONE_END(z0); - return status; + return iree_ok_status(); } static iree_status_t iree_hal_hip_device_wait_semaphores( iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, const iree_hal_semaphore_list_t semaphore_list, iree_timeout_t timeout) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(device->hip_symbols, device->hip_context)); - return iree_hal_hip_semaphore_multi_wait(semaphore_list, wait_mode, timeout, - &device->block_pool); + device->host_allocator); } static iree_status_t iree_hal_hip_device_profiling_begin( @@ -1298,41 +2039,6 @@ static const iree_hal_device_vtable_t iree_hal_hip_device_vtable = { .profiling_end = iree_hal_hip_device_profiling_end, }; -static const iree_hal_deferred_work_queue_device_interface_vtable_t - iree_hal_hip_deferred_work_queue_device_interface_vtable = { - .destroy = iree_hal_hip_deferred_work_queue_device_interface_destroy, - .bind_to_thread = - iree_hal_hip_deferred_work_queue_device_interface_bind_to_thread, - .wait_native_event = - iree_hal_hip_deferred_work_queue_device_interface_wait_native_event, - .create_native_event = - iree_hal_hip_deferred_work_queue_device_interface_create_native_event, - .record_native_event = - iree_hal_hip_deferred_work_queue_device_interface_record_native_event, - .synchronize_native_event = - iree_hal_hip_deferred_work_queue_device_interface_synchronize_native_event, - .destroy_native_event = - iree_hal_hip_deferred_work_queue_device_interface_destroy_native_event, - .semaphore_acquire_timepoint_device_signal_native_event = - iree_hal_hip_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event, - .acquire_host_wait_event = - iree_hal_hip_deferred_work_queue_device_interface_acquire_host_wait_event, - .device_wait_on_host_event = - iree_hal_hip_deferred_work_queue_device_interface_device_wait_on_host_event, - .release_wait_event = - iree_hal_hip_deferred_work_queue_device_interface_release_wait_event, - .native_event_from_wait_event = - iree_hal_hip_deferred_work_queue_device_interface_native_event_from_wait_event, - .create_stream_command_buffer = - iree_hal_hip_deferred_work_queue_device_interface_create_stream_command_buffer, - .submit_command_buffer = - iree_hal_hip_deferred_work_queue_device_interface_submit_command_buffer, - .async_alloc = - iree_hal_hip_deferred_work_queue_device_interface_async_alloc, - .async_dealloc = - iree_hal_hip_deferred_work_queue_device_interface_async_dealloc, -}; - static const iree_hal_stream_tracing_device_interface_vtable_t iree_hal_hip_tracing_device_interface_vtable_t = { .destroy = iree_hal_hip_tracing_device_interface_destroy, diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.h b/runtime/src/iree/hal/drivers/hip/hip_device.h index 044f4d53f844..f8fc144d90f5 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.h +++ b/runtime/src/iree/hal/drivers/hip/hip_device.h @@ -15,33 +15,16 @@ #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/rccl_dynamic_symbols.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Creates a device that owns and manages its own hipCtx_t. +// Creates a device group from a set of hip devices that manage their own +// hipCtxs. iree_status_t iree_hal_hip_device_create( iree_hal_driver_t* driver, iree_string_view_t identifier, const iree_hal_hip_device_params_t* params, const iree_hal_hip_dynamic_symbols_t* symbols, - const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, hipDevice_t device, + const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, + iree_host_size_t device_count, hipDevice_t* devices, iree_allocator_t host_allocator, iree_hal_device_t** out_device); -// Creates a HIP stream-backed command buffer using resources from the -// given |base_device|. -iree_status_t iree_hal_hip_device_create_stream_command_buffer( - iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, - iree_host_size_t binding_capacity, - iree_hal_command_buffer_t** out_command_buffer); - -// Returns the HIP context bound to the given |device| if it is a HIP device -// and otherwise returns NULL. -// -// WARNING: this API is unsafe and unstable. HAL devices may have any number of -// contexts and the context may be in use on other threads. -hipCtx_t iree_hal_hip_device_context(iree_hal_device_t* device); - // Returns the dynamic symbol table from the |device| if it is a HIP device // and otherwise returns NULL. // @@ -66,8 +49,4 @@ static inline hipDeviceptr_t iree_hal_hip_device_size_to_hip_device_prt( return (hipDeviceptr_t)p; } -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_DEVICE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/hip_driver.c b/runtime/src/iree/hal/drivers/hip/hip_driver.c index 4600d48b086d..a8c6cff5c108 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_driver.c +++ b/runtime/src/iree/hal/drivers/hip/hip_driver.c @@ -60,6 +60,15 @@ IREE_API_EXPORT void iree_hal_hip_driver_options_initialize( out_options->default_device_index = 0; } +// Initializes the HIP system. +static iree_status_t iree_hal_hip_init(iree_hal_hip_driver_t* driver) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = + IREE_HIP_CALL_TO_STATUS(&driver->hip_symbols, hipInit(0), "hipInit"); + IREE_TRACE_ZONE_END(z0); + return status; +} + static iree_status_t iree_hal_hip_driver_create_internal( iree_string_view_t identifier, const iree_hal_hip_driver_options_t* options, const iree_hal_hip_device_params_t* device_params, @@ -91,6 +100,10 @@ static iree_status_t iree_hal_hip_driver_create_internal( memcpy(&driver->device_params, device_params, sizeof(driver->device_params)); + if (iree_status_is_ok(status)) { + status = iree_hal_hip_init(driver); + } + if (iree_status_is_ok(status)) { *out_driver = (iree_hal_driver_t*)driver; } else { @@ -129,15 +142,6 @@ static void iree_hal_hip_driver_destroy(iree_hal_driver_t* base_driver) { IREE_TRACE_ZONE_END(z0); } -// Initializes the HIP system. -static iree_status_t iree_hal_hip_init(iree_hal_hip_driver_t* driver) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = - IREE_HIP_RESULT_TO_STATUS(&driver->hip_symbols, hipInit(0), "hipInit"); - IREE_TRACE_ZONE_END(z0); - return status; -} - // Populates device information from the given HIP physical device handle. // |out_device_info| must point to valid memory and additional data will be // appended to |buffer_ptr| and the new pointer is returned. @@ -198,9 +202,6 @@ static iree_status_t iree_hal_hip_driver_query_available_devices( iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); IREE_TRACE_ZONE_BEGIN(z0); - // Ensure HIP is initialized before querying it. - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_hip_init(driver)); - // Query the number of available HIP devices. int device_count = 0; IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR(z0, &driver->hip_symbols, @@ -218,17 +219,17 @@ static iree_status_t iree_hal_hip_driver_query_available_devices( int valid_device_count = 0; if (iree_status_is_ok(status)) { uint8_t* buffer_ptr = - (uint8_t*)device_infos + device_count * sizeof(iree_hal_device_info_t); + (uint8_t*)device_infos + device_count * sizeof(*device_infos); for (iree_host_size_t i = 0; i < device_count; ++i) { hipDevice_t device = 0; - status = IREE_HIP_RESULT_TO_STATUS( + status = IREE_HIP_CALL_TO_STATUS( &driver->hip_symbols, hipDeviceGet(&device, i), "hipDeviceGet"); if (!iree_status_is_ok(status)) break; status = iree_hal_hip_populate_device_info( device, &driver->hip_symbols, buffer_ptr, &buffer_ptr, &device_infos[valid_device_count]); if (!iree_status_is_ok(status)) break; - valid_device_count++; + ++valid_device_count; } } if (iree_status_is_ok(status)) { @@ -356,9 +357,6 @@ static iree_status_t iree_hal_hip_driver_create_device_by_id( iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); IREE_TRACE_ZONE_BEGIN(z0); - // Ensure HIP is initialized before querying it. - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_hip_init(driver)); - // Use either the specified device (enumerated earlier) or whatever default // one was specified when the driver was created. hipDevice_t device = 0; @@ -376,21 +374,30 @@ static iree_status_t iree_hal_hip_driver_create_device_by_id( // Attempt to create the device now. iree_status_t status = iree_hal_hip_device_create( base_driver, device_name, &driver->device_params, &driver->hip_symbols, - &driver->nccl_symbols, device, host_allocator, out_device); + &driver->nccl_symbols, 1, &device, host_allocator, out_device); IREE_TRACE_ZONE_END(z0); return status; } -static iree_status_t iree_hal_hip_driver_create_device_by_uuid( - iree_hal_driver_t* base_driver, iree_string_view_t driver_name, - const hipUUID* device_uuid, iree_host_size_t param_count, - const iree_string_pair_t* params, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { +static iree_status_t iree_hal_hip_driver_get_device_id_by_uuid( + iree_hal_driver_t* base_driver, iree_string_view_t device_path, + iree_hal_device_id_t* out_id) { iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); - // Ensure HIP is initialized before querying it. - IREE_RETURN_IF_ERROR(iree_hal_hip_init(driver)); + if (!iree_string_view_consume_prefix(&device_path, IREE_SV("GPU-"))) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "device path is not a UUID"); + } + // UUID as returned by hipDeviceGetUuid. + hipUUID device_uuid; + if (!iree_string_view_parse_hex_bytes(device_path, + IREE_ARRAYSIZE(device_uuid.bytes), + (uint8_t*)device_uuid.bytes)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid GPU UUID: '%.*s'", (int)device_path.size, + device_path.data); + } // HIP doesn't have an API to do this so we need to scan all devices to // find the one with the matching UUID. @@ -407,12 +414,13 @@ static iree_status_t iree_hal_hip_driver_create_device_by_uuid( IREE_HIP_RETURN_IF_ERROR(&driver->hip_symbols, hipDeviceGetUuid(&query_uuid, device), "hipDeviceGetUuid"); - if (memcmp(&device_uuid->bytes[0], &query_uuid.bytes[0], + if (memcmp(&device_uuid.bytes[0], &query_uuid.bytes[0], sizeof(device_uuid)) == 0) { found_device = true; break; } } + if (!found_device) { return iree_make_status( IREE_STATUS_NOT_FOUND, @@ -423,32 +431,29 @@ static iree_status_t iree_hal_hip_driver_create_device_by_uuid( "%02x%02x-" "%02x%02x%02x%02x%02x%02x" " not found", - (uint8_t)device_uuid->bytes[0], (uint8_t)device_uuid->bytes[1], - (uint8_t)device_uuid->bytes[2], (uint8_t)device_uuid->bytes[3], - (uint8_t)device_uuid->bytes[4], (uint8_t)device_uuid->bytes[5], - (uint8_t)device_uuid->bytes[6], (uint8_t)device_uuid->bytes[7], - (uint8_t)device_uuid->bytes[8], (uint8_t)device_uuid->bytes[9], - (uint8_t)device_uuid->bytes[10], (uint8_t)device_uuid->bytes[11], - (uint8_t)device_uuid->bytes[12], (uint8_t)device_uuid->bytes[13], - (uint8_t)device_uuid->bytes[14], (uint8_t)device_uuid->bytes[15]); + (uint8_t)device_uuid.bytes[0], (uint8_t)device_uuid.bytes[1], + (uint8_t)device_uuid.bytes[2], (uint8_t)device_uuid.bytes[3], + (uint8_t)device_uuid.bytes[4], (uint8_t)device_uuid.bytes[5], + (uint8_t)device_uuid.bytes[6], (uint8_t)device_uuid.bytes[7], + (uint8_t)device_uuid.bytes[8], (uint8_t)device_uuid.bytes[9], + (uint8_t)device_uuid.bytes[10], (uint8_t)device_uuid.bytes[11], + (uint8_t)device_uuid.bytes[12], (uint8_t)device_uuid.bytes[13], + (uint8_t)device_uuid.bytes[14], (uint8_t)device_uuid.bytes[15]); } - - iree_status_t status = iree_hal_hip_driver_create_device_by_id( - base_driver, IREE_HIPDEVICE_TO_DEVICE_ID(device), param_count, params, - host_allocator, out_device); - - return status; + *out_id = IREE_HIPDEVICE_TO_DEVICE_ID(device); + return iree_ok_status(); } -static iree_status_t iree_hal_hip_driver_create_device_by_index( - iree_hal_driver_t* base_driver, iree_string_view_t driver_name, - int device_index, iree_host_size_t param_count, - const iree_string_pair_t* params, iree_allocator_t host_allocator, - iree_hal_device_t** out_device) { +static iree_status_t iree_hal_hip_driver_get_device_id_by_index( + iree_hal_driver_t* base_driver, iree_string_view_t device_path, + iree_hal_device_id_t* out_id) { iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); - // Ensure HIP is initialized before querying it. - IREE_RETURN_IF_ERROR(iree_hal_hip_init(driver)); + int32_t device_index = 0; + if (!iree_string_view_atoi_int32(device_path, &device_index)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "device path is not an index"); + } // Query the number of available HIP devices. int device_count = 0; @@ -465,10 +470,122 @@ static iree_status_t iree_hal_hip_driver_create_device_by_index( IREE_HIP_RETURN_IF_ERROR(&driver->hip_symbols, hipDeviceGet(&device, device_index), "hipDeviceGet"); - iree_status_t status = iree_hal_hip_driver_create_device_by_id( - base_driver, IREE_HIPDEVICE_TO_DEVICE_ID(device), param_count, params, - host_allocator, out_device); + *out_id = IREE_HIPDEVICE_TO_DEVICE_ID(device); + return iree_ok_status(); +} +static bool iree_hal_hip_driver_is_path_uuid(iree_string_view_t device_path) { + return iree_string_view_starts_with(device_path, IREE_SV("GPU-")); +} + +static bool iree_hal_hip_driver_is_path_index(iree_string_view_t device_path) { + uint32_t unused_device_index = 0; + return iree_string_view_atoi_int32(device_path, &unused_device_index); +} + +static iree_status_t iree_hal_hip_driver_get_device_id_by_path( + iree_hal_driver_t* base_driver, iree_string_view_t device_path, + iree_hal_device_id_t* out_id) { + if (iree_hal_hip_driver_is_path_uuid(device_path)) { + return iree_hal_hip_driver_get_device_id_by_uuid(base_driver, device_path, + out_id); + } + if (iree_hal_hip_driver_is_path_index(device_path)) { + return iree_hal_hip_driver_get_device_id_by_index(base_driver, device_path, + out_id); + } + + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path"); +} + +static iree_status_t iree_hal_hip_driver_create_multi_queue_device_by_ids( + iree_hal_driver_t* base_driver, iree_hal_device_id_t* device_ids, + iree_host_size_t device_count, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(out_device); + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + if (device_count > IREE_HAL_MAX_QUEUES) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "too many physical devices allocated for this logical device"); + } + + hipDevice_t* devices = + (hipDevice_t*)iree_alloca(sizeof(*devices) * device_count); + + for (iree_host_size_t i = 0; i < device_count; ++i) { + if (device_ids[i] == IREE_HAL_DEVICE_ID_DEFAULT) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "invalid to create a device group with the default device id"); + } else { + devices[i] = IREE_DEVICE_ID_TO_HIPDEVICE(device_ids[i]); + } + } + + iree_string_view_t device_name = iree_make_cstring_view("hip"); + + // Attempt to create the device now. + iree_status_t status = iree_hal_hip_device_create( + base_driver, device_name, &driver->device_params, &driver->hip_symbols, + &driver->nccl_symbols, device_count, devices, host_allocator, out_device); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hip_driver_create_multi_queue_device_by_path( + iree_hal_driver_t* base_driver, iree_string_view_t driver_name, + iree_string_view_t device_path, iree_host_size_t param_count, + const iree_string_pair_t* params, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(base_driver); + IREE_ASSERT_ARGUMENT(out_device); + + iree_host_size_t multi_count = 0; + for (iree_host_size_t offs = 0; offs < device_path.size;) { + iree_host_size_t comma_pos = + iree_string_view_find_char(device_path, ',', offs); + if (comma_pos == IREE_STRING_VIEW_NPOS) { + comma_pos = device_path.size; + } + offs = comma_pos + 1; + ++multi_count; + } + + iree_hal_device_id_t* device_ids = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + host_allocator, sizeof(*device_ids) * multi_count, (void**)&device_ids)); + + iree_host_size_t device_index = 0; + for (iree_host_size_t offset = 0; offset < device_path.size;) { + iree_host_size_t comma_pos = + iree_string_view_find_char(device_path, ',', offset); + if (comma_pos == IREE_STRING_VIEW_NPOS) { + comma_pos = device_path.size; + } + iree_string_view_t this_device_path = + iree_string_view_substr(device_path, offset, comma_pos - offset); + iree_status_t status = iree_hal_hip_driver_get_device_id_by_path( + base_driver, this_device_path, &device_ids[device_index]); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + iree_allocator_free(host_allocator, device_ids); + return status; + } + offset = comma_pos + 1; + ++device_index; + } + + iree_status_t status = iree_hal_hip_driver_create_multi_queue_device_by_ids( + base_driver, device_ids, device_index, param_count, params, + host_allocator, out_device); + iree_allocator_free(host_allocator, device_ids); return status; } @@ -486,30 +603,18 @@ static iree_status_t iree_hal_hip_driver_create_device_by_path( host_allocator, out_device); } - if (iree_string_view_consume_prefix(&device_path, IREE_SV("GPU-"))) { - // UUID as returned by hipDeviceGetUuid. - hipUUID device_uuid; - if (!iree_string_view_parse_hex_bytes(device_path, - IREE_ARRAYSIZE(device_uuid.bytes), - (uint8_t*)device_uuid.bytes)) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "invalid GPU UUID: '%.*s'", (int)device_path.size, - device_path.data); - } - return iree_hal_hip_driver_create_device_by_uuid( - base_driver, driver_name, &device_uuid, param_count, params, - host_allocator, out_device); - } - - // Try to parse as a device index. - int device_index = 0; - if (iree_string_view_atoi_int32(device_path, &device_index)) { - return iree_hal_hip_driver_create_device_by_index( - base_driver, driver_name, device_index, param_count, params, + if (iree_string_view_find_char(device_path, ',', 0) != + IREE_STRING_VIEW_NPOS) { + return iree_hal_hip_driver_create_multi_queue_device_by_path( + base_driver, driver_name, device_path, param_count, params, host_allocator, out_device); } - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unsupported device path"); + iree_hal_device_id_t id; + IREE_RETURN_IF_ERROR( + iree_hal_hip_driver_get_device_id_by_path(base_driver, device_path, &id)); + return iree_hal_hip_driver_create_device_by_id( + base_driver, id, param_count, params, host_allocator, out_device); } static const iree_hal_driver_vtable_t iree_hal_hip_driver_vtable = { diff --git a/runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.c b/runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.c new file mode 100644 index 000000000000..5c002a19faf7 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.c @@ -0,0 +1,356 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/hip/hip_multi_queue_command_buffer.h" + +#include "iree/base/internal/arena.h" +#include "iree/base/internal/math.h" +#include "iree/hal/drivers/hip/context_util.h" +#include "iree/hal/drivers/hip/status_util.h" +#include "iree/hal/utils/resource_set.h" + +//===----------------------------------------------------------------------===// +// iree_hal_hip_multi_queue_command_buffer_t implementation +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hip_multi_queue_command_buffer_t { + iree_hal_command_buffer_t base; + iree_allocator_t host_allocator; + iree_host_size_t command_buffer_count; + iree_hal_hip_device_topology_t topology; + const iree_hal_hip_dynamic_symbols_t* hip_symbols; + iree_hal_command_buffer_t* child_buffers[]; +} iree_hal_hip_multi_queue_command_buffer_t; + +static const iree_hal_command_buffer_vtable_t + iree_hal_hip_multi_queue_command_buffer_vtable; + +static iree_hal_hip_multi_queue_command_buffer_t* +iree_hal_hip_multi_queue_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_hip_multi_queue_command_buffer_vtable); + return (iree_hal_hip_multi_queue_command_buffer_t*)base_value; +} + +IREE_API_EXPORT iree_status_t iree_hal_hip_multi_queue_command_buffer_create( + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t** in_command_buffers, + iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, + const iree_hal_hip_dynamic_symbols_t* hip_symbols, + iree_hal_hip_device_topology_t topology, iree_host_size_t binding_capacity, + iree_allocator_t host_allocator, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(out_command_buffer); + *out_command_buffer = NULL; + + if (iree_math_count_ones_u64(queue_affinity) != command_buffer_count) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "expected one command buffer per enabled queue"); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc( + host_allocator, + sizeof(*command_buffer) + + command_buffer_count * sizeof(iree_hal_command_buffer_t*) + + iree_hal_command_buffer_validation_state_size( + mode, binding_capacity), + (void**)&command_buffer)); + iree_hal_command_buffer_initialize( + device_allocator, mode, command_categories, queue_affinity, + binding_capacity, + (uint8_t*)command_buffer + sizeof(*command_buffer) + + command_buffer_count * sizeof(iree_hal_command_buffer_t*), + &iree_hal_hip_multi_queue_command_buffer_vtable, &command_buffer->base); + command_buffer->host_allocator = host_allocator; + command_buffer->command_buffer_count = command_buffer_count; + command_buffer->topology = topology; + command_buffer->hip_symbols = hip_symbols; + for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { + command_buffer->child_buffers[i] = in_command_buffers[i]; + iree_hal_resource_retain(command_buffer->child_buffers[i]); + } + + *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static void iree_hal_hip_multi_queue_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + for (iree_host_size_t i = 0; i < command_buffer->command_buffer_count; ++i) { + iree_hal_resource_release(command_buffer->child_buffers[i]); + } + iree_allocator_free(command_buffer->host_allocator, command_buffer); + IREE_TRACE_ZONE_END(z0); +} + +IREE_API_EXPORT bool iree_hal_hip_multi_queue_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer) { + return iree_hal_resource_is(&command_buffer->resource, + &iree_hal_hip_multi_queue_command_buffer_vtable); +} + +IREE_API_EXPORT iree_status_t iree_hal_hip_multi_queue_command_buffer_get( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_queue_affinity_t queue_affinity, + iree_hal_command_buffer_t** out_command_buffer) { + *out_command_buffer = NULL; + if (iree_math_count_ones_u64(queue_affinity) != 1) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "one and only one device may be specified."); + } + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + if (!(command_buffer->base.queue_affinity & queue_affinity)) { + return iree_make_status(IREE_STATUS_NOT_FOUND, + "no command buffer for affinity %lu", + queue_affinity); + } + int index = iree_math_count_ones_u64(command_buffer->base.queue_affinity & + (queue_affinity - 1)); + + *out_command_buffer = command_buffer->child_buffers[index]; + return iree_ok_status(); +} + +// Use |command_buffer_index| in the command to index into the correct +// command buffer, within the given command. +#define CALL_COMMAND(status, command) \ + do { \ + iree_hal_queue_affinity_t queue_affinity = \ + command_buffer->base.queue_affinity; \ + int command_buffer_index = 0; \ + int device_ordinal = 0; \ + while (queue_affinity && IREE_LIKELY(iree_status_is_ok(status))) { \ + int count = iree_math_count_trailing_zeros_u64(queue_affinity); \ + device_ordinal += count; \ + status = iree_hal_hip_set_context( \ + command_buffer->hip_symbols, \ + command_buffer->topology.devices[device_ordinal].hip_context); \ + if (!iree_status_is_ok(status)) { \ + break; \ + } \ + status = command; \ + queue_affinity >>= (count + 1); \ + device_ordinal += 1; \ + ++command_buffer_index; \ + } \ + } while (false) + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, + iree_hal_command_buffer_begin( + command_buffer->child_buffers[command_buffer_index])); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, + iree_hal_command_buffer_end( + command_buffer->child_buffers[command_buffer_index])); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_hal_execution_barrier_flags_t flags, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_execution_barrier( + command_buffer->child_buffers[command_buffer_index], + source_stage_mask, target_stage_mask, flags, + memory_barrier_count, memory_barriers, + buffer_barrier_count, buffer_barriers)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_signal_event( + command_buffer->child_buffers[command_buffer_index], + event, source_stage_mask)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_reset_event( + command_buffer->child_buffers[command_buffer_index], + event, source_stage_mask)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND( + status, + iree_hal_command_buffer_wait_events( + command_buffer->child_buffers[command_buffer_index], event_count, + events, source_stage_mask, target_stage_mask, memory_barrier_count, + memory_barriers, buffer_barrier_count, buffer_barriers)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_advise_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t buffer_ref, iree_hal_memory_advise_flags_t flags, + uint64_t arg0, uint64_t arg1) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_advise_buffer( + command_buffer->child_buffers[command_buffer_index], + buffer_ref, flags, arg0, arg1)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t target_ref, const void* pattern, + iree_host_size_t pattern_length, iree_hal_fill_flags_t flags) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_fill_buffer( + command_buffer->child_buffers[command_buffer_index], + target_ref, pattern, pattern_length, flags)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref, + iree_hal_update_flags_t flags) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_update_buffer( + command_buffer->child_buffers[command_buffer_index], + source_buffer, source_offset, target_ref, flags)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref, + iree_hal_copy_flags_t flags) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_copy_buffer( + command_buffer->child_buffers[command_buffer_index], + source_ref, target_ref, flags)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_collective( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel, + iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref, + iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, + iree_hal_command_buffer_collective( + command_buffer->child_buffers[command_buffer_index], channel, + op, param, send_ref, recv_ref, element_count)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_dispatch( + command_buffer->child_buffers[command_buffer_index], + executable, entry_point, workgroup_count, constants, + bindings, flags)); + return status; +} + +static iree_status_t iree_hal_hip_multi_queue_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_hip_multi_queue_command_buffer_t* command_buffer = + iree_hal_hip_multi_queue_command_buffer_cast(base_command_buffer); + iree_status_t status = iree_ok_status(); + CALL_COMMAND(status, iree_hal_command_buffer_dispatch_indirect( + command_buffer->child_buffers[command_buffer_index], + executable, entry_point, workgroups_ref, constants, + bindings, flags)); + return status; +} + +static const iree_hal_command_buffer_vtable_t + iree_hal_hip_multi_queue_command_buffer_vtable = { + .destroy = iree_hal_hip_multi_queue_command_buffer_destroy, + .begin = iree_hal_hip_multi_queue_command_buffer_begin, + .end = iree_hal_hip_multi_queue_command_buffer_end, + .execution_barrier = + iree_hal_hip_multi_queue_command_buffer_execution_barrier, + .signal_event = iree_hal_hip_multi_queue_command_buffer_signal_event, + .reset_event = iree_hal_hip_multi_queue_command_buffer_reset_event, + .wait_events = iree_hal_hip_multi_queue_command_buffer_wait_events, + .advise_buffer = iree_hal_hip_multi_queue_command_buffer_advise_buffer, + .fill_buffer = iree_hal_hip_multi_queue_command_buffer_fill_buffer, + .update_buffer = iree_hal_hip_multi_queue_command_buffer_update_buffer, + .copy_buffer = iree_hal_hip_multi_queue_command_buffer_copy_buffer, + .collective = iree_hal_hip_multi_queue_command_buffer_collective, + .dispatch = iree_hal_hip_multi_queue_command_buffer_dispatch, + .dispatch_indirect = + iree_hal_hip_multi_queue_command_buffer_dispatch_indirect, +}; + +#undef CALL_COMMAND diff --git a/runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.h b/runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.h new file mode 100644 index 000000000000..1dfe18f4da07 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/hip_multi_queue_command_buffer.h @@ -0,0 +1,48 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_MULTI_QUEUE_COMMAND_BUFFER_H_ +#define IREE_HAL_DRIVERS_HIP_MULTI_QUEUE_COMMAND_BUFFER_H_ + +#include "iree/base/api.h" +#include "iree/hal/command_buffer.h" +#include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/per_device_information.h" + +typedef struct iree_arena_block_pool_t iree_arena_block_pool_t; + +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_t deferred record/replay wrapper +//===----------------------------------------------------------------------===// + +// Creates a command buffer that records into multiple command buffers +// at a time based on the given queue affinity. +// +// After recording the underlying command buffers can be retrieved with +// iree_hal_hip_multi_queue_command_buffer_get for submission. +IREE_API_EXPORT iree_status_t iree_hal_hip_multi_queue_command_buffer_create( + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t** in_command_buffers, + iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_queue_affinity_t queue_affinity, + const iree_hal_hip_dynamic_symbols_t* hip_symbols, + iree_hal_hip_device_topology_t topology, iree_host_size_t binding_capacity, + iree_allocator_t host_allocator, + iree_hal_command_buffer_t** out_command_buffer); + +// Returns true if |command_buffer| is a multi command buffer. +IREE_API_EXPORT bool iree_hal_hip_multi_queue_command_buffer_isa( + iree_hal_command_buffer_t* command_buffer); + +// Returns a recorded command_buffer with the given |queue_affinity|. +// It is expected that only a single bit is set for the queue affinity here. +IREE_API_EXPORT iree_status_t iree_hal_hip_multi_queue_command_buffer_get( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_queue_affinity_t queue_affinity, + iree_hal_command_buffer_t** out_command_buffer); + +#endif // IREE_HAL_DRIVERS_HIP_MULTI_QUEUE_COMMAND_BUFFER_H_ diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.c b/runtime/src/iree/hal/drivers/hip/memory_pools.c index c5a927fc0bbd..557ee860e0c2 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.c +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.c @@ -6,7 +6,6 @@ #include "iree/hal/drivers/hip/memory_pools.h" -#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/status_util.h" @@ -44,7 +43,7 @@ static iree_status_t iree_hal_hip_create_memory_pool( IREE_HIP_RETURN_IF_ERROR(hip_symbols, hipMemPoolCreate(&pool, &pool_props), "hipMemPoolCreate"); - iree_status_t status = IREE_HIP_RESULT_TO_STATUS( + iree_status_t status = IREE_HIP_CALL_TO_STATUS( hip_symbols, hipMemPoolSetAttribute(pool, hipMemPoolAttrReleaseThreshold, ¶ms.release_threshold), @@ -61,7 +60,6 @@ static iree_status_t iree_hal_hip_create_memory_pool( iree_status_t iree_hal_hip_memory_pools_initialize( iree_hal_device_t* parent_device, const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t hip_device, - hipCtx_t hip_context, const iree_hal_hip_memory_pooling_params_t* pooling_params, iree_allocator_t host_allocator, iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools) { @@ -70,14 +68,11 @@ iree_status_t iree_hal_hip_memory_pools_initialize( IREE_ASSERT_ARGUMENT(pooling_params); IREE_ASSERT_ARGUMENT(out_pools); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(hip_symbols, hip_context)); memset(out_pools, 0, sizeof(*out_pools)); out_pools->parent_device = parent_device; out_pools->hip_symbols = hip_symbols; out_pools->host_allocator = host_allocator; - out_pools->hip_context = hip_context; iree_status_t status = iree_ok_status(); @@ -99,8 +94,6 @@ iree_status_t iree_hal_hip_memory_pools_initialize( void iree_hal_hip_memory_pools_deinitialize( iree_hal_hip_memory_pools_t* pools) { IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR( - iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); if (pools->device_local) { IREE_HIP_IGNORE_ERROR(pools->hip_symbols, @@ -159,9 +152,6 @@ static void iree_hal_hip_memory_pool_track_free( void iree_hal_hip_memory_pools_merge_statistics( iree_hal_hip_memory_pools_t* pools, iree_hal_allocator_statistics_t* statistics) { - IREE_IGNORE_ERROR( - iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); - IREE_STATISTICS({ statistics->device_bytes_allocated = iree_atomic_load( &pools->statistics.device_bytes_allocated, iree_memory_order_relaxed); @@ -194,9 +184,6 @@ void iree_hal_hip_memory_pools_merge_statistics( iree_status_t iree_hal_hip_memory_pools_trim( iree_hal_hip_memory_pools_t* pools, const iree_hal_hip_memory_pooling_params_t* pooling_params) { - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); - IREE_HIP_RETURN_IF_ERROR( pools->hip_symbols, hipMemPoolTrimTo(pools->device_local, @@ -216,8 +203,6 @@ static void iree_hal_hip_async_buffer_release_callback( void* user_data, iree_hal_buffer_t* buffer) { iree_hal_hip_memory_pools_t* pools = (iree_hal_hip_memory_pools_t*)user_data; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR( - iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (device_ptr) { @@ -231,9 +216,6 @@ static void iree_hal_hip_async_buffer_release_callback( iree_status_t iree_hal_hip_memory_pools_allocate_pointer( iree_hal_hip_memory_pools_t* pools, iree_hal_buffer_t* buffer, hipStream_t stream, iree_device_size_t allocation_size) { - IREE_RETURN_IF_ERROR( - iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); - // TODO: more pools and better selection; this is coarsely deciding between // only device local (variables, constants, transients) and other (staging, // external) but could use more buffer properties (including usage/export @@ -246,7 +228,7 @@ iree_status_t iree_hal_hip_memory_pools_allocate_pointer( : pools->other; hipDeviceptr_t device_ptr = NULL; - IREE_RETURN_IF_ERROR(IREE_HIP_RESULT_TO_STATUS( + IREE_RETURN_IF_ERROR(IREE_HIP_CALL_TO_STATUS( pools->hip_symbols, hipMallocFromPoolAsync(&device_ptr, (size_t)allocation_size, memory_pool, stream), @@ -264,7 +246,6 @@ iree_status_t iree_hal_hip_memory_pools_prepare_buffer( iree_hal_buffer_t** IREE_RESTRICT out_buffer) { IREE_TRACE_ZONE_BEGIN(z0); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)allocation_size); - iree_hal_buffer_params_canonicalize(¶ms); // NOTE: we don't provide a device allocator because we didn't allocate from @@ -281,18 +262,19 @@ iree_status_t iree_hal_hip_memory_pools_prepare_buffer( .user_data = pools, }; iree_hal_buffer_t* buffer = NULL; + iree_status_t status = iree_hal_hip_buffer_wrap( placement, params.type, params.access, params.usage, allocation_size, - /*byte_offset=*/0, - /*byte_length=*/allocation_size, IREE_HAL_HIP_BUFFER_TYPE_ASYNC, - /*device_ptr*/ NULL, /*host_ptr=*/NULL, release_callback, - pools->host_allocator, &buffer); + /*byte_offset=*/0, /*byte_length=*/allocation_size, + IREE_HAL_HIP_BUFFER_TYPE_ASYNC, /*device_ptr*/ NULL, /*host_ptr=*/NULL, + release_callback, pools->host_allocator, &buffer); if (iree_status_is_ok(status)) { - // Update statistics (note that it may not yet be accurate). *out_buffer = buffer; - } else if (buffer) { - iree_hal_hip_buffer_set_allocation_empty(buffer); + } else { + if (buffer) { + iree_hal_hip_buffer_set_allocation_empty(buffer); + } iree_hal_buffer_release(buffer); } @@ -304,8 +286,6 @@ iree_status_t iree_hal_hip_memory_pools_deallocate( iree_hal_hip_memory_pools_t* pools, hipStream_t stream, iree_hal_buffer_t* buffer) { IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(pools->hip_symbols, pools->hip_context)); IREE_TRACE_ZONE_APPEND_VALUE_I64( z0, (int64_t)iree_hal_buffer_allocation_size(buffer)); @@ -318,8 +298,8 @@ iree_status_t iree_hal_hip_memory_pools_deallocate( // Try to schedule the buffer for freeing. hipDeviceptr_t device_ptr = iree_hal_hip_buffer_device_pointer(buffer); if (device_ptr) { - status = IREE_HIP_RESULT_TO_STATUS( - pools->hip_symbols, hipFreeAsync(device_ptr, stream), "hipFreeAsync"); + status = IREE_HIP_CALL_TO_STATUS(pools->hip_symbols, hipFree(device_ptr), + "hipFree"); } if (iree_status_is_ok(status)) { // Drop the release callback so that we don't try to double-free the diff --git a/runtime/src/iree/hal/drivers/hip/memory_pools.h b/runtime/src/iree/hal/drivers/hip/memory_pools.h index 7d66090e33f6..5b3914c5cf01 100644 --- a/runtime/src/iree/hal/drivers/hip/memory_pools.h +++ b/runtime/src/iree/hal/drivers/hip/memory_pools.h @@ -14,10 +14,6 @@ #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_headers.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - // WARNING: hipMemPool API's are marked as beta in HIP library meaning // that while the feature is complete, it is still open to changes and may // have outstanding issues. @@ -34,7 +30,6 @@ typedef struct iree_hal_hip_memory_pools_t { iree_hal_device_t* parent_device; const iree_hal_hip_dynamic_symbols_t* hip_symbols; - hipCtx_t hip_context; iree_allocator_t host_allocator; IREE_STATISTICS(struct { @@ -49,7 +44,6 @@ typedef struct iree_hal_hip_memory_pools_t { iree_status_t iree_hal_hip_memory_pools_initialize( iree_hal_device_t* parent_device, const iree_hal_hip_dynamic_symbols_t* hip_symbols, hipDevice_t hip_device, - hipCtx_t hip_context, const iree_hal_hip_memory_pooling_params_t* pooling_params, iree_allocator_t host_allocator, iree_hal_hip_memory_pools_t* IREE_RESTRICT out_pools); @@ -87,8 +81,4 @@ iree_status_t iree_hal_hip_memory_pools_deallocate( iree_hal_hip_memory_pools_t* pools, hipStream_t stream, iree_hal_buffer_t* buffer); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_MEMORY_POOLS_H_ diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c index e37aba578f02..902fb8374116 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.c +++ b/runtime/src/iree/hal/drivers/hip/native_executable.c @@ -9,6 +9,7 @@ #include #include "iree/base/api.h" +#include "iree/base/internal/math.h" #include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/status_util.h" @@ -21,6 +22,16 @@ #include "iree/schemas/hip_executable_def_reader.h" #include "iree/schemas/hip_executable_def_verifier.h" +typedef struct iree_hal_hip_native_executable_per_device_data_t { + // Loaded HIP modules. + iree_host_size_t module_count; + hipModule_t* modules; + + // Exported kernels referencing the loaded modules. + iree_host_size_t export_count; + iree_hal_hip_kernel_params_t exports[]; +} iree_hal_hip_native_executable_per_device_data_t; + typedef struct iree_hal_hip_native_executable_t { // Abstract resource used for injecting reference counting and vtable; // must be at offset 0. @@ -29,13 +40,8 @@ typedef struct iree_hal_hip_native_executable_t { const iree_hal_hip_dynamic_symbols_t* symbols; - // Loaded HIP modules. - iree_host_size_t module_count; - hipModule_t* modules; - - // Exported kernels referencing the loaded modules. - iree_host_size_t export_count; - iree_hal_hip_kernel_params_t exports[]; + iree_host_size_t num_devices; + iree_hal_hip_native_executable_per_device_data_t* per_device_data[]; } iree_hal_hip_native_executable_t; static const iree_hal_executable_vtable_t iree_hal_hip_native_executable_vtable; @@ -207,22 +213,27 @@ static iree_status_t iree_hal_hip_native_executable_flatbuffer_verify( } iree_status_t iree_hal_hip_native_executable_create( - const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - hipCtx_t context, const iree_hal_executable_params_t* executable_params, + const iree_hal_hip_dynamic_symbols_t* symbols, + iree_hal_hip_device_topology_t topology, + const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(executable_params); IREE_ASSERT_ARGUMENT(out_executable); + if (topology.count < 1) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "at least one device is required but none were provided"); + } IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, - iree_hal_hip_set_context(symbols, context)); *out_executable = NULL; // TODO: move to the executable cache to avoid repeated queries. iree_hal_hip_limits_t limits = {0}; IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_query_limits(symbols, device, &limits)); + z0, iree_hal_hip_query_limits(symbols, topology.devices[0].hip_device, + &limits)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_native_executable_flatbuffer_verify( @@ -255,9 +266,17 @@ iree_status_t iree_hal_hip_native_executable_create( // Allocate storage for the executable and its associated data structures. iree_hal_hip_native_executable_t* executable = NULL; + iree_host_size_t native_executable_device_info_size = + sizeof(*executable->per_device_data[0]) + + module_count * sizeof(executable->per_device_data[0]->modules[0]) + + export_count * sizeof(executable->per_device_data[0]->exports[0]) + + total_export_info_length; + native_executable_device_info_size = + iree_host_align(native_executable_device_info_size, iree_max_align_t); const iree_host_size_t total_size = - sizeof(*executable) + module_count * sizeof(executable->modules[0]) + - export_count * sizeof(executable->exports[0]) + total_export_info_length; + sizeof(*executable) + + topology.count * sizeof(executable->per_device_data[0]) + + topology.count * native_executable_device_info_size; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, total_size, (void**)&executable)); @@ -265,127 +284,164 @@ iree_status_t iree_hal_hip_native_executable_create( &executable->resource); executable->host_allocator = host_allocator; executable->symbols = symbols; - executable->module_count = module_count; - executable->modules = - (hipModule_t*)((uint8_t*)executable + sizeof(*executable) + - export_count * sizeof(executable->exports[0])); - executable->export_count = export_count; - IREE_TRACE(uint8_t* export_info_ptr = - ((uint8_t*)executable->modules + - module_count * sizeof(executable->modules[0]))); + executable->num_devices = topology.count; + const iree_host_size_t per_device_data_size = + topology.count * sizeof(executable->per_device_data[0]); + const uint8_t* per_device_data_location = + (uint8_t*)executable + sizeof(*executable); + + for (iree_host_size_t i = 0; i < topology.count; ++i) { + const iree_host_size_t native_executable_device_info_size_offset = + (i * native_executable_device_info_size); + + executable->per_device_data[i] = + (iree_hal_hip_native_executable_per_device_data_t*)(per_device_data_location + + per_device_data_size + + native_executable_device_info_size_offset); + } // Publish any embedded source files to the tracing infrastructure. iree_hal_debug_publish_source_files( iree_hal_hip_ExecutableDef_source_files_get(executable_def)); - // Load each module first so that exports can reference them. iree_status_t status = iree_ok_status(); - for (iree_host_size_t i = 0; i < module_count; ++i) { - iree_hal_hip_ModuleDef_table_t module_def = - iree_hal_hip_ModuleDef_vec_at(modules_vec, i); + for (iree_host_size_t j = 0; j < topology.count && iree_status_is_ok(status); + ++j) { + IREE_RETURN_IF_ERROR(IREE_HIP_CALL_TO_STATUS( + symbols, hipCtxPushCurrent(topology.devices[j].hip_context))); - // WARNING: HIP doesn't take an expected length here so we can't bound it. - // It's likely that users could craft inputs that read beyond the extents of - // the embedded binary. - flatbuffers_string_t hsaco_image = - iree_hal_hip_ModuleDef_hsaco_image_get(module_def); - - // TODO: pass hipJitOption values to get log info and other info back. - // We pass the error buffer today but could use the info log to diagnose - // performance warnings. - char error_log[8192] = {0}; - hipJitOption jit_options[] = { - hipJitOptionErrorLogBuffer, - hipJitOptionErrorLogBufferSizeBytes, - }; - void* jit_option_values[] = { - (void*)error_log, - (void*)(uint32_t)sizeof(error_log), - }; - hipModule_t module = NULL; - status = IREE_HIP_RESULT_TO_STATUS( - symbols, - hipModuleLoadDataEx(&module, hsaco_image, IREE_ARRAYSIZE(jit_options), - jit_options, jit_option_values), - "hipModuleLoadDataEx"); if (!iree_status_is_ok(status)) { - status = iree_status_annotate( - status, - IREE_SV("mismatched target chip? missing/wrong bitcode directory?")); - if (strlen(error_log) > 0) { - status = - iree_status_annotate(status, iree_make_cstring_view(error_log)); - } + IREE_RETURN_IF_ERROR( + IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); break; } + iree_hal_hip_native_executable_per_device_data_t* per_device_data = + executable->per_device_data[j]; + + per_device_data->module_count = module_count; + per_device_data->modules = + (hipModule_t*)((uint8_t*)per_device_data + sizeof(*per_device_data) + + (export_count * sizeof(per_device_data->exports[0]))); + per_device_data->export_count = export_count; + IREE_TRACE(uint8_t* export_info_ptr = + ((uint8_t*)per_device_data->modules + + module_count * sizeof(per_device_data->modules[0]))); + + // Load each module first so that exports can reference them. + for (iree_host_size_t i = 0; i < module_count; ++i) { + iree_hal_hip_ModuleDef_table_t module_def = + iree_hal_hip_ModuleDef_vec_at(modules_vec, i); + + // WARNING: HIP doesn't take an expected length here so we can't bound it. + // It's likely that users could craft inputs that read beyond the extents + // of the embedded binary. + flatbuffers_string_t hsaco_image = + iree_hal_hip_ModuleDef_hsaco_image_get(module_def); + + // TODO: pass hipJitOption values to get log info and other info back. + // We pass the error buffer today but could use the info log to diagnose + // performance warnings. + char error_log[8192] = {0}; + hipJitOption jit_options[] = { + hipJitOptionErrorLogBuffer, + hipJitOptionErrorLogBufferSizeBytes, + }; + void* jit_option_values[] = { + (void*)error_log, + (void*)(uint32_t)sizeof(error_log), + }; + hipModule_t module = NULL; + status = IREE_HIP_CALL_TO_STATUS( + symbols, + hipModuleLoadDataEx(&module, hsaco_image, IREE_ARRAYSIZE(jit_options), + jit_options, jit_option_values), + "hipModuleLoadDataEx"); + if (!iree_status_is_ok(status)) { + status = iree_status_annotate( + status, + IREE_SV( + "mismatched target chip? missing/wrong bitcode directory?")); + if (strlen(error_log) > 0) { + status = + iree_status_annotate(status, iree_make_cstring_view(error_log)); + } + break; + } - executable->modules[i] = module; - } - - if (iree_status_is_ok(status)) { - for (iree_host_size_t i = 0; i < export_count; ++i) { - iree_hal_hip_ExportDef_table_t export_def = - iree_hal_hip_ExportDef_vec_at(exports_vec, i); + per_device_data->modules[i] = module; - // Lookup the function in the module; this should always succeed but - // we cannot trust that the input was generated by our compiler. - uint32_t module_ordinal = - iree_hal_hip_ExportDef_module_ordinal_get(export_def); - hipModule_t module = executable->modules[module_ordinal]; - flatbuffers_string_t kernel_name = - iree_hal_hip_ExportDef_kernel_name_get(export_def); - hipFunction_t function = NULL; - status = IREE_HIP_RESULT_TO_STATUS( - symbols, hipModuleGetFunction(&function, module, kernel_name), - "hipModuleGetFunction"); - if (!iree_status_is_ok(status)) break; - if (!function) { - status = iree_make_status(IREE_STATUS_NOT_FOUND, - "exports[%" PRIhsz - "] kernel `%s` not found in modules[%u]", - i, kernel_name, module_ordinal); + if (!iree_status_is_ok(status)) { break; } - - uint32_t block_shared_memory_size = - iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def); - status = IREE_HIP_RESULT_TO_STATUS( - symbols, - hipFuncSetAttribute( - function, - (hipFuncAttribute) - HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - block_shared_memory_size), - "hipFuncSetAttribute"); - if (!iree_status_is_ok(status)) break; - - // Package required parameters for kernel launches for each entry point. - iree_hal_hip_kernel_params_t* kernel_info = &executable->exports[i]; - kernel_info->function = function; - const iree_hal_hip_BlockDims_t* block_dims = - iree_hal_hip_ExportDef_block_dims_get(export_def); - kernel_info->block_dims[0] = block_dims->x; - kernel_info->block_dims[1] = block_dims->y; - kernel_info->block_dims[2] = block_dims->z; - kernel_info->block_shared_memory_size = - iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def); - kernel_info->constant_count = - iree_hal_hip_ExportDef_constant_count_get(export_def); - iree_hal_hip_BindingBits_vec_t binding_flags_vec = - iree_hal_hip_ExportDef_binding_flags_get(export_def); - kernel_info->binding_count = - iree_hal_hip_BindingBits_vec_len(binding_flags_vec); - - IREE_TRACE({ - iree_hal_debug_export_info_t* export_info = - (iree_hal_debug_export_info_t*)export_info_ptr; - export_info_ptr += iree_hal_debug_copy_export_info( - iree_hal_hip_ExportDef_debug_info_get(export_def), export_info); - kernel_info->debug_info.function_name = export_info->function_name; - kernel_info->debug_info.source_filename = export_info->source_filename; - kernel_info->debug_info.source_line = export_info->source_line; - }); + for (iree_host_size_t i = 0; i < export_count; ++i) { + iree_hal_hip_ExportDef_table_t export_def = + iree_hal_hip_ExportDef_vec_at(exports_vec, i); + + // Lookup the function in the module; this should always succeed but + // we cannot trust that the input was generated by our compiler. + uint32_t module_ordinal = + iree_hal_hip_ExportDef_module_ordinal_get(export_def); + hipModule_t module = per_device_data->modules[module_ordinal]; + flatbuffers_string_t kernel_name = + iree_hal_hip_ExportDef_kernel_name_get(export_def); + hipFunction_t function = NULL; + status = IREE_HIP_CALL_TO_STATUS( + symbols, hipModuleGetFunction(&function, module, kernel_name), + "hipModuleGetFunction"); + if (!iree_status_is_ok(status)) break; + if (!function) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "exports[%" PRIhsz + "] kernel `%s` not found in modules[%u]", + i, kernel_name, module_ordinal); + break; + } + + uint32_t block_shared_memory_size = + iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def); + status = IREE_HIP_CALL_TO_STATUS( + symbols, + hipFuncSetAttribute( + function, + (hipFuncAttribute) + HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + block_shared_memory_size), + "hipFuncSetAttribute"); + if (!iree_status_is_ok(status)) break; + + // Package required parameters for kernel launches for each entry + // point. + iree_hal_hip_kernel_params_t* kernel_info = + &per_device_data->exports[i]; + kernel_info->function = function; + const iree_hal_hip_BlockDims_t* block_dims = + iree_hal_hip_ExportDef_block_dims_get(export_def); + kernel_info->block_dims[0] = block_dims->x; + kernel_info->block_dims[1] = block_dims->y; + kernel_info->block_dims[2] = block_dims->z; + kernel_info->block_shared_memory_size = + iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def); + kernel_info->constant_count = + iree_hal_hip_ExportDef_constant_count_get(export_def); + iree_hal_hip_BindingBits_vec_t binding_flags_vec = + iree_hal_hip_ExportDef_binding_flags_get(export_def); + kernel_info->binding_count = + iree_hal_hip_BindingBits_vec_len(binding_flags_vec); + + IREE_TRACE({ + iree_hal_debug_export_info_t* export_info = + (iree_hal_debug_export_info_t*)export_info_ptr; + export_info_ptr += iree_hal_debug_copy_export_info( + iree_hal_hip_ExportDef_debug_info_get(export_def), export_info); + kernel_info->debug_info.function_name = export_info->function_name; + kernel_info->debug_info.source_filename = + export_info->source_filename; + kernel_info->debug_info.source_line = export_info->source_line; + }); + } } + IREE_RETURN_IF_ERROR( + IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); } if (iree_status_is_ok(status)) { @@ -405,10 +461,14 @@ static void iree_hal_hip_native_executable_destroy( iree_allocator_t host_allocator = executable->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); - for (iree_host_size_t i = 0; i < executable->module_count; ++i) { - if (executable->modules[i]) { - IREE_HIP_IGNORE_ERROR(executable->symbols, - hipModuleUnload(executable->modules[i])); + for (iree_host_size_t i = 0; i < executable->num_devices; ++i) { + const iree_hal_hip_native_executable_per_device_data_t* data = + executable->per_device_data[i]; + for (iree_host_size_t j = 0; j < data->module_count; ++j) { + if (data->modules[j]) { + IREE_HIP_IGNORE_ERROR(executable->symbols, + hipModuleUnload(data->modules[j])); + } } } @@ -419,17 +479,30 @@ static void iree_hal_hip_native_executable_destroy( iree_status_t iree_hal_hip_native_executable_lookup_kernel_params( iree_hal_executable_t* base_executable, int32_t ordinal, + iree_hal_queue_affinity_t queue_affinity, const iree_hal_hip_kernel_params_t** out_params) { + *out_params = NULL; iree_hal_hip_native_executable_t* executable = iree_hal_hip_native_executable_cast(base_executable); - if (ordinal >= executable->export_count) { + int device_ordinal = 0; + if (queue_affinity) { + device_ordinal = iree_math_count_trailing_zeros_u64(queue_affinity); + } + if (device_ordinal > executable->num_devices) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "affinity for non-existent queue was provided."); + } + + const iree_hal_hip_native_executable_per_device_data_t* data = + executable->per_device_data[device_ordinal]; + if (ordinal >= data->export_count) { return iree_make_status( IREE_STATUS_OUT_OF_RANGE, "export ordinal %d out of range; executable contains %" PRIhsz " exports", - ordinal, executable->export_count); + ordinal, data->export_count); } - *out_params = &executable->exports[ordinal]; + *out_params = &data->exports[ordinal]; return iree_ok_status(); } diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.h b/runtime/src/iree/hal/drivers/hip/native_executable.h index b67f5e73e599..52fe9300878a 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.h +++ b/runtime/src/iree/hal/drivers/hip/native_executable.h @@ -14,10 +14,7 @@ #include "iree/hal/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_headers.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus +#include "iree/hal/drivers/hip/per_device_information.h" // The max number of per-dispatch bindings allowed in the HIP HAL // implementation. @@ -48,18 +45,16 @@ typedef struct iree_hal_hip_kernel_params_t { // Creates an IREE executable from a HSACO module. The module may contain // several kernels that can be extracted along with the associated block size. iree_status_t iree_hal_hip_native_executable_create( - const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - hipCtx_t context, const iree_hal_executable_params_t* executable_params, + const iree_hal_hip_dynamic_symbols_t* symbols, + iree_hal_hip_device_topology_t topology, + const iree_hal_executable_params_t* executable_params, iree_allocator_t host_allocator, iree_hal_executable_t** out_executable); // Returns the kernel launch parameters for the given |entry_point| in the // |executable|. iree_status_t iree_hal_hip_native_executable_lookup_kernel_params( iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_queue_affinity_t queue_affinity, const iree_hal_hip_kernel_params_t** out_params); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_NATIVE_EXECUTABLE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c index 9680e3bf9f90..672d2ba5344e 100644 --- a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c +++ b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.c @@ -21,9 +21,7 @@ typedef struct iree_hal_hip_nop_executable_cache_t { iree_allocator_t host_allocator; const iree_hal_hip_dynamic_symbols_t* symbols; - - hipDevice_t device; - hipCtx_t hip_context; + iree_hal_hip_device_topology_t topology; } iree_hal_hip_nop_executable_cache_t; static const iree_hal_executable_cache_vtable_t @@ -38,8 +36,8 @@ iree_hal_hip_nop_executable_cache_cast( iree_status_t iree_hal_hip_nop_executable_cache_create( iree_string_view_t identifier, - const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - hipCtx_t hip_context, iree_allocator_t host_allocator, + const iree_hal_hip_dynamic_symbols_t* symbols, + iree_hal_hip_device_topology_t topology, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(out_executable_cache); @@ -55,8 +53,8 @@ iree_status_t iree_hal_hip_nop_executable_cache_create( &executable_cache->resource); executable_cache->host_allocator = host_allocator; executable_cache->symbols = symbols; - executable_cache->device = device; - executable_cache->hip_context = hip_context; + + executable_cache->topology = topology; *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; @@ -91,8 +89,7 @@ static iree_status_t iree_hal_hip_nop_executable_cache_prepare_executable( iree_hal_hip_nop_executable_cache_t* executable_cache = iree_hal_hip_nop_executable_cache_cast(base_executable_cache); return iree_hal_hip_native_executable_create( - executable_cache->symbols, executable_cache->device, - executable_cache->hip_context, executable_params, + executable_cache->symbols, executable_cache->topology, executable_params, executable_cache->host_allocator, out_executable); } diff --git a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h index 795aa21b53c9..fb0a9dd8da1b 100644 --- a/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h +++ b/runtime/src/iree/hal/drivers/hip/nop_executable_cache.h @@ -11,22 +11,15 @@ #include "iree/hal/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" #include "iree/hal/drivers/hip/hip_headers.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus +#include "iree/hal/drivers/hip/per_device_information.h" // Creates a no-op executable cache that does not cache at all. // This is useful to isolate pipeline caching behavior and verify compilation // behavior. iree_status_t iree_hal_hip_nop_executable_cache_create( iree_string_view_t identifier, - const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - hipCtx_t hip_context, iree_allocator_t host_allocator, + const iree_hal_hip_dynamic_symbols_t* symbols, + iree_hal_hip_device_topology_t topology, iree_allocator_t host_allocator, iree_hal_executable_cache_t** out_executable_cache); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_NOP_EXECUTABLE_CACHE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/per_device_information.h b/runtime/src/iree/hal/drivers/hip/per_device_information.h new file mode 100644 index 000000000000..88c332f1b3e9 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/per_device_information.h @@ -0,0 +1,37 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_PER_DEVICE_INFORMATION_H_ +#define IREE_HAL_DRIVERS_HIP_PER_DEVICE_INFORMATION_H_ + +#include "iree/hal/drivers/hip/dispatch_thread.h" +#include "iree/hal/drivers/hip/hip_headers.h" +#include "iree/hal/drivers/hip/memory_pools.h" + +typedef struct iree_hal_stream_tracing_context_t + iree_hal_stream_tracing_context_t; +typedef struct iree_hal_hip_event_pool_t iree_hal_hip_event_pool_t; + +typedef struct iree_hal_hip_per_device_info_t { + hipCtx_t hip_context; + hipDevice_t hip_device; + hipStream_t hip_dispatch_stream; + + iree_hal_stream_tracing_context_t* tracing_context; + + iree_hal_hip_event_pool_t* device_event_pool; + + iree_hal_hip_dispatch_thread_t* dispatch_thread; + + iree_hal_hip_memory_pools_t memory_pools; +} iree_hal_hip_per_device_info_t; + +typedef struct iree_hal_hip_device_topology_t { + iree_host_size_t count; + iree_hal_hip_per_device_info_t* devices; +} iree_hal_hip_device_topology_t; + +#endif // IREE_HAL_DRIVERS_HIP_PER_DEVICE_INFORMATION_H_ diff --git a/runtime/src/iree/hal/drivers/hip/rccl_channel.h b/runtime/src/iree/hal/drivers/hip/rccl_channel.h index 4f25dd6e6356..3447c7cc584e 100644 --- a/runtime/src/iree/hal/drivers/hip/rccl_channel.h +++ b/runtime/src/iree/hal/drivers/hip/rccl_channel.h @@ -15,10 +15,6 @@ #include "iree/hal/utils/collective_batch.h" #include "iree/hal/utils/stream_tracing.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - // Returns true if |id| is all zeros indicating an empty ID. static inline bool iree_hal_hip_nccl_id_is_empty( const iree_hal_hip_nccl_id_t* id) { @@ -52,8 +48,4 @@ iree_status_t iree_hal_hip_nccl_submit_batch( iree_hal_stream_tracing_context_event_list_t* tracing_event_list, const iree_hal_collective_batch_t* batch, hipStream_t stream); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_RCCL_CHANNEL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/rccl_status_util.h b/runtime/src/iree/hal/drivers/hip/rccl_status_util.h index 960961b9a9ec..0e4d33e6663f 100644 --- a/runtime/src/iree/hal/drivers/hip/rccl_status_util.h +++ b/runtime/src/iree/hal/drivers/hip/rccl_status_util.h @@ -12,10 +12,6 @@ #include "iree/base/api.h" #include "iree/hal/drivers/hip/rccl_dynamic_symbols.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - // Converts a ncclResult_t to an iree_status_t. // // Usage: @@ -61,8 +57,4 @@ iree_status_t iree_hal_hip_nccl_result_to_status( const iree_hal_hip_nccl_dynamic_symbols_t* syms, ncclResult_t result, const char* file, uint32_t line); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_RCCL_STATUS_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/status_util.h b/runtime/src/iree/hal/drivers/hip/status_util.h index 221f55fe0214..f9ed8ef7af0f 100644 --- a/runtime/src/iree/hal/drivers/hip/status_util.h +++ b/runtime/src/iree/hal/drivers/hip/status_util.h @@ -12,17 +12,21 @@ #include "iree/base/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus +// Converts a call into the hip driver into an iree_status_t. +// +// Usage: +// iree_status_t status = IREE_HIP_CALL_TO_STATUS(hip_symbols, +// hipDoThing(...)); +#define IREE_HIP_CALL_TO_STATUS(syms, expr, ...) \ + iree_hal_hip_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__) -// Converts a hipError_t to an iree_status_t. +// Converts hip status into an iree_status_t. // // Usage: // iree_status_t status = IREE_HIP_RESULT_TO_STATUS(hip_symbols, // hipDoThing(...)); -#define IREE_HIP_RESULT_TO_STATUS(syms, expr, ...) \ - iree_hal_hip_result_to_status((syms), ((syms)->expr), __FILE__, __LINE__) +#define IREE_HIP_RESULT_TO_STATUS(syms, result, ...) \ + iree_hal_hip_result_to_status((syms), (result), __FILE__, __LINE__) // IREE_RETURN_IF_ERROR but implicitly converts the hipError_t return value to // an iree_status_t. @@ -61,8 +65,4 @@ iree_status_t iree_hal_hip_result_to_status( const iree_hal_hip_dynamic_symbols_t* syms, hipError_t result, const char* file, uint32_t line); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - #endif // IREE_HAL_DRIVERS_HIP_STATUS_UTIL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index 6a201ca2976c..d0dcea11678c 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -7,7 +7,6 @@ #include "iree/hal/drivers/hip/stream_command_buffer.h" -#include "iree/hal/drivers/hip/context_util.h" #include "iree/hal/drivers/hip/hip_buffer.h" #include "iree/hal/drivers/hip/native_executable.h" #include "iree/hal/drivers/hip/rccl_channel.h" @@ -28,7 +27,6 @@ typedef struct iree_hal_hip_stream_command_buffer_t { iree_hal_stream_tracing_context_event_list_t tracing_event_list; hipStream_t hip_stream; - hipCtx_t hip_context; // A resource set to maintain references to all resources used within the // command buffer. Reset on each begin. @@ -56,11 +54,12 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( iree_hal_allocator_t* device_allocator, const iree_hal_hip_dynamic_symbols_t* hip_symbols, const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, - hipCtx_t hip_context, iree_hal_stream_tracing_context_t* tracing_context, + iree_hal_stream_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, - iree_host_size_t binding_capacity, hipStream_t stream, - iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + hipStream_t stream, iree_arena_block_pool_t* block_pool, + iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer) { IREE_ASSERT_ARGUMENT(device_allocator); IREE_ASSERT_ARGUMENT(hip_symbols); @@ -75,8 +74,6 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( } IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(hip_symbols, hip_context)); iree_hal_hip_stream_command_buffer_t* command_buffer = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( @@ -88,7 +85,7 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( (void**)&command_buffer)); iree_hal_command_buffer_initialize( - device_allocator, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + device_allocator, mode, command_categories, queue_affinity, binding_capacity, (uint8_t*)command_buffer + sizeof(*command_buffer), &iree_hal_hip_stream_command_buffer_vtable, &command_buffer->base); command_buffer->host_allocator = host_allocator; @@ -98,7 +95,6 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( command_buffer->tracing_event_list.head = NULL; command_buffer->tracing_event_list.tail = NULL; command_buffer->hip_stream = stream; - command_buffer->hip_context = hip_context; iree_arena_initialize(block_pool, &command_buffer->arena); iree_status_t status = @@ -121,8 +117,6 @@ static void iree_hal_hip_stream_command_buffer_destroy( iree_hal_hip_stream_command_buffer_cast(base_command_buffer); iree_allocator_t host_allocator = command_buffer->host_allocator; IREE_TRACE_ZONE_BEGIN(z0); - IREE_IGNORE_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); iree_hal_stream_tracing_free(command_buffer->tracing_context, &command_buffer->tracing_event_list); @@ -179,8 +173,7 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); + (void)command_buffer; IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -196,9 +189,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -230,8 +220,7 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin_debug_group( const iree_hal_label_location_t* location) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); + (void)command_buffer; IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -239,7 +228,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_begin_debug_group( location ? location->file.data : NULL, location ? location->file.size : 0, location ? location->line : 0, /*func_name=*/NULL, 0, label.data, label.size); - return iree_ok_status(); } @@ -247,13 +235,11 @@ static iree_status_t iree_hal_hip_stream_command_buffer_end_debug_group( iree_hal_command_buffer_t* base_command_buffer) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); + (void)command_buffer; IREE_HAL_STREAM_TRACE_ZONE_END(command_buffer->tracing_context, &command_buffer->tracing_event_list, IREE_HAL_STREAM_TRACING_VERBOSITY_COARSE); - return iree_ok_status(); } @@ -268,8 +254,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_execution_barrier( const iree_hal_buffer_barrier_t* buffer_barriers) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); - IREE_RETURN_IF_ERROR(iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); if (iree_any_bit_set(source_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST) || iree_any_bit_set(target_stage_mask, IREE_HAL_EXECUTION_STAGE_HOST)) { @@ -334,9 +318,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_fill_buffer( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -390,9 +371,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_update_buffer( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -436,9 +414,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_copy_buffer( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); @@ -471,9 +446,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_collective( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); iree_hal_buffer_binding_t send_binding = { .buffer = send_ref.buffer, @@ -501,9 +473,6 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_set_context(command_buffer->hip_symbols, - command_buffer->hip_context)); // If any of the workgroup counts are zero, we can skip execution // of the kernel. This prevents a 'hipErrorInvalidConfiguration' error when @@ -522,7 +491,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( const iree_hal_hip_kernel_params_t* kernel_params = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_hip_native_executable_lookup_kernel_params( - executable, entry_point, &kernel_params)); + executable, entry_point, command_buffer->base.queue_affinity, + &kernel_params)); IREE_HAL_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( command_buffer->tracing_context, &command_buffer->tracing_event_list, @@ -558,7 +528,7 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( void** params_ptr = (void**)storage_base; hipDeviceptr_t* payload_ptr = (hipDeviceptr_t*)((uint8_t*)params_ptr + kernel_params_length); - for (size_t i = 0; i < kernel_params_count; i++) { + for (iree_host_size_t i = 0; i < kernel_params_count; i++) { params_ptr[i] = &payload_ptr[i]; } for (iree_host_size_t i = 0; i < bindings.count; i++) { @@ -585,7 +555,7 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( ((const uint32_t*)constants.data)[i]; } - iree_status_t status = IREE_HIP_RESULT_TO_STATUS( + iree_status_t status = IREE_HIP_CALL_TO_STATUS( command_buffer->hip_symbols, hipModuleLaunchKernel( kernel_params->function, workgroup_count[0], workgroup_count[1], @@ -612,6 +582,14 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch_indirect( "indirect dispatch not yet implemented"); } +iree_hal_stream_tracing_context_event_list_t +iree_hal_hip_stream_command_buffer_tracing_events( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_hip_stream_command_buffer_t* command_buffer = + iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + return command_buffer->tracing_event_list; +} + static const iree_hal_command_buffer_vtable_t iree_hal_hip_stream_command_buffer_vtable = { .destroy = iree_hal_hip_stream_command_buffer_destroy, diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h index 43820866d8cd..da388c1e253b 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.h @@ -14,10 +14,6 @@ #include "iree/hal/drivers/hip/rccl_dynamic_symbols.h" #include "iree/hal/utils/stream_tracing.h" -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - // Creates command buffer that immediately issues commands against the given // HIP |stream|. Access to |stream| must be synchronized by the user. // @@ -33,11 +29,12 @@ iree_status_t iree_hal_hip_stream_command_buffer_create( iree_hal_allocator_t* device_allocator, const iree_hal_hip_dynamic_symbols_t* hip_symbols, const iree_hal_hip_nccl_dynamic_symbols_t* nccl_symbols, - hipCtx_t hip_context, iree_hal_stream_tracing_context_t* tracing_context, + iree_hal_stream_tracing_context_t* tracing_context, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, - iree_host_size_t binding_capacity, hipStream_t stream, - iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, + iree_host_size_t binding_capacity, iree_hal_queue_affinity_t queue_affinity, + hipStream_t stream, iree_arena_block_pool_t* block_pool, + iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer); // Returns true if |command_buffer| is a HIP stream-based command buffer. @@ -49,8 +46,11 @@ bool iree_hal_hip_stream_command_buffer_isa( // to collect. void iree_hal_hip_stream_notify_submitted_commands( iree_hal_command_buffer_t* base_command_buffer); -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus + +// Returns the set of tracing events that are associated with +// this command buffer. +iree_hal_stream_tracing_context_event_list_t +iree_hal_hip_stream_command_buffer_tracing_events( + iree_hal_command_buffer_t* base_command_buffer); #endif // IREE_HAL_DRIVERS_HIP_STREAM_COMMAND_BUFFER_H_ diff --git a/runtime/src/iree/hal/drivers/hip/timepoint_pool.c b/runtime/src/iree/hal/drivers/hip/timepoint_pool.c deleted file mode 100644 index fd0e6bcacded..000000000000 --- a/runtime/src/iree/hal/drivers/hip/timepoint_pool.c +++ /dev/null @@ -1,352 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/hal/drivers/hip/timepoint_pool.h" - -#include -#include -#include - -#include "iree/base/api.h" -#include "iree/base/internal/atomics.h" -#include "iree/base/internal/event_pool.h" -#include "iree/base/internal/synchronization.h" -#include "iree/hal/api.h" -#include "iree/hal/drivers/hip/dynamic_symbols.h" -#include "iree/hal/drivers/hip/event_pool.h" -#include "iree/hal/drivers/hip/status_util.h" -#include "iree/hal/utils/semaphore_base.h" - -//===----------------------------------------------------------------------===// -// iree_hal_hip_timepoint_t -//===----------------------------------------------------------------------===// - -static iree_status_t iree_hal_hip_timepoint_allocate( - iree_hal_hip_timepoint_pool_t* pool, iree_allocator_t host_allocator, - iree_hal_hip_timepoint_t** out_timepoint) { - IREE_ASSERT_ARGUMENT(pool); - IREE_ASSERT_ARGUMENT(out_timepoint); - *out_timepoint = NULL; - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_hip_timepoint_t* timepoint = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(host_allocator, sizeof(*timepoint), - (void**)&timepoint)); - // iree_allocator_malloc zeros out the whole struct. - timepoint->host_allocator = host_allocator; - timepoint->pool = pool; - - *out_timepoint = timepoint; - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -// Clears all data fields in the given |timepoint| except the original host -// allocator and owning pool. -static void iree_hal_hip_timepoint_clear(iree_hal_hip_timepoint_t* timepoint) { - iree_allocator_t host_allocator = timepoint->host_allocator; - iree_hal_hip_timepoint_pool_t* pool = timepoint->pool; - memset(timepoint, 0, sizeof(*timepoint)); - timepoint->host_allocator = host_allocator; - timepoint->pool = pool; -} - -static void iree_hal_hip_timepoint_free(iree_hal_hip_timepoint_t* timepoint) { - iree_allocator_t host_allocator = timepoint->host_allocator; - IREE_TRACE_ZONE_BEGIN(z0); - - IREE_ASSERT(timepoint->kind == IREE_HAL_HIP_TIMEPOINT_KIND_NONE); - iree_allocator_free(host_allocator, timepoint); - - IREE_TRACE_ZONE_END(z0); -} - -//===----------------------------------------------------------------------===// -// iree_hal_hip_timepoint_pool_t -//===----------------------------------------------------------------------===// - -struct iree_hal_hip_timepoint_pool_t { - // The allocator used to create the timepoint pool. - iree_allocator_t host_allocator; - - // The pool to acquire host events. - iree_event_pool_t* host_event_pool; - // The pool to acquire device events. Internally synchronized. - iree_hal_hip_event_pool_t* device_event_pool; - - // Note that the above pools are internally synchronized; so we don't and - // shouldn't use the following mutex to guard access to them. - - // Guards timepoint related fields this pool. We don't expect a performant - // program to frequently allocate timepoints for synchronization purposes; the - // traffic to this pool should be low. So it should be fine to use mutex to - // guard here. - iree_slim_mutex_t timepoint_mutex; - - // Maximum number of timepoint objects that will be maintained in the pool. - // More timepoints may be allocated at any time, but they will be disposed - // directly when they are no longer needed. - iree_host_size_t available_capacity IREE_GUARDED_BY(timepoint_mutex); - // Total number of currently available timepoint objects. - iree_host_size_t available_count IREE_GUARDED_BY(timepoint_mutex); - // The list of available_count timepoint objects. - iree_hal_hip_timepoint_t* available_list[] IREE_GUARDED_BY(timepoint_mutex); -}; -// + Additional inline allocation for holding timepoints up to the capacity. - -iree_status_t iree_hal_hip_timepoint_pool_allocate( - iree_event_pool_t* host_event_pool, - iree_hal_hip_event_pool_t* device_event_pool, - iree_host_size_t available_capacity, iree_allocator_t host_allocator, - iree_hal_hip_timepoint_pool_t** out_timepoint_pool) { - IREE_ASSERT_ARGUMENT(host_event_pool); - IREE_ASSERT_ARGUMENT(device_event_pool); - IREE_ASSERT_ARGUMENT(out_timepoint_pool); - *out_timepoint_pool = NULL; - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_hip_timepoint_pool_t* timepoint_pool = NULL; - iree_host_size_t total_size = - sizeof(*timepoint_pool) + - available_capacity * sizeof(*timepoint_pool->available_list); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(host_allocator, total_size, - (void**)&timepoint_pool)); - timepoint_pool->host_allocator = host_allocator; - timepoint_pool->host_event_pool = host_event_pool; - timepoint_pool->device_event_pool = device_event_pool; - - iree_slim_mutex_initialize(&timepoint_pool->timepoint_mutex); - timepoint_pool->available_capacity = available_capacity; - timepoint_pool->available_count = 0; - - iree_status_t status = iree_ok_status(); - for (iree_host_size_t i = 0; i < available_capacity; ++i) { - status = iree_hal_hip_timepoint_allocate( - timepoint_pool, host_allocator, - &timepoint_pool->available_list[timepoint_pool->available_count++]); - if (!iree_status_is_ok(status)) break; - } - - if (iree_status_is_ok(status)) { - *out_timepoint_pool = timepoint_pool; - } else { - iree_hal_hip_timepoint_pool_free(timepoint_pool); - } - IREE_TRACE_ZONE_END(z0); - return status; -} - -void iree_hal_hip_timepoint_pool_free( - iree_hal_hip_timepoint_pool_t* timepoint_pool) { - iree_allocator_t host_allocator = timepoint_pool->host_allocator; - IREE_TRACE_ZONE_BEGIN(z0); - - for (iree_host_size_t i = 0; i < timepoint_pool->available_count; ++i) { - iree_hal_hip_timepoint_free(timepoint_pool->available_list[i]); - } - iree_slim_mutex_deinitialize(&timepoint_pool->timepoint_mutex); - iree_allocator_free(host_allocator, timepoint_pool); - - IREE_TRACE_ZONE_END(z0); -} - -// Acquires |timepoint_count| timepoints from the given |timepoint_pool|. -// The |out_timepoints| needs to be further initialized with proper kind and -// payload values. -static iree_status_t iree_hal_hip_timepoint_pool_acquire_internal( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, - iree_hal_hip_timepoint_t** out_timepoints) { - IREE_ASSERT_ARGUMENT(timepoint_pool); - if (!timepoint_count) return iree_ok_status(); - IREE_ASSERT_ARGUMENT(out_timepoints); - IREE_TRACE_ZONE_BEGIN(z0); - - // We'll try to get what we can from the pool and fall back to initializing - // new iree_hal_hip_timepoint_t objects. - iree_host_size_t remaining_count = timepoint_count; - - // Try first to grab from the pool. - iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex); - iree_host_size_t from_pool_count = - iree_min(timepoint_pool->available_count, timepoint_count); - if (from_pool_count > 0) { - iree_host_size_t pool_base_index = - timepoint_pool->available_count - from_pool_count; - memcpy(out_timepoints, &timepoint_pool->available_list[pool_base_index], - from_pool_count * sizeof(*timepoint_pool->available_list)); - timepoint_pool->available_count -= from_pool_count; - remaining_count -= from_pool_count; - } - iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex); - - // Allocate the rest of the timepoints. - if (remaining_count > 0) { - IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-acquire"); - iree_status_t status = iree_ok_status(); - for (iree_host_size_t i = 0; i < remaining_count; ++i) { - status = iree_hal_hip_timepoint_allocate( - timepoint_pool, timepoint_pool->host_allocator, - &out_timepoints[from_pool_count + i]); - if (!iree_status_is_ok(status)) { - // Must release all timepoints we've acquired so far. - iree_hal_hip_timepoint_pool_release(timepoint_pool, from_pool_count + i, - out_timepoints); - IREE_TRACE_ZONE_END(z1); - IREE_TRACE_ZONE_END(z0); - return status; - } - } - IREE_TRACE_ZONE_END(z1); - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -iree_status_t iree_hal_hip_timepoint_pool_acquire_host_wait( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, - iree_hal_hip_timepoint_t** out_timepoints) { - IREE_TRACE_ZONE_BEGIN(z0); - - // Acquire host events to wrap up. This should happen before acquiring the - // timepoints to avoid nested locks. - iree_event_t* host_events = iree_alloca( - timepoint_count * sizeof((*out_timepoints)->timepoint.host_wait)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_event_pool_acquire(timepoint_pool->host_event_pool, - timepoint_count, host_events)); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_timepoint_pool_acquire_internal( - timepoint_pool, timepoint_count, out_timepoints)); - for (iree_host_size_t i = 0; i < timepoint_count; ++i) { - out_timepoints[i]->kind = IREE_HAL_HIP_TIMEPOINT_KIND_HOST_WAIT; - out_timepoints[i]->timepoint.host_wait = host_events[i]; - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -iree_status_t iree_hal_hip_timepoint_pool_acquire_device_signal( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, - iree_hal_hip_timepoint_t** out_timepoints) { - IREE_TRACE_ZONE_BEGIN(z0); - - // Acquire device events to wrap up. This should happen before acquiring the - // timepoints to avoid nested locks. - iree_hal_hip_event_t** device_events = iree_alloca( - timepoint_count * sizeof((*out_timepoints)->timepoint.device_signal)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_event_pool_acquire(timepoint_pool->device_event_pool, - timepoint_count, device_events)); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_timepoint_pool_acquire_internal( - timepoint_pool, timepoint_count, out_timepoints)); - for (iree_host_size_t i = 0; i < timepoint_count; ++i) { - out_timepoints[i]->kind = IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_SIGNAL; - out_timepoints[i]->timepoint.device_signal = device_events[i]; - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -iree_status_t iree_hal_hip_timepoint_pool_acquire_device_wait( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, - iree_hal_hip_timepoint_t** out_timepoints) { - IREE_TRACE_ZONE_BEGIN(z0); - - // Acquire device events to wrap up. This should happen before acquiring the - // timepoints to avoid nested locks. - iree_hal_hip_event_t** device_events = iree_alloca( - timepoint_count * sizeof((*out_timepoints)->timepoint.device_wait)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_event_pool_acquire(timepoint_pool->device_event_pool, - timepoint_count, device_events)); - - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_timepoint_pool_acquire_internal( - timepoint_pool, timepoint_count, out_timepoints)); - for (iree_host_size_t i = 0; i < timepoint_count; ++i) { - out_timepoints[i]->kind = IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_WAIT; - out_timepoints[i]->timepoint.device_wait = device_events[i]; - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -void iree_hal_hip_timepoint_pool_release( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, iree_hal_hip_timepoint_t** timepoints) { - IREE_ASSERT_ARGUMENT(timepoint_pool); - if (!timepoint_count) return; - IREE_ASSERT_ARGUMENT(timepoints); - IREE_TRACE_ZONE_BEGIN(z0); - - // Release the wrapped host/device events. This should happen before acquiring - // the timepoint pool's lock given that the host/device event pool its - // internal lock too. - // TODO: Release in batch to avoid lock overhead from separate calls. - for (iree_host_size_t i = 0; i < timepoint_count; ++i) { - switch (timepoints[i]->kind) { - case IREE_HAL_HIP_TIMEPOINT_KIND_HOST_WAIT: - iree_event_pool_release(timepoint_pool->host_event_pool, 1, - &timepoints[i]->timepoint.host_wait); - break; - case IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_SIGNAL: - iree_hal_hip_event_release(timepoints[i]->timepoint.device_signal); - break; - case IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_WAIT: - iree_hal_hip_event_release(timepoints[i]->timepoint.device_wait); - break; - default: - break; - } - } - - // We'll try to release all we can back to the pool and then deinitialize - // the ones that won't fit. - iree_host_size_t remaining_count = timepoint_count; - - // Try first to release to the pool. - iree_slim_mutex_lock(&timepoint_pool->timepoint_mutex); - iree_host_size_t to_pool_count = iree_min( - timepoint_pool->available_capacity - timepoint_pool->available_count, - timepoint_count); - if (to_pool_count > 0) { - for (iree_host_size_t i = 0; i < to_pool_count; ++i) { - iree_hal_hip_timepoint_clear(timepoints[i]); - } - iree_host_size_t pool_base_index = timepoint_pool->available_count; - memcpy(&timepoint_pool->available_list[pool_base_index], timepoints, - to_pool_count * sizeof(*timepoint_pool->available_list)); - timepoint_pool->available_count += to_pool_count; - remaining_count -= to_pool_count; - } - iree_slim_mutex_unlock(&timepoint_pool->timepoint_mutex); - - // Deallocate the rest of the timepoints. We don't bother resetting them as we - // are getting rid of them. - if (remaining_count > 0) { - IREE_TRACE_ZONE_BEGIN_NAMED(z1, "timepoint-pool-unpooled-release"); - for (iree_host_size_t i = 0; i < remaining_count; ++i) { - iree_hal_hip_timepoint_clear(timepoints[to_pool_count + i]); - iree_hal_hip_timepoint_free(timepoints[to_pool_count + i]); - } - IREE_TRACE_ZONE_END(z1); - } - IREE_TRACE_ZONE_END(z0); -} diff --git a/runtime/src/iree/hal/drivers/hip/timepoint_pool.h b/runtime/src/iree/hal/drivers/hip/timepoint_pool.h deleted file mode 100644 index 0f8c769ec716..000000000000 --- a/runtime/src/iree/hal/drivers/hip/timepoint_pool.h +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_HAL_DRIVERS_HIP_TIMEPOINT_POOL_H_ -#define IREE_HAL_DRIVERS_HIP_TIMEPOINT_POOL_H_ - -#include "iree/base/api.h" -#include "iree/base/internal/event_pool.h" -#include "iree/hal/drivers/hip/event_pool.h" -#include "iree/hal/utils/semaphore_base.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -//===----------------------------------------------------------------------===// -// iree_hal_hip_timepoint_t -//===----------------------------------------------------------------------===// - -// Forward declaration of the timepoint pool. -typedef struct iree_hal_hip_timepoint_pool_t iree_hal_hip_timepoint_pool_t; - -// An enum to identify the timepoint kind in iree_hal_hip_timepoint_t objects. -typedef enum iree_hal_hip_timepoint_kind_e { - // None; for uninitialized timepoint objects. - IREE_HAL_HIP_TIMEPOINT_KIND_NONE = 0, - // A timepoint waited by the host. - IREE_HAL_HIP_TIMEPOINT_KIND_HOST_WAIT, - // A timepoint signaled by the device. - IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_SIGNAL, - // A timepoint waited by the device. - IREE_HAL_HIP_TIMEPOINT_KIND_DEVICE_WAIT, -} iree_hal_hip_timepoint_kind_t; - -// An object that wraps a host iree_event_t or device iree_hal_hip_event_t to -// represent wait/signal of a timepoint on a timeline. -// -// iree_hal_hip_timepoint_t objects cannot be directly created; it should be -// acquired from the timeline pool and released back to the pool once done. -// -// Thread-compatible; a timepoint is typically only accessed by one thread. -typedef struct iree_hal_hip_timepoint_t { - // Base timepoint structure providing intrusive linked list pointers and - // timepoint callback mechanisms. - iree_hal_semaphore_timepoint_t base; - - // The allocator used to create the timepoint. - iree_allocator_t host_allocator; - - // The timepoint pool that owns this timepoint. - iree_hal_hip_timepoint_pool_t* pool; - - iree_hal_hip_timepoint_kind_t kind; - union { - iree_event_t host_wait; - iree_hal_hip_event_t* device_signal; - // The device event to wait. NULL means no device event available to wait - // for this timepoint at the moment. - iree_hal_hip_event_t* device_wait; - } timepoint; -} iree_hal_hip_timepoint_t; - -//===----------------------------------------------------------------------===// -// iree_hal_hip_timepoint_pool_t -//===----------------------------------------------------------------------===// - -// A simple pool of iree_hal_hip_timepoint_t objects to recycle. -// -// Thread-safe; multiple threads may acquire and release timepoints from the -// pool. -typedef struct iree_hal_hip_timepoint_pool_t iree_hal_hip_timepoint_pool_t; - -// Allocates a new timepoint pool with up to |available_capacity| timepoints. -// -// Extra timepoint requests beyond the capability are directly created and -// destroyed without pooling. -iree_status_t iree_hal_hip_timepoint_pool_allocate( - iree_event_pool_t* host_event_pool, - iree_hal_hip_event_pool_t* device_event_pool, - iree_host_size_t available_capacity, iree_allocator_t host_allocator, - iree_hal_hip_timepoint_pool_t** out_timepoint_pool); - -// Deallocates a timepoint pool and destroys all timepoints. -// -// All timepoints that were acquired from the pool must have already been -// released back to it prior to deallocation. -void iree_hal_hip_timepoint_pool_free( - iree_hal_hip_timepoint_pool_t* timepoint_pool); - -// Acquires one or more timepoints from the timepoint pool. -// -// |out_timepoints| are owned by the caller and must be kept live until the -// timepoints have been reached, or cancelled by the caller. -iree_status_t iree_hal_hip_timepoint_pool_acquire_host_wait( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, - iree_hal_hip_timepoint_t** out_timepoints); -iree_status_t iree_hal_hip_timepoint_pool_acquire_device_signal( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, - iree_hal_hip_timepoint_t** out_timepoints); -iree_status_t iree_hal_hip_timepoint_pool_acquire_device_wait( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, - iree_hal_hip_timepoint_t** out_timepoints); - -// Releases one or more timepoints back to the timepoint pool. -void iree_hal_hip_timepoint_pool_release( - iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_host_size_t timepoint_count, iree_hal_hip_timepoint_t** timepoints); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // IREE_HAL_DRIVERS_HIP_TIMEPOINT_POOL_H_ diff --git a/runtime/src/iree/hal/drivers/hip/util/CMakeLists.txt b/runtime/src/iree/hal/drivers/hip/util/CMakeLists.txt new file mode 100644 index 000000000000..8fd817addcfd --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/util/CMakeLists.txt @@ -0,0 +1,41 @@ +# Copyright 2023 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +iree_cc_library( + NAME + hip_util + HDRS + "queue.h" + "tree.h" + SRCS + "queue.c" + "tree.c" + DEPS + iree::base + PUBLIC +) + +iree_cc_test( + NAME + iree_hal_hip_util_queue_test + SRCS + "queue_test.cc" + DEPS + ::hip_util + iree::testing::gtest + iree::testing::gtest_main +) + +iree_cc_test( + NAME + iree_hal_hip_util_tree_test + SRCS + "tree_test.cc" + DEPS + ::hip_util + iree::testing::gtest + iree::testing::gtest_main +) diff --git a/runtime/src/iree/hal/drivers/hip/util/queue.c b/runtime/src/iree/hal/drivers/hip/util/queue.c new file mode 100644 index 000000000000..26e9f41513c3 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/util/queue.c @@ -0,0 +1,88 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/hip/util/queue.h" + +void iree_hal_hip_util_queue_initialize(iree_allocator_t allocator, + iree_host_size_t element_size, + iree_host_size_t inline_count, + iree_hal_hip_util_queue_t* out_queue) { + out_queue->allocator = allocator; + out_queue->elements = &out_queue->initial_allocation[0]; + out_queue->element_size = element_size; + out_queue->element_count = 0; + out_queue->capacity = inline_count; + out_queue->head = 0; +} + +void iree_hal_hip_util_queue_deinitialize(iree_hal_hip_util_queue_t* queue) { + IREE_ASSERT_ARGUMENT(queue); + if (queue->elements != &queue->initial_allocation[0]) { + iree_allocator_free(queue->allocator, queue->elements); + } +} + +iree_status_t iree_hal_hip_util_queue_push_back( + iree_hal_hip_util_queue_t* queue, void* element) { + // Expand the queue if necessary. + if (queue->capacity == queue->element_count) { + uint8_t* new_mem = NULL; + queue->capacity = iree_max(16, queue->capacity * 2); + if (queue->elements == &queue->initial_allocation[0]) { + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + queue->allocator, queue->element_size * queue->capacity, + (void**)&new_mem)); + memcpy(new_mem, queue->elements + (queue->head * queue->element_size), + (queue->element_count - queue->head) * queue->element_size); + memcpy(new_mem + + ((queue->element_count - queue->head) * queue->element_size), + queue->elements, queue->head * queue->element_size); + queue->head = 0; + } else { + new_mem = queue->elements; + IREE_RETURN_IF_ERROR(iree_allocator_realloc( + queue->allocator, queue->element_size * queue->capacity, + (void**)&new_mem)); + const iree_host_size_t num_head_elements = + queue->element_count - queue->head; + const iree_host_size_t num_wrapped_elements = + queue->element_count - num_head_elements; + + // If we have wrapped elements, then we move them to the end after the + // head, since we have at least doubled the size of out array, there is + // enough room. + if (num_wrapped_elements) { + memcpy( + new_mem + (queue->head + num_head_elements) * queue->element_size, + new_mem, num_wrapped_elements * queue->element_size); + } + } + queue->elements = new_mem; + } + memcpy(queue->elements + + (((queue->head + queue->element_count) % queue->capacity) * + queue->element_size), + element, queue->element_size); + ++queue->element_count; + return iree_ok_status(); +} + +void iree_hal_hip_util_queue_pop_front(iree_hal_hip_util_queue_t* queue, + iree_host_size_t count) { + IREE_ASSERT_ARGUMENT(queue); + IREE_ASSERT_LE(count, queue->element_count, "Popping too many elements"); + queue->head += count; + queue->head = queue->head % queue->capacity; + queue->element_count -= count; +} + +void* iree_hal_hip_util_queue_at(const iree_hal_hip_util_queue_t* queue, + iree_host_size_t i) { + IREE_ASSERT_ARGUMENT(queue); + IREE_ASSERT_LT(i, queue->element_count, "Index out of range"); + return queue->elements + + ((queue->head + i) % queue->capacity) * queue->element_size; +} diff --git a/runtime/src/iree/hal/drivers/hip/util/queue.h b/runtime/src/iree/hal/drivers/hip/util/queue.h new file mode 100644 index 000000000000..f8c2b81efd05 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/util/queue.h @@ -0,0 +1,100 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_UTIL_QUEUE_H_ +#define IREE_HAL_DRIVERS_HIP_UTIL_QUEUE_H_ + +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// A circular array where we can push to the back and pop from the front. +// The helper functions allow you to index into the array. Furthermore an +// initial allocation may be provided inline as an optimization. +typedef struct iree_hal_hip_util_queue_t { + iree_allocator_t allocator; + uint8_t* elements; + iree_host_size_t element_size; + iree_host_size_t element_count; + iree_host_size_t capacity; + iree_host_size_t head; + uint8_t initial_allocation[]; +} iree_hal_hip_util_queue_t; + +// Initializes the queue with elements of the given |element_size|. +// +// Optionally |inline_count| can be provided to notify the queue +// that an initial allocation is present for the given number of elements. +void iree_hal_hip_util_queue_initialize(iree_allocator_t allocator, + iree_host_size_t element_size, + iree_host_size_t inline_count, + iree_hal_hip_util_queue_t* out_queue); + +// Deinitializes the list, it does not have to be empty. +void iree_hal_hip_util_queue_deinitialize(iree_hal_hip_util_queue_t* queue); + +// Copies the given element into the back of the array. This may cause a +// re-allocation of data. +iree_status_t iree_hal_hip_util_queue_push_back( + iree_hal_hip_util_queue_t* queue, void* element); + +// Pops the element from the front of the array and moves the head. +void iree_hal_hip_util_queue_pop_front(iree_hal_hip_util_queue_t* queue, + iree_host_size_t count); + +// Returns a pointer to the element at index i +void* iree_hal_hip_util_queue_at(const iree_hal_hip_util_queue_t* queue, + iree_host_size_t i); + +#define IREE_HAL_HIP_UTIL_TYPED_QUEUE_WRAPPER(name, type, \ + default_element_count) \ + typedef struct name##_t { \ + iree_allocator_t allocator; \ + void* elements; \ + iree_host_size_t element_size; \ + iree_host_size_t element_count; \ + iree_host_size_t capacity; \ + iree_host_size_t head; \ + iree_alignas(iree_max_align_t) uint8_t \ + initial_allocation[default_element_count * sizeof(type)]; \ + } name##_t; \ + static inline void name##_initialize(iree_allocator_t allocator, \ + name##_t* out_queue) { \ + iree_hal_hip_util_queue_initialize(allocator, sizeof(type), \ + default_element_count, \ + (iree_hal_hip_util_queue_t*)out_queue); \ + } \ + static inline void name##_deinitialize(name##_t* out_queue) { \ + iree_hal_hip_util_queue_deinitialize( \ + (iree_hal_hip_util_queue_t*)out_queue); \ + } \ + iree_status_t name##_push_back(name##_t* queue, type element) { \ + return iree_hal_hip_util_queue_push_back( \ + (iree_hal_hip_util_queue_t*)queue, &element); \ + } \ + void name##_pop_front(name##_t* queue, iree_host_size_t count) { \ + iree_hal_hip_util_queue_pop_front((iree_hal_hip_util_queue_t*)queue, \ + count); \ + } \ + type name##_at(name##_t* queue, iree_host_size_t i) { \ + type t; \ + memcpy(&t, \ + iree_hal_hip_util_queue_at((iree_hal_hip_util_queue_t*)queue, i), \ + sizeof(type)); \ + return t; \ + } \ + bool name##_empty(name##_t* queue) { return queue->element_count == 0; } \ + iree_host_size_t name##_count(name##_t* queue) { \ + return queue->element_count; \ + } + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_DRIVERS_HIP_UTIL_QUEUE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/util/queue_test.cc b/runtime/src/iree/hal/drivers/hip/util/queue_test.cc new file mode 100644 index 000000000000..aa282c6efa94 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/util/queue_test.cc @@ -0,0 +1,111 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/hip/util/queue.h" + +#include "iree/testing/gtest.h" +#include "iree/testing/status_matchers.h" + +IREE_HAL_HIP_UTIL_TYPED_QUEUE_WRAPPER(test_queue, int32_t, 4); + +class QueueTest : public ::testing::Test { + protected: + void SetUp() override { + test_queue_initialize(iree_allocator_system(), &queue_); + } + + void TearDown() override { test_queue_deinitialize(&queue_); } + + test_queue_t queue_; +}; + +TEST_F(QueueTest, initialize) { + EXPECT_EQ(queue_.element_count, 0); + EXPECT_EQ(queue_.capacity, 4); + EXPECT_EQ(queue_.element_size, sizeof(int)); +} + +TEST_F(QueueTest, push_back) { + int value = 42; + IREE_ASSERT_OK(test_queue_push_back(&queue_, value)); + EXPECT_EQ(queue_.element_count, 1); + EXPECT_EQ(test_queue_at(&queue_, 0), value); +} + +TEST_F(QueueTest, push_back_and_expand) { + for (int i = 0; i < 5; ++i) { + IREE_ASSERT_OK(test_queue_push_back(&queue_, i)); + } + EXPECT_EQ(queue_.element_count, 5); + EXPECT_GT(queue_.capacity, 4); +} + +TEST_F(QueueTest, pop_front) { + IREE_ASSERT_OK(test_queue_push_back(&queue_, 12)); + IREE_ASSERT_OK(test_queue_push_back(&queue_, 15)); + test_queue_pop_front(&queue_, 1); + EXPECT_EQ(queue_.element_count, 1); + EXPECT_EQ(test_queue_at(&queue_, 0), 15); +} + +TEST_F(QueueTest, at) { + int value1 = 42; + int value2 = 84; + IREE_ASSERT_OK(test_queue_push_back(&queue_, value1)); + IREE_ASSERT_OK(test_queue_push_back(&queue_, value2)); + EXPECT_EQ(test_queue_at(&queue_, 0), value1); + EXPECT_EQ(test_queue_at(&queue_, 1), value2); +} + +TEST_F(QueueTest, cycle_around_queue) { + for (int i = 0; i < 4; ++i) { + IREE_ASSERT_OK(test_queue_push_back(&queue_, i)); + } + for (int i = 0; i < 4; ++i) { + test_queue_pop_front(&queue_, 1); + IREE_ASSERT_OK(test_queue_push_back(&queue_, i + 4)); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(test_queue_at(&queue_, i), i + 4); + } + EXPECT_EQ(queue_.element_count, 4); +} + +TEST_F(QueueTest, cycle_around_queue_twice) { + for (int i = 0; i < 4; ++i) { + IREE_ASSERT_OK(test_queue_push_back(&queue_, i)); + } + for (int i = 0; i < 8; ++i) { + test_queue_pop_front(&queue_, 1); + IREE_ASSERT_OK(test_queue_push_back(&queue_, i + 4)); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(test_queue_at(&queue_, i), i + 8); + } + EXPECT_EQ(queue_.element_count, 4); +} + +TEST_F(QueueTest, allocate_twice) { + for (int i = 0; i < 8; ++i) { + IREE_ASSERT_OK(test_queue_push_back(&queue_, i)); + } + EXPECT_EQ(queue_.element_count, 8); + EXPECT_GT(queue_.capacity, 4); + + for (int i = 8; i < 16; ++i) { + IREE_ASSERT_OK(test_queue_push_back(&queue_, i)); + } + EXPECT_EQ(queue_.element_count, 16); + EXPECT_GT(queue_.capacity, 8); +} + +TEST_F(QueueTest, no_reallocation_when_capacity_sufficient) { + size_t initial_capacity = queue_.capacity; + for (int i = 0; i < initial_capacity; ++i) { + IREE_ASSERT_OK(test_queue_push_back(&queue_, i)); + } + EXPECT_EQ(queue_.capacity, initial_capacity); +} diff --git a/runtime/src/iree/hal/drivers/hip/util/tree.c b/runtime/src/iree/hal/drivers/hip/util/tree.c new file mode 100644 index 000000000000..21b5b322da47 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/util/tree.c @@ -0,0 +1,578 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/hip/util/tree.h" + +static iree_hal_hip_util_tree_node_t* +iree_hal_hip_util_tree_get_node_from_cache(iree_hal_hip_util_tree_t* tree) { + if (tree->cache) { + iree_hal_hip_util_tree_node_t* node = tree->cache; + tree->cache = node->right; + return node; + } + return NULL; +} + +static void iree_hal_hip_util_tree_add_node_to_cache( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t* node) { + node->right = tree->cache; + tree->cache = node; +} + +static void iree_hal_hip_util_tree_delete_node( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t* node) { + if (node != &tree->nil) { + iree_hal_hip_util_tree_add_node_to_cache(tree, node); + } +} + +static iree_status_t iree_hal_hip_util_tree_allocate_node( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t** out_node) { + *out_node = NULL; + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_get_node_from_cache(tree); + if (node) { + memset(node, 0, sizeof(*node) + tree->element_size); + } else { + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + tree->allocator, sizeof(*node) + tree->element_size, (void**)&node)); + } + *out_node = node; + node->data = (uint8_t*)node + sizeof(*node); + return iree_ok_status(); +} + +static bool iree_hal_hip_util_tree_free_node( + iree_hal_hip_util_tree_node_t* node, void* user_data) { + iree_hal_hip_util_tree_t* tree = (iree_hal_hip_util_tree_t*)user_data; + if ((uint8_t*)node > tree->initial_node_cache && + (uint8_t*)node < + tree->initial_node_cache + tree->initial_node_cache_size) { + return true; + } + iree_allocator_free(tree->allocator, node); + return true; +} + +static void iree_hal_hip_util_tree_rotate_left( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t* node) { + iree_hal_hip_util_tree_node_t* right_child = node->right; + node->right = right_child->left; + if (right_child->left != &tree->nil) { + right_child->left->parent = node; + } + right_child->parent = node->parent; + if (node->parent == NULL) { + tree->root = right_child; + } else if (node == node->parent->left) { + node->parent->left = right_child; + } else { + node->parent->right = right_child; + } + right_child->left = node; + node->parent = right_child; +} + +static void iree_hal_hip_util_tree_rotate_right( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t* node) { + iree_hal_hip_util_tree_node_t* left_child = node->left; + node->left = left_child->right; + if (left_child->right != &tree->nil) { + left_child->right->parent = node; + } + left_child->parent = node->parent; + if (node->parent == NULL) { + tree->root = left_child; + } else if (node == node->parent->right) { + node->parent->right = left_child; + } else { + node->parent->left = left_child; + } + left_child->right = node; + node->parent = left_child; +} + +static iree_status_t iree_hal_hip_util_tree_insert_internal( + iree_hal_hip_util_tree_t* tree, iree_host_size_t key, + iree_hal_hip_util_tree_node_t* node) { + node->left = &tree->nil; + node->right = &tree->nil; + node->key = key; + node->red = true; // red + node->parent = NULL; + + iree_hal_hip_util_tree_node_t* search_position = tree->root; + iree_hal_hip_util_tree_node_t* target_parent = NULL; + while (search_position != &tree->nil) { + target_parent = search_position; + if (node->key < search_position->key) { + search_position = search_position->left; + } else if (node->key > search_position->key) { + search_position = search_position->right; + } else { + return iree_make_status(IREE_STATUS_ALREADY_EXISTS, + "trying to insert a duplicate key"); + } + } + node->parent = target_parent; + + if (!target_parent) { + tree->root = node; + } else if (node->key < target_parent->key) { + target_parent->left = node; + } else { + target_parent->right = node; + } + + if (node->parent == NULL) { + node->red = false; + return iree_ok_status(); + } + + if (node->parent == tree->root) { + return iree_ok_status(); + } + + while (node->parent->red) { + if (node->parent == node->parent->parent->right) { + iree_hal_hip_util_tree_node_t* uncle = node->parent->parent->left; + if (uncle->red) { + uncle->red = false; + node->parent->red = false; + node->parent->parent->red = true; + node = node->parent->parent; + } else { + if (node == node->parent->left) { + node = node->parent; + iree_hal_hip_util_tree_rotate_right(tree, node); + } + node->parent->red = false; + node->parent->parent->red = true; + iree_hal_hip_util_tree_rotate_left(tree, node->parent->parent); + } + } else { + iree_hal_hip_util_tree_node_t* uncle = node->parent->parent->right; + if (uncle && uncle->red) { + uncle->red = false; + node->parent->red = false; + node->parent->parent->red = true; + node = node->parent->parent; + } else { + if (node == node->parent->right) { + node = node->parent; + iree_hal_hip_util_tree_rotate_left(tree, node); + } + node->parent->red = false; + node->parent->parent->red = true; + iree_hal_hip_util_tree_rotate_right(tree, node->parent->parent); + } + } + if (node == tree->root) { + break; + } + } + tree->root->red = false; + + return iree_ok_status(); +} + +static bool iree_hal_hip_util_tree_walk_helper( + iree_hal_hip_util_tree_node_t* node, + iree_hal_hip_util_tree_walk_type_t walk_type, + iree_hal_hip_util_tree_walk_callback_fn_t callback, void* user_data) { + IREE_ASSERT_LE(walk_type, IREE_TREE_WALK_POSTORDER); + if (!node || node->is_sentinel) { + return true; + } + switch (walk_type) { + case IREE_TREE_WALK_PREORDER: + if (!callback(node, user_data)) { + return false; + } + if (!iree_hal_hip_util_tree_walk_helper(node->left, walk_type, callback, + user_data)) { + return false; + } + return iree_hal_hip_util_tree_walk_helper(node->right, walk_type, + callback, user_data); + case IREE_TREE_WALK_INORDER: + if (!iree_hal_hip_util_tree_walk_helper(node->left, walk_type, callback, + user_data)) { + return false; + } + if (!callback(node, user_data)) { + return false; + } + return iree_hal_hip_util_tree_walk_helper(node->right, walk_type, + callback, user_data); + case IREE_TREE_WALK_POSTORDER: + if (!iree_hal_hip_util_tree_walk_helper(node->left, walk_type, callback, + user_data)) { + return false; + } + if (!iree_hal_hip_util_tree_walk_helper(node->right, walk_type, callback, + user_data)) { + return false; + } + return callback(node, user_data); + } + return false; +} + +static void iree_hal_hip_util_tree_replace(iree_hal_hip_util_tree_t* tree, + iree_hal_hip_util_tree_node_t* dst, + iree_hal_hip_util_tree_node_t* src) { + if (!dst->parent) { + tree->root = src; + } else if (dst == dst->parent->left) { + dst->parent->left = src; + } else { + dst->parent->right = src; + } + src->parent = dst->parent; +} + +static void iree_hal_hip_util_tree_remove( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t* to_remove) { + iree_hal_hip_util_tree_node_t* replacement = NULL; + iree_hal_hip_util_tree_node_t* next = to_remove; + + bool initial_red = next->red; + if (to_remove->left == &tree->nil) { + replacement = to_remove->right; + iree_hal_hip_util_tree_replace(tree, to_remove, to_remove->right); + } else if (to_remove->right == &tree->nil) { + replacement = to_remove->left; + iree_hal_hip_util_tree_replace(tree, to_remove, to_remove->left); + } else { + next = iree_hal_hip_util_tree_node_next(to_remove); + initial_red = next->red; + replacement = next->right; + if (next->parent == to_remove) { + replacement->parent = next; + } else { + iree_hal_hip_util_tree_replace(tree, next, next->right); + next->right = to_remove->right; + next->right->parent = next; + } + + iree_hal_hip_util_tree_replace(tree, to_remove, next); + next->left = to_remove->left; + next->left->parent = next; + next->red = to_remove->red; + } + if (initial_red) { + return; + } + while (replacement != tree->root && !replacement->red) { + if (replacement == replacement->parent->left) { + iree_hal_hip_util_tree_node_t* sibling = replacement->parent->right; + if (sibling->red) { + sibling->red = false; + replacement->parent->red = true; + iree_hal_hip_util_tree_rotate_left(tree, replacement->parent); + sibling = replacement->parent->right; + } + + if (!sibling->left->red && !sibling->right->red) { + sibling->red = true; + replacement = replacement->parent; + } else { + if (!sibling->right->red) { + sibling->left->red = false; + sibling->red = true; + iree_hal_hip_util_tree_rotate_right(tree, sibling); + sibling = replacement->parent->right; + } + sibling->red = replacement->parent->red; + replacement->parent->red = false; + sibling->right->red = false; + iree_hal_hip_util_tree_rotate_left(tree, replacement->parent); + replacement = tree->root; + } + } else { + iree_hal_hip_util_tree_node_t* sibling = replacement->parent->left; + if (sibling->red) { + sibling->red = false; + replacement->parent->red = true; + iree_hal_hip_util_tree_rotate_right(tree, replacement->parent); + sibling = replacement->parent->left; + } + + if (!sibling->left->red && !sibling->right->red) { + sibling->red = true; + replacement = replacement->parent; + } else { + if (!sibling->left->red) { + sibling->right->red = false; + sibling->red = true; + iree_hal_hip_util_tree_rotate_left(tree, sibling); + sibling = replacement->parent->left; + } + + sibling->red = replacement->parent->red; + replacement->parent->red = false; + sibling->left->red = false; + iree_hal_hip_util_tree_rotate_right(tree, replacement->parent); + replacement = tree->root; + } + } + } + replacement->red = false; +} + +//===----------------------------------------------------------------------===// +// iree_hal_hip_util_tree_node_t +//===----------------------------------------------------------------------===// + +void* iree_hal_hip_util_tree_node_get_value( + const iree_hal_hip_util_tree_node_t* node) { + IREE_ASSERT_ARGUMENT(node); + return node->data; +} + +iree_host_size_t iree_hal_hip_util_tree_node_get_key( + const iree_hal_hip_util_tree_node_t* node) { + IREE_ASSERT_ARGUMENT(node); + return node->key; +} + +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_node_next( + iree_hal_hip_util_tree_node_t* node) { + IREE_ASSERT_ARGUMENT(node != NULL); + // 1. Find the smallest thing on our right hand side. + if (!node->right->is_sentinel) { + node = node->right; + while (!node->left->is_sentinel) { + node = node->left; + } + return node; + } + + // 2. Find the parent who is not on the right + iree_hal_hip_util_tree_node_t* parent = node->parent; + while (parent && node == parent->right) { + node = parent; + parent = node->parent; + } + return parent; +} + +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_node_prev( + iree_hal_hip_util_tree_node_t* node) { + IREE_ASSERT_ARGUMENT(node); + // 1. to find the largest thing on our left hand side. + if (!node->left->is_sentinel) { + node = node->left; + while (!node->right->is_sentinel) { + node = node->right; + } + return node; + } + + // 2. Find the parent who is not on the left + iree_hal_hip_util_tree_node_t* parent = node->parent; + while (parent && node == parent->left) { + node = parent; + parent = node->parent; + } + return parent; +} + +//===----------------------------------------------------------------------===// +// iree_hal_hip_util_tree_t +//===----------------------------------------------------------------------===// + +void iree_hal_hip_util_tree_initialize(iree_allocator_t allocator, + iree_host_size_t element_size, + void* initial_node_cache, + iree_host_size_t initial_node_cache_size, + iree_hal_hip_util_tree_t* out_tree) { + out_tree->element_size = element_size; + out_tree->allocator = allocator; + out_tree->root = &out_tree->nil; + out_tree->size = 0; + out_tree->cache = NULL; // Initialize cache + memset(&out_tree->nil, 0x00, sizeof(out_tree->nil)); + out_tree->nil.is_sentinel = true; + out_tree->initial_node_cache = initial_node_cache; + out_tree->initial_node_cache_size = initial_node_cache_size; + if (initial_node_cache) { + memset(initial_node_cache, 0, initial_node_cache_size); + iree_host_size_t node_size = + iree_host_align(sizeof(out_tree->nil) + element_size, 16); + + iree_hal_hip_util_tree_node_t* node = + (iree_hal_hip_util_tree_node_t*)initial_node_cache; + for (iree_host_size_t i = 0; i < initial_node_cache_size / node_size; ++i) { + node->data = (uint8_t*)node + sizeof(*node); + iree_hal_hip_util_tree_add_node_to_cache(out_tree, node); + node = (iree_hal_hip_util_tree_node_t*)((uint8_t*)node + node_size); + } + } +} + +void iree_hal_hip_util_tree_deinitialize(iree_hal_hip_util_tree_t* tree) { + iree_hal_hip_util_tree_walk(tree, IREE_TREE_WALK_POSTORDER, + iree_hal_hip_util_tree_free_node, tree); + + // Free cache nodes + iree_hal_hip_util_tree_node_t* node = tree->cache; + while (node) { + iree_hal_hip_util_tree_node_t* next = node->right; + if ((uint8_t*)node < tree->initial_node_cache || + (uint8_t*)node > + tree->initial_node_cache + tree->initial_node_cache_size) { + iree_allocator_free(tree->allocator, node); + } + node = next; + } + + // Reset the tree structure. + tree->root = &tree->nil; + memset(&tree->nil, 0, sizeof(tree->nil)); + tree->nil.is_sentinel = true; + tree->size = 0; + tree->cache = NULL; +} + +iree_host_size_t iree_hal_hip_util_tree_element_size( + const iree_hal_hip_util_tree_t* tree) { + return tree->element_size; +} + +iree_status_t iree_hal_hip_util_tree_insert( + iree_hal_hip_util_tree_t* tree, iree_host_size_t key, + iree_hal_hip_util_tree_node_t** out_data) { + *out_data = NULL; + iree_hal_hip_util_tree_node_t* t = NULL; + IREE_RETURN_IF_ERROR(iree_hal_hip_util_tree_allocate_node(tree, &t)); + + iree_status_t status = iree_hal_hip_util_tree_insert_internal(tree, key, t); + if (!iree_status_is_ok(status)) { + iree_hal_hip_util_tree_delete_node(tree, t); + return status; + } + ++tree->size; + *out_data = t; + return status; +} + +iree_host_size_t iree_hal_hip_util_tree_size( + const iree_hal_hip_util_tree_t* tree) { + return tree->size; +} + +iree_status_t iree_hal_hip_util_tree_move_node( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t* node, + iree_host_size_t new_key) { + iree_hal_hip_util_tree_node_t* next = iree_hal_hip_util_tree_node_next(node); + iree_hal_hip_util_tree_node_t* prev = iree_hal_hip_util_tree_node_prev(node); + if ((!next || next->key > new_key) && (!prev || prev->key < new_key)) { + // This node isn't going to move, just update its value. + node->key = new_key; + return iree_ok_status(); + } + iree_hal_hip_util_tree_remove(tree, node); + return iree_hal_hip_util_tree_insert_internal(tree, new_key, node); +} + +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_get( + const iree_hal_hip_util_tree_t* tree, iree_host_size_t key) { + iree_hal_hip_util_tree_node_t* node = tree->root; + while (node->is_sentinel == false) { + if (key == node->key) { + return node; + } else if (key < node->key) { + node = node->left; + } else { + node = node->right; + } + } + return NULL; +} + +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_lower_bound( + const iree_hal_hip_util_tree_t* tree, iree_host_size_t key) { + iree_hal_hip_util_tree_node_t* node = tree->root; + iree_hal_hip_util_tree_node_t* last = NULL; + while (node->is_sentinel == false) { + last = node; + if (key == node->key) { + return node; + } else if (key < node->key) { + node = node->left; + } else { + node = node->right; + } + } + if (!last || last->key > key) { + return last; + } + return iree_hal_hip_util_tree_node_next(last); +} + +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_upper_bound( + const iree_hal_hip_util_tree_t* tree, iree_host_size_t key) { + iree_hal_hip_util_tree_node_t* node = tree->root; + iree_hal_hip_util_tree_node_t* last = NULL; + while (node->is_sentinel == false) { + last = node; + if (key == node->key) { + return node; + } else if (key < node->key) { + node = node->left; + } else { + node = node->right; + } + } + if (!last || last->key > key) { + return last; + } + while (last && last->key <= key) { + last = iree_hal_hip_util_tree_node_next(last); + } + return last; +} + +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_first( + const iree_hal_hip_util_tree_t* tree) { + if (!tree->root || tree->root->is_sentinel) { + return NULL; + } + iree_hal_hip_util_tree_node_t* val = tree->root; + while (!val->left->is_sentinel) { + val = val->left; + } + return val; +} + +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_last( + const iree_hal_hip_util_tree_t* tree) { + if (!tree->root || tree->root->is_sentinel) { + return NULL; + } + iree_hal_hip_util_tree_node_t* val = tree->root; + while (!val->right->is_sentinel) { + val = val->right; + } + return val; +} + +void iree_hal_hip_util_tree_erase(iree_hal_hip_util_tree_t* tree, + iree_hal_hip_util_tree_node_t* node) { + iree_hal_hip_util_tree_remove(tree, node); + iree_hal_hip_util_tree_delete_node(tree, node); + --tree->size; +} + +void iree_hal_hip_util_tree_walk( + const iree_hal_hip_util_tree_t* tree, + iree_hal_hip_util_tree_walk_type_t walk_type, + iree_hal_hip_util_tree_walk_callback_fn_t callback, void* user_data) { + iree_hal_hip_util_tree_walk_helper(tree->root, walk_type, callback, + user_data); +} diff --git a/runtime/src/iree/hal/drivers/hip/util/tree.h b/runtime/src/iree/hal/drivers/hip/util/tree.h new file mode 100644 index 000000000000..b261248f0672 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/util/tree.h @@ -0,0 +1,151 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_DRIVERS_HIP_UTIL_TREE_H_ +#define IREE_HAL_DRIVERS_HIP_UTIL_TREE_H_ + +#include + +#include "iree/base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_hip_util_tree_node_t iree_hal_hip_util_tree_node_t; +typedef struct iree_hal_hip_util_tree_t iree_hal_hip_util_tree_t; + +typedef enum iree_hal_hip_util_tree_walk_type_e { + IREE_TREE_WALK_PREORDER, + IREE_TREE_WALK_INORDER, + IREE_TREE_WALK_POSTORDER, +} iree_hal_hip_util_tree_walk_type_t; + +//===----------------------------------------------------------------------===// +// iree_hal_hip_util_tree_node_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hip_util_tree_node_t { + bool red; + iree_hal_hip_util_tree_node_t* left; + iree_hal_hip_util_tree_node_t* right; + iree_hal_hip_util_tree_node_t* parent; + iree_host_size_t key; + bool is_sentinel; + uint8_t* data; +} iree_hal_hip_util_tree_node_t; + +// Returns the value associated with the given node. +void* iree_hal_hip_util_tree_node_get_value( + const iree_hal_hip_util_tree_node_t* node); + +// Returns the key for the given node. +iree_host_size_t iree_hal_hip_util_tree_node_get_key( + const iree_hal_hip_util_tree_node_t* node); + +// Callback function for the iree_hip_util_tree_walk. +// +// This is provided the node and user_data for every node in the tree. A return +// of false from this function will cause the tree walk to complete without +// walking any further nodes. +typedef bool (*iree_hal_hip_util_tree_walk_callback_fn_t)( + iree_hal_hip_util_tree_node_t* node, void* user_data); + +// Walks the entire tree invoking the callback for every node in the tree. +void iree_hal_hip_util_tree_walk( + const iree_hal_hip_util_tree_t* tree, + iree_hal_hip_util_tree_walk_type_t walk_type, + iree_hal_hip_util_tree_walk_callback_fn_t callback, void* user_data); + +// Returns the next node in the tree or NULL. +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_node_next( + iree_hal_hip_util_tree_node_t* node); + +// Returns the previous node in the tree or NULL. +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_node_prev( + iree_hal_hip_util_tree_node_t* node); + +//===----------------------------------------------------------------------===// +// iree_hal_hip_util_tree_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_hip_util_tree_t { + iree_allocator_t allocator; + iree_host_size_t element_size; + iree_hal_hip_util_tree_node_t* root; + iree_host_size_t size; + iree_hal_hip_util_tree_node_t* cache; // Cache for deleted nodes + iree_hal_hip_util_tree_node_t nil; + uint8_t* initial_node_cache; + iree_host_size_t initial_node_cache_size; +} iree_hal_hip_util_tree_t; + +// Initializes the tree for values of |element_size|. +// +// If |initial_node_cache| is not null then it points to +// a block of memory that will be used to hold nodes before +// the tree ever tries to use the allocator. +void iree_hal_hip_util_tree_initialize(iree_allocator_t allocator, + iree_host_size_t element_size, + void* initial_node_cache, + iree_host_size_t initial_node_cache_size, + iree_hal_hip_util_tree_t* out_tree); + +// Deinitializes the tree and frees any memory that was allocated. +void iree_hal_hip_util_tree_deinitialize(iree_hal_hip_util_tree_t* tree); + +// Returns the number of bytes that are allocated for every value in the tree. +iree_host_size_t iree_hal_hip_util_tree_element_size( + const iree_hal_hip_util_tree_t* tree); + +// Inserts a new node into the tree with the given |key|. +// +// If the key is already present in the tree an error is returned. +iree_status_t iree_hal_hip_util_tree_insert( + iree_hal_hip_util_tree_t* tree, iree_host_size_t key, + iree_hal_hip_util_tree_node_t** out_data); + +// Returns the number of elements in the tree. +iree_host_size_t iree_hal_hip_util_tree_size( + const iree_hal_hip_util_tree_t* tree); + +// Moves a node that already exists in the tree to a new location with the given +// key. +iree_status_t iree_hal_hip_util_tree_move_node( + iree_hal_hip_util_tree_t* tree, iree_hal_hip_util_tree_node_t* node, + iree_host_size_t new_key); + +// Returns the node in the tree that has a given key. +// +// Returns NULL if the key could not be found in the tree. +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_get( + const iree_hal_hip_util_tree_t* tree, iree_host_size_t key); + +// Returns the first node in the tree that has a key that is >= |key| or NULL. +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_lower_bound( + const iree_hal_hip_util_tree_t* tree, iree_host_size_t key); + +// Returns the first node in the tree that has a key that is > |key| or NULL; +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_upper_bound( + const iree_hal_hip_util_tree_t* tree, iree_host_size_t key); + +// Returns the node in the tree with the smallest |key| or NULL. +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_first( + const iree_hal_hip_util_tree_t* tree); + +// Returns the node in the tree with the largest |key| or NULL. +iree_hal_hip_util_tree_node_t* iree_hal_hip_util_tree_last( + const iree_hal_hip_util_tree_t* tree); + +// Erases the given node from the tree. +void iree_hal_hip_util_tree_erase(iree_hal_hip_util_tree_t* tree, + iree_hal_hip_util_tree_node_t* node); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_DRIVERS_HIP_UTIL_TREE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/util/tree_test.cc b/runtime/src/iree/hal/drivers/hip/util/tree_test.cc new file mode 100644 index 000000000000..341966cac7e9 --- /dev/null +++ b/runtime/src/iree/hal/drivers/hip/util/tree_test.cc @@ -0,0 +1,180 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/hal/drivers/hip/util/tree.h" + +#include "iree/testing/gtest.h" + +class RedBlackTreeTest : public ::testing::Test { + protected: + void SetUp() override { + iree_allocator_t allocator = iree_allocator_system(); + iree_hal_hip_util_tree_initialize(allocator, sizeof(int), initial_cache, + 1024, &tree_); + } + + void TearDown() override { iree_hal_hip_util_tree_deinitialize(&tree_); } + + iree_hal_hip_util_tree_t tree_; + uint8_t initial_cache[1024]; +}; + +TEST_F(RedBlackTreeTest, initialize) { + EXPECT_EQ(iree_hal_hip_util_tree_size(&tree_), 0); +} + +TEST_F(RedBlackTreeTest, insert) { + iree_hal_hip_util_tree_node_t* node = NULL; + EXPECT_EQ(iree_hal_hip_util_tree_insert(&tree_, 10, &node), iree_ok_status()); + EXPECT_EQ(iree_hal_hip_util_tree_size(&tree_), 1); + EXPECT_EQ(iree_hal_hip_util_tree_node_get_key(node), 10); +} + +TEST_F(RedBlackTreeTest, get) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + EXPECT_NE(iree_hal_hip_util_tree_get(&tree_, 10), nullptr); + EXPECT_EQ(iree_hal_hip_util_tree_get(&tree_, 20), nullptr); +} + +TEST_F(RedBlackTreeTest, delete) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + iree_hal_hip_util_tree_erase(&tree_, node); + EXPECT_EQ(iree_hal_hip_util_tree_get(&tree_, 10), nullptr); + EXPECT_EQ(iree_hal_hip_util_tree_size(&tree_), 0); +} + +TEST_F(RedBlackTreeTest, walk) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + static_cast(iree_hal_hip_util_tree_node_get_value(node))[0] = 10; + iree_hal_hip_util_tree_insert(&tree_, 20, &node); + static_cast(iree_hal_hip_util_tree_node_get_value(node))[0] = 20; + iree_hal_hip_util_tree_insert(&tree_, 30, &node); + static_cast(iree_hal_hip_util_tree_node_get_value(node))[0] = 30; + + int sum = 0; + auto callback = [](iree_hal_hip_util_tree_node_t* node, + void* user_data) -> bool { + int* sum = static_cast(user_data); + EXPECT_EQ(*static_cast(iree_hal_hip_util_tree_node_get_value(node)), + iree_hal_hip_util_tree_node_get_key(node)); + *sum += *static_cast(iree_hal_hip_util_tree_node_get_value(node)); + return true; + }; + iree_hal_hip_util_tree_walk(&tree_, IREE_TREE_WALK_PREORDER, callback, &sum); + EXPECT_EQ(sum, 60); +} + +TEST_F(RedBlackTreeTest, boundary_conditions) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + iree_hal_hip_util_tree_insert(&tree_, 20, &node); + iree_hal_hip_util_tree_insert(&tree_, 30, &node); + + EXPECT_EQ( + iree_hal_hip_util_tree_node_get_key(iree_hal_hip_util_tree_first(&tree_)), + 10); + EXPECT_EQ( + iree_hal_hip_util_tree_node_get_key(iree_hal_hip_util_tree_last(&tree_)), + 30); + EXPECT_EQ(iree_hal_hip_util_tree_node_get_key( + iree_hal_hip_util_tree_lower_bound(&tree_, 15)), + 20); + EXPECT_EQ(iree_hal_hip_util_tree_node_get_key( + iree_hal_hip_util_tree_upper_bound(&tree_, 15)), + 20); +} + +TEST_F(RedBlackTreeTest, move_node) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + iree_hal_hip_util_tree_move_node(&tree_, node, 20); + EXPECT_EQ(iree_hal_hip_util_tree_get(&tree_, 10), nullptr); + EXPECT_NE(iree_hal_hip_util_tree_get(&tree_, 20), nullptr); +} + +TEST_F(RedBlackTreeTest, in_order_iterators) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + iree_hal_hip_util_tree_insert(&tree_, 20, &node); + iree_hal_hip_util_tree_insert(&tree_, 30, &node); + + std::vector keys; + for (iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_first(&tree_); + node != nullptr; node = iree_hal_hip_util_tree_node_next(node)) { + keys.push_back(iree_hal_hip_util_tree_node_get_key(node)); + } + + EXPECT_EQ(keys.size(), 3); + EXPECT_EQ(keys[0], 10); + EXPECT_EQ(keys[1], 20); + EXPECT_EQ(keys[2], 30); +} + +TEST_F(RedBlackTreeTest, in_order_iterators_last) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + iree_hal_hip_util_tree_insert(&tree_, 20, &node); + iree_hal_hip_util_tree_insert(&tree_, 30, &node); + + std::vector keys; + for (iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_last(&tree_); + node != nullptr; node = iree_hal_hip_util_tree_node_prev(node)) { + keys.push_back(iree_hal_hip_util_tree_node_get_key(node)); + } + + EXPECT_EQ(keys.size(), 3); + EXPECT_EQ(keys[0], 30); + EXPECT_EQ(keys[1], 20); + EXPECT_EQ(keys[2], 10); +} + +class RedBlackTreeWalkTest + : public RedBlackTreeTest, + public ::testing::WithParamInterface { +}; + +TEST_P(RedBlackTreeWalkTest, walk) { + iree_hal_hip_util_tree_node_t* node = NULL; + iree_hal_hip_util_tree_insert(&tree_, 10, &node); + iree_hal_hip_util_tree_insert(&tree_, 20, &node); + iree_hal_hip_util_tree_insert(&tree_, 30, &node); + + std::vector keys; + auto callback = [](iree_hal_hip_util_tree_node_t* node, + void* user_data) -> bool { + auto* keys = static_cast*>(user_data); + keys->push_back(iree_hal_hip_util_tree_node_get_key(node)); + return true; + }; + iree_hal_hip_util_tree_walk(&tree_, GetParam(), callback, &keys); + + if (GetParam() == IREE_TREE_WALK_INORDER) { + EXPECT_EQ(keys.size(), 3); + EXPECT_EQ(keys[0], 10); + EXPECT_EQ(keys[1], 20); + EXPECT_EQ(keys[2], 30); + } else if (GetParam() == IREE_TREE_WALK_PREORDER) { + EXPECT_EQ(keys.size(), 3); + EXPECT_EQ(keys[0], 20); // Assuming 20 is the root after balancing + EXPECT_EQ(keys[1], 10); + EXPECT_EQ(keys[2], 30); + } else if (GetParam() == IREE_TREE_WALK_POSTORDER) { + EXPECT_EQ(keys.size(), 3); + EXPECT_EQ(keys[0], 10); + EXPECT_EQ(keys[1], 30); + EXPECT_EQ(keys[2], 20); // Assuming 20 is the root after balancing + } +} + +INSTANTIATE_TEST_SUITE_P(WalkTypes, RedBlackTreeWalkTest, + ::testing::Values(IREE_TREE_WALK_PREORDER, + IREE_TREE_WALK_INORDER, + IREE_TREE_WALK_POSTORDER)); diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c index 7283e580024b..c539dcbcf66f 100644 --- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c +++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c @@ -241,8 +241,8 @@ static iree_status_t iree_hal_sync_device_create_command_buffer( iree_hal_sync_device_t* device = iree_hal_sync_device_cast(base_device); return iree_hal_deferred_command_buffer_create( iree_hal_device_allocator(base_device), mode, command_categories, - binding_capacity, &device->large_block_pool, device->host_allocator, - out_command_buffer); + queue_affinity, binding_capacity, &device->large_block_pool, + device->host_allocator, out_command_buffer); } } diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c index 8aa092590e9d..c01979c36f45 100644 --- a/runtime/src/iree/hal/drivers/local_task/task_device.c +++ b/runtime/src/iree/hal/drivers/local_task/task_device.c @@ -303,8 +303,8 @@ static iree_status_t iree_hal_task_device_create_command_buffer( // destructive. return iree_hal_deferred_command_buffer_create( iree_hal_device_allocator(base_device), mode, command_categories, - binding_capacity, &device->large_block_pool, device->host_allocator, - out_command_buffer); + queue_affinity, binding_capacity, &device->large_block_pool, + device->host_allocator, out_command_buffer); } else { iree_host_size_t queue_index = iree_hal_task_device_select_queue( device, command_categories, queue_affinity); diff --git a/runtime/src/iree/hal/drivers/metal/metal_device.m b/runtime/src/iree/hal/drivers/metal/metal_device.m index ef8e2c974465..0a155d8fec13 100644 --- a/runtime/src/iree/hal/drivers/metal/metal_device.m +++ b/runtime/src/iree/hal/drivers/metal/metal_device.m @@ -256,8 +256,8 @@ static iree_status_t iree_hal_metal_device_create_command_buffer( // for argument buffer updates to pass in binding tables. if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) || binding_capacity > 0) { return iree_hal_deferred_command_buffer_create( - device->device_allocator, mode, command_categories, binding_capacity, &device->block_pool, - device->host_allocator, out_command_buffer); + device->device_allocator, mode, command_categories, queue_affinity, binding_capacity, + &device->block_pool, device->host_allocator, out_command_buffer); } return iree_hal_metal_direct_command_buffer_create( diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc index 043394292188..2e25d1193c3f 100644 --- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc +++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc @@ -1540,7 +1540,7 @@ static iree_status_t iree_hal_vulkan_device_create_command_buffer( if (binding_capacity > 0) { return iree_hal_deferred_command_buffer_create( iree_hal_device_allocator(base_device), mode, command_categories, - binding_capacity, &device->block_pool, + queue_affinity, binding_capacity, &device->block_pool, iree_hal_device_host_allocator(base_device), out_command_buffer); } diff --git a/runtime/src/iree/hal/queue.h b/runtime/src/iree/hal/queue.h index a96528849203..4e54e0b68e66 100644 --- a/runtime/src/iree/hal/queue.h +++ b/runtime/src/iree/hal/queue.h @@ -34,6 +34,7 @@ typedef uint64_t iree_hal_queue_affinity_t; // Specifies that any queue may be selected. #define IREE_HAL_QUEUE_AFFINITY_ANY ((iree_hal_queue_affinity_t)(-1)) +#define IREE_HAL_MAX_QUEUES (sizeof(iree_hal_queue_affinity_t) / 8) #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c index 939cddcb8828..6593f5b68bda 100644 --- a/runtime/src/iree/hal/utils/deferred_command_buffer.c +++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c @@ -161,8 +161,8 @@ iree_hal_deferred_command_buffer_cast(iree_hal_command_buffer_t* base_value) { IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_create( iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, - iree_host_size_t binding_capacity, iree_arena_block_pool_t* block_pool, - iree_allocator_t host_allocator, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer) { IREE_ASSERT_ARGUMENT(block_pool); IREE_ASSERT_ARGUMENT(out_command_buffer); @@ -177,7 +177,7 @@ IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_create( (void**)&command_buffer); if (iree_status_is_ok(status)) { iree_hal_command_buffer_initialize( - device_allocator, mode, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, + device_allocator, mode, command_categories, queue_affinity, binding_capacity, (uint8_t*)command_buffer + sizeof(*command_buffer), &iree_hal_deferred_command_buffer_vtable, &command_buffer->base); command_buffer->host_allocator = host_allocator; diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.h b/runtime/src/iree/hal/utils/deferred_command_buffer.h index 500c405f4271..51714e262cd3 100644 --- a/runtime/src/iree/hal/utils/deferred_command_buffer.h +++ b/runtime/src/iree/hal/utils/deferred_command_buffer.h @@ -44,8 +44,8 @@ typedef struct iree_arena_block_pool_t iree_arena_block_pool_t; IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_create( iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, - iree_host_size_t binding_capacity, iree_arena_block_pool_t* block_pool, - iree_allocator_t host_allocator, + iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_command_buffer_t** out_command_buffer); // Returns true if |command_buffer| is a deferred command buffer. diff --git a/runtime/src/iree/hal/utils/stream_tracing.c b/runtime/src/iree/hal/utils/stream_tracing.c index 47a9e3ed48c6..f99e932cbb51 100644 --- a/runtime/src/iree/hal/utils/stream_tracing.c +++ b/runtime/src/iree/hal/utils/stream_tracing.c @@ -11,7 +11,7 @@ // Total number of events per tracing context. This translates to the maximum // number of outstanding timestamp queries before collection is required. // To prevent spilling pages we leave some room for the context structure. -#define IREE_HAL_TRACING_DEFAULT_QUERY_CAPACITY (16 * 1024 - 256) +#define IREE_HAL_TRACING_DEFAULT_QUERY_CAPACITY (32 * 1024 - 256) // iree_hal_stream_tracing_context_event_t contains a native event that is used // to record timestamps for tracing GPU execution. In this struct, there are @@ -207,7 +207,7 @@ void iree_hal_stream_tracing_context_free( IREE_TRACE_ZONE_BEGIN(z0); // Always perform a collection on shutdown. - iree_hal_stream_tracing_context_collect(context); + iree_status_ignore(iree_hal_stream_tracing_context_collect(context)); // Release all events; since collection completed they should all be unused. IREE_TRACE_ZONE_BEGIN_NAMED( @@ -226,64 +226,109 @@ void iree_hal_stream_tracing_context_free( iree_slim_mutex_deinitialize(&context->event_mutex); + context->device_interface->vtable->destroy(context->device_interface); iree_allocator_t host_allocator = context->host_allocator; iree_allocator_free(host_allocator, context); IREE_TRACE_ZONE_END(z0); } -void iree_hal_stream_tracing_context_collect( +static iree_status_t iree_hal_stream_tracing_context_collect_list_internal( + iree_hal_stream_tracing_context_t* context, + iree_hal_stream_tracing_context_event_t* event) { + if (!context) return iree_ok_status(); + IREE_ASSERT_ARGUMENT(event); + // Inner per-event loop. + while (event) { + uint32_t query_id = (uint32_t)(event - &context->event_pool[0]); + IREE_RETURN_IF_ERROR( + context->device_interface->vtable->synchronize_native_event( + context->device_interface, event->event)); + IREE_RETURN_IF_ERROR(context->device_interface->vtable->query_native_event( + context->device_interface, event->event)); + + // Calculate context-relative time and notify tracy. + float relative_millis = 0.0f; + context->device_interface->vtable->event_elapsed_time( + context->device_interface, &relative_millis, context->base_event, + event->event); + + int64_t gpu_timestamp = (int64_t)((double)relative_millis * 1000000.0); + + iree_tracing_gpu_zone_notify(context->id, query_id, gpu_timestamp); + event = event->next_in_command_buffer; + } + return iree_ok_status(); +} + +iree_status_t iree_hal_stream_tracing_context_collect_list( + iree_hal_stream_tracing_context_t* context, + iree_hal_stream_tracing_context_event_t* completion_event) { + if (!context) return iree_ok_status(); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_stream_tracing_context_collect_list_internal( + context, completion_event)); + + iree_slim_mutex_lock(&context->event_mutex); + iree_hal_stream_tracing_context_event_t* events = + context->submitted_event_list.head; + iree_hal_stream_tracing_context_event_t* last_events = events; + if (events == completion_event) { + context->submitted_event_list.head = events->next_submission; + } + + while (events) { + if (events == completion_event) { + // Remove completed events from the list. + last_events->next_submission = events->next_submission; + break; + } + last_events = events; + events = events->next_submission; + } + + completion_event->was_submitted = true; + iree_slim_mutex_unlock(&context->event_mutex); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +iree_status_t iree_hal_stream_tracing_context_collect( iree_hal_stream_tracing_context_t* context) { - if (!context) return; + if (!context) return iree_ok_status(); iree_slim_mutex_lock(&context->event_mutex); // No outstanding queries if (!context->submitted_event_list.head) { iree_slim_mutex_unlock(&context->event_mutex); - return; + return iree_ok_status(); } IREE_TRACE_ZONE_BEGIN(z0); - + iree_status_t status = iree_ok_status(); // submitted_event_list is a list of the head elements for each command // buffer that has been submitted. Here we loop over all of the events, // wait for them to complete and gather the results with event_query. iree_hal_stream_tracing_context_event_t* events = context->submitted_event_list.head; - uint32_t read_query_count = 0; // Outer per-command_buffer loop. while (events) { iree_hal_stream_tracing_context_event_t* event = events; - // Inner per-event loop. - while (event) { - uint32_t query_id = (uint32_t)(event - &context->event_pool[0]); - iree_status_t status = - context->device_interface->vtable->synchronize_native_event( - context->device_interface, event->event); - if (!iree_status_is_ok(status)) break; - status = context->device_interface->vtable->query_native_event( - context->device_interface, event->event); - if (!iree_status_is_ok(status)) break; - - // Calculate context-relative time and notify tracy. - float relative_millis = 0.0f; - context->device_interface->vtable->event_elapsed_time( - context->device_interface, &relative_millis, context->base_event, - event->event); - - int64_t gpu_timestamp = (int64_t)((double)relative_millis * 1000000.0); - iree_tracing_gpu_zone_notify(context->id, query_id, gpu_timestamp); - read_query_count += 1; - event = event->next_in_command_buffer; + status = + iree_hal_stream_tracing_context_collect_list_internal(context, event); + if (!iree_status_is_ok(status)) { + break; } iree_hal_stream_tracing_context_event_t* next = events->next_submission; events->was_submitted = true; events = next; context->submitted_event_list.head = events; } - IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)read_query_count); IREE_TRACE_ZONE_END(z0); iree_slim_mutex_unlock(&context->event_mutex); + return status; } void iree_hal_stream_tracing_notify_submitted( @@ -513,14 +558,23 @@ iree_status_t iree_hal_stream_tracing_context_allocate( iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, iree_hal_stream_tracing_context_t** out_context) { *out_context = NULL; + interface->vtable->destroy(interface); return iree_ok_status(); } void iree_hal_stream_tracing_context_free( iree_hal_stream_tracing_context_t* context) {} -void iree_hal_stream_tracing_context_collect( - iree_hal_stream_tracing_context_t* context) {} +iree_status_t iree_hal_stream_tracing_context_collect( + iree_hal_stream_tracing_context_t* context) { + return iree_ok_status(); +} + +iree_status_t iree_hal_stream_tracing_context_collect_list( + iree_hal_stream_tracing_context_t* context, + iree_hal_stream_tracing_context_event_t* event) { + return iree_ok_status(); +} void iree_hal_stream_tracing_notify_submitted( iree_hal_stream_tracing_context_t* context, diff --git a/runtime/src/iree/hal/utils/stream_tracing.h b/runtime/src/iree/hal/utils/stream_tracing.h index 16a319e14da5..902a41a323e9 100644 --- a/runtime/src/iree/hal/utils/stream_tracing.h +++ b/runtime/src/iree/hal/utils/stream_tracing.h @@ -133,7 +133,7 @@ void iree_hal_stream_tracing_context_free( // Collects in-flight timestamp queries from the stream and feeds them to tracy. // Must be called frequently (every submission, etc) to drain the backlog; // tracing may start failing if the internal ringbuffer is exceeded. -void iree_hal_stream_tracing_context_collect( +iree_status_t iree_hal_stream_tracing_context_collect( iree_hal_stream_tracing_context_t* context); // Notifies that the given list of events has been dispached on to the gpu. @@ -141,6 +141,17 @@ void iree_hal_stream_tracing_notify_submitted( iree_hal_stream_tracing_context_t* context, iree_hal_stream_tracing_context_event_list_t* event_list); +// Manually collects the events for a specified event list. +// Callers must free the `events` to release them back into the context. +// +// Use this instead of `iree_hal_stream_tracing_notify_submitted` if +// you don't want the stream tracing to manually handle manually +// handle collecting events, (because it may cause more blocking than +// you would prefer) +iree_status_t iree_hal_stream_tracing_context_collect_list( + iree_hal_stream_tracing_context_t* context, + iree_hal_stream_tracing_context_event_t* event_list_head); + // Frees the events and returns them back into the tracing context. void iree_hal_stream_tracing_free( iree_hal_stream_tracing_context_t* context, diff --git a/third_party/llvm-project b/third_party/llvm-project index 65099e8406e8..6038573ce5f7 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 65099e8406e8b7003b64bb9f929511d25358a521 +Subproject commit 6038573ce5f70b6c62db858950ae6040aa182fb9