Skip to content

Commit

Permalink
[hip] Update the hip runtime so that multiple physical devices can ba…
Browse files Browse the repository at this point in the history
…ck a logical device. (#18790)

This gives us an interface for creating a logical device from a set of
physical hip devices.

Each physical device in the new logical device is represented as a queue in the
logical device and therefore can be accessed by specifying a queue affinity
for operations and buffers.

This is a pretty substantial change to how hip devices are handled. 
In general:
1. Synchronization is updated so that we no longer have to do a CPU round trip
  for GPU<->GPU synchronization.
2. The deferred work queue is no longer in use in HIP at all.
3. We have 2 much simpler helper threads per device. They handle keeping high-cost operations off the main thread. Specifically, async memory operations, and queue executions which are emulated in HIP, and have a relatively high cost.

Signed-off-by: Andrew Woloszyn <[email protected]>
  • Loading branch information
AWoloszyn authored Dec 11, 2024
1 parent c315833 commit 0e71e72
Show file tree
Hide file tree
Showing 57 changed files with 5,146 additions and 2,026 deletions.
23 changes: 10 additions & 13 deletions runtime/src/iree/hal/cts/semaphore_submission_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ TEST_F(SemaphoreSubmissionTest, SubmitWithNoCommandBuffers) {
signal_payload_values,
};

IREE_ASSERT_OK(iree_hal_device_queue_barrier(device_,
/*queue_affinity=*/0,
iree_hal_semaphore_list_empty(),
signal_semaphores));
IREE_ASSERT_OK(iree_hal_device_queue_barrier(
device_, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
signal_semaphores));
IREE_ASSERT_OK(
iree_hal_semaphore_wait(signal_semaphore, 1, iree_infinite_timeout()));

Expand All @@ -54,9 +53,9 @@ TEST_F(SemaphoreSubmissionTest, SubmitAndSignal) {
};

IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_,
/*queue_affinity=*/0, iree_hal_semaphore_list_empty(), signal_semaphores,
command_buffer, iree_hal_buffer_binding_table_empty()));
device_, IREE_HAL_QUEUE_AFFINITY_ANY, iree_hal_semaphore_list_empty(),
signal_semaphores, command_buffer,
iree_hal_buffer_binding_table_empty()));
IREE_ASSERT_OK(
iree_hal_semaphore_wait(signal_semaphore, 1, iree_infinite_timeout()));

Expand Down Expand Up @@ -87,9 +86,8 @@ TEST_F(SemaphoreSubmissionTest, SubmitWithWait) {
};

IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_,
/*queue_affinity=*/0, wait_semaphores, signal_semaphores, command_buffer,
iree_hal_buffer_binding_table_empty()));
device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, signal_semaphores,
command_buffer, iree_hal_buffer_binding_table_empty()));

// Work shouldn't start until the wait semaphore reaches its payload value.
CheckSemaphoreValue(signal_semaphore, 100);
Expand Down Expand Up @@ -130,9 +128,8 @@ TEST_F(SemaphoreSubmissionTest, SubmitWithMultipleSemaphores) {
};

IREE_ASSERT_OK(iree_hal_device_queue_execute(
device_,
/*queue_affinity=*/0, wait_semaphores, signal_semaphores, command_buffer,
iree_hal_buffer_binding_table_empty()));
device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, signal_semaphores,
command_buffer, iree_hal_buffer_binding_table_empty()));

// Work shouldn't start until all wait semaphores reach their payload values.
CheckSemaphoreValue(signal_semaphore_1, 0);
Expand Down
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/cuda/cuda_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ static iree_status_t iree_hal_cuda_device_create_command_buffer(
if (binding_capacity > 0) {
return iree_hal_deferred_command_buffer_create(
iree_hal_device_allocator(base_device), mode, command_categories,
binding_capacity, &device->block_pool,
queue_affinity, binding_capacity, &device->block_pool,
iree_hal_device_host_allocator(base_device), out_command_buffer);
} else {
return iree_hal_cuda_graph_command_buffer_create(
Expand All @@ -867,7 +867,7 @@ static iree_status_t iree_hal_cuda_device_create_command_buffer(
case IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM: {
return iree_hal_deferred_command_buffer_create(
iree_hal_device_allocator(base_device), mode, command_categories,
binding_capacity, &device->block_pool,
queue_affinity, binding_capacity, &device->block_pool,
iree_hal_device_host_allocator(base_device), out_command_buffer);
}
default: {
Expand Down
16 changes: 10 additions & 6 deletions runtime/src/iree/hal/drivers/hip/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ iree_cc_library(
"api.h"
SRCS
"api.h"
"context_util.h"
"cleanup_thread.c"
"cleanup_thread.h"
"dispatch_thread.c"
"dispatch_thread.h"
"event_pool.c"
"event_pool.h"
"event_semaphore.c"
Expand All @@ -32,21 +35,22 @@ iree_cc_library(
"hip_allocator.h"
"hip_buffer.c"
"hip_buffer.h"
"hip_device.c"
"hip_device.h"
"hip_driver.c"
"hip_device.h"
"hip_device.c"
"hip_multi_queue_command_buffer.h"
"hip_multi_queue_command_buffer.c"
"memory_pools.c"
"memory_pools.h"
"native_executable.c"
"native_executable.h"
"nop_executable_cache.c"
"nop_executable_cache.h"
"per_device_information.h"
"rccl_channel.c"
"rccl_channel.h"
"stream_command_buffer.c"
"stream_command_buffer.h"
"timepoint_pool.c"
"timepoint_pool.h"
INCLUDES
"${HIP_API_HEADERS_ROOT}"
DEPS
Expand All @@ -65,12 +69,12 @@ iree_cc_library(
iree::hal::utils::collective_batch
iree::hal::utils::executable_debug_info
iree::hal::utils::deferred_command_buffer
iree::hal::utils::deferred_work_queue
iree::hal::utils::file_transfer
iree::hal::utils::memory_file
iree::hal::utils::resource_set
iree::hal::utils::semaphore_base
iree::hal::utils::stream_tracing
iree::hal::drivers::hip::util::hip_util
iree::schemas::executable_debug_info_c_fbs
iree::schemas::hip_executable_def_c_fbs
PUBLIC
Expand Down
170 changes: 170 additions & 0 deletions runtime/src/iree/hal/drivers/hip/cleanup_thread.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/hal/drivers/hip/cleanup_thread.h"

#include "iree/base/internal/synchronization.h"
#include "iree/base/internal/threading.h"
#include "iree/hal/drivers/hip/event_pool.h"
#include "iree/hal/drivers/hip/status_util.h"
#include "iree/hal/drivers/hip/util/queue.h"

#define iree_hal_hip_cleanup_thread_default_queue_size 64

typedef struct iree_hal_hip_cleanup_thread_callback_t {
iree_hal_hip_cleanup_callback_t callback;
void* user_data;
iree_hal_hip_event_t* event;
} iree_hal_hip_cleanup_thread_callback_t;

IREE_HAL_HIP_UTIL_TYPED_QUEUE_WRAPPER(
iree_hal_hip_callback_queue, iree_hal_hip_cleanup_thread_callback_t,
iree_hal_hip_cleanup_thread_default_queue_size);

typedef struct iree_hal_hip_cleanup_thread_t {
iree_thread_t* thread;
iree_allocator_t host_allocator;
const iree_hal_hip_dynamic_symbols_t* symbols;
iree_slim_mutex_t mutex;

iree_hal_hip_callback_queue_t queue;
iree_status_t failure_status;
iree_notification_t notification;
bool do_exit;
} iree_hal_hip_cleanup_thread_t;

static bool iree_hal_hip_cleanup_thread_has_request(void* user_data) {
iree_hal_hip_cleanup_thread_t* thread =
(iree_hal_hip_cleanup_thread_t*)user_data;
iree_slim_mutex_lock(&thread->mutex);
bool has_request = !iree_hal_hip_callback_queue_empty(&thread->queue);
has_request |= thread->do_exit;
iree_slim_mutex_unlock(&thread->mutex);
return has_request;
}

static int iree_hal_hip_cleanup_thread_main(void* param) {
iree_hal_hip_cleanup_thread_t* thread = (iree_hal_hip_cleanup_thread_t*)param;
bool exit = false;
while (true) {
iree_notification_await(&thread->notification,
&iree_hal_hip_cleanup_thread_has_request, thread,
iree_infinite_timeout());

iree_slim_mutex_lock(&thread->mutex);
exit |= thread->do_exit;
iree_status_t status = thread->failure_status;
while (!iree_hal_hip_callback_queue_empty(&thread->queue)) {
iree_hal_hip_cleanup_thread_callback_t callback =
iree_hal_hip_callback_queue_at(&thread->queue, 0);
iree_hal_hip_callback_queue_pop_front(&thread->queue, 1);
iree_slim_mutex_unlock(&thread->mutex);

if (iree_status_is_ok(status)) {
status = IREE_HIP_CALL_TO_STATUS(
thread->symbols,
hipEventSynchronize(iree_hal_hip_event_handle(callback.event)));
}

status = iree_status_join(
status,
callback.callback(callback.user_data, callback.event, status));
iree_slim_mutex_lock(&thread->mutex);
if (!iree_status_is_ok(status)) {
thread->failure_status = status;
}
if (!iree_status_is_ok(thread->failure_status)) {
status = iree_status_clone(thread->failure_status);
}
}
iree_slim_mutex_unlock(&thread->mutex);

if (!iree_status_is_ok(status) || exit) {
break;
}
}
return 0;
}

iree_status_t iree_hal_hip_cleanup_thread_initialize(
const iree_hal_hip_dynamic_symbols_t* symbols,
iree_allocator_t host_allocator,
iree_hal_hip_cleanup_thread_t** out_thread) {
IREE_TRACE_ZONE_BEGIN(z0);
*out_thread = NULL;
iree_hal_hip_cleanup_thread_t* thread = NULL;

IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_allocator_malloc(host_allocator, sizeof(*thread), (void**)&thread));

thread->symbols = symbols;
thread->do_exit = false;
iree_slim_mutex_initialize(&thread->mutex);
iree_hal_hip_callback_queue_initialize(host_allocator, &thread->queue);
thread->failure_status = iree_ok_status();
thread->host_allocator = host_allocator;
iree_notification_initialize(&thread->notification);

iree_thread_create_params_t params;
memset(&params, 0x00, sizeof(params));
params.name = iree_make_cstring_view("iree-hal-hip-cleanup");
iree_status_t status =
iree_thread_create((iree_thread_entry_t)iree_hal_hip_cleanup_thread_main,
thread, params, host_allocator, &thread->thread);
if (iree_status_is_ok(status)) {
*out_thread = thread;
} else {
iree_hal_hip_callback_queue_deinitialize(&thread->queue);
iree_slim_mutex_deinitialize(&thread->mutex);
iree_allocator_free(host_allocator, thread);
}
IREE_TRACE_ZONE_END(z0);
return status;
}

void iree_hal_hip_cleanup_thread_deinitialize(
iree_hal_hip_cleanup_thread_t* thread) {
IREE_TRACE_ZONE_BEGIN(z0);

iree_slim_mutex_lock(&thread->mutex);
thread->do_exit = true;
iree_slim_mutex_unlock(&thread->mutex);

iree_notification_post(&thread->notification, IREE_ALL_WAITERS);
// There is only one owner for the thread, so this also joins the thread.
iree_thread_release(thread->thread);

iree_hal_hip_callback_queue_deinitialize(&thread->queue);
iree_slim_mutex_deinitialize(&thread->mutex);
iree_allocator_free(thread->host_allocator, thread);
IREE_TRACE_ZONE_END(z0);
}

iree_status_t iree_hal_hip_cleanup_thread_add_cleanup(
iree_hal_hip_cleanup_thread_t* thread, iree_hal_hip_event_t* event,
iree_hal_hip_cleanup_callback_t callback, void* user_data) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_slim_mutex_lock(&thread->mutex);
if (!iree_status_is_ok(thread->failure_status)) {
IREE_TRACE_ZONE_END(z0);
iree_slim_mutex_unlock(&thread->mutex);
return thread->failure_status;
}

iree_hal_hip_cleanup_thread_callback_t callback_data = {
.callback = callback,
.user_data = user_data,
.event = event,
};
iree_hal_hip_callback_queue_push_back(&thread->queue, callback_data);
iree_slim_mutex_unlock(&thread->mutex);
iree_notification_post(&thread->notification, IREE_ALL_WAITERS);

IREE_TRACE_ZONE_END(z0);

return iree_ok_status();
}
40 changes: 40 additions & 0 deletions runtime/src/iree/hal/drivers/hip/cleanup_thread.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#ifndef IREE_HAL_DRIVERS_HIP_CLEANUP_THREAD_H_
#define IREE_HAL_DRIVERS_HIP_CLEANUP_THREAD_H_

#include "iree/base/api.h"
#include "iree/hal/api.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"

typedef struct iree_hal_hip_cleanup_thread_t iree_hal_hip_cleanup_thread_t;
typedef struct iree_hal_hip_event_t iree_hal_hip_event_t;

typedef iree_status_t (*iree_hal_hip_cleanup_callback_t)(
void* user_data, iree_hal_hip_event_t* event, iree_status_t status);

// Initializes the cleanup thread for HIP driver.
iree_status_t iree_hal_hip_cleanup_thread_initialize(
const iree_hal_hip_dynamic_symbols_t* symbols,
iree_allocator_t host_allocator,
iree_hal_hip_cleanup_thread_t** out_thread);

// Deinitializes the cleanup thread for HIP driver.
void iree_hal_hip_cleanup_thread_deinitialize(
iree_hal_hip_cleanup_thread_t* thread);

// Adds a pending cleanup to the thread.
//
// The thread will wait on the event and fire the callback,
// once the event has completed.
// |user_data| must remain valid until the callback is called,
// and it is up to the callee to clean up user_data if required.
iree_status_t iree_hal_hip_cleanup_thread_add_cleanup(
iree_hal_hip_cleanup_thread_t* thread, iree_hal_hip_event_t* event,
iree_hal_hip_cleanup_callback_t callback, void* user_data);

#endif // IREE_HAL_DRIVERS_HIP_CLEANUP_THREAD_H_
4 changes: 2 additions & 2 deletions runtime/src/iree/hal/drivers/hip/context_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ static inline iree_status_t iree_hal_hip_set_context(
if (current_context != hip_context) {
IREE_TRACE_ZONE_BEGIN_NAMED(z0, "iree_hal_hip_set_context_switch");
iree_status_t status =
IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
IREE_HIP_CALL_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
IREE_TRACE_ZONE_END(z0);
return status;
}
});
return IREE_HIP_RESULT_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
return IREE_HIP_CALL_TO_STATUS(syms, hipCtxSetCurrent(hip_context));
}

#endif // IREE_HAL_DRIVERS_HIP_CONTEXT_UTIL_H_
Loading

0 comments on commit 0e71e72

Please sign in to comment.