diff --git a/cpp/src/io/fst/symbol_lut.cuh b/cpp/src/io/fst/symbol_lut.cuh index 08d5f4db58d..abf71a7fbea 100644 --- a/cpp/src/io/fst/symbol_lut.cuh +++ b/cpp/src/io/fst/symbol_lut.cuh @@ -16,6 +16,9 @@ #pragma once +#include +#include + #include #include @@ -34,38 +37,29 @@ namespace detail { * @tparam SymbolT The symbol type being passed in to lookup the corresponding symbol group id */ template -struct SingleSymbolSmemLUT { - //------------------------------------------------------------------------------ - // DEFAULT TYPEDEFS - //------------------------------------------------------------------------------ +class SingleSymbolSmemLUT { + private: // Type used for representing a symbol group id (i.e., what we return for a given symbol) using SymbolGroupIdT = uint8_t; - //------------------------------------------------------------------------------ - // DERIVED CONFIGURATIONS - //------------------------------------------------------------------------------ /// Number of entries for every lookup (e.g., for 8-bit Symbol this is 256) static constexpr uint32_t NUM_ENTRIES_PER_LUT = 0x01U << (sizeof(SymbolT) * 8U); - //------------------------------------------------------------------------------ - // TYPEDEFS - //------------------------------------------------------------------------------ - struct _TempStorage { - // d_match_meta_data[symbol] -> symbol group index - SymbolGroupIdT match_meta_data[NUM_ENTRIES_PER_LUT]; + // sym_to_sgid[symbol] -> symbol group index + SymbolGroupIdT sym_to_sgid[NUM_ENTRIES_PER_LUT]; }; + public: struct KernelParameter { - // d_match_meta_data[min(symbol,num_valid_entries)] -> symbol group index - SymbolGroupIdT num_valid_entries; + // sym_to_sgid[min(symbol,num_valid_entries)] -> symbol group index + SymbolT num_valid_entries; - // d_match_meta_data[symbol] -> symbol group index - SymbolGroupIdT* d_match_meta_data; + // sym_to_sgid[symbol] -> symbol group index + SymbolGroupIdT sym_to_sgid[NUM_ENTRIES_PER_LUT]; }; - struct TempStorage : cub::Uninitialized<_TempStorage> { - }; + using TempStorage = cub::Uninitialized<_TempStorage>; //------------------------------------------------------------------------------ // HELPER METHODS @@ -73,66 +67,48 @@ struct SingleSymbolSmemLUT { /** * @brief * - * @param[in] d_temp_storage Device-side temporary storage that can be used to store the lookup - * table. If no storage is provided it will return the temporary storage requirements in \p - * d_temp_storage_bytes. - * @param[in,out] d_temp_storage_bytes Amount of device-side temporary storage that can be used in - * the number of bytes + * @param[out] sgid_init A hostdevice_vector that will be populated * @param[in] symbol_strings Array of strings, where the i-th string holds all symbols * (characters!) that correspond to the i-th symbol group index - * @param[out] kernel_param The kernel parameter object to be initialized with the given mapping - * of symbols to symbol group ids. * @param[in] stream The stream that shall be used to cudaMemcpyAsync the lookup table * @return */ template - __host__ __forceinline__ static cudaError_t PrepareLUT(void* d_temp_storage, - size_t& d_temp_storage_bytes, - SymbolGroupItT const& symbol_strings, - KernelParameter& kernel_param, - cudaStream_t stream = 0) + static void InitDeviceSymbolGroupIdLut(hostdevice_vector& sgid_init, + SymbolGroupItT const& symbol_strings, + rmm::cuda_stream_view stream) { // The symbol group index to be returned if none of the given symbols match SymbolGroupIdT no_match_id = symbol_strings.size(); - std::vector lut(NUM_ENTRIES_PER_LUT); + // The symbol with the largest value that is mapped to a symbol group id SymbolGroupIdT max_base_match_val = 0; // Initialize all entries: by default we return the no-match-id - for (uint32_t i = 0; i < NUM_ENTRIES_PER_LUT; ++i) { - lut[i] = no_match_id; - } + std::fill(&sgid_init.host_ptr()->sym_to_sgid[0], + &sgid_init.host_ptr()->sym_to_sgid[NUM_ENTRIES_PER_LUT], + no_match_id); // Set up lookup table uint32_t sg_id = 0; + // Iterate over the symbol groups for (auto const& sg_symbols : symbol_strings) { + // Iterate over all symbols that belong to the current symbol group for (auto const& sg_symbol : sg_symbols) { max_base_match_val = std::max(max_base_match_val, static_cast(sg_symbol)); - lut[sg_symbol] = sg_id; + sgid_init.host_ptr()->sym_to_sgid[static_cast(sg_symbol)] = sg_id; } sg_id++; } - // Initialize the out-of-bounds lookup: d_match_meta_data[max_base_match_val+1] -> no_match_id - lut[max_base_match_val + 1] = no_match_id; + // Initialize the out-of-bounds lookup: sym_to_sgid[max_base_match_val+1] -> no_match_id + sgid_init.host_ptr()->sym_to_sgid[max_base_match_val + 1] = no_match_id; // Alias memory / return memory requiremenets - kernel_param.num_valid_entries = max_base_match_val + 2; - if (d_temp_storage) { - cudaError_t error = cudaMemcpyAsync(d_temp_storage, - lut.data(), - kernel_param.num_valid_entries * sizeof(SymbolGroupIdT), - cudaMemcpyHostToDevice, - stream); - - kernel_param.d_match_meta_data = reinterpret_cast(d_temp_storage); - return error; - } else { - d_temp_storage_bytes = kernel_param.num_valid_entries * sizeof(SymbolGroupIdT); - return cudaSuccess; - } + // TODO I think this could be +1? + sgid_init.host_ptr()->num_valid_entries = max_base_match_val + 2; - return cudaSuccess; + sgid_init.host_to_device(stream); } //------------------------------------------------------------------------------ @@ -150,29 +126,29 @@ struct SingleSymbolSmemLUT { return private_storage; } - __host__ __device__ __forceinline__ SingleSymbolSmemLUT(KernelParameter const& kernel_param, - TempStorage& temp_storage) + constexpr CUDF_HOST_DEVICE SingleSymbolSmemLUT(KernelParameter const& kernel_param, + TempStorage& temp_storage) : temp_storage(temp_storage.Alias()), num_valid_entries(kernel_param.num_valid_entries) { // GPU-side init #if CUB_PTX_ARCH > 0 for (int32_t i = threadIdx.x; i < kernel_param.num_valid_entries; i += blockDim.x) { - this->temp_storage.match_meta_data[i] = kernel_param.d_match_meta_data[i]; + this->temp_storage.sym_to_sgid[i] = kernel_param.sym_to_sgid[i]; } __syncthreads(); #else // CPU-side init for (std::size_t i = 0; i < kernel_param.num_luts; i++) { - this->temp_storage.match_meta_data[i] = kernel_param.d_match_meta_data[i]; + this->temp_storage.sym_to_sgid[i] = kernel_param.sym_to_sgid[i]; } #endif } - __host__ __device__ __forceinline__ int32_t operator()(SymbolT const symbol) const + constexpr CUDF_HOST_DEVICE int32_t operator()(SymbolT const symbol) const { // Look up the symbol group for given symbol - return temp_storage.match_meta_data[min(symbol, num_valid_entries - 1)]; + return temp_storage.sym_to_sgid[min(symbol, num_valid_entries - 1)]; } }; diff --git a/cpp/src/io/fst/transition_table.cuh b/cpp/src/io/fst/transition_table.cuh index 97fef03d8af..5eccb926974 100644 --- a/cpp/src/io/fst/transition_table.cuh +++ b/cpp/src/io/fst/transition_table.cuh @@ -16,6 +16,10 @@ #pragma once +#include +#include +#include + #include #include @@ -25,103 +29,50 @@ namespace io { namespace fst { namespace detail { -template -struct TransitionTable { - //------------------------------------------------------------------------------ - // DEFAULT TYPEDEFS - //------------------------------------------------------------------------------ +template +class TransitionTable { + private: + // Type used using ItemT = char; - struct TransitionVectorWrapper { - const ItemT* data; - - __host__ __device__ TransitionVectorWrapper(const ItemT* data) : data(data) {} - - __host__ __device__ __forceinline__ uint32_t Get(int32_t index) const { return data[index]; } - }; - - //------------------------------------------------------------------------------ - // TYPEDEFS - //------------------------------------------------------------------------------ - using TransitionVectorT = TransitionVectorWrapper; - struct _TempStorage { - // ItemT transitions[MAX_NUM_STATES * MAX_NUM_SYMBOLS]; }; - struct TempStorage : cub::Uninitialized<_TempStorage> { - }; + public: + using TempStorage = cub::Uninitialized<_TempStorage>; struct KernelParameter { - ItemT* transitions; + ItemT transitions[MAX_NUM_STATES * MAX_NUM_SYMBOLS]; }; - using LoadAliasT = std::uint32_t; - - static constexpr std::size_t NUM_AUX_MEM_BYTES = - CUB_QUOTIENT_CEILING(MAX_NUM_STATES * MAX_NUM_SYMBOLS * sizeof(ItemT), sizeof(LoadAliasT)) * - sizeof(LoadAliasT); - - //------------------------------------------------------------------------------ - // HELPER METHODS - //------------------------------------------------------------------------------ - __host__ static cudaError_t CreateTransitionTable( - void* d_temp_storage, - size_t& temp_storage_bytes, - const std::vector>& trans_table, - KernelParameter& kernel_param, - cudaStream_t stream = 0) + static void InitDeviceTransitionTable(hostdevice_vector& transition_table_init, + const std::vector>& trans_table, + rmm::cuda_stream_view stream) { - if (!d_temp_storage) { - temp_storage_bytes = NUM_AUX_MEM_BYTES; - return cudaSuccess; - } - - // trans_vectors[symbol][state] -> new_state - ItemT trans_vectors[MAX_NUM_STATES * MAX_NUM_SYMBOLS]; - // trans_table[state][symbol] -> new state for (std::size_t state = 0; state < trans_table.size(); ++state) { for (std::size_t symbol = 0; symbol < trans_table[state].size(); ++symbol) { - trans_vectors[symbol * MAX_NUM_STATES + state] = trans_table[state][symbol]; + transition_table_init.host_ptr()->transitions[symbol * MAX_NUM_STATES + state] = + trans_table[state][symbol]; } } - kernel_param.transitions = static_cast(d_temp_storage); - // Copy transition table to device - return cudaMemcpyAsync( - d_temp_storage, trans_vectors, NUM_AUX_MEM_BYTES, cudaMemcpyHostToDevice, stream); + transition_table_init.host_to_device(stream); } - //------------------------------------------------------------------------------ - // MEMBER VARIABLES - //------------------------------------------------------------------------------ - _TempStorage& temp_storage; - - __device__ __forceinline__ _TempStorage& PrivateStorage() - { - __shared__ _TempStorage private_storage; - return private_storage; - } - - //------------------------------------------------------------------------------ - // CONSTRUCTOR - //------------------------------------------------------------------------------ - __host__ __device__ __forceinline__ TransitionTable(const KernelParameter& kernel_param, - TempStorage& temp_storage) + constexpr CUDF_HOST_DEVICE TransitionTable(const KernelParameter& kernel_param, + TempStorage& temp_storage) : temp_storage(temp_storage.Alias()) { #if CUB_PTX_ARCH > 0 - for (int i = threadIdx.x; i < CUB_QUOTIENT_CEILING(NUM_AUX_MEM_BYTES, sizeof(LoadAliasT)); - i += blockDim.x) { - reinterpret_cast(this->temp_storage.transitions)[i] = - reinterpret_cast(kernel_param.transitions)[i]; + for (int i = threadIdx.x; i < MAX_NUM_STATES * MAX_NUM_SYMBOLS; i += blockDim.x) { + this->temp_storage.transitions[i] = kernel_param.transitions[i]; } __syncthreads(); #else - for (int i = 0; i < kernel_param.num_luts; i++) { + for (int i = 0; i < MAX_NUM_STATES * MAX_NUM_SYMBOLS; i++) { this->temp_storage.transitions[i] = kernel_param.transitions[i]; } #endif @@ -136,11 +87,21 @@ struct TransitionTable { * @return */ template - __host__ __device__ __forceinline__ int32_t operator()(StateIndexT state_id, - SymbolIndexT match_id) const + constexpr CUDF_HOST_DEVICE int32_t operator()(StateIndexT const state_id, + SymbolIndexT const match_id) const { return temp_storage.transitions[match_id * MAX_NUM_STATES + state_id]; - } + } + + private: + _TempStorage& temp_storage; + + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + + return private_storage; + } }; } // namespace detail diff --git a/cpp/src/io/fst/translation_table.cuh b/cpp/src/io/fst/translation_table.cuh index bfbfd41e3f0..89da994606c 100644 --- a/cpp/src/io/fst/translation_table.cuh +++ b/cpp/src/io/fst/translation_table.cuh @@ -16,7 +16,12 @@ #pragma once -#include "in_reg_array.cuh" +#include +#include +#include +#include + +#include "rmm/device_uvector.hpp" #include @@ -28,10 +33,10 @@ namespace fst { namespace detail { /** - * @brief Lookup table mapping (old_state, symbol_group_id) transitions to a sequence of symbols to - * output + * @brief Lookup table mapping (old_state, symbol_group_id) transitions to a sequence of symbols + * that the finite-state transducer is supposed to output for each transition * - * @tparam OutSymbolT The symbol type being returned + * @tparam OutSymbolT The symbol type being output * @tparam OutSymbolOffsetT Type sufficiently large to index into the lookup table of output symbols * @tparam MAX_NUM_SYMBOLS The maximum number of symbols being output by a single state transition * @tparam MAX_NUM_STATES The maximum number of states that this lookup table shall support @@ -42,57 +47,35 @@ template -struct TransducerLookupTable { - //------------------------------------------------------------------------------ - // TYPEDEFS - //------------------------------------------------------------------------------ +class TransducerLookupTable { + private: struct _TempStorage { OutSymbolOffsetT out_offset[MAX_NUM_STATES * MAX_NUM_SYMBOLS + 1]; OutSymbolT out_symbols[MAX_TABLE_SIZE]; }; - struct TempStorage : cub::Uninitialized<_TempStorage> { - }; + public: + using TempStorage = cub::Uninitialized<_TempStorage>; struct KernelParameter { - OutSymbolOffsetT* d_trans_offsets; - OutSymbolT* d_out_symbols; + OutSymbolOffsetT d_out_offsets[MAX_NUM_STATES * MAX_NUM_SYMBOLS + 1]; + OutSymbolT d_out_symbols[MAX_TABLE_SIZE]; }; - //------------------------------------------------------------------------------ - // HELPER METHODS - //------------------------------------------------------------------------------ - __host__ static cudaError_t CreateTransitionTable( - void* d_temp_storage, - size_t& temp_storage_bytes, - const std::vector>>& trans_table, - KernelParameter& kernel_param, - cudaStream_t stream = 0) + /** + * @brief Initializes the translation table (both the host and device parts) + */ + static void InitDeviceTranslationTable( + hostdevice_vector& translation_table_init, + std::vector>> const& trans_table, + rmm::cuda_stream_view stream) { - enum { MEM_OFFSETS = 0, MEM_OUT_SYMBOLS, NUM_ALLOCATIONS }; - - size_t allocation_sizes[NUM_ALLOCATIONS] = {}; - void* allocations[NUM_ALLOCATIONS] = {}; - allocation_sizes[MEM_OFFSETS] = - (MAX_NUM_STATES * MAX_NUM_SYMBOLS + 1) * sizeof(OutSymbolOffsetT); - allocation_sizes[MEM_OUT_SYMBOLS] = MAX_TABLE_SIZE * sizeof(OutSymbolT); - - // Alias the temporary allocations from the single storage blob (or compute the necessary size - // of the blob) - cudaError_t error = - cub::AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes); - if (error) return error; - - // Return if the caller is simply requesting the size of the storage allocation - if (d_temp_storage == nullptr) return cudaSuccess; - std::vector out_symbols; out_symbols.reserve(MAX_TABLE_SIZE); std::vector out_symbol_offsets; out_symbol_offsets.reserve(MAX_NUM_STATES * MAX_NUM_SYMBOLS + 1); out_symbol_offsets.push_back(0); - int st = 0; // Iterate over the states in the transition table for (auto const& state_trans : trans_table) { uint32_t num_added = 0; @@ -103,7 +86,6 @@ struct TransducerLookupTable { out_symbol_offsets.push_back(out_symbols.size()); num_added++; } - st++; // Copy the last offset for all symbols (to guarantee a proper lookup for omitted symbols of // this state) @@ -115,30 +97,21 @@ struct TransducerLookupTable { } // Check whether runtime-provided table size exceeds the compile-time given max. table size - if (out_symbols.size() > MAX_TABLE_SIZE) { return cudaErrorInvalidValue; } - - kernel_param.d_trans_offsets = static_cast(allocations[MEM_OFFSETS]); - kernel_param.d_out_symbols = static_cast(allocations[MEM_OUT_SYMBOLS]); - - // Copy out symbols - error = cudaMemcpyAsync(kernel_param.d_trans_offsets, - out_symbol_offsets.data(), - out_symbol_offsets.size() * sizeof(out_symbol_offsets[0]), - cudaMemcpyHostToDevice, - stream); - if (error) { return error; } - - // Copy offsets into output symbols - return cudaMemcpyAsync(kernel_param.d_out_symbols, - out_symbols.data(), - out_symbols.size() * sizeof(out_symbols[0]), - cudaMemcpyHostToDevice, - stream); + if (out_symbols.size() > MAX_TABLE_SIZE) { CUDF_FAIL("Unsupported translation table"); } + + // Prepare host-side data to be copied and passed to the device + std::copy(std::cbegin(out_symbol_offsets), + std::cend(out_symbol_offsets), + translation_table_init.host_ptr()->d_out_offsets); + std::copy(std::cbegin(out_symbols), + std::cend(out_symbols), + translation_table_init.host_ptr()->d_out_symbols); + + // Copy data to device + translation_table_init.host_to_device(stream); } - //------------------------------------------------------------------------------ - // MEMBER VARIABLES - //------------------------------------------------------------------------------ + private: _TempStorage& temp_storage; __device__ __forceinline__ _TempStorage& PrivateStorage() @@ -147,17 +120,19 @@ struct TransducerLookupTable { return private_storage; } - //------------------------------------------------------------------------------ - // CONSTRUCTOR - //------------------------------------------------------------------------------ - __host__ __device__ __forceinline__ TransducerLookupTable(const KernelParameter& kernel_param, - TempStorage& temp_storage) + public: + /** + * @brief Synchronizes the thread block, if called from device, and, hence, requires all threads + * of the thread block to call the constructor + */ + CUDF_HOST_DEVICE TransducerLookupTable(KernelParameter const& kernel_param, + TempStorage& temp_storage) : temp_storage(temp_storage.Alias()) { constexpr uint32_t num_offsets = MAX_NUM_STATES * MAX_NUM_SYMBOLS + 1; #if CUB_PTX_ARCH > 0 for (int i = threadIdx.x; i < num_offsets; i += blockDim.x) { - this->temp_storage.out_offset[i] = kernel_param.d_trans_offsets[i]; + this->temp_storage.out_offset[i] = kernel_param.d_out_offsets[i]; } // Make sure all threads in the block can read out_symbol_offsets[num_offsets - 1] from shared // memory @@ -168,7 +143,7 @@ struct TransducerLookupTable { __syncthreads(); #else for (int i = 0; i < num_offsets; i++) { - this->temp_storage.out_symbol_offsets[i] = kernel_param.d_trans_offsets[i]; + this->temp_storage.out_symbol_offsets[i] = kernel_param.d_out_offsets[i]; } for (int i = 0; i < this->temp_storage.out_symbol_offsets[i]; i++) { this->temp_storage.out_symbols[i] = kernel_param.d_out_symbols[i]; @@ -177,17 +152,17 @@ struct TransducerLookupTable { } template - __host__ __device__ __forceinline__ OutSymbolT operator()(StateIndexT state_id, - SymbolIndexT match_id, - RelativeOffsetT relative_offset) const + constexpr CUDF_HOST_DEVICE OutSymbolT operator()(StateIndexT const state_id, + SymbolIndexT const match_id, + RelativeOffsetT const relative_offset) const { auto offset = temp_storage.out_offset[state_id * MAX_NUM_SYMBOLS + match_id] + relative_offset; return temp_storage.out_symbols[offset]; } template - __host__ __device__ __forceinline__ OutSymbolOffsetT operator()(StateIndexT state_id, - SymbolIndexT match_id) const + constexpr CUDF_HOST_DEVICE OutSymbolOffsetT operator()(StateIndexT const state_id, + SymbolIndexT const match_id) const { return temp_storage.out_offset[state_id * MAX_NUM_SYMBOLS + match_id + 1] - temp_storage.out_offset[state_id * MAX_NUM_SYMBOLS + match_id];