Skip to content

Commit

Permalink
Fix Byte-Pair-Encoding usage of cuco static-map for storing merge-pai…
Browse files Browse the repository at this point in the history
…rs (#13807)

Switching to use `cuco::experimental::static_map` for storing the unique merge-pair strings that can be looked up by `string_view`. This takes advantage of a feature of the `static_map` that allows storing with one key (index to a string entry) and lookup with a different type (string). The map uses a hash on the string for storing the index but allows lookup by string since the hash of string can resolve the entry and duplicates can be resolved by comparing the string with row entries.

Authors:
  - David Wendt (https://github.com/davidwendt)
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - Yunsong Wang (https://github.com/PointKernel)
  - Bradley Dice (https://github.com/bdice)

URL: #13807
  • Loading branch information
davidwendt authored Aug 17, 2023
1 parent 41f0caf commit f543dfa
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 78 deletions.
15 changes: 1 addition & 14 deletions cpp/include/nvtext/bpe_tokenize.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,19 +61,6 @@ struct bpe_merge_pairs {
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

~bpe_merge_pairs();

/**
* @brief Returns the number of merge pairs in the table.
*
* @return The number of merge pairs in the table
*/
cudf::size_type get_size();
/**
* @brief Returns the number of unique merge pairs in the table.
*
* @return The number of unique merge pairs in the table
*/
std::size_t get_map_size();
};

/**
Expand Down
43 changes: 18 additions & 25 deletions cpp/src/text/subword/bpe_tokenizer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ __device__ cudf::string_view get_first_token(cudf::string_view const& d_str)
*
* @see The byte_pair_encoding_fn::operator() function below for details.
*/
template <typename MapRefType>
struct byte_pair_encoding_fn {
cudf::column_device_view const d_merges;
cudf::column_device_view const d_strings;
merge_pairs_map_type::device_view const d_map;
MapRefType const d_map;
cudf::size_type* d_sizes; // output size of encoded string
string_hasher_type const hasher;
cudf::size_type* d_byte_indices;
Expand Down Expand Up @@ -136,26 +137,22 @@ struct byte_pair_encoding_fn {
}

/**
* @brief Compute the hash over the input strings.
* @brief Look up the pair of strings in the d_map/d_merges
*
* The input strings are combined with a space to produce hash for matching
* a merge pair within the `d_map`.
*
* @param lhs First string.
* @param rhs Second string.
* @return The hash value to match with `d_map`.
* @param lhs Left half of the string
* @param rhs Right half of the string
* @return Position of merge pair within d_map
*/
__device__ cudf::hash_value_type compute_hash(cudf::string_view const& lhs,
cudf::string_view const& rhs)
__device__ auto get_merge_pair(cudf::string_view const& lhs, cudf::string_view const& rhs)
{
__shared__ char shmem[48 * 1024]; // max for Pascal
auto const total_size = lhs.size_bytes() + rhs.size_bytes() + 1;
auto const thread_memory_size = static_cast<cudf::size_type>(sizeof(shmem) / blockDim.x);

// Edge case check.
// Empirically found only two merge pair strings that were greater than 70 bytes
// and they both looked like ignorable errors. Double check this analysis with Vibhu.
if (thread_memory_size < total_size) { return 0; }
// and they both looked like ignorable errors.
if (thread_memory_size < total_size) { return d_map.end(); }

// build the target string in shared memory
char* ptr = &shmem[threadIdx.x * thread_memory_size];
Expand All @@ -165,8 +162,8 @@ struct byte_pair_encoding_fn {
memcpy(ptr + lhs.size_bytes(), " ", 1);
memcpy(ptr + lhs.size_bytes() + 1, rhs.data(), rhs.size_bytes());

auto const d_hash_str = cudf::string_view(ptr, total_size);
return hasher(d_hash_str); // return the hash for the temp string
auto const d_str = cudf::string_view(ptr, total_size);
return d_map.find(d_str);
}

/**
Expand Down Expand Up @@ -233,11 +230,10 @@ struct byte_pair_encoding_fn {
auto const rhs = next_substr(itr, end, d_str);
if (rhs.empty()) break; // no more adjacent pairs

auto const hash = compute_hash(lhs, rhs);
auto const map_itr = d_map.find(hash, thrust::identity<cudf::hash_value_type>{});
auto const map_itr = get_merge_pair(lhs, rhs);
if (map_itr != d_map.end()) {
// found a match; record the rank (and other min_ vars)
auto const rank = static_cast<cudf::size_type>(map_itr->second);
auto const rank = map_itr->second;
if (rank < min_rank) {
min_rank = rank;
min_itr = itr;
Expand Down Expand Up @@ -354,12 +350,12 @@ std::unique_ptr<cudf::column> byte_pair_encoding(
bpe_merge_pairs::bpe_merge_pairs_impl const& merge_pairs,
rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(!merge_pairs.get_merge_pairs().is_empty(), "Merge pairs table must not be empty");
auto const d_merges = merge_pairs.get_merge_pairs();
CUDF_EXPECTS(d_merges.size() > 0, "Merge pairs table must not be empty");

// build working vector to hold index values per byte
rmm::device_uvector<cudf::size_type> d_byte_indices(input.chars().size(), stream);

auto const d_merges = cudf::column_device_view::create(merge_pairs.get_merge_pairs(), stream);
auto const d_strings = cudf::column_device_view::create(input.parent(), stream);

auto offsets = cudf::make_numeric_column(cudf::data_type{cudf::type_to_id<cudf::size_type>()},
Expand All @@ -369,12 +365,9 @@ std::unique_ptr<cudf::column> byte_pair_encoding(
rmm::mr::get_current_device_resource());
auto d_offsets = offsets->mutable_view().data<cudf::size_type>();

byte_pair_encoding_fn fn{*d_merges,
*d_strings,
merge_pairs.get_merge_pairs_map(),
d_offsets,
string_hasher_type{},
d_byte_indices.data()};
auto map_ref = merge_pairs.get_merge_pairs_ref();
byte_pair_encoding_fn<decltype(map_ref)> fn{
d_merges, *d_strings, map_ref, d_offsets, string_hasher_type{}, d_byte_indices.data()};
thrust::for_each_n(
rmm::exec_policy(stream), thrust::make_counting_iterator<cudf::size_type>(0), input.size(), fn);

Expand Down
70 changes: 63 additions & 7 deletions cpp/src/text/subword/bpe_tokenizer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
#include <hash/hash_allocator.cuh>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/hashing/detail/murmurhash3_x86_32.cuh>
#include <cudf/strings/string_view.cuh>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
Expand All @@ -30,30 +32,84 @@
#include <cuco/static_map.cuh>

#include <cstdint>
#include <type_traits>

namespace nvtext {
namespace detail {

using hash_value_type = uint32_t;
using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>;

/**
* @brief Hasher function used for building and using the cuco static-map
*
* This takes advantage of heterogeneous lookup feature in cuco static-map which
* allows inserting with one type (index) and looking up with a different type (string).
*/
struct bpe_hasher {
cudf::column_device_view const d_strings;
string_hasher_type hasher{};
// used by insert
__device__ hash_value_type operator()(cudf::size_type index) const
{
return hasher(d_strings.element<cudf::string_view>(index));
}
// used by find
__device__ hash_value_type operator()(cudf::string_view const& s) const { return hasher(s); }
};

/**
* @brief Equal function used for building and using the cuco static-map
*
* This takes advantage of heterogeneous lookup feature in cuco static-map which
* allows inserting with one type (index) and looking up with a different type (string).
*/
struct bpe_equal {
cudf::column_device_view const d_strings;
// used by insert
__device__ bool operator()(cudf::size_type lhs, cudf::size_type rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == d_strings.element<cudf::string_view>(rhs);
}
// used by find
__device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == rhs;
}
};

using hash_table_allocator_type = rmm::mr::stream_allocator_adaptor<default_allocator<char>>;

using merge_pairs_map_type = cuco::static_map<cudf::hash_value_type,
cudf::size_type,
cuda::thread_scope_device,
hash_table_allocator_type>;
using probe_scheme = cuco::experimental::linear_probing<1, bpe_hasher>;

using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>;
using merge_pairs_map_type = cuco::experimental::static_map<cudf::size_type,
cudf::size_type,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
bpe_equal,
probe_scheme,
hash_table_allocator_type>;

} // namespace detail

// since column_device_view::create returns is a little more than
// std::unique_ptr<column_device_view> this helper simplifies the return type in a more maintainable
// way
using col_device_view = std::invoke_result_t<decltype(&cudf::column_device_view::create),
cudf::column_view,
rmm::cuda_stream_view>;

struct bpe_merge_pairs::bpe_merge_pairs_impl {
std::unique_ptr<cudf::column> const merge_pairs;
col_device_view const d_merge_pairs;
std::unique_ptr<detail::merge_pairs_map_type> merge_pairs_map;

bpe_merge_pairs_impl(std::unique_ptr<cudf::column>&& merge_pairs,
col_device_view&& d_merge_pairs,
std::unique_ptr<detail::merge_pairs_map_type>&& merge_pairs_map);

auto get_merge_pairs() const { return merge_pairs->view(); }
auto get_merge_pairs_map() const { return merge_pairs_map->get_device_view(); }
auto const get_merge_pairs() const { return *d_merge_pairs; }
auto get_merge_pairs_ref() const { return merge_pairs_map->ref(cuco::experimental::op::find); }
};

} // namespace nvtext
48 changes: 16 additions & 32 deletions cpp/src/text/subword/load_merges_file.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,8 @@

namespace nvtext {
namespace detail {

namespace {

struct make_pair_function {
/**
* @brief Hash the merge pair entry
*/
__device__ cuco::pair<cudf::hash_value_type, cudf::size_type> operator()(cudf::size_type idx)
{
auto const result = _hasher(d_strings.element<cudf::string_view>(idx));
return cuco::make_pair(result, idx);
}

string_hasher_type const _hasher;
cudf::column_device_view const d_strings;
};

/**
* @brief Loads a text file of merge-pairs into a strings column.
*
Expand Down Expand Up @@ -101,36 +86,34 @@ std::unique_ptr<cudf::column> load_file_to_column(std::string const& filename_me
}

std::unique_ptr<detail::merge_pairs_map_type> initialize_merge_pairs_map(
cudf::strings_column_view const& input, rmm::cuda_stream_view stream)
cudf::column_device_view const& input, rmm::cuda_stream_view stream)
{
// Ensure capacity is at least (size/0.7) as documented here:
// https://github.com/NVIDIA/cuCollections/blob/6ec8b6dcdeceea07ab4456d32461a05c18864411/include/cuco/static_map.cuh#L179-L182
auto merge_pairs_map = std::make_unique<merge_pairs_map_type>(
static_cast<size_t>(input.size() * 2), // capacity is 2x;
cuco::empty_key{std::numeric_limits<cudf::hash_value_type>::max()},
cuco::empty_key{-1},
cuco::empty_value{-1}, // empty value is not used
bpe_equal{input},
probe_scheme{bpe_hasher{input}},
hash_table_allocator_type{default_allocator<char>{}, stream},
stream.value());

auto d_strings = cudf::column_device_view::create(input.parent(), stream);
make_pair_function pair_func{string_hasher_type{}, *d_strings};
auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func);
auto iter = cudf::detail::make_counting_transform_iterator(
0, [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); });

merge_pairs_map->insert(iter,
iter + input.size(),
thrust::identity<cudf::hash_value_type>{},
thrust::equal_to<cudf::hash_value_type>{},
stream.value());
merge_pairs_map->insert_async(iter, iter + input.size(), stream.value());

return merge_pairs_map;
}

std::unique_ptr<bpe_merge_pairs::bpe_merge_pairs_impl> create_bpe_merge_pairs_impl(
std::unique_ptr<cudf::column>&& input, rmm::cuda_stream_view stream)
{
auto merge_pairs = initialize_merge_pairs_map(cudf::strings_column_view(input->view()), stream);
return std::make_unique<nvtext::bpe_merge_pairs::bpe_merge_pairs_impl>(std::move(input),
std::move(merge_pairs));
auto d_input = cudf::column_device_view::create(input->view(), stream);
auto merge_pairs = initialize_merge_pairs_map(*d_input, stream);
return std::make_unique<nvtext::bpe_merge_pairs::bpe_merge_pairs_impl>(
std::move(input), std::move(d_input), std::move(merge_pairs));
}

std::unique_ptr<bpe_merge_pairs::bpe_merge_pairs_impl> create_bpe_merge_pairs_impl(
Expand Down Expand Up @@ -163,8 +146,12 @@ std::unique_ptr<bpe_merge_pairs> load_merge_pairs_file(std::string const& filena

bpe_merge_pairs::bpe_merge_pairs_impl::bpe_merge_pairs_impl(
std::unique_ptr<cudf::column>&& merge_pairs,
std::unique_ptr<cudf::column_device_view, std::function<void(cudf::column_device_view*)>>&&
d_merge_pairs,
std::unique_ptr<detail::merge_pairs_map_type>&& merge_pairs_map)
: merge_pairs(std::move(merge_pairs)), merge_pairs_map(std::move(merge_pairs_map))
: merge_pairs(std::move(merge_pairs)),
d_merge_pairs(std::move(d_merge_pairs)),
merge_pairs_map(std::move(merge_pairs_map))
{
}

Expand All @@ -184,7 +171,4 @@ bpe_merge_pairs::bpe_merge_pairs(cudf::strings_column_view const& input,

bpe_merge_pairs::~bpe_merge_pairs() = default;

cudf::size_type bpe_merge_pairs::get_size() { return impl->merge_pairs->size(); }
std::size_t bpe_merge_pairs::get_map_size() { return impl->merge_pairs_map->get_size(); }

} // namespace nvtext

0 comments on commit f543dfa

Please sign in to comment.