Skip to content

Commit

Permalink
refactored lookup tables
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jul 16, 2022
1 parent fe06f0b commit 9dfd4ad
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 207 deletions.
94 changes: 35 additions & 59 deletions cpp/src/io/fst/symbol_lut.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

#pragma once

#include <cudf/types.hpp>
#include <io/utilities/hostdevice_vector.hpp>

#include <cub/cub.cuh>

#include <algorithm>
Expand All @@ -34,105 +37,78 @@ namespace detail {
* @tparam SymbolT The symbol type being passed in to lookup the corresponding symbol group id
*/
template <typename SymbolT>
struct SingleSymbolSmemLUT {
//------------------------------------------------------------------------------
// DEFAULT TYPEDEFS
//------------------------------------------------------------------------------
class SingleSymbolSmemLUT {
private:
// Type used for representing a symbol group id (i.e., what we return for a given symbol)
using SymbolGroupIdT = uint8_t;

//------------------------------------------------------------------------------
// DERIVED CONFIGURATIONS
//------------------------------------------------------------------------------
/// Number of entries for every lookup (e.g., for 8-bit Symbol this is 256)
static constexpr uint32_t NUM_ENTRIES_PER_LUT = 0x01U << (sizeof(SymbolT) * 8U);

//------------------------------------------------------------------------------
// TYPEDEFS
//------------------------------------------------------------------------------

struct _TempStorage {
// d_match_meta_data[symbol] -> symbol group index
SymbolGroupIdT match_meta_data[NUM_ENTRIES_PER_LUT];
// sym_to_sgid[symbol] -> symbol group index
SymbolGroupIdT sym_to_sgid[NUM_ENTRIES_PER_LUT];
};

public:
struct KernelParameter {
// d_match_meta_data[min(symbol,num_valid_entries)] -> symbol group index
SymbolGroupIdT num_valid_entries;
// sym_to_sgid[min(symbol,num_valid_entries)] -> symbol group index
SymbolT num_valid_entries;

// d_match_meta_data[symbol] -> symbol group index
SymbolGroupIdT* d_match_meta_data;
// sym_to_sgid[symbol] -> symbol group index
SymbolGroupIdT sym_to_sgid[NUM_ENTRIES_PER_LUT];
};

struct TempStorage : cub::Uninitialized<_TempStorage> {
};
using TempStorage = cub::Uninitialized<_TempStorage>;

//------------------------------------------------------------------------------
// HELPER METHODS
//------------------------------------------------------------------------------
/**
* @brief
*
* @param[in] d_temp_storage Device-side temporary storage that can be used to store the lookup
* table. If no storage is provided it will return the temporary storage requirements in \p
* d_temp_storage_bytes.
* @param[in,out] d_temp_storage_bytes Amount of device-side temporary storage that can be used in
* the number of bytes
* @param[out] sgid_init A hostdevice_vector that will be populated
* @param[in] symbol_strings Array of strings, where the i-th string holds all symbols
* (characters!) that correspond to the i-th symbol group index
* @param[out] kernel_param The kernel parameter object to be initialized with the given mapping
* of symbols to symbol group ids.
* @param[in] stream The stream that shall be used to cudaMemcpyAsync the lookup table
* @return
*/
template <typename SymbolGroupItT>
__host__ __forceinline__ static cudaError_t PrepareLUT(void* d_temp_storage,
size_t& d_temp_storage_bytes,
SymbolGroupItT const& symbol_strings,
KernelParameter& kernel_param,
cudaStream_t stream = 0)
static void InitDeviceSymbolGroupIdLut(hostdevice_vector<KernelParameter>& sgid_init,
SymbolGroupItT const& symbol_strings,
rmm::cuda_stream_view stream)
{
// The symbol group index to be returned if none of the given symbols match
SymbolGroupIdT no_match_id = symbol_strings.size();

std::vector<SymbolGroupIdT> lut(NUM_ENTRIES_PER_LUT);
// The symbol with the largest value that is mapped to a symbol group id
SymbolGroupIdT max_base_match_val = 0;

// Initialize all entries: by default we return the no-match-id
for (uint32_t i = 0; i < NUM_ENTRIES_PER_LUT; ++i) {
lut[i] = no_match_id;
}
std::fill(&sgid_init.host_ptr()->sym_to_sgid[0],
&sgid_init.host_ptr()->sym_to_sgid[NUM_ENTRIES_PER_LUT],
no_match_id);

// Set up lookup table
uint32_t sg_id = 0;
// Iterate over the symbol groups
for (auto const& sg_symbols : symbol_strings) {
// Iterate over all symbols that belong to the current symbol group
for (auto const& sg_symbol : sg_symbols) {
max_base_match_val = std::max(max_base_match_val, static_cast<SymbolGroupIdT>(sg_symbol));
lut[sg_symbol] = sg_id;
sgid_init.host_ptr()->sym_to_sgid[static_cast<int32_t>(sg_symbol)] = sg_id;
}
sg_id++;
}

// Initialize the out-of-bounds lookup: d_match_meta_data[max_base_match_val+1] -> no_match_id
lut[max_base_match_val + 1] = no_match_id;
// Initialize the out-of-bounds lookup: sym_to_sgid[max_base_match_val+1] -> no_match_id
sgid_init.host_ptr()->sym_to_sgid[max_base_match_val + 1] = no_match_id;

// Alias memory / return memory requiremenets
kernel_param.num_valid_entries = max_base_match_val + 2;
if (d_temp_storage) {
cudaError_t error = cudaMemcpyAsync(d_temp_storage,
lut.data(),
kernel_param.num_valid_entries * sizeof(SymbolGroupIdT),
cudaMemcpyHostToDevice,
stream);

kernel_param.d_match_meta_data = reinterpret_cast<SymbolGroupIdT*>(d_temp_storage);
return error;
} else {
d_temp_storage_bytes = kernel_param.num_valid_entries * sizeof(SymbolGroupIdT);
return cudaSuccess;
}
// TODO I think this could be +1?
sgid_init.host_ptr()->num_valid_entries = max_base_match_val + 2;

return cudaSuccess;
sgid_init.host_to_device(stream);
}

//------------------------------------------------------------------------------
Expand All @@ -150,29 +126,29 @@ struct SingleSymbolSmemLUT {
return private_storage;
}

__host__ __device__ __forceinline__ SingleSymbolSmemLUT(KernelParameter const& kernel_param,
TempStorage& temp_storage)
constexpr CUDF_HOST_DEVICE SingleSymbolSmemLUT(KernelParameter const& kernel_param,
TempStorage& temp_storage)
: temp_storage(temp_storage.Alias()), num_valid_entries(kernel_param.num_valid_entries)
{
// GPU-side init
#if CUB_PTX_ARCH > 0
for (int32_t i = threadIdx.x; i < kernel_param.num_valid_entries; i += blockDim.x) {
this->temp_storage.match_meta_data[i] = kernel_param.d_match_meta_data[i];
this->temp_storage.sym_to_sgid[i] = kernel_param.sym_to_sgid[i];
}
__syncthreads();

#else
// CPU-side init
for (std::size_t i = 0; i < kernel_param.num_luts; i++) {
this->temp_storage.match_meta_data[i] = kernel_param.d_match_meta_data[i];
this->temp_storage.sym_to_sgid[i] = kernel_param.sym_to_sgid[i];
}
#endif
}

__host__ __device__ __forceinline__ int32_t operator()(SymbolT const symbol) const
constexpr CUDF_HOST_DEVICE int32_t operator()(SymbolT const symbol) const
{
// Look up the symbol group for given symbol
return temp_storage.match_meta_data[min(symbol, num_valid_entries - 1)];
return temp_storage.sym_to_sgid[min(symbol, num_valid_entries - 1)];
}
};

Expand Down
109 changes: 35 additions & 74 deletions cpp/src/io/fst/transition_table.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

#pragma once

#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
#include <io/utilities/hostdevice_vector.hpp>

#include <cub/cub.cuh>

#include <cstdint>
Expand All @@ -25,103 +29,50 @@ namespace io {
namespace fst {
namespace detail {

template <int MAX_NUM_SYMBOLS, int MAX_NUM_STATES>
struct TransitionTable {
//------------------------------------------------------------------------------
// DEFAULT TYPEDEFS
//------------------------------------------------------------------------------
template <int32_t MAX_NUM_SYMBOLS, int32_t MAX_NUM_STATES>
class TransitionTable {
private:
// Type used
using ItemT = char;

struct TransitionVectorWrapper {
const ItemT* data;

__host__ __device__ TransitionVectorWrapper(const ItemT* data) : data(data) {}

__host__ __device__ __forceinline__ uint32_t Get(int32_t index) const { return data[index]; }
};

//------------------------------------------------------------------------------
// TYPEDEFS
//------------------------------------------------------------------------------
using TransitionVectorT = TransitionVectorWrapper;

struct _TempStorage {
//
ItemT transitions[MAX_NUM_STATES * MAX_NUM_SYMBOLS];
};

struct TempStorage : cub::Uninitialized<_TempStorage> {
};
public:
using TempStorage = cub::Uninitialized<_TempStorage>;

struct KernelParameter {
ItemT* transitions;
ItemT transitions[MAX_NUM_STATES * MAX_NUM_SYMBOLS];
};

using LoadAliasT = std::uint32_t;

static constexpr std::size_t NUM_AUX_MEM_BYTES =
CUB_QUOTIENT_CEILING(MAX_NUM_STATES * MAX_NUM_SYMBOLS * sizeof(ItemT), sizeof(LoadAliasT)) *
sizeof(LoadAliasT);

//------------------------------------------------------------------------------
// HELPER METHODS
//------------------------------------------------------------------------------
__host__ static cudaError_t CreateTransitionTable(
void* d_temp_storage,
size_t& temp_storage_bytes,
const std::vector<std::vector<int>>& trans_table,
KernelParameter& kernel_param,
cudaStream_t stream = 0)
static void InitDeviceTransitionTable(hostdevice_vector<KernelParameter>& transition_table_init,
const std::vector<std::vector<int>>& trans_table,
rmm::cuda_stream_view stream)
{
if (!d_temp_storage) {
temp_storage_bytes = NUM_AUX_MEM_BYTES;
return cudaSuccess;
}

// trans_vectors[symbol][state] -> new_state
ItemT trans_vectors[MAX_NUM_STATES * MAX_NUM_SYMBOLS];

// trans_table[state][symbol] -> new state
for (std::size_t state = 0; state < trans_table.size(); ++state) {
for (std::size_t symbol = 0; symbol < trans_table[state].size(); ++symbol) {
trans_vectors[symbol * MAX_NUM_STATES + state] = trans_table[state][symbol];
transition_table_init.host_ptr()->transitions[symbol * MAX_NUM_STATES + state] =
trans_table[state][symbol];
}
}

kernel_param.transitions = static_cast<ItemT*>(d_temp_storage);

// Copy transition table to device
return cudaMemcpyAsync(
d_temp_storage, trans_vectors, NUM_AUX_MEM_BYTES, cudaMemcpyHostToDevice, stream);
transition_table_init.host_to_device(stream);
}

//------------------------------------------------------------------------------
// MEMBER VARIABLES
//------------------------------------------------------------------------------
_TempStorage& temp_storage;

__device__ __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}

//------------------------------------------------------------------------------
// CONSTRUCTOR
//------------------------------------------------------------------------------
__host__ __device__ __forceinline__ TransitionTable(const KernelParameter& kernel_param,
TempStorage& temp_storage)
constexpr CUDF_HOST_DEVICE TransitionTable(const KernelParameter& kernel_param,
TempStorage& temp_storage)
: temp_storage(temp_storage.Alias())
{
#if CUB_PTX_ARCH > 0
for (int i = threadIdx.x; i < CUB_QUOTIENT_CEILING(NUM_AUX_MEM_BYTES, sizeof(LoadAliasT));
i += blockDim.x) {
reinterpret_cast<LoadAliasT*>(this->temp_storage.transitions)[i] =
reinterpret_cast<LoadAliasT*>(kernel_param.transitions)[i];
for (int i = threadIdx.x; i < MAX_NUM_STATES * MAX_NUM_SYMBOLS; i += blockDim.x) {
this->temp_storage.transitions[i] = kernel_param.transitions[i];
}
__syncthreads();
#else
for (int i = 0; i < kernel_param.num_luts; i++) {
for (int i = 0; i < MAX_NUM_STATES * MAX_NUM_SYMBOLS; i++) {
this->temp_storage.transitions[i] = kernel_param.transitions[i];
}
#endif
Expand All @@ -136,11 +87,21 @@ struct TransitionTable {
* @return
*/
template <typename StateIndexT, typename SymbolIndexT>
__host__ __device__ __forceinline__ int32_t operator()(StateIndexT state_id,
SymbolIndexT match_id) const
constexpr CUDF_HOST_DEVICE int32_t operator()(StateIndexT const state_id,
SymbolIndexT const match_id) const
{
return temp_storage.transitions[match_id * MAX_NUM_STATES + state_id];
}
}

private:
_TempStorage& temp_storage;

__device__ __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;

return private_storage;
}
};

} // namespace detail
Expand Down
Loading

0 comments on commit 9dfd4ad

Please sign in to comment.