From 98fc59c44c01a5221703cc2af467c88ace6b67cf Mon Sep 17 00:00:00 2001 From: avolkov-intel <117643568+avolkov-intel@users.noreply.github.com> Date: Thu, 2 Mar 2023 12:59:45 +0100 Subject: [PATCH] Dev/logloss (#2284) * Initial commit * Fix reduction 2d files * Update tests * Fix code format * Add gemv implemetation and tests * Minor fix * Add logloss without regularization computation * Add derivative and naive hessian computation * Add 4 hessian computation implementations and perfomance tests * Add tests, header and update BUILD file * Add skip_if to double tests, fix timeouts * Update regularization computation --- cpp/oneapi/dal/backend/primitives/BUILD | 2 + cpp/oneapi/dal/backend/primitives/ndarray.hpp | 65 ++- .../backend/primitives/objective_function.hpp | 19 + .../primitives/objective_function/BUILD | 42 ++ .../primitives/objective_function/logloss.hpp | 85 +++ .../objective_function/logloss_dpc.cpp | 470 +++++++++++++++++ .../objective_function/test/logloss_dpc.cpp | 499 ++++++++++++++++++ .../test/logloss_perf_dpc.cpp | 151 ++++++ .../primitives/reduction/reduction_1d.hpp | 2 +- .../primitives/reduction/reduction_1d_dpc.cpp | 2 +- .../reduction/test/reduction_1d_dpc.cpp | 2 +- 11 files changed, 1311 insertions(+), 28 deletions(-) create mode 100644 cpp/oneapi/dal/backend/primitives/objective_function.hpp create mode 100644 cpp/oneapi/dal/backend/primitives/objective_function/BUILD create mode 100644 cpp/oneapi/dal/backend/primitives/objective_function/logloss.hpp create mode 100644 cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp create mode 100644 cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp create mode 100644 cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_perf_dpc.cpp diff --git a/cpp/oneapi/dal/backend/primitives/BUILD b/cpp/oneapi/dal/backend/primitives/BUILD index f111d5dbd24..01c53411152 100644 --- a/cpp/oneapi/dal/backend/primitives/BUILD +++ b/cpp/oneapi/dal/backend/primitives/BUILD @@ -28,6 +28,7 @@ dal_collect_modules( "heap", "intersection", "lapack", + "objective_function", "reduction", "regression", "rng", @@ -74,6 +75,7 @@ dal_collect_test_suites( modules = [ "blas", "lapack", + "objective_function", "reduction", "selection", "sort", diff --git a/cpp/oneapi/dal/backend/primitives/ndarray.hpp b/cpp/oneapi/dal/backend/primitives/ndarray.hpp index 8d7e22423a5..761298f6e3e 100644 --- a/cpp/oneapi/dal/backend/primitives/ndarray.hpp +++ b/cpp/oneapi/dal/backend/primitives/ndarray.hpp @@ -165,6 +165,9 @@ class ndarray_base : public base { ndshape strides_; }; +template +class ndarray; + template class ndview : public ndarray_base { static_assert(!std::is_const_v, "T must be non-const"); @@ -309,6 +312,11 @@ class ndview : public ndarray_base { return this->t().get_row_slice(from_col, to_col).t(); } +#ifdef ONEDAL_DATA_PARALLEL + ndarray to_host(sycl::queue& q, const event_vector& deps = {}) const; + ndarray to_device(sycl::queue& q, const event_vector& deps = {}) const; +#endif + #ifdef ONEDAL_DATA_PARALLEL sycl::event prefetch(sycl::queue& queue) const { return queue.prefetch(data_, this->get_count()); @@ -487,7 +495,7 @@ inline sycl::event fill(sycl::queue& q, } #endif -template +template class ndarray : public ndview { template friend class ndarray; @@ -763,30 +771,6 @@ class ndarray : public ndview { } #endif -#ifdef ONEDAL_DATA_PARALLEL - ndarray to_host(sycl::queue& q, const event_vector& deps = {}) const { - T* host_ptr = dal::detail::host_allocator().allocate(this->get_count()); - dal::backend::copy_usm2host(q, host_ptr, this->get_data(), this->get_count(), deps) - .wait_and_throw(); - return wrap(host_ptr, - this->get_shape(), - dal::detail::make_default_delete(dal::detail::default_host_policy{})); - } -#endif - -#ifdef ONEDAL_DATA_PARALLEL - ndarray to_device(sycl::queue& q, const event_vector& deps = {}) const { - ndarray dev = empty(q, this->get_shape(), sycl::usm::alloc::device); - dal::backend::copy_host2usm(q, - dev.get_mutable_data(), - this->get_data(), - this->get_count(), - deps) - .wait_and_throw(); - return dev; - } -#endif - ndarray slice(std::int64_t offset, std::int64_t count, std::int64_t axis = 0) const { ONEDAL_ASSERT(order == ndorder::c, "Only C-order is supported"); ONEDAL_ASSERT(axis == 0, "Non-zero axis is not supported"); @@ -862,6 +846,37 @@ class ndarray : public ndview { shared_t data_; }; +#ifdef ONEDAL_DATA_PARALLEL +template +ndarray ndview::to_host( + sycl::queue& q, + const event_vector& deps) const { + T* host_ptr = dal::detail::host_allocator().allocate(this->get_count()); + dal::backend::copy_usm2host(q, host_ptr, this->get_data(), this->get_count(), deps) + .wait_and_throw(); + return ndarray::wrap( + host_ptr, + this->get_shape(), + dal::detail::make_default_delete(dal::detail::default_host_policy{})); +} +#endif + +#ifdef ONEDAL_DATA_PARALLEL +template +ndarray ndview::to_device( + sycl::queue& q, + const event_vector& deps) const { + auto dev = ndarray::empty(q, this->get_shape(), sycl::usm::alloc::device); + dal::backend::copy_host2usm(q, + dev.get_mutable_data(), + this->get_data(), + this->get_count(), + deps) + .wait_and_throw(); + return dev; +} +#endif + #ifdef ONEDAL_DATA_PARALLEL template +sycl::event compute_probabilities(sycl::queue& q, + const ndview& parameters, + const ndview& data, + ndview& predictions, + const event_vector& deps = {}); + +template +sycl::event compute_logloss(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + ndview& out, + Float L1 = Float(0), + Float L2 = Float(0), + const event_vector& deps = {}); + +template +sycl::event compute_logloss(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out, + Float L1 = Float(0), + Float L2 = Float(0), + const event_vector& deps = {}); + +template +sycl::event compute_logloss_with_der(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out, + ndview& out_derivative, + Float L1 = Float(0), + Float L2 = Float(0), + const event_vector& deps = {}); + +template +sycl::event compute_derivative(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out_derivative, + Float L1 = Float(0), + Float L2 = Float(0), + const event_vector& deps = {}); + +template +sycl::event compute_hessian(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out_hessian, + Float L1 = Float(0), + Float L2 = Float(0), + const event_vector& deps = {}); + +} // namespace oneapi::dal::backend::primitives diff --git a/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp b/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp new file mode 100644 index 00000000000..bda30c55331 --- /dev/null +++ b/cpp/oneapi/dal/backend/primitives/objective_function/logloss_dpc.cpp @@ -0,0 +1,470 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/backend/primitives/objective_function/logloss.hpp" +#include "oneapi/dal/backend/primitives/blas/gemv.hpp" + +namespace oneapi::dal::backend::primitives { + +template +sycl::event compute_probabilities(sycl::queue& q, + const ndview& parameters, + const ndview& data, + ndview& probabilities, + const event_vector& deps) { + const std::int64_t n = data.get_dimension(0); + const std::int64_t p = data.get_dimension(1); + ONEDAL_ASSERT(data.has_data()); + ONEDAL_ASSERT(parameters.has_data()); + ONEDAL_ASSERT(probabilities.has_mutable_data()); + ONEDAL_ASSERT(parameters.get_dimension(0) == p + 1); + ONEDAL_ASSERT(probabilities.get_dimension(0) == n); + + auto fill_event = fill(q, probabilities, Float(1), {}); + using oneapi::dal::backend::operator+; + + auto param_arr = ndarray::wrap(parameters.get_data(), 1); + Float w0 = param_arr.slice(0, 1).to_host(q, deps).at(0); // Poor perfomance + + auto event = gemv(q, + data, + parameters.get_slice(1, parameters.get_dimension(0)), + probabilities, + Float(1), + w0, + { fill_event }); + auto* const prob_ptr = probabilities.get_mutable_data(); + + return q.submit([&](sycl::handler& cgh) { + cgh.depends_on(event); + const auto range = make_range_1d(n); + cgh.parallel_for(range, [=](sycl::id<1> idx) { + prob_ptr[idx] = 1 / (1 + sycl::exp(-prob_ptr[idx])); + }); + }); +} + +template +sycl::event compute_logloss(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out, + Float L1, + Float L2, + const event_vector& deps) { + const std::int64_t n = data.get_dimension(0); + const std::int64_t p = data.get_dimension(1); + ONEDAL_ASSERT(parameters.get_dimension(0) == p + 1); + ONEDAL_ASSERT(labels.get_dimension(0) == n); + ONEDAL_ASSERT(probabilities.get_dimension(0) == n); + ONEDAL_ASSERT(labels.has_data()); + ONEDAL_ASSERT(parameters.has_data()); + ONEDAL_ASSERT(data.has_data()); + ONEDAL_ASSERT(probabilities.has_data()); + + const auto* const labels_ptr = labels.get_data(); + const auto* const prob_ptr = probabilities.get_data(); + + auto* const out_ptr = out.get_mutable_data(); + + auto loss_event = q.submit([&](sycl::handler& cgh) { + const auto range = make_range_1d(n); + using oneapi::dal::backend::operator+; + using sycl::reduction; + + cgh.depends_on(deps); + + auto sumReduction = reduction(out_ptr, sycl::plus<>()); + + cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) { + const Float prob = prob_ptr[idx]; + const std::int32_t label = labels_ptr[idx]; + sum += -label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob); + }); + }); + + auto [out_reg, out_reg_e] = ndarray::zeros(q, { 1 }, sycl::usm::alloc::device); + auto* const reg_ptr = out_reg.get_mutable_data(); + const event_vector vector_out_reg = { out_reg_e }; + + const auto* const param_ptr = parameters.get_data(); + + if (L1 > 0 || L2 > 0) { + auto reg_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on(vector_out_reg); + const auto range = make_range_1d(p); + auto sumReduction = sycl::reduction(reg_ptr, sycl::plus<>()); + cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) { + const Float param = param_ptr[idx + 1]; + sum += L1 * sycl::abs(param) + L2 * param * param; + }); + }); + auto final_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on({ reg_event, loss_event }); + cgh.single_task([=] { + out_ptr[0] += reg_ptr[0]; + }); + }); + return final_event; + } + return loss_event; +} + +template +sycl::event compute_logloss(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + ndview& out, + Float L1, + Float L2, + const event_vector& deps) { + const std::int64_t n = data.get_dimension(0); + const std::int64_t p = data.get_dimension(1); + ONEDAL_ASSERT(parameters.get_dimension(0) == p + 1); + ONEDAL_ASSERT(labels.get_dimension(0) == n); + ONEDAL_ASSERT(labels.has_data()); + ONEDAL_ASSERT(parameters.has_data()); + ONEDAL_ASSERT(data.has_data()); + + // out should be filled with zero + + auto probabilities = ndarray::empty(q, { n }, sycl::usm::alloc::device); + auto prediction_event = compute_probabilities(q, parameters, data, probabilities, deps); + + return compute_logloss(q, + parameters, + data, + labels, + probabilities, + out, + L1, + L2, + { prediction_event }); +} + +template +sycl::event compute_logloss_with_der(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out, + ndview& out_derivative, + Float L1, + Float L2, + const event_vector& deps) { + // out, out_derivative should be filled with zeros + + const std::int64_t n = data.get_dimension(0); + const std::int64_t p = data.get_dimension(1); + + ONEDAL_ASSERT(parameters.get_dimension(0) == p + 1); + ONEDAL_ASSERT(labels.get_dimension(0) == n); + ONEDAL_ASSERT(probabilities.get_dimension(0) == n); + ONEDAL_ASSERT(out.get_count() == 1); + ONEDAL_ASSERT(out_derivative.get_count() == p + 1); + + ONEDAL_ASSERT(labels.has_data()); + ONEDAL_ASSERT(parameters.has_data()); + ONEDAL_ASSERT(data.has_data()); + ONEDAL_ASSERT(probabilities.has_data()); + ONEDAL_ASSERT(out.has_mutable_data()); + ONEDAL_ASSERT(out_derivative.has_mutable_data()); + + // d loss_i / d pred_i + auto derivative_object = ndarray::empty(q, { n }, sycl::usm::alloc::device); + + auto* const der_obj_ptr = derivative_object.get_mutable_data(); + const auto* const proba_ptr = probabilities.get_data(); + const auto* const labels_ptr = labels.get_data(); + const auto* const param_ptr = parameters.get_data(); + auto* const out_ptr = out.get_mutable_data(); + auto* const out_derivative_ptr = out_derivative.get_mutable_data(); + + auto loss_event = q.submit([&](sycl::handler& cgh) { + using oneapi::dal::backend::operator+; + using sycl::reduction; + + cgh.depends_on(deps); + auto sumReductionLogLoss = reduction(out_ptr, sycl::plus<>()); + auto sumReductionDerivativeW0 = reduction(out_derivative_ptr, sycl::plus<>()); + const auto wg_size = propose_wg_size(q); + const auto range = make_multiple_nd_range_1d(n, wg_size); + + cgh.parallel_for( + range, + sumReductionLogLoss, + sumReductionDerivativeW0, + [=](sycl::nd_item<1> id, auto& sum_logloss, auto& sum_Dw0) { + auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id(); + if (idx >= std::size_t(n)) + return; + const Float prob = proba_ptr[idx]; + const float label = labels_ptr[idx]; + sum_logloss += -label * sycl::log(prob) - (1 - label) * sycl::log(1 - prob); + der_obj_ptr[idx] = prob - label; + sum_Dw0 += der_obj_ptr[idx]; + }); + }); + + auto out_der_suffix = out_derivative.get_slice(1, p + 1); + + auto der_event = gemv(q, data.t(), derivative_object, out_der_suffix, { loss_event }); + if (L1 == 0 && L2 == 0) { + return der_event; + } + auto [reg_val, reg_val_e] = ndarray::zeros(q, { 1 }, sycl::usm::alloc::device); + + const event_vector reg_deps = { reg_val_e, der_event }; + auto* const reg_ptr = reg_val.get_mutable_data(); + + auto reg_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on(reg_deps); + const auto range = make_range_1d(p); + auto sumReduction = sycl::reduction(reg_ptr, sycl::plus<>()); + cgh.parallel_for(range, sumReduction, [=](sycl::id<1> idx, auto& sum) { + const Float param = param_ptr[idx + 1]; + sum += L1 * sycl::abs(param) + L2 * param * param; + out_derivative_ptr[idx + 1] += L2 * 2 * param; + }); + }); + + auto final_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on({ reg_event, loss_event }); + cgh.single_task([=] { + out_ptr[0] += reg_ptr[0]; + }); + }); + + return final_event; +} + +template +sycl::event compute_derivative(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out_derivative, + Float L1, + Float L2, + const event_vector& deps) { + // out_derivative should be filled with zeros + + const std::int64_t n = data.get_dimension(0); + const std::int64_t p = data.get_dimension(1); + + ONEDAL_ASSERT(parameters.get_dimension(0) == p + 1); + ONEDAL_ASSERT(labels.get_dimension(0) == n); + ONEDAL_ASSERT(probabilities.get_dimension(0) == n); + ONEDAL_ASSERT(out_derivative.get_count() == p + 1); + + ONEDAL_ASSERT(labels.has_data()); + ONEDAL_ASSERT(parameters.has_data()); + ONEDAL_ASSERT(data.has_data()); + ONEDAL_ASSERT(probabilities.has_data()); + ONEDAL_ASSERT(out_derivative.has_mutable_data()); + + // d loss_i / d pred_i + auto derivative_object = ndarray::empty(q, { n }, sycl::usm::alloc::device); + + auto* const der_obj_ptr = derivative_object.get_mutable_data(); + const auto* const proba_ptr = probabilities.get_data(); + const auto* const labels_ptr = labels.get_data(); + const auto* const param_ptr = parameters.get_data(); + auto* const out_derivative_ptr = out_derivative.get_mutable_data(); + + auto loss_event = q.submit([&](sycl::handler& cgh) { + using sycl::reduction; + + cgh.depends_on(deps); + auto sumReductionDerivativeW0 = reduction(out_derivative_ptr, sycl::plus<>()); + const auto wg_size = propose_wg_size(q); + const auto range = make_multiple_nd_range_1d(n, wg_size); + + cgh.parallel_for(range, sumReductionDerivativeW0, [=](sycl::nd_item<1> id, auto& sum_Dw0) { + auto idx = id.get_group_linear_id() * wg_size + id.get_local_linear_id(); + if (idx >= std::size_t(n)) + return; + const Float prob = proba_ptr[idx]; + const Float label = labels_ptr[idx]; + der_obj_ptr[idx] = prob - label; + sum_Dw0 += der_obj_ptr[idx]; + }); + }); + + auto out_der_suffix = out_derivative.get_slice(1, p + 1); + + auto der_event = gemv(q, data.t(), derivative_object, out_der_suffix, { loss_event }); + + if (L1 == 0 && L2 == 0) { + return der_event; + } + + auto reg_event = q.submit([&](sycl::handler& cgh) { + using oneapi::dal::backend::operator+; + cgh.depends_on({ der_event }); + const auto range = make_range_1d(p); + cgh.parallel_for(range, [=](sycl::id<1> idx) { + const Float param = param_ptr[idx + 1]; + out_derivative_ptr[idx + 1] += L2 * 2 * param; + }); + }); + + return reg_event; +} + +template +sycl::event compute_hessian(sycl::queue& q, + const ndview& parameters, + const ndview& data, + const ndview& labels, + const ndview& probabilities, + ndview& out_hessian, + Float L1, + Float L2, + const event_vector& deps) { + const int64_t n = data.get_dimension(0); + const int64_t p = data.get_dimension(1); + + ONEDAL_ASSERT(parameters.get_dimension(0) == p + 1); + ONEDAL_ASSERT(labels.get_dimension(0) == n); + ONEDAL_ASSERT(probabilities.get_dimension(0) == n); + ONEDAL_ASSERT(out_hessian.get_dimension(0) == (p + 1)); + ONEDAL_ASSERT(out_hessian.get_dimension(1) == (p + 1)); + + ONEDAL_ASSERT(labels.has_data()); + ONEDAL_ASSERT(parameters.has_data()); + ONEDAL_ASSERT(data.has_data()); + ONEDAL_ASSERT(probabilities.has_data()); + ONEDAL_ASSERT(out_hessian.has_mutable_data()); + + const auto* const data_ptr = data.get_data(); + auto* const hes_ptr = out_hessian.get_mutable_data(); + const auto* const proba_ptr = probabilities.get_data(); + + const auto max_wg = device_max_wg_size(q); + const auto wg = std::min(p + 1, max_wg); + const auto inp_str = data.get_leading_stride(); + const auto out_str = out_hessian.get_leading_stride(); + + constexpr std::int64_t block_size = 32; + const std::int64_t num_blocks = (n + block_size - 1) / block_size; + const auto range = make_multiple_nd_range_3d({ num_blocks, p + 1, wg }, { 1, 1, wg }); + + auto hes_event = q.submit([&](sycl::handler& cgh) { + cgh.depends_on(deps); + cgh.parallel_for(range, [=](sycl::nd_item<3> item) { + const std::int64_t obj_ind = item.get_global_id(0); + const auto j = item.get_global_id(1); + const auto param_ind_2 = item.get_global_id(2); + Float val = 0; + for (auto k = param_ind_2; k < j + 1; k += wg) { + val = 0; + const std::int64_t last_ind = std::min((obj_ind + 1) * block_size, n); + for (auto i = obj_ind * block_size; i < last_ind; ++i) { + const Float x1 = j > 0 ? data_ptr[i * inp_str + (j - 1)] : 1; + const Float x2 = k > 0 ? data_ptr[i * inp_str + (k - 1)] : 1; + const Float prob = proba_ptr[i] * (1 - proba_ptr[i]); + val += x1 * x2 * prob; + } + Float& out = hes_ptr[j * out_str + k]; + sycl::atomic_ref(out) + .fetch_add(val); + } + }); + }); + + auto make_symmetric = q.submit([&](sycl::handler& cgh) { + cgh.depends_on({ hes_event }); + const auto range = make_range_2d(p + 1, p + 1); + cgh.parallel_for(range, [=](sycl::id<2> idx) { + auto j = idx[0]; + auto k = idx[1]; + if (j > k) { + hes_ptr[k * out_str + j] = hes_ptr[j * out_str + k]; + } + else if (j == k && j > 0) { + hes_ptr[j * out_str + j] += 2 * L2; + } + }); + }); + + return make_symmetric; +} + +#define INSTANTIATE(F) \ + template sycl::event compute_probabilities(sycl::queue&, \ + const ndview&, \ + const ndview&, \ + ndview&, \ + const event_vector&); \ + template sycl::event compute_logloss(sycl::queue&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + ndview&, \ + F, \ + F, \ + const event_vector&); \ + template sycl::event compute_logloss(sycl::queue&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + ndview&, \ + F, \ + F, \ + const event_vector&); \ + template sycl::event compute_logloss_with_der(sycl::queue&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + ndview&, \ + ndview&, \ + F, \ + F, \ + const event_vector&); \ + template sycl::event compute_derivative(sycl::queue&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + ndview&, \ + F, \ + F, \ + const event_vector&); \ + template sycl::event compute_hessian(sycl::queue&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + const ndview&, \ + ndview&, \ + F, \ + F, \ + const event_vector&); + +INSTANTIATE(float); +INSTANTIATE(double); + +} // namespace oneapi::dal::backend::primitives diff --git a/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp b/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp new file mode 100644 index 00000000000..12991d907b4 --- /dev/null +++ b/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_dpc.cpp @@ -0,0 +1,499 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "oneapi/dal/backend/primitives/objective_function/logloss.hpp" +#include "oneapi/dal/test/engine/common.hpp" +#include "oneapi/dal/test/engine/fixtures.hpp" +#include "oneapi/dal/backend/primitives/debug.hpp" +#include "oneapi/dal/table/row_accessor.hpp" + +namespace oneapi::dal::backend::primitives::test { + +namespace te = dal::test::engine; + +template +struct order_tag { + static constexpr ndorder value = order; +}; + +using c_order = order_tag; +using f_order = order_tag; + +template +class logloss_test : public te::float_algo_fixture { +public: + using float_t = Param; + + void check_val(const float_t real, + const float_t expected, + const float_t rtol, + const float_t atol) { + REQUIRE(abs(real - expected) < atol); + REQUIRE(abs(real - expected) / std::max(std::abs(expected), (float_t)1.0) < rtol); + } + + void generate_input(std::int64_t n = -1, std::int64_t p = -1) { + if (n == -1 || p == -1) { + this->n_ = GENERATE(7, 827, 13, 216); + this->p_ = GENERATE(4, 17, 41, 256); + } + else { + this->n_ = n; + this->p_ = p; + } + + const auto dataframe = + GENERATE_DATAFRAME(te::dataframe_builder{ n_, p_ }.fill_uniform(-0.5, 0.5)); + const auto parameters = + GENERATE_DATAFRAME(te::dataframe_builder{ 1, p_ + 1 }.fill_uniform(-1, 1)); + this->data_ = dataframe.get_table(this->get_homogen_table_id()); + this->params_ = parameters.get_table(this->get_homogen_table_id()); + this->labels_ = + ndarray::empty(this->get_queue(), { n_ }, sycl::usm::alloc::host); + + std::srand(2007 + n_); + auto* const ptr_lab = this->labels_.get_mutable_data(); + for (std::int64_t i = 0; i < n_; ++i) { + ptr_lab[i] = std::rand() % 2; + } + } + + void run_test(const float_t L1 = 0, const float_t L2 = 0) { + auto data_array = row_accessor{ this->data_ }.pull(this->get_queue()); + auto data_host = ndarray::wrap(data_array.get_data(), { n_, p_ }); + + auto param_array = row_accessor{ this->params_ }.pull(this->get_queue()); + auto params_host = ndarray::wrap(param_array.get_data(), { p_ + 1 }); + test_input(data_host, params_host, this->labels_, L1, L2); + + SUCCEED(); + } + + void test_gold_input() { + constexpr std::int64_t n = 5; + constexpr std::int64_t p = 3; + constexpr float_t data[n * p] = { 0.83731708, -0.70899924, -1.23362082, 0.23468538, + -0.10549413, 1.12902673, -0.61035703, -1.55617932, + 0.60419908, 0.30589827, 0.63919892, -0.23380754, + 2.38196927, 1.64158111, 0.13677077 }; + constexpr std::int32_t labels[n] = { 0, 1, 1, 0, 1 }; + constexpr float_t L1 = 0; + constexpr float_t L2 = 3.123; + constexpr float_t cur_param[p + 1] = { -0.2, 0.1, -1, 0.4 }; + + auto data_host = ndarray::wrap(data, { n, p }); + auto labels_host = ndarray::wrap(labels, n); + auto params_host = ndarray::wrap(cur_param, p + 1); + + test_input(data_host, params_host, labels_host, L1, L2); + + SUCCEED(); + } + + void test_input(const ndarray& data_host, + const ndarray& params_host, + const ndarray& labels_host, + const float_t L1, + const float_t L2) { + constexpr float_t rtol = sizeof(float_t) > 4 ? 1e-6 : 1e-4; + constexpr float_t atol = sizeof(float_t) > 4 ? 1e-6 : 1; + constexpr float_t atol2 = sizeof(float_t) > 4 ? 1e-6 : 1e-4; + const std::int64_t n = data_host.get_dimension(0); + const std::int64_t p = data_host.get_dimension(1); + + auto data_gpu = data_host.to_device(this->get_queue()); + auto labels_gpu = labels_host.to_device(this->get_queue()); + auto params_gpu = params_host.to_device(this->get_queue()); + + auto out_predictions = + ndarray::empty(this->get_queue(), { n }, sycl::usm::alloc::device); + + auto p_event = + compute_probabilities(this->get_queue(), params_gpu, data_gpu, out_predictions, {}); + p_event.wait_and_throw(); + + auto predictions_host = out_predictions.to_host(this->get_queue(), {}); + + const float_t logloss = test_predictions_and_logloss(data_host, + params_host, + labels_host, + predictions_host, + L1, + L2, + rtol, + atol); + + auto [out_logloss, out_e] = + ndarray::zeros(this->get_queue(), { 1 }, sycl::usm::alloc::device); + sycl::event logloss_event = compute_logloss(this->get_queue(), + params_gpu, + data_gpu, + labels_gpu, + out_logloss, + L1, + L2, + { out_e }); + logloss_event.wait_and_throw(); + const float_t val_logloss1 = out_logloss.to_host(this->get_queue(), {}).at(0); + check_val(val_logloss1, logloss, rtol, atol); + auto fill_event = fill(this->get_queue(), out_logloss, float_t(0), {}); + auto [out_derivative, out_der_e] = + ndarray::zeros(this->get_queue(), { p + 1 }, sycl::usm::alloc::device); + auto logloss_event_der = compute_logloss_with_der(this->get_queue(), + params_gpu, + data_gpu, + labels_gpu, + out_predictions, + out_logloss, + out_derivative, + L1, + L2, + { fill_event, out_der_e }); + logloss_event_der.wait_and_throw(); + auto out_derivative_host = out_derivative.to_host(this->get_queue()); + const float_t val_logloss2 = out_logloss.to_host(this->get_queue(), {}).at(0); + check_val(val_logloss2, logloss, rtol, atol); + auto [out_derivative2, out_der_e2] = + ndarray::zeros(this->get_queue(), { p + 1 }, sycl::usm::alloc::device); + auto der_event = compute_derivative(this->get_queue(), + params_gpu, + data_gpu, + labels_gpu, + out_predictions, + out_derivative2, + L1, + L2, + { out_der_e2 }); + der_event.wait_and_throw(); + auto out_derivative_host2 = out_derivative2.to_host(this->get_queue()); + for (auto i = 0; i <= p; ++i) { + REQUIRE(abs(out_derivative_host.at(i) - out_derivative_host2.at(i)) < atol); + } + auto [out_hessian, out_hess_e] = ndarray::zeros(this->get_queue(), + { p + 1, p + 1 }, + sycl::usm::alloc::device); + auto hess_event = compute_hessian(this->get_queue(), + params_gpu, + data_gpu, + labels_gpu, + out_predictions, + out_hessian, + L1, + L2, + { out_hess_e }); + + auto hessian_host = out_hessian.to_host(this->get_queue(), { hess_event }); + test_formula_derivative(data_host, + predictions_host, + params_host, + labels_host, + out_derivative_host, + L1, + L2, + rtol, + atol2); + test_formula_hessian(data_host, predictions_host, hessian_host, L2, rtol, atol2); + test_derivative_and_hessian(data_gpu, + labels_gpu, + out_derivative_host, + hessian_host, + params_host, + L1, + L2, + rtol, + atol); + } + + float_t test_predictions_and_logloss(const ndarray& data_host, + const ndarray& params_host, + const ndarray& labels_host, + const ndarray& probabilities, + const float_t L1, + const float_t L2, + const float_t rtol = 1e-3, + const float_t atol = 1e-3) { + const std::int64_t n = data_host.get_dimension(0); + const std::int64_t p = data_host.get_dimension(1); + + float_t logloss = 0; + for (std::int64_t i = 0; i < n; ++i) { + float_t pred = 0; + for (std::int64_t j = 0; j < p; ++j) { + pred += params_host.at(j + 1) * data_host.at(i, j); + } + pred += params_host.at(0); + float_t prob = 1 / (1 + std::exp(-pred)); + logloss -= + labels_host.at(i) * std::log(prob) + (1 - labels_host.at(i)) * std::log(1 - prob); + float_t out_val = probabilities.at(i); + REQUIRE(abs(out_val - prob) < atol); + } + for (std::int64_t i = 1; i < p + 1; ++i) { + logloss += L1 * abs(params_host.at(i)); + logloss += L2 * params_host.at(i) * params_host.at(i); + } + return logloss; + } + + double naive_logloss(const ndarray& data_host, + const ndarray& params_host, + const ndarray& labels_host, + const float_t L1, + const float_t L2) { + const std::int64_t n = data_host.get_dimension(0); + const std::int64_t p = data_host.get_dimension(1); + + double logloss = 0; + for (std::int64_t i = 0; i < n; ++i) { + double pred = 0; + for (std::int64_t j = 0; j < p; ++j) { + pred += (double)params_host.at(j + 1) * (double)data_host.at(i, j); + } + pred += (double)params_host.at(0); + logloss += std::log(1 + std::exp(-(2 * labels_host.at(i) - 1) * pred)); + } + for (std::int64_t i = 1; i < p + 1; ++i) { + logloss += L1 * abs(params_host.at(i)); + logloss += L2 * params_host.at(i) * params_host.at(i); + } + return logloss; + } + + void naive_derivative(const ndarray& data, + const ndarray& probabilities, + const ndarray& params, + const ndarray& labels, + ndarray& out_der, + float_t L1, + float_t L2) { + const std::int64_t n = data.get_dimension(0); + const std::int64_t p = data.get_dimension(1); + for (std::int64_t j = 0; j <= p; ++j) { + double val = 0; + for (std::int64_t i = 0; i < n; ++i) { + double x1 = j > 0 ? data.at(i, j - 1) : 1; + double prob = probabilities.at(i); + val += (prob - labels.at(i)) * x1; + } + val += j > 0 ? L2 * 2 * params.at(j) : 0; + out_der.at(j) = val; + } + } + + void naive_hessian(const ndarray& data_host, + const ndarray& probabilities_host, + ndarray& out_hessian, + float_t L2) { + const std::int64_t n = data_host.get_dimension(0); + const std::int64_t p = data_host.get_dimension(1); + for (std::int64_t j = 0; j <= p; ++j) { + for (std::int64_t k = 0; k <= p; ++k) { + double val = 0; + for (std::int64_t i = 0; i < n; ++i) { + double x1 = j > 0 ? data_host.at(i, j - 1) : 1; + double x2 = k > 0 ? data_host.at(i, k - 1) : 1; + double prob = probabilities_host.at(i); + val += x1 * x2 * (1 - prob) * prob; + } + out_hessian.at(j, k) = val; + } + if (j > 0) { + out_hessian.at(j, j) += 2 * L2; + } + } + } + + void test_formula_derivative(const ndarray& data, + const ndarray& probabilities, + const ndarray& params, + const ndarray& labels, + const ndarray& derivative, + const float_t L1, + const float_t L2, + const float_t rtol = 1e-3, + const float_t atol = 1e-3) { + const std::int64_t p = data.get_dimension(1); + auto out_derivative = + ndarray::empty(this->get_queue(), { p + 1 }, sycl::usm::alloc::host); + + naive_derivative(data, probabilities, params, labels, out_derivative, L1, L2); + + for (std::int64_t i = 0; i < p + 1; ++i) { + check_val(out_derivative.at(i), derivative.at(i), rtol, atol); + } + } + + void test_formula_hessian(const ndarray& data, + const ndarray& probabilities, + const ndarray& hessian, + const float_t L2, + const float_t rtol = 1e-3, + const float_t atol = 1e-3) { + const std::int64_t p = data.get_dimension(1); + auto out_hessian = + ndarray::empty(this->get_queue(), { p + 1, p + 1 }, sycl::usm::alloc::host); + + naive_hessian(data, probabilities, out_hessian, L2); + + for (std::int64_t i = 0; i <= p; ++i) { + for (std::int64_t j = 0; j <= p; ++j) { + check_val(out_hessian.at(i, j), hessian.at(i, j), rtol, atol); + } + } + } + + void test_derivative_and_hessian(const ndarray& data, + const ndarray& labels, + const ndarray& derivative, + const ndarray& hessian, + const ndarray& params_host, + const float_t L1, + const float_t L2, + const float_t rtol = 1e-3, + const float_t atol = 1e-3) { + const std::int64_t n = data.get_dimension(0); + const std::int64_t p = data.get_dimension(1); + constexpr std::int64_t max_n = 2000; + constexpr float_t step = sizeof(float_t) > 4 ? 1e-4 : 1e-3; + + const auto data_host = data.to_host(this->get_queue()); + const auto labels_host = labels.to_host(this->get_queue()); + + std::array cur_param; + for (std::int64_t i = 0; i < p + 1; ++i) { + cur_param[i] = params_host.at(i); + } + + auto out_logloss = + ndarray::empty(this->get_queue(), { 1 }, sycl::usm::alloc::device); + auto out_predictions = + ndarray::empty(this->get_queue(), { n }, sycl::usm::alloc::device); + auto out_derivative_up = + ndarray::empty(this->get_queue(), { p + 1 }, sycl::usm::alloc::device); + auto out_derivative_down = + ndarray::empty(this->get_queue(), { p + 1 }, sycl::usm::alloc::device); + + for (std::int64_t i = 0; i < p + 1; ++i) { + auto fill_event_1 = fill(this->get_queue(), out_logloss, float_t(0), {}); + auto fill_event_2 = fill(this->get_queue(), out_derivative_up, float_t(0), {}); + auto fill_event_3 = + fill(this->get_queue(), out_derivative_down, float_t(0), {}); + + cur_param[i] = params_host.at(i) + step; + auto params_host_up = ndarray::wrap(cur_param.begin(), p + 1); + auto params_gpu_up = params_host_up.to_device(this->get_queue()); + + // Compute logloss and derivative with params [w0, w1, ... w_i + eps, ...., w_p] + + sycl::event pred_up_event = + compute_probabilities(this->get_queue(), params_gpu_up, data, out_predictions, {}); + sycl::event der_event_up = + compute_logloss_with_der(this->get_queue(), + params_gpu_up, + data, + labels, + out_predictions, + out_logloss, + out_derivative_up, + L1, + L2, + { fill_event_1, fill_event_2, pred_up_event }); + der_event_up.wait_and_throw(); + double logloss_up = naive_logloss(data_host, params_host_up, labels_host, L1, L2); + auto der_up_host = out_derivative_up.to_host(this->get_queue(), {}); + + cur_param[i] = params_host.at(i) - step; + + auto params_host_down = ndarray::wrap(cur_param.begin(), p + 1); + auto params_gpu_down = params_host_down.to_device(this->get_queue()); + auto fill_event_4 = fill(this->get_queue(), out_logloss, float_t(0), {}); + + // Compute logloss and derivative with params [w0, w1, ... w_i - eps, ...., w_p] + + sycl::event pred_down_event = compute_probabilities(this->get_queue(), + params_gpu_down, + data, + out_predictions, + {}); + sycl::event der_event_down = + compute_logloss_with_der(this->get_queue(), + params_gpu_down, + data, + labels, + out_predictions, + out_logloss, + out_derivative_down, + L1, + L2, + { fill_event_3, fill_event_4, pred_down_event }); + der_event_down.wait_and_throw(); + + double logloss_down = naive_logloss(data_host, params_host_down, labels_host, L1, L2); + auto der_down_host = out_derivative_down.to_host(this->get_queue(), {}); + // Check condition: (logloss(w_i + eps) - logloss(w_i - eps)) / 2eps ~ d logloss / dw_i + if (L1 == 0) { + check_val(derivative.at(i), (logloss_up - logloss_down) / (2 * step), rtol, atol); + } + if (sizeof(float_t) > 4) { + for (std::int64_t j = 0; j < p + 1; ++j) { + // Check condition (d logloss(w_i + eps) / d w_j - d logloss(w_i - eps) / d w_j) / 2eps ~ h_i,j + // due to lack of precision this condition is not checked for 32-bit floating point numbers + check_val(hessian.at(i, j), + (der_up_host.at(j) - der_down_host.at(j)) / (2 * step), + rtol, + atol); + } + } + cur_param[i] += step; + } + } + +private: + std::int64_t n_; + std::int64_t p_; + table data_; + table params_; + ndarray labels_; +}; + +TEMPLATE_TEST_M(logloss_test, "gold input test - double", "[logloss]", double) { + SKIP_IF(this->not_float64_friendly()); + this->test_gold_input(); +} +TEMPLATE_TEST_M(logloss_test, "gold input test - float", "[logloss]", float) { + this->test_gold_input(); +} + +TEMPLATE_TEST_M(logloss_test, "test random input - double without L1", "[logloss]", double) { + SKIP_IF(this->not_float64_friendly()); + this->generate_input(); + this->run_test(0.0, 1.3); +} + +TEMPLATE_TEST_M(logloss_test, "test random input - double with L1", "[logloss]", double) { + SKIP_IF(this->not_float64_friendly()); + this->generate_input(); + this->run_test(0.4, 1.3); +} + +TEMPLATE_TEST_M(logloss_test, "test random input - float", "[logloss]", float) { + this->generate_input(); + this->run_test(0.4, 1.3); +} + +} // namespace oneapi::dal::backend::primitives::test diff --git a/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_perf_dpc.cpp b/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_perf_dpc.cpp new file mode 100644 index 00000000000..120b35f42f2 --- /dev/null +++ b/cpp/oneapi/dal/backend/primitives/objective_function/test/logloss_perf_dpc.cpp @@ -0,0 +1,151 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "oneapi/dal/backend/primitives/objective_function/logloss.hpp" +#include "oneapi/dal/test/engine/common.hpp" +#include "oneapi/dal/test/engine/fixtures.hpp" +#include "oneapi/dal/backend/primitives/debug.hpp" +#include "oneapi/dal/table/row_accessor.hpp" + +namespace oneapi::dal::backend::primitives::test { + +namespace te = dal::test::engine; + +template +struct order_tag { + static constexpr ndorder value = order; +}; + +using c_order = order_tag; +using f_order = order_tag; + +template +class logloss_perf_test : public te::float_algo_fixture { +public: + using float_t = Param; + + void generate_input(std::int64_t n, std::int64_t p) { + this->n_ = n; + this->p_ = p; + + const auto dataframe = + GENERATE_DATAFRAME(te::dataframe_builder{ n_, p_ }.fill_uniform(-0.5, 0.5)); + const auto parameters = + GENERATE_DATAFRAME(te::dataframe_builder{ 1, p_ + 1 }.fill_uniform(-1, 1)); + this->data_ = dataframe.get_table(this->get_homogen_table_id()); + this->params_ = parameters.get_table(this->get_homogen_table_id()); + this->labels_ = + ndarray::empty(this->get_queue(), { n_ }, sycl::usm::alloc::host); + + std::srand(2007 + n_); + auto* const ptr_lab = this->labels_.get_mutable_data(); + for (std::int64_t i = 0; i < n_; ++i) { + ptr_lab[i] = std::rand() % 2; + } + } + + void measure_time() { + constexpr float_t L1 = 1.2; + constexpr float_t L2 = 0.7; + + auto data_array = row_accessor{ this->data_ }.pull(this->get_queue()); + auto data_host = ndarray::wrap(data_array.get_data(), { n_, p_ }); + + auto param_array = row_accessor{ this->params_ }.pull(this->get_queue()); + auto params_host = ndarray::wrap(param_array.get_data(), { p_ + 1 }); + + auto data_gpu = data_host.to_device(this->get_queue()); + auto labels_gpu = this->labels_.to_device(this->get_queue()); + auto params_gpu = params_host.to_device(this->get_queue()); + + auto out_predictions = + ndarray::empty(this->get_queue(), { n_ }, sycl::usm::alloc::device); + + auto p_event = + compute_probabilities(this->get_queue(), params_gpu, data_gpu, out_predictions, {}); + p_event.wait_and_throw(); + + auto out_logloss = + ndarray::empty(this->get_queue(), { 1 }, sycl::usm::alloc::device); + + auto out_derivative = + ndarray::empty(this->get_queue(), { p_ + 1 }, sycl::usm::alloc::device); + + BENCHMARK("Derivative computation") { + auto fill_event1 = fill(this->get_queue(), out_logloss, float_t(0), {}); + auto fill_event2 = fill(this->get_queue(), out_derivative, float_t(0), {}); + + auto logloss_event_der = compute_logloss_with_der(this->get_queue(), + params_gpu, + data_gpu, + labels_gpu, + out_predictions, + out_logloss, + out_derivative, + L1, + L2, + { fill_event1, fill_event2 }); + logloss_event_der.wait_and_throw(); + }; + + auto out_hessian = ndarray::empty(this->get_queue(), + { p_ + 1, p_ + 1 }, + sycl::usm::alloc::device); + + BENCHMARK("Hessian computation") { + auto fill_event = fill(this->get_queue(), out_hessian, float_t(0), {}); + auto hess_event = compute_hessian(this->get_queue(), + params_gpu, + data_gpu, + labels_gpu, + out_predictions, + out_hessian, + L1, + L2, + { fill_event }); + hess_event.wait_and_throw(); + }; + } + +private: + std::int64_t n_; + std::int64_t p_; + table data_; + table params_; + ndarray labels_; +}; + +TEMPLATE_TEST_M(logloss_perf_test, "perfomance test square", "[logloss][5000*5000]", double) { + SKIP_IF(this->not_float64_friendly()); + this->generate_input(5000, 5000); + this->measure_time(); +} + +TEMPLATE_TEST_M(logloss_perf_test, "perfomance test small p", "[logloss][10000*100]", double) { + SKIP_IF(this->not_float64_friendly()); + this->generate_input(100000, 100); + this->measure_time(); +} + +TEMPLATE_TEST_M(logloss_perf_test, "perfomance test small n", "[logloss][100 * 1000]", double) { + SKIP_IF(this->not_float64_friendly()); + this->generate_input(100, 7000); + this->measure_time(); +} + +} // namespace oneapi::dal::backend::primitives::test diff --git a/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d.hpp b/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d.hpp index adc34faca5c..3cb98993bac 100644 --- a/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d.hpp +++ b/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d_dpc.cpp b/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d_dpc.cpp index 9ffacbc53b7..152121afdcb 100644 --- a/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/reduction/reduction_1d_dpc.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/oneapi/dal/backend/primitives/reduction/test/reduction_1d_dpc.cpp b/cpp/oneapi/dal/backend/primitives/reduction/test/reduction_1d_dpc.cpp index 5efdb5e0ae0..f2d645ce742 100644 --- a/cpp/oneapi/dal/backend/primitives/reduction/test/reduction_1d_dpc.cpp +++ b/cpp/oneapi/dal/backend/primitives/reduction/test/reduction_1d_dpc.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2022 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.