From 5d697b3ef6e12d6ffd77fbe7842a47bc7d4819e5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 14:47:53 -0400 Subject: [PATCH 01/58] Breaking apart mdspan/mdarray into host_ and device_ variants --- .../raft/core/detail/accessor_mixin.hpp | 39 + cpp/include/raft/core/device_mdarray.hpp | 174 ++++ cpp/include/raft/core/device_mdspan.hpp | 194 +++++ cpp/include/raft/core/host_mdarray.hpp | 144 ++++ cpp/include/raft/core/host_mdspan.hpp | 142 ++++ cpp/include/raft/core/mdarray.hpp | 778 +----------------- cpp/include/raft/core/mdspan.hpp | 265 +++++- cpp/include/raft/detail/mdarray.hpp | 74 +- cpp/include/raft/detail/span.hpp | 2 +- 9 files changed, 962 insertions(+), 850 deletions(-) create mode 100644 cpp/include/raft/core/detail/accessor_mixin.hpp create mode 100644 cpp/include/raft/core/device_mdarray.hpp create mode 100644 cpp/include/raft/core/device_mdspan.hpp create mode 100644 cpp/include/raft/core/host_mdarray.hpp create mode 100644 cpp/include/raft/core/host_mdspan.hpp diff --git a/cpp/include/raft/core/detail/accessor_mixin.hpp b/cpp/include/raft/core/detail/accessor_mixin.hpp new file mode 100644 index 0000000000..6edd85dbaf --- /dev/null +++ b/cpp/include/raft/core/detail/accessor_mixin.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::detail { + +/** + * @brief A mixin to distinguish host and device memory. + */ +template +struct accessor_mixin : public AccessorPolicy { + using accessor_type = AccessorPolicy; + using is_host_type = std::conditional_t; + using is_device_type = std::conditional_t; + using is_managed_type = std::conditional_t; + static constexpr bool is_host_accessible = is_host; + static constexpr bool is_device_accessible = is_device; + static constexpr bool is_managed_accessible = is_device && is_host; + // make sure the explicit ctor can fall through + using AccessorPolicy::AccessorPolicy; + using offset_policy = accessor_mixin; + accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT +}; + +} // namespace raft::detail diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp new file mode 100644 index 0000000000..c5b5d73fdc --- /dev/null +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { + +/** + * @brief mdarray with device container policy + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy + */ +template > +using device_mdarray = + mdarray>; + +/** + * @brief Shorthand for 0-dim host mdarray (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using device_scalar = device_mdarray>; + +/** + * @brief Shorthand for 1-dim device mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_vector = device_mdarray, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous device matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_matrix = device_mdarray, LayoutPolicy>; + +/** + * @brief Create a device mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::handle_t + * @param exts dimensionality of the array (series of integers) + * @return raft::device_mdarray + */ +template +auto make_device_mdarray(const raft::handle_t& handle, extents exts) +{ + using mdarray_t = device_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{handle.get_stream()}; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a device mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param handle raft::handle_t + * @param mr rmm memory resource used for allocating the memory for the array + * @param exts dimensionality of the array (series of integers) + * @return raft::device_mdarray + */ +template +auto make_device_mdarray(const raft::handle_t& handle, + rmm::mr::device_memory_resource* mr, + extents exts) +{ + using mdarray_t = device_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy{handle.get_stream(), mr}; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a 2-dim c-contiguous device mdarray. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::device_matrix + */ +template +auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexType n_cols) +{ + return make_device_mdarray( + handle.get_stream(), make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a device scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] v scalar to wrap on device + * @return raft::device_scalar + */ +template +auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) +{ + scalar_extent extents; + using policy_t = typename device_scalar::container_policy_type; + policy_t policy{handle.get_stream()}; + auto scalar = device_scalar{extents, policy}; + scalar(0) = v; + return scalar; +} + +/** + * @brief Create a 1-dim device mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] n number of elements in vector + * @return raft::device_vector + */ +template +auto make_device_vector(raft::handle_t const& handle, IndexType n) +{ + return make_device_mdarray(handle.get_stream(), + make_extents(n)); +} + +} // end namespace raft diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp new file mode 100644 index 0000000000..f835397e9d --- /dev/null +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { + +template +using device_accessor = detail::accessor_mixin; + +template +using managed_accessor = detail::accessor_mixin; + +/** + * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. + */ +template > +using device_mdspan = mdspan>; + +template > +using managed_mdspan = mdspan>; + +namespace detail { +template +struct is_device_mdspan : std::false_type { +}; +template +struct is_device_mdspan : std::bool_constant { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type + */ +template +using is_device_mdspan_t = is_device_mdspan>; + +template +struct is_managed_mdspan : std::false_type { +}; +template +struct is_managed_mdspan : std::bool_constant { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type + */ +template +using is_managed_mdspan_t = is_managed_mdspan>; + +} // end namespace detail + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a + * derived type + */ +template +inline constexpr bool is_device_mdspan_v = std::conjunction_v...>; + +template +using enable_if_device_mdspan = std::enable_if_t>; + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a + * derived type + */ +template +inline constexpr bool is_managed_mdspan_v = std::conjunction_v...>; + +template +using enable_if_managed_mdspan = std::enable_if_t>; + +/** + * @brief Shorthand for 0-dim host mdspan (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using device_scalar_view = device_mdspan>; + +/** + * @brief Shorthand for 1-dim device mdspan. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_vector_view = device_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous device matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using device_matrix_view = device_mdspan, LayoutPolicy>; + +/** + * @brief Create a raft::managed_mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::managed_mdspan + */ +template +auto make_managed_mdspan(ElementType* ptr, extents exts) +{ + return make_mdspan(ptr, exts); +} + +/** + * @brief Create a 0-dim (scalar) mdspan instance for device value. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @param[in] ptr on device to wrap + */ +template +auto make_device_scalar_view(ElementType* ptr) +{ + scalar_extent extents; + return device_scalar_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim c-contiguous mdspan instance for device pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam IndexType the index type of the extents + * @param[in] ptr on device to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template +auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +{ + matrix_extent extents{n_rows, n_cols}; + return device_matrix_view{ptr, extents}; +} + +/** + * @brief Create a 1-dim mdspan instance for device pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] n number of elements in pointer + * @return raft::device_vector_view + */ +template +auto make_device_vector_view(ElementType* ptr, IndexType n) +{ + vector_extent extents{n}; + return device_vector_view{ptr, extents}; +} + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp new file mode 100644 index 0000000000..872c007255 --- /dev/null +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +/** + * @brief mdarray with host container policy + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy + */ +template > +using host_mdarray = + mdarray>; + +/** + * @brief Shorthand for 0-dim host mdarray (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using host_scalar = host_mdarray>; + +/** + * @brief Shorthand for 1-dim host mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using host_vector = host_mdarray, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous host matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using host_matrix = host_mdarray, LayoutPolicy>; + +/** + * @brief Create a host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param exts dimensionality of the array (series of integers) + * @return raft::host_mdarray + */ +template +auto make_host_mdarray(extents exts) +{ + using mdarray_t = host_mdarray; + + typename mdarray_t::mapping_type layout{exts}; + typename mdarray_t::container_policy_type policy; + + return mdarray_t{layout, policy}; +} + +/** + * @brief Create a 2-dim c-contiguous host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::host_matrix + */ +template +auto make_host_matrix(IndexType n_rows, IndexType n_cols) +{ + return make_host_mdarray( + make_extents(n_rows, n_cols)); +} + +/** + * @brief Create a host scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + * @param[in] v scalar type to wrap + * @return raft::host_scalar + */ +template +auto make_host_scalar(ElementType const& v) +{ + // FIXME(jiamingy): We can optimize this by using std::array as container policy, which + // requires some more compile time dispatching. This is enabled in the ref impl but + // hasn't been ported here yet. + scalar_extent extents; + using policy_t = typename host_scalar::container_policy_type; + policy_t policy; + auto scalar = host_scalar{extents, policy}; + scalar(0) = v; + return scalar; +} + +/** + * @brief Create a 1-dim host mdarray. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] n number of elements in vector + * @return raft::host_vector + */ +template +auto make_host_vector(IndexType n) +{ + return make_host_mdarray(make_extents(n)); +} + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp new file mode 100644 index 0000000000..f46fb6ff17 --- /dev/null +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { + +template +using host_accessor = detail::accessor_mixin; + +/** + * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. + */ +template > +using host_mdspan = mdspan>; + +namespace detail { + +template +struct is_host_mdspan : std::false_type { +}; +template +struct is_host_mdspan : std::bool_constant { +}; + +/** + * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type + */ +template +using is_host_mdspan_t = is_host_mdspan>; + +} // namespace detail + +/** + * @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a + * derived type + */ +template +inline constexpr bool is_host_mdspan_v = std::conjunction_v...>; + +template +using enable_if_host_mdspan = std::enable_if_t>; + +/** + * @brief Shorthand for 0-dim host mdspan (scalar). + * @tparam ElementType the data type of the scalar element + * @tparam IndexType the index type of the extents + */ +template +using host_scalar_view = host_mdspan>; + +/** + * @brief Shorthand for 1-dim host mdspan. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + */ +template +using host_vector_view = host_mdspan, LayoutPolicy>; + +/** + * @brief Shorthand for c-contiguous host matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + */ +template +using host_matrix_view = host_mdspan, LayoutPolicy>; + +/** + * @brief Create a 0-dim (scalar) mdspan instance for host value. + * + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @param[in] ptr on device to wrap + */ +template +auto make_host_scalar_view(ElementType* ptr) +{ + scalar_extent extents; + return host_scalar_view{ptr, extents}; +} + +/** + * @brief Create a 2-dim c-contiguous mdspan instance for host pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on host to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + */ +template +auto make_host_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) +{ + matrix_extent extents{n_rows, n_cols}; + return host_matrix_view{ptr, extents}; +} + +/** + * @brief Create a 1-dim mdspan instance for host pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @param[in] ptr on host to wrap + * @param[in] n number of elements in pointer + * @return raft::host_vector_view + */ +template +auto make_host_vector_view(ElementType* ptr, IndexType n) +{ + vector_extent extents{n}; + return host_vector_view{ptr, extents}; +} +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index d251c2b419..e918d81ff1 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -24,212 +24,15 @@ #include +#include #include +#include #include #include - #include #include namespace raft { -/** - * @brief Dimensions extents for raft::host_mdspan or raft::device_mdspan - */ -template -using extents = std::experimental::extents; - -/** - * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. - * @{ - */ -using detail::stdex::layout_right; -using layout_c_contiguous = layout_right; -using row_major = layout_right; -/** @} */ - -/** - * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. - * @{ - */ -using detail::stdex::layout_left; -using layout_f_contiguous = layout_left; -using col_major = layout_left; -/** @} */ - -/** - * @brief Strided layout for non-contiguous memory. - */ -using detail::stdex::layout_stride; - -/** - * @defgroup Common mdarray/mdspan extent types. The rank is known at compile time, each dimension - * is known at run time (dynamic_extent in each dimension). - * @{ - */ -using detail::matrix_extent; -using detail::scalar_extent; -using detail::vector_extent; - -template -using extent_1d = vector_extent; - -template -using extent_2d = matrix_extent; - -template -using extent_3d = detail::stdex::extents; - -template -using extent_4d = - detail::stdex::extents; - -template -using extent_5d = detail::stdex::extents; -/** @} */ - -template > -using mdspan = detail::stdex::mdspan; - -namespace detail { -/** - * @\brief Template checks and helpers to determine if type T is an std::mdspan - * or a derived type - */ - -template -void __takes_an_mdspan_ptr(mdspan*); - -template -struct is_mdspan : std::false_type { -}; -template -struct is_mdspan()))>> - : std::true_type { -}; - -template -using is_mdspan_t = is_mdspan>; - -template -inline constexpr bool is_mdspan_v = is_mdspan_t::value; -} // namespace detail - -/** - * @\brief Boolean to determine if variadic template types Tn are either - * raft::host_mdspan/raft::device_mdspan or their derived types - */ -template -inline constexpr bool is_mdspan_v = std::conjunction_v...>; - -template -using enable_if_mdspan = std::enable_if_t>; - -/** - * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. - */ -template > -using device_mdspan = - mdspan>; - -/** - * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. - */ -template > -using host_mdspan = - mdspan>; - -template > -using managed_mdspan = - mdspan>; - -namespace detail { -template -struct is_device_mdspan : std::false_type { -}; -template -struct is_device_mdspan : std::bool_constant { -}; - -/** - * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type - */ -template -using is_device_mdspan_t = is_device_mdspan>; - -template -struct is_host_mdspan : std::false_type { -}; -template -struct is_host_mdspan : std::bool_constant { -}; - -/** - * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type - */ -template -using is_host_mdspan_t = is_host_mdspan>; - -template -struct is_managed_mdspan : std::false_type { -}; -template -struct is_managed_mdspan : std::bool_constant { -}; - -/** - * @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type - */ -template -using is_managed_mdspan_t = is_managed_mdspan>; -} // namespace detail - -/** - * @\brief Boolean to determine if variadic template types Tn are either raft::device_mdspan or a - * derived type - */ -template -inline constexpr bool is_device_mdspan_v = std::conjunction_v...>; - -template -using enable_if_device_mdspan = std::enable_if_t>; - -/** - * @\brief Boolean to determine if variadic template types Tn are either raft::host_mdspan or a - * derived type - */ -template -inline constexpr bool is_host_mdspan_v = std::conjunction_v...>; - -template -using enable_if_host_mdspan = std::enable_if_t>; - -/** - * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a - * derived type - */ -template -inline constexpr bool is_managed_mdspan_v = std::conjunction_v...>; - -template -using enable_if_managed_mdspan = std::enable_if_t>; - /** * @brief Interface to implement an owning multi-dimensional array * @@ -531,521 +334,6 @@ class mdarray container_type c_; }; -/** - * @brief mdarray with host container policy - * @tparam ElementType the data type of the elements - * @tparam Extents defines the shape - * @tparam LayoutPolicy policy for indexing strides and layout ordering - * @tparam ContainerPolicy storage and accessor policy - */ -template > -using host_mdarray = - mdarray>; - -/** - * @brief mdarray with device container policy - * @tparam ElementType the data type of the elements - * @tparam Extents defines the shape - * @tparam LayoutPolicy policy for indexing strides and layout ordering - * @tparam ContainerPolicy storage and accessor policy - */ -template > -using device_mdarray = - mdarray>; - -/** - * @brief Shorthand for 0-dim host mdarray (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using host_scalar = host_mdarray>; - -/** - * @brief Shorthand for 0-dim host mdarray (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using device_scalar = device_mdarray>; - -/** - * @brief Shorthand for 1-dim host mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using host_vector = host_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for 1-dim device mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_vector = device_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous host matrix. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using host_matrix = host_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous device matrix. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_matrix = device_mdarray, LayoutPolicy>; - -/** - * @brief Shorthand for 0-dim host mdspan (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using host_scalar_view = host_mdspan>; - -/** - * @brief Shorthand for 0-dim host mdspan (scalar). - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - */ -template -using device_scalar_view = device_mdspan>; - -/** - * @brief Shorthand for 1-dim host mdspan. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - */ -template -using host_vector_view = host_mdspan, LayoutPolicy>; - -/** - * @brief Shorthand for 1-dim device mdspan. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_vector_view = device_mdspan, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous host matrix view. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using host_matrix_view = host_mdspan, LayoutPolicy>; - -/** - * @brief Shorthand for c-contiguous device matrix view. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - */ -template -using device_matrix_view = device_mdspan, LayoutPolicy>; - -/** - * @brief Create a raft::mdspan - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @tparam is_host_accessible whether the data is accessible on host - * @tparam is_device_accessible whether the data is accessible on device - * @param ptr Pointer to the data - * @param exts dimensionality of the array (series of integers) - * @return raft::mdspan - */ -template -auto make_mdspan(ElementType* ptr, extents exts) -{ - using accessor_type = detail::accessor_mixin, - is_host_accessible, - is_device_accessible>; - - return mdspan{ptr, exts}; -} - -/** - * @brief Create a raft::managed_mdspan - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param ptr Pointer to the data - * @param exts dimensionality of the array (series of integers) - * @return raft::managed_mdspan - */ -template -auto make_managed_mdspan(ElementType* ptr, extents exts) -{ - return make_mdspan(ptr, exts); -} - -/** - * @brief Create a 0-dim (scalar) mdspan instance for host value. - * - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @param[in] ptr on device to wrap - */ -template -auto make_host_scalar_view(ElementType* ptr) -{ - scalar_extent extents; - return host_scalar_view{ptr, extents}; -} - -/** - * @brief Create a 0-dim (scalar) mdspan instance for device value. - * - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @param[in] ptr on device to wrap - */ -template -auto make_device_scalar_view(ElementType* ptr) -{ - scalar_extent extents; - return device_scalar_view{ptr, extents}; -} - -/** - * @brief Create a 2-dim c-contiguous mdspan instance for host pointer. It's - * expected that the given layout policy match the layout of the underlying - * pointer. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] ptr on host to wrap - * @param[in] n_rows number of rows in pointer - * @param[in] n_cols number of columns in pointer - */ -template -auto make_host_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) -{ - matrix_extent extents{n_rows, n_cols}; - return host_matrix_view{ptr, extents}; -} -/** - * @brief Create a 2-dim c-contiguous mdspan instance for device pointer. It's - * expected that the given layout policy match the layout of the underlying - * pointer. - * @tparam ElementType the data type of the matrix elements - * @tparam LayoutPolicy policy for strides and layout ordering - * @tparam IndexType the index type of the extents - * @param[in] ptr on device to wrap - * @param[in] n_rows number of rows in pointer - * @param[in] n_cols number of columns in pointer - */ -template -auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_cols) -{ - matrix_extent extents{n_rows, n_cols}; - return device_matrix_view{ptr, extents}; -} - -/** - * @brief Create a 1-dim mdspan instance for host pointer. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @param[in] ptr on host to wrap - * @param[in] n number of elements in pointer - * @return raft::host_vector_view - */ -template -auto make_host_vector_view(ElementType* ptr, IndexType n) -{ - vector_extent extents{n}; - return host_vector_view{ptr, extents}; -} - -/** - * @brief Create a 1-dim mdspan instance for device pointer. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] ptr on device to wrap - * @param[in] n number of elements in pointer - * @return raft::device_vector_view - */ -template -auto make_device_vector_view(ElementType* ptr, IndexType n) -{ - vector_extent extents{n}; - return device_vector_view{ptr, extents}; -} - -/** - * @brief Create a host mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param exts dimensionality of the array (series of integers) - * @return raft::host_mdarray - */ -template -auto make_host_mdarray(extents exts) -{ - using mdarray_t = host_mdarray; - - typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy; - - return mdarray_t{layout, policy}; -} - -/** - * @brief Create a device mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param handle raft::handle_t - * @param exts dimensionality of the array (series of integers) - * @return raft::device_mdarray - */ -template -auto make_device_mdarray(const raft::handle_t& handle, extents exts) -{ - using mdarray_t = device_mdarray; - - typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy{handle.get_stream()}; - - return mdarray_t{layout, policy}; -} - -/** - * @brief Create a device mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param handle raft::handle_t - * @param mr rmm memory resource used for allocating the memory for the array - * @param exts dimensionality of the array (series of integers) - * @return raft::device_mdarray - */ -template -auto make_device_mdarray(const raft::handle_t& handle, - rmm::mr::device_memory_resource* mr, - extents exts) -{ - using mdarray_t = device_mdarray; - - typename mdarray_t::mapping_type layout{exts}; - typename mdarray_t::container_policy_type policy{handle.get_stream(), mr}; - - return mdarray_t{layout, policy}; -} - -/** - * @brief Create raft::extents to specify dimensionality - * - * @tparam IndexType The type of each dimension of the extents - * @tparam Extents Dimensions (a series of integers) - * @param exts The desired dimensions - * @return raft::extents - */ -template > -auto make_extents(Extents... exts) -{ - return extents{exts...}; -} - -/** - * @brief Create a 2-dim c-contiguous host mdarray. - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] n_rows number or rows in matrix - * @param[in] n_cols number of columns in matrix - * @return raft::host_matrix - */ -template -auto make_host_matrix(IndexType n_rows, IndexType n_cols) -{ - return make_host_mdarray( - make_extents(n_rows, n_cols)); -} - -/** - * @brief Create a 2-dim c-contiguous device mdarray. - * - * @tparam ElementType the data type of the matrix elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] handle raft handle for managing expensive resources - * @param[in] n_rows number or rows in matrix - * @param[in] n_cols number of columns in matrix - * @return raft::device_matrix - */ -template -auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexType n_cols) -{ - return make_device_mdarray( - handle.get_stream(), make_extents(n_rows, n_cols)); -} - -/** - * @brief Create a host scalar from v. - * - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - * @param[in] v scalar type to wrap - * @return raft::host_scalar - */ -template -auto make_host_scalar(ElementType const& v) -{ - // FIXME(jiamingy): We can optimize this by using std::array as container policy, which - // requires some more compile time dispatching. This is enabled in the ref impl but - // hasn't been ported here yet. - scalar_extent extents; - using policy_t = typename host_scalar::container_policy_type; - policy_t policy; - auto scalar = host_scalar{extents, policy}; - scalar(0) = v; - return scalar; -} - -/** - * @brief Create a device scalar from v. - * - * @tparam ElementType the data type of the scalar element - * @tparam IndexType the index type of the extents - * @param[in] handle raft handle for managing expensive cuda resources - * @param[in] v scalar to wrap on device - * @return raft::device_scalar - */ -template -auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) -{ - scalar_extent extents; - using policy_t = typename device_scalar::container_policy_type; - policy_t policy{handle.get_stream()}; - auto scalar = device_scalar{extents, policy}; - scalar(0) = v; - return scalar; -} - -/** - * @brief Create a 1-dim host mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] n number of elements in vector - * @return raft::host_vector - */ -template -auto make_host_vector(IndexType n) -{ - return make_host_mdarray(make_extents(n)); -} - -/** - * @brief Create a 1-dim device mdarray. - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] handle raft handle for managing expensive cuda resources - * @param[in] n number of elements in vector - * @return raft::device_vector - */ -template -auto make_device_vector(raft::handle_t const& handle, IndexType n) -{ - return make_device_mdarray(handle.get_stream(), - make_extents(n)); -} - -/** - * @brief Flatten raft::host_mdspan or raft::device_mdspan into a 1-dim array view - * - * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan - * @param mds raft::host_mdspan or raft::device_mdspan object - * @return raft::host_mdspan or raft::device_mdspan with vector_extent - * depending on AccessoryPolicy - */ -template > -auto flatten(mdspan_type mds) -{ - RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); - - vector_extent ext{mds.size()}; - - return detail::stdex::mdspan(mds.data_handle(), ext); -} - /** * @brief Flatten object implementing raft::array_interface into a 1-dim array view * @@ -1061,36 +349,6 @@ auto flatten(const array_interface_type& mda) return flatten(mda.view()); } -/** - * @brief Reshape raft::host_mdspan or raft::device_mdspan - * - * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan - * @tparam IndexType the index type of the extents - * @tparam Extents raft::extents for dimensions - * @param mds raft::host_mdspan or raft::device_mdspan object - * @param new_shape Desired new shape of the input - * @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy - */ -template > -auto reshape(mdspan_type mds, extents new_shape) -{ - RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); - - size_t new_size = 1; - for (size_t i = 0; i < new_shape.rank(); ++i) { - new_size *= new_shape.extent(i); - } - RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); - - return detail::stdex::mdspan(mds.data_handle(), new_shape); -} - /** * @brief Reshape object implementing raft::array_interface * @@ -1111,36 +369,4 @@ auto reshape(const array_interface_type& mda, extents new return reshape(mda.view(), new_shape); } -/** - * \brief Turns linear index into coordinate. Similar to numpy unravel_index. - * - * \code - * auto m = make_host_matrix(7, 6); - * auto m_v = m.view(); - * auto coord = unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); - * std::apply(m_v, coord) = 2; - * \endcode - * - * \param idx The linear index. - * \param shape The shape of the array to use. - * \param layout Must be `layout_c_contiguous` (row-major) in current implementation. - * - * \return A std::tuple that represents the coordinate. - */ -template -MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, - extents shape, - LayoutPolicy const& layout) -{ - static_assert(std::is_same_v>, - layout_c_contiguous>, - "Only C layout is supported."); - static_assert(std::is_integral_v, "Index must be integral."); - auto constexpr kIs64 = sizeof(std::remove_cv_t>) == sizeof(uint64_t); - if (kIs64 && static_cast(idx) > std::numeric_limits::max()) { - return detail::unravel_index_impl(static_cast(idx), shape); - } else { - return detail::unravel_index_impl(static_cast(idx), shape); - } -} } // namespace raft diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 809134e96e..6078c730bf 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -22,4 +22,267 @@ */ #pragma once -#include \ No newline at end of file +#include +#include // dynamic_extent +#include + +namespace raft { +/** + * @brief Dimensions extents for raft::mdspan + */ +template +using extents = std::experimental::extents; + +namespace stdex = std::experimental; + +/** + * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. + * @{ + */ +using stdex::layout_right; +using layout_c_contiguous = layout_right; +using row_major = layout_right; +/** @} */ + +/** + * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. + * @{ + */ +using stdex::layout_left; +using layout_f_contiguous = layout_left; +using col_major = layout_left; +/** @} */ + +template +using vector_extent = stdex::extents; + +template +using matrix_extent = stdex::extents; + +template +using scalar_extent = stdex::extents; + +/** + * @brief Strided layout for non-contiguous memory. + */ +using stdex::layout_stride; + +template +using extent_1d = vector_extent; + +template +using extent_2d = matrix_extent; + +template +using extent_3d = stdex::extents; + +template +using extent_4d = + stdex::extents; + +template +using extent_5d = stdex::extents; +/** @} */ + +template > +using mdspan = stdex::mdspan; + +/** + * Ensure all types listed in the parameter pack `Extents` are integral types. + * Usage: + * put it as the last nameless template parameter of a function: + * `typename = ensure_integral_extents` + */ +template +using ensure_integral_extents = std::enable_if_t...>>; + +/** + * @\brief Template checks and helpers to determine if type T is an std::mdspan + * or a derived type + */ + +template +void __takes_an_mdspan_ptr(mdspan*); + +template +struct is_mdspan : std::false_type { +}; +template +struct is_mdspan()))>> + : std::true_type { +}; + +template +using is_mdspan_t = is_mdspan>; + +// template +// inline constexpr bool is_mdspan_v = is_mdspan_t::value; + +/** + * @\brief Boolean to determine if variadic template types Tn are either + * raft::host_mdspan/raft::device_mdspan or their derived types + */ +template +inline constexpr bool is_mdspan_v = std::conjunction_v...>; + +template +using enable_if_mdspan = std::enable_if_t>; + +// uint division optimization inspired by the CIndexer in cupy. Division operation is +// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 +// bit when the index is smaller, then try to avoid division when it's exp of 2. +template +MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) +{ + constexpr auto kRank = static_cast(shape.rank()); + std::size_t index[shape.rank()]{0}; // NOLINT + static_assert(std::is_signed::value, + "Don't change the type without changing the for loop."); + for (int32_t dim = kRank; --dim > 0;) { + auto s = static_cast>>(shape.extent(dim)); + if (s & (s - 1)) { + auto t = idx / s; + index[dim] = idx - t * s; + idx = t; + } else { // exp of 2 + index[dim] = idx & (s - 1); + idx >>= popc(s - 1); + } + } + index[0] = idx; + return arr_to_tup(index); +} + +/** + * @brief Create a raft::mdspan + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @tparam is_host_accessible whether the data is accessible on host + * @tparam is_device_accessible whether the data is accessible on device + * @param ptr Pointer to the data + * @param exts dimensionality of the array (series of integers) + * @return raft::mdspan + */ +template +auto make_mdspan(ElementType* ptr, extents exts) +{ + using accessor_type = detail::accessor_mixin, + is_host_accessible, + is_device_accessible>; + + return mdspan{ptr, exts}; +} + +/** + * @brief Create raft::extents to specify dimensionality + * + * @tparam IndexType The type of each dimension of the extents + * @tparam Extents Dimensions (a series of integers) + * @param exts The desired dimensions + * @return raft::extents + */ +template > +auto make_extents(Extents... exts) +{ + return extents{exts...}; +} + +/** + * @brief Flatten raft::mdspan into a 1-dim array view + * + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @param mds raft::host_mdspan or raft::device_mdspan object + * @return raft::host_mdspan or raft::device_mdspan with vector_extent + * depending on AccessoryPolicy + */ +template > +auto flatten(mdspan_type mds) +{ + RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); + + vector_extent ext{mds.size()}; + + return stdex::mdspan(mds.data_handle(), ext); +} + +/** + * @brief Reshape raft::host_mdspan or raft::device_mdspan + * + * @tparam mdspan_type Expected type raft::host_mdspan or raft::device_mdspan + * @tparam IndexType the index type of the extents + * @tparam Extents raft::extents for dimensions + * @param mds raft::host_mdspan or raft::device_mdspan object + * @param new_shape Desired new shape of the input + * @return raft::host_mdspan or raft::device_mdspan, depending on AccessorPolicy + */ +template > +auto reshape(mdspan_type mds, extents new_shape) +{ + RAFT_EXPECTS(mds.is_exhaustive(), "Input must be contiguous."); + + size_t new_size = 1; + for (size_t i = 0; i < new_shape.rank(); ++i) { + new_size *= new_shape.extent(i); + } + RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); + + return stdex::mdspan(mds.data_handle(), new_shape); +} + +/** + * \brief Turns linear index into coordinate. Similar to numpy unravel_index. + * + * \code + * auto m = make_host_matrix(7, 6); + * auto m_v = m.view(); + * auto coord = unravel_index(2, m.extents(), typename decltype(m)::layout_type{}); + * std::apply(m_v, coord) = 2; + * \endcode + * + * \param idx The linear index. + * \param shape The shape of the array to use. + * \param layout Must be `layout_c_contiguous` (row-major) in current implementation. + * + * \return A std::tuple that represents the coordinate. + */ +template +MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, + extents shape, + LayoutPolicy const& layout) +{ + static_assert(std::is_same_v>, + layout_c_contiguous>, + "Only C layout is supported."); + static_assert(std::is_integral_v, "Index must be integral."); + auto constexpr kIs64 = sizeof(std::remove_cv_t>) == sizeof(uint64_t); + if (kIs64 && static_cast(idx) > std::numeric_limits::max()) { + return unravel_index_impl(static_cast(idx), shape); + } else { + return unravel_index_impl(static_cast(idx), shape); + } +} + +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index dd813a7c18..e3d2b0bf9e 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -21,7 +21,9 @@ * limitations under the License. */ #pragma once +#include #include + #include #include // dynamic_extent @@ -228,44 +230,6 @@ class host_vector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -/** - * @brief A mixin to distinguish host and device memory. - */ -template -struct accessor_mixin : public AccessorPolicy { - using accessor_type = AccessorPolicy; - using is_host_type = std::conditional_t; - using is_device_type = std::conditional_t; - using is_managed_type = std::conditional_t; - static constexpr bool is_host_accessible = is_host; - static constexpr bool is_device_accessible = is_device; - static constexpr bool is_managed_accessible = is_device && is_host; - // make sure the explicit ctor can fall through - using AccessorPolicy::AccessorPolicy; - using offset_policy = accessor_mixin; - accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT -}; - -template -using host_accessor = accessor_mixin; - -template -using device_accessor = accessor_mixin; - -template -using managed_accessor = accessor_mixin; - -namespace stdex = std::experimental; - -template -using vector_extent = stdex::extents; - -template -using matrix_extent = stdex::extents; - -template -using scalar_extent = stdex::extents; - template MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t { @@ -310,38 +274,4 @@ MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) return arr_to_tup(arr, std::make_index_sequence{}); } -// uint division optimization inspired by the CIndexer in cupy. Division operation is -// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 -// bit when the index is smaller, then try to avoid division when it's exp of 2. -template -MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) -{ - constexpr auto kRank = static_cast(shape.rank()); - std::size_t index[shape.rank()]{0}; // NOLINT - static_assert(std::is_signed::value, - "Don't change the type without changing the for loop."); - for (int32_t dim = kRank; --dim > 0;) { - auto s = static_cast>>(shape.extent(dim)); - if (s & (s - 1)) { - auto t = idx / s; - index[dim] = idx - t * s; - idx = t; - } else { // exp of 2 - index[dim] = idx & (s - 1); - idx >>= popc(s - 1); - } - } - index[0] = idx; - return arr_to_tup(index); -} - -/** - * Ensure all types listed in the parameter pack `Extents` are integral types. - * Usage: - * put it as the last nameless template parameter of a function: - * `typename = ensure_integral_extents` - */ -template -using ensure_integral_extents = std::enable_if_t...>>; - } // namespace raft::detail diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp index 555b47dcae..de76ff3138 100644 --- a/cpp/include/raft/detail/span.hpp +++ b/cpp/include/raft/detail/span.hpp @@ -16,7 +16,7 @@ #pragma once #include // numeric_limits -#include +#include #include // __host__ __device__ #include From 0e6cc8618a139fda8f476a28edcd616c13d717db Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 15:28:18 -0400 Subject: [PATCH 02/58] Updates --- cpp/include/raft/cluster/detail/kmeans.cuh | 3 +- .../raft/cluster/detail/kmeans_common.cuh | 2 +- cpp/include/raft/core/detail/mdspan_util.hpp | 68 +++++++++++++++++++ cpp/include/raft/core/device_mdarray.hpp | 3 +- cpp/include/raft/core/host_mdarray.hpp | 4 +- cpp/include/raft/core/mdspan.hpp | 6 +- cpp/include/raft/detail/mdarray.hpp | 44 ------------ cpp/include/raft/linalg/transpose.cuh | 2 +- cpp/include/raft/random/make_blobs.cuh | 2 +- .../raft/spatial/knn/ivf_flat_types.hpp | 2 +- cpp/test/linalg/transpose.cu | 3 +- cpp/test/mdarray.cu | 10 +-- cpp/test/mdspan_utils.cu | 7 +- cpp/test/random/make_blobs.cu | 4 +- 14 files changed, 95 insertions(+), 65 deletions(-) create mode 100644 cpp/include/raft/core/detail/mdspan_util.hpp diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index 303de77078..b3f17295aa 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -29,9 +29,10 @@ #include #include #include +#include #include +#include #include -#include #include #include #include diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 358c8ce16e..4a77d230c2 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -29,9 +29,9 @@ #include #include +#include #include #include -#include #include #include #include diff --git a/cpp/include/raft/core/detail/mdspan_util.hpp b/cpp/include/raft/core/detail/mdspan_util.hpp new file mode 100644 index 0000000000..af5eabf3e3 --- /dev/null +++ b/cpp/include/raft/core/detail/mdspan_util.hpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::detail { + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence) +{ + return std::make_tuple(arr[Idx]...); +} + +template +MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) +{ + return arr_to_tup(arr, std::make_index_sequence{}); +} + +template +MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t +{ + int c = 0; + for (; v != 0; v &= v - 1) { + c++; + } + return c; +} + +MDSPAN_INLINE_FUNCTION auto popc(uint32_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popc(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcount(v); +#else + return native_popc(v); +#endif // compiler +} + +MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t +{ +#if defined(__CUDA_ARCH__) + return __popcll(v); +#elif defined(__GNUC__) || defined(__clang__) + return __builtin_popcountll(v); +#else + return native_popc(v); +#endif // compiler +} + +} // end namespace raft::detail \ No newline at end of file diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index c5b5d73fdc..393ff45815 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft { @@ -32,7 +33,7 @@ template > using device_mdarray = - mdarray>; + mdarray>; /** * @brief Shorthand for 0-dim host mdarray (scalar). diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp index 872c007255..448a639390 100644 --- a/cpp/include/raft/core/host_mdarray.hpp +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -16,6 +16,7 @@ #pragma once +#include #include namespace raft { @@ -30,8 +31,7 @@ template > -using host_mdarray = - mdarray>; +using host_mdarray = mdarray>; /** * @brief Shorthand for 0-dim host mdarray (scalar). diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 6078c730bf..93645e01cc 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -23,6 +23,8 @@ #pragma once #include +#include +#include #include // dynamic_extent #include @@ -154,11 +156,11 @@ MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents>= popc(s - 1); + idx >>= detail::popc(s - 1); } } index[0] = idx; - return arr_to_tup(index); + return detail::arr_to_tup(index); } /** diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/detail/mdarray.hpp index e3d2b0bf9e..1c13f5d1fc 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/detail/mdarray.hpp @@ -230,48 +230,4 @@ class host_vector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -template -MDSPAN_INLINE_FUNCTION auto native_popc(T v) -> int32_t -{ - int c = 0; - for (; v != 0; v &= v - 1) { - c++; - } - return c; -} - -MDSPAN_INLINE_FUNCTION auto popc(uint32_t v) -> int32_t -{ -#if defined(__CUDA_ARCH__) - return __popc(v); -#elif defined(__GNUC__) || defined(__clang__) - return __builtin_popcount(v); -#else - return native_popc(v); -#endif // compiler -} - -MDSPAN_INLINE_FUNCTION auto popc(uint64_t v) -> int32_t -{ -#if defined(__CUDA_ARCH__) - return __popcll(v); -#elif defined(__GNUC__) || defined(__clang__) - return __builtin_popcountll(v); -#else - return native_popc(v); -#endif // compiler -} - -template -MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N], std::index_sequence) -{ - return std::make_tuple(arr[Idx]...); -} - -template -MDSPAN_INLINE_FUNCTION constexpr auto arr_to_tup(T (&arr)[N]) -{ - return arr_to_tup(arr, std::make_index_sequence{}); -} - } // namespace raft::detail diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index cd78a2f495..e765ea7925 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -19,7 +19,7 @@ #pragma once #include "detail/transpose.cuh" -#include +#include namespace raft { namespace linalg { diff --git a/cpp/include/raft/random/make_blobs.cuh b/cpp/include/raft/random/make_blobs.cuh index 8bd78d98eb..82c940b471 100644 --- a/cpp/include/raft/random/make_blobs.cuh +++ b/cpp/include/raft/random/make_blobs.cuh @@ -21,7 +21,7 @@ #include "detail/make_blobs.cuh" #include -#include +#include namespace raft::random { diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index 02c4e30c1f..9beba05a5c 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -18,8 +18,8 @@ #include "common.hpp" +#include #include -#include #include #include diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 98f6d5e7e4..6edf9448b0 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -233,7 +234,7 @@ void test_transpose_submatrix() } auto vv = v.view(); - auto submat = raft::detail::stdex::submdspan( + auto submat = raft::stdex::submdspan( vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end)); static_assert(std::is_same_v); diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index af7bb7adf3..39cf76333b 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -14,8 +14,8 @@ * limitations under the License. */ #include -#include -#include +#include +#include #include #include #include @@ -467,19 +467,19 @@ void test_mdarray_unravel() // examples from numpy unravel_index { - auto coord = unravel_index(22, detail::matrix_extent{7, 6}, stdex::layout_right{}); + auto coord = unravel_index(22, matrix_extent{7, 6}, stdex::layout_right{}); static_assert(std::tuple_size::value == 2); ASSERT_EQ(std::get<0>(coord), 3); ASSERT_EQ(std::get<1>(coord), 4); } { - auto coord = unravel_index(41, detail::matrix_extent{7, 6}, stdex::layout_right{}); + auto coord = unravel_index(41, matrix_extent{7, 6}, stdex::layout_right{}); static_assert(std::tuple_size::value == 2); ASSERT_EQ(std::get<0>(coord), 6); ASSERT_EQ(std::get<1>(coord), 5); } { - auto coord = unravel_index(37, detail::matrix_extent{7, 6}, stdex::layout_right{}); + auto coord = unravel_index(37, matrix_extent{7, 6}, stdex::layout_right{}); static_assert(std::tuple_size::value == 2); ASSERT_EQ(std::get<0>(coord), 6); ASSERT_EQ(std::get<1>(coord), 1); diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 0d7d180b8f..7f1efb78bb 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -15,7 +15,8 @@ */ #include -#include +#include +#include namespace raft { @@ -24,7 +25,7 @@ namespace stdex = std::experimental; template > + typename AccessorPolicy = stdex::default_accessor> struct derived_device_mdspan : public device_mdspan { }; @@ -37,7 +38,7 @@ void test_template_asserts() using d_mdspan = derived_device_mdspan; static_assert( - std::is_same_v, device_mdspan>>, + std::is_same_v, device_mdspan>>, "not same"); static_assert(std::is_same_v, device_mdspan>>, diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 3f75a4cf0a..2971f62c56 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -18,8 +18,8 @@ #include #include #include -#include -#include +#include +#include #include namespace raft { From 71922f61ab4b9c1e4319d19efd48069b994fccbf Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 15:33:14 -0400 Subject: [PATCH 03/58] Fixing style --- cpp/test/linalg/transpose.cu | 2 +- cpp/test/mdarray.cu | 2 +- cpp/test/random/make_blobs.cu | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index 6edf9448b0..b5ad073f62 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -20,8 +20,8 @@ #include #include -#include #include +#include #include diff --git a/cpp/test/mdarray.cu b/cpp/test/mdarray.cu index 39cf76333b..0954073d86 100644 --- a/cpp/test/mdarray.cu +++ b/cpp/test/mdarray.cu @@ -14,8 +14,8 @@ * limitations under the License. */ #include -#include #include +#include #include #include #include diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index 2971f62c56..d06fa4c1cc 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -17,9 +17,9 @@ #include "../test_utils.h" #include #include -#include #include #include +#include #include namespace raft { From b0e5a02233597d2c2691c194bda27d5903bfaa86 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 15:47:55 -0400 Subject: [PATCH 04/58] Separating host_span and device_span as well --- cpp/include/raft/core/device_span.hpp | 29 +++++++++++++ cpp/include/raft/core/host_span.hpp | 28 ++++++++++++ cpp/include/raft/core/mdarray.hpp | 2 +- cpp/include/raft/core/mdspan.hpp | 61 ++++++++++++++------------- cpp/include/raft/core/span.hpp | 42 ++++++++---------- cpp/include/raft/detail/span.hpp | 30 +++++++------ cpp/test/span.cpp | 2 +- cpp/test/span.cu | 2 +- 8 files changed, 126 insertions(+), 70 deletions(-) create mode 100644 cpp/include/raft/core/device_span.hpp create mode 100644 cpp/include/raft/core/host_span.hpp diff --git a/cpp/include/raft/core/device_span.hpp b/cpp/include/raft/core/device_span.hpp new file mode 100644 index 0000000000..0730b20bfb --- /dev/null +++ b/cpp/include/raft/core/device_span.hpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { + +/** + * @brief A span class for device pointer. + */ +template +using device_span = span; + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_span.hpp b/cpp/include/raft/core/host_span.hpp new file mode 100644 index 0000000000..3cad62b7cd --- /dev/null +++ b/cpp/include/raft/core/host_span.hpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +/** + * @brief A span class for host pointer. + */ +template +using host_span = span; + +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index e918d81ff1..1110ecb75e 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -104,7 +104,7 @@ inline constexpr bool is_array_interface_v = is_array_interface::value; * are some inconsistencies in between them. We have made some modificiations to fit our * needs, which are listed below. * - * - Layout policy is different, the mdarray in raft uses `stdex::extent` directly just + * - Layout policy is different, the mdarray in raft uses `std::experimental::extent` directly just * like `mdspan`, while the `mdarray` in the reference implementation uses varidic * template. * diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 93645e01cc..64a69db171 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -25,23 +25,23 @@ #include #include #include -#include // dynamic_extent #include namespace raft { + +constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; + /** * @brief Dimensions extents for raft::mdspan */ template using extents = std::experimental::extents; -namespace stdex = std::experimental; - /** * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. * @{ */ -using stdex::layout_right; +using std::experimental::layout_right; using layout_c_contiguous = layout_right; using row_major = layout_right; /** @} */ @@ -50,24 +50,24 @@ using row_major = layout_right; * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. * @{ */ -using stdex::layout_left; +using std::experimental::layout_left; using layout_f_contiguous = layout_left; using col_major = layout_left; /** @} */ template -using vector_extent = stdex::extents; +using vector_extent = std::experimental::extents; template -using matrix_extent = stdex::extents; +using matrix_extent = std::experimental::extents; template -using scalar_extent = stdex::extents; +using scalar_extent = std::experimental::extents; /** * @brief Strided layout for non-contiguous memory. */ -using stdex::layout_stride; +using std::experimental::layout_stride; template using extent_1d = vector_extent; @@ -76,26 +76,27 @@ template using extent_2d = matrix_extent; template -using extent_3d = stdex::extents; +using extent_3d = + std::experimental::extents; template -using extent_4d = - stdex::extents; +using extent_4d = std::experimental:: + extents; template -using extent_5d = stdex::extents; +using extent_5d = std::experimental::extents; /** @} */ template > -using mdspan = stdex::mdspan; + typename AccessorPolicy = std::experimental::default_accessor> +using mdspan = std::experimental::mdspan; /** * Ensure all types listed in the parameter pack `Extents` are integral types. @@ -142,7 +143,8 @@ using enable_if_mdspan = std::enable_if_t>; // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // bit when the index is smaller, then try to avoid division when it's exp of 2. template -MDSPAN_INLINE_FUNCTION auto unravel_index_impl(I idx, stdex::extents shape) +MDSPAN_INLINE_FUNCTION auto unravel_index_impl( + I idx, std::experimental::extents shape) { constexpr auto kRank = static_cast(shape.rank()); std::size_t index[shape.rank()]{0}; // NOLINT @@ -218,10 +220,10 @@ auto flatten(mdspan_type mds) vector_extent ext{mds.size()}; - return stdex::mdspan(mds.data_handle(), ext); + return std::experimental::mdspan(mds.data_handle(), ext); } /** @@ -248,10 +250,11 @@ auto reshape(mdspan_type mds, extents new_shape) } RAFT_EXPECTS(new_size == mds.size(), "Cannot reshape array with size mismatch"); - return stdex::mdspan(mds.data_handle(), new_shape); + return std::experimental::mdspan(mds.data_handle(), + new_shape); } /** diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index 96950e979e..e32cf47138 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -35,7 +35,7 @@ namespace raft { * auto view = device_span{uvec.data(), uvec.size()}; * @endcode */ -template +template class span { public: using element_type = T; @@ -62,7 +62,7 @@ class span { */ constexpr span(pointer ptr, size_type count) noexcept : storage_{ptr, count} { - assert(!(Extent != dynamic_extent && count != Extent)); + assert(!(Extent != std::experimental::dynamic_extent && count != Extent)); assert(ptr || count == 0); } /** @@ -159,7 +159,8 @@ class span { return {data(), Count}; } - constexpr auto first(std::size_t _count) const -> span + constexpr auto first(std::size_t _count) const + -> span { assert(_count <= size()); return {data(), _count}; @@ -172,47 +173,40 @@ class span { return {data() + size() - Count, Count}; } - constexpr auto last(std::size_t _count) const -> span + constexpr auto last(std::size_t _count) const + -> span { assert(_count <= size()); return subspan(size() - _count, _count); } /*! - * If Count is std::dynamic_extent, r.size() == this->size() - Offset; + * If Count is std::std::experimental::dynamic_extent, r.size() == this->size() - Offset; * Otherwise r.size() == Count. */ - template + template constexpr auto subspan() const -> span::value> { - assert((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size())); - return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; + assert((Count == std::experimental::dynamic_extent) ? (Offset <= size()) + : (Offset + Count <= size())); + return {data() + Offset, Count == std::experimental::dynamic_extent ? size() - Offset : Count}; } - constexpr auto subspan(size_type _offset, size_type _count = dynamic_extent) const - -> span + constexpr auto subspan(size_type _offset, + size_type _count = std::experimental::dynamic_extent) const + -> span { - assert((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size())); - return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; + assert((_count == std::experimental::dynamic_extent) ? (_offset <= size()) + : (_offset + _count <= size())); + return {data() + _offset, + _count == std::experimental::dynamic_extent ? size() - _offset : _count}; } private: detail::span_storage storage_; }; -/** - * @brief A span class for host pointer. - */ -template -using host_span = span; - -/** - * @brief A span class for device pointer. - */ -template -using device_span = span; - template constexpr auto operator==(span l, span r) -> bool { diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp index de76ff3138..408224a617 100644 --- a/cpp/include/raft/detail/span.hpp +++ b/cpp/include/raft/detail/span.hpp @@ -17,11 +17,9 @@ #include // numeric_limits #include -#include // __host__ __device__ #include namespace raft { -constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; template class span; @@ -30,31 +28,35 @@ namespace detail { /*! * The extent E of the span returned by subspan is determined as follows: * - * - If Count is not dynamic_extent, Count; - * - Otherwise, if Extent is not dynamic_extent, Extent - Offset; - * - Otherwise, dynamic_extent. + * - If Count is not std::experimental::dynamic_extent, Count; + * - Otherwise, if Extent is not std::experimental::dynamic_extent, Extent - Offset; + * - Otherwise, std::experimental::dynamic_extent. */ template struct extent_value_t - : public std::integral_constant< - std::size_t, - Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> { + : public std::integral_constant { }; /*! - * If N is dynamic_extent, the extent of the returned span E is also - * dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. + * If N is std::experimental::dynamic_extent, the extent of the returned span E is also + * std::experimental::dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. */ template struct extent_as_bytes_value_t - : public std::integral_constant { + : public std::integral_constant< + std::size_t, + Extent == std::experimental::dynamic_extent ? Extent : sizeof(T) * Extent> { }; template struct is_allowed_extent_conversion_t : public std::integral_constant { + From == To || From == std::experimental::dynamic_extent || + To == std::experimental::dynamic_extent> { }; template @@ -101,7 +103,7 @@ struct span_storage { }; template -struct span_storage { +struct span_storage { private: T* ptr_{nullptr}; std::size_t size_{0}; diff --git a/cpp/test/span.cpp b/cpp/test/span.cpp index 6163811b95..0867427e1b 100644 --- a/cpp/test/span.cpp +++ b/cpp/test/span.cpp @@ -16,7 +16,7 @@ #include "test_span.hpp" #include #include // iota -#include +#include namespace raft { TEST(Span, DlfConstructors) diff --git a/cpp/test/span.cu b/cpp/test/span.cu index dcde9b5432..e03d1ddf5a 100644 --- a/cpp/test/span.cu +++ b/cpp/test/span.cu @@ -18,7 +18,7 @@ #include // iota #include #include -#include +#include #include #include From d69b16335d8de5880761855cfa403f5e5e54d9da Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 17:21:58 -0400 Subject: [PATCH 05/58] Cleanup and getting to build --- cpp/include/raft/core/device_mdspan.hpp | 6 +- cpp/include/raft/core/host_mdspan.hpp | 4 +- cpp/include/raft/core/mdarray.hpp | 1 + cpp/include/raft/core/mdspan.hpp | 62 +----------------- cpp/include/raft/core/mdspan_types.hpp | 85 +++++++++++++++++++++++++ cpp/include/raft/core/span.hpp | 41 ++++++------ cpp/include/raft/detail/span.hpp | 31 ++++----- cpp/test/linalg/transpose.cu | 2 +- cpp/test/span.cpp | 2 +- cpp/test/span.cu | 2 +- 10 files changed, 127 insertions(+), 109 deletions(-) create mode 100644 cpp/include/raft/core/mdspan_types.hpp diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index f835397e9d..88cd3dbf8e 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -28,18 +28,18 @@ template using managed_accessor = detail::accessor_mixin; /** - * @brief stdex::mdspan with device tag to avoid accessing incorrect memory location. + * @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location. */ template > + typename AccessorPolicy = std::experimental::default_accessor> using device_mdspan = mdspan>; template > + typename AccessorPolicy = std::experimental::default_accessor> using managed_mdspan = mdspan>; namespace detail { diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index f46fb6ff17..4602088a44 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -25,12 +25,12 @@ template using host_accessor = detail::accessor_mixin; /** - * @brief stdex::mdspan with host tag to avoid accessing incorrect memory location. + * @brief std::experimental::mdspan with host tag to avoid accessing incorrect memory location. */ template > + typename AccessorPolicy = std::experimental::default_accessor> using host_mdspan = mdspan>; namespace detail { diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 1110ecb75e..611d01fb70 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 64a69db171..55b651d69f 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -25,71 +25,11 @@ #include #include #include +#include #include namespace raft { -constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; - -/** - * @brief Dimensions extents for raft::mdspan - */ -template -using extents = std::experimental::extents; - -/** - * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. - * @{ - */ -using std::experimental::layout_right; -using layout_c_contiguous = layout_right; -using row_major = layout_right; -/** @} */ - -/** - * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. - * @{ - */ -using std::experimental::layout_left; -using layout_f_contiguous = layout_left; -using col_major = layout_left; -/** @} */ - -template -using vector_extent = std::experimental::extents; - -template -using matrix_extent = std::experimental::extents; - -template -using scalar_extent = std::experimental::extents; - -/** - * @brief Strided layout for non-contiguous memory. - */ -using std::experimental::layout_stride; - -template -using extent_1d = vector_extent; - -template -using extent_2d = matrix_extent; - -template -using extent_3d = - std::experimental::extents; - -template -using extent_4d = std::experimental:: - extents; - -template -using extent_5d = std::experimental::extents; /** @} */ template + +namespace raft { + +constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; + +/** + * @brief Dimensions extents for raft::mdspan + */ +template +using extents = std::experimental::extents; + +/** + * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. + * @{ + */ +using std::experimental::layout_right; +using layout_c_contiguous = layout_right; +using row_major = layout_right; +/** @} */ + +/** + * @defgroup F-Contiguous layout for mdarray and mdspan. Implies column-major and contiguous memory. + * @{ + */ +using std::experimental::layout_left; +using layout_f_contiguous = layout_left; +using col_major = layout_left; +/** @} */ + +template +using vector_extent = std::experimental::extents; + +template +using matrix_extent = std::experimental::extents; + +template +using scalar_extent = std::experimental::extents; + +/** + * @brief Strided layout for non-contiguous memory. + */ +using std::experimental::layout_stride; + +template +using extent_1d = vector_extent; + +template +using extent_2d = matrix_extent; + +template +using extent_3d = + std::experimental::extents; + +template +using extent_4d = std::experimental:: + extents; + +template +using extent_5d = std::experimental::extents; + +} // namespace raft diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index e32cf47138..3dec7e6fa8 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -18,10 +18,11 @@ #include #include // size_t #include // std::byte +#include #include #include #include -#include // __host__ __device__ +#include // _MDSPAN_HOST_DEVICE #include #include @@ -35,7 +36,7 @@ namespace raft { * auto view = device_span{uvec.data(), uvec.size()}; * @endcode */ -template +template class span { public: using element_type = T; @@ -62,7 +63,7 @@ class span { */ constexpr span(pointer ptr, size_type count) noexcept : storage_{ptr, count} { - assert(!(Extent != std::experimental::dynamic_extent && count != Extent)); + assert(!(Extent != dynamic_extent && count != Extent)); assert(ptr || count == 0); } /** @@ -108,22 +109,22 @@ class span { constexpr auto cend() const noexcept -> const_iterator { return data() + size(); } - __host__ __device__ constexpr auto rbegin() const noexcept -> reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto rbegin() const noexcept -> reverse_iterator { return reverse_iterator{end()}; } - __host__ __device__ constexpr auto rend() const noexcept -> reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto rend() const noexcept -> reverse_iterator { return reverse_iterator{begin()}; } - __host__ __device__ constexpr auto crbegin() const noexcept -> const_reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto crbegin() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cend()}; } - __host__ __device__ constexpr auto crend() const noexcept -> const_reverse_iterator + _MDSPAN_HOST_DEVICE constexpr auto crend() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cbegin()}; } @@ -159,8 +160,7 @@ class span { return {data(), Count}; } - constexpr auto first(std::size_t _count) const - -> span + constexpr auto first(std::size_t _count) const -> span { assert(_count <= size()); return {data(), _count}; @@ -173,34 +173,29 @@ class span { return {data() + size() - Count, Count}; } - constexpr auto last(std::size_t _count) const - -> span + constexpr auto last(std::size_t _count) const -> span { assert(_count <= size()); return subspan(size() - _count, _count); } /*! - * If Count is std::std::experimental::dynamic_extent, r.size() == this->size() - Offset; + * If Count is std::dynamic_extent, r.size() == this->size() - Offset; * Otherwise r.size() == Count. */ - template + template constexpr auto subspan() const -> span::value> { - assert((Count == std::experimental::dynamic_extent) ? (Offset <= size()) - : (Offset + Count <= size())); - return {data() + Offset, Count == std::experimental::dynamic_extent ? size() - Offset : Count}; + assert((Count == dynamic_extent) ? (Offset <= size()) : (Offset + Count <= size())); + return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; } - constexpr auto subspan(size_type _offset, - size_type _count = std::experimental::dynamic_extent) const - -> span + constexpr auto subspan(size_type _offset, size_type _count = dynamic_extent) const + -> span { - assert((_count == std::experimental::dynamic_extent) ? (_offset <= size()) - : (_offset + _count <= size())); - return {data() + _offset, - _count == std::experimental::dynamic_extent ? size() - _offset : _count}; + assert((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size())); + return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; } private: diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/detail/span.hpp index 408224a617..c11e6ba32b 100644 --- a/cpp/include/raft/detail/span.hpp +++ b/cpp/include/raft/detail/span.hpp @@ -16,6 +16,7 @@ #pragma once #include // numeric_limits +#include #include #include @@ -28,35 +29,31 @@ namespace detail { /*! * The extent E of the span returned by subspan is determined as follows: * - * - If Count is not std::experimental::dynamic_extent, Count; - * - Otherwise, if Extent is not std::experimental::dynamic_extent, Extent - Offset; - * - Otherwise, std::experimental::dynamic_extent. + * - If Count is not dynamic_extent, Count; + * - Otherwise, if Extent is not dynamic_extent, Extent - Offset; + * - Otherwise, dynamic_extent. */ template struct extent_value_t - : public std::integral_constant { + : public std::integral_constant< + std::size_t, + Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> { }; /*! - * If N is std::experimental::dynamic_extent, the extent of the returned span E is also - * std::experimental::dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. + * If N is dynamic_extent, the extent of the returned span E is also + * dynamic_extent; otherwise it is std::size_t(sizeof(T)) * N. */ template struct extent_as_bytes_value_t - : public std::integral_constant< - std::size_t, - Extent == std::experimental::dynamic_extent ? Extent : sizeof(T) * Extent> { + : public std::integral_constant { }; template struct is_allowed_extent_conversion_t : public std::integral_constant { + From == To || From == dynamic_extent || To == dynamic_extent> { }; template @@ -77,7 +74,7 @@ struct is_span_t : public is_span_oracle_t::type> { }; template -__host__ __device__ constexpr auto lexicographical_compare(InputIt1 first1, +_MDSPAN_HOST_DEVICE constexpr auto lexicographical_compare(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2) -> bool @@ -103,7 +100,7 @@ struct span_storage { }; template -struct span_storage { +struct span_storage { private: T* ptr_{nullptr}; std::size_t size_{0}; diff --git a/cpp/test/linalg/transpose.cu b/cpp/test/linalg/transpose.cu index b5ad073f62..adfaf6e49d 100644 --- a/cpp/test/linalg/transpose.cu +++ b/cpp/test/linalg/transpose.cu @@ -234,7 +234,7 @@ void test_transpose_submatrix() } auto vv = v.view(); - auto submat = raft::stdex::submdspan( + auto submat = std::experimental::submdspan( vv, std::make_tuple(row_beg, row_end), std::make_tuple(col_beg, col_end)); static_assert(std::is_same_v); diff --git a/cpp/test/span.cpp b/cpp/test/span.cpp index 0867427e1b..f8d9345a12 100644 --- a/cpp/test/span.cpp +++ b/cpp/test/span.cpp @@ -16,7 +16,7 @@ #include "test_span.hpp" #include #include // iota -#include +#include namespace raft { TEST(Span, DlfConstructors) diff --git a/cpp/test/span.cu b/cpp/test/span.cu index e03d1ddf5a..91833d4dc7 100644 --- a/cpp/test/span.cu +++ b/cpp/test/span.cu @@ -16,9 +16,9 @@ #include "test_span.hpp" #include #include // iota +#include #include #include -#include #include #include From 50d750b63b823ee81ed455fed1ef2c868e5de270 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 7 Sep 2022 18:00:43 -0400 Subject: [PATCH 06/58] Updates --- .../detail/device_mdarray.hpp} | 50 ++------------ cpp/include/raft/core/detail/host_mdarray.hpp | 69 +++++++++++++++++++ .../{mdspan_util.hpp => mdspan_util.cuh} | 1 + cpp/include/raft/{ => core}/detail/span.hpp | 0 cpp/include/raft/core/device_mdarray.hpp | 1 + cpp/include/raft/core/host_mdarray.hpp | 2 + cpp/include/raft/core/host_mdspan.hpp | 3 +- cpp/include/raft/core/mdarray.hpp | 16 ++--- cpp/include/raft/core/mdspan.hpp | 6 +- cpp/include/raft/core/span.hpp | 6 +- cpp/include/raft/distance/distance.cuh | 2 +- 11 files changed, 98 insertions(+), 58 deletions(-) rename cpp/include/raft/{detail/mdarray.hpp => core/detail/device_mdarray.hpp} (77%) create mode 100644 cpp/include/raft/core/detail/host_mdarray.hpp rename cpp/include/raft/core/detail/{mdspan_util.hpp => mdspan_util.cuh} (99%) rename cpp/include/raft/{ => core}/detail/span.hpp (100%) diff --git a/cpp/include/raft/detail/mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp similarity index 77% rename from cpp/include/raft/detail/mdarray.hpp rename to cpp/include/raft/core/detail/device_mdarray.hpp index 1c13f5d1fc..569f573c19 100644 --- a/cpp/include/raft/detail/mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -21,11 +21,12 @@ * limitations under the License. */ #pragma once -#include -#include +#include +#include +#include -#include -#include // dynamic_extent +#include +#include // dynamic_extent #include #include @@ -189,45 +190,4 @@ class device_uvector_policy { [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } }; -/** - * @brief A container policy for host mdarray. - */ -template > -class host_vector_policy { - public: - using element_type = ElementType; - using container_type = std::vector; - using allocator_type = typename container_type::allocator_type; - using pointer = typename container_type::pointer; - using const_pointer = typename container_type::const_pointer; - using reference = element_type&; - using const_reference = element_type const&; - using accessor_policy = std::experimental::default_accessor; - using const_accessor_policy = std::experimental::default_accessor; - - public: - auto create(size_t n) -> container_type { return container_type(n); } - - constexpr host_vector_policy() noexcept(std::is_nothrow_default_constructible_v) = - default; - explicit constexpr host_vector_policy(rmm::cuda_stream_view) noexcept( - std::is_nothrow_default_constructible_v) - : host_vector_policy() - { - } - - [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference - { - return c[n]; - } - [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept - -> const_reference - { - return c[n]; - } - - [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } - [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } -}; - } // namespace raft::detail diff --git a/cpp/include/raft/core/detail/host_mdarray.hpp b/cpp/include/raft/core/detail/host_mdarray.hpp new file mode 100644 index 0000000000..74bd55e78c --- /dev/null +++ b/cpp/include/raft/core/detail/host_mdarray.hpp @@ -0,0 +1,69 @@ +/* + * Copyright (2019) Sandia Corporation + * + * The source code is licensed under the 3-clause BSD license found in the LICENSE file + * thirdparty/LICENSES/mdarray.license + */ + +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace raft::detail { + +/** + * @brief A container policy for host mdarray. + */ +template > +class host_vector_policy { + public: + using element_type = ElementType; + using container_type = std::vector; + using allocator_type = typename container_type::allocator_type; + using pointer = typename container_type::pointer; + using const_pointer = typename container_type::const_pointer; + using reference = element_type&; + using const_reference = element_type const&; + using accessor_policy = std::experimental::default_accessor; + using const_accessor_policy = std::experimental::default_accessor; + + public: + auto create(size_t n) -> container_type { return container_type(n); } + + constexpr host_vector_policy() noexcept(std::is_nothrow_default_constructible_v) = + default; + explicit constexpr host_vector_policy(rmm::cuda_stream_view) noexcept( + std::is_nothrow_default_constructible_v) + : host_vector_policy() + { + } + + [[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference + { + return c[n]; + } + [[nodiscard]] constexpr auto access(container_type const& c, size_t n) const noexcept + -> const_reference + { + return c[n]; + } + + [[nodiscard]] auto make_accessor_policy() noexcept { return accessor_policy{}; } + [[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; } +}; +} // namespace raft::detail diff --git a/cpp/include/raft/core/detail/mdspan_util.hpp b/cpp/include/raft/core/detail/mdspan_util.cuh similarity index 99% rename from cpp/include/raft/core/detail/mdspan_util.hpp rename to cpp/include/raft/core/detail/mdspan_util.cuh index af5eabf3e3..b2f73857b2 100644 --- a/cpp/include/raft/core/detail/mdspan_util.hpp +++ b/cpp/include/raft/core/detail/mdspan_util.cuh @@ -16,6 +16,7 @@ #pragma once #include + #include #include diff --git a/cpp/include/raft/detail/span.hpp b/cpp/include/raft/core/detail/span.hpp similarity index 100% rename from cpp/include/raft/detail/span.hpp rename to cpp/include/raft/core/detail/span.hpp diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 393ff45815..1b20629f5f 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -16,6 +16,7 @@ #pragma once +#include #include #include diff --git a/cpp/include/raft/core/host_mdarray.hpp b/cpp/include/raft/core/host_mdarray.hpp index 448a639390..6221ca59f0 100644 --- a/cpp/include/raft/core/host_mdarray.hpp +++ b/cpp/include/raft/core/host_mdarray.hpp @@ -17,6 +17,8 @@ #pragma once #include + +#include #include namespace raft { diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index 4602088a44..fcd637f3a5 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -16,9 +16,10 @@ #pragma once -#include #include +#include + namespace raft { template diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 611d01fb70..64e8504708 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -24,14 +24,11 @@ #include -#include -#include -#include +#include #include #include -#include + #include -#include namespace raft { /** @@ -158,9 +155,12 @@ class mdarray typename container_policy_type::const_accessor_policy, typename container_policy_type::accessor_policy>> using view_type_impl = - std::conditional_t, - device_mdspan>; + mdspan>; public: /** diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 55b651d69f..f608f3d085 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -22,10 +22,12 @@ */ #pragma once -#include -#include #include #include + +#include +#include + #include namespace raft { diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index 3dec7e6fa8..db3b25296b 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -19,7 +19,11 @@ #include // size_t #include // std::byte #include -#include + +#include + +// TODO (cjnolet): Remove thrust dependencies here so host_span can be used without CUDA Toolkit +// being installed. Reference: https://github.com/rapidsai/raft/issues/812. #include #include #include // _MDSPAN_HOST_DEVICE diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 3db1749bb4..8aaeadd531 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -23,7 +23,7 @@ #include #include -#include +#include /** * @defgroup pairwise_distance pairwise distance prims From 6fda1fd454260ed5478c414b090bfbedb1663a82 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 8 Sep 2022 09:05:41 -0400 Subject: [PATCH 07/58] Fixing docs --- cpp/include/raft/core/mdspan.hpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index f608f3d085..289cc82f1b 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -1,10 +1,3 @@ -/* - * Copyright (2019) Sandia Corporation - * - * The source code is licensed under the 3-clause BSD license found in the LICENSE file - * thirdparty/LICENSES/mdarray.license - */ - /* * Copyright (c) 2022, NVIDIA CORPORATION. * @@ -32,8 +25,6 @@ namespace raft { -/** @} */ - template Date: Thu, 8 Sep 2022 09:54:34 -0400 Subject: [PATCH 08/58] Updating readme to use proper header paths --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2159f128bf..8194c27cf9 100755 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ The APIs in RAFT currently accept raw pointers to device memory and we are in th The `mdarray` forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: ```c++ -#include +#include int n_rows = 10; int n_cols = 10; @@ -56,8 +56,8 @@ Most of the primitives in RAFT accept a `raft::handle_t` object for the manageme The example below demonstrates creating a RAFT handle and using it with `device_matrix` and `device_vector` to allocate memory, generating random clusters, and computing pairwise Euclidean distances: ```c++ -#include -#include +#include +#include #include #include From 3aeb5304f0024a05580c011bc36f30fdaff3d72e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 8 Sep 2022 17:14:40 -0400 Subject: [PATCH 09/58] More updates based on review feedback --- cpp/include/raft.hpp | 5 +-- .../raft/core/detail/device_mdarray.hpp | 2 +- ...sor_mixin.hpp => host_device_accessor.hpp} | 6 ++-- cpp/include/raft/core/detail/macros.hpp | 35 ++++++++++++++++++ cpp/include/raft/core/detail/mdspan_util.cuh | 1 + cpp/include/raft/core/detail/span.hpp | 10 +++--- cpp/include/raft/core/device_mdarray.hpp | 2 +- cpp/include/raft/core/device_mdspan.hpp | 21 +++++------ cpp/include/raft/core/host_mdspan.hpp | 19 +++++----- cpp/include/raft/core/mdarray.hpp | 36 +++++++++---------- cpp/include/raft/core/mdspan.hpp | 21 ++++++----- cpp/include/raft/core/mdspan_types.hpp | 9 ++--- cpp/include/raft/core/span.hpp | 11 +++--- cpp/test/mdspan_utils.cu | 11 +++--- 14 files changed, 112 insertions(+), 77 deletions(-) rename cpp/include/raft/core/detail/{accessor_mixin.hpp => host_device_accessor.hpp} (88%) create mode 100644 cpp/include/raft/core/detail/macros.hpp diff --git a/cpp/include/raft.hpp b/cpp/include/raft.hpp index b1b8255b7e..7b997bfb87 100644 --- a/cpp/include/raft.hpp +++ b/cpp/include/raft.hpp @@ -17,9 +17,10 @@ /** * This file is deprecated and will be removed in release 22.06. */ +#include "raft/core/device_mdarray.hpp" +#include "raft/core/device_mdspan.hpp" +#include "raft/core/device_span.hpp" #include "raft/handle.hpp" -#include "raft/mdarray.hpp" -#include "raft/span.hpp" #include diff --git a/cpp/include/raft/core/detail/device_mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp index 569f573c19..aad35b6282 100644 --- a/cpp/include/raft/core/detail/device_mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -25,7 +25,7 @@ #include #include -#include +#include #include // dynamic_extent #include diff --git a/cpp/include/raft/core/detail/accessor_mixin.hpp b/cpp/include/raft/core/detail/host_device_accessor.hpp similarity index 88% rename from cpp/include/raft/core/detail/accessor_mixin.hpp rename to cpp/include/raft/core/detail/host_device_accessor.hpp index 6edd85dbaf..3a71e6366b 100644 --- a/cpp/include/raft/core/detail/accessor_mixin.hpp +++ b/cpp/include/raft/core/detail/host_device_accessor.hpp @@ -22,7 +22,7 @@ namespace raft::detail { * @brief A mixin to distinguish host and device memory. */ template -struct accessor_mixin : public AccessorPolicy { +struct host_device_accessor : public AccessorPolicy { using accessor_type = AccessorPolicy; using is_host_type = std::conditional_t; using is_device_type = std::conditional_t; @@ -32,8 +32,8 @@ struct accessor_mixin : public AccessorPolicy { static constexpr bool is_managed_accessible = is_device && is_host; // make sure the explicit ctor can fall through using AccessorPolicy::AccessorPolicy; - using offset_policy = accessor_mixin; - accessor_mixin(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT + using offset_policy = host_device_accessor; + host_device_accessor(AccessorPolicy const& that) : AccessorPolicy{that} {} // NOLINT }; } // namespace raft::detail diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp new file mode 100644 index 0000000000..00fbab1530 --- /dev/null +++ b/cpp/include/raft/core/detail/macros.hpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef _RAFT_HAS_CUDA +#if defined(__CUDACC__) +#define _RAFT_HAS_CUDA __CUDACC__ +#endif +#endif + +#ifndef _RAFT_HOST_DEVICE +#if defined(_RAFT_HAS_CUDA) +#define _RAFT_HOST_DEVICE __host__ __device__ +#else +#define _RAFT_HOST_DEVICE +#endif +#endif + +#ifndef RAFT_INLINE_FUNCTION +#define RAFT_INLINE_FUNCTION inline _RAFT_HOST_DEVICE +#endif diff --git a/cpp/include/raft/core/detail/mdspan_util.cuh b/cpp/include/raft/core/detail/mdspan_util.cuh index b2f73857b2..6b2c90abcc 100644 --- a/cpp/include/raft/core/detail/mdspan_util.cuh +++ b/cpp/include/raft/core/detail/mdspan_util.cuh @@ -17,6 +17,7 @@ #include +#include #include #include diff --git a/cpp/include/raft/core/detail/span.hpp b/cpp/include/raft/core/detail/span.hpp index c11e6ba32b..20500d618b 100644 --- a/cpp/include/raft/core/detail/span.hpp +++ b/cpp/include/raft/core/detail/span.hpp @@ -16,8 +16,8 @@ #pragma once #include // numeric_limits +#include #include -#include #include namespace raft { @@ -74,10 +74,10 @@ struct is_span_t : public is_span_oracle_t::type> { }; template -_MDSPAN_HOST_DEVICE constexpr auto lexicographical_compare(InputIt1 first1, - InputIt1 last1, - InputIt2 first2, - InputIt2 last2) -> bool +_RAFT_HOST_DEVICE constexpr auto lexicographical_compare(InputIt1 first1, + InputIt1 last1, + InputIt2 first2, + InputIt2 last2) -> bool { Compare comp; for (; first1 != last1 && first2 != last2; ++first1, ++first2) { diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 1b20629f5f..1c17b5bcb9 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include namespace raft { diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 88cd3dbf8e..4d6c1836fc 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -16,16 +16,16 @@ #pragma once -#include +#include #include namespace raft { template -using device_accessor = detail::accessor_mixin; +using device_accessor = detail::host_device_accessor; template -using managed_accessor = detail::accessor_mixin; +using managed_accessor = detail::host_device_accessor; /** * @brief std::experimental::mdspan with device tag to avoid accessing incorrect memory location. @@ -44,17 +44,18 @@ using managed_mdspan = mdspan -struct is_device_mdspan : std::false_type { +struct is_device_accessible_mdspan : std::false_type { }; template -struct is_device_mdspan : std::bool_constant { +struct is_device_accessible_mdspan + : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type */ template -using is_device_mdspan_t = is_device_mdspan>; +using is_device_accessible_mdspan_t = is_device_accessible_mdspan>; template struct is_managed_mdspan : std::false_type { @@ -76,10 +77,11 @@ using is_managed_mdspan_t = is_managed_mdspan>; * derived type */ template -inline constexpr bool is_device_mdspan_v = std::conjunction_v...>; +inline constexpr bool is_device_accessible_mdspan_v = + std::conjunction_v...>; template -using enable_if_device_mdspan = std::enable_if_t>; +using enable_if_device_mdspan = std::enable_if_t>; /** * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a @@ -187,8 +189,7 @@ template auto make_device_vector_view(ElementType* ptr, IndexType n) { - vector_extent extents{n}; - return device_vector_view{ptr, extents}; + return device_vector_view{ptr, n}; } } // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index fcd637f3a5..e6ab22004e 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -18,12 +18,12 @@ #include -#include +#include namespace raft { template -using host_accessor = detail::accessor_mixin; +using host_accessor = detail::host_device_accessor; /** * @brief std::experimental::mdspan with host tag to avoid accessing incorrect memory location. @@ -37,17 +37,18 @@ using host_mdspan = mdspan -struct is_host_mdspan : std::false_type { +struct is_host_accessible_mdspan : std::false_type { }; template -struct is_host_mdspan : std::bool_constant { +struct is_host_accessible_mdspan + : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type */ template -using is_host_mdspan_t = is_host_mdspan>; +using is_host_accessible_mdspan_t = is_host_accessible_mdspan>; } // namespace detail @@ -56,10 +57,11 @@ using is_host_mdspan_t = is_host_mdspan>; * derived type */ template -inline constexpr bool is_host_mdspan_v = std::conjunction_v...>; +inline constexpr bool is_host_accessible_mdspan_v = + std::conjunction_v...>; template -using enable_if_host_mdspan = std::enable_if_t>; +using enable_if_host_mdspan = std::enable_if_t>; /** * @brief Shorthand for 0-dim host mdspan (scalar). @@ -137,7 +139,6 @@ template auto make_host_vector_view(ElementType* ptr, IndexType n) { - vector_extent extents{n}; - return host_vector_view{ptr, extents}; + return host_vector_view{ptr, n}; } } // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 64e8504708..44730d901e 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -24,10 +24,10 @@ #include -#include +#include +#include #include #include - #include namespace raft { @@ -158,9 +158,9 @@ class mdarray mdspan>; + detail::host_device_accessor>; public: /** @@ -266,61 +266,61 @@ class mdarray } // basic_mdarray observers of the domain multidimensional index space (also in basic_mdspan) - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto rank() noexcept -> rank_type + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto rank() noexcept -> rank_type { return extents_type::rank(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto rank_dynamic() noexcept -> rank_type + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto rank_dynamic() noexcept -> rank_type { return extents_type::rank_dynamic(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto static_extent(size_t r) noexcept + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto static_extent(size_t r) noexcept -> index_type { return extents_type::static_extent(r); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto extents() const noexcept -> extents_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto extents() const noexcept -> extents_type { return map_.extents(); } /** * @brief the extent of rank r */ - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto extent(size_t r) const noexcept -> index_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto extent(size_t r) const noexcept -> index_type { return map_.extents().extent(r); } // mapping - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto mapping() const noexcept -> mapping_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto mapping() const noexcept -> mapping_type { return map_; } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto is_unique() const noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto is_unique() const noexcept -> bool { return map_.is_unique(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto is_exhaustive() const noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto is_exhaustive() const noexcept -> bool { return map_.is_exhaustive(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto is_strided() const noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto is_strided() const noexcept -> bool { return map_.is_strided(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION constexpr auto stride(size_t r) const -> index_type + [[nodiscard]] RAFT_INLINE_FUNCTION constexpr auto stride(size_t r) const -> index_type { return map_.stride(r); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto is_always_unique() noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto is_always_unique() noexcept -> bool { return mapping_type::is_always_unique(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto is_always_exhaustive() noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto is_always_exhaustive() noexcept -> bool { return mapping_type::is_always_exhaustive(); } - [[nodiscard]] MDSPAN_INLINE_FUNCTION static constexpr auto is_always_strided() noexcept -> bool + [[nodiscard]] RAFT_INLINE_FUNCTION static constexpr auto is_always_strided() noexcept -> bool { return mapping_type::is_always_strided(); } diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 289cc82f1b..7169a010b6 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -18,7 +18,8 @@ #include #include -#include +#include +#include #include #include @@ -59,9 +60,6 @@ struct is_mdspan( template using is_mdspan_t = is_mdspan>; -// template -// inline constexpr bool is_mdspan_v = is_mdspan_t::value; - /** * @\brief Boolean to determine if variadic template types Tn are either * raft::host_mdspan/raft::device_mdspan or their derived types @@ -76,7 +74,7 @@ using enable_if_mdspan = std::enable_if_t>; // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // bit when the index is smaller, then try to avoid division when it's exp of 2. template -MDSPAN_INLINE_FUNCTION auto unravel_index_impl( +RAFT_INLINE_FUNCTION auto unravel_index_impl( I idx, std::experimental::extents shape) { constexpr auto kRank = static_cast(shape.rank()); @@ -117,9 +115,10 @@ template auto make_mdspan(ElementType* ptr, extents exts) { - using accessor_type = detail::accessor_mixin, - is_host_accessible, - is_device_accessible>; + using accessor_type = + detail::host_device_accessor, + is_host_accessible, + is_device_accessible>; return mdspan{ptr, exts}; } @@ -207,9 +206,9 @@ auto reshape(mdspan_type mds, extents new_shape) * \return A std::tuple that represents the coordinate. */ template -MDSPAN_INLINE_FUNCTION auto unravel_index(Idx idx, - extents shape, - LayoutPolicy const& layout) +RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, + extents shape, + LayoutPolicy const& layout) { static_assert(std::is_same_v>, layout_c_contiguous>, diff --git a/cpp/include/raft/core/mdspan_types.hpp b/cpp/include/raft/core/mdspan_types.hpp index 1417dea8f6..bc2ba314a3 100644 --- a/cpp/include/raft/core/mdspan_types.hpp +++ b/cpp/include/raft/core/mdspan_types.hpp @@ -20,13 +20,8 @@ namespace raft { -constexpr std::size_t dynamic_extent = std::experimental::dynamic_extent; - -/** - * @brief Dimensions extents for raft::mdspan - */ -template -using extents = std::experimental::extents; +using std::experimental::dynamic_extent; +using std::experimental::extents; /** * @defgroup C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory. diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index db3b25296b..188d58c896 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -20,13 +20,14 @@ #include // std::byte #include +#include #include // TODO (cjnolet): Remove thrust dependencies here so host_span can be used without CUDA Toolkit // being installed. Reference: https://github.com/rapidsai/raft/issues/812. #include #include -#include // _MDSPAN_HOST_DEVICE +#include // _RAFT_HOST_DEVICE #include #include @@ -113,22 +114,22 @@ class span { constexpr auto cend() const noexcept -> const_iterator { return data() + size(); } - _MDSPAN_HOST_DEVICE constexpr auto rbegin() const noexcept -> reverse_iterator + _RAFT_HOST_DEVICE constexpr auto rbegin() const noexcept -> reverse_iterator { return reverse_iterator{end()}; } - _MDSPAN_HOST_DEVICE constexpr auto rend() const noexcept -> reverse_iterator + _RAFT_HOST_DEVICE constexpr auto rend() const noexcept -> reverse_iterator { return reverse_iterator{begin()}; } - _MDSPAN_HOST_DEVICE constexpr auto crbegin() const noexcept -> const_reverse_iterator + _RAFT_HOST_DEVICE constexpr auto crbegin() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cend()}; } - _MDSPAN_HOST_DEVICE constexpr auto crend() const noexcept -> const_reverse_iterator + _RAFT_HOST_DEVICE constexpr auto crend() const noexcept -> const_reverse_iterator { return const_reverse_iterator{cbegin()}; } diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 7f1efb78bb..5683c0267a 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -55,16 +55,17 @@ void test_template_asserts() static_assert(is_mdspan_v, "Derived device mdspan type is not mdspan"); // Checking if types are device_mdspan - static_assert(is_device_mdspan_v>, + static_assert(is_device_accessible_mdspan_v>, "device_matrix_view type not a device_mdspan"); - static_assert(!is_device_mdspan_v>, + static_assert(!is_device_accessible_mdspan_v>, "host_matrix_view type is a device_mdspan"); - static_assert(is_device_mdspan_v, "Derived device mdspan type is not device_mdspan"); + static_assert(is_device_accessible_mdspan_v, + "Derived device mdspan type is not device_mdspan"); // Checking if types are host_mdspan - static_assert(!is_host_mdspan_v>, + static_assert(!is_host_accessible_mdspan_v>, "device_matrix_view type is a host_mdspan"); - static_assert(is_host_mdspan_v>, + static_assert(is_host_accessible_mdspan_v>, "host_matrix_view type is not a host_mdspan"); // checking variadics From ca354f3075a3f0ad691e02e157acc629c2de8218 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 14 Sep 2022 14:48:48 -0400 Subject: [PATCH 10/58] Mdspanifying spatial/knn functions --- cpp/include/raft/spatial/knn/knn.cuh | 66 ++++++++++++++++++++++++++++ cpp/test/spatial/knn.cu | 24 ++++------ 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 52e7e31cc2..c9fdf94fbd 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include "detail/knn_brute_force_faiss.cuh" #include "detail/selection_faiss.cuh" @@ -224,4 +225,69 @@ void brute_force_knn(raft::handle_t const& handle, metric, metric_arg); } + +/** + * @brief Flat C++ API function to perform a brute force knn on + * a series of input arrays and combine the results into a single + * output array for indexes and distances. Inputs can be either + * row- or column-major but the output matrices will always be in + * row-major format. + * + * @param[in] handle the cuml handle to use + * @param[in] index vector of device matrices (each size m_i*d) to be used as the knn index + * @param[in] search matrix (size n*d) to be used for searching the index + * @param[out] indices matrix (size n*k) to store output knn indices + * @param[out] distances matrix (size n*k) to store the output knn distance + * @param[in] k the number of nearest neighbors to return + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + */ +template +void brute_force_knn(raft::handle_t const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + value_int k, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional> translations = std::nullopt) { + + RAFT_EXPECTS(index.extent(1) == search.extent(1), "Number of dimensions for both index and search matrices must be equal"); + RAFT_EXPECTS(indices.extent(0) == distances.extent(0) == search.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows in search matrix."); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); + + constexpr auto rowMajorIndex = std::is_same_v; + constexpr auto rowMajorQuery = std::is_same_v; + + std::vector inputs; + std::vector sizes; + for(int i = 0; i < index.size(); ++i) { + inputs.push_back(const_cast(index[i].data_handle())); + sizes.push_back(index[i].extents(0)); + } + + detail::brute_force_knn_impl(handle, + input, + sizes, + index.extents(1), + search.data_handle(), + search.extents(1), + indices.data_handle(), + distances.data_handle(), + k, + rowMajorIndex, + rowMajorQuery, + translations.value_or(nullptr), + metric, + metric_arg.value_or(2.0)); +} + } // namespace raft::spatial::knn diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index 37e0edb6ab..d679971401 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include +#include #include #include #if defined RAFT_NN_COMPILED @@ -86,22 +87,13 @@ class KNNTest : public ::testing::TestWithParam { raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); #endif - std::vector input_vec; - std::vector sizes_vec; - input_vec.push_back(input_.data()); - sizes_vec.push_back(rows_); - - brute_force_knn(handle, - input_vec, - sizes_vec, - cols_, - search_data_.data(), - rows_, - indices_.data(), - distances_.data(), - k_, - true, - true); + auto index = raft::make_device_matrix_view(input_.data(), rows_, cols_); + auto search = raft::make_device_matrix_view(search_data_.data(), rows_, cols_); + + auto indices = raft::make_device_matrix_view(indices_.data(), rows_, k_); + auto distances = raft::make_device_matrix_view(distances_.data(), rows_, k_); + + brute_force_knn(handle, index, search, indices, distances, k_); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); From 46b0750531d0b9ffa9ce250ea143e914edc037b3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 14 Sep 2022 17:09:54 -0400 Subject: [PATCH 11/58] Getting knn test to build --- cpp/include/raft/spatial/knn/knn.cuh | 217 ++++++++++++++++++++++----- cpp/test/spatial/knn.cu | 13 +- 2 files changed, 185 insertions(+), 45 deletions(-) diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index c9fdf94fbd..f8c13c9abf 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -16,9 +16,9 @@ #pragma once -#include #include "detail/knn_brute_force_faiss.cuh" #include "detail/selection_faiss.cuh" +#include #include "detail/topk/radix_topk.cuh" #include "detail/topk/warpsort_topk.cuh" @@ -66,6 +66,61 @@ inline void knn_merge_parts(value_t* in_keys, in_keys, in_values, out_keys, out_values, n_samples, n_parts, k, stream, translations); } +/** + * Performs a k-select across row partitioned index/distance + * matrices formatted like the following: + * row1: k0, k1, k2 + * row2: k0, k1, k2 + * row3: k0, k1, k2 + * row1: k0, k1, k2 + * row2: k0, k1, k2 + * row3: k0, k1, k2 + * + * etc... + * + * @tparam idx_t + * @tparam value_t + * @param in_keys + * @param in_values + * @param out_keys + * @param out_values + * @param n_samples + * @param n_parts + * @param k + * @param stream + * @param translations + */ +template +inline void knn_merge_parts( + const raft::handle_t& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + int k, + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), + "in_keys and in_values must have the same shape."); + RAFT_EXPECTS( + out_keys.extent(0) == out_values.extent(0) == n_samples, + "Number of rows in output keys and val matrices must equal number of rows in search matrix."); + RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); + + auto n_parts = in_keys.extent(0) / n_samples; + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + k, + handle.get_stream(), + translations.value_or(nullptr)); +} + /** Choose an implementation for the select-top-k, */ enum class SelectKAlgo { /** Adapted from the faiss project. Result: sorted (not stable). */ @@ -169,6 +224,77 @@ inline void select_k(const value_t* in_keys, } } +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_keys` as a row-major matrix with input_len columns and + * n_inputs rows, then this function selects k smallest/largest values in each row and fills + * in the row-major matrix `out_keys` of size (n_inputs, k). + * + * Note, depending on the selected algorithm, the values within rows of `out_keys` are not + * necessarily sorted. See the `SelectKAlgo` enumeration for more details. + * + * @tparam idx_t + * the payload type (what is being selected together with the keys). + * @tparam value_t + * the type of the keys (what is being compared). + * + * @param[in] in_keys + * contiguous device array of inputs of size (input_len * n_inputs); + * these are compared and selected. + * @param[in] in_values + * contiguous device array of inputs of size (input_len * n_inputs); + * typically, these are indices of the corresponding in_keys. + * You can pass `NULL` as an argument here; this would imply `in_values` is a homogeneous array + * of indices from `0` to `input_len - 1` for every input and reduce the usage of memory + * bandwidth. + * @param[in] n_inputs + * number of input rows, i.e. the batch size. + * @param[in] input_len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: input_len >= k. + * @param[out] out_keys + * contiguous device array of outputs of size (k * n_inputs); + * the k smallest/largest values from each row of the `in_keys`. + * @param[out] out_values + * contiguous device array of outputs of size (k * n_inputs); + * the payload selected together with `out_keys`. + * @param[in] select_min + * whether to select k smallest (true) or largest (false) keys. + * @param[in] k + * the number of outputs to select in each input row. + * @param[in] stream + * @param[in] algo + * the implementation of the algorithm + */ +template +inline void select_k(const raft::handle_t& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + int k, + bool select_min = true, + SelectKAlgo algo = SelectKAlgo::FAISS) +{ + size_t n_inputs = in_keys.extents(0); + size_t input_len = in_keys.extents(1); + + RAFT_EXPECTS(in_keys.extent(0) == in_values.extent(0) && in_keys.extent(1) == in_values.extent(1), + "in_keys and in_values must have the same shape"); + + select_k(in_keys.data_handle(), + in_values.data_handle(), + n_inputs, + input_len, + out_keys.data_handle(), + out_values.data_handle(), + select_min, + k, + handle.get_stream(), + algo); +} + /** * @brief Flat C++ API function to perform a brute force knn on * a series of input arrays and combine the results into a single @@ -233,6 +359,10 @@ void brute_force_knn(raft::handle_t const& handle, * row- or column-major but the output matrices will always be in * row-major format. * + * @example + * + * + * * @param[in] handle the cuml handle to use * @param[in] index vector of device matrices (each size m_i*d) to be used as the knn index * @param[in] search matrix (size n*d) to be used for searching the index @@ -243,51 +373,58 @@ void brute_force_knn(raft::handle_t const& handle, * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. */ -template -void brute_force_knn(raft::handle_t const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - value_int k, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional> translations = std::nullopt) { +void brute_force_knn( + raft::handle_t const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + value_int k, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(index[0].extent(1) == search.extent(1), + "Number of dimensions for both index and search matrices must be equal"); + RAFT_EXPECTS(indices.extent(0) == distances.extent(0) == search.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); - RAFT_EXPECTS(index.extent(1) == search.extent(1), "Number of dimensions for both index and search matrices must be equal"); - RAFT_EXPECTS(indices.extent(0) == distances.extent(0) == search.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows in search matrix."); - RAFT_EXPECTS(indices.extent(1) == distances.extent(1) == k, - "Number of columns in output indices and distances matrices must be equal to k"); + bool rowMajorIndex = std::is_same_v; + bool rowMajorQuery = std::is_same_v; - constexpr auto rowMajorIndex = std::is_same_v; - constexpr auto rowMajorQuery = std::is_same_v; + std::vector inputs; + std::vector sizes; + for (std::size_t i = 0; i < index.size(); ++i) { + inputs.push_back(const_cast(index[i].data_handle())); + sizes.push_back(index[i].extent(0)); + } - std::vector inputs; - std::vector sizes; - for(int i = 0; i < index.size(); ++i) { - inputs.push_back(const_cast(index[i].data_handle())); - sizes.push_back(index[i].extents(0)); - } + std::vector* trans = translations.has_value() ? &(*translations) : nullptr; - detail::brute_force_knn_impl(handle, - input, - sizes, - index.extents(1), - search.data_handle(), - search.extents(1), - indices.data_handle(), - distances.data_handle(), - k, - rowMajorIndex, - rowMajorQuery, - translations.value_or(nullptr), - metric, - metric_arg.value_or(2.0)); + detail::brute_force_knn_impl(handle, + inputs, + sizes, + (value_int)index[0].extent(1), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + (value_int)search.extent(1), + indices.data_handle(), + distances.data_handle(), + k, + rowMajorIndex, + rowMajorQuery, + trans, + metric, + metric_arg.value_or(2.0)); } } // namespace raft::spatial::knn diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index d679971401..0b9572b57f 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -16,8 +16,8 @@ #include "../test_utils.h" -#include #include +#include #include #include #if defined RAFT_NN_COMPILED @@ -87,11 +87,14 @@ class KNNTest : public ::testing::TestWithParam { raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); #endif - auto index = raft::make_device_matrix_view(input_.data(), rows_, cols_); - auto search = raft::make_device_matrix_view(search_data_.data(), rows_, cols_); + std::vector> index = { + make_device_matrix_view((const T*)(input_.data()), rows_, cols_)}; + auto search = raft::make_device_matrix_view( + (const T*)(search_data_.data()), rows_, cols_); - auto indices = raft::make_device_matrix_view(indices_.data(), rows_, k_); - auto distances = raft::make_device_matrix_view(distances_.data(), rows_, k_); + auto indices = raft::make_device_matrix_view(indices_.data(), rows_, k_); + auto distances = + raft::make_device_matrix_view(distances_.data(), rows_, k_); brute_force_knn(handle, index, search, indices, distances, k_); From b6c758c44e6186149f36f2ccabb857353d64b5be Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 14 Sep 2022 19:55:06 -0400 Subject: [PATCH 12/58] Fixing style --- cpp/include/raft/core/detail/device_mdarray.hpp | 2 +- cpp/include/raft/spatial/knn/ivf_flat_types.hpp | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/include/raft/core/detail/device_mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp index ad387db33d..ff7c31000d 100644 --- a/cpp/include/raft/core/detail/device_mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -21,9 +21,9 @@ * limitations under the License. */ #pragma once -#include #include #include +#include #include #include // dynamic_extent diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index 5014e6c41e..aabd6edfe3 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -20,7 +20,6 @@ #include #include -#include #include #include From 105a3a650bbf67d14c48ef3e2799a4f7793ef24e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 14 Sep 2022 20:56:37 -0400 Subject: [PATCH 13/58] Trying to FIND_RAFT_CPP on by default --- python/pylibraft/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index 6567fdff0f..9076626692 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -29,7 +29,7 @@ project( CXX) option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulting to local files" - OFF) + ON) # If the user requested it we attempt to find RAFT. if(FIND_RAFT_CPP) From 7eae6e3d16ad20ceab58b0bf97a041aab177d04a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 15 Sep 2022 13:00:12 -0400 Subject: [PATCH 14/58] Fixing bad merge --- cpp/include/raft/spatial/knn/ivf_flat_types.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp index aabd6edfe3..41fa1dd8ce 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat_types.hpp +++ b/cpp/include/raft/spatial/knn/ivf_flat_types.hpp @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include From 141a2d10e33c25329bf8ed6b839c24817a0693df Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 19 Sep 2022 13:27:50 -0400 Subject: [PATCH 15/58] Fixing knn wrapper --- .../raft/spatial/knn/ball_cover_types.hpp | 1 + cpp/include/raft/spatial/knn/knn.cuh | 18 +++++++++++------- cpp/test/spatial/knn.cu | 6 ++++-- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp index 9870217011..337df692bd 100644 --- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index ae144fdc18..31e26b6e77 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -369,9 +369,11 @@ void brute_force_knn(raft::handle_t const& handle, * @param[out] indices matrix (size n*k) to store output knn indices * @param[out] distances matrix (size n*k) to store the output knn distance * @param[in] k the number of nearest neighbors to return - * @param[in] metric distance metric to use. Euclidean (L2) is used by default * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + * @param[in] translations starting offsets for partitions. should be the same size + * as input vector. */ template (k), + "Number of columns in output indices and distances matrices must be equal to k"); bool rowMajorIndex = std::is_same_v; bool rowMajorQuery = std::is_same_v; @@ -413,10 +417,10 @@ void brute_force_knn( detail::brute_force_knn_impl(handle, inputs, sizes, - (value_int)index[0].extent(1), + static_cast(index[0].extent(1)), // TODO: This is unfortunate. Need to fix. const_cast(search.data_handle()), - (value_int)search.extent(1), + static_cast(search.extent(0)), indices.data_handle(), distances.data_handle(), k, @@ -424,7 +428,7 @@ void brute_force_knn( rowMajorQuery, trans, metric, - metric_arg.value_or(2.0)); + metric_arg.value_or(2.0f)); } } // namespace raft::spatial::knn diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index e41b1ba541..744d58b412 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -81,11 +81,11 @@ class KNNTest : public ::testing::TestWithParam { protected: void testBruteForce() { -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) +//#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) raft::print_device_vector("Input array: ", input_.data(), rows_ * cols_, std::cout); std::cout << "K: " << k_ << std::endl; raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); -#endif +//#endif std::vector> index = { make_device_matrix_view((const T*)(input_.data()), rows_, cols_)}; @@ -96,6 +96,8 @@ class KNNTest : public ::testing::TestWithParam { auto distances = raft::make_device_matrix_view(distances_.data(), rows_, k_); + printf("indices: %ld, distances: %ld, search: %ld\n", (size_t)indices.extent(0), (size_t)(distances.extent(0)), (size_t)search.extent(0)); + brute_force_knn(handle, index, search, indices, distances, k_); build_actual_output<<>>( From 6cd27af9365421cb9f2e6f74d2616fb2c3108359 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 19 Sep 2022 14:10:21 -0400 Subject: [PATCH 16/58] mdspanidying random ball cover --- cpp/include/raft/spatial/knn/ball_cover.cuh | 109 ++++++++ .../raft/spatial/knn/ball_cover_types.hpp | 56 +++- .../raft/spatial/knn/detail/ball_cover.cuh | 53 ++-- .../knn/detail/ball_cover/registers.cuh | 250 +++++++++--------- cpp/test/spatial/ball_cover.cu | 17 +- cpp/test/spatial/knn.cu | 9 +- 6 files changed, 325 insertions(+), 169 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index a354f6d5a4..e166c7f9ca 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -111,6 +111,55 @@ void rbc_all_knn_query(const raft::handle_t& handle, index.set_index_trained(); } +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * performs an all neighbors knn, which can reuse memory when + * the index and query are the same array. This function will + * build the index and assumes rbc_build_index() has not already + * been called. + * @tparam value_idx knn index type + * @tparam value_t knn distance type + * @tparam value_int type for integers, such as number of rows/cols + * @param handle raft handle for resource management + * @param index ball cover index which has not yet been built + * @param k number of nearest neighbors to find + * @param perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + */ +template +void rbc_all_knn_query(const raft::handle_t& handle, + BallCoverIndex& index, + value_int k, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in index matrix."); + + rbc_all_knn_query( + handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); +} + /** * Performs a faster exact knn in metric spaces using the triangle * inequality with a number of landmark points to reduce the @@ -180,6 +229,66 @@ void rbc_knn_query(const raft::handle_t& handle, } } +/** + * Performs a faster exact knn in metric spaces using the triangle + * inequality with a number of landmark points to reduce the + * number of distance computations from O(n^2) to O(sqrt(n)). This + * function does not build the index and assumes rbc_build_index() has + * already been called. Use this function when the index and + * query arrays are different, otherwise use rbc_all_knn_query(). + * @tparam value_idx index type + * @tparam value_t distances type + * @tparam value_int integer type for size info + * @param handle raft handle for resource management + * @param index ball cover index which has not yet been built + * @param k number of nearest neighbors to find + * @param query the + * @param perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[out] inds output knn indices + * @param[out] dists output knn distances + * @param weight a weight for overlap between the closest landmark and + * the radius of other landmarks when pruning distances. + * Setting this value below 1 can effectively turn off + * computing distances against many other balls, enabling + * approximate nearest neighbors. Recall can be adjusted + * based on how many relevant balls are ignored. Note that + * many datasets can still have great recall even by only + * looking in the closest landmark. + * @param[in] n_query_pts number of query points + */ +template +void rbc_knn_query(const raft::handle_t& handle, + BallCoverIndex& index, + value_int k, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matric_view dists, + bool perform_post_filtering = true, + float weight = 1.0) +{ + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + + RAFT_EXPECTS(query.extent(1) == index.get_R().extent(1), + "Number of columns in query and index matrices must match."); + + rbc_knn_query(handle, + index, + k, + query.data_handle(), + query.extent(0), + inds.data_handle(), + dists.data_handle(), + perform_post_filtering, + weight); +} + // TODO: implement functions for: // 4. rbc_eps_neigh() - given a populated index, perform query against different query array // 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp index 337df692bd..b05f1ff901 100644 --- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -64,13 +65,40 @@ class BallCoverIndex { { } - value_idx* get_R_indptr() { return R_indptr.data(); } - value_idx* get_R_1nn_cols() { return R_1nn_cols.data(); } - value_t* get_R_1nn_dists() { return R_1nn_dists.data(); } - value_t* get_R_radius() { return R_radius.data(); } - value_t* get_R() { return R.data(); } - value_t* get_R_closest_landmark_dists() { return R_closest_landmark_dists.data(); } - const value_t* get_X() { return X; } + explicit BallCoverIndex(const raft::handle_t& handle_, + raft::device_matrix_view X_, + raft::distance::DistanceType metric_) + : handle(handle_), + X(X_.data_handle()), + m(X_.extent(0)), + n(X_.extent(1)), + metric(metric_), + /** + * the sqrt() here makes the sqrt(m)^2 a linear-time lower bound + * + * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) + */ + n_landmarks(sqrt(m_)), + R_indptr(std::move(raft::make_device_vector(handle, sqrt(m_) + 1))), + R_1nn_cols(std::move(raft::make_device_vector(handle, m_))), + R_1nn_dists(std::move(raft::make_device_vector(handle, m_))), + R_closest_landmark_dists(std::move(raft::make_device_vector(handle, m_))), + R(std::move(raft::make_device_matrix(handle, sqrt(m_), n_))), + R_radius(std::move(raft::make_device_vector(handle, sqrt(m_)))), + index_trained(false) + { + } + + raft::device_vector_view get_R_indptr() { return R_indptr.view(); } + raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } + raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); } + raft::device_vector_view get_R_radius() { return R_radius.view(); } + raft::device_matrix_view get_R() { return R.view(); } + raft::device_vector_view get_R_closest_landmark_dists() + { + return R_closest_landmark_dists.view(); + } + const raft::device_matrix_view get_X() { return X; } bool is_index_trained() const { return index_trained; }; @@ -83,20 +111,20 @@ class BallCoverIndex { const value_int n; const value_int n_landmarks; - const value_t* X; + raft::device_matrix_view X; raft::distance::DistanceType metric; private: // CSR storing the neighborhoods for each data point - rmm::device_uvector R_indptr; - rmm::device_uvector R_1nn_cols; - rmm::device_uvector R_1nn_dists; - rmm::device_uvector R_closest_landmark_dists; + raft::device_vector R_indptr; + raft::device_vector R_1nn_cols; + raft::device_vector R_1nn_dists; + raft::device_vector R_closest_landmark_dists; - rmm::device_uvector R_radius; + raft::device_vector R_radius; - rmm::device_uvector R; + raft::device_matrix R; protected: bool index_trained; diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 457e1f495a..e65a895f60 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -75,8 +75,8 @@ void sample_landmarks(const raft::handle_t& handle, rmm::device_uvector R_indices(index.n_landmarks, handle.get_stream()); thrust::sequence(handle.get_thrust_policy(), - index.get_R_1nn_cols(), - index.get_R_1nn_cols() + index.m, + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_cols().data_handle() + index.m, (value_idx)0); thrust::fill( @@ -93,15 +93,15 @@ void sample_landmarks(const raft::handle_t& handle, rng_state, R_indices.data(), R_1nn_cols2.data(), - index.get_R_1nn_cols(), + index.get_R_1nn_cols().data_handle(), R_1nn_ones.data(), (value_idx)index.n_landmarks, (value_idx)index.m); - raft::matrix::copyRows(index.get_X(), + raft::matrix::copyRows(index.get_X().data_handle(), index.m, index.n, - index.get_R(), + index.get_R().data_handle(), R_1nn_cols2.data(), index.n_landmarks, handle.get_stream(), @@ -133,7 +133,7 @@ void construct_landmark_1nn(const raft::handle_t& handle, std::numeric_limits::max()); value_idx* R_1nn_inds_ptr = R_1nn_inds.data(); - value_t* R_1nn_dists_ptr = index.get_R_1nn_dists(); + value_t* R_1nn_dists_ptr = index.get_R_1nn_dists().data_handle(); auto idxs = thrust::make_counting_iterator(0); thrust::for_each(handle.get_thrust_policy(), idxs, idxs + index.m, [=] __device__(value_idx i) { @@ -141,16 +141,22 @@ void construct_landmark_1nn(const raft::handle_t& handle, R_1nn_dists_ptr[i] = R_knn_dists_ptr[i * k]; }); - auto keys = - thrust::make_zip_iterator(thrust::make_tuple(R_1nn_inds.data(), index.get_R_1nn_dists())); + auto keys = thrust::make_zip_iterator( + thrust::make_tuple(R_1nn_inds.data(), index.get_R_1nn_dists().data_handle())); // group neighborhoods for each reference landmark and sort each group by distance - thrust::sort_by_key( - handle.get_thrust_policy(), keys, keys + index.m, index.get_R_1nn_cols(), NNComp()); + thrust::sort_by_key(handle.get_thrust_policy(), + keys, + keys + index.m, + index.get_R_1nn_cols().data_handle(), + NNComp()); // convert to CSR for fast lookup - raft::sparse::convert::sorted_coo_to_csr( - R_1nn_inds.data(), index.m, index.get_R_indptr(), index.n_landmarks + 1, handle.get_stream()); + raft::sparse::convert::sorted_coo_to_csr(R_1nn_inds.data(), + index.m, + index.get_R_indptr().data_handle(), + index.n_landmarks + 1, + handle.get_stream()); } /** @@ -175,7 +181,7 @@ void k_closest_landmarks(const raft::handle_t& handle, value_idx* R_knn_inds, value_t* R_knn_dists) { - std::vector input = {index.get_R()}; + std::vector input = {index.get_R().data_handle()}; std::vector sizes = {index.n_landmarks}; brute_force_knn_impl(handle, @@ -207,9 +213,9 @@ void compute_landmark_radii(const raft::handle_t& handle, { auto entries = thrust::make_counting_iterator(0); - const value_idx* R_indptr_ptr = index.get_R_indptr(); - const value_t* R_1nn_dists_ptr = index.get_R_1nn_dists(); - value_t* R_radius_ptr = index.get_R_radius(); + const value_idx* R_indptr_ptr = index.get_R_indptr().data_handle(); + const value_t* R_1nn_dists_ptr = index.get_R_1nn_dists().data_handle(); + value_t* R_radius_ptr = index.get_R_radius().data_handle(); thrust::for_each(handle.get_thrust_policy(), entries, entries + index.n_landmarks, @@ -350,8 +356,8 @@ void rbc_build_index(const raft::handle_t& handle, R_knn_inds.end(), std::numeric_limits::max()); thrust::fill(handle.get_thrust_policy(), - index.get_R_closest_landmark_dists(), - index.get_R_closest_landmark_dists() + index.m, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_closest_landmark_dists().data_handle() + index.m, std::numeric_limits::max()); /** @@ -365,11 +371,11 @@ void rbc_build_index(const raft::handle_t& handle, value_int k = 1; k_closest_landmarks(handle, index, - index.get_X(), + index.get_X().data_handle(), index.m, k, R_knn_inds.data(), - index.get_R_closest_landmark_dists()); + index.get_R_closest_landmark_dists().data_handle()); /** * 3. Create L_r = knn[:,0].T (CSR) @@ -377,7 +383,8 @@ void rbc_build_index(const raft::handle_t& handle, * Slice closest neighboring R * Secondary sort by (R_knn_inds, R_knn_dists) */ - construct_landmark_1nn(handle, R_knn_inds.data(), index.get_R_closest_landmark_dists(), k, index); + construct_landmark_1nn( + handle, R_knn_inds.data(), index.get_R_closest_landmark_dists().data_handle(), k, index); /** * Compute radius of each R for filtering: p(q, r) <= p(q, q_r) + radius(r) @@ -432,7 +439,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, sample_landmarks(handle, index); k_closest_landmarks( - handle, index, index.get_X(), index.m, k, R_knn_inds.data(), R_knn_dists.data()); + handle, index, index.get_X().data_handle(), index.m, k, R_knn_inds.data(), R_knn_dists.data()); construct_landmark_1nn(handle, R_knn_inds.data(), R_knn_dists.data(), k, index); @@ -440,7 +447,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, perform_rbc_query(handle, index, - index.get_X(), + index.get_X().data_handle(), index.m, k, R_knn_inds.data(), diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 88f5aa3460..751289fca4 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -486,95 +486,95 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, { if (k <= 32) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 64) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 128) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 256) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); else if (k <= 512) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); @@ -587,13 +587,13 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, R_knn_dists, index.m, k, - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), inds, dists, dists_counter, - index.get_R_radius(), + index.get_R_radius().data_handle(), dfunc, weight); } @@ -645,22 +645,22 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, 32, 2, 128, - dims> - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 64) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 128) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 256) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 512) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); else if (k <= 1024) compute_final_dists_registers - <<>>(index.get_X(), - query, - index.n, - bitset.data(), - bitset_size, - index.get_R_closest_landmark_dists(), - index.get_R_indptr(), - index.get_R_1nn_cols(), - index.get_R_1nn_dists(), - inds, - dists, - index.n_landmarks, - k, - dfunc, - post_dists_counter); + dims><<>>( + index.get_X().data_handle(), + query, + index.n, + bitset.data(), + bitset_size, + index.get_R_closest_landmark_dists().data_handle(), + index.get_R_indptr().data_handle(), + index.get_R_1nn_cols().data_handle(), + index.get_R_1nn_dists().data_handle(), + inds, + dists, + index.n_landmarks, + k, + dfunc, + post_dists_counter); } }; // namespace detail diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 46867f0fa7..8e831a25d9 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -16,6 +16,7 @@ #include "../test_utils.h" #include "spatial_data.h" +#include #include #include #include @@ -200,12 +201,20 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { rmm::device_uvector d_pred_I(params.n_query * k, handle.get_stream()); rmm::device_uvector d_pred_D(params.n_query * k, handle.get_stream()); - BallCoverIndex index( - handle, X.data(), params.n_rows, params.n_cols, metric); + auto X_view = raft::make_device_matrix_view( + handle, X.data(), params.n_rows, params.n_cols); + auto X2_view = raft::make_device_matrix_view( + handle, X2.data(), params.n_query, params.n_cols); + + auto d_pred_I_view = + raft::make_device_matrix_view(handle, params.n_query, k); + auto d_pred_D_view = + raft::make_device_matrix_view(handle, params.n_query, k); + + BallCoverIndex index(handle, X, metric); raft::spatial::knn::rbc_build_index(handle, index); - raft::spatial::knn::rbc_knn_query( - handle, index, k, X2.data(), params.n_query, d_pred_I.data(), d_pred_D.data(), true, weight); + raft::spatial::knn::rbc_knn_query(handle, index, k, X2, d_pred_I, d_pred_D, true, weight); handle.sync_stream(); // What we really want are for the distances to match exactly. The diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index 744d58b412..acdf122acf 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -81,11 +81,11 @@ class KNNTest : public ::testing::TestWithParam { protected: void testBruteForce() { -//#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) + //#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) raft::print_device_vector("Input array: ", input_.data(), rows_ * cols_, std::cout); std::cout << "K: " << k_ << std::endl; raft::print_device_vector("Labels array: ", search_labels_.data(), rows_, std::cout); -//#endif + //#endif std::vector> index = { make_device_matrix_view((const T*)(input_.data()), rows_, cols_)}; @@ -96,7 +96,10 @@ class KNNTest : public ::testing::TestWithParam { auto distances = raft::make_device_matrix_view(distances_.data(), rows_, k_); - printf("indices: %ld, distances: %ld, search: %ld\n", (size_t)indices.extent(0), (size_t)(distances.extent(0)), (size_t)search.extent(0)); + printf("indices: %ld, distances: %ld, search: %ld\n", + (size_t)indices.extent(0), + (size_t)(distances.extent(0)), + (size_t)search.extent(0)); brute_force_knn(handle, index, search, indices, distances, k_); From 487055430698cf13ca4924d7a0af57ea16c5eaa8 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 19 Sep 2022 16:23:52 -0400 Subject: [PATCH 17/58] mdspan-ifying ivf_flat, rbc, and epsilon neighborhoods --- cpp/include/raft/spatial/knn/ball_cover.cuh | 133 ++++++----- .../raft/spatial/knn/ball_cover_types.hpp | 70 +++--- .../knn/detail/ball_cover/registers.cuh | 6 +- .../raft/spatial/knn/epsilon_neighborhood.cuh | 60 ++++- cpp/include/raft/spatial/knn/ivf_flat.cuh | 207 +++++++++++++++++- cpp/test/spatial/ann_ivf_flat.cu | 36 +-- cpp/test/spatial/ball_cover.cu | 57 +++-- 7 files changed, 429 insertions(+), 140 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index e166c7f9ca..16cdc504fd 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -30,16 +30,28 @@ namespace raft { namespace spatial { namespace knn { -template +/** + * Builds and populates a previously unbuilt BallCoverIndex + * @tparam idx_t knn index type + * @tparam value_t knn value type + * @tparam int_t integral type for knn params + * @tparam matrix_idx_t matrix indexing type + * @param handle library resource management handle + * @param index an empty (and not previous built) instance of BallCoverIndex + */ +template void rbc_build_index(const raft::handle_t& handle, - BallCoverIndex& index) + BallCoverIndex& index) { ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { - detail::rbc_build_index(handle, index, detail::HaversineFunc()); + detail::rbc_build_index(handle, index, detail::HaversineFunc()); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { - detail::rbc_build_index(handle, index, detail::EuclideanFunc()); + detail::rbc_build_index(handle, index, detail::EuclideanFunc()); } else { RAFT_FAIL("Metric not support"); } @@ -55,9 +67,9 @@ void rbc_build_index(const raft::handle_t& handle, * the index and query are the same array. This function will * build the index and assumes rbc_build_index() has not already * been called. - * @tparam value_idx knn index type + * @tparam idx_t knn index type * @tparam value_t knn distance type - * @tparam value_int type for integers, such as number of rows/cols + * @tparam int_t type for integers, such as number of rows/cols * @param handle raft handle for resource management * @param index ball cover index which has not yet been built * @param k number of nearest neighbors to find @@ -75,11 +87,11 @@ void rbc_build_index(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template +template void rbc_all_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, - value_idx* inds, + BallCoverIndex& index, + int_t k, + idx_t* inds, value_t* dists, bool perform_post_filtering = true, float weight = 1.0) @@ -91,7 +103,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, k, inds, dists, - detail::HaversineFunc(), + detail::HaversineFunc(), perform_post_filtering, weight); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -101,7 +113,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, k, inds, dists, - detail::EuclideanFunc(), + detail::EuclideanFunc(), perform_post_filtering, weight); } else { @@ -119,18 +131,19 @@ void rbc_all_knn_query(const raft::handle_t& handle, * the index and query are the same array. This function will * build the index and assumes rbc_build_index() has not already * been called. - * @tparam value_idx knn index type + * @tparam idx_t knn index type * @tparam value_t knn distance type - * @tparam value_int type for integers, such as number of rows/cols - * @param handle raft handle for resource management - * @param index ball cover index which has not yet been built - * @param k number of nearest neighbors to find - * @param perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). + * @tparam int_t type for integers, such as number of rows/cols + * @tparam matrix_idx_t matrix indexing type + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built * @param[out] inds output knn indices * @param[out] dists output knn distances - * @param weight a weight for overlap between the closest landmark and + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and * the radius of other landmarks when pruning distances. * Setting this value below 1 can effectively turn off * computing distances against many other balls, enabling @@ -139,17 +152,22 @@ void rbc_all_knn_query(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template +template void rbc_all_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, - raft::device_matrix_view inds, - raft::device_matrix_view dists, + BallCoverIndex& index, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k = 5, bool perform_post_filtering = true, float weight = 1.0) { RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), "Number of columns in output indices and distances matrices must be equal to k"); RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), @@ -167,9 +185,9 @@ void rbc_all_knn_query(const raft::handle_t& handle, * function does not build the index and assumes rbc_build_index() has * already been called. Use this function when the index and * query arrays are different, otherwise use rbc_all_knn_query(). - * @tparam value_idx index type + * @tparam idx_t index type * @tparam value_t distances type - * @tparam value_int integer type for size info + * @tparam int_t integer type for size info * @param handle raft handle for resource management * @param index ball cover index which has not yet been built * @param k number of nearest neighbors to find @@ -179,7 +197,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, * results). * @param[out] inds output knn indices * @param[out] dists output knn distances - * @param weight a weight for overlap between the closest landmark and + * @param[in] weight a weight for overlap between the closest landmark and * the radius of other landmarks when pruning distances. * Setting this value below 1 can effectively turn off * computing distances against many other balls, enabling @@ -189,13 +207,13 @@ void rbc_all_knn_query(const raft::handle_t& handle, * looking in the closest landmark. * @param[in] n_query_pts number of query points */ -template +template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, + BallCoverIndex& index, + int_t k, const value_t* query, - value_int n_query_pts, - value_idx* inds, + int_t n_query_pts, + idx_t* inds, value_t* dists, bool perform_post_filtering = true, float weight = 1.0) @@ -209,7 +227,7 @@ void rbc_knn_query(const raft::handle_t& handle, n_query_pts, inds, dists, - detail::HaversineFunc(), + detail::HaversineFunc(), perform_post_filtering, weight); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -221,7 +239,7 @@ void rbc_knn_query(const raft::handle_t& handle, n_query_pts, inds, dists, - detail::EuclideanFunc(), + detail::EuclideanFunc(), perform_post_filtering, weight); } else { @@ -236,19 +254,20 @@ void rbc_knn_query(const raft::handle_t& handle, * function does not build the index and assumes rbc_build_index() has * already been called. Use this function when the index and * query arrays are different, otherwise use rbc_all_knn_query(). - * @tparam value_idx index type + * @tparam idx_t index type * @tparam value_t distances type - * @tparam value_int integer type for size info - * @param handle raft handle for resource management - * @param index ball cover index which has not yet been built - * @param k number of nearest neighbors to find - * @param query the - * @param perform_post_filtering if this is false, only the closest k landmarks - * are considered (which will return approximate - * results). + * @tparam int_t integer type for size info + * @tparam matrix_idx_t + * @param[in] handle raft handle for resource management + * @param[in] index ball cover index which has not yet been built + * @param[in] query device matrix containing query data points * @param[out] inds output knn indices * @param[out] dists output knn distances - * @param weight a weight for overlap between the closest landmark and + * @param[in] k number of nearest neighbors to find + * @param[in] perform_post_filtering if this is false, only the closest k landmarks + * are considered (which will return approximate + * results). + * @param[in] weight a weight for overlap between the closest landmark and * the radius of other landmarks when pruning distances. * Setting this value below 1 can effectively turn off * computing distances against many other balls, enabling @@ -256,19 +275,23 @@ void rbc_knn_query(const raft::handle_t& handle, * based on how many relevant balls are ignored. Note that * many datasets can still have great recall even by only * looking in the closest landmark. - * @param[in] n_query_pts number of query points */ -template +template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, - value_int k, - raft::device_matrix_view query, - raft::device_matrix_view inds, - raft::device_matric_view dists, + BallCoverIndex& index, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k = 5, bool perform_post_filtering = true, float weight = 1.0) { - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + RAFT_EXPECTS(k <= index.m, + "k must be less than or equal to the number of data points in the index"); + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), "Number of columns in output indices and distances matrices must be equal to k"); RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0), diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp index b05f1ff901..47c9397fdd 100644 --- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp @@ -36,7 +36,10 @@ namespace knn { * @tparam value_t * @tparam value_int */ -template +template class BallCoverIndex { public: explicit BallCoverIndex(const raft::handle_t& handle_, @@ -45,7 +48,7 @@ class BallCoverIndex { value_int n_, raft::distance::DistanceType metric_) : handle(handle_), - X(X_), + X(std::move(raft::make_device_matrix_view(X_, m_, n_))), m(m_), n(n_), metric(metric_), @@ -55,21 +58,22 @@ class BallCoverIndex { * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) */ n_landmarks(sqrt(m_)), - R_indptr(sqrt(m_) + 1, handle.get_stream()), - R_1nn_cols(m_, handle.get_stream()), - R_1nn_dists(m_, handle.get_stream()), - R_closest_landmark_dists(m_, handle.get_stream()), - R(sqrt(m_) * n_, handle.get_stream()), - R_radius(sqrt(m_), handle.get_stream()), + R_indptr(std::move(raft::make_device_vector(handle, sqrt(m_) + 1))), + R_1nn_cols(std::move(raft::make_device_vector(handle, m_))), + R_1nn_dists(std::move(raft::make_device_vector(handle, m_))), + R_closest_landmark_dists( + std::move(raft::make_device_vector(handle, m_))), + R(std::move(raft::make_device_matrix(handle, sqrt(m_), n_))), + R_radius(std::move(raft::make_device_vector(handle, sqrt(m_)))), index_trained(false) { } explicit BallCoverIndex(const raft::handle_t& handle_, - raft::device_matrix_view X_, + raft::device_matrix_view X_, raft::distance::DistanceType metric_) : handle(handle_), - X(X_.data_handle()), + X(X_), m(X_.extent(0)), n(X_.extent(1)), metric(metric_), @@ -78,27 +82,31 @@ class BallCoverIndex { * * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) */ - n_landmarks(sqrt(m_)), - R_indptr(std::move(raft::make_device_vector(handle, sqrt(m_) + 1))), - R_1nn_cols(std::move(raft::make_device_vector(handle, m_))), - R_1nn_dists(std::move(raft::make_device_vector(handle, m_))), - R_closest_landmark_dists(std::move(raft::make_device_vector(handle, m_))), - R(std::move(raft::make_device_matrix(handle, sqrt(m_), n_))), - R_radius(std::move(raft::make_device_vector(handle, sqrt(m_)))), + n_landmarks(sqrt(X_.extent(0))), + R_indptr( + std::move(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1))), + R_1nn_cols(std::move(raft::make_device_vector(handle, X_.extent(0)))), + R_1nn_dists(std::move(raft::make_device_vector(handle, X_.extent(0)))), + R_closest_landmark_dists( + std::move(raft::make_device_vector(handle, X_.extent(0)))), + R(std::move( + raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1)))), + R_radius( + std::move(raft::make_device_vector(handle, sqrt(X_.extent(0))))), index_trained(false) { } - raft::device_vector_view get_R_indptr() { return R_indptr.view(); } - raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } - raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); } - raft::device_vector_view get_R_radius() { return R_radius.view(); } - raft::device_matrix_view get_R() { return R.view(); } - raft::device_vector_view get_R_closest_landmark_dists() + raft::device_vector_view get_R_indptr() { return R_indptr.view(); } + raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } + raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); } + raft::device_vector_view get_R_radius() { return R_radius.view(); } + raft::device_matrix_view get_R() { return R.view(); } + raft::device_vector_view get_R_closest_landmark_dists() { return R_closest_landmark_dists.view(); } - const raft::device_matrix_view get_X() { return X; } + raft::device_matrix_view get_X() { return X; } bool is_index_trained() const { return index_trained; }; @@ -111,20 +119,20 @@ class BallCoverIndex { const value_int n; const value_int n_landmarks; - raft::device_matrix_view X; + raft::device_matrix_view X; raft::distance::DistanceType metric; private: // CSR storing the neighborhoods for each data point - raft::device_vector R_indptr; - raft::device_vector R_1nn_cols; - raft::device_vector R_1nn_dists; - raft::device_vector R_closest_landmark_dists; + raft::device_vector R_indptr; + raft::device_vector R_1nn_cols; + raft::device_vector R_1nn_dists; + raft::device_vector R_closest_landmark_dists; - raft::device_vector R_radius; + raft::device_vector R_radius; - raft::device_matrix R; + raft::device_matrix R; protected: bool index_trained; diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index 751289fca4..c0056e7137 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -580,7 +580,7 @@ void rbc_low_dim_pass_one(const raft::handle_t& handle, else if (k <= 1024) block_rbc_kernel_registers - <<>>(index.get_X(), + <<>>(index.get_X().data_handle(), query, index.n, R_knn_inds, @@ -627,8 +627,8 @@ void rbc_low_dim_pass_two(const raft::handle_t& handle, index.n, R_knn_inds, R_knn_dists, - index.get_R_radius(), - index.get_R(), + index.get_R_radius().data_handle(), + index.get_R().data_handle(), index.n_landmarks, bitset_size, k, diff --git a/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh index 29ed51fb3d..dce5f0f99d 100644 --- a/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh @@ -19,6 +19,7 @@ #pragma once +#include #include namespace raft { @@ -28,8 +29,8 @@ namespace knn { /** * @brief Computes epsilon neighborhood for the L2-Squared distance metric * - * @tparam DataT IO and math type - * @tparam IdxT Index type + * @tparam value_t IO and math type + * @tparam idx_t Index type * * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n] * @param[out] vd vertex degree array [on device] [len = m + 1] @@ -44,19 +45,56 @@ namespace knn { * squared as we compute L2-squared distance in this method) * @param[in] stream cuda stream */ -template +template void epsUnexpL2SqNeighborhood(bool* adj, - IdxT* vd, - const DataT* x, - const DataT* y, - IdxT m, - IdxT n, - IdxT k, - DataT eps, + idx_t* vd, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + value_t eps, cudaStream_t stream) { - detail::epsUnexpL2SqNeighborhood(adj, vd, x, y, m, n, k, eps, stream); + detail::epsUnexpL2SqNeighborhood(adj, vd, x, y, m, n, k, eps, stream); } + +/** + * @brief Computes epsilon neighborhood for the L2-Squared distance metric + * + * @tparam value_t IO and math type + * @tparam idx_t Index type + * @tparam matrix_idx_t matrix indexing type + * + * @param[in] handle raft handle to manage library resources + * @param[in] x first matrix [row-major] [on device] [dim = m x k] + * @param[in] y second matrix [row-major] [on device] [dim = n x k] + * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n] + * @param[out] vd vertex degree array [on device] [len = m + 1] + * `vd + m` stores the total number of edges in the adjacency + * matrix. Pass a nullptr if you don't need this info. + * @param[in] eps defines epsilon neighborhood radius (should be passed as + * squared as we compute L2-squared distance in this method) + */ +template +void eps_neighbors_l2sq(const raft::handle_t& handle, + raft::device_matrix_view x, + raft::device_matrix_view y, + raft::device_matrix_view adj, + raft::device_vector_view vd, + value_t eps) +{ + epsUnexpL2SqNeighborhood(adj.data_handle(), + vd.data_handle(), + x.data_handle(), + y.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + eps, + handle.get_stream()); +} + } // namespace knn } // namespace spatial } // namespace raft diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index 09bd8edd85..010586579c 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -22,6 +22,7 @@ #include +#include #include #include @@ -67,6 +68,53 @@ inline auto build( return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); } +/** + * @brief Build the index from the dataset for efficient search. + * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * + * Usage example: + * @code{.cpp} + * using namespace raft::spatial::knn; + * // use default index parameters + * ivf_flat::index_params index_params; + * // create and fill the index from a [N, D] dataset + * auto index = ivf_flat::build(handle, index_params, dataset, N, D); + * // use default search parameters + * ivf_flat::search_params search_params; + * // search K nearest neighbours for each of the N queries + * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param handle + * @param params configure the index building + * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] + * + * @return the constructed ivf-flat index + */ +template +auto build_index(const handle_t& handle, + raft::device_matrix_view dataset, + const index_params& params) -> index +{ + return raft::spatial::knn::ivf_flat::detail::build(handle, + params, + dataset.data_handle(), + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); +} + /** * @brief Build a new index containing the data of the original plus new extra vectors. * @@ -95,7 +143,6 @@ inline auto build( * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples * * @return the constructed extended ivf-flat index */ @@ -110,6 +157,57 @@ inline auto extend(const handle_t& handle, handle, orig_index, new_vectors, new_indices, n_rows); } +/** + * @brief Build a new index containing the data of the original plus new extra vectors. + * + * Implementation note: + * The new data is clustered according to existing kmeans clusters, then the cluster + * centers are adjusted to match the newly labeled data. + * + * Usage example: + * @code{.cpp} + * using namespace raft::spatial::knn; + * ivf_flat::index_params index_params; + * index_params.add_data_on_build = false; // don't populate index on build + * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training + * // train the index from a [N, D] dataset + * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); + * // fill the index with the data + * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); + * @endcode + * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param handle + * @param orig_index original index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` + * here to imply a continuous range `[0...n_rows)`. + * + * @return the constructed extended ivf-flat index + */ +template +auto extend(const handle_t& handle, + const index& orig_index, + raft::device_matrix_view new_vectors, + std::optional> new_indices = + std::nullopt) -> index +{ + return raft::spatial::knn::ivf_flat::detail::extend( + handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); +} + /** * @brief Extend the index with the new data. * * @@ -134,6 +232,38 @@ inline void extend(const handle_t& handle, *index = extend(handle, *index, new_vectors, new_indices, n_rows); } +/** + * @brief Extend the index with the new data. + * * + * @tparam value_t data element type + * @tparam idx_t type of the indices in the source dataset + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param handle + * @param[inout] index + * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] + * @param[in] new_indices a device pointer to a vector of indices [n_rows]. + * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` + * here to imply a continuous range `[0...n_rows)`. + */ +template +void extend( + const handle_t& handle, + index* index, + raft::device_matrix_view new_vectors, + std::optional> new_indices = std::nullopt) +{ + *index = extend(handle, + *index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + static_cast(new_vectors.extent(0))); +} + /** * @brief Search ANN using the constructed index. * @@ -191,4 +321,79 @@ inline void search(const handle_t& handle, handle, params, index, queries, n_queries, k, neighbors, distances, mr); } +/** + * @brief Search ANN using the constructed index. + * + * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. + * + * Note, this function requires a temporary buffer to store intermediate results between cuda kernel + * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can + * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or + * eliminate entirely allocations happening within `search`: + * @code{.cpp} + * ... + * // Create a pooling memory resource with a pre-defined initial size. + * rmm::mr::pool_memory_resource mr( + * rmm::mr::get_current_device_resource(), 1024 * 1024); + * // use default search parameters + * ivf_flat::search_params search_params; + * // Use the same allocator across multiple searches to reduce the number of + * // cuda memory allocations + * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); + * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); + * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); + * ... + * @endcode + * The exact size of the temporary buffer depends on multiple factors and is an implementation + * detail. However, you can safely specify a small initial size for the memory pool, so that only a + * few allocations happen to grow it during the first invocations of the `search`. + * + * @tparam value_t data element type + * @tparam idx_t type of the indices + * @tparam int_t precision / type of integral arguments + * @tparam matrix_idx_t matrix indexing type + * + * @param handle + * @param index ivf-flat constructed index + * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] + * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset + * [n_queries, k] + * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] + * @param params configure the search + * @param k the number of neighbors to find for each query. + */ +template +void search(const handle_t& handle, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + const search_params& params, + int_t k) +{ + RAFT_EXPECTS( + queries.extent(0) == neigbors.extent(0) && queries.extent(0) == distances.extent(0), + "Number of rows in output neighbors and distances matrices must equal the number of queries."); + + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1) && + neighbors.extent(1) == static_cast(k), + "Number of columns in output neighbors and distances matrices must equal k"); + + RAFT_EXPECTS(queries.extent(1) == index.dim(), + "Number of query dimensions should equal number of dimensions in the index."); + + return raft::spatial::knn::ivf_flat::detail::search(handle, + params, + index, + queries.data_handle(), + queries.extent(0), + k, + neighbors.data_handle(), + distances.data_handle(), + nullptr); +} + } // namespace raft::spatial::knn::ivf_flat diff --git a/cpp/test/spatial/ann_ivf_flat.cu b/cpp/test/spatial/ann_ivf_flat.cu index a049c3f428..7619218515 100644 --- a/cpp/test/spatial/ann_ivf_flat.cu +++ b/cpp/test/spatial/ann_ivf_flat.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include "./ann_base_kernel.cuh" +#include #include #include #include @@ -109,7 +110,7 @@ auto eval_knn(const std::vector& expected_idx, return testing::AssertionSuccess(); } -template +template class AnnIVFFlatTest : public ::testing::TestWithParam { public: AnnIVFFlatTest() @@ -124,14 +125,14 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { void testIVFFlat() { size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivfflat(queries_size); - std::vector indices_naive(queries_size); + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); std::vector distances_ivfflat(queries_size); std::vector distances_naive(queries_size); { rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); using acc_t = typename detail::utils::config::value_t; naiveBfKnn(distances_naive_dev.data(), indices_naive_dev.data(), @@ -155,7 +156,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); - rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); + rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); { // legacy interface @@ -206,10 +207,13 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 0.5; - auto index = - ivf_flat::build(handle_, index_params, database.data(), int64_t(ps.num_db_vecs), ps.dim); - rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + + auto index = ivf_flat::build_index(handle_, database_view, index_params); + + rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), thrust::device_pointer_cast(vector_indices.data()), thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); @@ -217,14 +221,16 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { int64_t half_of_data = ps.num_db_vecs / 2; - auto index_2 = - ivf_flat::extend(handle_, index, database.data(), nullptr, half_of_data); + auto half_of_data_view = raft::make_device_matrix_view( + (const DataT*)database.data(), static_cast(half_of_data), ps.dim); + + auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); - ivf_flat::extend(handle_, - &index_2, - database.data() + half_of_data * ps.dim, - vector_indices.data() + half_of_data, - int64_t(ps.num_db_vecs) - half_of_data); + ivf_flat::extend(handle_, + &index_2, + database.data() + half_of_data * ps.dim, + vector_indices.data() + half_of_data, + int64_t(ps.num_db_vecs) - half_of_data); ivf_flat::search(handle_, search_params, diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 8e831a25d9..ffd57a6b42 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -139,21 +139,22 @@ struct ToRadians { __device__ __host__ float operator()(float a) { return a * (CUDART_PI_F / 180.0); } }; +template struct BallCoverInputs { - uint32_t k; - uint32_t n_rows; - uint32_t n_cols; + value_int k; + value_int n_rows; + value_int n_cols; float weight; - uint32_t n_query; + value_int n_query; raft::distance::DistanceType metric; }; -template -class BallCoverKNNQueryTest : public ::testing::TestWithParam { +template +class BallCoverKNNQueryTest : public ::testing::TestWithParam> { protected: void basicTest() { - params = ::testing::TestWithParam::GetParam(); + params = ::testing::TestWithParam>::GetParam(); raft::handle_t handle; uint32_t k = params.k; @@ -201,20 +202,21 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { rmm::device_uvector d_pred_I(params.n_query * k, handle.get_stream()); rmm::device_uvector d_pred_D(params.n_query * k, handle.get_stream()); - auto X_view = raft::make_device_matrix_view( - handle, X.data(), params.n_rows, params.n_cols); - auto X2_view = raft::make_device_matrix_view( - handle, X2.data(), params.n_query, params.n_cols); + auto X_view = + raft::make_device_matrix_view(X.data(), params.n_rows, params.n_cols); + auto X2_view = raft::make_device_matrix_view( + (const value_t*)X2.data(), params.n_query, params.n_cols); auto d_pred_I_view = - raft::make_device_matrix_view(handle, params.n_query, k); + raft::make_device_matrix_view(d_pred_I.data(), params.n_query, k); auto d_pred_D_view = - raft::make_device_matrix_view(handle, params.n_query, k); + raft::make_device_matrix_view(d_pred_D.data(), params.n_query, k); - BallCoverIndex index(handle, X, metric); + BallCoverIndex index(handle, X_view, metric); raft::spatial::knn::rbc_build_index(handle, index); - raft::spatial::knn::rbc_knn_query(handle, index, k, X2, d_pred_I, d_pred_D, true, weight); + raft::spatial::knn::rbc_knn_query( + handle, index, X2_view, d_pred_I_view, d_pred_D_view, k, true, weight); handle.sync_stream(); // What we really want are for the distances to match exactly. The @@ -245,15 +247,15 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam { protected: uint32_t d = 2; - BallCoverInputs params; + BallCoverInputs params; }; -template -class BallCoverAllKNNTest : public ::testing::TestWithParam { +template +class BallCoverAllKNNTest : public ::testing::TestWithParam> { protected: void basicTest() { - params = ::testing::TestWithParam::GetParam(); + params = ::testing::TestWithParam>::GetParam(); raft::handle_t handle; uint32_t k = params.k; @@ -270,6 +272,9 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { rmm::device_uvector d_ref_I(params.n_rows * k, handle.get_stream()); rmm::device_uvector d_ref_D(params.n_rows * k, handle.get_stream()); + auto X_view = raft::make_device_matrix_view( + (const value_t*)X.data(), params.n_rows, params.n_cols); + if (metric == raft::distance::DistanceType::Haversine) { thrust::transform( handle.get_thrust_policy(), X.data(), X.data() + X.size(), X.data(), ToRadians()); @@ -292,11 +297,15 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { rmm::device_uvector d_pred_I(params.n_rows * k, handle.get_stream()); rmm::device_uvector d_pred_D(params.n_rows * k, handle.get_stream()); - BallCoverIndex index( - handle, X.data(), params.n_rows, params.n_cols, metric); + auto d_pred_I_view = + raft::make_device_matrix_view(d_pred_I.data(), params.n_query, k); + auto d_pred_D_view = + raft::make_device_matrix_view(d_pred_D.data(), params.n_query, k); + + BallCoverIndex index(handle, X_view, metric); raft::spatial::knn::rbc_all_knn_query( - handle, index, k, d_pred_I.data(), d_pred_D.data(), true, weight); + handle, index, d_pred_I_view, d_pred_D_view, k, true, weight); handle.sync_stream(); // What we really want are for the distances to match exactly. The @@ -330,13 +339,13 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam { void TearDown() override {} protected: - BallCoverInputs params; + BallCoverInputs params; }; typedef BallCoverAllKNNTest BallCoverAllKNNTestF; typedef BallCoverKNNQueryTest BallCoverKNNQueryTestF; -const std::vector ballcover_inputs = { +const std::vector> ballcover_inputs = { {11, 5000, 2, 1.0, 10000, raft::distance::DistanceType::Haversine}, {25, 10000, 2, 1.0, 5000, raft::distance::DistanceType::Haversine}, {2, 10000, 2, 1.0, 5000, raft::distance::DistanceType::L2SqrtUnexpanded}, From 428a9e692b615c8b79ac9989b117d7ab2560147d Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 19 Sep 2022 16:53:47 -0400 Subject: [PATCH 18/58] Fixing last compile error --- cpp/include/raft/spatial/knn/ivf_flat.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index 010586579c..e0058e93cf 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -375,7 +375,7 @@ void search(const handle_t& handle, int_t k) { RAFT_EXPECTS( - queries.extent(0) == neigbors.extent(0) && queries.extent(0) == distances.extent(0), + queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), "Number of rows in output neighbors and distances matrices must equal the number of queries."); RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1) && From a612bc26541e7b21a970019e53002698fb42356f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 20 Sep 2022 11:47:41 -0400 Subject: [PATCH 19/58] Updating ball cover specializations and API --- cpp/include/raft/spatial/knn/ball_cover.cuh | 9 +++-- .../knn/specializations/ball_cover.cuh | 15 +++++---- cpp/src/nn/specializations/ball_cover.cu | 13 ++++---- cpp/test/spatial/ball_cover.cu | 4 +-- cpp/test/spatial/epsilon_neighborhood.cu | 33 ++++++++++++------- cpp/test/spatial/knn.cu | 5 --- 6 files changed, 44 insertions(+), 35 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index 16cdc504fd..ad9658872f 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -87,9 +87,12 @@ void rbc_build_index(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template +template void rbc_all_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, int_t k, idx_t* inds, value_t* dists, @@ -157,7 +160,7 @@ template void rbc_all_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, raft::device_matrix_view inds, raft::device_matrix_view dists, int_t k = 5, diff --git a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh index 0c35bf4b9c..c859f2c5ec 100644 --- a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh @@ -25,15 +25,16 @@ namespace raft { namespace spatial { namespace knn { -extern template class BallCoverIndex; -extern template class BallCoverIndex; +extern template class BallCoverIndex; +extern template class BallCoverIndex; -extern template void rbc_build_index( - const raft::handle_t& handle, BallCoverIndex& index); +extern template void rbc_build_index( + const raft::handle_t& handle, + BallCoverIndex& index); extern template void rbc_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, @@ -42,9 +43,9 @@ extern template void rbc_knn_query( bool perform_post_filtering, float weight); -extern template void rbc_all_knn_query( +extern template void rbc_all_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, std::int64_t* inds, float* dists, diff --git a/cpp/src/nn/specializations/ball_cover.cu b/cpp/src/nn/specializations/ball_cover.cu index 87796752d9..7473b65d25 100644 --- a/cpp/src/nn/specializations/ball_cover.cu +++ b/cpp/src/nn/specializations/ball_cover.cu @@ -28,15 +28,16 @@ namespace raft { namespace spatial { namespace knn { -template class BallCoverIndex; -template class BallCoverIndex; +template class BallCoverIndex; +template class BallCoverIndex; -template void rbc_build_index( - const raft::handle_t& handle, BallCoverIndex& index); +template void rbc_build_index( + const raft::handle_t& handle, + BallCoverIndex& index); template void rbc_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, @@ -47,7 +48,7 @@ template void rbc_knn_query( template void rbc_all_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + BallCoverIndex& index, std::uint32_t k, std::int64_t* inds, float* dists, diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index ffd57a6b42..15f6b5fa87 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -298,9 +298,9 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam d_pred_D(params.n_rows * k, handle.get_stream()); auto d_pred_I_view = - raft::make_device_matrix_view(d_pred_I.data(), params.n_query, k); + raft::make_device_matrix_view(d_pred_I.data(), params.n_rows, k); auto d_pred_D_view = - raft::make_device_matrix_view(d_pred_D.data(), params.n_query, k); + raft::make_device_matrix_view(d_pred_D.data(), params.n_rows, k); BallCoverIndex index(handle, X_view, metric); diff --git a/cpp/test/spatial/epsilon_neighborhood.cu b/cpp/test/spatial/epsilon_neighborhood.cu index 515636ad8c..18f24e1800 100644 --- a/cpp/test/spatial/epsilon_neighborhood.cu +++ b/cpp/test/spatial/epsilon_neighborhood.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include #include +#include #include #include #include @@ -40,12 +41,18 @@ template template class EpsNeighTest : public ::testing::TestWithParam> { protected: - EpsNeighTest() : data(0, stream), adj(0, stream), labels(0, stream), vd(0, stream) {} + EpsNeighTest() + : data(0, handle.get_stream()), + adj(0, handle.get_stream()), + labels(0, handle.get_stream()), + vd(0, handle.get_stream()) + { + } void SetUp() override { - param = ::testing::TestWithParam>::GetParam(); - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + auto stream = handle.get_stream(); + param = ::testing::TestWithParam>::GetParam(); data.resize(param.n_row * param.n_col, stream); labels.resize(param.n_row, stream); batchSize = param.n_row / param.n_batches; @@ -73,6 +80,7 @@ class EpsNeighTest : public ::testing::TestWithParam> { rmm::device_uvector adj; rmm::device_uvector labels, vd; IdxT batchSize; + const raft::handle_t handle; }; // class EpsNeighTest const std::vector> inputsfi = { @@ -93,15 +101,16 @@ TEST_P(EpsNeighTestFI, Result) for (int i = 0; i < param.n_batches; ++i) { RAFT_CUDA_TRY(cudaMemsetAsync(adj.data(), 0, sizeof(bool) * param.n_row * batchSize, stream)); RAFT_CUDA_TRY(cudaMemsetAsync(vd.data(), 0, sizeof(int) * (batchSize + 1), stream)); - epsUnexpL2SqNeighborhood(adj.data(), - vd.data(), - data.data(), - data.data() + (i * batchSize * param.n_col), - param.n_row, - batchSize, - param.n_col, - param.eps * param.eps, - stream); + + auto adj_view = make_device_matrix_view(adj.data(), param.n_row, batchSize); + auto vd_view = make_device_vector_view(vd.data(), batchSize + 1); + auto x_view = make_device_matrix_view(data.data(), param.n_row, param.n_col); + auto y_view = make_device_matrix_view( + data.data() + (i * batchSize * param.n_col), batchSize, param.n_col); + + eps_neighbors_l2sq( + handle, x_view, y_view, adj_view, vd_view, param.eps * param.eps); + ASSERT_TRUE(raft::devArrMatch( param.n_row / param.n_centers, vd.data(), batchSize, raft::Compare(), stream)); } diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index acdf122acf..9c22a0cb73 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -96,11 +96,6 @@ class KNNTest : public ::testing::TestWithParam { auto distances = raft::make_device_matrix_view(distances_.data(), rows_, k_); - printf("indices: %ld, distances: %ld, search: %ld\n", - (size_t)indices.extent(0), - (size_t)(distances.extent(0)), - (size_t)search.extent(0)); - brute_force_knn(handle, index, search, indices, distances, k_); build_actual_output<<>>( From e63d121e9ff85337c63c12a6eb9d916e9b0114e0 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 20 Sep 2022 15:23:31 -0400 Subject: [PATCH 20/58] Removing stream destroy from eps neigh tests --- cpp/test/spatial/epsilon_neighborhood.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/test/spatial/epsilon_neighborhood.cu b/cpp/test/spatial/epsilon_neighborhood.cu index 18f24e1800..c83817f6f8 100644 --- a/cpp/test/spatial/epsilon_neighborhood.cu +++ b/cpp/test/spatial/epsilon_neighborhood.cu @@ -72,8 +72,6 @@ class EpsNeighTest : public ::testing::TestWithParam> { false); } - void TearDown() override { RAFT_CUDA_TRY(cudaStreamDestroy(stream)); } - EpsInputs param; cudaStream_t stream = 0; rmm::device_uvector data; From 44f7acafffe4c8e32fe9b9e7734103fa19461874 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 20 Sep 2022 16:32:47 -0400 Subject: [PATCH 21/58] Starting on col wise sort --- cpp/include/raft/matrix/col_wise_sort.cuh | 36 +++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index afdec24ebd..d6265a212e 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -18,6 +18,7 @@ #pragma once +#include #include namespace raft { @@ -50,6 +51,41 @@ void sort_cols_per_row(const InType* in, detail::sortColumnsPerRow( in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); } + + + +/** + * @brief sort columns within each row of row-major input matrix and return sorted indexes + * modelled as key-value sort with key being input matrix and value being index of values + * @param in: input matrix + * @param out: output value(index) matrix + * @param n_rows: number rows of input matrix + * @param n_columns: number columns of input matrix + * @param bAllocWorkspace: check returned value, if true allocate workspace passed in workspaceSize + * @param workspacePtr: pointer to workspace memory + * @param workspaceSize: Size of workspace to be allocated + * @param stream: cuda stream to execute prim on + * @param sortedKeys: Optional, output matrix for sorted keys (input) + */ +template +void sort_cols_per_row(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + std::optional> sorted_keys = std::nullptr) { + + RAFT_EXPECTS(in.extent(1) == out.extent(1) && + in.extent(0) == out.extent(0), "Input and output matrices must have the same shape."); + + if(sorted_keys.has_value()) { + RAFT_EXPECTS(in.extent(1) == sorted_keys.value().extent(1) && + in.extent(0) == sorted_keys.value().extent(0), "Input and `sorted_keys` matrices must have the same shape."); + } + + detail::sortColumnsPerRow( + in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); +} + + }; // end namespace matrix }; // end namespace raft From d1117433c59778983088168edacd62c20b455691 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 20 Sep 2022 16:38:33 -0400 Subject: [PATCH 22/58] Updating docs --- cpp/include/raft/spatial/knn/ivf_flat.cuh | 5 +++-- cpp/include/raft/spatial/knn/knn.cuh | 14 ++++---------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index e0058e93cf..88c08f77e6 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -137,12 +137,13 @@ auto build_index(const handle_t& handle, * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param orig_index original index + * @param[in] handle + * @param[in] orig_index original index * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. + * @param[in] n_rows number of rows in `new_vectors` * * @return the constructed extended ivf-flat index */ diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 31e26b6e77..e6e54253f6 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -80,14 +80,13 @@ inline void knn_merge_parts(value_t* in_keys, * * @tparam idx_t * @tparam value_t + * @param handle * @param in_keys * @param in_values * @param out_keys * @param out_values * @param n_samples - * @param n_parts * @param k - * @param stream * @param translations */ template @@ -239,6 +238,7 @@ inline void select_k(const value_t* in_keys, * @tparam value_t * the type of the keys (what is being compared). * + * @param[in] handle the cuml handle to use * @param[in] in_keys * contiguous device array of inputs of size (input_len * n_inputs); * these are compared and selected. @@ -248,22 +248,16 @@ inline void select_k(const value_t* in_keys, * You can pass `NULL` as an argument here; this would imply `in_values` is a homogeneous array * of indices from `0` to `input_len - 1` for every input and reduce the usage of memory * bandwidth. - * @param[in] n_inputs - * number of input rows, i.e. the batch size. - * @param[in] input_len - * length of a single input array (row); also sometimes referred as n_cols. - * Invariant: input_len >= k. * @param[out] out_keys * contiguous device array of outputs of size (k * n_inputs); * the k smallest/largest values from each row of the `in_keys`. * @param[out] out_values * contiguous device array of outputs of size (k * n_inputs); * the payload selected together with `out_keys`. - * @param[in] select_min - * whether to select k smallest (true) or largest (false) keys. * @param[in] k * the number of outputs to select in each input row. - * @param[in] stream + * @param[in] select_min + * whether to select k smallest (true) or largest (false) keys. * @param[in] algo * the implementation of the algorithm */ From 3a14fa530f5e5d5966279033b5a67df20002d954 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 21 Sep 2022 10:51:23 -0400 Subject: [PATCH 23/58] MOre gather and colwise sort --- cpp/include/raft/core/mdspan.hpp | 109 ++++++++++++++++ cpp/include/raft/matrix/col_wise_sort.cuh | 30 +++-- cpp/include/raft/matrix/gather.cuh | 150 +++++++++++++++++++++- cpp/test/matrix/columnSort.cu | 72 ++++------- cpp/test/matrix/gather.cu | 13 +- 5 files changed, 315 insertions(+), 59 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 7169a010b6..1d21849611 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -222,4 +222,113 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } +template +constexpr bool is_row_or_column_major(mdspan /* m */ ) +{ + return false; +} + +template +constexpr bool is_row_or_column_major(mdspan /* m */ ) +{ + return true; +} + +template +constexpr bool is_row_or_column_major(mdspan /* m */ ) +{ + return true; +} + +template +constexpr bool is_row_or_column_major(mdspan m) +{ + return m.is_exhaustive(); +} + + +template +constexpr bool is_row_major(mdspan /* m */ ) +{ + return false; +} + +template +constexpr bool is_row_major(mdspan /* m */ ) +{ + return false; +} + +template +constexpr bool is_row_major(mdspan /* m */ ) +{ + return true; +} + +template +constexpr bool is_row_major(mdspan m) +{ + return m.is_exhaustive(); +} + +template +constexpr bool is_col_major(mdspan /* m */ ) +{ + return false; +} + +template +constexpr bool is_col_major(mdspan /* m */ ) +{ + return true; +} + +template +constexpr bool is_col_major(mdspan /* m */ ) +{ + return false; +} + +template +constexpr bool is_col_major(mdspan m) +{ + return m.is_exhaustive(); +} + + +template +constexpr bool is_matrix_view(mdspan> /* m */) { + return true; +} + +template +constexpr bool is_matrix_view(mdspan m) { + return false; +} + +template +constexpr bool is_vector_view(mdspan> /* m */) { + return true; +} + +template +constexpr bool is_vector_view(mdspan m) { + return false; +} + +template +constexpr bool is_scalar_view(mdspan> /* m */) { + return true; +} + +template +constexpr bool is_scalar_view(mdspan m) { + return false; +} + + + + + + } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index d6265a212e..06e0adaab0 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -18,6 +18,7 @@ #pragma once +#include #include #include @@ -59,13 +60,7 @@ void sort_cols_per_row(const InType* in, * modelled as key-value sort with key being input matrix and value being index of values * @param in: input matrix * @param out: output value(index) matrix - * @param n_rows: number rows of input matrix - * @param n_columns: number columns of input matrix - * @param bAllocWorkspace: check returned value, if true allocate workspace passed in workspaceSize - * @param workspacePtr: pointer to workspace memory - * @param workspaceSize: Size of workspace to be allocated - * @param stream: cuda stream to execute prim on - * @param sortedKeys: Optional, output matrix for sorted keys (input) + * @param sorted_keys: Optional, output matrix for sorted keys (input) */ template void sort_cols_per_row(const raft::handle_t &handle, @@ -81,8 +76,27 @@ void sort_cols_per_row(const raft::handle_t &handle, in.extent(0) == sorted_keys.value().extent(0), "Input and `sorted_keys` matrices must have the same shape."); } + size_t workspace_size = 0; + bool alloc_workspace = false; + detail::sortColumnsPerRow( - in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); + in.data_handle(), out.data_handle(), + in.extent(0), in.extent(1), + alloc_workspace, + nullptr, &workspace_size, handle.get_stream(), + sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); + + if(alloc_workspace) { + + auto workspace = raft::make_device_vector(handle, workspace_size); + + detail::sortColumnsPerRow( + in.data_handle(), out.data_handle(), + in.extent(0), in.extent(1), + alloc_workspace, + workspace.data_handle(), &workspace_size, handle.get_stream(), + sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); + } } diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 31164b2041..384de604c8 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -15,6 +15,8 @@ */ #pragma once + +#include #include namespace raft { @@ -49,6 +51,60 @@ void gather(const MatrixIteratorT in, detail::gather(in, D, N, map, map_length, out, stream); } +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param map Pointer to the input sequence of gather locations + * @param out Pointer to the output matrix (assumed to be row-major) + */ +template +void gather(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map) { + + RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); + + detail::gather(in.data_handle(), in.extent(1), in.extent(0), map, map.extent(0), out.data_handle(), handle.get_stream()); +} + + +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a + * transformed map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to IndexT (= int) type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param map Pointer to the input sequence of gather locations + * @param out Pointer to the output matrix (assumed to be row-major) + * @param transform_op The transformation operation, transforms the map values to IndexT + */ +template +void gather(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out + raft::device_vector_view map, + MapTransformOp transform_op) { + + RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); + + detail::gather(in.data_handle(), in.extent(1), in.extent(0), map, map.extent(0), out.data_handle(), transform_op, handle.get_stream()); +} + /** * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. @@ -72,17 +128,16 @@ void gather(const MatrixIteratorT in, */ template void gather(const MatrixIteratorT in, - int D, - int N, MapIteratorT map, int map_length, MatrixIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { - detail::gather(in, D, N, map, map_length, out, transform_op, stream); + detail::gather(in, D, N, map, map_length, out, transform_op, stream); } + /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. @@ -124,6 +179,49 @@ void gather_if(const MatrixIteratorT in, detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, stream); } + + +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a + * simple pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * type must be convertible to bool type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param map Pointer to the input sequence of gather locations + * @param stencil Pointer to the input sequence of stencil or predicate values + * @param out Pointer to the output matrix (assumed to be row-major) + * @param pred_op Predicate to apply to the stencil values + */ +template +void gather_if(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map, + raft::device_vector_view stencil, + UnaryPredicateOp pred_op) { + + RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); + RAFT_EXPECTS(map.extent(0) == stencil.extent(0), "Number of elements in stencil must equal number of elements in map"); + + detail::gather_if(in.data_handle(), out.extent(1), out.extent(0), + map.data_handle(), stencil.data_handle(), map.extent(0), + out.data_handle(), pred_op, handle.get_stream()); +} + + /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. @@ -169,5 +267,51 @@ void gather_if(const MatrixIteratorT in, { detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } + +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a transformed map. + * + * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * simple pointer type). + * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple + * pointer type). + * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a + * simple pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * type must be convertible to bool type. + * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to IndexT (= int) type. + * + * @param in Pointer to the input matrix (assumed to be row-major) + * @param map Pointer to the input sequence of gather locations + * @param stencil Pointer to the input sequence of stencil or predicate values + * @param out Pointer to the output matrix (assumed to be row-major) + * @param pred_op Predicate to apply to the stencil values + * @param transform_op The transformation operation, transforms the map values to IndexT + */ +template +void gather_if(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map, + raft::device_vector_view stencil, + UnaryPredicateOp pred_op, + MapTransformOp transform_op) +{ + + RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); + RAFT_EXPECTS(map.extent(0) == stencil.extent(0), "Number of elements in stencil must equal number of elements in map"); + + detail::gather_if(in.data_handle(), in.extent(1), in.extent(0), map.data_handle(), stencil.data_handle(), + map.extent(0), out.data_handle(), pred_op, transform_op, handle.get_stream()); +} + } // namespace matrix } // namespace raft diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index 325ed0204b..7642a4db7d 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -46,7 +47,7 @@ struct columnSort { }; template -::std::ostream& operator<<(::std::ostream& os, const columnSort& dims) +::std::ohandle.get_stream()& operator<<(::std::ohandle.get_stream()& os, const columnSort& dims) { return os; } @@ -55,12 +56,12 @@ template class ColumnSort : public ::testing::TestWithParam> { protected: ColumnSort() - : keyIn(0, stream), - keySorted(0, stream), - keySortGolden(0, stream), - valueOut(0, stream), - goldenValOut(0, stream), - workspacePtr(0, stream) + : keyIn(0, handle.get_stream()), + keySorted(0, handle.get_stream()), + keySortGolden(0, handle.get_stream()), + valueOut(0, handle.get_stream()), + goldenValOut(0, handle.get_stream()), + workspacePtr(0, handle.get_stream()) { } @@ -68,13 +69,13 @@ class ColumnSort : public ::testing::TestWithParam> { { params = ::testing::TestWithParam>::GetParam(); int len = params.n_row * params.n_col; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - keyIn.resize(len, stream); - valueOut.resize(len, stream); - goldenValOut.resize(len, stream); + RAFT_CUDA_TRY(cudahandle.get_stream()Create(&handle.get_stream())); + keyIn.resize(len, handle.get_stream()); + valueOut.resize(len, handle.get_stream()); + goldenValOut.resize(len, handle.get_stream()); if (params.testKeys) { - keySorted.resize(len, stream); - keySortGolden.resize(len, stream); + keySorted.resize(len, handle.get_stream()); + keySortGolden.resize(len, handle.get_stream()); } std::vector vals(len); @@ -97,45 +98,26 @@ class ColumnSort : public ::testing::TestWithParam> { } } - raft::update_device(keyIn.data(), &vals[0], len, stream); - raft::update_device(goldenValOut.data(), &cValGolden[0], len, stream); - - if (params.testKeys) raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, stream); - - bool needWorkspace = false; - size_t workspaceSize = 0; - // Remove this branch once the implementation of descending sort is fixed. - sort_cols_per_row(keyIn.data(), - valueOut.data(), - params.n_row, - params.n_col, - needWorkspace, - NULL, - workspaceSize, - stream, - keySorted.data()); - if (needWorkspace) { - workspacePtr.resize(workspaceSize, stream); - sort_cols_per_row(keyIn.data(), - valueOut.data(), - params.n_row, - params.n_col, - needWorkspace, - workspacePtr.data(), - workspaceSize, - stream, - keySorted.data()); - } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); + raft::update_device(keyIn.data(), &vals[0], len, handle.get_stream()); + raft::update_device(goldenValOut.data(), &cValGolden[0], len, handle.get_stream()); + + if (params.testKeys) raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, handle.get_stream()); + + auto key_in_view = raft::make_device_matrix_view(keyIn.data(), params.n_row, params.n_col); + auto value_out_view = raft::make_device_matrix_view(valueOut.data(), params.n_row, params.n_col); + auto key_sorted_view = raft::make_device_matrix_view(keySorted.data(), params.n_row, params.n_col); + + raft::matrix::sort_cols_per_row(handle, key_in_view, value_out_view, key_sorted_view); + + RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } protected: - cudaStream_t stream = 0; columnSort params; rmm::device_uvector keyIn, keySorted, keySortGolden; rmm::device_uvector valueOut, goldenValOut; // valueOut are indexes rmm::device_uvector workspacePtr; + raft::handle_t handle; }; const std::vector> inputsf1 = {{0.000001f, 503, 2000, false}, diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index adedaacc81..da2057d4f6 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include #include @@ -46,7 +47,8 @@ void naiveGather( } template -void gatherLaunch(MatrixIteratorT in, +void gatherLaunch(const raft::handle_t &handle, + MatrixIteratorT in, int D, int N, MapIteratorT map, @@ -55,7 +57,12 @@ void gatherLaunch(MatrixIteratorT in, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; - matrix::gather(in, D, N, map, map_length, out, stream); + + auto in_view = raft::make_device_matrix_view(in, N, D); + auto map_view = raft::make_device_vector_view(map, map_length); + auto out_view = raft::make_device_matrix_view(out, N, D); + + matrix::gather(handle, in_view, out_view, map_view); } struct GatherInputs { @@ -110,7 +117,7 @@ class GatherTest : public ::testing::TestWithParam { raft::update_device(d_out_exp.data(), h_out.data(), map_length * ncols, stream); // launch device version of the kernel - gatherLaunch(d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); + gatherLaunch(handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); handle.sync_stream(stream); } From f27b0cb22174a6e4e1a4a8dcc11e96f9f0daa92a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 21 Sep 2022 17:04:18 -0400 Subject: [PATCH 24/58] Breaking matrix functions out into individual files. --- cpp/include/raft/matrix/argmax.cuh | 41 +++++++ cpp/include/raft/matrix/copy.cuh | 87 +++++++++++++ cpp/include/raft/matrix/detail/matrix.cuh | 1 - cpp/include/raft/matrix/detail/print.hpp | 48 ++++++++ cpp/include/raft/matrix/diagonal.cuh | 53 ++++++++ cpp/include/raft/matrix/init.cuh | 39 ++++++ cpp/include/raft/matrix/linewise_op.cuh | 61 ++++++++++ cpp/include/raft/matrix/math.cuh | 10 ++ cpp/include/raft/matrix/matrix.cuh | 30 +++++ cpp/include/raft/matrix/matrix_vector.cuh | 142 ++++++++++++++++++++++ cpp/include/raft/matrix/norm.cuh | 39 ++++++ cpp/include/raft/matrix/power.cuh | 81 ++++++++++++ cpp/include/raft/matrix/print.cuh | 51 ++++++++ cpp/include/raft/matrix/print.hpp | 32 +++++ cpp/include/raft/matrix/ratio.cuh | 42 +++++++ cpp/include/raft/matrix/reciprocal.cuh | 67 ++++++++++ cpp/include/raft/matrix/reverse.cuh | 84 +++++++++++++ cpp/include/raft/matrix/seq_root.cuh | 92 ++++++++++++++ cpp/include/raft/matrix/sign_flip.cuh | 36 ++++++ cpp/include/raft/matrix/slice.cuh | 44 +++++++ cpp/include/raft/matrix/threshold.cuh | 58 +++++++++ cpp/include/raft/matrix/triangular.cuh | 40 ++++++ 22 files changed, 1177 insertions(+), 1 deletion(-) create mode 100644 cpp/include/raft/matrix/argmax.cuh create mode 100644 cpp/include/raft/matrix/copy.cuh create mode 100644 cpp/include/raft/matrix/detail/print.hpp create mode 100644 cpp/include/raft/matrix/diagonal.cuh create mode 100644 cpp/include/raft/matrix/init.cuh create mode 100644 cpp/include/raft/matrix/linewise_op.cuh create mode 100644 cpp/include/raft/matrix/matrix_vector.cuh create mode 100644 cpp/include/raft/matrix/norm.cuh create mode 100644 cpp/include/raft/matrix/power.cuh create mode 100644 cpp/include/raft/matrix/print.cuh create mode 100644 cpp/include/raft/matrix/print.hpp create mode 100644 cpp/include/raft/matrix/ratio.cuh create mode 100644 cpp/include/raft/matrix/reciprocal.cuh create mode 100644 cpp/include/raft/matrix/reverse.cuh create mode 100644 cpp/include/raft/matrix/seq_root.cuh create mode 100644 cpp/include/raft/matrix/sign_flip.cuh create mode 100644 cpp/include/raft/matrix/slice.cuh create mode 100644 cpp/include/raft/matrix/threshold.cuh create mode 100644 cpp/include/raft/matrix/triangular.cuh diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh new file mode 100644 index 0000000000..fd52441c15 --- /dev/null +++ b/cpp/include/raft/matrix/argmax.cuh @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Argmax: find the row idx with maximum value for each column + * @param in: input matrix + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param out: output vector of size n_cols + * @param stream: cuda stream + */ + template + void argmax(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_vector_view out) { + + RAFT_EXPECTS(out.extent(1) == in.extent(1), "Size of output vector must equal number of columns in input matrix."); + detail::argmax(in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); + } +} diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh new file mode 100644 index 0000000000..2798fe7491 --- /dev/null +++ b/cpp/include/raft/matrix/copy.cuh @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Copy selected rows of the input matrix into contiguous space. + * + * On exit out[i + k*n_rows] = in[indices[i] + k*n_rows], + * where i = 0..n_rows_indices-1, and k = 0..n_cols-1. + * + * @param[in] handle raft handle + * @param[in] in input matrix + * @param[out] out output matrix + * @param[in] indices of the rows to be copied + */ +template +void copy_rows(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view indices) +{ + RAFT_EXPECTS(in.extent(1) == out.extent(1), "Input and output matrices must have same number of columns"); + RAFT_EXPECTS(indices.extent(0) == out.extent(0), "Number of rows in output matrix must equal number of indices"); + bool in_rowmajor = raft::is_row_major(in); + bool out_rowmajor = raft::is_row_major(out); + + RAFT_EXPECTS(in_rowmajor == out_rowmajor, "Input and output matrices must have same layout (row- or column-major)") + + detail::copyRows(in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), indices.data_handle(), indices.extent(0), handle.get_stream()); +} + +/** + * @brief copy matrix operation for column major matrices. + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix + */ +template +void copy(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.extent(0) == out.extent(0) && + in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); + + raft::copy_async(out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); +} + +/** + * @brief copy matrix operation for column major matrices. First n_rows and + * n_cols of input matrix "in" is copied to "out" matrix. + * @param in: input matrix + * @param in_n_rows: number of rows of input matrix + * @param out: output matrix + * @param out_n_rows: number of rows of output matrix + * @param out_n_cols: number of columns of output matrix + * @param stream: cuda stream + */ +template +void trunc_zero_origin( + m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) +{ + detail::truncZeroOrigin(in, in_n_rows, out, out_n_rows, out_n_cols, stream); +} + + +} diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index a8568b0859..1b343cf5b4 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -279,7 +279,6 @@ m_t getL2Norm(const raft::handle_t& handle, m_t* in, idx_t size, cudaStream_t st { cublasHandle_t cublasH = handle.get_cublas_handle(); m_t normval = 0; - // #TODO: Call from the public API when ready RAFT_CUBLAS_TRY(raft::linalg::detail::cublasnrm2(cublasH, size, in, 1, &normval, stream)); return normval; } diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp new file mode 100644 index 0000000000..d2b7b1100e --- /dev/null +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace raft::matrix::detail { + +template +void printHost(const m_t* in, idx_t n_rows, idx_t n_cols) +{ + for (idx_t i = 0; i < n_rows; i++) { + for (idx_t j = 0; j < n_cols; j++) { + printf("%1.4f ", in[j * n_rows + i]); + } + printf("\n"); + } +} + +} // end namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/diagonal.cuh b/cpp/include/raft/matrix/diagonal.cuh new file mode 100644 index 0000000000..69bce231d3 --- /dev/null +++ b/cpp/include/raft/matrix/diagonal.cuh @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Initialize a diagonal matrix with a vector + * @param vec: vector of length k = min(n_rows, n_cols) + * @param matrix: matrix of size n_rows x n_cols + */ +template +void initialize_diagonal( + const raft::handle_t &handle, + raft::device_vector_view vec, + raft::device_matrix_view matrix) { + detail::initializeDiagonalMatrix(vec.data_handle(), + matrix.data_handle(), + matrix.extent(0), + matrix.extent(1), + handle.get_stream()); +} + +/** + * @brief Take reciprocal of elements on diagonal of square matrix (in-place) + * @param in: square input matrix with size len x len + */ +template +void invert_diagonal(const raft::handle_t &handle, + raft::device_matrix_view in) +{ + RAFT_EXPECTS(in.extent(0) == in.extent(1), "Matrix must be square."); + detail::getDiagonalInverseMatrix(in.data_handle(), in.extent(0), handle.get_stream()); +} +} diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh new file mode 100644 index 0000000000..7d27689227 --- /dev/null +++ b/cpp/include/raft/matrix/init.cuh @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { +/** + * @brief set values to scalar in matrix + * @tparam math_t data-type upon which the math operation will be performed + * @param handle: raft handle + * @param in input matrix + * @param out output matrix. The result is stored in the out matrix + * @param scalar svalar value + */ +template +void fill(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, math_t scalar) { + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); + detail::setValue(out.data_handle(), in.data_handle(), scalar, in.size(), handle.get_stream()); +} +} diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh new file mode 100644 index 0000000000..adf7732954 --- /dev/null +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + + /** + * Run a function over matrix lines (rows or columns) with a variable number + * row-vectors or column-vectors. + * The term `line` here signifies that the lines can be either columns or rows, + * depending on the matrix layout. + * What matters is if the vectors are applied along lines (indices of vectors correspond to + * indices within lines), or across lines (indices of vectors correspond to line numbers). + * + * @param [out] out result of the operation; can be same as `in`; should be aligned the same + * as `in` to allow faster vectorized memory transfers. + * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. + * @param [in] lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in + * col-major) + * @param [in] nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) + * @param [in] alongLines whether vectors are indices along or across lines. + * @param [in] op the operation applied on each line: + * for i in [0..lineLen) and j in [0..nLines): + * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true + * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false + * where matrix indexing is row-major ([i, j] = [i + lineLen * j]). + * @param [in] stream a cuda stream for the kernels + * @param [in] vecs zero or more vectors to be passed as arguments, + * size of each vector is `alongLines ? lineLen : nLines`. + */ +template +void linewise_op(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + const idx_t lineLen, + const idx_t nLines, + const bool alongLines, + Lambda op, + raft::device_vector_view... vecs) { + detail::MatrixLinewiseOp<16, 256>::run( + out, in, lineLen, nLines, alongLines, op, stream, vecs...); +} +} diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index 9e103afda5..6fb15e6d8b 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -14,6 +14,16 @@ * limitations under the License. */ +/** + * This file is deprecated and will be removed in a future release. + * Please use versions in individual header files instead. + */ + +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use versions in individual header files instead.") + + #ifndef __MATH_H #define __MATH_H diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh index 1af7e37dec..588399655b 100644 --- a/cpp/include/raft/matrix/matrix.cuh +++ b/cpp/include/raft/matrix/matrix.cuh @@ -14,11 +14,21 @@ * limitations under the License. */ +/** + * This file is deprecated and will be removed in a future release. + * Please use versions in individual header files instead. + */ + +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use versions in individual header files instead.") + #ifndef __MATRIX_H #define __MATRIX_H #pragma once +#include #include "detail/linewise_op.cuh" #include "detail/matrix.cuh" @@ -57,6 +67,9 @@ void copyRows(const m_t* in, detail::copyRows(in, n_rows, n_cols, out, indices, n_rows_indices, stream, rowMajor); } + + + /** * @brief copy matrix operation for column major matrices. * @param in: input matrix @@ -71,6 +84,23 @@ void copy(const m_t* in, m_t* out, idx_t n_rows, idx_t n_cols, cudaStream_t stre raft::copy_async(out, in, n_rows * n_cols, stream); } +/** + * @brief copy matrix operation for column major matrices. + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix + */ +template +void copy(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.extent(0) == out.extent(0) && + in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); + + raft::copy_async(out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); +} + /** * @brief copy matrix operation for column major matrices. First n_rows and * n_cols of input matrix "in" is copied to "out" matrix. diff --git a/cpp/include/raft/matrix/matrix_vector.cuh b/cpp/include/raft/matrix/matrix_vector.cuh new file mode 100644 index 0000000000..b051831a15 --- /dev/null +++ b/cpp/include/raft/matrix/matrix_vector.cuh @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "detail/matrix.cuh" + +namespace raft::matrix { + + /** + * @brief multiply each row or column of matrix with vector, skipping zeros in vector + * @param data input matrix, results are in-place + * @param vec input vector + * @param n_row number of rows of input matrix + * @param n_col number of columns of input matrix + * @param rowMajor whether matrix is row major + * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns + * @param stream cuda stream + */ +template +void binary_mult_skip_zero(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) +{ + detail::matrixVectorBinaryMultSkipZero( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); +} + +/** + * @brief divide each row or column of matrix with vector + * @param data input matrix, results are in-place + * @param vec input vector + * @param n_row number of rows of input matrix + * @param n_col number of columns of input matrix + * @param rowMajor whether matrix is row major + * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns + * @param stream cuda stream + */ + template + void binary_div(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) + { + detail::matrixVectorBinaryDiv( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); + } + + /** + * @brief divide each row or column of matrix with vector, skipping zeros in vector + * @param data input matrix, results are in-place + * @param vec input vector + * @param n_row number of rows of input matrix + * @param n_col number of columns of input matrix + * @param rowMajor whether matrix is row major + * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns + * @param stream cuda stream + * @param return_zero result is zero if true and vector value is below threshold, original value if + * false + */ + template + void binary_div_skip_zero(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream, + bool return_zero = false) + { + detail::matrixVectorBinaryDivSkipZero( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream, return_zero); + } + + /** + * @brief add each row or column of matrix with vector + * @param data input matrix, results are in-place + * @param vec input vector + * @param n_row number of rows of input matrix + * @param n_col number of columns of input matrix + * @param rowMajor whether matrix is row major + * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns + * @param stream cuda stream + */ + template + void binary_add(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) + { + detail::matrixVectorBinaryAdd( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); + } + + /** + * @brief subtract each row or column of matrix with vector + * @param data input matrix, results are in-place + * @param vec input vector + * @param n_row number of rows of input matrix + * @param n_col number of columns of input matrix + * @param rowMajor whether matrix is row major + * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns + * @param stream cuda stream + */ + template + void binary_sub(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) + { + detail::matrixVectorBinarySub( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); + } + +} \ No newline at end of file diff --git a/cpp/include/raft/matrix/norm.cuh b/cpp/include/raft/matrix/norm.cuh new file mode 100644 index 0000000000..f4814ccbe9 --- /dev/null +++ b/cpp/include/raft/matrix/norm.cuh @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + + +/** + * @brief Get the L2/F-norm of a matrix + * @param handle + * @param in: input matrix/vector with totally size elements + * @param size: size of the matrix/vector + * @param stream: cuda stream + */ +template +m_t l2_norm(const raft::handle_t& handle, + raft::device_mdspan in) +{ + return detail::getL2Norm(handle, in.data_handle(), in.size(), handle.get_stream()); +} +} diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh new file mode 100644 index 0000000000..1c38f9c2f8 --- /dev/null +++ b/cpp/include/raft/matrix/power.cuh @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Power of every element in the input matrix + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix. The result is stored in the out matrix + * @param[in] scalar: every element is multiplied with scalar. + */ +template +void weighted_power( + const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t scalar) { + RAFT_EXPECTS(in.size() == out.size(), "Size of input and output matrices must be equal"); + detail::power(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream()); +} + +/** + * @brief Power of every element in the input matrix (inplace) + * @param[inout] inout: input matrix and also the result is stored + * @param[in] scalar: every element is multiplied with scalar. + */ +template +void weighted_power(const raft::handle_t &handle, + raft::device_matrix_view inout, + math_t scalar) { + detail::power(inout.data_handle(), scalar, inout.size(), handle.get_stream()); +} + +/** + * @brief Power of every element in the input matrix (inplace) + * @param[inout] inout: input matrix and also the result is stored + */ +template +void power(const raft::handle_t &handle, + raft::device_matrix_view inout) { + detail::power(inout.data_handle(), inout.size(), handle.get_stream()); +} + +/** + * @brief Power of every element in the input matrix + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix. The result is stored in the out matrix + * @{ + */ +template +void power(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out) { + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); + detail::power(in, out, len, stream); +} + + + +} diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh new file mode 100644 index 0000000000..d7f978ec13 --- /dev/null +++ b/cpp/include/raft/matrix/print.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::matrix { + + +/** + * @brief Prints the data stored in GPU memory + * @param handle: raft handle + * @param in: input matrix + * @param h_separator: horizontal separator character + * @param v_separator: vertical separator character + */ +template +void print(const raft::handle_t &handle, + raft::device_matrix_view in, + char h_separator = ' ', + char v_separator = '\n') +{ + detail::print(in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator, handle.get_stream()); +} + +/** + * @brief Prints the data stored in CPU memory + * @param in: input matrix with column-major layout + */ +template +void print(raft::host_matrix_view in) { + detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); +} +} diff --git a/cpp/include/raft/matrix/print.hpp b/cpp/include/raft/matrix/print.hpp new file mode 100644 index 0000000000..cc80671cec --- /dev/null +++ b/cpp/include/raft/matrix/print.hpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::matrix { + +/** + * @brief Prints the data stored in CPU memory + * @param in: input matrix with column-major layout + */ +template +void print(raft::host_matrix_view in) { + detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); +} +} diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh new file mode 100644 index 0000000000..78917ee684 --- /dev/null +++ b/cpp/include/raft/matrix/ratio.cuh @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief ratio of every element over sum of input vector is calculated + * @tparam math_t data-type upon which the math operation will be performed + * @tparam IdxType Integer type used to for addressing + * @param handle + * @param src: input matrix + * @param dest: output matrix. The result is stored in the dest matrix + * @param len: number elements of input matrix + * @param stream cuda stream + */ +template +void ratio(const raft::handle_t& handle, + raft::device_matrix_view src, + raft::device_matrix_view dest) { + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); + detail::ratio(handle, src.data_handle(), dest.data_handle(), in.size(), handle.get_stream()); +} +} diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh new file mode 100644 index 0000000000..66c4f9bf5f --- /dev/null +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Reciprocal of every element in the input matrix + * @tparam math_t data-type upon which the math operation will be performed + * @tparam IdxType Integer type used to for addressing + * @param handle: raft handle + * @param in: input matrix and also the result is stored + * @param out: output matrix. The result is stored in the out matrix + * @param scalar: every element is multiplied with scalar + * @param setzero round down to zero if the input is less the threshold + * @param thres the threshold used to forcibly set inputs to zero + * @{ + */ +template +void reciprocal(raft::device_matrix_view in, + raft::device_matrix_view out, + math_t scalar, + bool setzero = false, + math_t thres = 1e-15) { + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have the same size."); + detail::reciprocal(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), setzero, thres); +} + +/** + * @brief Reciprocal of every element in the input matrix (in place) + * @tparam math_t data-type upon which the math operation will be performed + * @tparam IdxType Integer type used to for addressing + * @param inout: input matrix with in-place results + * @param scalar: every element is multiplied with scalar + * @param len: number elements of input matrix + * @param stream cuda stream + * @param setzero round down to zero if the input is less the threshold + * @param thres the threshold used to forcibly set inputs to zero + * @{ + */ +template +void reciprocal(const raft::handle_t &handle, + raft::device_matrix_view inout, + math_t scalar, + bool setzero = false, + math_t thres = 1e-15) { + detail::reciprocal(inout.data_handle(), scalar, inout.size(), handle.get_stream(), setzero, thres); +} +} diff --git a/cpp/include/raft/matrix/reverse.cuh b/cpp/include/raft/matrix/reverse.cuh new file mode 100644 index 0000000000..2416a81636 --- /dev/null +++ b/cpp/include/raft/matrix/reverse.cuh @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Columns of a column major matrix are reversed in place (i.e. first column and + * last column are swapped) + * @param inout: input and output matrix + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param stream: cuda stream + */ +template +void col_reverse(const raft::handle_t &handle, + raft::device_matrix_view inout) { + detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); +} + +/** + * @brief Columns of a column major matrix are reversed in place (i.e. first column and + * last column are swapped) + * @param inout: input and output matrix + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param stream: cuda stream + */ +template +void col_reverse(const raft::handle_t &handle, + raft::device_matrix_view inout) { + detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); +} + + +/** + * @brief Rows of a column major matrix are reversed in place (i.e. first row and last + * row are swapped) + * @param inout: input and output matrix + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param stream: cuda stream + */ +template +void row_reverse(const raft::handle_t &handle, + raft::device_matrix_view inout) +{ + detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); +} + +/** + * @brief Rows of a column major matrix are reversed in place (i.e. first row and last + * row are swapped) + * @param inout: input and output matrix + * @param n_rows: number of rows of input matrix + * @param n_cols: number of columns of input matrix + * @param stream: cuda stream + */ +template +void row_reverse(const raft::handle_t &handle, + raft::device_matrix_view inout) +{ + detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); +} + +} diff --git a/cpp/include/raft/matrix/seq_root.cuh b/cpp/include/raft/matrix/seq_root.cuh new file mode 100644 index 0000000000..9af90bd842 --- /dev/null +++ b/cpp/include/raft/matrix/seq_root.cuh @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Square root of every element in the input matrix + * @tparam math_t data-type upon which the math operation will be performed + * @param[in] handle: raft handle + * @param[in] in: input matrix and also the result is stored + * @param[out] out: output matrix. The result is stored in the out matrix + */ +template +void seq_root(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out) { + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); + detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); +} + +/** + * @brief Square root of every element in the input matrix (in place) + * @tparam math_t data-type upon which the math operation will be performed + * @param[in] handle: raft handle + * @param[inout] inout: input matrix with in-place results + */ +template +void seq_root(const raft::handle_t &handle, + raft::device_matrix_view inout) +{ + detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream()); +} + +/** + * @brief Square root of every element in the input matrix + * @tparam math_t data-type upon which the math operation will be performed + * @param[in] handle: raft handle + * @param[in] in: input matrix and also the result is stored + * @param[out] out: output matrix. The result is stored in the out matrix + * @param[in] scalar: every element is multiplied with scalar + * @param[in] set_neg_zero whether to set negative numbers to zero + */ +template +void weighted_seq_root(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t scalar, + bool set_neg_zero = false) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); + detail::seqRoot(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), set_neg_zero); +} + +/** + * @brief Square root of every element in the input matrix (in place) + * @tparam math_t data-type upon which the math operation will be performed + * @param handle: raft handle + * @param inout: input matrix and also the result is stored + * @param scalar: every element is multiplied with scalar + * @param set_neg_zero whether to set negative numbers to zero + */ +template +void weighted_seq_root( + const raft::handle_t &handle, + raft::device_matrix_view inout, + math_t scalar, bool set_neg_zero = false) +{ + detail::seqRoot(inout.data_handle(), scalar, inout.size(), handle.get_stream(), set_neg_zero); +} + + + +} diff --git a/cpp/include/raft/matrix/sign_flip.cuh b/cpp/include/raft/matrix/sign_flip.cuh new file mode 100644 index 0000000000..479f93cd59 --- /dev/null +++ b/cpp/include/raft/matrix/sign_flip.cuh @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief sign flip stabilizes the sign of col major eigen vectors. + * The sign is flipped if the column has negative |max|. + * @param handle: raft handle + * @param inout: input matrix. Result also stored in this parameter + */ +template +void sign_flip(const raft::handle_t &handle, + raft::device_matrix_view inout) { + detail::signFlip(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); +} +} diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh new file mode 100644 index 0000000000..43209d4054 --- /dev/null +++ b/cpp/include/raft/matrix/slice.cuh @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Slice a matrix (in-place) + * @param handle: raft handle + * @param in: input matrix (column-major) + * @param out: output matrix (column-major) + * @param x1, y1: coordinate of the top-left point of the wanted area (0-based) + * @param x2, y2: coordinate of the bottom-right point of the wanted area + * (1-based) + * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice_matrix(M_d, 4, + * 3, 0, 1, 4, 3); + */ +template +void slice(const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + idx_t x1, idx_t y1, idx_t x2, idx_t y2) { + detail::sliceMatrix(in.data_handle(), in.extent(0), in.extent(1), + out.data_handle(), x1, y1, x2, y2, handle.get_stream()); +} +} diff --git a/cpp/include/raft/matrix/threshold.cuh b/cpp/include/raft/matrix/threshold.cuh new file mode 100644 index 0000000000..6959392ca9 --- /dev/null +++ b/cpp/include/raft/matrix/threshold.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief sets the small values to zero based on a defined threshold + * @tparam math_t data-type upon which the math operation will be performed + * @param handle: raft handle + * @param in: input matrix + * @param out: output matrix. The result is stored in the out matrix + * @param thres threshold to set values to zero + */ +template +void zero_small_values( + const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t thres = 1e-15) { + + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size"); + detail::setSmallValuesZero(out.data_handle(), in.data_handle(), in.size(), handle.get_stream(), thres); +} + +/** + * @brief sets the small values to zero in-place based on a defined threshold + * @tparam math_t data-type upon which the math operation will be performed + * @param handle: raft handle + * @param inout: input matrix and also the result is stored + * @param thres: threshold + */ +template +void zero_small_values( + const raft::handle_t &handle, + raft::device_matrix_view inout, + math_t thres = 1e-15) { + detail::setSmallValuesZero(inout.data_handle(), inout.size(), handle.get_stream(), thres); +} +} diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh new file mode 100644 index 0000000000..7dcf6e39b4 --- /dev/null +++ b/cpp/include/raft/matrix/triangular.cuh @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Copy the upper triangular part of a matrix to another + * @param[in] handle: raft handle + * @param[in] src: input matrix with a size of n_rows x n_cols + * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) + */ +template +void upper_triangular(const raft::handle_t &handle, + raft::device_matrix_view src, + raft::device_matrix_view dst) { + + detail::copyUpperTriangular(src.data_handle(), dst.data_handle(), + src.extent(0), src.extent(1), + handle.get_stream()); +} +} From 9002337ca114017f8600c8f39a34b087b60bd778 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 21 Sep 2022 21:18:55 -0400 Subject: [PATCH 25/58] Fixing style --- cpp/include/raft/core/mdspan.hpp | 110 ++++++++++--------- cpp/include/raft/matrix/argmax.cuh | 20 ++-- cpp/include/raft/matrix/col_wise_sort.cuh | 71 ++++++------ cpp/include/raft/matrix/copy.cuh | 41 ++++--- cpp/include/raft/matrix/detail/print.hpp | 10 +- cpp/include/raft/matrix/diagonal.cuh | 28 ++--- cpp/include/raft/matrix/gather.cuh | 119 ++++++++++++-------- cpp/include/raft/matrix/init.cuh | 14 +-- cpp/include/raft/matrix/linewise_op.cuh | 15 +-- cpp/include/raft/matrix/math.cuh | 1 - cpp/include/raft/matrix/matrix.cuh | 14 ++- cpp/include/raft/matrix/matrix_vector.cuh | 126 +++++++++++----------- cpp/include/raft/matrix/norm.cuh | 10 +- cpp/include/raft/matrix/power.cuh | 42 ++++---- cpp/include/raft/matrix/print.cuh | 21 ++-- cpp/include/raft/matrix/print.hpp | 7 +- cpp/include/raft/matrix/ratio.cuh | 11 +- cpp/include/raft/matrix/reciprocal.cuh | 20 ++-- cpp/include/raft/matrix/reverse.cuh | 27 ++--- cpp/include/raft/matrix/seq_root.cuh | 45 ++++---- cpp/include/raft/matrix/sign_flip.cuh | 11 +- cpp/include/raft/matrix/slice.cuh | 23 ++-- cpp/include/raft/matrix/threshold.cuh | 30 +++--- cpp/include/raft/matrix/triangular.cuh | 17 ++- 24 files changed, 451 insertions(+), 382 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 336fec3bb4..202e6163a9 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -222,108 +222,112 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } -template -constexpr bool is_row_or_column_major(mdspan /* m */ ) +template +constexpr bool is_row_or_column_major(mdspan /* m */) { - return false; + return false; } -template -constexpr bool is_row_or_column_major(mdspan /* m */ ) +template +constexpr bool is_row_or_column_major(mdspan /* m */) { - return true; + return true; } -template -constexpr bool is_row_or_column_major(mdspan /* m */ ) +template +constexpr bool is_row_or_column_major(mdspan /* m */) { - return true; + return true; } -template +template constexpr bool is_row_or_column_major(mdspan m) { - return m.is_exhaustive(); + return m.is_exhaustive(); } - -template -constexpr bool is_row_major(mdspan /* m */ ) +template +constexpr bool is_row_major(mdspan /* m */) { - return false; + return false; } -template -constexpr bool is_row_major(mdspan /* m */ ) +template +constexpr bool is_row_major(mdspan /* m */) { - return false; + return false; } -template -constexpr bool is_row_major(mdspan /* m */ ) +template +constexpr bool is_row_major(mdspan /* m */) { - return true; + return true; } -template +template constexpr bool is_row_major(mdspan m) { - return m.is_exhaustive(); + return m.is_exhaustive(); } -template -constexpr bool is_col_major(mdspan /* m */ ) +template +constexpr bool is_col_major(mdspan /* m */) { - return false; + return false; } -template -constexpr bool is_col_major(mdspan /* m */ ) +template +constexpr bool is_col_major(mdspan /* m */) { - return true; + return true; } -template -constexpr bool is_col_major(mdspan /* m */ ) +template +constexpr bool is_col_major(mdspan /* m */) { - return false; + return false; } -template +template constexpr bool is_col_major(mdspan m) { - return m.is_exhaustive(); + return m.is_exhaustive(); } - -template -constexpr bool is_matrix_view(mdspan> /* m */) { - return true; +template +constexpr bool is_matrix_view(mdspan> /* m */) +{ + return true; } -template -constexpr bool is_matrix_view(mdspan m) { - return false; +template +constexpr bool is_matrix_view(mdspan m) +{ + return false; } -template -constexpr bool is_vector_view(mdspan> /* m */) { - return true; +template +constexpr bool is_vector_view(mdspan> /* m */) +{ + return true; } -template -constexpr bool is_vector_view(mdspan m) { - return false; +template +constexpr bool is_vector_view(mdspan m) +{ + return false; } -template -constexpr bool is_scalar_view(mdspan> /* m */) { - return true; +template +constexpr bool is_scalar_view(mdspan> /* m */) +{ + return true; } -template -constexpr bool is_scalar_view(mdspan m) { - return false; +template +constexpr bool is_scalar_view(mdspan m) +{ + return false; } } // namespace raft diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index fd52441c15..28cb69dd8f 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -30,12 +30,14 @@ namespace raft::matrix { * @param out: output vector of size n_cols * @param stream: cuda stream */ - template - void argmax(const raft::handle_t &handle, - raft::device_matrix_view in, - raft::device_vector_view out) { - - RAFT_EXPECTS(out.extent(1) == in.extent(1), "Size of output vector must equal number of columns in input matrix."); - detail::argmax(in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); - } +template +void argmax(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view out) +{ + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Size of output vector must equal number of columns in input matrix."); + detail::argmax( + in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 06e0adaab0..74f78796e8 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -53,8 +53,6 @@ void sort_cols_per_row(const InType* in, in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); } - - /** * @brief sort columns within each row of row-major input matrix and return sorted indexes * modelled as key-value sort with key being input matrix and value being index of values @@ -63,43 +61,52 @@ void sort_cols_per_row(const InType* in, * @param sorted_keys: Optional, output matrix for sorted keys (input) */ template -void sort_cols_per_row(const raft::handle_t &handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - std::optional> sorted_keys = std::nullptr) { +void sort_cols_per_row( + const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + std::optional> sorted_keys = + std::nullptr) +{ + RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0), + "Input and output matrices must have the same shape."); - RAFT_EXPECTS(in.extent(1) == out.extent(1) && - in.extent(0) == out.extent(0), "Input and output matrices must have the same shape."); + if (sorted_keys.has_value()) { + RAFT_EXPECTS(in.extent(1) == sorted_keys.value().extent(1) && + in.extent(0) == sorted_keys.value().extent(0), + "Input and `sorted_keys` matrices must have the same shape."); + } - if(sorted_keys.has_value()) { - RAFT_EXPECTS(in.extent(1) == sorted_keys.value().extent(1) && - in.extent(0) == sorted_keys.value().extent(0), "Input and `sorted_keys` matrices must have the same shape."); - } + size_t workspace_size = 0; + bool alloc_workspace = false; - size_t workspace_size = 0; - bool alloc_workspace = false; + detail::sortColumnsPerRow( + in.data_handle(), + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + nullptr, + &workspace_size, + handle.get_stream(), + sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); + + if (alloc_workspace) { + auto workspace = raft::make_device_vector(handle, workspace_size); detail::sortColumnsPerRow( - in.data_handle(), out.data_handle(), - in.extent(0), in.extent(1), - alloc_workspace, - nullptr, &workspace_size, handle.get_stream(), - sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); - - if(alloc_workspace) { - - auto workspace = raft::make_device_vector(handle, workspace_size); - - detail::sortColumnsPerRow( - in.data_handle(), out.data_handle(), - in.extent(0), in.extent(1), - alloc_workspace, - workspace.data_handle(), &workspace_size, handle.get_stream(), - sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); - } + in.data_handle(), + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + workspace.data_handle(), + &workspace_size, + handle.get_stream(), + sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); + } } - }; // end namespace matrix }; // end namespace raft diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 2798fe7491..4a00a0f732 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -34,19 +34,28 @@ namespace raft::matrix { * @param[in] indices of the rows to be copied */ template -void copy_rows(const raft::handle_t &handle, +void copy_rows(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view indices) { - RAFT_EXPECTS(in.extent(1) == out.extent(1), "Input and output matrices must have same number of columns"); - RAFT_EXPECTS(indices.extent(0) == out.extent(0), "Number of rows in output matrix must equal number of indices"); - bool in_rowmajor = raft::is_row_major(in); - bool out_rowmajor = raft::is_row_major(out); + RAFT_EXPECTS(in.extent(1) == out.extent(1), + "Input and output matrices must have same number of columns"); + RAFT_EXPECTS(indices.extent(0) == out.extent(0), + "Number of rows in output matrix must equal number of indices"); + bool in_rowmajor = raft::is_row_major(in); + bool out_rowmajor = raft::is_row_major(out); - RAFT_EXPECTS(in_rowmajor == out_rowmajor, "Input and output matrices must have same layout (row- or column-major)") + RAFT_EXPECTS(in_rowmajor == out_rowmajor, + "Input and output matrices must have same layout (row- or column-major)") - detail::copyRows(in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), indices.data_handle(), indices.extent(0), handle.get_stream()); + detail::copyRows(in.data_handle(), + in.extent(0), + in.extent(1), + out.data_handle(), + indices.data_handle(), + indices.extent(0), + handle.get_stream()); } /** @@ -56,14 +65,15 @@ void copy_rows(const raft::handle_t &handle, * @param[out] out: output matrix */ template -void copy(const raft::handle_t &handle, +void copy(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out) { - RAFT_EXPECTS(in.extent(0) == out.extent(0) && - in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); + RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), + "Input and output matrix shapes must match."); - raft::copy_async(out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); + raft::copy_async( + out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); } /** @@ -78,10 +88,9 @@ void copy(const raft::handle_t &handle, */ template void trunc_zero_origin( - m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) + m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) { - detail::truncZeroOrigin(in, in_n_rows, out, out_n_rows, out_n_cols, stream); + detail::truncZeroOrigin(in, in_n_rows, out, out_n_rows, out_n_cols, stream); } - -} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp index d2b7b1100e..0545d049ad 100644 --- a/cpp/include/raft/matrix/detail/print.hpp +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -37,12 +37,12 @@ namespace raft::matrix::detail { template void printHost(const m_t* in, idx_t n_rows, idx_t n_cols) { - for (idx_t i = 0; i < n_rows; i++) { - for (idx_t j = 0; j < n_cols; j++) { - printf("%1.4f ", in[j * n_rows + i]); - } - printf("\n"); + for (idx_t i = 0; i < n_rows; i++) { + for (idx_t j = 0; j < n_cols; j++) { + printf("%1.4f ", in[j * n_rows + i]); } + printf("\n"); + } } } // end namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/diagonal.cuh b/cpp/include/raft/matrix/diagonal.cuh index 69bce231d3..f5ab33ebd7 100644 --- a/cpp/include/raft/matrix/diagonal.cuh +++ b/cpp/include/raft/matrix/diagonal.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -28,15 +28,15 @@ namespace raft::matrix { * @param matrix: matrix of size n_rows x n_cols */ template -void initialize_diagonal( - const raft::handle_t &handle, - raft::device_vector_view vec, - raft::device_matrix_view matrix) { - detail::initializeDiagonalMatrix(vec.data_handle(), - matrix.data_handle(), - matrix.extent(0), - matrix.extent(1), - handle.get_stream()); +void initialize_diagonal(const raft::handle_t& handle, + raft::device_vector_view vec, + raft::device_matrix_view matrix) +{ + detail::initializeDiagonalMatrix(vec.data_handle(), + matrix.data_handle(), + matrix.extent(0), + matrix.extent(1), + handle.get_stream()); } /** @@ -44,10 +44,10 @@ void initialize_diagonal( * @param in: square input matrix with size len x len */ template -void invert_diagonal(const raft::handle_t &handle, +void invert_diagonal(const raft::handle_t& handle, raft::device_matrix_view in) { - RAFT_EXPECTS(in.extent(0) == in.extent(1), "Matrix must be square."); - detail::getDiagonalInverseMatrix(in.data_handle(), in.extent(0), handle.get_stream()); -} + RAFT_EXPECTS(in.extent(0) == in.extent(1), "Matrix must be square."); + detail::getDiagonalInverseMatrix(in.data_handle(), in.extent(0), handle.get_stream()); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 384de604c8..6cf6ea756d 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -64,18 +64,25 @@ void gather(const MatrixIteratorT in, * @param out Pointer to the output matrix (assumed to be row-major) */ template -void gather(const raft::handle_t &handle, +void gather(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - raft::device_vector_view map) { - - RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); + raft::device_vector_view map) +{ + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); - detail::gather(in.data_handle(), in.extent(1), in.extent(0), map, map.extent(0), out.data_handle(), handle.get_stream()); + detail::gather(in.data_handle(), + in.extent(1), + in.extent(0), + map, + map.extent(0), + out.data_handle(), + handle.get_stream()); } - /** * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. @@ -93,16 +100,25 @@ void gather(const raft::handle_t &handle, * @param transform_op The transformation operation, transforms the map values to IndexT */ template -void gather(const raft::handle_t &handle, +void gather(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out - raft::device_vector_view map, - MapTransformOp transform_op) { - - RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); + raft::device_vector_view map, + MapTransformOp transform_op) +{ + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); - detail::gather(in.data_handle(), in.extent(1), in.extent(0), map, map.extent(0), out.data_handle(), transform_op, handle.get_stream()); + detail::gather(in.data_handle(), + in.extent(1), + in.extent(0), + map, + map.extent(0), + out.data_handle(), + transform_op, + handle.get_stream()); } /** @@ -134,10 +150,9 @@ void gather(const MatrixIteratorT in, MapTransformOp transform_op, cudaStream_t stream) { - detail::gather(in, D, N, map, map_length, out, transform_op, stream); + detail::gather(in, D, N, map, map_length, out, transform_op, stream); } - /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. @@ -179,8 +194,6 @@ void gather_if(const MatrixIteratorT in, detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, stream); } - - /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. @@ -201,27 +214,35 @@ void gather_if(const MatrixIteratorT in, * @param pred_op Predicate to apply to the stencil values */ template -void gather_if(const raft::handle_t &handle, + typename MapIteratorT, + typename StencilIteratorT, + typename UnaryPredicateOp, + typename MatrixIdxT> +void gather_if(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view map, raft::device_vector_view stencil, - UnaryPredicateOp pred_op) { - - RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - RAFT_EXPECTS(map.extent(0) == stencil.extent(0), "Number of elements in stencil must equal number of elements in map"); + UnaryPredicateOp pred_op) +{ + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); + RAFT_EXPECTS(map.extent(0) == stencil.extent(0), + "Number of elements in stencil must equal number of elements in map"); - detail::gather_if(in.data_handle(), out.extent(1), out.extent(0), - map.data_handle(), stencil.data_handle(), map.extent(0), - out.data_handle(), pred_op, handle.get_stream()); + detail::gather_if(in.data_handle(), + out.extent(1), + out.extent(0), + map.data_handle(), + stencil.data_handle(), + map.extent(0), + out.data_handle(), + pred_op, + handle.get_stream()); } - /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. @@ -291,12 +312,12 @@ void gather_if(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT */ template -void gather_if(const raft::handle_t &handle, + typename MapIteratorT, + typename StencilIteratorT, + typename UnaryPredicateOp, + typename MapTransformOp, + typename MatrixIdxT> +void gather_if(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view map, @@ -304,13 +325,23 @@ void gather_if(const raft::handle_t &handle, UnaryPredicateOp pred_op, MapTransformOp transform_op) { + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); + RAFT_EXPECTS(map.extent(0) == stencil.extent(0), + "Number of elements in stencil must equal number of elements in map"); - RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - RAFT_EXPECTS(map.extent(0) == stencil.extent(0), "Number of elements in stencil must equal number of elements in map"); - - detail::gather_if(in.data_handle(), in.extent(1), in.extent(0), map.data_handle(), stencil.data_handle(), - map.extent(0), out.data_handle(), pred_op, transform_op, handle.get_stream()); + detail::gather_if(in.data_handle(), + in.extent(1), + in.extent(0), + map.data_handle(), + stencil.data_handle(), + map.extent(0), + out.data_handle(), + pred_op, + transform_op, + handle.get_stream()); } } // namespace matrix diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index 7d27689227..0c6f45f904 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { /** @@ -30,10 +30,12 @@ namespace raft::matrix { * @param scalar svalar value */ template -void fill(const raft::handle_t &handle, +void fill(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_matrix_view out, math_t scalar) { - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); - detail::setValue(out.data_handle(), in.data_handle(), scalar, in.size(), handle.get_stream()); -} + raft::device_matrix_view out, + math_t scalar) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); + detail::setValue(out.data_handle(), in.data_handle(), scalar, in.size(), handle.get_stream()); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index adf7732954..2321548b35 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -17,12 +17,12 @@ #pragma once #include -#include #include +#include namespace raft::matrix { - /** +/** * Run a function over matrix lines (rows or columns) with a variable number * row-vectors or column-vectors. * The term `line` here signifies that the lines can be either columns or rows, @@ -47,15 +47,16 @@ namespace raft::matrix { * size of each vector is `alongLines ? lineLen : nLines`. */ template -void linewise_op(const raft::handle_t &handle, +void linewise_op(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, const idx_t lineLen, const idx_t nLines, const bool alongLines, Lambda op, - raft::device_vector_view... vecs) { - detail::MatrixLinewiseOp<16, 256>::run( - out, in, lineLen, nLines, alongLines, op, stream, vecs...); -} + raft::device_vector_view... vecs) +{ + detail::MatrixLinewiseOp<16, 256>::run( + out, in, lineLen, nLines, alongLines, op, stream, vecs...); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index 6fb15e6d8b..25ad185935 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -23,7 +23,6 @@ " is deprecated and will be removed in a future release." \ " Please use versions in individual header files instead.") - #ifndef __MATH_H #define __MATH_H diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh index 588399655b..3a7e0dad47 100644 --- a/cpp/include/raft/matrix/matrix.cuh +++ b/cpp/include/raft/matrix/matrix.cuh @@ -28,9 +28,9 @@ #pragma once -#include #include "detail/linewise_op.cuh" #include "detail/matrix.cuh" +#include #include @@ -67,9 +67,6 @@ void copyRows(const m_t* in, detail::copyRows(in, n_rows, n_cols, out, indices, n_rows_indices, stream, rowMajor); } - - - /** * @brief copy matrix operation for column major matrices. * @param in: input matrix @@ -91,14 +88,15 @@ void copy(const m_t* in, m_t* out, idx_t n_rows, idx_t n_cols, cudaStream_t stre * @param[out] out: output matrix */ template -void copy(const raft::handle_t &handle, +void copy(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out) { - RAFT_EXPECTS(in.extent(0) == out.extent(0) && - in.extent(1) == out.extent(1), "Input and output matrix shapes must match."); + RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), + "Input and output matrix shapes must match."); - raft::copy_async(out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); + raft::copy_async( + out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); } /** diff --git a/cpp/include/raft/matrix/matrix_vector.cuh b/cpp/include/raft/matrix/matrix_vector.cuh index b051831a15..5d05d03d2c 100644 --- a/cpp/include/raft/matrix/matrix_vector.cuh +++ b/cpp/include/raft/matrix/matrix_vector.cuh @@ -16,12 +16,12 @@ #pragma once -#include #include "detail/matrix.cuh" +#include namespace raft::matrix { - /** +/** * @brief multiply each row or column of matrix with vector, skipping zeros in vector * @param data input matrix, results are in-place * @param vec input vector @@ -33,15 +33,15 @@ namespace raft::matrix { */ template void binary_mult_skip_zero(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) { - detail::matrixVectorBinaryMultSkipZero( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); + detail::matrixVectorBinaryMultSkipZero( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); } /** @@ -54,20 +54,20 @@ void binary_mult_skip_zero(Type* data, * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns * @param stream cuda stream */ - template - void binary_div(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) - { - detail::matrixVectorBinaryDiv( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); - } +template +void binary_div(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) +{ + detail::matrixVectorBinaryDiv( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); +} - /** +/** * @brief divide each row or column of matrix with vector, skipping zeros in vector * @param data input matrix, results are in-place * @param vec input vector @@ -79,21 +79,21 @@ void binary_mult_skip_zero(Type* data, * @param return_zero result is zero if true and vector value is below threshold, original value if * false */ - template - void binary_div_skip_zero(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream, - bool return_zero = false) - { - detail::matrixVectorBinaryDivSkipZero( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream, return_zero); - } +template +void binary_div_skip_zero(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream, + bool return_zero = false) +{ + detail::matrixVectorBinaryDivSkipZero( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream, return_zero); +} - /** +/** * @brief add each row or column of matrix with vector * @param data input matrix, results are in-place * @param vec input vector @@ -103,20 +103,20 @@ void binary_mult_skip_zero(Type* data, * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns * @param stream cuda stream */ - template - void binary_add(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) - { - detail::matrixVectorBinaryAdd( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); - } +template +void binary_add(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) +{ + detail::matrixVectorBinaryAdd( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); +} - /** +/** * @brief subtract each row or column of matrix with vector * @param data input matrix, results are in-place * @param vec input vector @@ -126,17 +126,17 @@ void binary_mult_skip_zero(Type* data, * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns * @param stream cuda stream */ - template - void binary_sub(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) - { - detail::matrixVectorBinarySub( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); - } +template +void binary_sub(Type* data, + const Type* vec, + IdxType n_row, + IdxType n_col, + bool rowMajor, + bool bcastAlongRows, + cudaStream_t stream) +{ + detail::matrixVectorBinarySub( + data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); +} -} \ No newline at end of file +} // namespace raft::matrix \ No newline at end of file diff --git a/cpp/include/raft/matrix/norm.cuh b/cpp/include/raft/matrix/norm.cuh index f4814ccbe9..5c1e0b9c01 100644 --- a/cpp/include/raft/matrix/norm.cuh +++ b/cpp/include/raft/matrix/norm.cuh @@ -17,12 +17,11 @@ #pragma once #include -#include #include +#include namespace raft::matrix { - /** * @brief Get the L2/F-norm of a matrix * @param handle @@ -31,9 +30,8 @@ namespace raft::matrix { * @param stream: cuda stream */ template -m_t l2_norm(const raft::handle_t& handle, - raft::device_mdspan in) +m_t l2_norm(const raft::handle_t& handle, raft::device_mdspan in) { - return detail::getL2Norm(handle, in.data_handle(), in.size(), handle.get_stream()); -} + return detail::getL2Norm(handle, in.data_handle(), in.size(), handle.get_stream()); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh index 1c38f9c2f8..98d691afd4 100644 --- a/cpp/include/raft/matrix/power.cuh +++ b/cpp/include/raft/matrix/power.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -30,13 +30,13 @@ namespace raft::matrix { * @param[in] scalar: every element is multiplied with scalar. */ template -void weighted_power( - const raft::handle_t &handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - math_t scalar) { - RAFT_EXPECTS(in.size() == out.size(), "Size of input and output matrices must be equal"); - detail::power(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream()); +void weighted_power(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t scalar) +{ + RAFT_EXPECTS(in.size() == out.size(), "Size of input and output matrices must be equal"); + detail::power(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream()); } /** @@ -45,10 +45,11 @@ void weighted_power( * @param[in] scalar: every element is multiplied with scalar. */ template -void weighted_power(const raft::handle_t &handle, +void weighted_power(const raft::handle_t& handle, raft::device_matrix_view inout, - math_t scalar) { - detail::power(inout.data_handle(), scalar, inout.size(), handle.get_stream()); + math_t scalar) +{ + detail::power(inout.data_handle(), scalar, inout.size(), handle.get_stream()); } /** @@ -56,9 +57,9 @@ void weighted_power(const raft::handle_t &handle, * @param[inout] inout: input matrix and also the result is stored */ template -void power(const raft::handle_t &handle, - raft::device_matrix_view inout) { - detail::power(inout.data_handle(), inout.size(), handle.get_stream()); +void power(const raft::handle_t& handle, raft::device_matrix_view inout) +{ + detail::power(inout.data_handle(), inout.size(), handle.get_stream()); } /** @@ -69,13 +70,12 @@ void power(const raft::handle_t &handle, * @{ */ template -void power(const raft::handle_t &handle, +void power(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_matrix_view out) { - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); - detail::power(in, out, len, stream); + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); + detail::power(in, out, len, stream); } - - -} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh index d7f978ec13..5eef7e0fda 100644 --- a/cpp/include/raft/matrix/print.cuh +++ b/cpp/include/raft/matrix/print.cuh @@ -16,14 +16,13 @@ #pragma once -#include #include -#include +#include #include +#include namespace raft::matrix { - /** * @brief Prints the data stored in GPU memory * @param handle: raft handle @@ -32,12 +31,13 @@ namespace raft::matrix { * @param v_separator: vertical separator character */ template -void print(const raft::handle_t &handle, +void print(const raft::handle_t& handle, raft::device_matrix_view in, - char h_separator = ' ', - char v_separator = '\n') + char h_separator = ' ', + char v_separator = '\n') { - detail::print(in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator, handle.get_stream()); + detail::print( + in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator, handle.get_stream()); } /** @@ -45,7 +45,8 @@ void print(const raft::handle_t &handle, * @param in: input matrix with column-major layout */ template -void print(raft::host_matrix_view in) { - detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); -} +void print(raft::host_matrix_view in) +{ + detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.hpp b/cpp/include/raft/matrix/print.hpp index cc80671cec..66e939be0f 100644 --- a/cpp/include/raft/matrix/print.hpp +++ b/cpp/include/raft/matrix/print.hpp @@ -26,7 +26,8 @@ namespace raft::matrix { * @param in: input matrix with column-major layout */ template -void print(raft::host_matrix_view in) { - detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); -} +void print(raft::host_matrix_view in) +{ + detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh index 78917ee684..52c2180f95 100644 --- a/cpp/include/raft/matrix/ratio.cuh +++ b/cpp/include/raft/matrix/ratio.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -35,8 +35,9 @@ namespace raft::matrix { template void ratio(const raft::handle_t& handle, raft::device_matrix_view src, - raft::device_matrix_view dest) { - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); - detail::ratio(handle, src.data_handle(), dest.data_handle(), in.size(), handle.get_stream()); -} + raft::device_matrix_view dest) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); + detail::ratio(handle, src.data_handle(), dest.data_handle(), in.size(), handle.get_stream()); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index 66c4f9bf5f..f74d98969c 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -39,9 +39,11 @@ void reciprocal(raft::device_matrix_view in, raft::device_matrix_view out, math_t scalar, bool setzero = false, - math_t thres = 1e-15) { - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have the same size."); - detail::reciprocal(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), setzero, thres); + math_t thres = 1e-15) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have the same size."); + detail::reciprocal( + in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), setzero, thres); } /** @@ -57,11 +59,13 @@ void reciprocal(raft::device_matrix_view in, * @{ */ template -void reciprocal(const raft::handle_t &handle, +void reciprocal(const raft::handle_t& handle, raft::device_matrix_view inout, math_t scalar, bool setzero = false, - math_t thres = 1e-15) { - detail::reciprocal(inout.data_handle(), scalar, inout.size(), handle.get_stream(), setzero, thres); -} + math_t thres = 1e-15) +{ + detail::reciprocal( + inout.data_handle(), scalar, inout.size(), handle.get_stream(), setzero, thres); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/reverse.cuh b/cpp/include/raft/matrix/reverse.cuh index 2416a81636..8a9837f467 100644 --- a/cpp/include/raft/matrix/reverse.cuh +++ b/cpp/include/raft/matrix/reverse.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -31,9 +31,10 @@ namespace raft::matrix { * @param stream: cuda stream */ template -void col_reverse(const raft::handle_t &handle, - raft::device_matrix_view inout) { - detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); +void col_reverse(const raft::handle_t& handle, + raft::device_matrix_view inout) +{ + detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); } /** @@ -45,12 +46,12 @@ void col_reverse(const raft::handle_t &handle, * @param stream: cuda stream */ template -void col_reverse(const raft::handle_t &handle, - raft::device_matrix_view inout) { - detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); +void col_reverse(const raft::handle_t& handle, + raft::device_matrix_view inout) +{ + detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); } - /** * @brief Rows of a column major matrix are reversed in place (i.e. first row and last * row are swapped) @@ -60,10 +61,10 @@ void col_reverse(const raft::handle_t &handle, * @param stream: cuda stream */ template -void row_reverse(const raft::handle_t &handle, +void row_reverse(const raft::handle_t& handle, raft::device_matrix_view inout) { - detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); } /** @@ -75,10 +76,10 @@ void row_reverse(const raft::handle_t &handle, * @param stream: cuda stream */ template -void row_reverse(const raft::handle_t &handle, +void row_reverse(const raft::handle_t& handle, raft::device_matrix_view inout) { - detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); + detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); } -} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/seq_root.cuh b/cpp/include/raft/matrix/seq_root.cuh index 9af90bd842..0f74799154 100644 --- a/cpp/include/raft/matrix/seq_root.cuh +++ b/cpp/include/raft/matrix/seq_root.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -30,11 +30,12 @@ namespace raft::matrix { * @param[out] out: output matrix. The result is stored in the out matrix */ template -void seq_root(const raft::handle_t &handle, +void seq_root(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_matrix_view out) { - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); - detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); + detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); } /** @@ -44,10 +45,9 @@ void seq_root(const raft::handle_t &handle, * @param[inout] inout: input matrix with in-place results */ template -void seq_root(const raft::handle_t &handle, - raft::device_matrix_view inout) +void seq_root(const raft::handle_t& handle, raft::device_matrix_view inout) { - detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream()); + detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream()); } /** @@ -60,14 +60,15 @@ void seq_root(const raft::handle_t &handle, * @param[in] set_neg_zero whether to set negative numbers to zero */ template -void weighted_seq_root(const raft::handle_t &handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - math_t scalar, - bool set_neg_zero = false) +void weighted_seq_root(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t scalar, + bool set_neg_zero = false) { - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); - detail::seqRoot(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), set_neg_zero); + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); + detail::seqRoot( + in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), set_neg_zero); } /** @@ -79,14 +80,12 @@ void weighted_seq_root(const raft::handle_t &handle, * @param set_neg_zero whether to set negative numbers to zero */ template -void weighted_seq_root( - const raft::handle_t &handle, - raft::device_matrix_view inout, - math_t scalar, bool set_neg_zero = false) +void weighted_seq_root(const raft::handle_t& handle, + raft::device_matrix_view inout, + math_t scalar, + bool set_neg_zero = false) { - detail::seqRoot(inout.data_handle(), scalar, inout.size(), handle.get_stream(), set_neg_zero); + detail::seqRoot(inout.data_handle(), scalar, inout.size(), handle.get_stream(), set_neg_zero); } - - -} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/sign_flip.cuh b/cpp/include/raft/matrix/sign_flip.cuh index 479f93cd59..f99c3111ab 100644 --- a/cpp/include/raft/matrix/sign_flip.cuh +++ b/cpp/include/raft/matrix/sign_flip.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -29,8 +29,9 @@ namespace raft::matrix { * @param inout: input matrix. Result also stored in this parameter */ template -void sign_flip(const raft::handle_t &handle, - raft::device_matrix_view inout) { - detail::signFlip(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); -} +void sign_flip(const raft::handle_t& handle, + raft::device_matrix_view inout) +{ + detail::signFlip(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index 43209d4054..acb2e35793 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -34,11 +34,22 @@ namespace raft::matrix { * 3, 0, 1, 4, 3); */ template -void slice(const raft::handle_t &handle, +void slice(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - idx_t x1, idx_t y1, idx_t x2, idx_t y2) { - detail::sliceMatrix(in.data_handle(), in.extent(0), in.extent(1), - out.data_handle(), x1, y1, x2, y2, handle.get_stream()); -} + idx_t x1, + idx_t y1, + idx_t x2, + idx_t y2) +{ + detail::sliceMatrix(in.data_handle(), + in.extent(0), + in.extent(1), + out.data_handle(), + x1, + y1, + x2, + y2, + handle.get_stream()); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/threshold.cuh b/cpp/include/raft/matrix/threshold.cuh index 6959392ca9..bdcab41421 100644 --- a/cpp/include/raft/matrix/threshold.cuh +++ b/cpp/include/raft/matrix/threshold.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -31,14 +31,14 @@ namespace raft::matrix { * @param thres threshold to set values to zero */ template -void zero_small_values( - const raft::handle_t &handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - math_t thres = 1e-15) { - - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size"); - detail::setSmallValuesZero(out.data_handle(), in.data_handle(), in.size(), handle.get_stream(), thres); +void zero_small_values(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t thres = 1e-15) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size"); + detail::setSmallValuesZero( + out.data_handle(), in.data_handle(), in.size(), handle.get_stream(), thres); } /** @@ -49,10 +49,10 @@ void zero_small_values( * @param thres: threshold */ template -void zero_small_values( - const raft::handle_t &handle, - raft::device_matrix_view inout, - math_t thres = 1e-15) { - detail::setSmallValuesZero(inout.data_handle(), inout.size(), handle.get_stream(), thres); -} +void zero_small_values(const raft::handle_t& handle, + raft::device_matrix_view inout, + math_t thres = 1e-15) +{ + detail::setSmallValuesZero(inout.data_handle(), inout.size(), handle.get_stream(), thres); } +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh index 7dcf6e39b4..6f520841eb 100644 --- a/cpp/include/raft/matrix/triangular.cuh +++ b/cpp/include/raft/matrix/triangular.cuh @@ -17,8 +17,8 @@ #pragma once #include -#include #include +#include namespace raft::matrix { @@ -29,12 +29,11 @@ namespace raft::matrix { * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) */ template -void upper_triangular(const raft::handle_t &handle, - raft::device_matrix_view src, - raft::device_matrix_view dst) { - - detail::copyUpperTriangular(src.data_handle(), dst.data_handle(), - src.extent(0), src.extent(1), - handle.get_stream()); -} +void upper_triangular(const raft::handle_t& handle, + raft::device_matrix_view src, + raft::device_matrix_view dst) +{ + detail::copyUpperTriangular( + src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); } +} // namespace raft::matrix From 6c34351d046f15b75c0e864278371e3dcd30dac1 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 21 Sep 2022 21:32:33 -0400 Subject: [PATCH 26/58] Fixing style for tests --- cpp/test/matrix/columnSort.cu | 14 +++++++++----- cpp/test/matrix/gather.cu | 9 +++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index 7642a4db7d..f795eaa6f0 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -47,7 +47,8 @@ struct columnSort { }; template -::std::ohandle.get_stream()& operator<<(::std::ohandle.get_stream()& os, const columnSort& dims) +::std::ohandle.get_stream() & operator<<(::std::ohandle.get_stream() & os, + const columnSort& dims) { return os; } @@ -69,7 +70,7 @@ class ColumnSort : public ::testing::TestWithParam> { { params = ::testing::TestWithParam>::GetParam(); int len = params.n_row * params.n_col; - RAFT_CUDA_TRY(cudahandle.get_stream()Create(&handle.get_stream())); + RAFT_CUDA_TRY(cudahandle.get_stream() Create(&handle.get_stream())); keyIn.resize(len, handle.get_stream()); valueOut.resize(len, handle.get_stream()); goldenValOut.resize(len, handle.get_stream()); @@ -101,11 +102,14 @@ class ColumnSort : public ::testing::TestWithParam> { raft::update_device(keyIn.data(), &vals[0], len, handle.get_stream()); raft::update_device(goldenValOut.data(), &cValGolden[0], len, handle.get_stream()); - if (params.testKeys) raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, handle.get_stream()); + if (params.testKeys) + raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, handle.get_stream()); auto key_in_view = raft::make_device_matrix_view(keyIn.data(), params.n_row, params.n_col); - auto value_out_view = raft::make_device_matrix_view(valueOut.data(), params.n_row, params.n_col); - auto key_sorted_view = raft::make_device_matrix_view(keySorted.data(), params.n_row, params.n_col); + auto value_out_view = + raft::make_device_matrix_view(valueOut.data(), params.n_row, params.n_col); + auto key_sorted_view = + raft::make_device_matrix_view(keySorted.data(), params.n_row, params.n_col); raft::matrix::sort_cols_per_row(handle, key_in_view, value_out_view, key_sorted_view); diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index da2057d4f6..92c68f58f5 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -15,8 +15,8 @@ */ #include -#include #include +#include #include #include #include @@ -47,7 +47,7 @@ void naiveGather( } template -void gatherLaunch(const raft::handle_t &handle, +void gatherLaunch(const raft::handle_t& handle, MatrixIteratorT in, int D, int N, @@ -58,7 +58,7 @@ void gatherLaunch(const raft::handle_t &handle, { typedef typename std::iterator_traits::value_type MapValueT; - auto in_view = raft::make_device_matrix_view(in, N, D); + auto in_view = raft::make_device_matrix_view(in, N, D); auto map_view = raft::make_device_vector_view(map, map_length); auto out_view = raft::make_device_matrix_view(out, N, D); @@ -117,7 +117,8 @@ class GatherTest : public ::testing::TestWithParam { raft::update_device(d_out_exp.data(), h_out.data(), map_length * ncols, stream); // launch device version of the kernel - gatherLaunch(handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); + gatherLaunch( + handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); handle.sync_stream(stream); } From fea8448d54e2cd77c970b70c2212c1ac23044301 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 26 Sep 2022 14:46:11 -0400 Subject: [PATCH 27/58] iUpdates --- cpp/include/raft/core/mdspan.hpp | 6 +- cpp/include/raft/matrix/argmax.cuh | 10 ++- cpp/include/raft/matrix/gather.cuh | 4 +- cpp/test/CMakeLists.txt | 1 + cpp/test/matrix/argmax.cu | 104 +++++++++++++++++++++++++++++ 5 files changed, 115 insertions(+), 10 deletions(-) create mode 100644 cpp/test/matrix/argmax.cu diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 202e6163a9..3d95bb54cd 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -301,7 +301,7 @@ constexpr bool is_matrix_view(mdspan> /* m } template -constexpr bool is_matrix_view(mdspan m) +constexpr bool is_matrix_view(mdspan m) { return false; } @@ -313,7 +313,7 @@ constexpr bool is_vector_view(mdspan> /* m } template -constexpr bool is_vector_view(mdspan m) +constexpr bool is_vector_view(mdspan m) { return false; } @@ -325,7 +325,7 @@ constexpr bool is_scalar_view(mdspan> /* m } template -constexpr bool is_scalar_view(mdspan m) +constexpr bool is_scalar_view(mdspan m) { return false; } diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index 28cb69dd8f..5afd026745 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -24,16 +24,14 @@ namespace raft::matrix { /** * @brief Argmax: find the row idx with maximum value for each column - * @param in: input matrix - * @param n_rows: number of rows of input matrix - * @param n_cols: number of columns of input matrix + * @param handle: raft handle + * @param in: input matrix of size (n_rows, n_cols) * @param out: output vector of size n_cols - * @param stream: cuda stream */ -template +template void argmax(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_vector_view out) + raft::device_vector_view out) { RAFT_EXPECTS(out.extent(1) == in.extent(1), "Size of output vector must equal number of columns in input matrix."); diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 6cf6ea756d..b5e8b4af7c 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -102,7 +102,7 @@ void gather(const raft::handle_t& handle, template void gather(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_matrix_view out + raft::device_matrix_view out, raft::device_vector_view map, MapTransformOp transform_op) { @@ -144,6 +144,8 @@ void gather(const raft::handle_t& handle, */ template void gather(const MatrixIteratorT in, + int D, + int N, MapIteratorT map, int map_length, MatrixIteratorT out, diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 82d381bbb5..5e6ecd5094 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -73,6 +73,7 @@ add_executable(test_raft test/linalg/ternary_op.cu test/linalg/transpose.cu test/linalg/unary_op.cu + test/matrix/argmax.cu test/matrix/gather.cu test/matrix/math.cu test/matrix/matrix.cu diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu new file mode 100644 index 0000000000..50ca4e3e92 --- /dev/null +++ b/cpp/test/matrix/argmax.cu @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include + +namespace raft { + namespace matrix { + + template + struct ArgMaxInputs { + std::vector input_matrix; + std::vector output_matrix; + int n_cols; + int n_rows; + }; + + template + ::std::ostream& operator<<(::std::ostream& os, const ArgMaxInputs& dims) + { + return os; + } + + template + class ArgMaxTest : public ::testing::TestWithParam> { + public: + ArgMaxTest() + : params(::testing::TestWithParam>::GetParam()), + input(std::move(raft::make_device_matrix(handle, params.n_rows, params.n_cols))), + output(std::move(raft::make_device_vector(handle, params.n_rows))), + expected(std::move(raft::make_device_vector(handle, params.n_rows))) { + + raft::copy(input.data_handle(), params.input_matrix.data(), params.n_rows * params.n_cols); + raft::copy(expected.data_handle(), params.output_matrix.data(), params.n_rows * params.n_cols); + + raft::matrix::argmax(handle, input, output); + } + + protected: + raft::handle_t handle; + ArgMaxInputs params; + + raft::device_matrix input; + raft::device_vector output; + raft::device_vector expected; + }; + + const std::vector> inputsf = { + {0.1f, 0.2f, 0.3f, 0.4f}, + {0.4f, 0.3f, 0.2f, 0.1f}, + {0.2f, 0.3f, 0.5f, 0.0f}, + {3, 0, 2}, + 3, 4}; + + const std::vector> inputsd = { + {0.1, 0.2, 0.3, 0.4}, + {0.4, 0.3, 0.2, 0.1}, + {0.2, 0.3, 0.5, 0.0}, + {3, 0, 2}, + 3, 4}; + + typedef ArgMaxTest ArgMaxTestF; + TEST_P(ArgMaxTestF, Result) + { + ASSERT_TRUE(devArrMatch(output.data_handle(), + expected.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); + } + + typedef ArgMaxTest ArgMaxTestD; + TEST_P(ArgMaxTestD, Result) +{ + ASSERT_TRUE(devArrMatch(output.data_handle(), + expected.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); +} + +INSTANTIATE_TEST_SUITE_P(ArgMaxTest, ArgMaxTestTestF, ::testing::ValuesIn(inputsf)); + +INSTANTIATE_TEST_SUITE_P(ArgMaxTest, ArgMaxTestTestD, ::testing::ValuesIn(inputsd)); + +} // namespace matrix +} // namespace raft From bbf2ab1c52408c682aaa87f0fe957dad0636e0d3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 26 Sep 2022 15:18:25 -0400 Subject: [PATCH 28/58] Updates based on review feedback --- cpp/include/raft/spatial/knn/ball_cover.cuh | 5 +- cpp/include/raft/spatial/knn/brute_force.cuh | 158 +++++++++++++++ cpp/include/raft/spatial/knn/knn.cuh | 199 ------------------- cpp/test/spatial/knn.cu | 2 +- 4 files changed, 161 insertions(+), 203 deletions(-) create mode 100644 cpp/include/raft/spatial/knn/brute_force.cuh diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index ad9658872f..63ce2019f8 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -289,8 +289,7 @@ void rbc_knn_query(const raft::handle_t& handle, raft::device_matrix_view inds, raft::device_matrix_view dists, int_t k = 5, - bool perform_post_filtering = true, - float weight = 1.0) + bool perform_post_filtering = true) { RAFT_EXPECTS(k <= index.m, "k must be less than or equal to the number of data points in the index"); @@ -312,7 +311,7 @@ void rbc_knn_query(const raft::handle_t& handle, inds.data_handle(), dists.data_handle(), perform_post_filtering, - weight); + 1.0); } // TODO: implement functions for: diff --git a/cpp/include/raft/spatial/knn/brute_force.cuh b/cpp/include/raft/spatial/knn/brute_force.cuh new file mode 100644 index 0000000000..128c4bbd54 --- /dev/null +++ b/cpp/include/raft/spatial/knn/brute_force.cuh @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/knn_brute_force_faiss.cuh" +#include "detail/selection_faiss.cuh" +#include + +namespace raft::spatial::knn { + +/** + * Performs a k-select across row partitioned index/distance + * matrices formatted like the following: + * row1: k0, k1, k2 + * row2: k0, k1, k2 + * row3: k0, k1, k2 + * row1: k0, k1, k2 + * row2: k0, k1, k2 + * row3: k0, k1, k2 + * + * etc... + * + * @tparam idx_t + * @tparam value_t + * @param handle + * @param in_keys + * @param in_values + * @param out_keys + * @param out_values + * @param n_samples + * @param k + * @param translations + */ +template +inline void knn_merge_parts( + const raft::handle_t& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + int k, + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), + "in_keys and in_values must have the same shape."); + RAFT_EXPECTS( + out_keys.extent(0) == out_values.extent(0) == n_samples, + "Number of rows in output keys and val matrices must equal number of rows in search matrix."); + RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); + + auto n_parts = in_keys.extent(0) / n_samples; + detail::knn_merge_parts(in_keys.data_handle(), + in_values.data_handle(), + out_keys.data_handle(), + out_values.data_handle(), + n_samples, + n_parts, + k, + handle.get_stream(), + translations.value_or(nullptr)); +} + +/** + * @brief Flat C++ API function to perform a brute force knn on + * a series of input arrays and combine the results into a single + * output array for indexes and distances. Inputs can be either + * row- or column-major but the output matrices will always be in + * row-major format. + * + * @example + * + * + * + * @param[in] handle the cuml handle to use + * @param[in] index vector of device matrices (each size m_i*d) to be used as the knn index + * @param[in] search matrix (size n*d) to be used for searching the index + * @param[out] indices matrix (size n*k) to store output knn indices + * @param[out] distances matrix (size n*k) to store the output knn distance + * @param[in] k the number of nearest neighbors to return + * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This + * is ignored if the metric_type is not Minkowski. + * @param[in] metric distance metric to use. Euclidean (L2) is used by default + * @param[in] translations starting offsets for partitions. should be the same size + * as input vector. + */ +template +void brute_force_knn( + raft::handle_t const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + value_int k, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional> translations = std::nullopt) +{ + RAFT_EXPECTS(index[0].extent(1) == search.extent(1), + "Number of dimensions for both index and search matrices must be equal"); + + RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), + "Number of rows in output indices and distances matrices must equal number of rows " + "in search matrix."); + RAFT_EXPECTS( + indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast(k), + "Number of columns in output indices and distances matrices must be equal to k"); + + bool rowMajorIndex = std::is_same_v; + bool rowMajorQuery = std::is_same_v; + + std::vector inputs; + std::vector sizes; + for (std::size_t i = 0; i < index.size(); ++i) { + inputs.push_back(const_cast(index[i].data_handle())); + sizes.push_back(index[i].extent(0)); + } + + std::vector* trans = translations.has_value() ? &(*translations) : nullptr; + + detail::brute_force_knn_impl(handle, + inputs, + sizes, + static_cast(index[0].extent(1)), + // TODO: This is unfortunate. Need to fix. + const_cast(search.data_handle()), + static_cast(search.extent(0)), + indices.data_handle(), + distances.data_handle(), + k, + rowMajorIndex, + rowMajorQuery, + trans, + metric, + metric_arg.value_or(2.0f)); +} + +} // namespace raft::spatial::knn diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index e6e54253f6..95f7aab9da 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -66,60 +66,6 @@ inline void knn_merge_parts(value_t* in_keys, in_keys, in_values, out_keys, out_values, n_samples, n_parts, k, stream, translations); } -/** - * Performs a k-select across row partitioned index/distance - * matrices formatted like the following: - * row1: k0, k1, k2 - * row2: k0, k1, k2 - * row3: k0, k1, k2 - * row1: k0, k1, k2 - * row2: k0, k1, k2 - * row3: k0, k1, k2 - * - * etc... - * - * @tparam idx_t - * @tparam value_t - * @param handle - * @param in_keys - * @param in_values - * @param out_keys - * @param out_values - * @param n_samples - * @param k - * @param translations - */ -template -inline void knn_merge_parts( - const raft::handle_t& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - size_t n_samples, - int k, - std::optional> translations = std::nullopt) -{ - RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), - "in_keys and in_values must have the same shape."); - RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) == n_samples, - "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == k, - "Number of columns in output indices and distances matrices must be equal to k"); - - auto n_parts = in_keys.extent(0) / n_samples; - detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - k, - handle.get_stream(), - translations.value_or(nullptr)); -} - /** Choose an implementation for the select-top-k, */ enum class SelectKAlgo { /** Adapted from the faiss project. Result: sorted (not stable). */ @@ -223,72 +169,6 @@ inline void select_k(const value_t* in_keys, } } -/** - * Select k smallest or largest key/values from each row in the input data. - * - * If you think of the input data `in_keys` as a row-major matrix with input_len columns and - * n_inputs rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out_keys` of size (n_inputs, k). - * - * Note, depending on the selected algorithm, the values within rows of `out_keys` are not - * necessarily sorted. See the `SelectKAlgo` enumeration for more details. - * - * @tparam idx_t - * the payload type (what is being selected together with the keys). - * @tparam value_t - * the type of the keys (what is being compared). - * - * @param[in] handle the cuml handle to use - * @param[in] in_keys - * contiguous device array of inputs of size (input_len * n_inputs); - * these are compared and selected. - * @param[in] in_values - * contiguous device array of inputs of size (input_len * n_inputs); - * typically, these are indices of the corresponding in_keys. - * You can pass `NULL` as an argument here; this would imply `in_values` is a homogeneous array - * of indices from `0` to `input_len - 1` for every input and reduce the usage of memory - * bandwidth. - * @param[out] out_keys - * contiguous device array of outputs of size (k * n_inputs); - * the k smallest/largest values from each row of the `in_keys`. - * @param[out] out_values - * contiguous device array of outputs of size (k * n_inputs); - * the payload selected together with `out_keys`. - * @param[in] k - * the number of outputs to select in each input row. - * @param[in] select_min - * whether to select k smallest (true) or largest (false) keys. - * @param[in] algo - * the implementation of the algorithm - */ -template -inline void select_k(const raft::handle_t& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - int k, - bool select_min = true, - SelectKAlgo algo = SelectKAlgo::FAISS) -{ - size_t n_inputs = in_keys.extents(0); - size_t input_len = in_keys.extents(1); - - RAFT_EXPECTS(in_keys.extent(0) == in_values.extent(0) && in_keys.extent(1) == in_values.extent(1), - "in_keys and in_values must have the same shape"); - - select_k(in_keys.data_handle(), - in_values.data_handle(), - n_inputs, - input_len, - out_keys.data_handle(), - out_values.data_handle(), - select_min, - k, - handle.get_stream(), - algo); -} - /** * @brief Flat C++ API function to perform a brute force knn on * a series of input arrays and combine the results into a single @@ -346,83 +226,4 @@ void brute_force_knn(raft::handle_t const& handle, metric_arg); } -/** - * @brief Flat C++ API function to perform a brute force knn on - * a series of input arrays and combine the results into a single - * output array for indexes and distances. Inputs can be either - * row- or column-major but the output matrices will always be in - * row-major format. - * - * @example - * - * - * - * @param[in] handle the cuml handle to use - * @param[in] index vector of device matrices (each size m_i*d) to be used as the knn index - * @param[in] search matrix (size n*d) to be used for searching the index - * @param[out] indices matrix (size n*k) to store output knn indices - * @param[out] distances matrix (size n*k) to store the output knn distance - * @param[in] k the number of nearest neighbors to return - * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This - * is ignored if the metric_type is not Minkowski. - * @param[in] metric distance metric to use. Euclidean (L2) is used by default - * @param[in] translations starting offsets for partitions. should be the same size - * as input vector. - */ -template -void brute_force_knn( - raft::handle_t const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - value_int k, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional> translations = std::nullopt) -{ - RAFT_EXPECTS(index[0].extent(1) == search.extent(1), - "Number of dimensions for both index and search matrices must be equal"); - - RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - RAFT_EXPECTS( - indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); - - bool rowMajorIndex = std::is_same_v; - bool rowMajorQuery = std::is_same_v; - - std::vector inputs; - std::vector sizes; - for (std::size_t i = 0; i < index.size(); ++i) { - inputs.push_back(const_cast(index[i].data_handle())); - sizes.push_back(index[i].extent(0)); - } - - std::vector* trans = translations.has_value() ? &(*translations) : nullptr; - - detail::brute_force_knn_impl(handle, - inputs, - sizes, - static_cast(index[0].extent(1)), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - static_cast(search.extent(0)), - indices.data_handle(), - distances.data_handle(), - k, - rowMajorIndex, - rowMajorQuery, - trans, - metric, - metric_arg.value_or(2.0f)); -} - } // namespace raft::spatial::knn diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index 9c22a0cb73..5807705038 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -19,7 +19,7 @@ #include #include #include -#include +#include #if defined RAFT_NN_COMPILED #include #endif From f528697b74eebad12ea3eb42fb26496daeb7dab1 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 26 Sep 2022 16:02:13 -0400 Subject: [PATCH 29/58] Updates based on review feedback --- cpp/include/raft/spatial/knn/ball_cover.cuh | 7 ++-- .../raft/spatial/knn/ball_cover_types.hpp | 2 +- cpp/include/raft/spatial/knn/ivf_flat.cuh | 42 +++++++++---------- cpp/test/spatial/ann_ivf_flat.cu | 4 +- 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index 63ce2019f8..c3e457c871 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -164,13 +164,12 @@ void rbc_all_knn_query(const raft::handle_t& handle, raft::device_matrix_view inds, raft::device_matrix_view dists, int_t k = 5, - bool perform_post_filtering = true, - float weight = 1.0) + bool perform_post_filtering = true) { RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); RAFT_EXPECTS(k <= index.m, "k must be less than or equal to the number of data points in the index"); - RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), + RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k), "Number of columns in output indices and distances matrices must be equal to k"); RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0), @@ -178,7 +177,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, "in index matrix."); rbc_all_knn_query( - handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); + handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, 1.0); } /** diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp index 47c9397fdd..1dd45365b7 100644 --- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp @@ -48,7 +48,7 @@ class BallCoverIndex { value_int n_, raft::distance::DistanceType metric_) : handle(handle_), - X(std::move(raft::make_device_matrix_view(X_, m_, n_))), + X(raft::make_device_matrix_view(X_, m_, n_)), m(m_), n(n_), metric(metric_), diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index 88c08f77e6..288834214f 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -52,11 +52,11 @@ namespace raft::spatial::knn::ivf_flat { * @tparam T data element type * @tparam IdxT type of the indices in the source dataset * - * @param handle - * @param params configure the index building + * @param[in] handle + * @param[in] params configure the index building * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] - * @param n_rows the number of samples - * @param dim the dimensionality of the data + * @param[in] n_rows the number of samples + * @param[in] dim the dimensionality of the data * * @return the constructed ivf-flat index */ @@ -94,8 +94,8 @@ inline auto build( * @tparam int_t precision / type of integral arguments * @tparam matrix_idx_t matrix indexing type * - * @param handle - * @param params configure the index building + * @param[in] handle + * @param[in] params configure the index building * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim] * * @return the constructed ivf-flat index @@ -182,8 +182,8 @@ inline auto extend(const handle_t& handle, * @tparam int_t precision / type of integral arguments * @tparam matrix_idx_t matrix indexing type * - * @param handle - * @param orig_index original index + * @param[in] handle + * @param[in] orig_index original index * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` @@ -221,7 +221,7 @@ auto extend(const handle_t& handle, * @param[in] new_indices a device pointer to a vector of indices [n_rows]. * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` * here to imply a continuous range `[0...n_rows)`. - * @param n_rows the number of samples + * @param[in] n_rows the number of samples */ template inline void extend(const handle_t& handle, @@ -241,7 +241,7 @@ inline void extend(const handle_t& handle, * @tparam int_t precision / type of integral arguments * @tparam matrix_idx_t matrix indexing type * - * @param handle + * @param[in] handle * @param[inout] index * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()] * @param[in] new_indices a device pointer to a vector of indices [n_rows]. @@ -295,17 +295,17 @@ void extend( * @tparam T data element type * @tparam IdxT type of the indices * - * @param handle - * @param params configure the search - * @param index ivf-flat constructed index + * @param[in] handle + * @param[in] params configure the search + * @param[in] index ivf-flat constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param n_queries the batch size - * @param k the number of neighbors to find for each query. + * @param[in] n_queries the batch size + * @param[in] k the number of neighbors to find for each query. * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param mr an optional memory resource to use across the searches (you can provide a large enough - * memory pool here to avoid memory allocations within search). + * @param[in] mr an optional memory resource to use across the searches (you can provide a large + * enough memory pool here to avoid memory allocations within search). */ template inline void search(const handle_t& handle, @@ -354,14 +354,14 @@ inline void search(const handle_t& handle, * @tparam int_t precision / type of integral arguments * @tparam matrix_idx_t matrix indexing type * - * @param handle - * @param index ivf-flat constructed index + * @param[in] handle + * @param[in] index ivf-flat constructed index * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param params configure the search - * @param k the number of neighbors to find for each query. + * @param[in] params configure the search + * @param[in] k the number of neighbors to find for each query. */ template { int64_t half_of_data = ps.num_db_vecs / 2; - auto half_of_data_view = raft::make_device_matrix_view( - (const DataT*)database.data(), static_cast(half_of_data), ps.dim); + auto half_of_data_view = raft::make_device_matrix_view( + (const DataT*)database.data(), half_of_data, ps.dim); auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); From 90bbb330738a526f40e53b36163e1854869ca2db Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 26 Sep 2022 18:20:14 -0400 Subject: [PATCH 30/58] Getting to nbuild --- cpp/include/raft/spatial/knn/ball_cover.cuh | 2 +- cpp/include/raft/spatial/knn/brute_force.cuh | 2 +- cpp/test/spatial/ball_cover.cu | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index c3e457c871..6628926c97 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -163,7 +163,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, BallCoverIndex& index, raft::device_matrix_view inds, raft::device_matrix_view dists, - int_t k = 5, + int_t k, bool perform_post_filtering = true) { RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); diff --git a/cpp/include/raft/spatial/knn/brute_force.cuh b/cpp/include/raft/spatial/knn/brute_force.cuh index 128c4bbd54..6c7b9dd893 100644 --- a/cpp/include/raft/spatial/knn/brute_force.cuh +++ b/cpp/include/raft/spatial/knn/brute_force.cuh @@ -93,9 +93,9 @@ inline void knn_merge_parts( * @param[out] indices matrix (size n*k) to store output knn indices * @param[out] distances matrix (size n*k) to store the output knn distance * @param[in] k the number of nearest neighbors to return + * @param[in] metric distance metric to use. Euclidean (L2) is used by default * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. - * @param[in] metric distance metric to use. Euclidean (L2) is used by default * @param[in] translations starting offsets for partitions. should be the same size * as input vector. */ diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 15f6b5fa87..273c5a966b 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -216,7 +216,7 @@ class BallCoverKNNQueryTest : public ::testing::TestWithParam index(handle, X_view, metric); raft::spatial::knn::rbc_all_knn_query( - handle, index, d_pred_I_view, d_pred_D_view, k, true, weight); + handle, index, d_pred_I_view, d_pred_D_view, k, true); handle.sync_stream(); // What we really want are for the distances to match exactly. The From 8cb3ec522386a1da53c04934594fc8f5b2c957a5 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 26 Sep 2022 18:22:51 -0400 Subject: [PATCH 31/58] Fixing style --- cpp/test/spatial/ball_cover.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 273c5a966b..d9ad9cc358 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -304,8 +304,7 @@ class BallCoverAllKNNTest : public ::testing::TestWithParam index(handle, X_view, metric); - raft::spatial::knn::rbc_all_knn_query( - handle, index, d_pred_I_view, d_pred_D_view, k, true); + raft::spatial::knn::rbc_all_knn_query(handle, index, d_pred_I_view, d_pred_D_view, k, true); handle.sync_stream(); // What we really want are for the distances to match exactly. The From c4bd2d17dc464b3184c1bb6448b58597913fab86 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 26 Sep 2022 19:11:46 -0400 Subject: [PATCH 32/58] Removing files from raft;:Matrix which still need to be tested --- cpp/include/raft/matrix/diagonal.cuh | 53 -------- cpp/include/raft/matrix/matrix_vector.cuh | 142 ---------------------- cpp/include/raft/matrix/norm.cuh | 37 ------ cpp/include/raft/matrix/reverse.cuh | 85 ------------- cpp/include/raft/matrix/slice.cuh | 55 --------- cpp/include/raft/matrix/triangular.cuh | 39 ------ 6 files changed, 411 deletions(-) delete mode 100644 cpp/include/raft/matrix/diagonal.cuh delete mode 100644 cpp/include/raft/matrix/matrix_vector.cuh delete mode 100644 cpp/include/raft/matrix/norm.cuh delete mode 100644 cpp/include/raft/matrix/reverse.cuh delete mode 100644 cpp/include/raft/matrix/slice.cuh delete mode 100644 cpp/include/raft/matrix/triangular.cuh diff --git a/cpp/include/raft/matrix/diagonal.cuh b/cpp/include/raft/matrix/diagonal.cuh deleted file mode 100644 index f5ab33ebd7..0000000000 --- a/cpp/include/raft/matrix/diagonal.cuh +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft::matrix { - -/** - * @brief Initialize a diagonal matrix with a vector - * @param vec: vector of length k = min(n_rows, n_cols) - * @param matrix: matrix of size n_rows x n_cols - */ -template -void initialize_diagonal(const raft::handle_t& handle, - raft::device_vector_view vec, - raft::device_matrix_view matrix) -{ - detail::initializeDiagonalMatrix(vec.data_handle(), - matrix.data_handle(), - matrix.extent(0), - matrix.extent(1), - handle.get_stream()); -} - -/** - * @brief Take reciprocal of elements on diagonal of square matrix (in-place) - * @param in: square input matrix with size len x len - */ -template -void invert_diagonal(const raft::handle_t& handle, - raft::device_matrix_view in) -{ - RAFT_EXPECTS(in.extent(0) == in.extent(1), "Matrix must be square."); - detail::getDiagonalInverseMatrix(in.data_handle(), in.extent(0), handle.get_stream()); -} -} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/matrix_vector.cuh b/cpp/include/raft/matrix/matrix_vector.cuh deleted file mode 100644 index 5d05d03d2c..0000000000 --- a/cpp/include/raft/matrix/matrix_vector.cuh +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/matrix.cuh" -#include - -namespace raft::matrix { - -/** - * @brief multiply each row or column of matrix with vector, skipping zeros in vector - * @param data input matrix, results are in-place - * @param vec input vector - * @param n_row number of rows of input matrix - * @param n_col number of columns of input matrix - * @param rowMajor whether matrix is row major - * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns - * @param stream cuda stream - */ -template -void binary_mult_skip_zero(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) -{ - detail::matrixVectorBinaryMultSkipZero( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); -} - -/** - * @brief divide each row or column of matrix with vector - * @param data input matrix, results are in-place - * @param vec input vector - * @param n_row number of rows of input matrix - * @param n_col number of columns of input matrix - * @param rowMajor whether matrix is row major - * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns - * @param stream cuda stream - */ -template -void binary_div(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) -{ - detail::matrixVectorBinaryDiv( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); -} - -/** - * @brief divide each row or column of matrix with vector, skipping zeros in vector - * @param data input matrix, results are in-place - * @param vec input vector - * @param n_row number of rows of input matrix - * @param n_col number of columns of input matrix - * @param rowMajor whether matrix is row major - * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns - * @param stream cuda stream - * @param return_zero result is zero if true and vector value is below threshold, original value if - * false - */ -template -void binary_div_skip_zero(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream, - bool return_zero = false) -{ - detail::matrixVectorBinaryDivSkipZero( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream, return_zero); -} - -/** - * @brief add each row or column of matrix with vector - * @param data input matrix, results are in-place - * @param vec input vector - * @param n_row number of rows of input matrix - * @param n_col number of columns of input matrix - * @param rowMajor whether matrix is row major - * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns - * @param stream cuda stream - */ -template -void binary_add(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) -{ - detail::matrixVectorBinaryAdd( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); -} - -/** - * @brief subtract each row or column of matrix with vector - * @param data input matrix, results are in-place - * @param vec input vector - * @param n_row number of rows of input matrix - * @param n_col number of columns of input matrix - * @param rowMajor whether matrix is row major - * @param bcastAlongRows whether to broadcast vector along rows of matrix or columns - * @param stream cuda stream - */ -template -void binary_sub(Type* data, - const Type* vec, - IdxType n_row, - IdxType n_col, - bool rowMajor, - bool bcastAlongRows, - cudaStream_t stream) -{ - detail::matrixVectorBinarySub( - data, vec, n_row, n_col, rowMajor, bcastAlongRows, stream); -} - -} // namespace raft::matrix \ No newline at end of file diff --git a/cpp/include/raft/matrix/norm.cuh b/cpp/include/raft/matrix/norm.cuh deleted file mode 100644 index 5c1e0b9c01..0000000000 --- a/cpp/include/raft/matrix/norm.cuh +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft::matrix { - -/** - * @brief Get the L2/F-norm of a matrix - * @param handle - * @param in: input matrix/vector with totally size elements - * @param size: size of the matrix/vector - * @param stream: cuda stream - */ -template -m_t l2_norm(const raft::handle_t& handle, raft::device_mdspan in) -{ - return detail::getL2Norm(handle, in.data_handle(), in.size(), handle.get_stream()); -} -} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/reverse.cuh b/cpp/include/raft/matrix/reverse.cuh deleted file mode 100644 index 8a9837f467..0000000000 --- a/cpp/include/raft/matrix/reverse.cuh +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft::matrix { - -/** - * @brief Columns of a column major matrix are reversed in place (i.e. first column and - * last column are swapped) - * @param inout: input and output matrix - * @param n_rows: number of rows of input matrix - * @param n_cols: number of columns of input matrix - * @param stream: cuda stream - */ -template -void col_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) -{ - detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); -} - -/** - * @brief Columns of a column major matrix are reversed in place (i.e. first column and - * last column are swapped) - * @param inout: input and output matrix - * @param n_rows: number of rows of input matrix - * @param n_cols: number of columns of input matrix - * @param stream: cuda stream - */ -template -void col_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) -{ - detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), stream); -} - -/** - * @brief Rows of a column major matrix are reversed in place (i.e. first row and last - * row are swapped) - * @param inout: input and output matrix - * @param n_rows: number of rows of input matrix - * @param n_cols: number of columns of input matrix - * @param stream: cuda stream - */ -template -void row_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) -{ - detail::rowReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); -} - -/** - * @brief Rows of a column major matrix are reversed in place (i.e. first row and last - * row are swapped) - * @param inout: input and output matrix - * @param n_rows: number of rows of input matrix - * @param n_cols: number of columns of input matrix - * @param stream: cuda stream - */ -template -void row_reverse(const raft::handle_t& handle, - raft::device_matrix_view inout) -{ - detail::colReverse(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); -} - -} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh deleted file mode 100644 index acb2e35793..0000000000 --- a/cpp/include/raft/matrix/slice.cuh +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft::matrix { - -/** - * @brief Slice a matrix (in-place) - * @param handle: raft handle - * @param in: input matrix (column-major) - * @param out: output matrix (column-major) - * @param x1, y1: coordinate of the top-left point of the wanted area (0-based) - * @param x2, y2: coordinate of the bottom-right point of the wanted area - * (1-based) - * example: Slice the 2nd and 3rd columns of a 4x3 matrix: slice_matrix(M_d, 4, - * 3, 0, 1, 4, 3); - */ -template -void slice(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - idx_t x1, - idx_t y1, - idx_t x2, - idx_t y2) -{ - detail::sliceMatrix(in.data_handle(), - in.extent(0), - in.extent(1), - out.data_handle(), - x1, - y1, - x2, - y2, - handle.get_stream()); -} -} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh deleted file mode 100644 index 6f520841eb..0000000000 --- a/cpp/include/raft/matrix/triangular.cuh +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft::matrix { - -/** - * @brief Copy the upper triangular part of a matrix to another - * @param[in] handle: raft handle - * @param[in] src: input matrix with a size of n_rows x n_cols - * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) - */ -template -void upper_triangular(const raft::handle_t& handle, - raft::device_matrix_view src, - raft::device_matrix_view dst) -{ - detail::copyUpperTriangular( - src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); -} -} // namespace raft::matrix From b166814d1cc7ecbf6e1efb9ab43f80b0344a17fb Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 26 Sep 2022 20:22:34 -0400 Subject: [PATCH 33/58] Progress on matrix API --- cpp/include/raft/matrix/col_wise_sort.cuh | 3 +- cpp/include/raft/matrix/copy.cuh | 6 +-- cpp/include/raft/matrix/gather.cuh | 1 + cpp/include/raft/matrix/linewise_op.cuh | 24 ++++++---- .../raft/matrix/{seq_root.cuh => sqrt.cuh} | 10 ++-- cpp/test/matrix/linewise_op.cu | 36 ++++++++++++--- cpp/test/matrix/math.cu | 40 ++++++++++++---- cpp/test/matrix/matrix.cu | 46 +++++++++++-------- 8 files changed, 112 insertions(+), 54 deletions(-) rename cpp/include/raft/matrix/{seq_root.cuh => sqrt.cuh} (92%) diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 74f78796e8..f2a6c69c71 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -65,8 +65,7 @@ void sort_cols_per_row( const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - std::optional> sorted_keys = - std::nullptr) + std::optional> sorted_keys = std::nullopt) { RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0), "Input and output matrices must have the same shape."); diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 4a00a0f732..455cbcacbd 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -33,7 +33,7 @@ namespace raft::matrix { * @param[out] out output matrix * @param[in] indices of the rows to be copied */ -template +template void copy_rows(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, @@ -64,7 +64,7 @@ void copy_rows(const raft::handle_t& handle, * @param[in] in: input matrix * @param[out] out: output matrix */ -template +template void copy(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out) @@ -86,7 +86,7 @@ void copy(const raft::handle_t& handle, * @param out_n_cols: number of columns of output matrix * @param stream: cuda stream */ -template +template void trunc_zero_origin( m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) { diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index b5e8b4af7c..80de561cd0 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index 2321548b35..02086d52d2 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -33,9 +33,6 @@ namespace raft::matrix { * @param [out] out result of the operation; can be same as `in`; should be aligned the same * as `in` to allow faster vectorized memory transfers. * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. - * @param [in] lineLen length of matrix line in elements (`=nCols` in row-major or `=nRows` in - * col-major) - * @param [in] nLines number of matrix lines (`=nRows` in row-major or `=nCols` in col-major) * @param [in] alongLines whether vectors are indices along or across lines. * @param [in] op the operation applied on each line: * for i in [0..lineLen) and j in [0..nLines): @@ -46,17 +43,26 @@ namespace raft::matrix { * @param [in] vecs zero or more vectors to be passed as arguments, * size of each vector is `alongLines ? lineLen : nLines`. */ -template +template void linewise_op(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - const idx_t lineLen, - const idx_t nLines, + raft::device_matrix_view in, + raft::device_matrix_view out, const bool alongLines, Lambda op, raft::device_vector_view... vecs) { + constexpr auto is_rowmajor = std::is_same_v; + constexpr auto is_colmajor = std::is_same_v; + + static_assert(is_rowmajor || is_colmajor, "layout for in and out must be either row or col major"); + + const idx_t lineLen = is_rowmajor ? in.extent(1) : in.extent(0); + const idx_t nLines = is_rowmajor ? in.extent(0) : in.extent(1); + + RAFT_EXPECTS(out.extent(0) == in.extent(0) && + out.extent(1) == in.extent(1), "Input and output must have the same shape."); + detail::MatrixLinewiseOp<16, 256>::run( - out, in, lineLen, nLines, alongLines, op, stream, vecs...); + out.data_handle(), in.data_handle(), lineLen, nLines, alongLines, op, stream, vecs...); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/seq_root.cuh b/cpp/include/raft/matrix/sqrt.cuh similarity index 92% rename from cpp/include/raft/matrix/seq_root.cuh rename to cpp/include/raft/matrix/sqrt.cuh index 0f74799154..8d5edf679b 100644 --- a/cpp/include/raft/matrix/seq_root.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -30,7 +30,7 @@ namespace raft::matrix { * @param[out] out: output matrix. The result is stored in the out matrix */ template -void seq_root(const raft::handle_t& handle, +void sqrt(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out) { @@ -44,8 +44,8 @@ void seq_root(const raft::handle_t& handle, * @param[in] handle: raft handle * @param[inout] inout: input matrix with in-place results */ -template -void seq_root(const raft::handle_t& handle, raft::device_matrix_view inout) +template +void sqrt(const raft::handle_t& handle, raft::device_matrix_view inout) { detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream()); } @@ -60,7 +60,7 @@ void seq_root(const raft::handle_t& handle, raft::device_matrix_view ino * @param[in] set_neg_zero whether to set negative numbers to zero */ template -void weighted_seq_root(const raft::handle_t& handle, +void weighted_sqrt(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, math_t scalar, @@ -80,7 +80,7 @@ void weighted_seq_root(const raft::handle_t& handle, * @param set_neg_zero whether to set negative numbers to zero */ template -void weighted_seq_root(const raft::handle_t& handle, +void weighted_sqrt(const raft::handle_t& handle, raft::device_matrix_view inout, math_t scalar, bool set_neg_zero = false) diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 16e2ceb29a..27a4d9f05d 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -54,23 +54,37 @@ struct LinewiseTest : public ::testing::TestWithParam void runLinewiseSum( - T* out, const T* in, const I lineLen, const I nLines, const bool alongLines, const T* vec) + T* out, const T* in, const I lineLen, const I nLines, const T* vec) { auto f = [] __device__(T a, T b) -> T { return a + b; }; - matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); + + constexpr auto layout = alongLines ? row_major : col_major; + + auto in_view = raft::make_device_matrix_view(in, nLines, lineLen) + auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); + + auto vec_view = raft::make_device_vector_view(vec, lineLen); + matrix::line_wise_op(handle, in_view, out_view, alongLines, f, vec); } + template void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, - const bool alongLines, const T* vec1, const T* vec2) { auto f = [] __device__(T a, T b, T c) -> T { return a + b + c; }; - matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, vec2); + + constexpr auto layout = alongLines ? row_major : col_major; + + auto in_view = raft::make_device_matrix_view(in, nLines, lineLen) + auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); + + matrix::line_wise_op(handle, in_view, out_view, alongLines, f, vec1, vec2); } rmm::device_uvector genData(size_t workSizeBytes) @@ -149,7 +163,11 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1); + } else { + runLinewiseSum(out, in, lineLen, nLines, vec1); + } } if (params.checkCorrectness) { linalg::naiveMatVec( @@ -161,7 +179,13 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1, vec2); + + } else { + runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); + + } } if (params.checkCorrectness) { linalg::naiveMatVec( diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index d550852150..fb49fc804a 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -16,7 +16,15 @@ #include "../test_utils.h" #include -#include + +#include +#include +#include +#include +#include +#include +#include + #include #include @@ -147,16 +155,23 @@ class MathTest : public ::testing::TestWithParam> { uniform(handle, r, in_sign_flip.data(), len, T(-100.0), T(100.0)); naivePower(in_power.data(), out_power_ref.data(), len, stream); - power(in_power.data(), len, stream); + + auto in_power_view = raft::make_device_matrix_view(in_power.data(), len, 1); + power(handle, in_power_view); naiveSqrt(in_sqrt.data(), out_sqrt_ref.data(), len, stream); - seqRoot(in_sqrt.data(), len, stream); - ratio(handle, in_ratio.data(), in_ratio.data(), 4, stream); + auto in_sqrt_view = raft::make_device_matrix_view(in_sqrt.data(), len, 1); + sqrt(handle, in_sqrt_view); + + auto in_ratio_view = raft::make_device_matrix_view(in_ratio.data(), 4, 1); + ratio(handle, in_ratio_view); naiveSignFlip( in_sign_flip.data(), out_sign_flip_ref.data(), params.n_row, params.n_col, stream); - signFlip(in_sign_flip.data(), params.n_row, params.n_col, stream); + + auto in_sign_flip_view = raft::make_device_matrix_view(in_sign_flip.data(), params.n_row, params.n_col); + sign_flip(handle, in_sign_flip_view); // default threshold is 1e-15 std::vector in_recip_h = {0.1, 0.01, -0.01, 0.1e-16}; @@ -165,18 +180,23 @@ class MathTest : public ::testing::TestWithParam> { update_device(in_recip_ref.data(), in_recip_ref_h.data(), 4, stream); T recip_scalar = T(1.0); - // this `reciprocal()` has to go first bc next one modifies its input - reciprocal(in_recip.data(), out_recip.data(), recip_scalar, 4, stream); + auto in_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); + auto out_recip_view = raft::make_device_matrix_view(out_recip.data(), 4, 1); - reciprocal(in_recip.data(), recip_scalar, 4, stream, true); + // this `reciprocal()` has to go first bc next one modifies its input + reciprocal(handle, in_recip_view, out_recip_view, recip_scalar); + reciprocal(in_recip_view, recip_scalar, 4, stream, true); std::vector in_small_val_zero_h = {0.1, 1e-16, -1e-16, -0.1}; std::vector in_small_val_zero_ref_h = {0.1, 0.0, 0.0, -0.1}; + auto in_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); + auto out_smallzero_view = raft::make_device_matrix_view(out_smallzero.data(), 4, 1); + update_device(in_smallzero.data(), in_small_val_zero_h.data(), 4, stream); update_device(out_smallzero_ref.data(), in_small_val_zero_ref_h.data(), 4, stream); - setSmallValuesZero(out_smallzero.data(), in_smallzero.data(), 4, stream); - setSmallValuesZero(in_smallzero.data(), 4, stream); + zero_small_values(handle, in_smallzero_view, out_smallzero_view); + zero_small_values(handle, in_smallzero_view); handle.sync_stream(stream); } diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 6ccd7aa335..9c10d78e75 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -16,7 +16,9 @@ #include "../test_utils.h" #include -#include +#include + +#include #include #include #include @@ -61,12 +63,17 @@ class MatrixTest : public ::testing::TestWithParam> { int len = params.n_row * params.n_col; uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); - copy(in1.data(), in2.data(), params.n_row, params.n_col, stream); + auto in1_view = raft::make_device_matrix_view(in1.data(), params.n_row, params.n_col); + auto in2_view = raft::Make_device_matrix_view(in2.data(), params.n_row, params.n_col); + + copy(handle, in1_view, in2_view); // copy(in1, in1_revr, params.n_row, params.n_col); // colReverse(in1_revr, params.n_row, params.n_col); rmm::device_uvector outTrunc(6, stream); - truncZeroOrigin(in1.data(), params.n_row, outTrunc.data(), 3, 2, stream); + + auto out_trunc_view = raft::make_device_matrix_view(outTrunc.data(), 3, 2); + trunc_zero_origin(handle, in1_view, out_trunc_view); handle.sync_stream(stream); } @@ -128,24 +135,25 @@ class MatrixCopyRowsTest : public ::testing::Test { void testCopyRows() { - copyRows(input.data(), - n_rows, - n_cols, - output.data(), - indices.data(), - n_selected, - handle.get_stream(), - false); + auto input_view = raft::device_matrix_view(input.data(), n_rows, n_cols); + auto output_view = raft::device_matrix_view(output.data(), n_rows, n_cols); + auto indices_view = raft::device_vector_view(indices.data(), n_selected); + + copy_rows(handle, + input_view, + output_view, + indices_view); + EXPECT_TRUE(raft::devArrMatchHost( output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); - copyRows(input.data(), - n_rows, - n_cols, - output.data(), - indices.data(), - n_selected, - handle.get_stream(), - true); + + auto input_row_view = raft::device_matrix_view(input.data(), n_rows, n_cols); + auto output_row_view = raft::device_matrix_view(output.data(), n_rows, n_cols); + + copy_rows(handle, + input_row_view, + output_row_view, + indices_view); EXPECT_TRUE(raft::devArrMatchHost( output_exp_rowmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); } From 4c552bc0770b90c4d604f8123780e2e9ea5be066 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 27 Sep 2022 10:12:27 -0400 Subject: [PATCH 34/58] Adding weight back into rbc --- cpp/include/raft/spatial/knn/ball_cover.cuh | 26 +++++++++++---------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index 6628926c97..5bc839b67f 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -155,16 +155,17 @@ void rbc_all_knn_query(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template + typename int_t, + typename matrix_idx_t> void rbc_all_knn_query(const raft::handle_t& handle, BallCoverIndex& index, raft::device_matrix_view inds, raft::device_matrix_view dists, int_t k, - bool perform_post_filtering = true) + bool perform_post_filtering = true, + float weight = 1.0) { RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); RAFT_EXPECTS(k <= index.m, @@ -177,7 +178,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, "in index matrix."); rbc_all_knn_query( - handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, 1.0); + handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight); } /** @@ -209,7 +210,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, * looking in the closest landmark. * @param[in] n_query_pts number of query points */ -template +template void rbc_knn_query(const raft::handle_t& handle, BallCoverIndex& index, int_t k, @@ -278,17 +279,18 @@ void rbc_knn_query(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template + typename int_t, + typename matrix_idx_t> void rbc_knn_query(const raft::handle_t& handle, BallCoverIndex& index, raft::device_matrix_view query, raft::device_matrix_view inds, raft::device_matrix_view dists, - int_t k = 5, - bool perform_post_filtering = true) + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) { RAFT_EXPECTS(k <= index.m, "k must be less than or equal to the number of data points in the index"); @@ -310,7 +312,7 @@ void rbc_knn_query(const raft::handle_t& handle, inds.data_handle(), dists.data_handle(), perform_post_filtering, - 1.0); + weight); } // TODO: implement functions for: From 53a254eb2a66d9168af1511a8a67cb370dc347f3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 27 Sep 2022 10:17:01 -0400 Subject: [PATCH 35/58] Style check --- cpp/include/raft/spatial/knn/ball_cover.cuh | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index 5bc839b67f..704acafd45 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -155,17 +155,14 @@ void rbc_all_knn_query(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template +template void rbc_all_knn_query(const raft::handle_t& handle, BallCoverIndex& index, raft::device_matrix_view inds, raft::device_matrix_view dists, int_t k, bool perform_post_filtering = true, - float weight = 1.0) + float weight = 1.0) { RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); RAFT_EXPECTS(k <= index.m, @@ -279,10 +276,7 @@ void rbc_knn_query(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template +template void rbc_knn_query(const raft::handle_t& handle, BallCoverIndex& index, raft::device_matrix_view query, @@ -290,7 +284,7 @@ void rbc_knn_query(const raft::handle_t& handle, raft::device_matrix_view dists, int_t k, bool perform_post_filtering = true, - float weight = 1.0) + float weight = 1.0) { RAFT_EXPECTS(k <= index.m, "k must be less than or equal to the number of data points in the index"); From 278ce4dc88fb0b7553cf016394439fbb731e7283 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 27 Sep 2022 14:59:31 -0400 Subject: [PATCH 36/58] More updates --- cpp/include/raft/matrix/argmax.cuh | 11 ++- cpp/include/raft/matrix/col_wise_sort.cuh | 17 +++-- cpp/include/raft/matrix/copy.cuh | 15 ++-- cpp/include/raft/matrix/detail/math.cuh | 18 ++--- cpp/include/raft/matrix/detail/matrix.cuh | 2 +- cpp/include/raft/matrix/math.cuh | 4 +- cpp/include/raft/matrix/power.cuh | 7 +- cpp/include/raft/matrix/ratio.cuh | 21 +++++- cpp/include/raft/matrix/reciprocal.cuh | 14 ++-- cpp/include/raft/matrix/threshold.cuh | 1 - cpp/test/matrix/argmax.cu | 89 +++++++++++------------ cpp/test/matrix/columnSort.cu | 15 ++-- cpp/test/matrix/gather.cu | 6 +- cpp/test/matrix/linewise_op.cu | 20 ++--- cpp/test/matrix/math.cu | 26 ++++--- cpp/test/matrix/matrix.cu | 31 ++++---- 16 files changed, 151 insertions(+), 146 deletions(-) diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index 5afd026745..013e496ea1 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -17,23 +17,22 @@ #pragma once #include -#include -#include +#include namespace raft::matrix { /** * @brief Argmax: find the row idx with maximum value for each column - * @param handle: raft handle - * @param in: input matrix of size (n_rows, n_cols) - * @param out: output vector of size n_cols + * @param[in] handle: raft handle + * @param[in] in: input matrix of size (n_rows, n_cols) + * @param[out] out: output vector of size n_cols */ template void argmax(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_vector_view out) { - RAFT_EXPECTS(out.extent(1) == in.extent(1), + RAFT_EXPECTS(static_cast(out.extent(1)) == in.extent(1), "Size of output vector must equal number of columns in input matrix."); detail::argmax( in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index f2a6c69c71..bb5251c346 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -64,7 +64,7 @@ template void sort_cols_per_row( const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_matrix_view out, + raft::device_matrix_view out, std::optional> sorted_keys = std::nullopt) { RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0), @@ -79,16 +79,19 @@ void sort_cols_per_row( size_t workspace_size = 0; bool alloc_workspace = false; + + InType *keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; + detail::sortColumnsPerRow( in.data_handle(), out.data_handle(), in.extent(0), in.extent(1), alloc_workspace, - nullptr, - &workspace_size, + (void*)nullptr, + workspace_size, handle.get_stream(), - sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); + keys); if (alloc_workspace) { auto workspace = raft::make_device_vector(handle, workspace_size); @@ -99,10 +102,10 @@ void sort_cols_per_row( in.extent(0), in.extent(1), alloc_workspace, - workspace.data_handle(), - &workspace_size, + (void*)workspace.data_handle(), + workspace_size, handle.get_stream(), - sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr); + keys); } } diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 455cbcacbd..842548452a 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -18,7 +18,6 @@ #include #include -#include namespace raft::matrix { @@ -47,7 +46,7 @@ void copy_rows(const raft::handle_t& handle, bool out_rowmajor = raft::is_row_major(out); RAFT_EXPECTS(in_rowmajor == out_rowmajor, - "Input and output matrices must have same layout (row- or column-major)") + "Input and output matrices must have same layout (row- or column-major)"); detail::copyRows(in.data_handle(), in.extent(0), @@ -88,9 +87,15 @@ void copy(const raft::handle_t& handle, */ template void trunc_zero_origin( - m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) -{ - detail::truncZeroOrigin(in, in_n_rows, out, out_n_rows, out_n_cols, stream); + const raft::handle_t &handle, + raft::device_matrix_view in, + raft::device_matrix_view out) { + + RAFT_EXPECTS(out.extent(0) <= in.extent(0) && + out.extent(1) <= in.extent(1), + "Output matrix must have less or equal number of rows and columns"); + + detail::truncZeroOrigin(in.data_handle(), in.extent(0), out.data_handle(), out.extent(0), out.extent(1), handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 95953feca4..8af0a31504 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -141,7 +141,7 @@ void setSmallValuesZero(math_t* inout, IdxType len, cudaStream_t stream, math_t } template -void reciprocal(math_t* in, +void reciprocal(const math_t* in, math_t* out, math_t scalar, int len, @@ -363,8 +363,8 @@ void matrixVectorBinarySub(Type* data, } // Computes the argmax(d_in) column-wise in a DxN matrix -template -__global__ void argmaxKernel(const T* d_in, int D, int N, T* argmax) +template +__global__ void argmaxKernel(const T* d_in, int D, int N, IdxT* argmax) { typedef cub::BlockReduce, TPB> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -384,19 +384,19 @@ __global__ void argmaxKernel(const T* d_in, int D, int N, T* argmax) if (threadIdx.x == 0) { argmax[blockIdx.x] = maxKV.key; } } -template -void argmax(const math_t* in, int n_rows, int n_cols, math_t* out, cudaStream_t stream) +template +void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) { int D = n_rows; int N = n_cols; if (D <= 32) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else if (D <= 64) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else if (D <= 128) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 1b343cf5b4..c425aad79b 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -67,7 +67,7 @@ void copyRows(const m_t* in, template void truncZeroOrigin( - m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) + const m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) { auto m = out_n_rows; auto k = in_n_rows; diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index 25ad185935..3c2705cf87 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -310,8 +310,8 @@ void ratio( * @param out: output vector of size n_cols * @param stream: cuda stream */ -template -void argmax(const math_t* in, int n_rows, int n_cols, math_t* out, cudaStream_t stream) +template +void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) { detail::argmax(in, n_rows, n_cols, out, stream); } diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh index 98d691afd4..60a3231bf0 100644 --- a/cpp/include/raft/matrix/power.cuh +++ b/cpp/include/raft/matrix/power.cuh @@ -17,8 +17,7 @@ #pragma once #include -#include -#include +#include namespace raft::matrix { @@ -59,7 +58,7 @@ void weighted_power(const raft::handle_t& handle, template void power(const raft::handle_t& handle, raft::device_matrix_view inout) { - detail::power(inout.data_handle(), inout.size(), handle.get_stream()); + detail::power(inout.data_handle(), inout.size(), handle.get_stream()); } /** @@ -75,7 +74,7 @@ void power(const raft::handle_t& handle, raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); - detail::power(in, out, len, stream); + detail::power(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh index 52c2180f95..ae2181757f 100644 --- a/cpp/include/raft/matrix/ratio.cuh +++ b/cpp/include/raft/matrix/ratio.cuh @@ -18,7 +18,6 @@ #include #include -#include namespace raft::matrix { @@ -32,12 +31,26 @@ namespace raft::matrix { * @param len: number elements of input matrix * @param stream cuda stream */ -template +template void ratio(const raft::handle_t& handle, raft::device_matrix_view src, raft::device_matrix_view dest) { - RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); - detail::ratio(handle, src.data_handle(), dest.data_handle(), in.size(), handle.get_stream()); + RAFT_EXPECTS(src.size() == dst.size(), "Input and output matrices must be the same size."); + detail::ratio(handle, src.data_handle(), dest.data_handle(), src.size(), handle.get_stream()); } + +/** + * @brief ratio of every element over sum of input vector is calculated + * @tparam math_t data-type upon which the math operation will be performed + * @tparam IdxType Integer type used to for addressing + * @param[in] handle + * @param[inout] inout: input matrix + */ + template + void ratio(const raft::handle_t& handle, + raft::device_matrix_view inout) + { + detail::ratio(handle, inout.data_handle(), inout.data_handle(), inout.size(), handle.get_stream()); + } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index f74d98969c..e4867aabd3 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -17,15 +17,13 @@ #pragma once #include -#include -#include +#include namespace raft::matrix { /** * @brief Reciprocal of every element in the input matrix * @tparam math_t data-type upon which the math operation will be performed - * @tparam IdxType Integer type used to for addressing * @param handle: raft handle * @param in: input matrix and also the result is stored * @param out: output matrix. The result is stored in the out matrix @@ -35,25 +33,23 @@ namespace raft::matrix { * @{ */ template -void reciprocal(raft::device_matrix_view in, +void reciprocal(const raft::handle_t &handle, + raft::device_matrix_view in, raft::device_matrix_view out, math_t scalar, bool setzero = false, math_t thres = 1e-15) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have the same size."); - detail::reciprocal( + detail::reciprocal( in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), setzero, thres); } /** * @brief Reciprocal of every element in the input matrix (in place) * @tparam math_t data-type upon which the math operation will be performed - * @tparam IdxType Integer type used to for addressing * @param inout: input matrix with in-place results * @param scalar: every element is multiplied with scalar - * @param len: number elements of input matrix - * @param stream cuda stream * @param setzero round down to zero if the input is less the threshold * @param thres the threshold used to forcibly set inputs to zero * @{ @@ -65,7 +61,7 @@ void reciprocal(const raft::handle_t& handle, bool setzero = false, math_t thres = 1e-15) { - detail::reciprocal( + detail::reciprocal( inout.data_handle(), scalar, inout.size(), handle.get_stream(), setzero, thres); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/threshold.cuh b/cpp/include/raft/matrix/threshold.cuh index bdcab41421..d13c55ea80 100644 --- a/cpp/include/raft/matrix/threshold.cuh +++ b/cpp/include/raft/matrix/threshold.cuh @@ -18,7 +18,6 @@ #include #include -#include namespace raft::matrix { diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 50ca4e3e92..b3abf93377 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -15,6 +15,7 @@ */ #include "../test_utils.h" +#include #include #include #include @@ -26,14 +27,14 @@ namespace raft { template struct ArgMaxInputs { - std::vector input_matrix; - std::vector output_matrix; - int n_cols; - int n_rows; + const std::vector input_matrix; + const std::vector output_matrix; + std::size_t n_cols; + std::size_t n_rows; }; - template - ::std::ostream& operator<<(::std::ostream& os, const ArgMaxInputs& dims) + template + ::std::ostream& operator<<(::std::ostream& os, const ArgMaxInputs& dims) { return os; } @@ -42,63 +43,57 @@ namespace raft { class ArgMaxTest : public ::testing::TestWithParam> { public: ArgMaxTest() - : params(::testing::TestWithParam>::GetParam()), - input(std::move(raft::make_device_matrix(handle, params.n_rows, params.n_cols))), - output(std::move(raft::make_device_vector(handle, params.n_rows))), - expected(std::move(raft::make_device_vector(handle, params.n_rows))) { + : params(::testing::TestWithParam>::GetParam()) + {} + + void test() { + + auto input = raft::make_device_matrix(handle, params.n_rows, params.n_cols); + auto output = raft::make_device_vector(handle, params.n_rows); + auto expected = raft::make_device_vector(handle, params.n_rows); + + raft::copy(input.data_handle(), params.input_matrix.data(), params.n_rows * params.n_cols, handle.get_stream()); + raft::copy(expected.data_handle(), params.output_matrix.data(), params.n_rows * params.n_cols, handle.get_stream()); + + auto input_view = raft::make_device_matrix_view(input.data_handle(), params.n_rows, params.n_cols); - raft::copy(input.data_handle(), params.input_matrix.data(), params.n_rows * params.n_cols); - raft::copy(expected.data_handle(), params.output_matrix.data(), params.n_rows * params.n_cols); + raft::matrix::argmax(handle, input_view, output.view()); - raft::matrix::argmax(handle, input, output); - } + ASSERT_TRUE(devArrMatch(output.data_handle(), + expected.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); + + } protected: raft::handle_t handle; - ArgMaxInputs params; - - raft::device_matrix input; - raft::device_vector output; - raft::device_vector expected; + ArgMaxInputs params; }; const std::vector> inputsf = { - {0.1f, 0.2f, 0.3f, 0.4f}, - {0.4f, 0.3f, 0.2f, 0.1f}, - {0.2f, 0.3f, 0.5f, 0.0f}, + { + {0.1f, 0.2f, 0.3f, 0.4f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.5f, 0.0f}, {3, 0, 2}, - 3, 4}; + 3, 4} + }; const std::vector> inputsd = { - {0.1, 0.2, 0.3, 0.4}, - {0.4, 0.3, 0.2, 0.1}, - {0.2, 0.3, 0.5, 0.0}, + { + {0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.5, 0.0}, {3, 0, 2}, - 3, 4}; + 3, 4} + }; typedef ArgMaxTest ArgMaxTestF; - TEST_P(ArgMaxTestF, Result) - { - ASSERT_TRUE(devArrMatch(output.data_handle(), - expected.data_handle(), - params.n_rows, - Compare(), - handle.get_stream())); - } + TEST_P(ArgMaxTestF, Result) { test(); } typedef ArgMaxTest ArgMaxTestD; - TEST_P(ArgMaxTestD, Result) -{ - ASSERT_TRUE(devArrMatch(output.data_handle(), - expected.data_handle(), - params.n_rows, - Compare(), - handle.get_stream())); -} - -INSTANTIATE_TEST_SUITE_P(ArgMaxTest, ArgMaxTestTestF, ::testing::ValuesIn(inputsf)); - -INSTANTIATE_TEST_SUITE_P(ArgMaxTest, ArgMaxTestTestD, ::testing::ValuesIn(inputsd)); + TEST_P(ArgMaxTestD, Result) { test(); } + +INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestF, ::testing::ValuesIn(inputsf)); +INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestD, ::testing::ValuesIn(inputsd)); } // namespace matrix } // namespace raft diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index f795eaa6f0..270a9bb52f 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -47,7 +47,7 @@ struct columnSort { }; template -::std::ohandle.get_stream() & operator<<(::std::ohandle.get_stream() & os, +::std::ostream & operator<<(::std::ostream & os, const columnSort& dims) { return os; @@ -61,8 +61,7 @@ class ColumnSort : public ::testing::TestWithParam> { keySorted(0, handle.get_stream()), keySortGolden(0, handle.get_stream()), valueOut(0, handle.get_stream()), - goldenValOut(0, handle.get_stream()), - workspacePtr(0, handle.get_stream()) + goldenValOut(0, handle.get_stream()) { } @@ -70,7 +69,6 @@ class ColumnSort : public ::testing::TestWithParam> { { params = ::testing::TestWithParam>::GetParam(); int len = params.n_row * params.n_col; - RAFT_CUDA_TRY(cudahandle.get_stream() Create(&handle.get_stream())); keyIn.resize(len, handle.get_stream()); valueOut.resize(len, handle.get_stream()); goldenValOut.resize(len, handle.get_stream()); @@ -105,13 +103,13 @@ class ColumnSort : public ::testing::TestWithParam> { if (params.testKeys) raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, handle.get_stream()); - auto key_in_view = raft::make_device_matrix_view(keyIn.data(), params.n_row, params.n_col); + auto key_in_view = raft::make_device_matrix_view(keyIn.data(), params.n_row, params.n_col); auto value_out_view = - raft::make_device_matrix_view(valueOut.data(), params.n_row, params.n_col); + raft::make_device_matrix_view(valueOut.data(), params.n_row, params.n_col); auto key_sorted_view = - raft::make_device_matrix_view(keySorted.data(), params.n_row, params.n_col); + raft::make_device_matrix_view(keySorted.data(), params.n_row, params.n_col); - raft::matrix::sort_cols_per_row(handle, key_in_view, value_out_view, key_sorted_view); + raft::matrix::sort_cols_per_row(handle, key_in_view, value_out_view, std::make_optional(key_sorted_view)); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } @@ -120,7 +118,6 @@ class ColumnSort : public ::testing::TestWithParam> { columnSort params; rmm::device_uvector keyIn, keySorted, keySortGolden; rmm::device_uvector valueOut, goldenValOut; // valueOut are indexes - rmm::device_uvector workspacePtr; raft::handle_t handle; }; diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 92c68f58f5..8b95b28542 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -58,9 +58,9 @@ void gatherLaunch(const raft::handle_t& handle, { typedef typename std::iterator_traits::value_type MapValueT; - auto in_view = raft::make_device_matrix_view(in, N, D); - auto map_view = raft::make_device_vector_view(map, map_length); - auto out_view = raft::make_device_matrix_view(out, N, D); + auto in_view = raft::make_device_matrix_view(in, N, D); + auto map_view = raft::make_device_vector_view(map, map_length); + auto out_view = raft::make_device_matrix_view(out, N, D); matrix::gather(handle, in_view, out_view, map_view); } diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 27a4d9f05d..430d842d8e 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -18,7 +18,8 @@ #include "../test_utils.h" #include #include -#include +#include +#include #include #include #include @@ -54,22 +55,21 @@ struct LinewiseTest : public ::testing::TestWithParam + template void runLinewiseSum( T* out, const T* in, const I lineLen, const I nLines, const T* vec) { auto f = [] __device__(T a, T b) -> T { return a + b; }; - constexpr auto layout = alongLines ? row_major : col_major; - auto in_view = raft::make_device_matrix_view(in, nLines, lineLen) + auto in_view = raft::make_device_matrix_view(in, nLines, lineLen); auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); - auto vec_view = raft::make_device_vector_view(vec, lineLen); + auto vec_view = raft::make_device_vector_view(vec, lineLen); matrix::line_wise_op(handle, in_view, out_view, alongLines, f, vec); } - template + template void runLinewiseSum(T* out, const T* in, const I lineLen, @@ -164,9 +164,9 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1); + runLinewiseSum(out, in, lineLen, nLines, vec1); } else { - runLinewiseSum(out, in, lineLen, nLines, vec1); + runLinewiseSum(out, in, lineLen, nLines, vec1); } } if (params.checkCorrectness) { @@ -180,10 +180,10 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1, vec2); + runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); } else { - runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); + runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); } } diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index fb49fc804a..a5e5ffc8dc 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -51,7 +51,7 @@ template __global__ void nativeSqrtKernel(Type* in, Type* out, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < len) { out[idx] = sqrt(in[idx]); } + if (idx < len) { out[idx] = std::sqrt(in[idx]); } } template @@ -157,21 +157,21 @@ class MathTest : public ::testing::TestWithParam> { naivePower(in_power.data(), out_power_ref.data(), len, stream); auto in_power_view = raft::make_device_matrix_view(in_power.data(), len, 1); - power(handle, in_power_view); + power(handle, in_power_view); naiveSqrt(in_sqrt.data(), out_sqrt_ref.data(), len, stream); auto in_sqrt_view = raft::make_device_matrix_view(in_sqrt.data(), len, 1); - sqrt(handle, in_sqrt_view); + sqrt(handle, in_sqrt_view); auto in_ratio_view = raft::make_device_matrix_view(in_ratio.data(), 4, 1); - ratio(handle, in_ratio_view); + ratio(handle, in_ratio_view); naiveSignFlip( in_sign_flip.data(), out_sign_flip_ref.data(), params.n_row, params.n_col, stream); auto in_sign_flip_view = raft::make_device_matrix_view(in_sign_flip.data(), params.n_row, params.n_col); - sign_flip(handle, in_sign_flip_view); + sign_flip(handle, in_sign_flip_view); // default threshold is 1e-15 std::vector in_recip_h = {0.1, 0.01, -0.01, 0.1e-16}; @@ -180,23 +180,27 @@ class MathTest : public ::testing::TestWithParam> { update_device(in_recip_ref.data(), in_recip_ref_h.data(), 4, stream); T recip_scalar = T(1.0); - auto in_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); + auto in_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); auto out_recip_view = raft::make_device_matrix_view(out_recip.data(), 4, 1); // this `reciprocal()` has to go first bc next one modifies its input - reciprocal(handle, in_recip_view, out_recip_view, recip_scalar); - reciprocal(in_recip_view, recip_scalar, 4, stream, true); + reciprocal(handle, in_recip_view, out_recip_view, recip_scalar); + + auto inout_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); + + reciprocal(handle, inout_recip_view, recip_scalar, true); std::vector in_small_val_zero_h = {0.1, 1e-16, -1e-16, -0.1}; std::vector in_small_val_zero_ref_h = {0.1, 0.0, 0.0, -0.1}; - auto in_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); + auto in_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); + auto inout_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); auto out_smallzero_view = raft::make_device_matrix_view(out_smallzero.data(), 4, 1); update_device(in_smallzero.data(), in_small_val_zero_h.data(), 4, stream); update_device(out_smallzero_ref.data(), in_small_val_zero_ref_h.data(), 4, stream); - zero_small_values(handle, in_smallzero_view, out_smallzero_view); - zero_small_values(handle, in_smallzero_view); + zero_small_values(handle, in_smallzero_view, out_smallzero_view); + zero_small_values(handle, inout_smallzero_view); handle.sync_stream(stream); } diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 9c10d78e75..f684ebfd45 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -63,17 +63,17 @@ class MatrixTest : public ::testing::TestWithParam> { int len = params.n_row * params.n_col; uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); - auto in1_view = raft::make_device_matrix_view(in1.data(), params.n_row, params.n_col); - auto in2_view = raft::Make_device_matrix_view(in2.data(), params.n_row, params.n_col); + auto in1_view = raft::make_device_matrix_view(in1.data(), params.n_row, params.n_col); + auto in2_view = raft::make_device_matrix_view(in2.data(), params.n_row, params.n_col); - copy(handle, in1_view, in2_view); + copy(handle, in1_view, in2_view); // copy(in1, in1_revr, params.n_row, params.n_col); // colReverse(in1_revr, params.n_row, params.n_col); rmm::device_uvector outTrunc(6, stream); - auto out_trunc_view = raft::make_device_matrix_view(outTrunc.data(), 3, 2); - trunc_zero_origin(handle, in1_view, out_trunc_view); + auto out_trunc_view = raft::make_device_matrix_view(outTrunc.data(), 3, 2); + trunc_zero_origin(handle, in1_view, out_trunc_view); handle.sync_stream(stream); } @@ -135,25 +135,20 @@ class MatrixCopyRowsTest : public ::testing::Test { void testCopyRows() { - auto input_view = raft::device_matrix_view(input.data(), n_rows, n_cols); - auto output_view = raft::device_matrix_view(output.data(), n_rows, n_cols); - auto indices_view = raft::device_vector_view(indices.data(), n_selected); + auto input_view = raft::make_device_matrix_view(input.data(), n_rows, n_cols); + auto output_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); - copy_rows(handle, - input_view, - output_view, - indices_view); + auto indices_view = raft::make_device_vector_view(indices.data(), n_selected); + + copy_rows(handle, input_view, output_view, indices_view); EXPECT_TRUE(raft::devArrMatchHost( output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); - auto input_row_view = raft::device_matrix_view(input.data(), n_rows, n_cols); - auto output_row_view = raft::device_matrix_view(output.data(), n_rows, n_cols); + auto input_row_view = raft::make_device_matrix_view(input.data(), n_rows, n_cols); + auto output_row_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); - copy_rows(handle, - input_row_view, - output_row_view, - indices_view); + copy_rows(handle, input_row_view, output_row_view, indices_view); EXPECT_TRUE(raft::devArrMatchHost( output_exp_rowmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); } From 14c4ff1d8402e1d4251c43a5b83ef37bbc231ae7 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 27 Sep 2022 17:11:20 -0400 Subject: [PATCH 37/58] Trying to figure out gather and linewise op --- cpp/include/raft/matrix/col_wise_sort.cuh | 44 ++++--- cpp/include/raft/matrix/copy.cuh | 30 ++--- cpp/include/raft/matrix/detail/math.cuh | 6 +- cpp/include/raft/matrix/gather.cuh | 45 ++++--- cpp/include/raft/matrix/linewise_op.cuh | 23 ++-- cpp/include/raft/matrix/ratio.cuh | 14 +-- cpp/include/raft/matrix/reciprocal.cuh | 2 +- cpp/include/raft/matrix/sqrt.cuh | 18 +-- cpp/test/matrix/argmax.cu | 136 +++++++++++----------- cpp/test/matrix/columnSort.cu | 17 +-- cpp/test/matrix/gather.cu | 39 ++++--- cpp/test/matrix/linewise_op.cu | 48 ++++---- cpp/test/matrix/math.cu | 23 ++-- cpp/test/matrix/matrix.cu | 27 +++-- 14 files changed, 245 insertions(+), 227 deletions(-) diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index bb5251c346..2a6ecf61a6 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -65,7 +65,8 @@ void sort_cols_per_row( const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - std::optional> sorted_keys = std::nullopt) + std::optional> sorted_keys = + std::nullopt) { RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0), "Input and output matrices must have the same shape."); @@ -79,33 +80,30 @@ void sort_cols_per_row( size_t workspace_size = 0; bool alloc_workspace = false; + InType* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; - InType *keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; - - detail::sortColumnsPerRow( - in.data_handle(), - out.data_handle(), - in.extent(0), - in.extent(1), - alloc_workspace, - (void*)nullptr, - workspace_size, - handle.get_stream(), - keys); + detail::sortColumnsPerRow(in.data_handle(), + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + (void*)nullptr, + workspace_size, + handle.get_stream(), + keys); if (alloc_workspace) { auto workspace = raft::make_device_vector(handle, workspace_size); - detail::sortColumnsPerRow( - in.data_handle(), - out.data_handle(), - in.extent(0), - in.extent(1), - alloc_workspace, - (void*)workspace.data_handle(), - workspace_size, - handle.get_stream(), - keys); + detail::sortColumnsPerRow(in.data_handle(), + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + (void*)workspace.data_handle(), + workspace_size, + handle.get_stream(), + keys); } } diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 842548452a..c99f86daac 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -32,11 +32,11 @@ namespace raft::matrix { * @param[out] out output matrix * @param[in] indices of the rows to be copied */ -template +template void copy_rows(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view indices) + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view indices) { RAFT_EXPECTS(in.extent(1) == out.extent(1), "Input and output matrices must have same number of columns"); @@ -45,9 +45,6 @@ void copy_rows(const raft::handle_t& handle, bool in_rowmajor = raft::is_row_major(in); bool out_rowmajor = raft::is_row_major(out); - RAFT_EXPECTS(in_rowmajor == out_rowmajor, - "Input and output matrices must have same layout (row- or column-major)"); - detail::copyRows(in.data_handle(), in.extent(0), in.extent(1), @@ -86,16 +83,19 @@ void copy(const raft::handle_t& handle, * @param stream: cuda stream */ template -void trunc_zero_origin( - const raft::handle_t &handle, - raft::device_matrix_view in, - raft::device_matrix_view out) { - - RAFT_EXPECTS(out.extent(0) <= in.extent(0) && - out.extent(1) <= in.extent(1), +void trunc_zero_origin(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(out.extent(0) <= in.extent(0) && out.extent(1) <= in.extent(1), "Output matrix must have less or equal number of rows and columns"); - detail::truncZeroOrigin(in.data_handle(), in.extent(0), out.data_handle(), out.extent(0), out.extent(1), handle.get_stream()); + detail::truncZeroOrigin(in.data_handle(), + in.extent(0), + out.data_handle(), + out.extent(0), + out.extent(1), + handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 8af0a31504..07b9ccc12b 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -392,11 +392,11 @@ void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t s if (D <= 32) { argmaxKernel<<>>(in, D, N, out); } else if (D <= 64) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else if (D <= 128) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 80de561cd0..4415fb866c 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -16,8 +16,8 @@ #pragma once -#include #include +#include #include namespace raft { @@ -61,27 +61,27 @@ void gather(const MatrixIteratorT in, * pointer type). * * @param in Pointer to the input matrix (assumed to be row-major) - * @param map Pointer to the input sequence of gather locations * @param out Pointer to the output matrix (assumed to be row-major) + * @param map Pointer to the input sequence of gather locations */ -template +template void gather(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map) + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - detail::gather(in.data_handle(), - in.extent(1), - in.extent(0), - map, - map.extent(0), - out.data_handle(), - handle.get_stream()); + gather(in.data_handle(), + static_cast(in.extent(1)), + static_cast(in.extent(0)), + map.data_handle(), + static_cast(map.extent(0)), + out.data_handle(), + handle.get_stream()); } /** @@ -104,7 +104,7 @@ template in, raft::device_matrix_view out, - raft::device_vector_view map, + raft::device_vector_view map, MapTransformOp transform_op) { RAFT_EXPECTS(out.extent(0) == map.extent(0), @@ -112,11 +112,22 @@ void gather(const raft::handle_t& handle, RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); + /** + * template +void gather(const MatrixIteratorT in, + int D, + int N, + MapIteratorT map, + int map_length, + MatrixIteratorT out, + cudaStream_t stream) + + */ detail::gather(in.data_handle(), - in.extent(1), - in.extent(0), + static_cast(in.extent(1)), + static_cast(in.extent(0)), map, - map.extent(0), + static_cast(map.extent(0)), out.data_handle(), transform_op, handle.get_stream()); diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index 02086d52d2..c4871711a5 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -43,26 +43,27 @@ namespace raft::matrix { * @param [in] vecs zero or more vectors to be passed as arguments, * size of each vector is `alongLines ? lineLen : nLines`. */ -template +template void linewise_op(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, const bool alongLines, Lambda op, - raft::device_vector_view... vecs) + raft::device_vector_view... vecs) { - constexpr auto is_rowmajor = std::is_same_v; - constexpr auto is_colmajor = std::is_same_v; + constexpr auto is_rowmajor = std::is_same_v; + constexpr auto is_colmajor = std::is_same_v; - static_assert(is_rowmajor || is_colmajor, "layout for in and out must be either row or col major"); + static_assert(is_rowmajor || is_colmajor, + "layout for in and out must be either row or col major"); - const idx_t lineLen = is_rowmajor ? in.extent(1) : in.extent(0); - const idx_t nLines = is_rowmajor ? in.extent(0) : in.extent(1); + const idx_t lineLen = is_rowmajor ? in.extent(1) : in.extent(0); + const idx_t nLines = is_rowmajor ? in.extent(0) : in.extent(1); - RAFT_EXPECTS(out.extent(0) == in.extent(0) && - out.extent(1) == in.extent(1), "Input and output must have the same shape."); + RAFT_EXPECTS(out.extent(0) == in.extent(0) && out.extent(1) == in.extent(1), + "Input and output must have the same shape."); - detail::MatrixLinewiseOp<16, 256>::run( - out.data_handle(), in.data_handle(), lineLen, nLines, alongLines, op, stream, vecs...); + // detail::MatrixLinewiseOp<16, 256>::run( + // out.data_handle(), in.data_handle(), lineLen, nLines, alongLines, op, stream, vecs...); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh index ae2181757f..573d906723 100644 --- a/cpp/include/raft/matrix/ratio.cuh +++ b/cpp/include/raft/matrix/ratio.cuh @@ -36,7 +36,7 @@ void ratio(const raft::handle_t& handle, raft::device_matrix_view src, raft::device_matrix_view dest) { - RAFT_EXPECTS(src.size() == dst.size(), "Input and output matrices must be the same size."); + RAFT_EXPECTS(src.size() == dest.size(), "Input and output matrices must be the same size."); detail::ratio(handle, src.data_handle(), dest.data_handle(), src.size(), handle.get_stream()); } @@ -47,10 +47,10 @@ void ratio(const raft::handle_t& handle, * @param[in] handle * @param[inout] inout: input matrix */ - template - void ratio(const raft::handle_t& handle, - raft::device_matrix_view inout) - { - detail::ratio(handle, inout.data_handle(), inout.data_handle(), inout.size(), handle.get_stream()); - } +template +void ratio(const raft::handle_t& handle, raft::device_matrix_view inout) +{ + detail::ratio( + handle, inout.data_handle(), inout.data_handle(), inout.size(), handle.get_stream()); +} } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index e4867aabd3..545c4592c5 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -33,7 +33,7 @@ namespace raft::matrix { * @{ */ template -void reciprocal(const raft::handle_t &handle, +void reciprocal(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, math_t scalar, diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index 8d5edf679b..b3791b0a7c 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -31,8 +31,8 @@ namespace raft::matrix { */ template void sqrt(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out) + raft::device_matrix_view in, + raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); @@ -61,10 +61,10 @@ void sqrt(const raft::handle_t& handle, raft::device_matrix_view inout) */ template void weighted_sqrt(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - math_t scalar, - bool set_neg_zero = false) + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t scalar, + bool set_neg_zero = false) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); detail::seqRoot( @@ -81,9 +81,9 @@ void weighted_sqrt(const raft::handle_t& handle, */ template void weighted_sqrt(const raft::handle_t& handle, - raft::device_matrix_view inout, - math_t scalar, - bool set_neg_zero = false) + raft::device_matrix_view inout, + math_t scalar, + bool set_neg_zero = false) { detail::seqRoot(inout.data_handle(), scalar, inout.size(), handle.get_stream(), set_neg_zero); } diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index b3abf93377..8851e170bb 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -17,80 +17,76 @@ #include "../test_utils.h" #include #include -#include -#include #include +#include +#include #include namespace raft { - namespace matrix { - - template - struct ArgMaxInputs { - const std::vector input_matrix; - const std::vector output_matrix; - std::size_t n_cols; - std::size_t n_rows; - }; - - template - ::std::ostream& operator<<(::std::ostream& os, const ArgMaxInputs& dims) - { - return os; - } - - template - class ArgMaxTest : public ::testing::TestWithParam> { - public: - ArgMaxTest() - : params(::testing::TestWithParam>::GetParam()) - {} - - void test() { - - auto input = raft::make_device_matrix(handle, params.n_rows, params.n_cols); - auto output = raft::make_device_vector(handle, params.n_rows); - auto expected = raft::make_device_vector(handle, params.n_rows); - - raft::copy(input.data_handle(), params.input_matrix.data(), params.n_rows * params.n_cols, handle.get_stream()); - raft::copy(expected.data_handle(), params.output_matrix.data(), params.n_rows * params.n_cols, handle.get_stream()); - - auto input_view = raft::make_device_matrix_view(input.data_handle(), params.n_rows, params.n_cols); - - raft::matrix::argmax(handle, input_view, output.view()); - - ASSERT_TRUE(devArrMatch(output.data_handle(), - expected.data_handle(), - params.n_rows, - Compare(), - handle.get_stream())); - - } - - protected: - raft::handle_t handle; - ArgMaxInputs params; - }; - - const std::vector> inputsf = { - { - {0.1f, 0.2f, 0.3f, 0.4f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.5f, 0.0f}, - {3, 0, 2}, - 3, 4} - }; - - const std::vector> inputsd = { - { - {0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.5, 0.0}, - {3, 0, 2}, - 3, 4} - }; - - typedef ArgMaxTest ArgMaxTestF; - TEST_P(ArgMaxTestF, Result) { test(); } - - typedef ArgMaxTest ArgMaxTestD; - TEST_P(ArgMaxTestD, Result) { test(); } +namespace matrix { + +template +struct ArgMaxInputs { + const std::vector input_matrix; + const std::vector output_matrix; + std::size_t n_cols; + std::size_t n_rows; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const ArgMaxInputs& dims) +{ + return os; +} + +template +class ArgMaxTest : public ::testing::TestWithParam> { + public: + ArgMaxTest() : params(::testing::TestWithParam>::GetParam()) {} + + void test() + { + auto input = raft::make_device_matrix(handle, params.n_rows, params.n_cols); + auto output = raft::make_device_vector(handle, params.n_rows); + auto expected = raft::make_device_vector(handle, params.n_rows); + + raft::copy(input.data_handle(), + params.input_matrix.data(), + params.n_rows * params.n_cols, + handle.get_stream()); + raft::copy(expected.data_handle(), + params.output_matrix.data(), + params.n_rows * params.n_cols, + handle.get_stream()); + + auto input_view = raft::make_device_matrix_view( + input.data_handle(), params.n_rows, params.n_cols); + + raft::matrix::argmax(handle, input_view, output.view()); + + ASSERT_TRUE(devArrMatch(output.data_handle(), + expected.data_handle(), + params.n_rows, + Compare(), + handle.get_stream())); + } + + protected: + raft::handle_t handle; + ArgMaxInputs params; +}; + +const std::vector> inputsf = { + {{0.1f, 0.2f, 0.3f, 0.4f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.5f, 0.0f}, {3, 0, 2}, 3, 4}}; + +const std::vector> inputsd = { + {{0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.5, 0.0}, {3, 0, 2}, 3, 4}}; + +typedef ArgMaxTest ArgMaxTestF; +TEST_P(ArgMaxTestF, Result) { test(); } + +typedef ArgMaxTest ArgMaxTestD; +TEST_P(ArgMaxTestD, Result) { test(); } INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestF, ::testing::ValuesIn(inputsf)); INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestD, ::testing::ValuesIn(inputsd)); diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index 270a9bb52f..aba1c4e1f0 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -47,8 +47,7 @@ struct columnSort { }; template -::std::ostream & operator<<(::std::ostream & os, - const columnSort& dims) +::std::ostream& operator<<(::std::ostream& os, const columnSort& dims) { return os; } @@ -103,13 +102,15 @@ class ColumnSort : public ::testing::TestWithParam> { if (params.testKeys) raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, handle.get_stream()); - auto key_in_view = raft::make_device_matrix_view(keyIn.data(), params.n_row, params.n_col); - auto value_out_view = - raft::make_device_matrix_view(valueOut.data(), params.n_row, params.n_col); - auto key_sorted_view = - raft::make_device_matrix_view(keySorted.data(), params.n_row, params.n_col); + auto key_in_view = raft::make_device_matrix_view( + keyIn.data(), params.n_row, params.n_col); + auto value_out_view = raft::make_device_matrix_view( + valueOut.data(), params.n_row, params.n_col); + auto key_sorted_view = raft::make_device_matrix_view( + keySorted.data(), params.n_row, params.n_col); - raft::matrix::sort_cols_per_row(handle, key_in_view, value_out_view, std::make_optional(key_sorted_view)); + raft::matrix::sort_cols_per_row( + handle, key_in_view, value_out_view, std::make_optional(key_sorted_view)); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 8b95b28542..b81e89563f 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -46,23 +46,21 @@ void naiveGather( naiveGatherImpl(in, D, N, map, map_length, out); } -template +template void gatherLaunch(const raft::handle_t& handle, - MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, + const value_t* in, + idx_t D, + idx_t N, + map_t* map, + idx_t map_length, + value_t* out, cudaStream_t stream) { - typedef typename std::iterator_traits::value_type MapValueT; - - auto in_view = raft::make_device_matrix_view(in, N, D); - auto map_view = raft::make_device_vector_view(map, map_length); - auto out_view = raft::make_device_matrix_view(out, N, D); + auto in_view = raft::make_device_matrix_view(in, N, D); + auto out_view = raft::make_device_matrix_view(out, N, D); + auto map_view = raft::make_device_vector_view(map, map_length); - matrix::gather(handle, in_view, out_view, map_view); + raft::matrix::gather(handle, in_view, out_view, map_view); } struct GatherInputs { @@ -116,9 +114,18 @@ class GatherTest : public ::testing::TestWithParam { naiveGather(h_in.data(), ncols, nrows, h_map.data(), map_length, h_out.data()); raft::update_device(d_out_exp.data(), h_out.data(), map_length * ncols, stream); - // launch device version of the kernel - gatherLaunch( - handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); + auto in_view = raft::make_device_matrix_view( + d_in.data(), nrows, ncols); + auto out_view = + raft::make_device_matrix_view(d_out_act.data(), nrows, ncols); + auto map_view = + raft::make_device_vector_view(d_map.data(), map_length); + + raft::matrix::gather(handle, in_view, out_view, map_view); + + // // launch device version of the kernel + // gatherLaunch( + // handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); handle.sync_stream(stream); } diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 430d842d8e..22e884e33f 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include #include #include @@ -55,36 +55,31 @@ struct LinewiseTest : public ::testing::TestWithParam - void runLinewiseSum( - T* out, const T* in, const I lineLen, const I nLines, const T* vec) + template + void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, const T* vec) { auto f = [] __device__(T a, T b) -> T { return a + b; }; - - auto in_view = raft::make_device_matrix_view(in, nLines, lineLen); + auto in_view = raft::make_device_matrix_view(in, nLines, lineLen); auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); - auto vec_view = raft::make_device_vector_view(vec, lineLen); - matrix::line_wise_op(handle, in_view, out_view, alongLines, f, vec); + auto vec_view = raft::make_device_vector_view(vec, lineLen); + matrix::linewise_op(handle, in_view, out_view, raft::is_row_major(in_view), f, vec_view); } - template - void runLinewiseSum(T* out, - const T* in, - const I lineLen, - const I nLines, - const T* vec1, - const T* vec2) + template + void runLinewiseSum( + T* out, const T* in, const I lineLen, const I nLines, const T* vec1, const T* vec2) { auto f = [] __device__(T a, T b, T c) -> T { return a + b + c; }; - constexpr auto layout = alongLines ? row_major : col_major; + auto in_view = raft::make_device_matrix_view(in, nLines, lineLen); + auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); + auto vec1_view = raft::make_device_vector_view(vec1, lineLen); + auto vec2_view = raft::make_device_vector_view(vec2, lineLen); - auto in_view = raft::make_device_matrix_view(in, nLines, lineLen) - auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); - - matrix::line_wise_op(handle, in_view, out_view, alongLines, f, vec1, vec2); + matrix::linewise_op( + handle, in_view, out_view, raft::is_row_major(in_view), f, vec1_view, vec2_view); } rmm::device_uvector genData(size_t workSizeBytes) @@ -163,10 +158,10 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1); + if (alongRows) { + runLinewiseSum(out, in, lineLen, nLines, vec1); } else { - runLinewiseSum(out, in, lineLen, nLines, vec1); + runLinewiseSum(out, in, lineLen, nLines, vec1); } } if (params.checkCorrectness) { @@ -179,12 +174,11 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1, vec2); + if (alongRows) { + runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); } else { - runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); - + runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); } } if (params.checkCorrectness) { diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index a5e5ffc8dc..adfa45b84a 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -18,12 +18,12 @@ #include #include -#include #include #include +#include #include -#include #include +#include #include #include @@ -170,7 +170,8 @@ class MathTest : public ::testing::TestWithParam> { naiveSignFlip( in_sign_flip.data(), out_sign_flip_ref.data(), params.n_row, params.n_col, stream); - auto in_sign_flip_view = raft::make_device_matrix_view(in_sign_flip.data(), params.n_row, params.n_col); + auto in_sign_flip_view = raft::make_device_matrix_view( + in_sign_flip.data(), params.n_row, params.n_col); sign_flip(handle, in_sign_flip_view); // default threshold is 1e-15 @@ -180,27 +181,27 @@ class MathTest : public ::testing::TestWithParam> { update_device(in_recip_ref.data(), in_recip_ref_h.data(), 4, stream); T recip_scalar = T(1.0); - auto in_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); + auto in_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); auto out_recip_view = raft::make_device_matrix_view(out_recip.data(), 4, 1); // this `reciprocal()` has to go first bc next one modifies its input reciprocal(handle, in_recip_view, out_recip_view, recip_scalar); - auto inout_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); + auto inout_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); - reciprocal(handle, inout_recip_view, recip_scalar, true); + reciprocal(handle, inout_recip_view, recip_scalar, true); std::vector in_small_val_zero_h = {0.1, 1e-16, -1e-16, -0.1}; std::vector in_small_val_zero_ref_h = {0.1, 0.0, 0.0, -0.1}; - auto in_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); - auto inout_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); - auto out_smallzero_view = raft::make_device_matrix_view(out_smallzero.data(), 4, 1); + auto in_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); + auto inout_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); + auto out_smallzero_view = raft::make_device_matrix_view(out_smallzero.data(), 4, 1); update_device(in_smallzero.data(), in_small_val_zero_h.data(), 4, stream); update_device(out_smallzero_ref.data(), in_small_val_zero_ref_h.data(), 4, stream); - zero_small_values(handle, in_smallzero_view, out_smallzero_view); - zero_small_values(handle, inout_smallzero_view); + zero_small_values(handle, in_smallzero_view, out_smallzero_view); + zero_small_values(handle, inout_smallzero_view); handle.sync_stream(stream); } diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index f684ebfd45..fd7ca6b10e 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -63,8 +63,10 @@ class MatrixTest : public ::testing::TestWithParam> { int len = params.n_row * params.n_col; uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); - auto in1_view = raft::make_device_matrix_view(in1.data(), params.n_row, params.n_col); - auto in2_view = raft::make_device_matrix_view(in2.data(), params.n_row, params.n_col); + auto in1_view = raft::make_device_matrix_view( + in1.data(), params.n_row, params.n_col); + auto in2_view = + raft::make_device_matrix_view(in2.data(), params.n_row, params.n_col); copy(handle, in1_view, in2_view); // copy(in1, in1_revr, params.n_row, params.n_col); @@ -135,20 +137,27 @@ class MatrixCopyRowsTest : public ::testing::Test { void testCopyRows() { - auto input_view = raft::make_device_matrix_view(input.data(), n_rows, n_cols); - auto output_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); + device_matrix_view input_view = + raft::make_device_matrix_view( + input.data(), n_rows, n_cols); + device_matrix_view output_view = + raft::make_device_matrix_view(output.data(), n_rows, n_cols); - auto indices_view = raft::make_device_vector_view(indices.data(), n_selected); + device_vector_view indices_view = + raft::make_device_vector_view(indices.data(), n_selected); - copy_rows(handle, input_view, output_view, indices_view); + raft::matrix::copy_rows(handle, input_view, output_view, indices_view); EXPECT_TRUE(raft::devArrMatchHost( output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); - auto input_row_view = raft::make_device_matrix_view(input.data(), n_rows, n_cols); - auto output_row_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); + device_matrix_view input_row_view = + raft::make_device_matrix_view( + input.data(), n_rows, n_cols); + device_matrix_view output_row_view = + raft::make_device_matrix_view(output.data(), n_rows, n_cols); - copy_rows(handle, input_row_view, output_row_view, indices_view); + raft::matrix::copy_rows(handle, input_row_view, output_row_view, indices_view); EXPECT_TRUE(raft::devArrMatchHost( output_exp_rowmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); } From 8db49ab644014792610f689c193311df77ed256e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 27 Sep 2022 18:16:24 -0400 Subject: [PATCH 38/58] Still trying to figure out why gaqther isn't being invoked --- cpp/include/raft/matrix/copy.cuh | 6 ++---- cpp/include/raft/matrix/gather.cuh | 23 ++++++++++++++--------- cpp/test/matrix/gather.cu | 17 ----------------- cpp/test/matrix/matrix.cu | 10 +++++----- 4 files changed, 21 insertions(+), 35 deletions(-) diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index c99f86daac..9d2b8ed5bf 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -42,16 +42,14 @@ void copy_rows(const raft::handle_t& handle, "Input and output matrices must have same number of columns"); RAFT_EXPECTS(indices.extent(0) == out.extent(0), "Number of rows in output matrix must equal number of indices"); - bool in_rowmajor = raft::is_row_major(in); - bool out_rowmajor = raft::is_row_major(out); - detail::copyRows(in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), indices.data_handle(), indices.extent(0), - handle.get_stream()); + handle.get_stream(), + raft::is_row_major(in)); } /** diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 4415fb866c..f508c57136 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -16,12 +16,11 @@ #pragma once -#include #include +#include #include -namespace raft { -namespace matrix { +namespace raft::matrix { /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. @@ -75,13 +74,20 @@ void gather(const raft::handle_t& handle, RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - gather(in.data_handle(), + const matrix_t *in_ptr = in.data_handle(); + map_t *map_ptr = map.data_handle(); + matrix_t *out_ptr = out.data_handle(); + + cudaStream_t stream = handle.get_stream(); + + detail::gather( + in_ptr, static_cast(in.extent(1)), static_cast(in.extent(0)), - map.data_handle(), + map_ptr, static_cast(map.extent(0)), - out.data_handle(), - handle.get_stream()); + out_ptr, + stream); } /** @@ -358,5 +364,4 @@ void gather_if(const raft::handle_t& handle, handle.get_stream()); } -} // namespace matrix -} // namespace raft +} // namespace raft::matrix diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index b81e89563f..08f538424b 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -46,23 +46,6 @@ void naiveGather( naiveGatherImpl(in, D, N, map, map_length, out); } -template -void gatherLaunch(const raft::handle_t& handle, - const value_t* in, - idx_t D, - idx_t N, - map_t* map, - idx_t map_length, - value_t* out, - cudaStream_t stream) -{ - auto in_view = raft::make_device_matrix_view(in, N, D); - auto out_view = raft::make_device_matrix_view(out, N, D); - auto map_view = raft::make_device_vector_view(map, map_length); - - raft::matrix::gather(handle, in_view, out_view, map_view); -} - struct GatherInputs { uint32_t nrows; uint32_t ncols; diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index fd7ca6b10e..e7b6373cad 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -137,13 +137,13 @@ class MatrixCopyRowsTest : public ::testing::Test { void testCopyRows() { - device_matrix_view input_view = + auto input_view = raft::make_device_matrix_view( input.data(), n_rows, n_cols); - device_matrix_view output_view = + auto output_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); - device_vector_view indices_view = + auto indices_view = raft::make_device_vector_view(indices.data(), n_selected); raft::matrix::copy_rows(handle, input_view, output_view, indices_view); @@ -151,10 +151,10 @@ class MatrixCopyRowsTest : public ::testing::Test { EXPECT_TRUE(raft::devArrMatchHost( output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); - device_matrix_view input_row_view = + auto input_row_view = raft::make_device_matrix_view( input.data(), n_rows, n_cols); - device_matrix_view output_row_view = + auto output_row_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); raft::matrix::copy_rows(handle, input_row_view, output_row_view, indices_view); From 4ea623fc53f00a47e2845042a2c2daaee91ba184 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 27 Sep 2022 18:17:31 -0400 Subject: [PATCH 39/58] Style fix --- cpp/include/raft/matrix/gather.cuh | 23 +++++++++++------------ cpp/test/matrix/matrix.cu | 10 ++++------ 2 files changed, 15 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index f508c57136..682a1208c1 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -16,8 +16,8 @@ #pragma once -#include #include +#include #include namespace raft::matrix { @@ -74,20 +74,19 @@ void gather(const raft::handle_t& handle, RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - const matrix_t *in_ptr = in.data_handle(); - map_t *map_ptr = map.data_handle(); - matrix_t *out_ptr = out.data_handle(); + const matrix_t* in_ptr = in.data_handle(); + map_t* map_ptr = map.data_handle(); + matrix_t* out_ptr = out.data_handle(); cudaStream_t stream = handle.get_stream(); - detail::gather( - in_ptr, - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map_ptr, - static_cast(map.extent(0)), - out_ptr, - stream); + detail::gather(in_ptr, + static_cast(in.extent(1)), + static_cast(in.extent(0)), + map_ptr, + static_cast(map.extent(0)), + out_ptr, + stream); } /** diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index e7b6373cad..2752de6431 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -137,9 +137,8 @@ class MatrixCopyRowsTest : public ::testing::Test { void testCopyRows() { - auto input_view = - raft::make_device_matrix_view( - input.data(), n_rows, n_cols); + auto input_view = raft::make_device_matrix_view( + input.data(), n_rows, n_cols); auto output_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); @@ -151,9 +150,8 @@ class MatrixCopyRowsTest : public ::testing::Test { EXPECT_TRUE(raft::devArrMatchHost( output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); - auto input_row_view = - raft::make_device_matrix_view( - input.data(), n_rows, n_cols); + auto input_row_view = raft::make_device_matrix_view( + input.data(), n_rows, n_cols); auto output_row_view = raft::make_device_matrix_view(output.data(), n_rows, n_cols); From 477d18ad08b30130a6b30d38efcbc030ff9c31d8 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 27 Sep 2022 19:43:14 -0400 Subject: [PATCH 40/58] Finisxhing up matrix --- .../raft/cluster/detail/single_linkage.cuh | 2 +- cpp/include/raft/cluster/single_linkage.cuh | 41 ++++++++++++++++++- .../raft/cluster/single_linkage_types.hpp | 30 ++++++++++---- cpp/include/raft/matrix/gather.cuh | 20 ++++----- cpp/test/sparse/linkage.cu | 2 +- 5 files changed, 73 insertions(+), 22 deletions(-) diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh index 7de942444e..9eee21b09c 100644 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ b/cpp/include/raft/cluster/detail/single_linkage.cuh @@ -54,7 +54,7 @@ void single_linkage(const raft::handle_t& handle, size_t m, size_t n, raft::distance::DistanceType metric, - linkage_output* out, + linkage_output* out, int c, size_t n_clusters) { diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 98735c74e4..7f0553a553 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -17,6 +17,7 @@ #include #include +#include namespace raft::cluster { @@ -48,11 +49,49 @@ void single_linkage(const raft::handle_t& handle, size_t m, size_t n, raft::distance::DistanceType metric, - linkage_output* out, + linkage_output* out, int c, size_t n_clusters) { detail::single_linkage( handle, X, m, n, metric, out, c, n_clusters); } + +/** + * Single-linkage clustering, capable of constructing a KNN graph to + * scale the algorithm beyond the n^2 memory consumption of implementations + * that use the fully-connected graph of pairwise distances by connecting + * a knn graph when k is not large enough to connect it. + + * @tparam value_idx + * @tparam value_t + * @tparam dist_type method to use for constructing connectivities graph + * @param[in] handle raft handle + * @param[in] X dense input matrix in row-major layout + * @param[in] m number of rows in X + * @param[in] n number of columns in X + * @param[in] metric distance metrix to use when constructing connectivities graph + * @param[out] out struct containing output dendrogram and cluster assignments + * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect + control of k. The algorithm will set `k = log(n) + c` + * @param[in] n_clusters number of clusters to assign data samples + */ +template +void single_linkage(const raft::handle_t& handle, + raft::device_matrix_view X, + raft::distance::DistanceType metric, + linkage_output& out, + int c, + size_t n_clusters) +{ + detail::single_linkage(handle, + X.data_handle(), + static_cast(X.extent(0)), + static_cast(X.extent(1)), + metric, + &out, + c, + n_clusters); +} + }; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index 1c35cf5c68..d97e6afed3 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::cluster { enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; @@ -27,23 +29,33 @@ enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; * @tparam value_idx * @tparam value_t */ -template +template class linkage_output { public: - value_idx m; - value_idx n_clusters; + idx_t m; + idx_t n_clusters; + + idx_t n_leaves; + idx_t n_connected_components; - value_idx n_leaves; - value_idx n_connected_components; + // TODO: These will be made private in a future release + idx_t* labels; // size: m + idx_t* children; // size: (m-1, 2) - value_idx* labels; // size: m + raft::device_vector_view get_labels() + { + return raft::make_device_vector_view(labels, m); + } - value_idx* children; // size: (m-1, 2) + raft::device_matrix_view get_children() + { + return raft::make_device_matrix_view(children, m - 1, 2); + } }; -class linkage_output_int_float : public linkage_output { +class linkage_output_int_float : public linkage_output { }; -class linkage_output__int64_float : public linkage_output { +class linkage_output__int64_float : public linkage_output { }; }; // namespace raft::cluster \ No newline at end of file diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 682a1208c1..024d61cb41 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -74,19 +74,19 @@ void gather(const raft::handle_t& handle, RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - const matrix_t* in_ptr = in.data_handle(); - map_t* map_ptr = map.data_handle(); - matrix_t* out_ptr = out.data_handle(); + matrix_t* in_ptr = const_cast(in.data_handle()); + map_t* map_ptr = map.data_handle(); + matrix_t* out_ptr = out.data_handle(); cudaStream_t stream = handle.get_stream(); - detail::gather(in_ptr, - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map_ptr, - static_cast(map.extent(0)), - out_ptr, - stream); + raft::matrix::detail::gather(in_ptr, + static_cast(in.extent(1)), + static_cast(in.extent(0)), + map_ptr, + static_cast(map.extent(0)), + out_ptr, + stream); } /** diff --git a/cpp/test/sparse/linkage.cu b/cpp/test/sparse/linkage.cu index e9df5e3df0..6fa1d0461e 100644 --- a/cpp/test/sparse/linkage.cu +++ b/cpp/test/sparse/linkage.cu @@ -175,7 +175,7 @@ class LinkageTest : public ::testing::TestWithParam> { raft::copy(data.data(), params.data.data(), data.size(), stream); raft::copy(labels_ref.data(), params.expected_labels.data(), params.n_row, stream); - raft::hierarchy::linkage_output out_arrs; + raft::hierarchy::linkage_output out_arrs; out_arrs.labels = labels.data(); rmm::device_uvector out_children(params.n_row * 2, stream); From f63b458d5e4f4469363246427eb6bfe1bc5c3dc8 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 28 Sep 2022 15:25:47 -0400 Subject: [PATCH 41/58] Getting tests to pass. Still need to figure out linewise_op failures --- cpp/include/raft/matrix/argmax.cuh | 4 ++-- cpp/include/raft/matrix/linewise_op.cuh | 25 +++++++++++++++++++++---- cpp/test/matrix/argmax.cu | 20 ++++++++++---------- cpp/test/matrix/gather.cu | 2 +- cpp/test/matrix/matrix.cu | 8 ++++---- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index 013e496ea1..9d15f1e94e 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -32,8 +32,8 @@ void argmax(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_vector_view out) { - RAFT_EXPECTS(static_cast(out.extent(1)) == in.extent(1), - "Size of output vector must equal number of columns in input matrix."); + RAFT_EXPECTS(static_cast(out.extent(0)) == in.extent(0), + "Size of output vector must equal number of rows in input matrix."); detail::argmax( in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); } diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index c4871711a5..848f8580b3 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -22,6 +22,12 @@ namespace raft::matrix { +// template +// args *extract_ptr(raft::device_vector_view vec, raft::device_vector_view... vecs) { +// vecs.data_handle(); +//} + /** * Run a function over matrix lines (rows or columns) with a variable number * row-vectors or column-vectors. @@ -43,13 +49,18 @@ namespace raft::matrix { * @param [in] vecs zero or more vectors to be passed as arguments, * size of each vector is `alongLines ? lineLen : nLines`. */ -template +template > void linewise_op(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, const bool alongLines, Lambda op, - raft::device_vector_view... vecs) + vec_t... vecs) { constexpr auto is_rowmajor = std::is_same_v; constexpr auto is_colmajor = std::is_same_v; @@ -63,7 +74,13 @@ void linewise_op(const raft::handle_t& handle, RAFT_EXPECTS(out.extent(0) == in.extent(0) && out.extent(1) == in.extent(1), "Input and output must have the same shape."); - // detail::MatrixLinewiseOp<16, 256>::run( - // out.data_handle(), in.data_handle(), lineLen, nLines, alongLines, op, stream, vecs...); + detail::MatrixLinewiseOp<16, 256>::run(out.data_handle(), + in.data_handle(), + lineLen, + nLines, + alongLines, + op, + handle.get_stream(), + vecs.data_handle()...); } } // namespace raft::matrix diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 8851e170bb..8fbcb1a38d 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -27,8 +27,8 @@ namespace matrix { template struct ArgMaxInputs { - const std::vector input_matrix; - const std::vector output_matrix; + std::vector input_matrix; + std::vector output_matrix; std::size_t n_cols; std::size_t n_rows; }; @@ -50,14 +50,14 @@ class ArgMaxTest : public ::testing::TestWithParam> { auto output = raft::make_device_vector(handle, params.n_rows); auto expected = raft::make_device_vector(handle, params.n_rows); - raft::copy(input.data_handle(), - params.input_matrix.data(), - params.n_rows * params.n_cols, - handle.get_stream()); - raft::copy(expected.data_handle(), - params.output_matrix.data(), - params.n_rows * params.n_cols, - handle.get_stream()); + raft::update_device(input.data_handle(), + params.input_matrix.data(), + params.n_rows * params.n_cols, + handle.get_stream()); + raft::update_device( + expected.data_handle(), params.output_matrix.data(), params.n_rows, handle.get_stream()); + + printf("Finished copy\n"); auto input_view = raft::make_device_matrix_view( input.data_handle(), params.n_rows, params.n_cols); diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index bb675bff04..7d626f5f3c 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -100,7 +100,7 @@ class GatherTest : public ::testing::TestWithParam { auto in_view = raft::make_device_matrix_view( d_in.data(), nrows, ncols); auto out_view = - raft::make_device_matrix_view(d_out_act.data(), nrows, ncols); + raft::make_device_matrix_view(d_out_act.data(), map_length, ncols); auto map_view = raft::make_device_vector_view(d_map.data(), map_length); diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 2752de6431..78391d5ff2 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -139,8 +139,8 @@ class MatrixCopyRowsTest : public ::testing::Test { { auto input_view = raft::make_device_matrix_view( input.data(), n_rows, n_cols); - auto output_view = - raft::make_device_matrix_view(output.data(), n_rows, n_cols); + auto output_view = raft::make_device_matrix_view( + output.data(), n_selected, n_cols); auto indices_view = raft::make_device_vector_view(indices.data(), n_selected); @@ -152,8 +152,8 @@ class MatrixCopyRowsTest : public ::testing::Test { auto input_row_view = raft::make_device_matrix_view( input.data(), n_rows, n_cols); - auto output_row_view = - raft::make_device_matrix_view(output.data(), n_rows, n_cols); + auto output_row_view = raft::make_device_matrix_view( + output.data(), n_selected, n_cols); raft::matrix::copy_rows(handle, input_row_view, output_row_view, indices_view); EXPECT_TRUE(raft::devArrMatchHost( From 7edf83ad413716a9138ca324e68881f89b96c355 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 28 Sep 2022 18:04:55 -0400 Subject: [PATCH 42/58] Cleaning up, docs, tests passing. ready for review --- cpp/include/raft/matrix/argmax.cuh | 10 +- cpp/include/raft/matrix/gather.cuh | 194 +++++++++++------------- cpp/include/raft/matrix/init.cuh | 16 +- cpp/include/raft/matrix/linewise_op.cuh | 11 +- cpp/include/raft/matrix/power.cuh | 34 +++-- cpp/include/raft/matrix/print.cuh | 12 +- cpp/include/raft/matrix/ratio.cuh | 24 +-- cpp/include/raft/matrix/reciprocal.cuh | 22 +-- cpp/include/raft/matrix/sign_flip.cuh | 6 +- cpp/include/raft/matrix/sqrt.cuh | 36 +++-- cpp/include/raft/matrix/threshold.cuh | 20 ++- cpp/test/matrix/argmax.cu | 2 +- cpp/test/matrix/gather.cu | 2 +- cpp/test/matrix/linewise_op.cu | 26 ++-- 14 files changed, 223 insertions(+), 192 deletions(-) diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index 9d15f1e94e..b7423b9ea4 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -23,16 +23,18 @@ namespace raft::matrix { /** * @brief Argmax: find the row idx with maximum value for each column + * @tparam math_t matrix element type + * @tparam idx_t integer type for matrix and vector indexing * @param[in] handle: raft handle * @param[in] in: input matrix of size (n_rows, n_cols) * @param[out] out: output vector of size n_cols */ -template +template void argmax(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view out) + raft::device_matrix_view in, + raft::device_vector_view out) { - RAFT_EXPECTS(static_cast(out.extent(0)) == in.extent(0), + RAFT_EXPECTS(out.extent(0) == in.extent(0), "Size of output vector must equal number of rows in input matrix."); detail::argmax( in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 024d61cb41..58cacf9c73 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -54,88 +54,70 @@ void gather(const MatrixIteratorT in, /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param out Pointer to the output matrix (assumed to be row-major) - * @param map Pointer to the input sequence of gather locations + * @tparam matrix_t Matrix element type + * @tparam map_t Map vector type + * @tparam idx_t integer type used for indexing + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Vector of gather locations + * @param[out] out Output matrix (assumed to be row-major) */ template void gather(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map) + raft::device_vector_view map, + raft::device_matrix_view out) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - matrix_t* in_ptr = const_cast(in.data_handle()); - map_t* map_ptr = map.data_handle(); - matrix_t* out_ptr = out.data_handle(); - - cudaStream_t stream = handle.get_stream(); - - raft::matrix::detail::gather(in_ptr, - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map_ptr, - static_cast(map.extent(0)), - out_ptr, - stream); + raft::matrix::detail::gather( + const_cast(in.data_handle()), // TODO: There's a better way to handle this + static_cast(in.extent(1)), + static_cast(in.extent(0)), + map.data_handle(), + static_cast(map.extent(0)), + out.data_handle(), + handle.get_stream()); } /** * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * @tparam matrix_t Matrix type + * @tparam map_t Map vector type + * @tparam map_xform_t Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to idx_t (= int) type. + * @tparam idx_t integer type for indexing * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param map Pointer to the input sequence of gather locations - * @param out Pointer to the output matrix (assumed to be row-major) - * @param transform_op The transformation operation, transforms the map values to IndexT + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Input vector of gather locations + * @param[out] out Output matrix (assumed to be row-major) + * @param[in] transform_op The transformation operation, transforms the map values to idx_t */ -template +template void gather(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map, - MapTransformOp transform_op) + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map, + map_xform_t transform_op) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); RAFT_EXPECTS(out.extent(1) == in.extent(1), "Number of columns in input and output matrices must be equal."); - /** - * template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, - cudaStream_t stream) - - */ - detail::gather(in.data_handle(), - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map, - static_cast(map.extent(0)), - out.data_handle(), - transform_op, - handle.get_stream()); + detail::gather( + const_cast(in.data_handle()), // TODO: There's a better way to handle this + static_cast(in.extent(1)), + static_cast(in.extent(0)), + map, + static_cast(map.extent(0)), + out.data_handle(), + transform_op, + handle.get_stream()); } /** @@ -217,32 +199,30 @@ void gather_if(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a - * simple pointer type). - * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * @tparam matrix_t Matrix value type + * @tparam map_t Map vector type + * @tparam stencil_t Stencil vector type + * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result * type must be convertible to bool type. + * @tparam idx_t integer type for indexing * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param map Pointer to the input sequence of gather locations - * @param stencil Pointer to the input sequence of stencil or predicate values - * @param out Pointer to the output matrix (assumed to be row-major) - * @param pred_op Predicate to apply to the stencil values + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Input vector of gather locations + * @param[in] stencil Input vector of stencil or predicate values + * @param[out] out Output matrix (assumed to be row-major) + * @param[in] pred_op Predicate to apply to the stencil values */ -template +template void gather_if(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map, - raft::device_vector_view stencil, - UnaryPredicateOp pred_op) + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map, + raft::device_vector_view stencil, + unary_pred_t pred_op) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); @@ -251,7 +231,7 @@ void gather_if(const raft::handle_t& handle, RAFT_EXPECTS(map.extent(0) == stencil.extent(0), "Number of elements in stencil must equal number of elements in map"); - detail::gather_if(in.data_handle(), + detail::gather_if(const_cast(in.data_handle()), out.extent(1), out.extent(0), map.data_handle(), @@ -312,37 +292,35 @@ void gather_if(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a - * simple pointer type). - * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result + * @tparam matrix_t Matrix value type, for reading input matrix + * @tparam map_t Vector value type for map + * @tparam stencil_t Vector value type for stencil + * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result * type must be convertible to bool type. - * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * @tparam map_xform_t Unary lambda expression or operator type, map_xform_t's result + * type must be convertible to idx_t (= int) type. + * @tparam idx_t integer type for indexing * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param map Pointer to the input sequence of gather locations - * @param stencil Pointer to the input sequence of stencil or predicate values - * @param out Pointer to the output matrix (assumed to be row-major) - * @param pred_op Predicate to apply to the stencil values - * @param transform_op The transformation operation, transforms the map values to IndexT + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Vector of gather locations + * @param[in] stencil Vector of stencil or predicate values + * @param[out] out Output matrix (assumed to be row-major) + * @param[in] pred_op Predicate to apply to the stencil values + * @param[in] transform_op The transformation operation, transforms the map values to idx_t */ -template +template void gather_if(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map, - raft::device_vector_view stencil, - UnaryPredicateOp pred_op, - MapTransformOp transform_op) + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map, + raft::device_vector_view stencil, + unary_pred_t pred_op, + map_xform_t transform_op) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); @@ -351,7 +329,7 @@ void gather_if(const raft::handle_t& handle, RAFT_EXPECTS(map.extent(0) == stencil.extent(0), "Number of elements in stencil must equal number of elements in map"); - detail::gather_if(in.data_handle(), + detail::gather_if(const_cast(in.data_handle()), in.extent(1), in.extent(0), map.data_handle(), diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index 0c6f45f904..37ea1dce1a 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -24,15 +24,17 @@ namespace raft::matrix { /** * @brief set values to scalar in matrix * @tparam math_t data-type upon which the math operation will be performed - * @param handle: raft handle - * @param in input matrix - * @param out output matrix. The result is stored in the out matrix - * @param scalar svalar value + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[in] in input matrix + * @param[out] out output matrix. The result is stored in the out matrix + * @param[in] scalar scalar value to fill matrix elements */ -template +template void fill(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, + raft::device_matrix_view in, + raft::device_matrix_view out, math_t scalar) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index 848f8580b3..5996e40f86 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -35,7 +35,12 @@ namespace raft::matrix { * depending on the matrix layout. * What matters is if the vectors are applied along lines (indices of vectors correspond to * indices within lines), or across lines (indices of vectors correspond to line numbers). - * + * @tparam m_t matrix elements type + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @tparam Lambda type of lambda function used for the operation + * @tparam vec_t variadic types of device_vector_view vectors (size m if alongRows, size n + * otherwise) * @param [out] out result of the operation; can be same as `in`; should be aligned the same * as `in` to allow faster vectorized memory transfers. * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. @@ -68,8 +73,8 @@ void linewise_op(const raft::handle_t& handle, static_assert(is_rowmajor || is_colmajor, "layout for in and out must be either row or col major"); - const idx_t lineLen = is_rowmajor ? in.extent(1) : in.extent(0); - const idx_t nLines = is_rowmajor ? in.extent(0) : in.extent(1); + const idx_t lineLen = is_rowmajor ? in.extent(0) : in.extent(1); + const idx_t nLines = is_rowmajor ? in.extent(1) : in.extent(0); RAFT_EXPECTS(out.extent(0) == in.extent(0) && out.extent(1) == in.extent(1), "Input and output must have the same shape."); diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh index 60a3231bf0..320ca4fe0f 100644 --- a/cpp/include/raft/matrix/power.cuh +++ b/cpp/include/raft/matrix/power.cuh @@ -23,15 +23,18 @@ namespace raft::matrix { /** * @brief Power of every element in the input matrix + * @tparam math_t type of matrix elements + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) * @param[in] handle: raft handle * @param[in] in: input matrix * @param[out] out: output matrix. The result is stored in the out matrix * @param[in] scalar: every element is multiplied with scalar. */ -template +template void weighted_power(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, + raft::device_matrix_view in, + raft::device_matrix_view out, math_t scalar) { RAFT_EXPECTS(in.size() == out.size(), "Size of input and output matrices must be equal"); @@ -40,12 +43,16 @@ void weighted_power(const raft::handle_t& handle, /** * @brief Power of every element in the input matrix (inplace) + * @tparam math_t matrix element type + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle * @param[inout] inout: input matrix and also the result is stored * @param[in] scalar: every element is multiplied with scalar. */ -template +template void weighted_power(const raft::handle_t& handle, - raft::device_matrix_view inout, + raft::device_matrix_view inout, math_t scalar) { detail::power(inout.data_handle(), scalar, inout.size(), handle.get_stream()); @@ -53,25 +60,32 @@ void weighted_power(const raft::handle_t& handle, /** * @brief Power of every element in the input matrix (inplace) + * @tparam math_t matrix element type + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle * @param[inout] inout: input matrix and also the result is stored */ -template -void power(const raft::handle_t& handle, raft::device_matrix_view inout) +template +void power(const raft::handle_t& handle, raft::device_matrix_view inout) { detail::power(inout.data_handle(), inout.size(), handle.get_stream()); } /** * @brief Power of every element in the input matrix + * @tparam math_t type used for matrix elements + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix (row or column major) * @param[in] handle: raft handle * @param[in] in: input matrix * @param[out] out: output matrix. The result is stored in the out matrix * @{ */ -template +template void power(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out) + raft::device_matrix_view in, + raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); detail::power(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh index 5eef7e0fda..060cd8642c 100644 --- a/cpp/include/raft/matrix/print.cuh +++ b/cpp/include/raft/matrix/print.cuh @@ -25,10 +25,12 @@ namespace raft::matrix { /** * @brief Prints the data stored in GPU memory - * @param handle: raft handle - * @param in: input matrix - * @param h_separator: horizontal separator character - * @param v_separator: vertical separator character + * @tparam m_t type of matrix elements + * @tparam idx_t integer type used for indexing + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[in] h_separator: horizontal separator character + * @param[in] v_separator: vertical separator character */ template void print(const raft::handle_t& handle, @@ -42,6 +44,8 @@ void print(const raft::handle_t& handle, /** * @brief Prints the data stored in CPU memory + * @tparam m_t type of matrix elements + * @tparam idx_t integer type used for indexing * @param in: input matrix with column-major layout */ template diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh index 573d906723..635b8ec46d 100644 --- a/cpp/include/raft/matrix/ratio.cuh +++ b/cpp/include/raft/matrix/ratio.cuh @@ -24,17 +24,16 @@ namespace raft::matrix { /** * @brief ratio of every element over sum of input vector is calculated * @tparam math_t data-type upon which the math operation will be performed - * @tparam IdxType Integer type used to for addressing - * @param handle - * @param src: input matrix - * @param dest: output matrix. The result is stored in the dest matrix - * @param len: number elements of input matrix - * @param stream cuda stream + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle + * @param[in] src: input matrix + * @param[out] dest: output matrix. The result is stored in the dest matrix */ -template +template void ratio(const raft::handle_t& handle, - raft::device_matrix_view src, - raft::device_matrix_view dest) + raft::device_matrix_view src, + raft::device_matrix_view dest) { RAFT_EXPECTS(src.size() == dest.size(), "Input and output matrices must be the same size."); detail::ratio(handle, src.data_handle(), dest.data_handle(), src.size(), handle.get_stream()); @@ -43,12 +42,13 @@ void ratio(const raft::handle_t& handle, /** * @brief ratio of every element over sum of input vector is calculated * @tparam math_t data-type upon which the math operation will be performed - * @tparam IdxType Integer type used to for addressing + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) * @param[in] handle * @param[inout] inout: input matrix */ -template -void ratio(const raft::handle_t& handle, raft::device_matrix_view inout) +template +void ratio(const raft::handle_t& handle, raft::device_matrix_view inout) { detail::ratio( handle, inout.data_handle(), inout.data_handle(), inout.size(), handle.get_stream()); diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index 545c4592c5..80f253c828 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -24,6 +24,7 @@ namespace raft::matrix { /** * @brief Reciprocal of every element in the input matrix * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing * @param handle: raft handle * @param in: input matrix and also the result is stored * @param out: output matrix. The result is stored in the out matrix @@ -32,10 +33,10 @@ namespace raft::matrix { * @param thres the threshold used to forcibly set inputs to zero * @{ */ -template +template void reciprocal(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, + raft::device_matrix_view in, + raft::device_matrix_view out, math_t scalar, bool setzero = false, math_t thres = 1e-15) @@ -48,15 +49,18 @@ void reciprocal(const raft::handle_t& handle, /** * @brief Reciprocal of every element in the input matrix (in place) * @tparam math_t data-type upon which the math operation will be performed - * @param inout: input matrix with in-place results - * @param scalar: every element is multiplied with scalar - * @param setzero round down to zero if the input is less the threshold - * @param thres the threshold used to forcibly set inputs to zero + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle to manage resources + * @param[inout] inout: input matrix with in-place results + * @param[in] scalar: every element is multiplied with scalar + * @param[in] setzero round down to zero if the input is less the threshold + * @param[in] thres the threshold used to forcibly set inputs to zero * @{ */ -template +template void reciprocal(const raft::handle_t& handle, - raft::device_matrix_view inout, + raft::device_matrix_view inout, math_t scalar, bool setzero = false, math_t thres = 1e-15) diff --git a/cpp/include/raft/matrix/sign_flip.cuh b/cpp/include/raft/matrix/sign_flip.cuh index f99c3111ab..01f8829c85 100644 --- a/cpp/include/raft/matrix/sign_flip.cuh +++ b/cpp/include/raft/matrix/sign_flip.cuh @@ -25,8 +25,10 @@ namespace raft::matrix { /** * @brief sign flip stabilizes the sign of col major eigen vectors. * The sign is flipped if the column has negative |max|. - * @param handle: raft handle - * @param inout: input matrix. Result also stored in this parameter + * @tparam math_t floating point type used for matrix elements + * @tparam idx_t integer type used for indexing + * @param[in] handle: raft handle + * @param[inout] inout: input matrix. Result also stored in this parameter */ template void sign_flip(const raft::handle_t& handle, diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index b3791b0a7c..2c03a8672c 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -25,14 +25,16 @@ namespace raft::matrix { /** * @brief Square root of every element in the input matrix * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) * @param[in] handle: raft handle * @param[in] in: input matrix and also the result is stored * @param[out] out: output matrix. The result is stored in the out matrix */ -template +template void sqrt(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out) + raft::device_matrix_view in, + raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); @@ -41,11 +43,13 @@ void sqrt(const raft::handle_t& handle, /** * @brief Square root of every element in the input matrix (in place) * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) * @param[in] handle: raft handle * @param[inout] inout: input matrix with in-place results */ -template -void sqrt(const raft::handle_t& handle, raft::device_matrix_view inout) +template +void sqrt(const raft::handle_t& handle, raft::device_matrix_view inout) { detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream()); } @@ -53,16 +57,18 @@ void sqrt(const raft::handle_t& handle, raft::device_matrix_view inout) /** * @brief Square root of every element in the input matrix * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) * @param[in] handle: raft handle * @param[in] in: input matrix and also the result is stored * @param[out] out: output matrix. The result is stored in the out matrix * @param[in] scalar: every element is multiplied with scalar * @param[in] set_neg_zero whether to set negative numbers to zero */ -template +template void weighted_sqrt(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, + raft::device_matrix_view in, + raft::device_matrix_view out, math_t scalar, bool set_neg_zero = false) { @@ -74,14 +80,16 @@ void weighted_sqrt(const raft::handle_t& handle, /** * @brief Square root of every element in the input matrix (in place) * @tparam math_t data-type upon which the math operation will be performed - * @param handle: raft handle - * @param inout: input matrix and also the result is stored - * @param scalar: every element is multiplied with scalar - * @param set_neg_zero whether to set negative numbers to zero + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[inout] inout: input matrix and also the result is stored + * @param[in] scalar: every element is multiplied with scalar + * @param[in] set_neg_zero whether to set negative numbers to zero */ -template +template void weighted_sqrt(const raft::handle_t& handle, - raft::device_matrix_view inout, + raft::device_matrix_view inout, math_t scalar, bool set_neg_zero = false) { diff --git a/cpp/include/raft/matrix/threshold.cuh b/cpp/include/raft/matrix/threshold.cuh index d13c55ea80..7540ceb3c6 100644 --- a/cpp/include/raft/matrix/threshold.cuh +++ b/cpp/include/raft/matrix/threshold.cuh @@ -24,15 +24,17 @@ namespace raft::matrix { /** * @brief sets the small values to zero based on a defined threshold * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) * @param handle: raft handle - * @param in: input matrix - * @param out: output matrix. The result is stored in the out matrix - * @param thres threshold to set values to zero + * @param[in] in: input matrix + * @param[out] out: output matrix. The result is stored in the out matrix + * @param[in] thres threshold to set values to zero */ -template +template void zero_small_values(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, + raft::device_matrix_view in, + raft::device_matrix_view out, math_t thres = 1e-15) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size"); @@ -43,13 +45,15 @@ void zero_small_values(const raft::handle_t& handle, /** * @brief sets the small values to zero in-place based on a defined threshold * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) * @param handle: raft handle * @param inout: input matrix and also the result is stored * @param thres: threshold */ -template +template void zero_small_values(const raft::handle_t& handle, - raft::device_matrix_view inout, + raft::device_matrix_view inout, math_t thres = 1e-15) { detail::setSmallValuesZero(inout.data_handle(), inout.size(), handle.get_stream(), thres); diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 8fbcb1a38d..87ebf7a290 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -47,7 +47,7 @@ class ArgMaxTest : public ::testing::TestWithParam> { void test() { auto input = raft::make_device_matrix(handle, params.n_rows, params.n_cols); - auto output = raft::make_device_vector(handle, params.n_rows); + auto output = raft::make_device_vector(handle, params.n_rows); auto expected = raft::make_device_vector(handle, params.n_rows); raft::update_device(input.data_handle(), diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 7d626f5f3c..c2f8d87b56 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -104,7 +104,7 @@ class GatherTest : public ::testing::TestWithParam { auto map_view = raft::make_device_vector_view(d_map.data(), map_length); - raft::matrix::gather(handle, in_view, out_view, map_view); + raft::matrix::gather(handle, in_view, map_view, out_view); // // launch device version of the kernel // gatherLaunch( diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 22e884e33f..9d3d5af51e 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -58,12 +58,16 @@ struct LinewiseTest : public ::testing::TestWithParam void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, const T* vec) { - auto f = [] __device__(T a, T b) -> T { return a + b; }; + auto f = [] __device__(T a, T b) -> T { return a + b; }; + constexpr auto rowmajor = std::is_same_v; - auto in_view = raft::make_device_matrix_view(in, nLines, lineLen); - auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); + I m = rowmajor ? lineLen : nLines; + I n = rowmajor ? nLines : lineLen; - auto vec_view = raft::make_device_vector_view(vec, lineLen); + auto in_view = raft::make_device_matrix_view(in, m, n); + auto out_view = raft::make_device_matrix_view(out, m, n); + + auto vec_view = raft::make_device_vector_view(vec, m); matrix::linewise_op(handle, in_view, out_view, raft::is_row_major(in_view), f, vec_view); } @@ -71,12 +75,16 @@ struct LinewiseTest : public ::testing::TestWithParam T { return a + b + c; }; + auto f = [] __device__(T a, T b, T c) -> T { return a + b + c; }; + constexpr auto rowmajor = std::is_same_v; + + I m = rowmajor ? lineLen : nLines; + I n = rowmajor ? nLines : lineLen; - auto in_view = raft::make_device_matrix_view(in, nLines, lineLen); - auto out_view = raft::make_device_matrix_view(out, nLines, lineLen); - auto vec1_view = raft::make_device_vector_view(vec1, lineLen); - auto vec2_view = raft::make_device_vector_view(vec2, lineLen); + auto in_view = raft::make_device_matrix_view(in, m, n); + auto out_view = raft::make_device_matrix_view(out, m, n); + auto vec1_view = raft::make_device_vector_view(vec1, m); + auto vec2_view = raft::make_device_vector_view(vec2, m); matrix::linewise_op( handle, in_view, out_view, raft::is_row_major(in_view), f, vec1_view, vec2_view); From 5b705d798b5c9416791c6c09b452420d01a69281 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 28 Sep 2022 18:12:24 -0400 Subject: [PATCH 43/58] More docs cleanup --- cpp/include/raft/matrix/col_wise_sort.cuh | 29 +++++++++++++---------- cpp/include/raft/matrix/gather.cuh | 6 ++--- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 2a6ecf61a6..249b8a9406 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -56,17 +56,20 @@ void sort_cols_per_row(const InType* in, /** * @brief sort columns within each row of row-major input matrix and return sorted indexes * modelled as key-value sort with key being input matrix and value being index of values - * @param in: input matrix - * @param out: output value(index) matrix - * @param sorted_keys: Optional, output matrix for sorted keys (input) + * @tparam in_t: element type of input matrix + * @tparam out_t: element type of output matrix + * @tparam matrix_idx_t: integer type for matrix indexing + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output value(index) matrix + * @param[out] sorted_keys: Optional, output matrix for sorted keys (input) */ -template -void sort_cols_per_row( - const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - std::optional> sorted_keys = - std::nullopt) +template +void sort_cols_per_row(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + std::optional> + sorted_keys = std::nullopt) { RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0), "Input and output matrices must have the same shape."); @@ -80,9 +83,9 @@ void sort_cols_per_row( size_t workspace_size = 0; bool alloc_workspace = false; - InType* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; + in_t* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; - detail::sortColumnsPerRow(in.data_handle(), + detail::sortColumnsPerRow(in.data_handle(), out.data_handle(), in.extent(0), in.extent(1), @@ -95,7 +98,7 @@ void sort_cols_per_row( if (alloc_workspace) { auto workspace = raft::make_device_vector(handle, workspace_size); - detail::sortColumnsPerRow(in.data_handle(), + detail::sortColumnsPerRow(in.data_handle(), out.data_handle(), in.extent(0), in.extent(1), diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 58cacf9c73..cab03fe52c 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -99,9 +99,9 @@ void gather(const raft::handle_t& handle, */ template void gather(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map, + raft::device_matrix_view in, + raft::device_vector_view map, + raft::device_matrix_view out, map_xform_t transform_op) { RAFT_EXPECTS(out.extent(0) == map.extent(0), From b49c9f4a536fb5d8cb864a77585725752fa0a736 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 28 Sep 2022 18:31:01 -0400 Subject: [PATCH 44/58] Updates --- cpp/include/raft/matrix/col_wise_sort.cuh | 32 +++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 249b8a9406..a1ad3a8b36 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -86,27 +86,27 @@ void sort_cols_per_row(const raft::handle_t& handle, in_t* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; detail::sortColumnsPerRow(in.data_handle(), - out.data_handle(), - in.extent(0), - in.extent(1), - alloc_workspace, - (void*)nullptr, - workspace_size, - handle.get_stream(), - keys); + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + (void*)nullptr, + workspace_size, + handle.get_stream(), + keys); if (alloc_workspace) { auto workspace = raft::make_device_vector(handle, workspace_size); detail::sortColumnsPerRow(in.data_handle(), - out.data_handle(), - in.extent(0), - in.extent(1), - alloc_workspace, - (void*)workspace.data_handle(), - workspace_size, - handle.get_stream(), - keys); + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + (void*)workspace.data_handle(), + workspace_size, + handle.get_stream(), + keys); } } From 5fa5bbf97a60882c6e7f825f887277c43b50bd01 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 29 Sep 2022 11:39:10 -0400 Subject: [PATCH 45/58] Implementing review feedback --- cpp/include/raft/cluster/single_linkage.cuh | 36 ++++++++++++--------- cpp/include/raft/core/mdspan.hpp | 25 +++++++------- cpp/test/CMakeLists.txt | 3 +- cpp/test/sparse/linkage.cu | 26 ++++++++------- 4 files changed, 51 insertions(+), 39 deletions(-) diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 7f0553a553..8e33b8389d 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -21,6 +21,8 @@ namespace raft::cluster { +constexpr int DEFAULT_CONST_C = 15; + /** * Single-linkage clustering, capable of constructing a KNN graph to * scale the algorithm beyond the n^2 memory consumption of implementations @@ -68,30 +70,34 @@ void single_linkage(const raft::handle_t& handle, * @tparam dist_type method to use for constructing connectivities graph * @param[in] handle raft handle * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X + * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2) + * @param[out] labels output labels vector (size n_rows) * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments + * @param[in] n_clusters number of clusters to assign data samples * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect control of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples */ template void single_linkage(const raft::handle_t& handle, raft::device_matrix_view X, + raft::device_matrix_view dendrogram, + raft::device_vector_view labels, raft::distance::DistanceType metric, - linkage_output& out, - int c, - size_t n_clusters) + size_t n_clusters, + std::optional c = std::make_optional(DEFAULT_CONST_C)) { - detail::single_linkage(handle, - X.data_handle(), - static_cast(X.extent(0)), - static_cast(X.extent(1)), - metric, - &out, - c, - n_clusters); + linkage_output out_arrs; + out_arrs.children = dendrogram.data_handle(); + out_arrs.labels = labels.data_handle(); + + single_linkage(handle, + X.data_handle(), + static_cast(X.extent(0)), + static_cast(X.extent(1)), + metric, + &out_arrs, + c.has_value() ? c.value() : DEFAULT_CONST_C, + n_clusters); } }; // namespace raft::cluster diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 3d95bb54cd..2ec05473f2 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -267,7 +267,7 @@ constexpr bool is_row_major(mdspan template constexpr bool is_row_major(mdspan m) { - return m.is_exhaustive(); + return m.is_exhaustive() && m.stride(1) == typename Extents::index_type(1); } template @@ -291,13 +291,14 @@ constexpr bool is_col_major(mdspan template constexpr bool is_col_major(mdspan m) { - return m.is_exhaustive(); + return m.is_exhaustive() && m.stride(0) == typename Extents::index_type(1); } -template -constexpr bool is_matrix_view(mdspan> /* m */) +template +constexpr bool is_matrix_view( + mdspan, Layout, Accessor> /* m */) { - return true; + return sizeof...(Exts) == 2; } template @@ -306,10 +307,11 @@ constexpr bool is_matrix_view(mdspan m) return false; } -template -constexpr bool is_vector_view(mdspan> /* m */) +template +constexpr bool is_matrix_view( + mdspan, Layout, Accessor> /* m */) { - return true; + return sizeof...(Exts) == 1; } template @@ -318,10 +320,11 @@ constexpr bool is_vector_view(mdspan m) return false; } -template -constexpr bool is_scalar_view(mdspan> /* m */) +template +constexpr bool is_matrix_view( + mdspan, Layout, Accessor> /* m */) { - return true; + return sizeof...(Exts) == 0; } template diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index bbc35906b9..fe2504606b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -82,6 +82,8 @@ if(BUILD_TESTS) PATH test/cluster/kmeans.cu test/cluster_solvers.cu + test/sparse/linkage.cu + OPTIONAL DIST NN ) ConfigureTest(NAME CORE_TEST @@ -213,7 +215,6 @@ if(BUILD_TESTS) test/sparse/connect_components.cu test/sparse/knn.cu test/sparse/knn_graph.cu - test/sparse/linkage.cu OPTIONAL DIST NN ) diff --git a/cpp/test/sparse/linkage.cu b/cpp/test/sparse/linkage.cu index 6fa1d0461e..ce5741d06b 100644 --- a/cpp/test/sparse/linkage.cu +++ b/cpp/test/sparse/linkage.cu @@ -24,6 +24,7 @@ #include #endif +#include #include #include @@ -175,23 +176,24 @@ class LinkageTest : public ::testing::TestWithParam> { raft::copy(data.data(), params.data.data(), data.size(), stream); raft::copy(labels_ref.data(), params.expected_labels.data(), params.n_row, stream); - raft::hierarchy::linkage_output out_arrs; - out_arrs.labels = labels.data(); - rmm::device_uvector out_children(params.n_row * 2, stream); - out_arrs.children = out_children.data(); - raft::handle_t handle; - raft::hierarchy::single_linkage( + + auto data_view = + raft::make_device_matrix_view(data.data(), params.n_row, params.n_col); + auto dendrogram_view = + raft::make_device_matrix_view(out_children.data(), params.n_row, 2); + auto labels_view = raft::make_device_vector_view(labels.data(), params.n_row); + + raft::cluster::single_linkage( handle, - data.data(), - params.n_row, - params.n_col, + data_view, + dendrogram_view, + labels_view, raft::distance::DistanceType::L2SqrtExpanded, - &out_arrs, - params.c, - params.n_clusters); + params.n_clusters, + std::make_optional(params.c)); handle.sync_stream(stream); From ce94f632da0a545482193da0b9b00d2538a830bc Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 29 Sep 2022 11:43:27 -0400 Subject: [PATCH 46/58] Renaming --- cpp/include/raft/core/mdspan.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 2ec05473f2..c801e7586c 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -308,7 +308,7 @@ constexpr bool is_matrix_view(mdspan m) } template -constexpr bool is_matrix_view( +constexpr bool is_vector_view( mdspan, Layout, Accessor> /* m */) { return sizeof...(Exts) == 1; @@ -321,7 +321,7 @@ constexpr bool is_vector_view(mdspan m) } template -constexpr bool is_matrix_view( +constexpr bool is_scalar_view( mdspan, Layout, Accessor> /* m */) { return sizeof...(Exts) == 0; From 9939464d435d049c04fdae3b2b2d0e19b640ff0a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 29 Sep 2022 12:17:55 -0400 Subject: [PATCH 47/58] Updating docs to include [in] and [out] --- cpp/include/raft/spatial/knn/ball_cover.cuh | 24 ++++++++++---------- cpp/include/raft/spatial/knn/brute_force.cuh | 22 ++++++++---------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index 704acafd45..ccc9f6e0e8 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -36,8 +36,8 @@ namespace knn { * @tparam value_t knn value type * @tparam int_t integral type for knn params * @tparam matrix_idx_t matrix indexing type - * @param handle library resource management handle - * @param index an empty (and not previous built) instance of BallCoverIndex + * @param[in] handle library resource management handle + * @param[inout] index an empty (and not previous built) instance of BallCoverIndex */ template inline void knn_merge_parts( @@ -83,10 +83,6 @@ inline void knn_merge_parts( * row- or column-major but the output matrices will always be in * row-major format. * - * @example - * - * - * * @param[in] handle the cuml handle to use * @param[in] index vector of device matrices (each size m_i*d) to be used as the knn index * @param[in] search matrix (size n*d) to be used for searching the index From 32775a6becc05b8c85b9db078883504dfb88ce4f Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 29 Sep 2022 16:56:29 -0400 Subject: [PATCH 48/58] Syncing handle after argmax prim --- cpp/test/matrix/argmax.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu index 87ebf7a290..70884af4de 100644 --- a/cpp/test/matrix/argmax.cu +++ b/cpp/test/matrix/argmax.cu @@ -57,13 +57,13 @@ class ArgMaxTest : public ::testing::TestWithParam> { raft::update_device( expected.data_handle(), params.output_matrix.data(), params.n_rows, handle.get_stream()); - printf("Finished copy\n"); - auto input_view = raft::make_device_matrix_view( input.data_handle(), params.n_rows, params.n_cols); raft::matrix::argmax(handle, input_view, output.view()); + handle.sync_stream(); + ASSERT_TRUE(devArrMatch(output.data_handle(), expected.data_handle(), params.n_rows, From 5363a2d998ad0c6bcfdde2d5f76c9b2958e91ce4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 09:58:15 -0400 Subject: [PATCH 49/58] Removing defaults on template args --- cpp/include/raft/spatial/knn/ball_cover.cuh | 12 ++++++------ cpp/include/raft/spatial/knn/brute_force.cuh | 10 +++++----- .../raft/spatial/knn/epsilon_neighborhood.cuh | 2 +- cpp/include/raft/spatial/knn/ivf_flat.cuh | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index ccc9f6e0e8..0420e47cde 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -39,10 +39,10 @@ namespace knn { * @param[in] handle library resource management handle * @param[inout] index an empty (and not previous built) instance of BallCoverIndex */ -template + typename int_t, + typename matrix_idx_t> void rbc_build_index(const raft::handle_t& handle, BallCoverIndex& index) { @@ -87,10 +87,10 @@ void rbc_build_index(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template + typename int_t, + typename matrix_idx_t> void rbc_all_knn_query(const raft::handle_t& handle, BallCoverIndex& index, int_t k, diff --git a/cpp/include/raft/spatial/knn/brute_force.cuh b/cpp/include/raft/spatial/knn/brute_force.cuh index fb40ae13d1..c32a33d2e2 100644 --- a/cpp/include/raft/spatial/knn/brute_force.cuh +++ b/cpp/include/raft/spatial/knn/brute_force.cuh @@ -45,7 +45,7 @@ namespace raft::spatial::knn { * @param[in] k number of neighbors for each part * @param[in] translations optional vector of starting index mappings for each partition */ -template +template inline void knn_merge_parts( const raft::handle_t& handle, raft::device_matrix_view in_keys, @@ -95,10 +95,10 @@ inline void knn_merge_parts( * @param[in] translations starting offsets for partitions. should be the same size * as input vector. */ -template void brute_force_knn( diff --git a/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh index dce5f0f99d..53fe76fada 100644 --- a/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/epsilon_neighborhood.cuh @@ -76,7 +76,7 @@ void epsUnexpL2SqNeighborhood(bool* adj, * @param[in] eps defines epsilon neighborhood radius (should be passed as * squared as we compute L2-squared distance in this method) */ -template +template void eps_neighbors_l2sq(const raft::handle_t& handle, raft::device_matrix_view x, raft::device_matrix_view y, diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index 288834214f..bd8f916584 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -60,7 +60,7 @@ namespace raft::spatial::knn::ivf_flat { * * @return the constructed ivf-flat index */ -template +template inline auto build( const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim) -> index @@ -101,9 +101,9 @@ inline auto build( * @return the constructed ivf-flat index */ template + typename idx_t, + typename int_t, + typename matrix_idx_t> auto build_index(const handle_t& handle, raft::device_matrix_view dataset, const index_params& params) -> index From dec2f813d8a55d22393f399128a2755f6f10575b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 10:36:20 -0400 Subject: [PATCH 50/58] Adding limit-tests to build.sh. Removing default template args per reviews --- build.sh | 25 ++++++++-- cpp/include/raft/spatial/knn/ball_cover.cuh | 10 +--- cpp/include/raft/spatial/knn/ivf_flat.cuh | 51 ++++++++------------- cpp/test/spatial/ann_ivf_flat.cu | 18 ++++---- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/build.sh b/build.sh index e8dfa3e404..0caa823ca7 100755 --- a/build.sh +++ b/build.sh @@ -19,7 +19,7 @@ ARGS=$* REPODIR=$(cd $(dirname $0); pwd) VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean -v -g --install --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps" -HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] +HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] where is: clean - remove all existing build artifacts and configuration (start over) libraft - build the raft C++ code only. Also builds the C-wrapper library @@ -40,6 +40,7 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=\"] [--cache-tool= +template void rbc_build_index(const raft::handle_t& handle, BallCoverIndex& index) { @@ -87,10 +84,7 @@ void rbc_build_index(const raft::handle_t& handle, * many datasets can still have great recall even by only * looking in the closest landmark. */ -template +template void rbc_all_knn_query(const raft::handle_t& handle, BallCoverIndex& index, int_t k, diff --git a/cpp/include/raft/spatial/knn/ivf_flat.cuh b/cpp/include/raft/spatial/knn/ivf_flat.cuh index bd8f916584..58ca96d392 100644 --- a/cpp/include/raft/spatial/knn/ivf_flat.cuh +++ b/cpp/include/raft/spatial/knn/ivf_flat.cuh @@ -100,19 +100,16 @@ inline auto build( * * @return the constructed ivf-flat index */ -template +template auto build_index(const handle_t& handle, - raft::device_matrix_view dataset, + raft::device_matrix_view dataset, const index_params& params) -> index { return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset.data_handle(), static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); + static_cast(dataset.extent(1))); } /** @@ -191,15 +188,12 @@ inline auto extend(const handle_t& handle, * * @return the constructed extended ivf-flat index */ -template +template auto extend(const handle_t& handle, const index& orig_index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = - std::nullopt) -> index + raft::device_matrix_view new_vectors, + std::optional> new_indices = std::nullopt) + -> index { return raft::spatial::knn::ivf_flat::detail::extend( handle, @@ -248,15 +242,11 @@ inline void extend(const handle_t& handle, * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` * here to imply a continuous range `[0...n_rows)`. */ -template -void extend( - const handle_t& handle, - index* index, - raft::device_matrix_view new_vectors, - std::optional> new_indices = std::nullopt) +template +void extend(const handle_t& handle, + index* index, + raft::device_matrix_view new_vectors, + std::optional> new_indices = std::nullopt) { *index = extend(handle, *index, @@ -363,15 +353,12 @@ inline void search(const handle_t& handle, * @param[in] params configure the search * @param[in] k the number of neighbors to find for each query. */ -template +template void search(const handle_t& handle, const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, const search_params& params, int_t k) { @@ -379,9 +366,9 @@ void search(const handle_t& handle, queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), "Number of rows in output neighbors and distances matrices must equal the number of queries."); - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1) && - neighbors.extent(1) == static_cast(k), - "Number of columns in output neighbors and distances matrices must equal k"); + RAFT_EXPECTS( + neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast(k), + "Number of columns in output neighbors and distances matrices must equal k"); RAFT_EXPECTS(queries.extent(1) == index.dim(), "Number of query dimensions should equal number of dimensions in the index."); diff --git a/cpp/test/spatial/ann_ivf_flat.cu b/cpp/test/spatial/ann_ivf_flat.cu index 4db7b85394..47ca92fc97 100644 --- a/cpp/test/spatial/ann_ivf_flat.cu +++ b/cpp/test/spatial/ann_ivf_flat.cu @@ -208,10 +208,10 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { index_params.add_data_on_build = false; index_params.kmeans_trainset_fraction = 0.5; - auto database_view = raft::make_device_matrix_view( + auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto index = ivf_flat::build_index(handle_, database_view, index_params); + auto index = ivf_flat::build_index(handle_, database_view, index_params); rmm::device_uvector vector_indices(ps.num_db_vecs, stream_); thrust::sequence(handle_.get_thrust_policy(), @@ -221,16 +221,16 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { int64_t half_of_data = ps.num_db_vecs / 2; - auto half_of_data_view = raft::make_device_matrix_view( + auto half_of_data_view = raft::make_device_matrix_view( (const DataT*)database.data(), half_of_data, ps.dim); - auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); + auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view); - ivf_flat::extend(handle_, - &index_2, - database.data() + half_of_data * ps.dim, - vector_indices.data() + half_of_data, - int64_t(ps.num_db_vecs) - half_of_data); + ivf_flat::extend(handle_, + &index_2, + database.data() + half_of_data * ps.dim, + vector_indices.data() + half_of_data, + int64_t(ps.num_db_vecs) - half_of_data); ivf_flat::search(handle_, search_params, From b70b44ba7ec7a668b4405600cb2ba8bf1f79cd89 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 10:54:02 -0400 Subject: [PATCH 51/58] Updating docs --- cpp/include/raft/matrix/copy.cuh | 5 +---- cpp/include/raft/matrix/detail/print.hpp | 7 ++++--- cpp/include/raft/matrix/gather.cuh | 9 +++++---- cpp/include/raft/matrix/linewise_op.cuh | 2 +- cpp/include/raft/matrix/print.cuh | 14 ++++++++++---- 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 9d2b8ed5bf..035370d5c8 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -73,12 +73,9 @@ void copy(const raft::handle_t& handle, /** * @brief copy matrix operation for column major matrices. First n_rows and * n_cols of input matrix "in" is copied to "out" matrix. + * @param handle: raft handle for managing resources * @param in: input matrix - * @param in_n_rows: number of rows of input matrix * @param out: output matrix - * @param out_n_rows: number of rows of output matrix - * @param out_n_cols: number of columns of output matrix - * @param stream: cuda stream */ template void trunc_zero_origin(const raft::handle_t& handle, diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp index 0545d049ad..c8510ee1b6 100644 --- a/cpp/include/raft/matrix/detail/print.hpp +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -35,13 +35,14 @@ namespace raft::matrix::detail { template -void printHost(const m_t* in, idx_t n_rows, idx_t n_cols) +void printHost(const m_t* in, idx_t n_rows, idx_t n_cols, char h_separator = ' ', + char v_separator = '\n', +) { for (idx_t i = 0; i < n_rows; i++) { for (idx_t j = 0; j < n_cols; j++) { - printf("%1.4f ", in[j * n_rows + i]); + printf("%1.4f%c", in[j * n_rows + i], j < n_cols - 1 ? h_separator : v_separator); } - printf("\n"); } } diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index cab03fe52c..571121f18b 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -57,9 +57,10 @@ void gather(const MatrixIteratorT in, * @tparam matrix_t Matrix element type * @tparam map_t Map vector type * @tparam idx_t integer type used for indexing + * @param[in] handle raft handle for managing resources * @param[in] in Input matrix (assumed to be row-major) * @param[in] map Vector of gather locations - * @param[out] out Output matrix (assumed to be row-major) + * @param[out] out Output matrix (assumed to be row-major) */ template void gather(const raft::handle_t& handle, @@ -91,7 +92,7 @@ void gather(const raft::handle_t& handle, * @tparam map_xform_t Unary lambda expression or operator type, MapTransformOp's result * type must be convertible to idx_t (= int) type. * @tparam idx_t integer type for indexing - * + * @param[in] handle raft handle for managing resources * @param[in] in Input matrix (assumed to be row-major) * @param[in] map Input vector of gather locations * @param[out] out Output matrix (assumed to be row-major) @@ -205,7 +206,7 @@ void gather_if(const MatrixIteratorT in, * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result * type must be convertible to bool type. * @tparam idx_t integer type for indexing - * + * @param[in] handle raft handle for managing resources * @param[in] in Input matrix (assumed to be row-major) * @param[in] map Input vector of gather locations * @param[in] stencil Input vector of stencil or predicate values @@ -300,7 +301,7 @@ void gather_if(const MatrixIteratorT in, * @tparam map_xform_t Unary lambda expression or operator type, map_xform_t's result * type must be convertible to idx_t (= int) type. * @tparam idx_t integer type for indexing - * + * @param[in] handle raft handle for managing resources * @param[in] in Input matrix (assumed to be row-major) * @param[in] map Vector of gather locations * @param[in] stencil Vector of stencil or predicate values diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index 5996e40f86..053e02cb4a 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -41,6 +41,7 @@ namespace raft::matrix { * @tparam Lambda type of lambda function used for the operation * @tparam vec_t variadic types of device_vector_view vectors (size m if alongRows, size n * otherwise) + * @param[in] handle raft handle for managing resources * @param [out] out result of the operation; can be same as `in`; should be aligned the same * as `in` to allow faster vectorized memory transfers. * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. @@ -50,7 +51,6 @@ namespace raft::matrix { * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false * where matrix indexing is row-major ([i, j] = [i + lineLen * j]). - * @param [in] stream a cuda stream for the kernels * @param [in] vecs zero or more vectors to be passed as arguments, * size of each vector is `alongLines ? lineLen : nLines`. */ diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh index 060cd8642c..def9fc9182 100644 --- a/cpp/include/raft/matrix/print.cuh +++ b/cpp/include/raft/matrix/print.cuh @@ -43,14 +43,20 @@ void print(const raft::handle_t& handle, } /** - * @brief Prints the data stored in CPU memory + * @brief Prints the host data stored in CPU memory * @tparam m_t type of matrix elements * @tparam idx_t integer type used for indexing - * @param in: input matrix with column-major layout + * @param[in] handle raft handle for managing resources + * @param[in] in input matrix with column-major layout + * @param[in] h_separator: horizontal separator character + * @param[in] v_separator: vertical separator character */ template -void print(raft::host_matrix_view in) +void print(const raft::handle_t& handle, + raft::host_matrix_view in, + char h_separator = ' ', + char v_separator = '\n') { - detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); + detail::printHost(in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator); } } // namespace raft::matrix From e25caefa962d291d46d92e5887dc0fb3c95027e7 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 12:08:07 -0400 Subject: [PATCH 52/58] More style cleanup --- cpp/include/raft/matrix/detail/print.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp index c8510ee1b6..fc3d14861c 100644 --- a/cpp/include/raft/matrix/detail/print.hpp +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -35,9 +35,8 @@ namespace raft::matrix::detail { template -void printHost(const m_t* in, idx_t n_rows, idx_t n_cols, char h_separator = ' ', - char v_separator = '\n', -) +void printHost( + const m_t* in, idx_t n_rows, idx_t n_cols, char h_separator = ' ', char v_separator = '\n', ) { for (idx_t i = 0; i < n_rows; i++) { for (idx_t j = 0; j < n_cols; j++) { From 37ae236256a33a2b1609d3c701fcec7a11d9b609 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 15:18:33 -0400 Subject: [PATCH 53/58] Pulling out argmax for now since the test seems to be failing in centos. --- cpp/include/raft/matrix/argmax.cuh | 42 ----------- cpp/include/raft/matrix/matrix_types.hpp | 26 +++++++ cpp/include/raft/matrix/print.cuh | 33 +++----- cpp/include/raft/matrix/print.hpp | 9 ++- cpp/test/CMakeLists.txt | 1 - cpp/test/matrix/argmax.cu | 95 ------------------------ 6 files changed, 41 insertions(+), 165 deletions(-) delete mode 100644 cpp/include/raft/matrix/argmax.cuh create mode 100644 cpp/include/raft/matrix/matrix_types.hpp delete mode 100644 cpp/test/matrix/argmax.cu diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh deleted file mode 100644 index b7423b9ea4..0000000000 --- a/cpp/include/raft/matrix/argmax.cuh +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace raft::matrix { - -/** - * @brief Argmax: find the row idx with maximum value for each column - * @tparam math_t matrix element type - * @tparam idx_t integer type for matrix and vector indexing - * @param[in] handle: raft handle - * @param[in] in: input matrix of size (n_rows, n_cols) - * @param[out] out: output vector of size n_cols - */ -template -void argmax(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view out) -{ - RAFT_EXPECTS(out.extent(0) == in.extent(0), - "Size of output vector must equal number of rows in input matrix."); - detail::argmax( - in.data_handle(), in.extent(0), in.extent(1), out.data_handle(), handle.get_stream()); -} -} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/matrix_types.hpp b/cpp/include/raft/matrix/matrix_types.hpp new file mode 100644 index 0000000000..1f22154627 --- /dev/null +++ b/cpp/include/raft/matrix/matrix_types.hpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::matrix { + +struct print_separators { + char horizontal = ' '; + char vertical = '\n'; +}; + +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh index def9fc9182..4d3a8ca938 100644 --- a/cpp/include/raft/matrix/print.cuh +++ b/cpp/include/raft/matrix/print.cuh @@ -20,6 +20,7 @@ #include #include #include +#include namespace raft::matrix { @@ -29,34 +30,18 @@ namespace raft::matrix { * @tparam idx_t integer type used for indexing * @param[in] handle: raft handle * @param[in] in: input matrix - * @param[in] h_separator: horizontal separator character - * @param[in] v_separator: vertical separator character + * @param[in] separators: horizontal and vertical separator characters */ template void print(const raft::handle_t& handle, raft::device_matrix_view in, - char h_separator = ' ', - char v_separator = '\n') + print_separators& separators) { - detail::print( - in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator, handle.get_stream()); -} - -/** - * @brief Prints the host data stored in CPU memory - * @tparam m_t type of matrix elements - * @tparam idx_t integer type used for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in input matrix with column-major layout - * @param[in] h_separator: horizontal separator character - * @param[in] v_separator: vertical separator character - */ -template -void print(const raft::handle_t& handle, - raft::host_matrix_view in, - char h_separator = ' ', - char v_separator = '\n') -{ - detail::printHost(in.data_handle(), in.extent(0), in.extent(1), h_separator, v_separator); + detail::print(in.data_handle(), + in.extent(0), + in.extent(1), + separators.horizontal, + separators.vertical, + handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.hpp b/cpp/include/raft/matrix/print.hpp index 66e939be0f..86c314ed44 100644 --- a/cpp/include/raft/matrix/print.hpp +++ b/cpp/include/raft/matrix/print.hpp @@ -18,16 +18,19 @@ #include #include +#include namespace raft::matrix { /** * @brief Prints the data stored in CPU memory - * @param in: input matrix with column-major layout + * @param[in] in: input matrix with column-major layout + * @param[in] separators: horizontal and vertical separator characters */ template -void print(raft::host_matrix_view in) +void print(raft::host_matrix_view in, print_separators& separators) { - detail::printHost(in.data_handle(), in.extent(0), in.extent(1)); + detail::printHost( + in.data_handle(), in.extent(0), in.extent(1), separators.horizontal, separators.vertical); } } // namespace raft::matrix diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index fe2504606b..a18a750e4b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -157,7 +157,6 @@ if(BUILD_TESTS) ConfigureTest(NAME MATRIX_TEST PATH - test/matrix/argmax.cu test/matrix/gather.cu test/matrix/math.cu test/matrix/matrix.cu diff --git a/cpp/test/matrix/argmax.cu b/cpp/test/matrix/argmax.cu deleted file mode 100644 index 70884af4de..0000000000 --- a/cpp/test/matrix/argmax.cu +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../test_utils.h" -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace matrix { - -template -struct ArgMaxInputs { - std::vector input_matrix; - std::vector output_matrix; - std::size_t n_cols; - std::size_t n_rows; -}; - -template -::std::ostream& operator<<(::std::ostream& os, const ArgMaxInputs& dims) -{ - return os; -} - -template -class ArgMaxTest : public ::testing::TestWithParam> { - public: - ArgMaxTest() : params(::testing::TestWithParam>::GetParam()) {} - - void test() - { - auto input = raft::make_device_matrix(handle, params.n_rows, params.n_cols); - auto output = raft::make_device_vector(handle, params.n_rows); - auto expected = raft::make_device_vector(handle, params.n_rows); - - raft::update_device(input.data_handle(), - params.input_matrix.data(), - params.n_rows * params.n_cols, - handle.get_stream()); - raft::update_device( - expected.data_handle(), params.output_matrix.data(), params.n_rows, handle.get_stream()); - - auto input_view = raft::make_device_matrix_view( - input.data_handle(), params.n_rows, params.n_cols); - - raft::matrix::argmax(handle, input_view, output.view()); - - handle.sync_stream(); - - ASSERT_TRUE(devArrMatch(output.data_handle(), - expected.data_handle(), - params.n_rows, - Compare(), - handle.get_stream())); - } - - protected: - raft::handle_t handle; - ArgMaxInputs params; -}; - -const std::vector> inputsf = { - {{0.1f, 0.2f, 0.3f, 0.4f, 0.4f, 0.3f, 0.2f, 0.1f, 0.2f, 0.3f, 0.5f, 0.0f}, {3, 0, 2}, 3, 4}}; - -const std::vector> inputsd = { - {{0.1, 0.2, 0.3, 0.4, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.5, 0.0}, {3, 0, 2}, 3, 4}}; - -typedef ArgMaxTest ArgMaxTestF; -TEST_P(ArgMaxTestF, Result) { test(); } - -typedef ArgMaxTest ArgMaxTestD; -TEST_P(ArgMaxTestD, Result) { test(); } - -INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestF, ::testing::ValuesIn(inputsf)); -INSTANTIATE_TEST_CASE_P(ArgMaxTest, ArgMaxTestD, ::testing::ValuesIn(inputsd)); - -} // namespace matrix -} // namespace raft From 9ce7d4f97b5469c17a6b8617e93f520ea68ef3ed Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 18:23:07 -0400 Subject: [PATCH 54/58] Lots of updates from review feedback. --- BUILD.md | 26 +++++--- cpp/include/raft/matrix/col_wise_sort.cuh | 47 ++++++++++++-- cpp/include/raft/matrix/gather.cuh | 14 ++-- cpp/include/raft/matrix/power.cuh | 4 +- cpp/include/raft/matrix/ratio.cuh | 2 +- cpp/include/raft/matrix/sqrt.cuh | 4 +- cpp/include/raft/spatial/knn/ball_cover.cuh | 6 +- .../raft/spatial/knn/ball_cover_types.hpp | 65 +++++++++++++------ cpp/include/raft/spatial/knn/brute_force.cuh | 6 +- .../raft/spatial/knn/detail/ball_cover.cuh | 13 ++-- .../knn/detail/ball_cover/registers.cuh | 6 +- .../knn/specializations/ball_cover.cuh | 2 +- .../detail/ball_cover_lowdim.hpp | 8 +-- cpp/src/nn/specializations/ball_cover.cu | 2 +- .../detail/ball_cover_lowdim_pass_one_2d.cu | 2 +- .../detail/ball_cover_lowdim_pass_one_3d.cu | 4 +- .../detail/ball_cover_lowdim_pass_two_2d.cu | 2 +- .../detail/ball_cover_lowdim_pass_two_3d.cu | 2 +- cpp/test/matrix/gather.cu | 2 +- cpp/test/spatial/ann_ivf_flat.cu | 4 +- 20 files changed, 146 insertions(+), 75 deletions(-) diff --git a/BUILD.md b/BUILD.md index 0c7fdd7e82..f572b11848 100644 --- a/BUILD.md +++ b/BUILD.md @@ -5,6 +5,7 @@ - [Build Dependencies](#required_depenencies) - [Header-only C++](#install_header_only_cpp) - [C++ Shared Libraries](#shared_cpp_libs) + - [Improving Rebuild Times](#ccache) - [Googletests](#gtests) - [C++ Using Cmake](#cpp_using_cmake) - [Python](#python) @@ -29,7 +30,6 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth - [RMM](https://github.com/rapidsai/rmm) corresponding to RAFT version. #### Optional -- [mdspan](https://github.com/rapidsai/mdspan) - On by default but can be disabled. - [Thrust](https://github.com/NVIDIA/thrust) v1.15 / [CUB](https://github.com/NVIDIA/cub) - On by default but can be disabled. - [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API. - [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 @@ -53,11 +53,6 @@ The following example will download the needed dependencies and install the RAFT ./build.sh libraft --install ``` -The `--minimal-deps` flag can be used to install the headers with minimal dependencies: -```bash -./build.sh libraft --install --minimal-deps -``` - ### C++ Shared Libraries (optional) For larger projects which make heavy use of the pairwise distances or nearest neighbors APIs, shared libraries can be built to speed up compile times. These shared libraries can also significantly improve re-compile times both while developing RAFT and developing against the APIs. Build all of the available shared libraries by passing `--compile-libs` flag to `build.sh`: @@ -72,6 +67,14 @@ Individual shared libraries have their own flags and multiple can be used (thoug Add the `--install` flag to the above example to also install the shared libraries into `$INSTALL_PREFIX/lib`. +### `ccache` and `sccache` + +`ccache` and `sccache` can be used to better cache parts of the build when rebuilding frequently, such as when working on a new feature. You can also use `ccache` or `sccache` with `build.sh`: + +```bash +./build.sh libraft --cache-tool=ccache +``` + ### Tests Compile the tests using the `tests` target in `build.sh`. @@ -86,10 +89,17 @@ Test compile times can be improved significantly by using the optional shared li ./build.sh libraft tests --compile-libs ``` -To run C++ tests: +The tests are broken apart by algorithm category, so you will find several binaries in `cpp/build/` named `*_TEST`. + +For example, to run the distance tests: +```bash +./cpp/build/DISTANCE_TEST +``` + +It can take sometime to compile all of the tests. You can build individual tests by providing a semicolon-separated list to the `--limit-tests` option in `build.sh`: ```bash -./cpp/build/test_raft +./build.sh libraft tests --limit-tests=SPATIAL_TEST;DISTANCE_TEST;MATRIX_TEST ``` ### Benchmarks diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index a1ad3a8b36..d26f5f73cf 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -22,8 +22,7 @@ #include #include -namespace raft { -namespace matrix { +namespace raft::matrix { /** * @brief sort columns within each row of row-major input matrix and return sorted indexes @@ -110,7 +109,47 @@ void sort_cols_per_row(const raft::handle_t& handle, } } -}; // end namespace matrix -}; // end namespace raft +namespace sort_cols_per_row_impl { +template +struct sorted_keys_alias { +}; + +template <> +struct sorted_keys_alias { + using type = double; +}; + +template +struct sorted_keys_alias< + std::optional>> { + using type = typename raft::device_matrix_view::value_type; +}; + +template +using sorted_keys_t = typename sorted_keys_alias::type; +} // namespace sort_cols_per_row_impl + +/** + * @brief Overload of `sort_keys_per_row` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `sort_keys_per_row`. + */ +template +void sort_cols_per_row(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + sorted_keys_vector_type sorted_keys) +{ + using sorted_keys_type = sort_cols_per_row_impl::sorted_keys_t< + std::remove_const_t>>; + std::optional> sorted_keys_opt = + std::forward(sorted_keys); + + sort_cols_per_row(handle, in, out, sorted_keys_opt); +} + +}; // end namespace raft::matrix #endif \ No newline at end of file diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 571121f18b..fa6e73de49 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -65,7 +65,7 @@ void gather(const MatrixIteratorT in, template void gather(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_vector_view map, + raft::device_vector_view map, raft::device_matrix_view out) { RAFT_EXPECTS(out.extent(0) == map.extent(0), @@ -101,8 +101,8 @@ void gather(const raft::handle_t& handle, template void gather(const raft::handle_t& handle, raft::device_matrix_view in, - raft::device_vector_view map, - raft::device_matrix_view out, + raft::device_vector_view map, + raft::device_matrix_view out, map_xform_t transform_op) { RAFT_EXPECTS(out.extent(0) == map.extent(0), @@ -221,8 +221,8 @@ template in, raft::device_matrix_view out, - raft::device_vector_view map, - raft::device_vector_view stencil, + raft::device_vector_view map, + raft::device_vector_view stencil, unary_pred_t pred_op) { RAFT_EXPECTS(out.extent(0) == map.extent(0), @@ -318,8 +318,8 @@ template in, raft::device_matrix_view out, - raft::device_vector_view map, - raft::device_vector_view stencil, + raft::device_vector_view map, + raft::device_vector_view stencil, unary_pred_t pred_op, map_xform_t transform_op) { diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh index 320ca4fe0f..4e2b3b7d72 100644 --- a/cpp/include/raft/matrix/power.cuh +++ b/cpp/include/raft/matrix/power.cuh @@ -33,7 +33,7 @@ namespace raft::matrix { */ template void weighted_power(const raft::handle_t& handle, - raft::device_matrix_view in, + raft::device_matrix_view in, raft::device_matrix_view out, math_t scalar) { @@ -84,7 +84,7 @@ void power(const raft::handle_t& handle, raft::device_matrix_view void power(const raft::handle_t& handle, - raft::device_matrix_view in, + raft::device_matrix_view in, raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh index 635b8ec46d..7895ea972f 100644 --- a/cpp/include/raft/matrix/ratio.cuh +++ b/cpp/include/raft/matrix/ratio.cuh @@ -32,7 +32,7 @@ namespace raft::matrix { */ template void ratio(const raft::handle_t& handle, - raft::device_matrix_view src, + raft::device_matrix_view src, raft::device_matrix_view dest) { RAFT_EXPECTS(src.size() == dest.size(), "Input and output matrices must be the same size."); diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index 2c03a8672c..b371253690 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -33,7 +33,7 @@ namespace raft::matrix { */ template void sqrt(const raft::handle_t& handle, - raft::device_matrix_view in, + raft::device_matrix_view in, raft::device_matrix_view out) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); @@ -67,7 +67,7 @@ void sqrt(const raft::handle_t& handle, raft::device_matrix_view void weighted_sqrt(const raft::handle_t& handle, - raft::device_matrix_view in, + raft::device_matrix_view in, raft::device_matrix_view out, math_t scalar, bool set_neg_zero = false) diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index 714b019fba..838c950115 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -203,7 +203,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, */ template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, int_t k, const value_t* query, int_t n_query_pts, @@ -272,7 +272,7 @@ void rbc_knn_query(const raft::handle_t& handle, */ template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, raft::device_matrix_view query, raft::device_matrix_view inds, raft::device_matrix_view dists, @@ -289,7 +289,7 @@ void rbc_knn_query(const raft::handle_t& handle, "Number of rows in output indices and distances matrices must equal number of rows " "in search matrix."); - RAFT_EXPECTS(query.extent(1) == index.get_R().extent(1), + RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1), "Number of columns in query and index matrices must match."); rbc_knn_query(handle, diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp index 1dd45365b7..897bb4df5b 100644 --- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp @@ -58,13 +58,12 @@ class BallCoverIndex { * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) */ n_landmarks(sqrt(m_)), - R_indptr(std::move(raft::make_device_vector(handle, sqrt(m_) + 1))), - R_1nn_cols(std::move(raft::make_device_vector(handle, m_))), - R_1nn_dists(std::move(raft::make_device_vector(handle, m_))), - R_closest_landmark_dists( - std::move(raft::make_device_vector(handle, m_))), - R(std::move(raft::make_device_matrix(handle, sqrt(m_), n_))), - R_radius(std::move(raft::make_device_vector(handle, sqrt(m_)))), + R_indptr(raft::make_device_vector(handle, sqrt(m_) + 1)), + R_1nn_cols(raft::make_device_vector(handle, m_)), + R_1nn_dists(raft::make_device_vector(handle, m_)), + R_closest_landmark_dists(raft::make_device_vector(handle, m_)), + R(raft::make_device_matrix(handle, sqrt(m_), n_)), + R_radius(raft::make_device_vector(handle, sqrt(m_))), index_trained(false) { } @@ -83,20 +82,41 @@ class BallCoverIndex { * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) */ n_landmarks(sqrt(X_.extent(0))), - R_indptr( - std::move(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1))), - R_1nn_cols(std::move(raft::make_device_vector(handle, X_.extent(0)))), - R_1nn_dists(std::move(raft::make_device_vector(handle, X_.extent(0)))), - R_closest_landmark_dists( - std::move(raft::make_device_vector(handle, X_.extent(0)))), - R(std::move( - raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1)))), - R_radius( - std::move(raft::make_device_vector(handle, sqrt(X_.extent(0))))), + R_indptr(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1)), + R_1nn_cols(raft::make_device_vector(handle, X_.extent(0))), + R_1nn_dists(raft::make_device_vector(handle, X_.extent(0))), + R_closest_landmark_dists(raft::make_device_vector(handle, X_.extent(0))), + R(raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1))), + R_radius(raft::make_device_vector(handle, sqrt(X_.extent(0)))), index_trained(false) { } + auto get_R_indptr() const -> raft::device_vector_view + { + return R_indptr.view(); + } + auto get_R_1nn_cols() const -> raft::device_vector_view + { + return R_1nn_cols.view(); + } + auto get_R_1nn_dists() const -> raft::device_vector_view + { + return R_1nn_dists.view(); + } + auto get_R_radius() const -> raft::device_vector_view + { + return R_radius.view(); + } + auto get_R() const -> raft::device_matrix_view + { + return R.view(); + } + auto get_R_closest_landmark_dists() const -> raft::device_vector_view + { + return R_closest_landmark_dists.view(); + } + raft::device_vector_view get_R_indptr() { return R_indptr.view(); } raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); } @@ -106,8 +126,11 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } - raft::device_matrix_view get_X() { return X; } + raft::device_matrix_view get_X() const { return X; } + + raft::distance::DistanceType get_metric() const { return metric; } + value_int get_n_landmarks() const { return n_landmarks; } bool is_index_trained() const { return index_trained; }; // This should only be set by internal functions @@ -115,9 +138,9 @@ class BallCoverIndex { const raft::handle_t& handle; - const value_int m; - const value_int n; - const value_int n_landmarks; + value_int m; + value_int n; + value_int n_landmarks; raft::device_matrix_view X; diff --git a/cpp/include/raft/spatial/knn/brute_force.cuh b/cpp/include/raft/spatial/knn/brute_force.cuh index c32a33d2e2..dda1e02eed 100644 --- a/cpp/include/raft/spatial/knn/brute_force.cuh +++ b/cpp/include/raft/spatial/knn/brute_force.cuh @@ -42,7 +42,6 @@ namespace raft::spatial::knn { * @param[out] out_keys matrix of output keys (size n_samples * k) * @param[out] out_values matrix of output values (size n_samples * k) * @param[in] n_samples number of rows in each part - * @param[in] k number of neighbors for each part * @param[in] translations optional vector of starting index mappings for each partition */ template @@ -53,7 +52,6 @@ inline void knn_merge_parts( raft::device_matrix_view out_keys, raft::device_matrix_view out_values, size_t n_samples, - int k, std::optional> translations = std::nullopt) { RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), @@ -61,7 +59,7 @@ inline void knn_merge_parts( RAFT_EXPECTS( out_keys.extent(0) == out_values.extent(0) == n_samples, "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == k, + RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1), "Number of columns in output indices and distances matrices must be equal to k"); auto n_parts = in_keys.extent(0) / n_samples; @@ -71,7 +69,7 @@ inline void knn_merge_parts( out_values.data_handle(), n_samples, n_parts, - k, + in_keys.extent(1), handle.get_stream(), translations.value_or(nullptr)); } diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index e65a895f60..94897daa22 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -174,15 +174,16 @@ void construct_landmark_1nn(const raft::handle_t& handle, */ template void k_closest_landmarks(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query_pts, value_int n_query_pts, value_int k, value_idx* R_knn_inds, value_t* R_knn_dists) { - std::vector input = {index.get_R().data_handle()}; - std::vector sizes = {index.n_landmarks}; + // TODO: Add const to the brute-force knn inputs + std::vector input = {const_cast(index.get_R().data_handle())}; + std::vector sizes = {index.n_landmarks}; brute_force_knn_impl(handle, input, @@ -196,7 +197,7 @@ void k_closest_landmarks(const raft::handle_t& handle, true, true, nullptr, - index.metric); + index.get_metric()); } /** @@ -240,7 +241,7 @@ template void perform_rbc_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, value_int n_query_pts, std::uint32_t k, @@ -470,7 +471,7 @@ template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, value_int k, const value_t* query, value_int n_query_pts, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index c0056e7137..112ab9f13c 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -331,7 +331,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, value_idx* out_inds, value_t* out_dists, value_int* dist_counter, - value_t* R_radius, + const value_t* R_radius, distance_func dfunc, float weight = 1.0) { @@ -472,7 +472,7 @@ template void rbc_low_dim_pass_one(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, @@ -604,7 +604,7 @@ template void rbc_low_dim_pass_two(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, diff --git a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh index c859f2c5ec..a861375b2f 100644 --- a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh @@ -34,7 +34,7 @@ extern template void rbc_build_index( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, diff --git a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp index afee3bd7a3..31df566b3f 100644 --- a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp +++ b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp @@ -25,7 +25,7 @@ namespace detail { extern template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -39,7 +39,7 @@ extern template void rbc_low_dim_pass_one extern template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -53,7 +53,7 @@ extern template void rbc_low_dim_pass_two extern template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -67,7 +67,7 @@ extern template void rbc_low_dim_pass_one extern template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/src/nn/specializations/ball_cover.cu b/cpp/src/nn/specializations/ball_cover.cu index 7473b65d25..15af9f6e68 100644 --- a/cpp/src/nn/specializations/ball_cover.cu +++ b/cpp/src/nn/specializations/ball_cover.cu @@ -37,7 +37,7 @@ template void rbc_build_index template void rbc_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu index 8950ff8d5c..d2d729a52d 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu index 7b8b6ce9a2..0b32d43ba9 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -39,7 +39,7 @@ template void rbc_low_dim_pass_one( template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu index 29e8eec8c8..7c8f18859f 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu index d6d4b356c8..1ef071033c 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index c2f8d87b56..4b3244913b 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -102,7 +102,7 @@ class GatherTest : public ::testing::TestWithParam { auto out_view = raft::make_device_matrix_view(d_out_act.data(), map_length, ncols); auto map_view = - raft::make_device_vector_view(d_map.data(), map_length); + raft::make_device_vector_view(d_map.data(), map_length); raft::matrix::gather(handle, in_view, map_view, out_view); diff --git a/cpp/test/spatial/ann_ivf_flat.cu b/cpp/test/spatial/ann_ivf_flat.cu index 47ca92fc97..99a5a42824 100644 --- a/cpp/test/spatial/ann_ivf_flat.cu +++ b/cpp/test/spatial/ann_ivf_flat.cu @@ -219,7 +219,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs)); handle_.sync_stream(stream_); - int64_t half_of_data = ps.num_db_vecs / 2; + IdxT half_of_data = ps.num_db_vecs / 2; auto half_of_data_view = raft::make_device_matrix_view( (const DataT*)database.data(), half_of_data, ps.dim); @@ -230,7 +230,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam { &index_2, database.data() + half_of_data * ps.dim, vector_indices.data() + half_of_data, - int64_t(ps.num_db_vecs) - half_of_data); + IdxT(ps.num_db_vecs) - half_of_data); ivf_flat::search(handle_, search_params, From 3db8a27d0e6f6019a7dcd046d38268db6442db15 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 30 Sep 2022 19:53:50 -0400 Subject: [PATCH 55/58] Moving re-defined validation logic out of mdspan.hpp --- cpp/include/raft/core/mdspan.hpp | 112 --------------------- cpp/include/raft/matrix/copy.cuh | 1 + cpp/include/raft/matrix/linewise_op.cuh | 6 -- cpp/include/raft/util/input_validation.hpp | 87 ++++++++++++++++ 4 files changed, 88 insertions(+), 118 deletions(-) diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 7933322e64..a858633e07 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -255,116 +255,4 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, return unravel_index_impl(static_cast(idx), shape); } } - -template -constexpr bool is_row_or_column_major(mdspan /* m */) -{ - return false; -} - -template -constexpr bool is_row_or_column_major(mdspan /* m */) -{ - return true; -} - -template -constexpr bool is_row_or_column_major(mdspan /* m */) -{ - return true; -} - -template -constexpr bool is_row_or_column_major(mdspan m) -{ - return m.is_exhaustive(); -} - -template -constexpr bool is_row_major(mdspan /* m */) -{ - return false; -} - -template -constexpr bool is_row_major(mdspan /* m */) -{ - return false; -} - -template -constexpr bool is_row_major(mdspan /* m */) -{ - return true; -} - -template -constexpr bool is_row_major(mdspan m) -{ - return m.is_exhaustive() && m.stride(1) == typename Extents::index_type(1); -} - -template -constexpr bool is_col_major(mdspan /* m */) -{ - return false; -} - -template -constexpr bool is_col_major(mdspan /* m */) -{ - return true; -} - -template -constexpr bool is_col_major(mdspan /* m */) -{ - return false; -} - -template -constexpr bool is_col_major(mdspan m) -{ - return m.is_exhaustive() && m.stride(0) == typename Extents::index_type(1); -} - -template -constexpr bool is_matrix_view( - mdspan, Layout, Accessor> /* m */) -{ - return sizeof...(Exts) == 2; -} - -template -constexpr bool is_matrix_view(mdspan m) -{ - return false; -} - -template -constexpr bool is_vector_view( - mdspan, Layout, Accessor> /* m */) -{ - return sizeof...(Exts) == 1; -} - -template -constexpr bool is_vector_view(mdspan m) -{ - return false; -} - -template -constexpr bool is_scalar_view( - mdspan, Layout, Accessor> /* m */) -{ - return sizeof...(Exts) == 0; -} - -template -constexpr bool is_scalar_view(mdspan m) -{ - return false; -} - } // namespace raft diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 035370d5c8..5f1d16485c 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -18,6 +18,7 @@ #include #include +#include namespace raft::matrix { diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index 053e02cb4a..6b383b14f5 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -22,12 +22,6 @@ namespace raft::matrix { -// template -// args *extract_ptr(raft::device_vector_view vec, raft::device_vector_view... vecs) { -// vecs.data_handle(); -//} - /** * Run a function over matrix lines (rows or columns) with a variable number * row-vectors or column-vectors. diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp index b34843f5e8..ab5264f900 100644 --- a/cpp/include/raft/util/input_validation.hpp +++ b/cpp/include/raft/util/input_validation.hpp @@ -42,4 +42,91 @@ constexpr bool is_row_or_column_major(mdspan +constexpr bool is_row_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_row_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_row_major(mdspan /* m */) +{ + return true; +} + +template +constexpr bool is_row_major(mdspan m) +{ + return m.is_exhaustive() && m.stride(1) == typename Extents::index_type(1); +} + +template +constexpr bool is_col_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_col_major(mdspan /* m */) +{ + return true; +} + +template +constexpr bool is_col_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_col_major(mdspan m) +{ + return m.is_exhaustive() && m.stride(0) == typename Extents::index_type(1); +} + +template +constexpr bool is_matrix_view( + mdspan, Layout, Accessor> /* m */) +{ + return sizeof...(Exts) == 2; +} + +template +constexpr bool is_matrix_view(mdspan m) +{ + return false; +} + +template +constexpr bool is_vector_view( + mdspan, Layout, Accessor> /* m */) +{ + return sizeof...(Exts) == 1; +} + +template +constexpr bool is_vector_view(mdspan m) +{ + return false; +} + +template +constexpr bool is_scalar_view( + mdspan, Layout, Accessor> /* m */) +{ + return sizeof...(Exts) == 0; +} + +template +constexpr bool is_scalar_view(mdspan m) +{ + return false; +} + }; // end namespace raft \ No newline at end of file From 4c06e91e8d637026ac435cffe7388e29ea79902c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 1 Oct 2022 20:27:07 -0400 Subject: [PATCH 56/58] Review feedback --- cpp/include/raft/matrix/init.cuh | 5 +++-- cpp/include/raft/matrix/reciprocal.cuh | 21 +++++++++++++++------ cpp/include/raft/matrix/sqrt.cuh | 15 ++++++++++----- cpp/test/matrix/math.cu | 3 ++- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index 37ea1dce1a..d6f34197a9 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -35,9 +35,10 @@ template void fill(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - math_t scalar) + raft::host_scalar_view scalar) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); - detail::setValue(out.data_handle(), in.data_handle(), scalar, in.size(), handle.get_stream()); + detail::setValue( + out.data_handle(), in.data_handle(), *(scalar.data_handle()), in.size(), handle.get_stream()); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index 80f253c828..e5d432bce6 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -37,13 +37,18 @@ template void reciprocal(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - math_t scalar, + raft::host_scalar_view scalar, bool setzero = false, math_t thres = 1e-15) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have the same size."); - detail::reciprocal( - in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), setzero, thres); + detail::reciprocal(in.data_handle(), + out.data_handle(), + *(scalar.data_handle()), + in.size(), + handle.get_stream(), + setzero, + thres); } /** @@ -61,11 +66,15 @@ void reciprocal(const raft::handle_t& handle, template void reciprocal(const raft::handle_t& handle, raft::device_matrix_view inout, - math_t scalar, + raft::host_scalar_view scalar, bool setzero = false, math_t thres = 1e-15) { - detail::reciprocal( - inout.data_handle(), scalar, inout.size(), handle.get_stream(), setzero, thres); + detail::reciprocal(inout.data_handle(), + *(scalar.data_handle()), + inout.size(), + handle.get_stream(), + setzero, + thres); } } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index b371253690..10c76c8cab 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -69,12 +69,16 @@ template void weighted_sqrt(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, - math_t scalar, + raft::host_scalar_view scalar, bool set_neg_zero = false) { RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); - detail::seqRoot( - in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream(), set_neg_zero); + detail::seqRoot(in.data_handle(), + out.data_handle(), + *(scalar.data_handle()), + in.size(), + handle.get_stream(), + set_neg_zero); } /** @@ -90,10 +94,11 @@ void weighted_sqrt(const raft::handle_t& handle, template void weighted_sqrt(const raft::handle_t& handle, raft::device_matrix_view inout, - math_t scalar, + raft::host_scalar_view scalar, bool set_neg_zero = false) { - detail::seqRoot(inout.data_handle(), scalar, inout.size(), handle.get_stream(), set_neg_zero); + detail::seqRoot( + inout.data_handle(), *(scalar.data_handle()), inout.size(), handle.get_stream(), set_neg_zero); } } // namespace raft::matrix diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index adfa45b84a..8797709593 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -185,7 +185,8 @@ class MathTest : public ::testing::TestWithParam> { auto out_recip_view = raft::make_device_matrix_view(out_recip.data(), 4, 1); // this `reciprocal()` has to go first bc next one modifies its input - reciprocal(handle, in_recip_view, out_recip_view, recip_scalar); + reciprocal( + handle, in_recip_view, out_recip_view, raft::make_host_scalar_view(&recip_scalar)); auto inout_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); From 946bfcb87de2d9472b4de1281743ff9759d8c8ab Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 1 Oct 2022 20:29:44 -0400 Subject: [PATCH 57/58] More review feedbck --- cpp/include/raft/matrix/init.cuh | 1 + cpp/include/raft/matrix/reciprocal.cuh | 1 + cpp/include/raft/matrix/sqrt.cuh | 1 + cpp/test/matrix/math.cu | 2 +- 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index d6f34197a9..0d17983e4b 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index e5d432bce6..0f8a486a82 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index 10c76c8cab..5aeefe3002 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index 8797709593..ad4a37825c 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -190,7 +190,7 @@ class MathTest : public ::testing::TestWithParam> { auto inout_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); - reciprocal(handle, inout_recip_view, recip_scalar, true); + reciprocal(handle, inout_recip_view, raft::make_host_scalar_view(&recip_scalar), true); std::vector in_small_val_zero_h = {0.1, 1e-16, -1e-16, -0.1}; std::vector in_small_val_zero_ref_h = {0.1, 0.0, 0.0, -0.1}; From e0e5eae92870ea97603c5e3a38c8ef3aff2f3a63 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sat, 1 Oct 2022 20:42:32 -0400 Subject: [PATCH 58/58] Fixing styl --- cpp/include/raft/matrix/init.cuh | 2 +- cpp/include/raft/matrix/reciprocal.cuh | 2 +- cpp/include/raft/matrix/sqrt.cuh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index 0d17983e4b..e3a6c09fe6 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -16,8 +16,8 @@ #pragma once -#include #include +#include #include #include diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh index 0f8a486a82..c41ecfb999 100644 --- a/cpp/include/raft/matrix/reciprocal.cuh +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -16,8 +16,8 @@ #pragma once -#include #include +#include #include namespace raft::matrix { diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index 5aeefe3002..302167480e 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -16,8 +16,8 @@ #pragma once -#include #include +#include #include #include