Skip to content

Commit

Permalink
Dev/logloss primitive fix (#2304)
Browse files Browse the repository at this point in the history
* Substitute sycl::reduction with sycl::atomic_reference. Note that leads to performance degradation.

* Clang-format
  • Loading branch information
avolkov-intel authored Mar 31, 2023
1 parent 8967113 commit 063c09a
Showing 1 changed file with 57 additions and 60 deletions.
117 changes: 57 additions & 60 deletions cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,42 +85,36 @@ sycl::event compute_logloss(sycl::queue& q,
auto loss_event = q.submit([&](sycl::handler& cgh) {
const auto range = make_range_1d(n);
using oneapi::dal::backend::operator+;
using sycl::reduction;

cgh.depends_on(deps);

auto sumReduction = reduction(out_ptr, sycl::plus<>());

cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) {
cgh.parallel_for(range, [=](sycl::id<1> idx) {
const Float prob = prob_ptr[idx];
const std::int32_t label = labels_ptr[idx];
sum += -label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob);
Float& out = *out_ptr;
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::ext_intel_global_device_space>(out)
.fetch_add(-label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob));
});
});

auto [out_reg, out_reg_e] = ndarray<Float, 1>::zeros(q, { 1 }, sycl::usm::alloc::device);
auto* const reg_ptr = out_reg.get_mutable_data();
const event_vector vector_out_reg = { out_reg_e };

const auto* const param_ptr = parameters.get_data();

if (L1 > 0 || L2 > 0) {
auto reg_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on(vector_out_reg);
cgh.depends_on({ loss_event });
const auto range = make_range_1d(p);
auto sumReduction = sycl::reduction(reg_ptr, sycl::plus<>());
cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) {
cgh.parallel_for(range, [=](sycl::id<1> idx) {
const Float param = param_ptr[idx + 1];
sum += L1 * sycl::abs(param) + L2 * param * param;
});
});
auto final_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on({ reg_event, loss_event });
cgh.single_task([=] {
out_ptr[0] += reg_ptr[0];
Float& out = *out_ptr;
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::ext_intel_global_device_space>(out)
.fetch_add(L1 * sycl::abs(param) + L2 * param * param);
});
});
return final_event;
return reg_event;
}
return loss_event;
}
Expand Down Expand Up @@ -199,28 +193,35 @@ sycl::event compute_logloss_with_der(sycl::queue& q,

auto loss_event = q.submit([&](sycl::handler& cgh) {
using oneapi::dal::backend::operator+;
using sycl::reduction;

cgh.depends_on(deps);
auto sumReductionLogLoss = reduction(out_ptr, sycl::plus<>());
auto sumReductionDerivativeW0 = reduction(out_derivative_ptr, sycl::plus<>());
const auto wg_size = propose_wg_size(q);
const auto range = make_multiple_nd_range_1d(n, wg_size);

cgh.parallel_for(
range,
sumReductionLogLoss,
sumReductionDerivativeW0,
[=](sycl::nd_item<1> id, auto& sum_logloss, auto& sum_Dw0) {
auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id();
if (idx >= std::size_t(n))
return;
const Float prob = proba_ptr[idx];
const float label = labels_ptr[idx];
sum_logloss += -label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob);
der_obj_ptr[idx] = prob - label;
sum_Dw0 += der_obj_ptr[idx];
});
cgh.parallel_for(range, [=](sycl::nd_item<1> id) {
auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id();
if (idx >= std::size_t(n))
return;
const Float prob = proba_ptr[idx];
const float label = labels_ptr[idx];

Float& out_logloss = *out_ptr;
Float& out_der = *out_derivative_ptr;
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::ext_intel_global_device_space>(
out_logloss)
.fetch_add(-label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob));

der_obj_ptr[idx] = prob - label;

sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::ext_intel_global_device_space>(out_der)
.fetch_add(der_obj_ptr[idx]);
});
});

auto out_der_suffix = out_derivative.get_slice(1, p + 1);
Expand All @@ -229,30 +230,24 @@ sycl::event compute_logloss_with_der(sycl::queue& q,
if (L1 == 0 && L2 == 0) {
return der_event;
}
auto [reg_val, reg_val_e] = ndarray<Float, 1>::zeros(q, { 1 }, sycl::usm::alloc::device);

const event_vector reg_deps = { reg_val_e, der_event };
auto* const reg_ptr = reg_val.get_mutable_data();

auto reg_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on(reg_deps);
cgh.depends_on({ loss_event, der_event });
const auto range = make_range_1d(p);
auto sumReduction = sycl::reduction(reg_ptr, sycl::plus<>());
cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) {
cgh.parallel_for(range, [=](sycl::id<1> idx) {
const Float param = param_ptr[idx + 1];
sum += L1 * sycl::abs(param) + L2 * param * param;
Float& out_logloss = *out_ptr;
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::ext_intel_global_device_space>(
out_logloss)
.fetch_add(L1 * sycl::abs(param) + L2 * param * param);
out_derivative_ptr[idx + 1] += L2 * 2 * param;
});
});

auto final_event = q.submit([&](sycl::handler& cgh) {
cgh.depends_on({ reg_event, loss_event });
cgh.single_task([=] {
out_ptr[0] += reg_ptr[0];
});
});

return final_event;
return reg_event;
}

template <typename Float>
Expand Down Expand Up @@ -291,21 +286,23 @@ sycl::event compute_derivative(sycl::queue& q,
auto* const out_derivative_ptr = out_derivative.get_mutable_data();

auto loss_event = q.submit([&](sycl::handler& cgh) {
using sycl::reduction;

cgh.depends_on(deps);
auto sumReductionDerivativeW0 = reduction(out_derivative_ptr, sycl::plus<>());
const auto wg_size = propose_wg_size(q);
const auto range = make_multiple_nd_range_1d(n, wg_size);

cgh.parallel_for(range, sumReductionDerivativeW0, [=](sycl::nd_item<1> id, auto& sum_Dw0) {
cgh.parallel_for(range, [=](sycl::nd_item<1> id) {
auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id();
if (idx >= std::size_t(n))
return;
const Float prob = proba_ptr[idx];
const Float label = labels_ptr[idx];
der_obj_ptr[idx] = prob - label;
sum_Dw0 += der_obj_ptr[idx];
Float& out_der = *out_derivative_ptr;
sycl::atomic_ref<Float,
sycl::memory_order::relaxed,
sycl::memory_scope::device,
sycl::access::address_space::ext_intel_global_device_space>(out_der)
.fetch_add(der_obj_ptr[idx]);
});
});

Expand Down

0 comments on commit 063c09a

Please sign in to comment.