Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Add bitset for ANN pre-filtering and deletion #1803

Merged
merged 21 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions cpp/include/raft/util/bitset.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/map.cuh>
#include <thrust/for_each.h>

namespace raft::utils {
/**
* @defgroup bitset Bitset
* @{
*/
/**
* @brief View of a RAFT Bitset.
*
* This lightweight structure stores a pointer to a bitset in device memory with it's length.
* It provides a test() device function to check if a given index is set in the bitset.
*
* @tparam IdxT Indexing type used. Default is uint32_t.
*/
template <typename IdxT = uint32_t>
struct bitset_view {
using BitsetT = uint32_t;
lowener marked this conversation as resolved.
Show resolved Hide resolved
IdxT bitset_element_size = sizeof(BitsetT) * 8;
wphicks marked this conversation as resolved.
Show resolved Hide resolved

_RAFT_HOST_DEVICE bitset_view(BitsetT* bitset_ptr, IdxT bitset_len)
: bitset_ptr_{bitset_ptr}, bitset_len_{bitset_len}
{
}
/**
* @brief Create a bitset view from a device vector view of the bitset.
*
* @param bitset_span Device vector view of the bitset
*/
_RAFT_HOST_DEVICE bitset_view(raft::device_vector_view<BitsetT, IdxT> bitset_span)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like most of the logic is host-device compatible. Any reason why we restrict ourselves to an on-device bitset? I ask in particular because I can imagine several common workflows where we construct the bitset on host and then apply it on device by copying the underlying memory. As a matter of fact, this is exactly what we do in FIL.

Separately, it seems like a span, rather than an mdspan, is sufficient for what we're using it for. If we're going to use an mdspan anyway, is there a specific reason to restrict ourselves to 1D? If not, is there some other reason not to just accept a span?

: bitset_ptr_{bitset_span.data_handle()}, bitset_len_{bitset_span.extent(0)}
{
}
/**
* @brief Device function to test if a given index is set in the bitset.
*
* @param sample_index Single index to test
* @return bool True if index has not been unset in the bitset
*/
inline _RAFT_DEVICE auto test(const IdxT sample_index) const -> bool
{
const IdxT bit_element = bitset_ptr_[sample_index / bitset_element_size];
wphicks marked this conversation as resolved.
Show resolved Hide resolved
const IdxT bit_index = sample_index % bitset_element_size;
const bool is_bit_set = (bit_element & (1ULL << bit_index)) != 0;
return is_bit_set;
}
/**
* @brief Get the device pointer to the bitset.
*/
inline _RAFT_HOST_DEVICE auto get_bitset_ptr() -> BitsetT* { return bitset_ptr_; }
lowener marked this conversation as resolved.
Show resolved Hide resolved
inline _RAFT_HOST_DEVICE auto get_bitset_ptr() const -> const BitsetT* { return bitset_ptr_; }
/**
* @brief Get the length of the bitset representation.
*/
inline _RAFT_HOST_DEVICE auto get_bitset_len() const -> IdxT { return bitset_len_; }

private:
BitsetT* bitset_ptr_;
IdxT bitset_len_;
};

/**
* @brief RAFT Bitset.
*
* This structure encapsulates a bitset in device memory. It provides a view() method to get a
* device-usable lightweight view of the bitset.
* Each index is represented by a single bit in the bitset. The total number of bytes used is
* ceil(bitset_len / 4).
* The underlying type of the bitset array is uint32_t.
* @tparam IdxT Indexing type used. Default is uint32_t.
*/
template <typename IdxT = uint32_t>
struct bitset {
using BitsetT = uint32_t;
IdxT bitset_element_size = sizeof(BitsetT) * 8;
lowener marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Construct a new bitset object with a list of indices to unset.
*
* @param res RAFT resources
* @param mask_index List of indices to unset in the bitset
lowener marked this conversation as resolved.
Show resolved Hide resolved
* @param bitset_len Length of the bitset
*/
bitset(const raft::resources& res,
raft::device_vector_view<const IdxT, IdxT> mask_index,
IdxT bitset_len)
: bitset_{raft::make_device_vector<BitsetT, IdxT>(
res, raft::ceildiv(bitset_len, bitset_element_size))}
{
RAFT_EXPECTS(mask_index.extent(0) <= bitset_len, "Mask index cannot be larger than bitset len");
lowener marked this conversation as resolved.
Show resolved Hide resolved
cudaMemsetAsync(bitset_.data_handle(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more efficient to construct the data on host and then just transfer it all at once, rather than setting and unsetting the data like this? It would come at the cost of host memory footprint, but I suspect that it would be worth it for large bitsets.

0xff,
bitset_.size() * sizeof(BitsetT),
resource::get_cuda_stream(res));
bitset_unset(res, view(), mask_index);
}

/**
* @brief Construct a new bitset object
*
* @param res RAFT resources
* @param bitset_len Length of the bitset
*/
bitset(const raft::resources& res, IdxT bitset_len)
: bitset_{raft::make_device_vector<BitsetT, IdxT>(
res, raft::ceildiv(bitset_len, bitset_element_size))}
{
cudaMemsetAsync(bitset_.data_handle(),
0xff,
bitset_.size() * sizeof(BitsetT),
resource::get_cuda_stream(res));
}
// Disable copy constructor
bitset(const bitset&) = delete;
bitset(bitset&&) = default;
bitset& operator=(const bitset&) = delete;
bitset& operator=(bitset&&) = default;

/**
* @brief Create a device-usable view of the bitset.
*
* @return bitset_view<IdxT>
*/
inline auto view() -> raft::utils::bitset_view<IdxT> { return bitset_view<IdxT>(bitset_.view()); }
[[nodiscard]] inline auto view() const -> raft::utils::bitset_view<IdxT>
{
return bitset_view<IdxT>(bitset_.view());
}

private:
raft::device_vector<BitsetT, IdxT> bitset_;
};

/**
* @brief Function to unset a list of indices in a bitset.
*
* @tparam IdxT Indexing type used. Default is uint32_t.
* @param res RAFT resources
* @param bitset_view_ View of the bitset
* @param mask_index indices to remove from the bitset
*/
template <typename IdxT>
void bitset_unset(const raft::resources& res,
raft::utils::bitset_view<IdxT> bitset_view_,
raft::device_vector_view<const IdxT, IdxT> mask_index)
{
auto* bitset_ptr = bitset_view_.get_bitset_ptr();
thrust::for_each_n(resource::get_thrust_policy(res),
mask_index.data_handle(),
mask_index.extent(0),
[bitset_ptr] __device__(const IdxT sample_index) {
const IdxT bit_element = sample_index / 32;
const IdxT bit_index = sample_index % 32;
const uint32_t bitmask = ~(1 << bit_index);
atomicAnd(bitset_ptr + bit_element, bitmask);
wphicks marked this conversation as resolved.
Show resolved Hide resolved
});
}

/**
* @brief Function to test a list of indices in a bitset.
*
* @tparam IdxT Indexing type
* @tparam OutputT Output type of the test. Default is bool.
* @param res RAFT resources
* @param bitset_view_ View of the bitset
* @param queries List of indices to test
* @param output List of outputs
*/
template <typename IdxT, typename OutputT = bool>
void bitset_test(const raft::resources& res,
const raft::utils::bitset_view<IdxT> bitset_view_,
raft::device_vector_view<const IdxT, IdxT> queries,
raft::device_vector_view<OutputT, IdxT> output)
{
RAFT_EXPECTS(output.extent(0) == queries.extent(0), "Output and queries must be same size");
raft::linalg::map(
res, output, [=] __device__(IdxT query) { return OutputT(bitset_view_.test(query)); }, queries);
}
/** @} */
} // end namespace raft::utils
5 changes: 5 additions & 0 deletions cpp/include/raft/util/memory_pool-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

namespace raft {

/**
* @defgroup memory_pool Memory Pool
* @{
*/
/**
* @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned
* unique pointer.
Expand Down Expand Up @@ -73,4 +77,5 @@ RAFT_INLINE_CONDITIONAL std::unique_ptr<rmm::mr::device_memory_resource> get_poo
return pool_res;
}

/** @} */
} // namespace raft
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ if(BUILD_TESTS)
PATH
test/core/seive.cu
test/util/bitonic_sort.cu
test/util/bitset.cu
test/util/cudart_utils.cpp
test/util/device_atomics.cu
test/util/integer_utils.cpp
Expand Down
154 changes: 154 additions & 0 deletions cpp/test/util/bitset.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.cuh"

#include <raft/random/rng.cuh>
#include <raft/util/bitset.cuh>

#include <gtest/gtest.h>

#include <algorithm>
#include <numeric>

namespace raft::utils {

struct test_spec {
int bitset_len;
int mask_len;
int query_len;
};

auto operator<<(std::ostream& os, const test_spec& ss) -> std::ostream&
{
os << "bitset{bitset_len: " << ss.bitset_len << ", mask_len: " << ss.mask_len << "}";
return os;
}

template <typename T>
void add_cpu_bitset(std::vector<uint32_t>& bitset, const std::vector<T>& mask_idx)
wphicks marked this conversation as resolved.
Show resolved Hide resolved
{
for (size_t i = 0; i < mask_idx.size(); i++) {
auto idx = mask_idx[i];
bitset[idx / 32] &= ~(1 << (idx % 32));
}
}

template <typename T>
void create_cpu_bitset(std::vector<uint32_t>& bitset, const std::vector<T>& mask_idx)
{
for (size_t i = 0; i < bitset.size(); i++) {
bitset[i] = 0xffffffff;
}
add_cpu_bitset<T>(bitset, mask_idx);
}

template <typename T>
void test_cpu_bitset(const std::vector<uint32_t>& bitset,
const std::vector<T>& queries,
std::vector<uint8_t>& result)
{
for (size_t i = 0; i < queries.size(); i++) {
result[i] = uint8_t((bitset[queries[i] / 32] & (1 << (queries[i] % 32))) != 0);
}
}

template <typename T>
class BitsetTest : public testing::TestWithParam<test_spec> {
protected:
const test_spec spec;
std::vector<uint32_t> bitset_result;
std::vector<uint32_t> bitset_ref;
raft::resources res;

public:
explicit BitsetTest()
: spec(testing::TestWithParam<test_spec>::GetParam()),
bitset_result(raft::ceildiv(spec.bitset_len, 32)),
bitset_ref(raft::ceildiv(spec.bitset_len, 32))
{
}

void run()
{
auto stream = resource::get_cuda_stream(res);

// generate input and mask
raft::random::RngState rng(42);
auto mask_device = raft::make_device_vector<T, T>(res, spec.mask_len);
std::vector<T> mask_cpu(spec.mask_len);
raft::random::uniformInt(res, rng, mask_device.view(), T(0), T(spec.bitset_len));
update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream);
resource::sync_stream(res, stream);

// calculate the results
auto test_bitset =
raft::utils::bitset<T>(res, raft::make_const_mdspan(mask_device.view()), T(spec.bitset_len));
update_host(
bitset_result.data(), test_bitset.view().get_bitset_ptr(), bitset_result.size(), stream);

// calculate the reference
create_cpu_bitset(bitset_ref, mask_cpu);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<T>()));

auto query_device = raft::make_device_vector<T, T>(res, spec.query_len);
auto result_device = raft::make_device_vector<uint8_t, T>(res, spec.query_len);
auto query_cpu = std::vector<T>(spec.query_len);
auto result_cpu = std::vector<uint8_t>(spec.query_len);
auto result_ref = std::vector<uint8_t>(spec.query_len);

// Create queries and verify the test results
raft::random::uniformInt(res, rng, query_device.view(), T(0), T(spec.bitset_len));
update_host(query_cpu.data(), query_device.data_handle(), query_device.extent(0), stream);
raft::utils::bitset_test(
res, test_bitset.view(), raft::make_const_mdspan(query_device.view()), result_device.view());
update_host(result_cpu.data(), result_device.data_handle(), result_device.extent(0), stream);
test_cpu_bitset(bitset_ref, query_cpu, result_ref);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(result_cpu, result_ref, Compare<bool>()));

// Add more sample to the bitset and re-test
raft::random::uniformInt(res, rng, mask_device.view(), T(0), T(spec.bitset_len));
update_host(mask_cpu.data(), mask_device.data_handle(), mask_device.extent(0), stream);
resource::sync_stream(res, stream);
raft::utils::bitset_unset<T>(res, test_bitset.view(), mask_device.view());
update_host(
bitset_result.data(), test_bitset.view().get_bitset_ptr(), bitset_result.size(), stream);

add_cpu_bitset(bitset_ref, mask_cpu);
resource::sync_stream(res, stream);
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<T>()));
}
};

auto inputs = ::testing::Values(test_spec{32, 5, 10},
test_spec{100, 30, 10},
test_spec{1024, 55, 100},
test_spec{10000, 1000, 1000},
test_spec{1 << 15, 1 << 3, 1 << 12},
test_spec{1 << 15, 1 << 14, 1 << 13},
test_spec{1 << 25, 1 << 23, 1 << 14});

using Uint32 = BitsetTest<uint32_t>;
TEST_P(Uint32, Run) { run(); }
INSTANTIATE_TEST_CASE_P(BitsetTest, Uint32, inputs);

using Uint64 = BitsetTest<uint64_t>;
TEST_P(Uint64, Run) { run(); }
INSTANTIATE_TEST_CASE_P(BitsetTest, Uint64, inputs);

} // namespace raft::utils
3 changes: 2 additions & 1 deletion docs/source/cpp_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ C++ API
cpp_api/random.rst
cpp_api/solver.rst
cpp_api/sparse.rst
cpp_api/stats.rst
cpp_api/stats.rst
cpp_api/utils.rst
Loading