From 138ddd64f9ee9d17e2ba8ca734a29200d1d92b1b Mon Sep 17 00:00:00 2001 From: Aleksandr Solovev Date: Wed, 10 Jan 2024 11:00:40 +0100 Subject: [PATCH] feature: adding blocking in a table convert function (#2625) --- cpp/oneapi/dal/backend/transfer.hpp | 7 +++ cpp/oneapi/dal/backend/transfer_dpc.cpp | 80 ++++++++++++++++++++++-- cpp/oneapi/dal/table/backend/convert.cpp | 28 ++++++--- 3 files changed, 103 insertions(+), 12 deletions(-) diff --git a/cpp/oneapi/dal/backend/transfer.hpp b/cpp/oneapi/dal/backend/transfer.hpp index 8a994de3c6d..103e46d450e 100644 --- a/cpp/oneapi/dal/backend/transfer.hpp +++ b/cpp/oneapi/dal/backend/transfer.hpp @@ -93,6 +93,13 @@ sycl::event scatter_host2device(sycl::queue& q, std::int64_t dst_stride_in_bytes, std::int64_t block_size_in_bytes, const event_vector& deps = {}); +sycl::event scatter_host2device_blocking(sycl::queue& q, + void* dst_device, + const void* src_host, + std::int64_t block_count, + std::int64_t dst_stride_in_bytes, + std::int64_t block_size_in_bytes, + const event_vector& deps = {}); #endif } // namespace oneapi::dal::backend diff --git a/cpp/oneapi/dal/backend/transfer_dpc.cpp b/cpp/oneapi/dal/backend/transfer_dpc.cpp index 2a72e9ad7cd..6f772e96c56 100644 --- a/cpp/oneapi/dal/backend/transfer_dpc.cpp +++ b/cpp/oneapi/dal/backend/transfer_dpc.cpp @@ -18,6 +18,12 @@ #include namespace oneapi::dal::backend { +namespace bk = dal::backend; +template +std::int64_t propose_block_size(const sycl::queue& q, const std::int64_t r) { + constexpr std::int64_t fsize = sizeof(Float); + return 0x10000l * (8 / fsize); +} sycl::event gather_device2host(sycl::queue& q, void* dst_host, @@ -101,10 +107,11 @@ sycl::event scatter_host2device(sycl::queue& q, auto scatter_event = q.submit([&](sycl::handler& cgh) { cgh.depends_on(copy_event); - byte_t* gathered_byte = reinterpret_cast(gathered_device_unique.get()); - byte_t* dst_byte = reinterpret_cast(dst_device); + const byte_t* const gathered_byte = + reinterpret_cast(gathered_device_unique.get()); + byte_t* const dst_byte = reinterpret_cast(dst_device); - const std::int64_t required_local_size = 256; + const std::int64_t required_local_size = bk::device_max_wg_size(q); const std::int64_t local_size = std::min(down_pow2(block_count), required_local_size); const auto range = make_multiple_nd_range_1d(block_count, local_size); @@ -112,7 +119,7 @@ sycl::event scatter_host2device(sycl::queue& q, const auto i = id.get_global_id(); if (i < block_count) { // TODO: Unroll for optimization - for (int j = 0; j < block_size_in_bytes; j++) { + for (std::int64_t j = 0; j < block_size_in_bytes; ++j) { dst_byte[i * dst_stride_in_bytes + j] = gathered_byte[i * block_size_in_bytes + j]; } @@ -127,4 +134,69 @@ sycl::event scatter_host2device(sycl::queue& q, return sycl::event{}; } +sycl::event scatter_host2device_blocking(sycl::queue& q, + void* dst_device, + const void* src_host, + std::int64_t block_count, + std::int64_t dst_stride_in_bytes, + std::int64_t block_size_in_bytes, + const event_vector& deps) { + ONEDAL_ASSERT(dst_device); + ONEDAL_ASSERT(src_host); + ONEDAL_ASSERT(block_count > 0); + ONEDAL_ASSERT(dst_stride_in_bytes > 0); + ONEDAL_ASSERT(block_size_in_bytes > 0); + ONEDAL_ASSERT(dst_stride_in_bytes >= block_size_in_bytes); + ONEDAL_ASSERT(is_known_usm(q, dst_device)); + ONEDAL_ASSERT_MUL_OVERFLOW(std::int64_t, block_count, block_size_in_bytes); + const auto gathered_device_unique = + make_unique_usm_device(q, block_count * block_size_in_bytes); + + auto copy_event = memcpy_host2usm(q, + gathered_device_unique.get(), + src_host, + block_count * block_size_in_bytes, + deps); + + const byte_t* const gathered_byte = + reinterpret_cast(gathered_device_unique.get()); + byte_t* const dst_byte = reinterpret_cast(dst_device); + + const auto block_size = propose_block_size(q, block_count); + const bk::uniform_blocking blocking(block_count, block_size); + std::vector events(blocking.get_block_count()); + + const auto block_range = blocking.get_block_count(); + + for (std::int64_t block_index = 0; block_index < block_range; ++block_index) { + const auto start_block = blocking.get_block_start_index(block_index); + const auto end_block = blocking.get_block_end_index(block_index); + const auto curr_block = end_block - start_block; + ONEDAL_ASSERT(curr_block > 0); + + auto scatter_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on(copy_event); + + const std::int64_t required_local_size = bk::device_max_wg_size(q); + const std::int64_t local_size = std::min(down_pow2(curr_block), required_local_size); + const auto range = make_multiple_nd_range_1d(curr_block, local_size); + + cgh.parallel_for(range, [=](sycl::nd_item<1> id) { + const auto i = id.get_global_id() + start_block; + if (i < block_count) { + // TODO: Unroll for optimization + for (std::int64_t j = 0; j < block_size_in_bytes; ++j) { + dst_byte[i * dst_stride_in_bytes + j] = + gathered_byte[i * block_size_in_bytes + j]; + } + } + }); + }); + events.push_back(scatter_event); + } + // We need to wait until scatter kernel is completed to deallocate + // `gathered_device_unique` + return bk::wait_or_pass(events); +} + } // namespace oneapi::dal::backend diff --git a/cpp/oneapi/dal/table/backend/convert.cpp b/cpp/oneapi/dal/table/backend/convert.cpp index 4c996609c7b..1830d00c432 100644 --- a/cpp/oneapi/dal/table/backend/convert.cpp +++ b/cpp/oneapi/dal/table/backend/convert.cpp @@ -291,14 +291,26 @@ sycl::event convert_vector_host2device(sycl::queue& q, src_stride, 1L, element_count); - - auto scatter_event = scatter_host2device(q, - dst_device, - tmp_host_unique.get(), - element_count, - dst_stride_in_bytes, - element_size_in_bytes, - deps); + const std::int64_t max_loop_range = std::numeric_limits::max(); + sycl::event scatter_event; + if (element_count > max_loop_range) { + scatter_event = scatter_host2device_blocking(q, + dst_device, + tmp_host_unique.get(), + element_count, + dst_stride_in_bytes, + element_size_in_bytes, + deps); + } + else { + scatter_event = scatter_host2device(q, + dst_device, + tmp_host_unique.get(), + element_count, + dst_stride_in_bytes, + element_size_in_bytes, + deps); + } return scatter_event; }