Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Byte-Pair-Encoding usage of cuco static-map for storing merge-pairs #13807

Merged
merged 31 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
11a8dd0
Fix Byte-Pair-Encoding usage of cuco static-map for storing merge-pairs
davidwendt Aug 2, 2023
890c0f6
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 2, 2023
b3e01f0
Use linear probing + use proper key sentinel
PointKernel Aug 2, 2023
ac73b10
Use != operator for map iterators
PointKernel Aug 2, 2023
326c2fa
Use access operator
PointKernel Aug 2, 2023
5d4223e
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 3, 2023
fceefc4
use linear-probe scheme
davidwendt Aug 3, 2023
b93b1cc
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
PointKernel Aug 3, 2023
bfa1a13
Merge branch 'branch-23.10' into fix-bpe-static-map
PointKernel Aug 3, 2023
860b1b8
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 3, 2023
148fc83
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 5, 2023
91ab4fb
keep merge-pairs column-device-view alive
davidwendt Aug 6, 2023
6a4fcf6
cleanup col-device-view declaration in impl class
davidwendt Aug 6, 2023
0008bd2
cleanup code; add comments
davidwendt Aug 7, 2023
40755df
Update cuco git tag
PointKernel Aug 7, 2023
cda11b8
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
PointKernel Aug 7, 2023
353971f
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 9, 2023
baf54e5
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 10, 2023
440b8b5
reuse already created col-dev-view
davidwendt Aug 11, 2023
d6234c6
Revert temporary CMake changes
PointKernel Aug 11, 2023
caf5e68
remove unneeded include
davidwendt Aug 11, 2023
38171e3
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
davidwendt Aug 11, 2023
42f1b92
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 11, 2023
06136dd
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 13, 2023
d884817
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 14, 2023
f262d74
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 15, 2023
7ffb1ab
Merge branch 'fix-bpe-static-map' of github.com:davidwendt/cudf into …
davidwendt Aug 15, 2023
708d8fd
fix some comments
davidwendt Aug 16, 2023
7054430
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 16, 2023
24548cd
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 16, 2023
8b5aba7
Merge branch 'branch-23.10' into fix-bpe-static-map
davidwendt Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions cpp/include/nvtext/bpe_tokenize.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -61,19 +61,6 @@ struct bpe_merge_pairs {
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

~bpe_merge_pairs();

/**
* @brief Returns the number of merge pairs in the table.
*
* @return The number of merge pairs in the table
*/
cudf::size_type get_size();
/**
* @brief Returns the number of unique merge pairs in the table.
*
* @return The number of unique merge pairs in the table
*/
std::size_t get_map_size();
};

/**
Expand Down
41 changes: 17 additions & 24 deletions cpp/src/text/subword/bpe_tokenizer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ __device__ cudf::string_view get_first_token(cudf::string_view const& d_str)
*
* @see The byte_pair_encoding_fn::operator() function below for details.
*/
template <typename MapRefType>
struct byte_pair_encoding_fn {
cudf::column_device_view const d_merges;
cudf::column_device_view const d_strings;
merge_pairs_map_type::device_view const d_map;
MapRefType const d_map;
cudf::size_type* d_sizes; // output size of encoded string
string_hasher_type const hasher;
cudf::size_type* d_byte_indices;
Expand Down Expand Up @@ -136,17 +137,13 @@ struct byte_pair_encoding_fn {
}

/**
* @brief Compute the hash over the input strings.
* @brief Lookup the pair of strings in the d_map/d_merges
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
*
* The input strings are combined with a space to produce hash for matching
* a merge pair within the `d_map`.
*
* @param lhs First string.
* @param rhs Second string.
* @return The hash value to match with `d_map`.
* @param lhs Left half of the string
* @param rhs Right half of the string
* @return Position of merge pair within d_map
*/
__device__ cudf::hash_value_type compute_hash(cudf::string_view const& lhs,
cudf::string_view const& rhs)
__device__ auto get_merge_pair(cudf::string_view const& lhs, cudf::string_view const& rhs)
{
__shared__ char shmem[48 * 1024]; // max for Pascal
auto const total_size = lhs.size_bytes() + rhs.size_bytes() + 1;
Expand All @@ -155,7 +152,7 @@ struct byte_pair_encoding_fn {
// Edge case check.
// Empirically found only two merge pair strings that were greater than 70 bytes
// and they both looked like ignorable errors. Double check this analysis with Vibhu.
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
if (thread_memory_size < total_size) { return 0; }
if (thread_memory_size < total_size) { return d_map.end(); }

// build the target string in shared memory
char* ptr = &shmem[threadIdx.x * thread_memory_size];
Expand All @@ -165,8 +162,8 @@ struct byte_pair_encoding_fn {
memcpy(ptr + lhs.size_bytes(), " ", 1);
memcpy(ptr + lhs.size_bytes() + 1, rhs.data(), rhs.size_bytes());

auto const d_hash_str = cudf::string_view(ptr, total_size);
return hasher(d_hash_str); // return the hash for the temp string
auto const d_str = cudf::string_view(ptr, total_size);
return d_map.find(d_str);
}

/**
Expand Down Expand Up @@ -233,11 +230,10 @@ struct byte_pair_encoding_fn {
auto const rhs = next_substr(itr, end, d_str);
if (rhs.empty()) break; // no more adjacent pairs

auto const hash = compute_hash(lhs, rhs);
auto const map_itr = d_map.find(hash, thrust::identity<cudf::hash_value_type>{});
auto const map_itr = get_merge_pair(lhs, rhs);
if (map_itr != d_map.end()) {
// found a match; record the rank (and other min_ vars)
auto const rank = static_cast<cudf::size_type>(map_itr->second);
auto const rank = map_itr->second;
if (rank < min_rank) {
min_rank = rank;
min_itr = itr;
Expand Down Expand Up @@ -354,12 +350,12 @@ std::unique_ptr<cudf::column> byte_pair_encoding(
bpe_merge_pairs::bpe_merge_pairs_impl const& merge_pairs,
rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(!merge_pairs.get_merge_pairs().is_empty(), "Merge pairs table must not be empty");
auto const d_merges = merge_pairs.get_merge_pairs();
CUDF_EXPECTS(d_merges.size() > 0, "Merge pairs table must not be empty");

// build working vector to hold index values per byte
rmm::device_uvector<cudf::size_type> d_byte_indices(input.chars().size(), stream);

auto const d_merges = cudf::column_device_view::create(merge_pairs.get_merge_pairs(), stream);
auto const d_strings = cudf::column_device_view::create(input.parent(), stream);

auto offsets = cudf::make_numeric_column(cudf::data_type{cudf::type_to_id<cudf::size_type>()},
Expand All @@ -369,12 +365,9 @@ std::unique_ptr<cudf::column> byte_pair_encoding(
rmm::mr::get_current_device_resource());
auto d_offsets = offsets->mutable_view().data<cudf::size_type>();

byte_pair_encoding_fn fn{*d_merges,
*d_strings,
merge_pairs.get_merge_pairs_map(),
d_offsets,
string_hasher_type{},
d_byte_indices.data()};
auto map_ref = merge_pairs.get_merge_pairs_ref();
byte_pair_encoding_fn<decltype(map_ref)> fn{
d_merges, *d_strings, map_ref, d_offsets, string_hasher_type{}, d_byte_indices.data()};
thrust::for_each_n(
rmm::exec_policy(stream), thrust::make_counting_iterator<cudf::size_type>(0), input.size(), fn);

Expand Down
71 changes: 64 additions & 7 deletions cpp/src/text/subword/bpe_tokenizer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,96 @@
#include <hash/hash_allocator.cuh>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/hashing/detail/murmurhash3_x86_32.cuh>
#include <cudf/strings/string_view.cuh>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/polymorphic_allocator.hpp>

#include <cuco/operator.hpp>
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
#include <cuco/static_map.cuh>

#include <cstdint>
#include <type_traits>

namespace nvtext {
namespace detail {

using hash_value_type = uint32_t;
using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>;

/**
* @brief Hasher function used for building and using the cuco static-map
*
* This takes advantage of heterogeneous lookup feature in cuco static-map which
* allows inserting with one type (index) and looking up with a different type (string).
davidwendt marked this conversation as resolved.
Show resolved Hide resolved
*/
struct bpe_hasher {
cudf::column_device_view const d_strings;
string_hasher_type hasher{};
// used by insert
__device__ hash_value_type operator()(cudf::size_type index) const
{
return hasher(d_strings.element<cudf::string_view>(index));
}
// used by find
__device__ hash_value_type operator()(cudf::string_view const& s) const { return hasher(s); }
};

/**
* @brief Equal function used for building and using the cuco static-map
*
* This takes advantage of heterogeneous lookup feature in cuco static-map which
* allows inserting with one type (index) and looking up with a different type (string).
*/
struct bpe_equal {
cudf::column_device_view const d_strings;
// used by insert
__device__ bool operator()(cudf::size_type lhs, cudf::size_type rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == d_strings.element<cudf::string_view>(rhs);
}
// used by find
__device__ bool operator()(cudf::size_type lhs, cudf::string_view const& rhs) const noexcept
{
return d_strings.element<cudf::string_view>(lhs) == rhs;
}
};

using hash_table_allocator_type = rmm::mr::stream_allocator_adaptor<default_allocator<char>>;

using merge_pairs_map_type = cuco::static_map<cudf::hash_value_type,
cudf::size_type,
cuda::thread_scope_device,
hash_table_allocator_type>;
using probe_scheme = cuco::experimental::linear_probing<1, bpe_hasher>;

using string_hasher_type = cudf::hashing::detail::MurmurHash3_x86_32<cudf::string_view>;
using merge_pairs_map_type = cuco::experimental::static_map<cudf::size_type,
cudf::size_type,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
bpe_equal,
probe_scheme,
hash_table_allocator_type>;

} // namespace detail

// since column_device_view::create returns is a little more than
// std::unique_ptr<column_device_view> this helper simplifies the return type in a more maintainable
// way
using col_device_view = std::invoke_result_t<decltype(&cudf::column_device_view::create),
cudf::column_view,
rmm::cuda_stream_view>;

struct bpe_merge_pairs::bpe_merge_pairs_impl {
std::unique_ptr<cudf::column> const merge_pairs;
col_device_view const d_merge_pairs;
std::unique_ptr<detail::merge_pairs_map_type> merge_pairs_map;

bpe_merge_pairs_impl(std::unique_ptr<cudf::column>&& merge_pairs,
col_device_view&& d_merge_pairs,
std::unique_ptr<detail::merge_pairs_map_type>&& merge_pairs_map);

auto get_merge_pairs() const { return merge_pairs->view(); }
auto get_merge_pairs_map() const { return merge_pairs_map->get_device_view(); }
auto const get_merge_pairs() const { return *d_merge_pairs; }
auto get_merge_pairs_ref() const { return merge_pairs_map->ref(cuco::experimental::op::find); }
};

} // namespace nvtext
48 changes: 16 additions & 32 deletions cpp/src/text/subword/load_merges_file.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,8 @@

namespace nvtext {
namespace detail {

namespace {

struct make_pair_function {
/**
* @brief Hash the merge pair entry
*/
__device__ cuco::pair<cudf::hash_value_type, cudf::size_type> operator()(cudf::size_type idx)
{
auto const result = _hasher(d_strings.element<cudf::string_view>(idx));
return cuco::make_pair(result, idx);
}

string_hasher_type const _hasher;
cudf::column_device_view const d_strings;
};

/**
* @brief Loads a text file of merge-pairs into a strings column.
*
Expand Down Expand Up @@ -101,36 +86,34 @@ std::unique_ptr<cudf::column> load_file_to_column(std::string const& filename_me
}

std::unique_ptr<detail::merge_pairs_map_type> initialize_merge_pairs_map(
cudf::strings_column_view const& input, rmm::cuda_stream_view stream)
cudf::column_device_view const& input, rmm::cuda_stream_view stream)
{
// Ensure capacity is at least (size/0.7) as documented here:
// https://github.com/NVIDIA/cuCollections/blob/6ec8b6dcdeceea07ab4456d32461a05c18864411/include/cuco/static_map.cuh#L179-L182
auto merge_pairs_map = std::make_unique<merge_pairs_map_type>(
static_cast<size_t>(input.size() * 2), // capacity is 2x;
cuco::empty_key{std::numeric_limits<cudf::hash_value_type>::max()},
cuco::empty_key{-1},
cuco::empty_value{-1}, // empty value is not used
bpe_equal{input},
probe_scheme{bpe_hasher{input}},
hash_table_allocator_type{default_allocator<char>{}, stream},
stream.value());

auto d_strings = cudf::column_device_view::create(input.parent(), stream);
make_pair_function pair_func{string_hasher_type{}, *d_strings};
auto iter = cudf::detail::make_counting_transform_iterator(0, pair_func);
auto iter = cudf::detail::make_counting_transform_iterator(
0, [] __device__(cudf::size_type idx) { return cuco::make_pair(idx, idx); });

merge_pairs_map->insert(iter,
iter + input.size(),
thrust::identity<cudf::hash_value_type>{},
thrust::equal_to<cudf::hash_value_type>{},
stream.value());
merge_pairs_map->insert_async(iter, iter + input.size(), stream.value());

return merge_pairs_map;
}

std::unique_ptr<bpe_merge_pairs::bpe_merge_pairs_impl> create_bpe_merge_pairs_impl(
std::unique_ptr<cudf::column>&& input, rmm::cuda_stream_view stream)
{
auto merge_pairs = initialize_merge_pairs_map(cudf::strings_column_view(input->view()), stream);
return std::make_unique<nvtext::bpe_merge_pairs::bpe_merge_pairs_impl>(std::move(input),
std::move(merge_pairs));
auto d_input = cudf::column_device_view::create(input->view(), stream);
auto merge_pairs = initialize_merge_pairs_map(*d_input, stream);
return std::make_unique<nvtext::bpe_merge_pairs::bpe_merge_pairs_impl>(
std::move(input), std::move(d_input), std::move(merge_pairs));
}

std::unique_ptr<bpe_merge_pairs::bpe_merge_pairs_impl> create_bpe_merge_pairs_impl(
Expand Down Expand Up @@ -163,8 +146,12 @@ std::unique_ptr<bpe_merge_pairs> load_merge_pairs_file(std::string const& filena

bpe_merge_pairs::bpe_merge_pairs_impl::bpe_merge_pairs_impl(
std::unique_ptr<cudf::column>&& merge_pairs,
std::unique_ptr<cudf::column_device_view, std::function<void(cudf::column_device_view*)>>&&
d_merge_pairs,
std::unique_ptr<detail::merge_pairs_map_type>&& merge_pairs_map)
: merge_pairs(std::move(merge_pairs)), merge_pairs_map(std::move(merge_pairs_map))
: merge_pairs(std::move(merge_pairs)),
d_merge_pairs(std::move(d_merge_pairs)),
merge_pairs_map(std::move(merge_pairs_map))
{
}

Expand All @@ -184,7 +171,4 @@ bpe_merge_pairs::bpe_merge_pairs(cudf::strings_column_view const& input,

bpe_merge_pairs::~bpe_merge_pairs() = default;

cudf::size_type bpe_merge_pairs::get_size() { return impl->merge_pairs->size(); }
std::size_t bpe_merge_pairs::get_map_size() { return impl->merge_pairs_map->get_size(); }

} // namespace nvtext