diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 79cbb6198f..5902d1405f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -16,6 +16,9 @@ #pragma once +#include +#include + #include #include #include @@ -94,6 +97,22 @@ void search_main(raft::device_resources const& res, _num_executed_iterations, topk); } + + static_assert(std::is_same_v, + "only float distances are supported at the moment"); + float* dist_out = distances.data_handle(); + const DistanceT* dist_in = distances.data_handle(); + // We're converting the data from T to DistanceT during distance computation + // and divide the values by kDivisor. Here we restore the original scale. + constexpr float kScale = spatial::knn::detail::utils::config::kDivisor / + spatial::knn::detail::utils::config::kDivisor; + ivf_pq::detail::postprocess_distances(dist_out, + dist_in, + index.metric(), + distances.extent(0), + distances.extent(1), + kScale, + res.get_stream()); } /** @} */ // end group cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 29c841c0b5..52e5c62169 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "device_common.hpp" #include "hashmap.hpp" #include "utils.hpp" @@ -102,7 +104,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( const uint32_t kv = k + v; // if (kv >= dataset_dim) break; DISTANCE_T diff = query_buffer[device::swizzling(kv)]; - diff -= static_cast(dl_buff[e].data[v]) * device::fragment_scale(); + diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); norm2 += diff * diff; } } @@ -229,7 +231,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in const unsigned kv = k + v; diff = query_buffer[device::swizzling(kv)]; } - diff -= static_cast(dl_buff[e].data[v]) * device::fragment_scale(); + diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); norm2 += diff * diff; } } diff --git a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp index 20f30d9f11..f9c81f3d25 100644 --- a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp @@ -27,30 +27,6 @@ namespace device { // warpSize for compile time calculation constexpr unsigned warp_size = 32; -// scaling factor for distance computation -template -_RAFT_HOST_DEVICE constexpr float fragment_scale(); -template <> -_RAFT_HOST_DEVICE constexpr float fragment_scale() -{ - return 1.0; -}; -template <> -_RAFT_HOST_DEVICE constexpr float fragment_scale() -{ - return 1.0; -}; -template <> -_RAFT_HOST_DEVICE constexpr float fragment_scale() -{ - return 1.0 / 256.0; -}; -template <> -_RAFT_HOST_DEVICE constexpr float fragment_scale() -{ - return 1.0 / 128.0; -}; - /** Xorshift rondem number generator. * * See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference. @@ -73,4 +49,4 @@ _RAFT_DEVICE inline T swizzling(T x) } } // namespace device -} // namespace raft::neighbors::experimental::cagra::detail \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index 6148441bd0..99553632ac 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -14,6 +14,9 @@ * limitations under the License. */ #pragma once + +#include + #include #include #include @@ -204,7 +207,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { unsigned j = device::swizzling(i); if (i < dataset_dim) { - query_buffer[j] = static_cast(query_ptr[i]) * device::fragment_scale(); + query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); } else { query_buffer[j] = 0.0; } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 629bed2aee..e3e9c8a655 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -14,6 +14,9 @@ * limitations under the License. */ #pragma once + +#include + #include #include #include @@ -124,10 +127,12 @@ __global__ void random_pickup_kernel( random_data_frag, dataset_ptr + (dataset_dim * seed_index), dataset_dim); // Compute the norm of two data - const auto norm2 = - device::norm2(query_frag, random_data_frag, device::fragment_scale() - /*, scale*/ - ); + const auto norm2 = device::norm2( + query_frag, + random_data_frag, + static_cast(1.0 / spatial::knn::detail::utils::config::kDivisor) + /*, scale*/ + ); if (norm2 < best_norm2_team_local) { best_norm2_team_local = norm2; @@ -335,8 +340,10 @@ __global__ void compute_distance_to_child_nodes_kernel( device::fragment frag_query; device::load_vector_sync(frag_query, query_ptr + blockIdx.y * data_dim, data_dim); - const auto norm2 = - device::norm2(frag_target, frag_query, device::fragment_scale()); + const auto norm2 = device::norm2( + frag_target, + frag_query, + static_cast(1.0 / spatial::knn::detail::utils::config::kDivisor)); if (threadIdx.x % TEAM_SIZE == 0) { result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index fc87b952b0..531b30ba85 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -14,6 +14,9 @@ * limitations under the License. */ #pragma once + +#include + #include #include #include @@ -592,7 +595,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) { unsigned j = device::swizzling(i); if (i < dataset_dim) { - query_buffer[j] = static_cast(query_ptr[i]) * device::fragment_scale(); + query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); } else { query_buffer[j] = 0.0; } diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index dd291251b4..850b741dfd 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -29,6 +29,8 @@ #include #include +#include + namespace raft::spatial::knn::detail::utils { /** Whether pointers are accessible on the device or on the host. */ @@ -136,12 +138,22 @@ struct with_mapped_memory_t { template struct config {}; +template <> +struct config { + using value_t = double; + static constexpr double kDivisor = 1.0; +}; template <> struct config { using value_t = float; static constexpr double kDivisor = 1.0; }; template <> +struct config { + using value_t = half; + static constexpr double kDivisor = 1.0; +}; +template <> struct config { using value_t = uint32_t; static constexpr double kDivisor = 256.0; @@ -169,13 +181,13 @@ struct mapping { * @{ */ template - HDI auto operator()(const S& x) const -> std::enable_if_t, T> + HDI constexpr auto operator()(const S& x) const -> std::enable_if_t, T> { return x; }; template - HDI auto operator()(const S& x) const -> std::enable_if_t, T> + HDI constexpr auto operator()(const S& x) const -> std::enable_if_t, T> { constexpr double kMult = config::kDivisor / config::kDivisor; if constexpr (std::is_floating_point_v) { return static_cast(x * static_cast(kMult)); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 7f45a6dd22..88ad7772c2 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -314,6 +314,8 @@ if(BUILD_TESTS) NEIGHBORS_TEST PATH test/neighbors/ann_cagra/test_float_uint32_t.cu + test/neighbors/ann_cagra/test_int8_t_uint32_t.cu + test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 8b8aa21fc9..f9df1f724f 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -82,6 +82,10 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { + if (ps.dim * sizeof(DataT) % 8 != 0) { + GTEST_SKIP() + << "CAGRA requires the input data rows to be aligned at least to 8 bytes for now."; + } size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); std::vector indices_naive(queries_size); @@ -310,4 +314,4 @@ inline std::vector generate_inputs() const std::vector inputs = generate_inputs(); -} // namespace raft::neighbors::experimental::cagra \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu new file mode 100644 index 0000000000..f148ebc186 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_cagra.cuh" + +namespace raft::neighbors::experimental::cagra { + +typedef AnnCagraTest AnnCagraTestI8; +TEST_P(AnnCagraTestI8, AnnCagra) { this->testCagra(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu new file mode 100644 index 0000000000..087d7cec71 --- /dev/null +++ b/cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../ann_cagra.cuh" + +namespace raft::neighbors::experimental::cagra { + +typedef AnnCagraTest AnnCagraTestU8; +TEST_P(AnnCagraTestU8, AnnCagra) { this->testCagra(); } + +INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8, ::testing::ValuesIn(inputs)); + +} // namespace raft::neighbors::experimental::cagra