Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: enabling oneDPL and sort primitive refactoring #3046

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion .ci/pipeline/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ variables:
VM_IMAGE : 'ubuntu-22.04'
SYSROOT_OS: 'jammy'
WINDOWS_BASEKIT_URL: 'https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe'
WINDOWS_DPCPP_COMPONENTS: 'intel.oneapi.win.mkl.devel:intel.oneapi.win.tbb.devel'
WINDOWS_DPCPP_COMPONENTS: 'intel.oneapi.win.mkl.devel:intel.oneapi.win.tbb.devel:intel.oneapi.win.dpl'
LINUX_DPL_URL: 'https://registrationcenter-download.intel.com/akdlm/IRC_NAS/de3c613f-829c-4bdc-aa2b-6129eece3bd9/intel-onedpl-2022.7.1.15_offline.sh'

resources:
repositories:
Expand Down Expand Up @@ -71,6 +72,9 @@ jobs:
- script: |
.ci/env/apt.sh mkl
displayName: 'mkl installation'
- script:
chmod +x .ci/scripts/install_dpl.sh && .ci/scripts/install_dpl.sh $(LINUX_DPL_URL)
displayName: 'Install oneAPI Base Toolkit'
Alexandr-Solovev marked this conversation as resolved.
Show resolved Hide resolved
- script: |
source /opt/intel/oneapi/setvars.sh
.ci/scripts/describe_system.sh
Expand Down Expand Up @@ -393,6 +397,9 @@ jobs:
- script: |
.ci/env/apt.sh mkl
displayName: 'mkl installation'
- script:
chmod +x .ci/scripts/install_dpl.sh && .ci/scripts/install_dpl.sh $(LINUX_DPL_URL)
displayName: 'Install oneAPI Base Toolkit'
- script: |
source /opt/intel/oneapi/setvars.sh
.ci/scripts/describe_system.sh
Expand Down
31 changes: 31 additions & 0 deletions .ci/scripts/install_dpl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash
#===============================================================================
# Copyright contributors to the oneDAL project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

URL=$1

# Download the installation script
curl --output installer.sh --url "$URL" --retry 5 --retry-delay 5
chmod +x installer.sh

# Execute the installation script
sudo sh installer.sh -a --silent --eula accept
installer_exit_code=$?

# Clean up
rm -f installer.sh

exit $installer_exit_code
15 changes: 15 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ ccl_repo(
]
)

load("@onedal//dev/bazel/deps:dpl.bzl", "dpl_repo")
dpl_repo(
name = "dpl",
root_env_var = "DPL_ROOT",
david-cortes-intel marked this conversation as resolved.
Show resolved Hide resolved
urls = [
"https://files.pythonhosted.org/packages/95/f6/18f78cb933e01ecd9e99d37a10da4971a795fcfdd1d24640799b4050fdbb/onedpl_devel-2022.7.1-py2.py3-none-manylinux_2_28_x86_64.whl",
],
sha256s = [
"3b270999d2464c5151aa0e7995dda9e896d072c75069ccee1efae9dc56bdc417",
],
strip_prefixes = [
"onedpl_devel-2022.7.1.data/data",
],
)

load("@onedal//dev/bazel/deps:mkl.bzl", "mkl_repo")
mkl_repo(
name = "mkl",
Expand Down
1 change: 1 addition & 0 deletions cpp/oneapi/dal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dal_module(
],
dpc_deps = [
"@mkl//:mkl_dpc",
"@dpl//:headers",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ inline sycl::event sort_inplace(sycl::queue& queue_,
const bk::event_vector& deps = {}) {
ONEDAL_ASSERT(src.get_count() > 0);
auto src_ind = pr::ndarray<Index, 1>::empty(queue_, { src.get_count() });
return pr::radix_sort_indices_inplace<Float, Index>{ queue_ }(src, src_ind, deps);
return pr::radix_sort_indices_inplace<Float, Index>(queue_, src, src_ind, deps);
}

template <typename Float, typename Bin, typename Index>
Expand Down Expand Up @@ -429,13 +429,14 @@ sycl::event indexed_features<Float, Bin, Index>::operator()(const table& tbl,
pr::ndarray<Bin, 1>::empty(queue_, { row_count_ }, sycl::usm::alloc::device);
}

pr::radix_sort_indices_inplace<Float, Index> sort{ queue_ };

sycl::event last_event;

for (Index i = 0; i < column_count_; i++) {
last_event = extract_column(data_nd_, values_nd, indices_nd, i, { last_event });
last_event = sort(values_nd, indices_nd, { last_event });
last_event = pr::radix_sort_indices_inplace<Float, Index>(queue_,
values_nd,
indices_nd,
{ last_event });
last_event =
compute_bins(values_nd, indices_nd, column_bin_vec_[i], entries_[i], i, { last_event });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ static auto fill_candidate_indices_and_distances(sycl::queue& queue,
});
});

pr::radix_sort_indices_inplace<Float, std::int32_t> radix_sort{ queue };
auto sort_event = radix_sort(values, indices, { fill_event });
auto sort_event =
pr::radix_sort_indices_inplace<Float, std::int32_t>(queue, values, indices, { fill_event });

auto copy_event = queue.submit([&](sycl::handler& cgh) {
cgh.depends_on(sort_event);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,12 @@ sycl::event working_set_selector<Float>::sort_f_indices(sycl::queue& q,

auto copy_event = dal::backend::copy(q, tmp_sort_ptr, f_ptr, row_count_, deps);
auto arange_event = sorted_f_indices_.arange(q);
auto radix_sort = pr::radix_sort_indices_inplace<Float, std::int32_t>{ q };

auto radix_sort_event =
radix_sort(tmp_sort_values_, sorted_f_indices_, { copy_event, arange_event });
pr::radix_sort_indices_inplace<Float, std::int32_t>(q,
tmp_sort_values_,
sorted_f_indices_,
{ copy_event, arange_event });

return radix_sort_event;
}
Expand Down
134 changes: 14 additions & 120 deletions cpp/oneapi/dal/backend/primitives/sort/sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,130 +36,24 @@ struct float2uint_map<double> {
using integer_t = std::uint64_t;
};

/// @tparam Float Floating-point type used for storing input values
/// @tparam Index Integer type used for storing input indices
template <typename Float, typename Index = std::uint32_t>
class radix_sort_indices_inplace {
static_assert(std::is_same_v<float, Float> || std::is_same_v<double, Float>);
using radix_integer_t = typename float2uint_map<Float>::integer_t;
sycl::event radix_sort_indices_inplace(sycl::queue& queue,
ndview<Float, 1>& val,
ndview<Index, 1>& ind,
const event_vector& deps = {});

public:
/// Performs initialization of auxiliary variables and required auxiliary buffers
///
/// @param[in] queue The queue
/// @param[in] elem_count The number of elements in input vector
radix_sort_indices_inplace(const sycl::queue& queue);
radix_sort_indices_inplace(const radix_sort_indices_inplace&) = delete;
~radix_sort_indices_inplace();
radix_sort_indices_inplace& operator=(const radix_sort_indices_inplace&) = delete;

/// Performs inplace radix sort of input vector and corresponding indices
/// NOTE: auxiliary buffers and variables are reset in case if number of elements in val
/// differs from the number of elements provided in constructor
///
/// @param[in, out] val The [n] input/output vector of values to sort out
/// @param[in, out] ind The [n] input/output vector of corresponding indices
sycl::event operator()(ndview<Float, 1>& val,
ndview<Index, 1>& ind,
const event_vector& deps = {});

private:
void init(sycl::queue& queue, std::int64_t elem_count);
sycl::event radix_scan(sycl::queue& queue,
const ndview<Float, 1>& val,
ndarray<Index, 1>& part_hist,
Index elem_count,
std::uint32_t bit_offset,
std::int64_t local_size,
std::int64_t local_hist_count,
sycl::event& deps);
sycl::event radix_hist_scan(sycl::queue& queue,
const ndarray<Index, 1>& part_hist,
ndarray<Index, 1>& part_prefix_hist,
std::int64_t local_size,
std::int64_t local_hist_count,
sycl::event& deps);
sycl::event radix_reorder(sycl::queue& queue,
const ndview<Float, 1>& val_in,
const ndview<Index, 1>& ind_in,
const ndview<Index, 1>& part_prefix_hist,
ndview<Float, 1>& val_out,
ndview<Index, 1>& ind_out,
Index elem_count,
std::uint32_t bit_offset,
std::int64_t local_size,
std::int64_t local_hist_count,
sycl::event& deps);

sycl::queue queue_;
sycl::event sort_event_;

ndarray<Float, 1> val_buff_;
ndarray<Index, 1> ind_buff_;

ndarray<Index, 1> part_hist_;
ndarray<Index, 1> part_prefix_hist_;

std::uint32_t elem_count_;
std::uint32_t local_size_;
std::uint32_t local_hist_count_;
std::uint32_t hist_buff_size_;

static constexpr inline std::uint32_t radix_bits_ = 4;
static constexpr inline std::uint32_t radix_range_ = (std::uint32_t)1 << radix_bits_;
static constexpr inline std::uint32_t radix_range_1_ = radix_range_ - 1;

static constexpr inline std::uint32_t byte_range_ = 8;
static constexpr inline std::uint32_t max_local_hist_count_ = 1024;
static constexpr inline std::uint32_t preferable_sbg_size_ = 16;
};

/// @tparam Integer Integer type used for storing input values
template <typename Integer>
class radix_sort {
public:
/// Performs initialization of auxiliary variables and required auxiliary buffers
///
/// @param[in] queue The queue
/// @param[in] vector_count The number of vectors (rows) in input array
radix_sort(const sycl::queue& queue);
radix_sort(const radix_sort&) = delete;
~radix_sort();
radix_sort& operator=(const radix_sort&) = delete;

/// Performs radix sort of batch of integer input vectors
/// NOTE: only positive values are supported for now.
/// Auxiliary buffers and variables are reset in case if number of elements in val
/// differs from the number of elements provided in constructor
///
/// @param[in] val_in The [n x p] input array of vectors (row major format) to sort out,
/// is also used for temporary data storage
/// @param[out] val_out The [n x p] output array of sorted vectors (row major format)
/// @param[in] sorted_elem_count The number of elements to sort in each vector
/// TODO: Extend interface with strided (not dense) input & output arrays
sycl::event operator()(ndview<Integer, 2>& val_in,
ndview<Integer, 2>& val_out,
std::int64_t sorted_elem_count,
const event_vector& deps = {});

sycl::event operator()(ndview<Integer, 2>& val_in,
ndview<Integer, 2>& val_out,
const event_vector& deps = {});

private:
void init(sycl::queue& queue, std::int64_t vector_count);

sycl::queue queue_;
sycl::event sort_event_;
sycl::event radix_sort(sycl::queue& queue,
ndview<Integer, 2>& val_in,
ndview<Integer, 2>& val_out,
std::int64_t sorted_elem_count,
const event_vector& deps = {});

ndarray<Integer, 2> buffer_;

std::uint32_t vector_count_;

static constexpr inline std::uint32_t preferable_wg_size_ = 32;
static constexpr inline std::uint32_t radix_range_ = 256;
static constexpr inline std::uint32_t radix_count_ = sizeof(Integer);
};
template <typename Integer>
sycl::event radix_sort(sycl::queue& queue,
ndview<Integer, 2>& val_in,
ndview<Integer, 2>& val_out,
const event_vector& deps = {});

#endif

Expand Down
Loading
Loading