Skip to content

Commit

Permalink
Dev/logloss (#2284)
Browse files Browse the repository at this point in the history
* Initial commit

* Fix reduction 2d files

* Update tests

* Fix code format

* Add gemv implemetation and tests

* Minor fix

* Add logloss without regularization computation

* Add derivative and naive hessian computation

* Add 4 hessian computation implementations and perfomance tests

* Add tests, header and update BUILD file

* Add skip_if to double tests, fix timeouts

* Update regularization computation
  • Loading branch information
avolkov-intel authored Mar 2, 2023
1 parent 32c31e5 commit 98fc59c
Show file tree
Hide file tree
Showing 11 changed files with 1,311 additions and 28 deletions.
2 changes: 2 additions & 0 deletions cpp/oneapi/dal/backend/primitives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dal_collect_modules(
"heap",
"intersection",
"lapack",
"objective_function",
"reduction",
"regression",
"rng",
Expand Down Expand Up @@ -74,6 +75,7 @@ dal_collect_test_suites(
modules = [
"blas",
"lapack",
"objective_function",
"reduction",
"selection",
"sort",
Expand Down
65 changes: 40 additions & 25 deletions cpp/oneapi/dal/backend/primitives/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ class ndarray_base : public base {
ndshape<axis_count> strides_;
};

template <typename T, std::int64_t axis_count, ndorder order = ndorder::c>
class ndarray;

template <typename T, std::int64_t axis_count, ndorder order = ndorder::c>
class ndview : public ndarray_base<axis_count, order> {
static_assert(!std::is_const_v<T>, "T must be non-const");
Expand Down Expand Up @@ -309,6 +312,11 @@ class ndview : public ndarray_base<axis_count, order> {
return this->t().get_row_slice(from_col, to_col).t();
}

#ifdef ONEDAL_DATA_PARALLEL
ndarray<T, axis_count, order> to_host(sycl::queue& q, const event_vector& deps = {}) const;
ndarray<T, axis_count, order> to_device(sycl::queue& q, const event_vector& deps = {}) const;
#endif

#ifdef ONEDAL_DATA_PARALLEL
sycl::event prefetch(sycl::queue& queue) const {
return queue.prefetch(data_, this->get_count());
Expand Down Expand Up @@ -487,7 +495,7 @@ inline sycl::event fill(sycl::queue& q,
}
#endif

template <typename T, std::int64_t axis_count, ndorder order = ndorder::c>
template <typename T, std::int64_t axis_count, ndorder order>
class ndarray : public ndview<T, axis_count, order> {
template <typename, std::int64_t, ndorder>
friend class ndarray;
Expand Down Expand Up @@ -763,30 +771,6 @@ class ndarray : public ndview<T, axis_count, order> {
}
#endif

#ifdef ONEDAL_DATA_PARALLEL
ndarray to_host(sycl::queue& q, const event_vector& deps = {}) const {
T* host_ptr = dal::detail::host_allocator<T>().allocate(this->get_count());
dal::backend::copy_usm2host(q, host_ptr, this->get_data(), this->get_count(), deps)
.wait_and_throw();
return wrap(host_ptr,
this->get_shape(),
dal::detail::make_default_delete<T>(dal::detail::default_host_policy{}));
}
#endif

#ifdef ONEDAL_DATA_PARALLEL
ndarray to_device(sycl::queue& q, const event_vector& deps = {}) const {
ndarray dev = empty(q, this->get_shape(), sycl::usm::alloc::device);
dal::backend::copy_host2usm(q,
dev.get_mutable_data(),
this->get_data(),
this->get_count(),
deps)
.wait_and_throw();
return dev;
}
#endif

ndarray slice(std::int64_t offset, std::int64_t count, std::int64_t axis = 0) const {
ONEDAL_ASSERT(order == ndorder::c, "Only C-order is supported");
ONEDAL_ASSERT(axis == 0, "Non-zero axis is not supported");
Expand Down Expand Up @@ -862,6 +846,37 @@ class ndarray : public ndview<T, axis_count, order> {
shared_t data_;
};

#ifdef ONEDAL_DATA_PARALLEL
template <typename T, std::int64_t axis_count, ndorder order>
ndarray<T, axis_count, order> ndview<T, axis_count, order>::to_host(
sycl::queue& q,
const event_vector& deps) const {
T* host_ptr = dal::detail::host_allocator<T>().allocate(this->get_count());
dal::backend::copy_usm2host(q, host_ptr, this->get_data(), this->get_count(), deps)
.wait_and_throw();
return ndarray<T, axis_count, order>::wrap(
host_ptr,
this->get_shape(),
dal::detail::make_default_delete<T>(dal::detail::default_host_policy{}));
}
#endif

#ifdef ONEDAL_DATA_PARALLEL
template <typename T, std::int64_t axis_count, ndorder order>
ndarray<T, axis_count, order> ndview<T, axis_count, order>::to_device(
sycl::queue& q,
const event_vector& deps) const {
auto dev = ndarray<T, axis_count, order>::empty(q, this->get_shape(), sycl::usm::alloc::device);
dal::backend::copy_host2usm(q,
dev.get_mutable_data(),
this->get_data(),
this->get_count(),
deps)
.wait_and_throw();
return dev;
}
#endif

#ifdef ONEDAL_DATA_PARALLEL

template <ndorder yorder,
Expand Down
19 changes: 19 additions & 0 deletions cpp/oneapi/dal/backend/primitives/objective_function.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/backend/primitives/objective_function/logloss.hpp"
42 changes: 42 additions & 0 deletions cpp/oneapi/dal/backend/primitives/objective_function/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package(default_visibility = ["//visibility:public"])
load("@onedal//dev/bazel:dal.bzl",
"dal_module",
"dal_test_suite",
)

dal_module(
name = "objective_function",
auto = True,
dal_deps = [
"@onedal//cpp/oneapi/dal/backend/primitives:common",
"@onedal//cpp/oneapi/dal/backend/primitives:blas"
],
)

dal_test_suite(
name = "tests",
framework = "catch2",
compile_as = [ "dpc++" ],
private = True,
srcs = glob([
"test/*_dpc.cpp",
], exclude=[
"test/*perf*.cpp",
]),
dal_deps = [
":objective_function",
],
)

dal_test_suite(
name = "perf_tests",
framework = "catch2",
compile_as = [ "dpc++" ],
private = True,
srcs = glob([
"test/*perf_dpc.cpp",
]),
dal_deps = [
":objective_function",
],
)
85 changes: 85 additions & 0 deletions cpp/oneapi/dal/backend/primitives/objective_function/logloss.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*******************************************************************************
* Copyright 2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/backend/primitives/ndarray.hpp"

namespace oneapi::dal::backend::primitives {

template <typename Float>
sycl::event compute_probabilities(sycl::queue& q,
const ndview<Float, 1>& parameters,
const ndview<Float, 2>& data,
ndview<Float, 1>& predictions,
const event_vector& deps = {});

template <typename Float>
sycl::event compute_logloss(sycl::queue& q,
const ndview<Float, 1>& parameters,
const ndview<Float, 2>& data,
const ndview<std::int32_t, 1>& labels,
ndview<Float, 1>& out,
Float L1 = Float(0),
Float L2 = Float(0),
const event_vector& deps = {});

template <typename Float>
sycl::event compute_logloss(sycl::queue& q,
const ndview<Float, 1>& parameters,
const ndview<Float, 2>& data,
const ndview<std::int32_t, 1>& labels,
const ndview<Float, 1>& probabilities,
ndview<Float, 1>& out,
Float L1 = Float(0),
Float L2 = Float(0),
const event_vector& deps = {});

template <typename Float>
sycl::event compute_logloss_with_der(sycl::queue& q,
const ndview<Float, 1>& parameters,
const ndview<Float, 2>& data,
const ndview<std::int32_t, 1>& labels,
const ndview<Float, 1>& probabilities,
ndview<Float, 1>& out,
ndview<Float, 1>& out_derivative,
Float L1 = Float(0),
Float L2 = Float(0),
const event_vector& deps = {});

template <typename Float>
sycl::event compute_derivative(sycl::queue& q,
const ndview<Float, 1>& parameters,
const ndview<Float, 2>& data,
const ndview<std::int32_t, 1>& labels,
const ndview<Float, 1>& probabilities,
ndview<Float, 1>& out_derivative,
Float L1 = Float(0),
Float L2 = Float(0),
const event_vector& deps = {});

template <typename Float>
sycl::event compute_hessian(sycl::queue& q,
const ndview<Float, 1>& parameters,
const ndview<Float, 2>& data,
const ndview<std::int32_t, 1>& labels,
const ndview<Float, 1>& probabilities,
ndview<Float, 2>& out_hessian,
Float L1 = Float(0),
Float L2 = Float(0),
const event_vector& deps = {});

} // namespace oneapi::dal::backend::primitives
Loading

0 comments on commit 98fc59c

Please sign in to comment.