diff --git a/BUILD.md b/BUILD.md
index 0c7fdd7e82..f572b11848 100644
--- a/BUILD.md
+++ b/BUILD.md
@@ -5,6 +5,7 @@
- [Build Dependencies](#required_depenencies)
- [Header-only C++](#install_header_only_cpp)
- [C++ Shared Libraries](#shared_cpp_libs)
+ - [Improving Rebuild Times](#ccache)
- [Googletests](#gtests)
- [C++ Using Cmake](#cpp_using_cmake)
- [Python](#python)
@@ -29,7 +30,6 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth
- [RMM](https://github.com/rapidsai/rmm) corresponding to RAFT version.
#### Optional
-- [mdspan](https://github.com/rapidsai/mdspan) - On by default but can be disabled.
- [Thrust](https://github.com/NVIDIA/thrust) v1.15 / [CUB](https://github.com/NVIDIA/cub) - On by default but can be disabled.
- [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API.
- [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0
@@ -53,11 +53,6 @@ The following example will download the needed dependencies and install the RAFT
./build.sh libraft --install
```
-The `--minimal-deps` flag can be used to install the headers with minimal dependencies:
-```bash
-./build.sh libraft --install --minimal-deps
-```
-
### C++ Shared Libraries (optional)
For larger projects which make heavy use of the pairwise distances or nearest neighbors APIs, shared libraries can be built to speed up compile times. These shared libraries can also significantly improve re-compile times both while developing RAFT and developing against the APIs. Build all of the available shared libraries by passing `--compile-libs` flag to `build.sh`:
@@ -72,6 +67,14 @@ Individual shared libraries have their own flags and multiple can be used (thoug
Add the `--install` flag to the above example to also install the shared libraries into `$INSTALL_PREFIX/lib`.
+### `ccache` and `sccache`
+
+`ccache` and `sccache` can be used to better cache parts of the build when rebuilding frequently, such as when working on a new feature. You can also use `ccache` or `sccache` with `build.sh`:
+
+```bash
+./build.sh libraft --cache-tool=ccache
+```
+
### Tests
Compile the tests using the `tests` target in `build.sh`.
@@ -86,10 +89,17 @@ Test compile times can be improved significantly by using the optional shared li
./build.sh libraft tests --compile-libs
```
-To run C++ tests:
+The tests are broken apart by algorithm category, so you will find several binaries in `cpp/build/` named `*_TEST`.
+
+For example, to run the distance tests:
+```bash
+./cpp/build/DISTANCE_TEST
+```
+
+It can take sometime to compile all of the tests. You can build individual tests by providing a semicolon-separated list to the `--limit-tests` option in `build.sh`:
```bash
-./cpp/build/test_raft
+./build.sh libraft tests --limit-tests=SPATIAL_TEST;DISTANCE_TEST;MATRIX_TEST
```
### Benchmarks
diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh
index 7de942444e..9eee21b09c 100644
--- a/cpp/include/raft/cluster/detail/single_linkage.cuh
+++ b/cpp/include/raft/cluster/detail/single_linkage.cuh
@@ -54,7 +54,7 @@ void single_linkage(const raft::handle_t& handle,
size_t m,
size_t n,
raft::distance::DistanceType metric,
- linkage_output* out,
+ linkage_output* out,
int c,
size_t n_clusters)
{
diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh
index 98735c74e4..8e33b8389d 100644
--- a/cpp/include/raft/cluster/single_linkage.cuh
+++ b/cpp/include/raft/cluster/single_linkage.cuh
@@ -17,9 +17,12 @@
#include
#include
+#include
namespace raft::cluster {
+constexpr int DEFAULT_CONST_C = 15;
+
/**
* Single-linkage clustering, capable of constructing a KNN graph to
* scale the algorithm beyond the n^2 memory consumption of implementations
@@ -48,11 +51,53 @@ void single_linkage(const raft::handle_t& handle,
size_t m,
size_t n,
raft::distance::DistanceType metric,
- linkage_output* out,
+ linkage_output* out,
int c,
size_t n_clusters)
{
detail::single_linkage(
handle, X, m, n, metric, out, c, n_clusters);
}
+
+/**
+ * Single-linkage clustering, capable of constructing a KNN graph to
+ * scale the algorithm beyond the n^2 memory consumption of implementations
+ * that use the fully-connected graph of pairwise distances by connecting
+ * a knn graph when k is not large enough to connect it.
+
+ * @tparam value_idx
+ * @tparam value_t
+ * @tparam dist_type method to use for constructing connectivities graph
+ * @param[in] handle raft handle
+ * @param[in] X dense input matrix in row-major layout
+ * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2)
+ * @param[out] labels output labels vector (size n_rows)
+ * @param[in] metric distance metrix to use when constructing connectivities graph
+ * @param[in] n_clusters number of clusters to assign data samples
+ * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect
+ control of k. The algorithm will set `k = log(n) + c`
+ */
+template
+void single_linkage(const raft::handle_t& handle,
+ raft::device_matrix_view X,
+ raft::device_matrix_view dendrogram,
+ raft::device_vector_view labels,
+ raft::distance::DistanceType metric,
+ size_t n_clusters,
+ std::optional c = std::make_optional(DEFAULT_CONST_C))
+{
+ linkage_output out_arrs;
+ out_arrs.children = dendrogram.data_handle();
+ out_arrs.labels = labels.data_handle();
+
+ single_linkage(handle,
+ X.data_handle(),
+ static_cast(X.extent(0)),
+ static_cast(X.extent(1)),
+ metric,
+ &out_arrs,
+ c.has_value() ? c.value() : DEFAULT_CONST_C,
+ n_clusters);
+}
+
}; // namespace raft::cluster
diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp
index 1c35cf5c68..d97e6afed3 100644
--- a/cpp/include/raft/cluster/single_linkage_types.hpp
+++ b/cpp/include/raft/cluster/single_linkage_types.hpp
@@ -16,6 +16,8 @@
#pragma once
+#include
+
namespace raft::cluster {
enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 };
@@ -27,23 +29,33 @@ enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 };
* @tparam value_idx
* @tparam value_t
*/
-template
+template
class linkage_output {
public:
- value_idx m;
- value_idx n_clusters;
+ idx_t m;
+ idx_t n_clusters;
+
+ idx_t n_leaves;
+ idx_t n_connected_components;
- value_idx n_leaves;
- value_idx n_connected_components;
+ // TODO: These will be made private in a future release
+ idx_t* labels; // size: m
+ idx_t* children; // size: (m-1, 2)
- value_idx* labels; // size: m
+ raft::device_vector_view get_labels()
+ {
+ return raft::make_device_vector_view(labels, m);
+ }
- value_idx* children; // size: (m-1, 2)
+ raft::device_matrix_view get_children()
+ {
+ return raft::make_device_matrix_view(children, m - 1, 2);
+ }
};
-class linkage_output_int_float : public linkage_output {
+class linkage_output_int_float : public linkage_output {
};
-class linkage_output__int64_float : public linkage_output {
+class linkage_output__int64_float : public linkage_output {
};
}; // namespace raft::cluster
\ No newline at end of file
diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp
index 6281ca98ea..a858633e07 100644
--- a/cpp/include/raft/core/mdspan.hpp
+++ b/cpp/include/raft/core/mdspan.hpp
@@ -255,5 +255,4 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx,
return unravel_index_impl(static_cast(idx), shape);
}
}
-
-} // namespace raft
\ No newline at end of file
+} // namespace raft
diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh
index afdec24ebd..d26f5f73cf 100644
--- a/cpp/include/raft/matrix/col_wise_sort.cuh
+++ b/cpp/include/raft/matrix/col_wise_sort.cuh
@@ -18,10 +18,11 @@
#pragma once
+#include
+#include
#include
-namespace raft {
-namespace matrix {
+namespace raft::matrix {
/**
* @brief sort columns within each row of row-major input matrix and return sorted indexes
@@ -50,7 +51,105 @@ void sort_cols_per_row(const InType* in,
detail::sortColumnsPerRow(
in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys);
}
-}; // end namespace matrix
-}; // end namespace raft
+
+/**
+ * @brief sort columns within each row of row-major input matrix and return sorted indexes
+ * modelled as key-value sort with key being input matrix and value being index of values
+ * @tparam in_t: element type of input matrix
+ * @tparam out_t: element type of output matrix
+ * @tparam matrix_idx_t: integer type for matrix indexing
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix
+ * @param[out] out: output value(index) matrix
+ * @param[out] sorted_keys: Optional, output matrix for sorted keys (input)
+ */
+template
+void sort_cols_per_row(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ std::optional>
+ sorted_keys = std::nullopt)
+{
+ RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0),
+ "Input and output matrices must have the same shape.");
+
+ if (sorted_keys.has_value()) {
+ RAFT_EXPECTS(in.extent(1) == sorted_keys.value().extent(1) &&
+ in.extent(0) == sorted_keys.value().extent(0),
+ "Input and `sorted_keys` matrices must have the same shape.");
+ }
+
+ size_t workspace_size = 0;
+ bool alloc_workspace = false;
+
+ in_t* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr;
+
+ detail::sortColumnsPerRow(in.data_handle(),
+ out.data_handle(),
+ in.extent(0),
+ in.extent(1),
+ alloc_workspace,
+ (void*)nullptr,
+ workspace_size,
+ handle.get_stream(),
+ keys);
+
+ if (alloc_workspace) {
+ auto workspace = raft::make_device_vector(handle, workspace_size);
+
+ detail::sortColumnsPerRow(in.data_handle(),
+ out.data_handle(),
+ in.extent(0),
+ in.extent(1),
+ alloc_workspace,
+ (void*)workspace.data_handle(),
+ workspace_size,
+ handle.get_stream(),
+ keys);
+ }
+}
+
+namespace sort_cols_per_row_impl {
+template
+struct sorted_keys_alias {
+};
+
+template <>
+struct sorted_keys_alias {
+ using type = double;
+};
+
+template
+struct sorted_keys_alias<
+ std::optional>> {
+ using type = typename raft::device_matrix_view::value_type;
+};
+
+template
+using sorted_keys_t = typename sorted_keys_alias::type;
+} // namespace sort_cols_per_row_impl
+
+/**
+ * @brief Overload of `sort_keys_per_row` to help the
+ * compiler find the above overload, in case users pass in
+ * `std::nullopt` for one or both of the optional arguments.
+ *
+ * Please see above for documentation of `sort_keys_per_row`.
+ */
+template
+void sort_cols_per_row(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ sorted_keys_vector_type sorted_keys)
+{
+ using sorted_keys_type = sort_cols_per_row_impl::sorted_keys_t<
+ std::remove_const_t>>;
+ std::optional> sorted_keys_opt =
+ std::forward(sorted_keys);
+
+ sort_cols_per_row(handle, in, out, sorted_keys_opt);
+}
+
+}; // end namespace raft::matrix
#endif
\ No newline at end of file
diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh
new file mode 100644
index 0000000000..5f1d16485c
--- /dev/null
+++ b/cpp/include/raft/matrix/copy.cuh
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief Copy selected rows of the input matrix into contiguous space.
+ *
+ * On exit out[i + k*n_rows] = in[indices[i] + k*n_rows],
+ * where i = 0..n_rows_indices-1, and k = 0..n_cols-1.
+ *
+ * @param[in] handle raft handle
+ * @param[in] in input matrix
+ * @param[out] out output matrix
+ * @param[in] indices of the rows to be copied
+ */
+template
+void copy_rows(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ raft::device_vector_view indices)
+{
+ RAFT_EXPECTS(in.extent(1) == out.extent(1),
+ "Input and output matrices must have same number of columns");
+ RAFT_EXPECTS(indices.extent(0) == out.extent(0),
+ "Number of rows in output matrix must equal number of indices");
+ detail::copyRows(in.data_handle(),
+ in.extent(0),
+ in.extent(1),
+ out.data_handle(),
+ indices.data_handle(),
+ indices.extent(0),
+ handle.get_stream(),
+ raft::is_row_major(in));
+}
+
+/**
+ * @brief copy matrix operation for column major matrices.
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix
+ * @param[out] out: output matrix
+ */
+template
+void copy(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out)
+{
+ RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1),
+ "Input and output matrix shapes must match.");
+
+ raft::copy_async(
+ out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream());
+}
+
+/**
+ * @brief copy matrix operation for column major matrices. First n_rows and
+ * n_cols of input matrix "in" is copied to "out" matrix.
+ * @param handle: raft handle for managing resources
+ * @param in: input matrix
+ * @param out: output matrix
+ */
+template
+void trunc_zero_origin(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out)
+{
+ RAFT_EXPECTS(out.extent(0) <= in.extent(0) && out.extent(1) <= in.extent(1),
+ "Output matrix must have less or equal number of rows and columns");
+
+ detail::truncZeroOrigin(in.data_handle(),
+ in.extent(0),
+ out.data_handle(),
+ out.extent(0),
+ out.extent(1),
+ handle.get_stream());
+}
+
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh
index 95953feca4..07b9ccc12b 100644
--- a/cpp/include/raft/matrix/detail/math.cuh
+++ b/cpp/include/raft/matrix/detail/math.cuh
@@ -141,7 +141,7 @@ void setSmallValuesZero(math_t* inout, IdxType len, cudaStream_t stream, math_t
}
template
-void reciprocal(math_t* in,
+void reciprocal(const math_t* in,
math_t* out,
math_t scalar,
int len,
@@ -363,8 +363,8 @@ void matrixVectorBinarySub(Type* data,
}
// Computes the argmax(d_in) column-wise in a DxN matrix
-template
-__global__ void argmaxKernel(const T* d_in, int D, int N, T* argmax)
+template
+__global__ void argmaxKernel(const T* d_in, int D, int N, IdxT* argmax)
{
typedef cub::BlockReduce, TPB> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
@@ -384,19 +384,19 @@ __global__ void argmaxKernel(const T* d_in, int D, int N, T* argmax)
if (threadIdx.x == 0) { argmax[blockIdx.x] = maxKV.key; }
}
-template
-void argmax(const math_t* in, int n_rows, int n_cols, math_t* out, cudaStream_t stream)
+template
+void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream)
{
int D = n_rows;
int N = n_cols;
if (D <= 32) {
- argmaxKernel<<>>(in, D, N, out);
+ argmaxKernel<<>>(in, D, N, out);
} else if (D <= 64) {
- argmaxKernel<<>>(in, D, N, out);
+ argmaxKernel<<>>(in, D, N, out);
} else if (D <= 128) {
- argmaxKernel<<>>(in, D, N, out);
+ argmaxKernel<<>>(in, D, N, out);
} else {
- argmaxKernel<<>>(in, D, N, out);
+ argmaxKernel<<>>(in, D, N, out);
}
RAFT_CUDA_TRY(cudaPeekAtLastError());
}
diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh
index a8568b0859..c425aad79b 100644
--- a/cpp/include/raft/matrix/detail/matrix.cuh
+++ b/cpp/include/raft/matrix/detail/matrix.cuh
@@ -67,7 +67,7 @@ void copyRows(const m_t* in,
template
void truncZeroOrigin(
- m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream)
+ const m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream)
{
auto m = out_n_rows;
auto k = in_n_rows;
@@ -279,7 +279,6 @@ m_t getL2Norm(const raft::handle_t& handle, m_t* in, idx_t size, cudaStream_t st
{
cublasHandle_t cublasH = handle.get_cublas_handle();
m_t normval = 0;
- // #TODO: Call from the public API when ready
RAFT_CUBLAS_TRY(raft::linalg::detail::cublasnrm2(cublasH, size, in, 1, &normval, stream));
return normval;
}
diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp
new file mode 100644
index 0000000000..fc3d14861c
--- /dev/null
+++ b/cpp/include/raft/matrix/detail/print.hpp
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) 2021-2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+
+#include
+
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace raft::matrix::detail {
+
+template
+void printHost(
+ const m_t* in, idx_t n_rows, idx_t n_cols, char h_separator = ' ', char v_separator = '\n', )
+{
+ for (idx_t i = 0; i < n_rows; i++) {
+ for (idx_t j = 0; j < n_cols; j++) {
+ printf("%1.4f%c", in[j * n_rows + i], j < n_cols - 1 ? h_separator : v_separator);
+ }
+ }
+}
+
+} // end namespace raft::matrix::detail
diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh
index 31164b2041..fa6e73de49 100644
--- a/cpp/include/raft/matrix/gather.cuh
+++ b/cpp/include/raft/matrix/gather.cuh
@@ -15,10 +15,12 @@
*/
#pragma once
+
+#include
+#include
#include
-namespace raft {
-namespace matrix {
+namespace raft::matrix {
/**
* @brief gather copies rows from a source matrix into a destination matrix according to a map.
@@ -49,6 +51,76 @@ void gather(const MatrixIteratorT in,
detail::gather(in, D, N, map, map_length, out, stream);
}
+/**
+ * @brief gather copies rows from a source matrix into a destination matrix according to a map.
+ *
+ * @tparam matrix_t Matrix element type
+ * @tparam map_t Map vector type
+ * @tparam idx_t integer type used for indexing
+ * @param[in] handle raft handle for managing resources
+ * @param[in] in Input matrix (assumed to be row-major)
+ * @param[in] map Vector of gather locations
+ * @param[out] out Output matrix (assumed to be row-major)
+ */
+template
+void gather(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_vector_view map,
+ raft::device_matrix_view out)
+{
+ RAFT_EXPECTS(out.extent(0) == map.extent(0),
+ "Number of rows in output matrix must equal the size of the map vector");
+ RAFT_EXPECTS(out.extent(1) == in.extent(1),
+ "Number of columns in input and output matrices must be equal.");
+
+ raft::matrix::detail::gather(
+ const_cast(in.data_handle()), // TODO: There's a better way to handle this
+ static_cast(in.extent(1)),
+ static_cast(in.extent(0)),
+ map.data_handle(),
+ static_cast(map.extent(0)),
+ out.data_handle(),
+ handle.get_stream());
+}
+
+/**
+ * @brief gather copies rows from a source matrix into a destination matrix according to a
+ * transformed map.
+ *
+ * @tparam matrix_t Matrix type
+ * @tparam map_t Map vector type
+ * @tparam map_xform_t Unary lambda expression or operator type, MapTransformOp's result
+ * type must be convertible to idx_t (= int) type.
+ * @tparam idx_t integer type for indexing
+ * @param[in] handle raft handle for managing resources
+ * @param[in] in Input matrix (assumed to be row-major)
+ * @param[in] map Input vector of gather locations
+ * @param[out] out Output matrix (assumed to be row-major)
+ * @param[in] transform_op The transformation operation, transforms the map values to idx_t
+ */
+template
+void gather(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_vector_view map,
+ raft::device_matrix_view out,
+ map_xform_t transform_op)
+{
+ RAFT_EXPECTS(out.extent(0) == map.extent(0),
+ "Number of rows in output matrix must equal the size of the map vector");
+ RAFT_EXPECTS(out.extent(1) == in.extent(1),
+ "Number of columns in input and output matrices must be equal.");
+
+ detail::gather(
+ const_cast(in.data_handle()), // TODO: There's a better way to handle this
+ static_cast(in.extent(1)),
+ static_cast(in.extent(0)),
+ map,
+ static_cast(map.extent(0)),
+ out.data_handle(),
+ transform_op,
+ handle.get_stream());
+}
+
/**
* @brief gather copies rows from a source matrix into a destination matrix according to a
* transformed map.
@@ -124,6 +196,53 @@ void gather_if(const MatrixIteratorT in,
detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, stream);
}
+/**
+ * @brief gather_if conditionally copies rows from a source matrix into a destination matrix
+ * according to a map.
+ *
+ * @tparam matrix_t Matrix value type
+ * @tparam map_t Map vector type
+ * @tparam stencil_t Stencil vector type
+ * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result
+ * type must be convertible to bool type.
+ * @tparam idx_t integer type for indexing
+ * @param[in] handle raft handle for managing resources
+ * @param[in] in Input matrix (assumed to be row-major)
+ * @param[in] map Input vector of gather locations
+ * @param[in] stencil Input vector of stencil or predicate values
+ * @param[out] out Output matrix (assumed to be row-major)
+ * @param[in] pred_op Predicate to apply to the stencil values
+ */
+template
+void gather_if(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ raft::device_vector_view map,
+ raft::device_vector_view stencil,
+ unary_pred_t pred_op)
+{
+ RAFT_EXPECTS(out.extent(0) == map.extent(0),
+ "Number of rows in output matrix must equal the size of the map vector");
+ RAFT_EXPECTS(out.extent(1) == in.extent(1),
+ "Number of columns in input and output matrices must be equal.");
+ RAFT_EXPECTS(map.extent(0) == stencil.extent(0),
+ "Number of elements in stencil must equal number of elements in map");
+
+ detail::gather_if(const_cast(in.data_handle()),
+ out.extent(1),
+ out.extent(0),
+ map.data_handle(),
+ stencil.data_handle(),
+ map.extent(0),
+ out.data_handle(),
+ pred_op,
+ handle.get_stream());
+}
+
/**
* @brief gather_if conditionally copies rows from a source matrix into a destination matrix
* according to a transformed map.
@@ -169,5 +288,58 @@ void gather_if(const MatrixIteratorT in,
{
detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream);
}
-} // namespace matrix
-} // namespace raft
+
+/**
+ * @brief gather_if conditionally copies rows from a source matrix into a destination matrix
+ * according to a transformed map.
+ *
+ * @tparam matrix_t Matrix value type, for reading input matrix
+ * @tparam map_t Vector value type for map
+ * @tparam stencil_t Vector value type for stencil
+ * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result
+ * type must be convertible to bool type.
+ * @tparam map_xform_t Unary lambda expression or operator type, map_xform_t's result
+ * type must be convertible to idx_t (= int) type.
+ * @tparam idx_t integer type for indexing
+ * @param[in] handle raft handle for managing resources
+ * @param[in] in Input matrix (assumed to be row-major)
+ * @param[in] map Vector of gather locations
+ * @param[in] stencil Vector of stencil or predicate values
+ * @param[out] out Output matrix (assumed to be row-major)
+ * @param[in] pred_op Predicate to apply to the stencil values
+ * @param[in] transform_op The transformation operation, transforms the map values to idx_t
+ */
+template
+void gather_if(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ raft::device_vector_view map,
+ raft::device_vector_view stencil,
+ unary_pred_t pred_op,
+ map_xform_t transform_op)
+{
+ RAFT_EXPECTS(out.extent(0) == map.extent(0),
+ "Number of rows in output matrix must equal the size of the map vector");
+ RAFT_EXPECTS(out.extent(1) == in.extent(1),
+ "Number of columns in input and output matrices must be equal.");
+ RAFT_EXPECTS(map.extent(0) == stencil.extent(0),
+ "Number of elements in stencil must equal number of elements in map");
+
+ detail::gather_if(const_cast(in.data_handle()),
+ in.extent(1),
+ in.extent(0),
+ map.data_handle(),
+ stencil.data_handle(),
+ map.extent(0),
+ out.data_handle(),
+ pred_op,
+ transform_op,
+ handle.get_stream());
+}
+
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh
new file mode 100644
index 0000000000..e3a6c09fe6
--- /dev/null
+++ b/cpp/include/raft/matrix/init.cuh
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace raft::matrix {
+/**
+ * @brief set values to scalar in matrix
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[in] in input matrix
+ * @param[out] out output matrix. The result is stored in the out matrix
+ * @param[in] scalar scalar value to fill matrix elements
+ */
+template
+void fill(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ raft::host_scalar_view scalar)
+{
+ RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size.");
+ detail::setValue(
+ out.data_handle(), in.data_handle(), *(scalar.data_handle()), in.size(), handle.get_stream());
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh
new file mode 100644
index 0000000000..6b383b14f5
--- /dev/null
+++ b/cpp/include/raft/matrix/linewise_op.cuh
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * Run a function over matrix lines (rows or columns) with a variable number
+ * row-vectors or column-vectors.
+ * The term `line` here signifies that the lines can be either columns or rows,
+ * depending on the matrix layout.
+ * What matters is if the vectors are applied along lines (indices of vectors correspond to
+ * indices within lines), or across lines (indices of vectors correspond to line numbers).
+ * @tparam m_t matrix elements type
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @tparam Lambda type of lambda function used for the operation
+ * @tparam vec_t variadic types of device_vector_view vectors (size m if alongRows, size n
+ * otherwise)
+ * @param[in] handle raft handle for managing resources
+ * @param [out] out result of the operation; can be same as `in`; should be aligned the same
+ * as `in` to allow faster vectorized memory transfers.
+ * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long.
+ * @param [in] alongLines whether vectors are indices along or across lines.
+ * @param [in] op the operation applied on each line:
+ * for i in [0..lineLen) and j in [0..nLines):
+ * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true
+ * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false
+ * where matrix indexing is row-major ([i, j] = [i + lineLen * j]).
+ * @param [in] vecs zero or more vectors to be passed as arguments,
+ * size of each vector is `alongLines ? lineLen : nLines`.
+ */
+template >
+void linewise_op(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ const bool alongLines,
+ Lambda op,
+ vec_t... vecs)
+{
+ constexpr auto is_rowmajor = std::is_same_v;
+ constexpr auto is_colmajor = std::is_same_v;
+
+ static_assert(is_rowmajor || is_colmajor,
+ "layout for in and out must be either row or col major");
+
+ const idx_t lineLen = is_rowmajor ? in.extent(0) : in.extent(1);
+ const idx_t nLines = is_rowmajor ? in.extent(1) : in.extent(0);
+
+ RAFT_EXPECTS(out.extent(0) == in.extent(0) && out.extent(1) == in.extent(1),
+ "Input and output must have the same shape.");
+
+ detail::MatrixLinewiseOp<16, 256>::run(out.data_handle(),
+ in.data_handle(),
+ lineLen,
+ nLines,
+ alongLines,
+ op,
+ handle.get_stream(),
+ vecs.data_handle()...);
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh
index 9e103afda5..3c2705cf87 100644
--- a/cpp/include/raft/matrix/math.cuh
+++ b/cpp/include/raft/matrix/math.cuh
@@ -14,6 +14,15 @@
* limitations under the License.
*/
+/**
+ * This file is deprecated and will be removed in a future release.
+ * Please use versions in individual header files instead.
+ */
+
+#pragma message(__FILE__ \
+ " is deprecated and will be removed in a future release." \
+ " Please use versions in individual header files instead.")
+
#ifndef __MATH_H
#define __MATH_H
@@ -301,8 +310,8 @@ void ratio(
* @param out: output vector of size n_cols
* @param stream: cuda stream
*/
-template
-void argmax(const math_t* in, int n_rows, int n_cols, math_t* out, cudaStream_t stream)
+template
+void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream)
{
detail::argmax(in, n_rows, n_cols, out, stream);
}
diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh
index 1af7e37dec..3a7e0dad47 100644
--- a/cpp/include/raft/matrix/matrix.cuh
+++ b/cpp/include/raft/matrix/matrix.cuh
@@ -14,6 +14,15 @@
* limitations under the License.
*/
+/**
+ * This file is deprecated and will be removed in a future release.
+ * Please use versions in individual header files instead.
+ */
+
+#pragma message(__FILE__ \
+ " is deprecated and will be removed in a future release." \
+ " Please use versions in individual header files instead.")
+
#ifndef __MATRIX_H
#define __MATRIX_H
@@ -21,6 +30,7 @@
#include "detail/linewise_op.cuh"
#include "detail/matrix.cuh"
+#include
#include
@@ -71,6 +81,24 @@ void copy(const m_t* in, m_t* out, idx_t n_rows, idx_t n_cols, cudaStream_t stre
raft::copy_async(out, in, n_rows * n_cols, stream);
}
+/**
+ * @brief copy matrix operation for column major matrices.
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix
+ * @param[out] out: output matrix
+ */
+template
+void copy(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out)
+{
+ RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1),
+ "Input and output matrix shapes must match.");
+
+ raft::copy_async(
+ out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream());
+}
+
/**
* @brief copy matrix operation for column major matrices. First n_rows and
* n_cols of input matrix "in" is copied to "out" matrix.
diff --git a/cpp/include/raft/matrix/matrix_types.hpp b/cpp/include/raft/matrix/matrix_types.hpp
new file mode 100644
index 0000000000..1f22154627
--- /dev/null
+++ b/cpp/include/raft/matrix/matrix_types.hpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+namespace raft::matrix {
+
+struct print_separators {
+ char horizontal = ' ';
+ char vertical = '\n';
+};
+
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh
new file mode 100644
index 0000000000..4e2b3b7d72
--- /dev/null
+++ b/cpp/include/raft/matrix/power.cuh
@@ -0,0 +1,94 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief Power of every element in the input matrix
+ * @tparam math_t type of matrix elements
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix
+ * @param[out] out: output matrix. The result is stored in the out matrix
+ * @param[in] scalar: every element is multiplied with scalar.
+ */
+template
+void weighted_power(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ math_t scalar)
+{
+ RAFT_EXPECTS(in.size() == out.size(), "Size of input and output matrices must be equal");
+ detail::power(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream());
+}
+
+/**
+ * @brief Power of every element in the input matrix (inplace)
+ * @tparam math_t matrix element type
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[inout] inout: input matrix and also the result is stored
+ * @param[in] scalar: every element is multiplied with scalar.
+ */
+template
+void weighted_power(const raft::handle_t& handle,
+ raft::device_matrix_view inout,
+ math_t scalar)
+{
+ detail::power(inout.data_handle(), scalar, inout.size(), handle.get_stream());
+}
+
+/**
+ * @brief Power of every element in the input matrix (inplace)
+ * @tparam math_t matrix element type
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[inout] inout: input matrix and also the result is stored
+ */
+template
+void power(const raft::handle_t& handle, raft::device_matrix_view inout)
+{
+ detail::power(inout.data_handle(), inout.size(), handle.get_stream());
+}
+
+/**
+ * @brief Power of every element in the input matrix
+ * @tparam math_t type used for matrix elements
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix (row or column major)
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix
+ * @param[out] out: output matrix. The result is stored in the out matrix
+ * @{
+ */
+template
+void power(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out)
+{
+ RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size.");
+ detail::power(in.data_handle(), out.data_handle(), in.size(), handle.get_stream());
+}
+
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh
new file mode 100644
index 0000000000..4d3a8ca938
--- /dev/null
+++ b/cpp/include/raft/matrix/print.cuh
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief Prints the data stored in GPU memory
+ * @tparam m_t type of matrix elements
+ * @tparam idx_t integer type used for indexing
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix
+ * @param[in] separators: horizontal and vertical separator characters
+ */
+template
+void print(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ print_separators& separators)
+{
+ detail::print(in.data_handle(),
+ in.extent(0),
+ in.extent(1),
+ separators.horizontal,
+ separators.vertical,
+ handle.get_stream());
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/print.hpp b/cpp/include/raft/matrix/print.hpp
new file mode 100644
index 0000000000..86c314ed44
--- /dev/null
+++ b/cpp/include/raft/matrix/print.hpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief Prints the data stored in CPU memory
+ * @param[in] in: input matrix with column-major layout
+ * @param[in] separators: horizontal and vertical separator characters
+ */
+template
+void print(raft::host_matrix_view in, print_separators& separators)
+{
+ detail::printHost(
+ in.data_handle(), in.extent(0), in.extent(1), separators.horizontal, separators.vertical);
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh
new file mode 100644
index 0000000000..7895ea972f
--- /dev/null
+++ b/cpp/include/raft/matrix/ratio.cuh
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief ratio of every element over sum of input vector is calculated
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle
+ * @param[in] src: input matrix
+ * @param[out] dest: output matrix. The result is stored in the dest matrix
+ */
+template
+void ratio(const raft::handle_t& handle,
+ raft::device_matrix_view src,
+ raft::device_matrix_view dest)
+{
+ RAFT_EXPECTS(src.size() == dest.size(), "Input and output matrices must be the same size.");
+ detail::ratio(handle, src.data_handle(), dest.data_handle(), src.size(), handle.get_stream());
+}
+
+/**
+ * @brief ratio of every element over sum of input vector is calculated
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle
+ * @param[inout] inout: input matrix
+ */
+template
+void ratio(const raft::handle_t& handle, raft::device_matrix_view inout)
+{
+ detail::ratio(
+ handle, inout.data_handle(), inout.data_handle(), inout.size(), handle.get_stream());
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh
new file mode 100644
index 0000000000..c41ecfb999
--- /dev/null
+++ b/cpp/include/raft/matrix/reciprocal.cuh
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief Reciprocal of every element in the input matrix
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @param handle: raft handle
+ * @param in: input matrix and also the result is stored
+ * @param out: output matrix. The result is stored in the out matrix
+ * @param scalar: every element is multiplied with scalar
+ * @param setzero round down to zero if the input is less the threshold
+ * @param thres the threshold used to forcibly set inputs to zero
+ * @{
+ */
+template
+void reciprocal(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ raft::host_scalar_view scalar,
+ bool setzero = false,
+ math_t thres = 1e-15)
+{
+ RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have the same size.");
+ detail::reciprocal(in.data_handle(),
+ out.data_handle(),
+ *(scalar.data_handle()),
+ in.size(),
+ handle.get_stream(),
+ setzero,
+ thres);
+}
+
+/**
+ * @brief Reciprocal of every element in the input matrix (in place)
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle to manage resources
+ * @param[inout] inout: input matrix with in-place results
+ * @param[in] scalar: every element is multiplied with scalar
+ * @param[in] setzero round down to zero if the input is less the threshold
+ * @param[in] thres the threshold used to forcibly set inputs to zero
+ * @{
+ */
+template
+void reciprocal(const raft::handle_t& handle,
+ raft::device_matrix_view inout,
+ raft::host_scalar_view scalar,
+ bool setzero = false,
+ math_t thres = 1e-15)
+{
+ detail::reciprocal(inout.data_handle(),
+ *(scalar.data_handle()),
+ inout.size(),
+ handle.get_stream(),
+ setzero,
+ thres);
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/sign_flip.cuh b/cpp/include/raft/matrix/sign_flip.cuh
new file mode 100644
index 0000000000..01f8829c85
--- /dev/null
+++ b/cpp/include/raft/matrix/sign_flip.cuh
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief sign flip stabilizes the sign of col major eigen vectors.
+ * The sign is flipped if the column has negative |max|.
+ * @tparam math_t floating point type used for matrix elements
+ * @tparam idx_t integer type used for indexing
+ * @param[in] handle: raft handle
+ * @param[inout] inout: input matrix. Result also stored in this parameter
+ */
+template
+void sign_flip(const raft::handle_t& handle,
+ raft::device_matrix_view inout)
+{
+ detail::signFlip(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream());
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh
new file mode 100644
index 0000000000..302167480e
--- /dev/null
+++ b/cpp/include/raft/matrix/sqrt.cuh
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief Square root of every element in the input matrix
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix and also the result is stored
+ * @param[out] out: output matrix. The result is stored in the out matrix
+ */
+template
+void sqrt(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out)
+{
+ RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size.");
+ detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream());
+}
+
+/**
+ * @brief Square root of every element in the input matrix (in place)
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[inout] inout: input matrix with in-place results
+ */
+template
+void sqrt(const raft::handle_t& handle, raft::device_matrix_view inout)
+{
+ detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream());
+}
+
+/**
+ * @brief Square root of every element in the input matrix
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[in] in: input matrix and also the result is stored
+ * @param[out] out: output matrix. The result is stored in the out matrix
+ * @param[in] scalar: every element is multiplied with scalar
+ * @param[in] set_neg_zero whether to set negative numbers to zero
+ */
+template
+void weighted_sqrt(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ raft::host_scalar_view scalar,
+ bool set_neg_zero = false)
+{
+ RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size.");
+ detail::seqRoot(in.data_handle(),
+ out.data_handle(),
+ *(scalar.data_handle()),
+ in.size(),
+ handle.get_stream(),
+ set_neg_zero);
+}
+
+/**
+ * @brief Square root of every element in the input matrix (in place)
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param[in] handle: raft handle
+ * @param[inout] inout: input matrix and also the result is stored
+ * @param[in] scalar: every element is multiplied with scalar
+ * @param[in] set_neg_zero whether to set negative numbers to zero
+ */
+template
+void weighted_sqrt(const raft::handle_t& handle,
+ raft::device_matrix_view inout,
+ raft::host_scalar_view scalar,
+ bool set_neg_zero = false)
+{
+ detail::seqRoot(
+ inout.data_handle(), *(scalar.data_handle()), inout.size(), handle.get_stream(), set_neg_zero);
+}
+
+} // namespace raft::matrix
diff --git a/cpp/include/raft/matrix/threshold.cuh b/cpp/include/raft/matrix/threshold.cuh
new file mode 100644
index 0000000000..7540ceb3c6
--- /dev/null
+++ b/cpp/include/raft/matrix/threshold.cuh
@@ -0,0 +1,61 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+
+namespace raft::matrix {
+
+/**
+ * @brief sets the small values to zero based on a defined threshold
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param handle: raft handle
+ * @param[in] in: input matrix
+ * @param[out] out: output matrix. The result is stored in the out matrix
+ * @param[in] thres threshold to set values to zero
+ */
+template
+void zero_small_values(const raft::handle_t& handle,
+ raft::device_matrix_view in,
+ raft::device_matrix_view out,
+ math_t thres = 1e-15)
+{
+ RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size");
+ detail::setSmallValuesZero(
+ out.data_handle(), in.data_handle(), in.size(), handle.get_stream(), thres);
+}
+
+/**
+ * @brief sets the small values to zero in-place based on a defined threshold
+ * @tparam math_t data-type upon which the math operation will be performed
+ * @tparam idx_t integer type used for indexing
+ * @tparam layout layout of the matrix data (must be row or col major)
+ * @param handle: raft handle
+ * @param inout: input matrix and also the result is stored
+ * @param thres: threshold
+ */
+template
+void zero_small_values(const raft::handle_t& handle,
+ raft::device_matrix_view inout,
+ math_t thres = 1e-15)
+{
+ detail::setSmallValuesZero(inout.data_handle(), inout.size(), handle.get_stream(), thres);
+}
+} // namespace raft::matrix
diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh
index 714b019fba..9cb9b573b1 100644
--- a/cpp/include/raft/spatial/knn/ball_cover.cuh
+++ b/cpp/include/raft/spatial/knn/ball_cover.cuh
@@ -203,7 +203,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
*/
template
void rbc_knn_query(const raft::handle_t& handle,
- BallCoverIndex& index,
+ const BallCoverIndex& index,
int_t k,
const value_t* query,
int_t n_query_pts,
@@ -272,7 +272,7 @@ void rbc_knn_query(const raft::handle_t& handle,
*/
template
void rbc_knn_query(const raft::handle_t& handle,
- BallCoverIndex& index,
+ const BallCoverIndex& index,
raft::device_matrix_view query,
raft::device_matrix_view inds,
raft::device_matrix_view dists,
@@ -289,7 +289,7 @@ void rbc_knn_query(const raft::handle_t& handle,
"Number of rows in output indices and distances matrices must equal number of rows "
"in search matrix.");
- RAFT_EXPECTS(query.extent(1) == index.get_R().extent(1),
+ RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1),
"Number of columns in query and index matrices must match.");
rbc_knn_query(handle,
@@ -311,4 +311,4 @@ void rbc_knn_query(const raft::handle_t& handle,
} // namespace spatial
} // namespace raft
-#endif
\ No newline at end of file
+#endif
diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp
index 1dd45365b7..897bb4df5b 100644
--- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp
+++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp
@@ -58,13 +58,12 @@ class BallCoverIndex {
* Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m)
*/
n_landmarks(sqrt(m_)),
- R_indptr(std::move(raft::make_device_vector(handle, sqrt(m_) + 1))),
- R_1nn_cols(std::move(raft::make_device_vector(handle, m_))),
- R_1nn_dists(std::move(raft::make_device_vector(handle, m_))),
- R_closest_landmark_dists(
- std::move(raft::make_device_vector(handle, m_))),
- R(std::move(raft::make_device_matrix(handle, sqrt(m_), n_))),
- R_radius(std::move(raft::make_device_vector(handle, sqrt(m_)))),
+ R_indptr(raft::make_device_vector(handle, sqrt(m_) + 1)),
+ R_1nn_cols(raft::make_device_vector(handle, m_)),
+ R_1nn_dists(raft::make_device_vector(handle, m_)),
+ R_closest_landmark_dists(raft::make_device_vector(handle, m_)),
+ R(raft::make_device_matrix(handle, sqrt(m_), n_)),
+ R_radius(raft::make_device_vector(handle, sqrt(m_))),
index_trained(false)
{
}
@@ -83,20 +82,41 @@ class BallCoverIndex {
* Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m)
*/
n_landmarks(sqrt(X_.extent(0))),
- R_indptr(
- std::move(raft::make_device_vector