From 063c09a0949cc20e9bfd5b698f5943bb8fefffee Mon Sep 17 00:00:00 2001 From: avolkov-intel <117643568+avolkov-intel@users.noreply.github.com> Date: Fri, 31 Mar 2023 10:01:53 +0200 Subject: [PATCH] Dev/logloss primitive fix (#2304) * Substitute sycl::reduction with sycl::atomic_reference. Note that leads to performance degradation. * Clang-format --- .../objective_function/logloss_dpc.cpp | 117 +++++++++--------- 1 file changed, 57 insertions(+), 60 deletions(-) diff --git a/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp b/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp index bda30c55331..8d311468345 100644 --- a/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp @@ -85,42 +85,36 @@ sycl::event compute_logloss(sycl::queue& q, auto loss_event = q.submit([&](sycl::handler& cgh) { const auto range = make_range_1d(n); using oneapi::dal::backend::operator+; - using sycl::reduction; - cgh.depends_on(deps); - - auto sumReduction = reduction(out_ptr, sycl::plus<>()); - - cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) { + cgh.parallel_for(range, [=](sycl::id<1> idx) { const Float prob = prob_ptr[idx]; const std::int32_t label = labels_ptr[idx]; - sum += -label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob); + Float& out = *out_ptr; + sycl::atomic_ref(out) + .fetch_add(-label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob)); }); }); - auto [out_reg, out_reg_e] = ndarray::zeros(q, { 1 }, sycl::usm::alloc::device); - auto* const reg_ptr = out_reg.get_mutable_data(); - const event_vector vector_out_reg = { out_reg_e }; - const auto* const param_ptr = parameters.get_data(); if (L1 > 0 || L2 > 0) { auto reg_event = q.submit([&](sycl::handler& cgh) { - cgh.depends_on(vector_out_reg); + cgh.depends_on({ loss_event }); const auto range = make_range_1d(p); - auto sumReduction = sycl::reduction(reg_ptr, sycl::plus<>()); - cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) { + cgh.parallel_for(range, [=](sycl::id<1> idx) { const Float param = param_ptr[idx + 1]; - sum += L1 * sycl::abs(param) + L2 * param * param; - }); - }); - auto final_event = q.submit([&](sycl::handler& cgh) { - cgh.depends_on({ reg_event, loss_event }); - cgh.single_task([=] { - out_ptr[0] += reg_ptr[0]; + Float& out = *out_ptr; + sycl::atomic_ref(out) + .fetch_add(L1 * sycl::abs(param) + L2 * param * param); }); }); - return final_event; + return reg_event; } return loss_event; } @@ -199,28 +193,35 @@ sycl::event compute_logloss_with_der(sycl::queue& q, auto loss_event = q.submit([&](sycl::handler& cgh) { using oneapi::dal::backend::operator+; - using sycl::reduction; cgh.depends_on(deps); - auto sumReductionLogLoss = reduction(out_ptr, sycl::plus<>()); - auto sumReductionDerivativeW0 = reduction(out_derivative_ptr, sycl::plus<>()); const auto wg_size = propose_wg_size(q); const auto range = make_multiple_nd_range_1d(n, wg_size); - cgh.parallel_for( - range, - sumReductionLogLoss, - sumReductionDerivativeW0, - [=](sycl::nd_item<1> id, auto& sum_logloss, auto& sum_Dw0) { - auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id(); - if (idx >= std::size_t(n)) - return; - const Float prob = proba_ptr[idx]; - const float label = labels_ptr[idx]; - sum_logloss += -label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob); - der_obj_ptr[idx] = prob - label; - sum_Dw0 += der_obj_ptr[idx]; - }); + cgh.parallel_for(range, [=](sycl::nd_item<1> id) { + auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id(); + if (idx >= std::size_t(n)) + return; + const Float prob = proba_ptr[idx]; + const float label = labels_ptr[idx]; + + Float& out_logloss = *out_ptr; + Float& out_der = *out_derivative_ptr; + sycl::atomic_ref( + out_logloss) + .fetch_add(-label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob)); + + der_obj_ptr[idx] = prob - label; + + sycl::atomic_ref(out_der) + .fetch_add(der_obj_ptr[idx]); + }); }); auto out_der_suffix = out_derivative.get_slice(1, p + 1); @@ -229,30 +230,24 @@ sycl::event compute_logloss_with_der(sycl::queue& q, if (L1 == 0 && L2 == 0) { return der_event; } - auto [reg_val, reg_val_e] = ndarray::zeros(q, { 1 }, sycl::usm::alloc::device); - - const event_vector reg_deps = { reg_val_e, der_event }; - auto* const reg_ptr = reg_val.get_mutable_data(); auto reg_event = q.submit([&](sycl::handler& cgh) { - cgh.depends_on(reg_deps); + cgh.depends_on({ loss_event, der_event }); const auto range = make_range_1d(p); - auto sumReduction = sycl::reduction(reg_ptr, sycl::plus<>()); - cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) { + cgh.parallel_for(range, [=](sycl::id<1> idx) { const Float param = param_ptr[idx + 1]; - sum += L1 * sycl::abs(param) + L2 * param * param; + Float& out_logloss = *out_ptr; + sycl::atomic_ref( + out_logloss) + .fetch_add(L1 * sycl::abs(param) + L2 * param * param); out_derivative_ptr[idx + 1] += L2 * 2 * param; }); }); - auto final_event = q.submit([&](sycl::handler& cgh) { - cgh.depends_on({ reg_event, loss_event }); - cgh.single_task([=] { - out_ptr[0] += reg_ptr[0]; - }); - }); - - return final_event; + return reg_event; } template @@ -291,21 +286,23 @@ sycl::event compute_derivative(sycl::queue& q, auto* const out_derivative_ptr = out_derivative.get_mutable_data(); auto loss_event = q.submit([&](sycl::handler& cgh) { - using sycl::reduction; - cgh.depends_on(deps); - auto sumReductionDerivativeW0 = reduction(out_derivative_ptr, sycl::plus<>()); const auto wg_size = propose_wg_size(q); const auto range = make_multiple_nd_range_1d(n, wg_size); - cgh.parallel_for(range, sumReductionDerivativeW0, [=](sycl::nd_item<1> id, auto& sum_Dw0) { + cgh.parallel_for(range, [=](sycl::nd_item<1> id) { auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id(); if (idx >= std::size_t(n)) return; const Float prob = proba_ptr[idx]; const Float label = labels_ptr[idx]; der_obj_ptr[idx] = prob - label; - sum_Dw0 += der_obj_ptr[idx]; + Float& out_der = *out_derivative_ptr; + sycl::atomic_ref(out_der) + .fetch_add(der_obj_ptr[idx]); }); });