Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement raft::stats API with mdspan #802

Merged
merged 50 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
853d984
Add mdspan for cov and mean_center
lowener Jul 25, 2022
2b5218c
Change only public API
lowener Jul 25, 2022
9c97e01
Add accuracy, randIndex, completeness and contingency
lowener Jul 25, 2022
426c201
Merge branch 'branch-22.10' into 22.10-stats-api
lowener Aug 11, 2022
1591389
Update meanvar
lowener Aug 11, 2022
a6cdb62
Start using vanilla mdspan
lowener Aug 11, 2022
16c1d03
Add vanilla mdspan for stats public api
lowener Aug 15, 2022
3b262cf
Fix comments, start adding tests
lowener Aug 18, 2022
b199ef9
Merge branch 'branch-22.10' into 22.10-stats-api
lowener Aug 31, 2022
21f374a
Remove constness, add tests
lowener Sep 3, 2022
82a9bd7
Add remaining tests, fix style
lowener Sep 5, 2022
0d11ec7
Merge branch 'branch-22.10' into 22.10-stats-api
lowener Sep 5, 2022
7ff8a2f
Using device_*_view instead of vanilla mdspan
lowener Sep 20, 2022
1e52d85
Template fix, add static_assert and fix tests
lowener Sep 21, 2022
37894d7
Add optional argument to contingency matrix
lowener Sep 21, 2022
8e95d2f
Prefer extent over size, change workspace type of contingency
lowener Sep 21, 2022
b85044c
Merge branch 'branch-22.10' into 22.10-stats-api
lowener Sep 22, 2022
e9a929c
Add device_mdspan include. Fix parameter order
lowener Sep 22, 2022
bbdf6dd
Fix copyright
lowener Sep 22, 2022
55c0c91
Fix tests
lowener Sep 23, 2022
b5cf18b
Update remaining stats function and their tests with mdspan
lowener Sep 26, 2022
ef94359
Use snake case for variables, parameters and templates
lowener Sep 27, 2022
0479c11
fix style
lowener Sep 27, 2022
f8ae9e1
Remove workspace from public api
lowener Sep 28, 2022
093dc4c
Add [in] [out] to parameter documentation
lowener Sep 28, 2022
2a4b5b8
Adding const specifier when possible
lowener Sep 28, 2022
669163e
Merge branch 'branch-22.10' into 22.10-stats-api
lowener Sep 29, 2022
00448a1
Remove default template, rename dispersion, fix silhouette_score
lowener Oct 3, 2022
36f066f
Fix silhouette test file
lowener Oct 3, 2022
39cc643
Add overload for std::nullopt
lowener Oct 3, 2022
2d4d285
Merge branch 'branch-22.10' into 22.10-stats-api
lowener Oct 3, 2022
2a97ef2
Add cluster dispersion definition
lowener Oct 3, 2022
461eee5
Fix bcast_along_rows
lowener Oct 4, 2022
8b9d840
Merge branch 'branch-22.10' into 22.10-stats-api
lowener Oct 4, 2022
9c4adf4
mean_center and weighted_mean correction for along_rows parameter
lowener Oct 5, 2022
29d360e
Updating row weighted mean
cjnolet Oct 5, 2022
333e596
iRemoving weighted mean mdspanification for now.
cjnolet Oct 5, 2022
b489e3b
Updating weighted mean test.
cjnolet Oct 6, 2022
a78d38a
Weighted mean
cjnolet Oct 6, 2022
c3d11a3
Skipping re-install of raft-dask for docs build
cjnolet Oct 6, 2022
5a7307f
Adding missing semicolon
cjnolet Oct 6, 2022
a7790e3
Adding pylibraft to docs build.
cjnolet Oct 6, 2022
374c91c
Reverting changes to build.sh
cjnolet Oct 6, 2022
be1baa1
Merge branch 'branch-22.10' into 22.10-stats-api
cjnolet Oct 6, 2022
faf86cc
Merge branch 'branch-22.10' into 22.10-stats-api
cjnolet Oct 6, 2022
72d6c80
enabling verbose logging in build.sh for docs
cjnolet Oct 7, 2022
ff2b9b0
Removing --build from cmake --build
cjnolet Oct 7, 2022
db02b2d
Merge branch 'fea-2212-increase_docs_build_logging' into 22.10-stats-api
cjnolet Oct 7, 2022
dd89e81
Fixing doxygen build
cjnolet Oct 7, 2022
4a3ad93
Fixing style
cjnolet Oct 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions cpp/include/raft/stats/accuracy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,25 @@ float accuracy(const math_t* predictions, const math_t* ref_predictions, int n,

/**
* @brief Compute accuracy of predictions. Useful for classification.
* @tparam math_t: data type for predictions (e.g., int for classification)
* @tparam DataT: data type for predictions (e.g., int for classification)
* @tparam IdxType Index type of matrix extent.
* @tparam LayoutPolicy Layout type of the input data.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
* @param[in] handle: the raft handle.
* @param[in] predictions: array of predictions (GPU pointer).
* @param[in] ref_predictions: array of reference (ground-truth) predictions (GPU pointer).
* @return: Accuracy score in [0, 1]; higher is better.
*/
template <typename math_t, typename IdxType, typename LayoutPolicy, typename AccessorPolicy>
float accuracy(
const raft::handle_t& handle,
raft::mdspan<math_t, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> predictions,
raft::mdspan<math_t, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> ref_predictions)
template <typename DataT, typename IdxType>
float accuracy(const raft::handle_t& handle,
raft::device_vector_view<const DataT, IdxType> predictions,
raft::device_vector_view<const DataT, IdxType> ref_predictions)
{
RAFT_EXPECTS(predictions.size() == ref_predictions.size(), "Size mismatch");
RAFT_EXPECTS(predictions.is_exhaustive(), "predictions must be contiguous");
RAFT_EXPECTS(ref_predictions.is_exhaustive(), "ref_predictions must be contiguous");

return detail::accuracy_score(predictions.data_handle(),
lowener marked this conversation as resolved.
Show resolved Hide resolved
ref_predictions.data_handle(),
predictions.extent(0),
lowener marked this conversation as resolved.
Show resolved Hide resolved
predictions.size(),
handle.get_stream());
}
} // namespace stats
Expand Down
30 changes: 13 additions & 17 deletions cpp/include/raft/stats/adjusted_rand_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,26 @@ double adjusted_rand_index(const T* firstClusterArray,
/**
* @brief Function to calculate Adjusted RandIndex as described
* <a href="https://en.wikipedia.org/wiki/Rand_index">here</a>
* @tparam T data-type for input label arrays
* @tparam DataT data-type for input label arrays
* @tparam MathT integral data-type used for computing n-choose-r
* @tparam IdxType Index type of matrix extent.
* @tparam LayoutPolicy Layout type of the input data.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
* @param handle: the raft handle.
* @param firstClusterArray: the array of classes
* @param secondClusterArray: the array of classes
*/
template <typename T,
typename MathT = int,
typename IdxType,
typename LayoutPolicy,
typename AccessorPolicy>
double adjusted_rand_index(
const raft::handle_t& handle,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> firstClusterArray,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> secondClusterArray)
template <typename DataT, typename MathT = int, typename IdxType>
double adjusted_rand_index(const raft::handle_t& handle,
raft::device_vector_view<const DataT, IdxType> firstClusterArray,
raft::device_vector_view<const DataT, IdxType> secondClusterArray)
{
return detail::compute_adjusted_rand_index<T, MathT>(firstClusterArray.data_handle(),
secondClusterArray.data_handle(),
firstClusterArray.extent(0),
handle.get_stream());
RAFT_EXPECTS(firstClusterArray.size() == secondClusterArray.size(), "Size mismatch");
RAFT_EXPECTS(firstClusterArray.is_exhaustive(), "firstClusterArray must be contiguous");
RAFT_EXPECTS(secondClusterArray.is_exhaustive(), "secondClusterArray must be contiguous");

return detail::compute_adjusted_rand_index<DataT, MathT>(firstClusterArray.data_handle(),
secondClusterArray.data_handle(),
firstClusterArray.size(),
handle.get_stream());
}

}; // end namespace stats
Expand Down
27 changes: 13 additions & 14 deletions cpp/include/raft/stats/completeness_score.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,27 @@ double completeness_score(const T* truthClusterArray,
/**
* @brief Function to calculate the completeness score between two clusters
*
* @tparam T the data type
* @tparam DataT the data type
* @tparam IdxType Index type of matrix extent.
* @tparam LayoutPolicy Layout type of the input data.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
* @param handle: the raft handle.
* @param truthClusterArray: the array of truth classes of type T
* @param predClusterArray: the array of predicted classes of type T
* @param truthClusterArray: the array of truth classes of type DataT
* @param predClusterArray: the array of predicted classes of type DataT
* @param lowerLabelRange: the lower bound of the range of labels
* @param upperLabelRange: the upper bound of the range of labels
*/
template <typename T, typename IdxType, typename LayoutPolicy, typename AccessorPolicy>
double completeness_score(
const raft::handle_t& handle,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> truthClusterArray,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> predClusterArray,
T lowerLabelRange,
T upperLabelRange)
template <typename DataT, typename IdxType>
double completeness_score(const raft::handle_t& handle,
raft::device_vector_view<const DataT, IdxType> truthClusterArray,
raft::device_vector_view<const DataT, IdxType> predClusterArray,
DataT lowerLabelRange,
DataT upperLabelRange)
{
RAFT_EXPECTS(truthClusterArray.size() == predClusterArray.size(), "Size mismatch");
RAFT_EXPECTS(truthClusterArray.is_exhaustive(), "truthClusterArray must be contiguous");
RAFT_EXPECTS(predClusterArray.is_exhaustive(), "predClusterArray must be contiguous");
return detail::homogeneity_score(predClusterArray.data_handle(),
truthClusterArray.data_handle(),
truthClusterArray.extent(0),
truthClusterArray.size(),
lowerLabelRange,
upperLabelRange,
handle.get_stream());
Expand Down
91 changes: 46 additions & 45 deletions cpp/include/raft/stats/contingency_matrix.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ void getInputClassCardinality(
* @param minLabel: [out] calculated min value in input array
* @param maxLabel: [out] calculated max value in input array
*/
template <typename T, typename IdxType, typename LayoutPolicy, typename AccessorPolicy>
void getInputClassCardinality(
const raft::handle_t& handle,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> groundTruth,
const raft::host_scalar_view<T>& minLabel,
const raft::host_scalar_view<T>& maxLabel)
template <typename DataT, typename IdxType>
void getInputClassCardinality(const raft::handle_t& handle,
lowener marked this conversation as resolved.
Show resolved Hide resolved
raft::device_vector_view<const DataT, IdxType> groundTruth,
raft::host_scalar_view<DataT> minLabel,
raft::host_scalar_view<DataT> maxLabel)
{
detail::getInputClassCardinality(groundTruth.data_handle(),
groundTruth.extent(0),
Expand Down Expand Up @@ -88,22 +87,16 @@ size_t getContingencyMatrixWorkspaceSize(int nSamples,
/**
* @brief Calculate workspace size for running contingency matrix calculations
* @tparam T label type
* @tparam OutT output matrix type
* @param handle: the raft handle.
* @param groundTruth: device 1-d array for ground truth (num of rows)
* @param minLabel: Optional, min value in input array
* @param maxLabel: Optional, max value in input array
*/
template <typename T,
typename OutT = int,
typename IdxType,
typename LayoutPolicy,
typename AccessorPolicy>
size_t getContingencyMatrixWorkspaceSize(
const raft::handle_t& handle,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> groundTruth,
T minLabel = std::numeric_limits<T>::max(),
T maxLabel = std::numeric_limits<T>::max())
template <typename DataT, typename IdxType>
size_t getContingencyMatrixWorkspaceSize(const raft::handle_t& handle,
lowener marked this conversation as resolved.
Show resolved Hide resolved
raft::device_vector_view<const DataT, IdxType> groundTruth,
DataT minLabel = std::numeric_limits<DataT>::max(),
DataT maxLabel = std::numeric_limits<DataT>::max())
{
return detail::getContingencyMatrixWorkspaceSize(
groundTruth.extent(0), groundTruth.data_handle(), handle.get_stream(), minLabel, maxLabel);
Expand All @@ -119,7 +112,7 @@ size_t getContingencyMatrixWorkspaceSize(
* @param groundTruth: device 1-d array for ground truth (num of rows)
* @param predictedLabel: device 1-d array for prediction (num of columns)
* @param nSamples: number of elements in input array
* @param outMat: output buffer for contingecy matrix
* @param outMat: output buffer for contingency matrix
* @param stream: cuda stream for execution
* @param workspace: Optional, workspace memory allocation
* @param workspaceSize: Optional, size of workspace memory
Expand Down Expand Up @@ -153,45 +146,53 @@ void contingencyMatrix(const T* groundTruth,
* labels. Users should call function getInputClassCardinality to find
* and allocate memory for output. Similarly workspace requirements
* should be checked using function getContingencyMatrixWorkspaceSize
* @tparam T label type
* @tparam OutT output matrix type
* @tparam DataT label type
* @tparam OutType output matrix type
* @tparam IdxType Index type of matrix extent.
* @tparam LayoutPolicy Layout type of the input data.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
* @param handle: the raft handle.
* @param groundTruth: device 1-d array for ground truth (num of rows)
* @param predictedLabel: device 1-d array for prediction (num of columns)
* @param outMat: output buffer for contingecy matrix
* @param outMat: output buffer for contingency matrix
lowener marked this conversation as resolved.
Show resolved Hide resolved
* @param workspace: Optional, workspace memory allocation
* @param workspaceSize: Optional, size of workspace memory
* @param minLabel: Optional, min value in input ground truth array
* @param maxLabel: Optional, max value in input ground truth array
*/
template <typename T,
typename OutT = int,
template <typename DataT,
typename OutType,
typename IdxType,
typename LayoutPolicy,
typename AccessorPolicy>
void contingencyMatrix(
const raft::handle_t& handle,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> groundTruth,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> predictedLabel,
raft::mdspan<OutT, raft::matrix_extent<IdxType>, LayoutPolicy, AccessorPolicy> outMat,
void* workspace = nullptr,
size_t workspaceSize = 0,
T minLabel = std::numeric_limits<T>::max(),
T maxLabel = std::numeric_limits<T>::max())
typename WorkspaceType,
lowener marked this conversation as resolved.
Show resolved Hide resolved
typename = raft::enable_if_mdspan<WorkspaceType>>
void contingencyMatrix(const raft::handle_t& handle,
raft::device_vector_view<const DataT, IdxType> groundTruth,
raft::device_vector_view<const DataT, IdxType> predictedLabel,
raft::device_matrix_view<OutType, IdxType, LayoutPolicy> outMat,
std::optional<WorkspaceType> workspace,
DataT minLabel = std::numeric_limits<DataT>::max(),
DataT maxLabel = std::numeric_limits<DataT>::max())
{
detail::contingencyMatrix<T, OutT>(groundTruth.data_handle(),
predictedLabel.data_handle(),
groundTruth.extent(0),
outMat.data_handle(),
handle.get_stream(),
workspace,
workspaceSize,
minLabel,
maxLabel);
RAFT_EXPECTS(groundTruth.size() == predictedLabel.size(), "Size mismatch");
RAFT_EXPECTS(groundTruth.is_exhaustive(), "groundTruth must be contiguous");
RAFT_EXPECTS(predictedLabel.is_exhaustive(), "predictedLabel must be contiguous");
RAFT_EXPECTS(outMat.is_exhaustive(), "outMat must be contiguous");

using workspaceElemType = typename WorkspaceType::element_type;
workspaceElemType* workspace_p = nullptr;
IdxType workspace_size = 0;
if (workspace.has_value()) {
workspace_p = workspace.value().data_handle();
workspace_size = workspace.value().size() * sizeof(workspaceElemType);
}
detail::contingencyMatrix<DataT, OutType>(groundTruth.data_handle(),
predictedLabel.data_handle(),
groundTruth.size(),
outMat.data_handle(),
handle.get_stream(),
workspace_p,
workspace_size,
minLabel,
maxLabel);
}

}; // namespace stats
Expand Down
18 changes: 12 additions & 6 deletions cpp/include/raft/stats/cov.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ void cov(const raft::handle_t& handle,
* @tparam Type the data type
* @tparam IdxT the index type
* @tparam LayoutPolicy Layout type of the input data.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
* @param handle the raft handle
* @param covar the output covariance matrix
* @param data the input matrix (this will get mean-centered at the end!)
Expand All @@ -80,14 +78,22 @@ void cov(const raft::handle_t& handle,
* @note if stable=true, then the input data will be mean centered after this
* function returns!
*/
template <typename Type, typename IdxType, typename LayoutPolicy, typename AccessorPolicy>
template <typename DataT, typename IdxType, typename LayoutPolicy>
void cov(const raft::handle_t& handle,
raft::mdspan<Type, raft::matrix_extent<IdxType>, LayoutPolicy, AccessorPolicy> covar,
raft::mdspan<Type, raft::matrix_extent<IdxType>, LayoutPolicy, AccessorPolicy> data,
raft::mdspan<Type, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> mu,
raft::device_matrix_view<DataT, IdxType, LayoutPolicy> covar,
raft::device_matrix_view<DataT, IdxType, LayoutPolicy> data,
raft::device_vector_view<const DataT, IdxType> mu,
bool sample,
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
bool stable)
{
static_assert(
std::is_same_v<LayoutPolicy, raft::row_major> || std::is_same_v<LayoutPolicy, raft::col_major>,
"Data layout not supported");
RAFT_EXPECTS(data.size() == covar.size(), "Size mismatch");
RAFT_EXPECTS(data.is_exhaustive(), "data must be contiguous");
RAFT_EXPECTS(covar.is_exhaustive(), "covar must be contiguous");
RAFT_EXPECTS(mu.is_exhaustive(), "mu must be contiguous");

detail::cov(handle,
covar.data_handle(),
data.data_handle(),
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/stats/detail/histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ HistType selectBestHistAlgo(IdxT nbins)
* @param nbins number of bins
* @param data input data (length = ncols * nrows)
* @param nrows data array length in each column (or batch)
* @param ncols number of columsn (or batch size)
* @param ncols number of columns (or batch size)
* @param stream cuda stream
* @param binner the operation that computes the bin index of the input data
*
Expand Down
28 changes: 13 additions & 15 deletions cpp/include/raft/stats/dispersion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ DataT dispersion(const DataT* centroids,
* @tparam DataT data type
* @tparam IdxType index type
* @tparam LayoutPolicy Layout type of the input data.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
* @tparam TPB threads block for kernels launched
* @param handle the raft handle
* @param centroids the cluster centroids. This is assumed to be row-major
Expand All @@ -76,21 +74,21 @@ DataT dispersion(const DataT* centroids,
* @param nPoints number of points in the dataset
* @return the cluster dispersion value
*/
template <typename DataT,
typename IdxType = int,
typename LayoutPolicy,
typename AccessorPolicy,
int TPB = 256>
DataT dispersion(
const raft::handle_t& handle,
raft::mdspan<DataT, raft::matrix_extent<IdxType>, LayoutPolicy, AccessorPolicy> centroids,
raft::mdspan<IdxType, raft::vector_extent<IdxType>> clusterSizes,
std::optional<raft::mdspan<DataT, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy>>
globalCentroid,
const IdxType nPoints)
template <typename DataT, typename IdxType, int TPB = 256>
DataT dispersion(const raft::handle_t& handle,
raft::device_matrix_view<const DataT, IdxType, raft::row_major> centroids,
raft::device_vector_view<const IdxType, IdxType> clusterSizes,
std::optional<raft::device_vector_view<DataT, IdxType>> globalCentroid,
const IdxType nPoints)
{
RAFT_EXPECTS(clusterSizes.extent(0) == centroids.extent(0), "Size mismatch");
RAFT_EXPECTS(clusterSizes.is_exhaustive(), "clusterSizes must be contiguous");

DataT* globalCentroid_ptr = nullptr;
if (globalCentroid.has_value()) { globalCentroid_ptr = globalCentroid.value().data_handle(); }
if (globalCentroid.has_value()) {
RAFT_EXPECTS(globalCentroid.value().is_exhaustive(), "globalCentroid must be contiguous");
globalCentroid_ptr = globalCentroid.value().data_handle();
}
return detail::dispersion<DataT, IdxType, TPB>(centroids.data_handle(),
clusterSizes.data_handle(),
globalCentroid_ptr,
Expand Down
21 changes: 9 additions & 12 deletions cpp/include/raft/stats/entropy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,23 @@ double entropy(const T* clusterArray,
* @brief Function to calculate entropy
* <a href="https://en.wikipedia.org/wiki/Entropy_(information_theory)">more info on entropy</a>
*
* @tparam T data type
* @tparam DataT data type
* @tparam IdxT index type
* @tparam LayoutPolicy Layout type of the input data.
* @tparam AccessorPolicy Accessor for the input and output, must be valid accessor on
* device.
* @param handle the raft handle
* @param clusterArray: the array of classes of type T
* @param clusterArray: the array of classes of type DataT
* @param lowerLabelRange: the lower bound of the range of labels
* @param upperLabelRange: the upper bound of the range of labels
* @return the entropy score
*/
template <typename T, typename IdxType, typename LayoutPolicy, typename AccessorPolicy>
double entropy(
const raft::handle_t& handle,
raft::mdspan<T, raft::vector_extent<IdxType>, LayoutPolicy, AccessorPolicy> clusterArray,
const T lowerLabelRange,
const T upperLabelRange)
template <typename DataT, typename IdxType>
double entropy(const raft::handle_t& handle,
raft::device_vector_view<const DataT, IdxType> clusterArray,
const DataT lowerLabelRange,
const DataT upperLabelRange)
{
RAFT_EXPECTS(clusterArray.is_exhaustive(), "clusterArray must be contiguous");
return detail::entropy(clusterArray.data_handle(),
clusterArray.extent(0),
clusterArray.size(),
lowerLabelRange,
upperLabelRange,
handle.get_stream());
Expand Down
Loading