-
Notifications
You must be signed in to change notification settings - Fork 217
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial commit * Fix reduction 2d files * Update tests * Fix code format * Add gemv implemetation and tests * Minor fix * Add logloss without regularization computation * Add derivative and naive hessian computation * Add 4 hessian computation implementations and perfomance tests * Add tests, header and update BUILD file * Add skip_if to double tests, fix timeouts * Update regularization computation
- Loading branch information
1 parent
32c31e5
commit 98fc59c
Showing
11 changed files
with
1,311 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/******************************************************************************* | ||
* Copyright 2023 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#pragma once | ||
|
||
#include "oneapi/dal/backend/primitives/objective_function/logloss.hpp" |
42 changes: 42 additions & 0 deletions
42
cpp/oneapi/dal/backend/primitives/objective_function/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
load("@onedal//dev/bazel:dal.bzl", | ||
"dal_module", | ||
"dal_test_suite", | ||
) | ||
|
||
dal_module( | ||
name = "objective_function", | ||
auto = True, | ||
dal_deps = [ | ||
"@onedal//cpp/oneapi/dal/backend/primitives:common", | ||
"@onedal//cpp/oneapi/dal/backend/primitives:blas" | ||
], | ||
) | ||
|
||
dal_test_suite( | ||
name = "tests", | ||
framework = "catch2", | ||
compile_as = [ "dpc++" ], | ||
private = True, | ||
srcs = glob([ | ||
"test/*_dpc.cpp", | ||
], exclude=[ | ||
"test/*perf*.cpp", | ||
]), | ||
dal_deps = [ | ||
":objective_function", | ||
], | ||
) | ||
|
||
dal_test_suite( | ||
name = "perf_tests", | ||
framework = "catch2", | ||
compile_as = [ "dpc++" ], | ||
private = True, | ||
srcs = glob([ | ||
"test/*perf_dpc.cpp", | ||
]), | ||
dal_deps = [ | ||
":objective_function", | ||
], | ||
) |
85 changes: 85 additions & 0 deletions
85
cpp/oneapi/dal/backend/primitives/objective_function/logloss.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/******************************************************************************* | ||
* Copyright 2023 Intel Corporation | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*******************************************************************************/ | ||
|
||
#pragma once | ||
|
||
#include "oneapi/dal/backend/primitives/ndarray.hpp" | ||
|
||
namespace oneapi::dal::backend::primitives { | ||
|
||
template <typename Float> | ||
sycl::event compute_probabilities(sycl::queue& q, | ||
const ndview<Float, 1>& parameters, | ||
const ndview<Float, 2>& data, | ||
ndview<Float, 1>& predictions, | ||
const event_vector& deps = {}); | ||
|
||
template <typename Float> | ||
sycl::event compute_logloss(sycl::queue& q, | ||
const ndview<Float, 1>& parameters, | ||
const ndview<Float, 2>& data, | ||
const ndview<std::int32_t, 1>& labels, | ||
ndview<Float, 1>& out, | ||
Float L1 = Float(0), | ||
Float L2 = Float(0), | ||
const event_vector& deps = {}); | ||
|
||
template <typename Float> | ||
sycl::event compute_logloss(sycl::queue& q, | ||
const ndview<Float, 1>& parameters, | ||
const ndview<Float, 2>& data, | ||
const ndview<std::int32_t, 1>& labels, | ||
const ndview<Float, 1>& probabilities, | ||
ndview<Float, 1>& out, | ||
Float L1 = Float(0), | ||
Float L2 = Float(0), | ||
const event_vector& deps = {}); | ||
|
||
template <typename Float> | ||
sycl::event compute_logloss_with_der(sycl::queue& q, | ||
const ndview<Float, 1>& parameters, | ||
const ndview<Float, 2>& data, | ||
const ndview<std::int32_t, 1>& labels, | ||
const ndview<Float, 1>& probabilities, | ||
ndview<Float, 1>& out, | ||
ndview<Float, 1>& out_derivative, | ||
Float L1 = Float(0), | ||
Float L2 = Float(0), | ||
const event_vector& deps = {}); | ||
|
||
template <typename Float> | ||
sycl::event compute_derivative(sycl::queue& q, | ||
const ndview<Float, 1>& parameters, | ||
const ndview<Float, 2>& data, | ||
const ndview<std::int32_t, 1>& labels, | ||
const ndview<Float, 1>& probabilities, | ||
ndview<Float, 1>& out_derivative, | ||
Float L1 = Float(0), | ||
Float L2 = Float(0), | ||
const event_vector& deps = {}); | ||
|
||
template <typename Float> | ||
sycl::event compute_hessian(sycl::queue& q, | ||
const ndview<Float, 1>& parameters, | ||
const ndview<Float, 2>& data, | ||
const ndview<std::int32_t, 1>& labels, | ||
const ndview<Float, 1>& probabilities, | ||
ndview<Float, 2>& out_hessian, | ||
Float L1 = Float(0), | ||
Float L2 = Float(0), | ||
const event_vector& deps = {}); | ||
|
||
} // namespace oneapi::dal::backend::primitives |
Oops, something went wrong.