Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Un-scale output distances #1499

Merged
merged 6 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#pragma once

#include <raft/neighbors/detail/ivf_pq_search.cuh>
#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/neighbors/cagra_types.hpp>
Expand Down Expand Up @@ -94,6 +97,22 @@ void search_main(raft::device_resources const& res,
_num_executed_iterations,
topk);
}

static_assert(std::is_same_v<DistanceT, float>,
"only float distances are supported at the moment");
float* dist_out = distances.data_handle();
const DistanceT* dist_in = distances.data_handle();
// We're converting the data from T to DistanceT during distance computation
// and divide the values by kDivisor. Here we restore the original scale.
constexpr float kScale = spatial::knn::detail::utils::config<T>::kDivisor /
spatial::knn::detail::utils::config<DistanceT>::kDivisor;
ivf_pq::detail::postprocess_distances(dist_out,
dist_in,
index.metric(),
distances.extent(0),
distances.extent(1),
kScale,
res.get_stream());
}
/** @} */ // end group cagra

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
#pragma once

#include <raft/spatial/knn/detail/ann_utils.cuh>

#include "device_common.hpp"
#include "hashmap.hpp"
#include "utils.hpp"
Expand Down Expand Up @@ -102,7 +104,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const uint32_t kv = k + v;
// if (kv >= dataset_dim) break;
DISTANCE_T diff = query_buffer[device::swizzling(kv)];
diff -= static_cast<float>(dl_buff[e].data[v]) * device::fragment_scale<DATA_T>();
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
}
}
Expand Down Expand Up @@ -229,7 +231,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
const unsigned kv = k + v;
diff = query_buffer[device::swizzling(kv)];
}
diff -= static_cast<float>(dl_buff[e].data[v]) * device::fragment_scale<DATA_T>();
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
}
}
Expand Down
26 changes: 1 addition & 25 deletions cpp/include/raft/neighbors/detail/cagra/device_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,6 @@ namespace device {
// warpSize for compile time calculation
constexpr unsigned warp_size = 32;

// scaling factor for distance computation
template <class T>
_RAFT_HOST_DEVICE constexpr float fragment_scale();
template <>
_RAFT_HOST_DEVICE constexpr float fragment_scale<float>()
{
return 1.0;
};
template <>
_RAFT_HOST_DEVICE constexpr float fragment_scale<half>()
{
return 1.0;
};
template <>
_RAFT_HOST_DEVICE constexpr float fragment_scale<uint8_t>()
{
return 1.0 / 256.0;
};
template <>
_RAFT_HOST_DEVICE constexpr float fragment_scale<int8_t>()
{
return 1.0 / 128.0;
};

/** Xorshift rondem number generator.
*
* See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference.
Expand All @@ -73,4 +49,4 @@ _RAFT_DEVICE inline T swizzling(T x)
}

} // namespace device
} // namespace raft::neighbors::experimental::cagra::detail
} // namespace raft::neighbors::experimental::cagra::detail
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* limitations under the License.
*/
#pragma once

#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <algorithm>
#include <cassert>
#include <iostream>
Expand Down Expand Up @@ -204,7 +207,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) {
unsigned j = device::swizzling(i);
if (i < dataset_dim) {
query_buffer[j] = static_cast<float>(query_ptr[i]) * device::fragment_scale<DATA_T>();
query_buffer[j] = spatial::knn::detail::utils::mapping<float>{}(query_ptr[i]);
} else {
query_buffer[j] = 0.0;
}
Expand Down
19 changes: 13 additions & 6 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* limitations under the License.
*/
#pragma once

#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <algorithm>
#include <cassert>
#include <iostream>
Expand Down Expand Up @@ -124,10 +127,12 @@ __global__ void random_pickup_kernel(
random_data_frag, dataset_ptr + (dataset_dim * seed_index), dataset_dim);

// Compute the norm of two data
const auto norm2 =
device::norm2<DISTANCE_T>(query_frag, random_data_frag, device::fragment_scale<DATA_T>()
/*, scale*/
);
const auto norm2 = device::norm2<DISTANCE_T>(
query_frag,
random_data_frag,
static_cast<float>(1.0 / spatial::knn::detail::utils::config<DATA_T>::kDivisor)
/*, scale*/
);

if (norm2 < best_norm2_team_local) {
best_norm2_team_local = norm2;
Expand Down Expand Up @@ -335,8 +340,10 @@ __global__ void compute_distance_to_child_nodes_kernel(
device::fragment<MAX_DATASET_DIM, DATA_T, TEAM_SIZE> frag_query;
device::load_vector_sync(frag_query, query_ptr + blockIdx.y * data_dim, data_dim);

const auto norm2 =
device::norm2<DISTANCE_T>(frag_target, frag_query, device::fragment_scale<DATA_T>());
const auto norm2 = device::norm2<DISTANCE_T>(
frag_target,
frag_query,
static_cast<float>(1.0 / spatial::knn::detail::utils::config<DATA_T>::kDivisor));

if (threadIdx.x % TEAM_SIZE == 0) {
result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* limitations under the License.
*/
#pragma once

#include <raft/spatial/knn/detail/ann_utils.cuh>

#include <algorithm>
#include <cassert>
#include <iostream>
Expand Down Expand Up @@ -592,7 +595,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__
for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += BLOCK_SIZE) {
unsigned j = device::swizzling(i);
if (i < dataset_dim) {
query_buffer[j] = static_cast<float>(query_ptr[i]) * device::fragment_scale<DATA_T>();
query_buffer[j] = spatial::knn::detail::utils::mapping<float>{}(query_ptr[i]);
} else {
query_buffer[j] = 0.0;
}
Expand Down
16 changes: 14 additions & 2 deletions cpp/include/raft/spatial/knn/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <memory>
#include <optional>

#include <cuda_fp16.hpp>

namespace raft::spatial::knn::detail::utils {

/** Whether pointers are accessible on the device or on the host. */
Expand Down Expand Up @@ -136,12 +138,22 @@ struct with_mapped_memory_t {
template <typename T>
struct config {};

template <>
struct config<double> {
using value_t = double;
static constexpr double kDivisor = 1.0;
};
template <>
struct config<float> {
using value_t = float;
static constexpr double kDivisor = 1.0;
};
template <>
struct config<half> {
using value_t = half;
static constexpr double kDivisor = 1.0;
};
template <>
struct config<uint8_t> {
using value_t = uint32_t;
static constexpr double kDivisor = 256.0;
Expand Down Expand Up @@ -169,13 +181,13 @@ struct mapping {
* @{
*/
template <typename S>
HDI auto operator()(const S& x) const -> std::enable_if_t<std::is_same_v<S, T>, T>
HDI constexpr auto operator()(const S& x) const -> std::enable_if_t<std::is_same_v<S, T>, T>
{
return x;
};

template <typename S>
HDI auto operator()(const S& x) const -> std::enable_if_t<!std::is_same_v<S, T>, T>
HDI constexpr auto operator()(const S& x) const -> std::enable_if_t<!std::is_same_v<S, T>, T>
{
constexpr double kMult = config<T>::kDivisor / config<S>::kDivisor;
if constexpr (std::is_floating_point_v<S>) { return static_cast<T>(x * static_cast<S>(kMult)); }
Expand Down
2 changes: 2 additions & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ if(BUILD_TESTS)
NEIGHBORS_TEST
PATH
test/neighbors/ann_cagra/test_float_uint32_t.cu
test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu
test/neighbors/ann_ivf_flat/test_float_int64_t.cu
test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu
test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu
Expand Down
6 changes: 5 additions & 1 deletion cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
protected:
void testCagra()
{
if (ps.dim * sizeof(DataT) % 8 != 0) {
GTEST_SKIP()
<< "CAGRA requires the input data rows to be aligned at least to 8 bytes for now.";
}
size_t queries_size = ps.n_queries * ps.k;
std::vector<IdxT> indices_Cagra(queries_size);
std::vector<IdxT> indices_naive(queries_size);
Expand Down Expand Up @@ -310,4 +314,4 @@ inline std::vector<AnnCagraInputs> generate_inputs()

const std::vector<AnnCagraInputs> inputs = generate_inputs();

} // namespace raft::neighbors::experimental::cagra
} // namespace raft::neighbors::experimental::cagra
28 changes: 28 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#include "../ann_cagra.cuh"

namespace raft::neighbors::experimental::cagra {

typedef AnnCagraTest<float, std::int8_t, std::uint32_t> AnnCagraTestI8;
TEST_P(AnnCagraTestI8, AnnCagra) { this->testCagra(); }

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestI8, ::testing::ValuesIn(inputs));

} // namespace raft::neighbors::experimental::cagra
28 changes: 28 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>

#include "../ann_cagra.cuh"

namespace raft::neighbors::experimental::cagra {

typedef AnnCagraTest<float, std::uint8_t, std::uint32_t> AnnCagraTestU8;
TEST_P(AnnCagraTestU8, AnnCagra) { this->testCagra(); }

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestU8, ::testing::ValuesIn(inputs));

} // namespace raft::neighbors::experimental::cagra