Skip to content

Commit

Permalink
refactored lookup tables
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jul 13, 2022
1 parent 355d1e4 commit 39a6b65
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 317 deletions.
3 changes: 0 additions & 3 deletions cpp/src/io/fst/agent_dfa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ class StateVectorTransitionOp : public StateTransitionCallbackOp {
__host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index,
SymbolIndexT const read_symbol_id) const
{
using TransitionVectorT = typename TransitionTableT::TransitionVectorT;

for (int32_t i = 0; i < NUM_INSTANCES; ++i) {
state_vector.Set(i, transition_table(state_vector.Get(i), read_symbol_id));
}
Expand Down Expand Up @@ -185,7 +183,6 @@ struct StateTransitionOp {
__host__ __device__ __forceinline__ void ReadSymbol(const CharIndexT& character_index,
const SymbolIndexT& read_symbol_id)
{
using TransitionVectorT = typename TransitionTableT::TransitionVectorT;
old_state_vector = state_vector;
state_vector.Set(0, transition_table(state_vector.Get(0), read_symbol_id));
callback_op.ReadSymbol(character_index, old_state_vector, state_vector, read_symbol_id);
Expand Down
192 changes: 89 additions & 103 deletions cpp/src/io/fst/device_dfa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
*/
#pragma once

#include "cub/util_type.cuh"
#include "dispatch_dfa.cuh"

#include <io/utilities/hostdevice_vector.hpp>
#include <src/io/fst/symbol_lut.cuh>
#include <src/io/fst/transition_table.cuh>
#include <src/io/fst/translation_table.cuh>
Expand Down Expand Up @@ -95,140 +96,121 @@ cudaError_t DeviceTransduce(void* d_temp_storage,
stream);
}

/**
* @brief Helper class to facilitate the specification and instantiation of a DFA (i.e., the
* transition table and its number of states, the mapping of symbols to symbol groups, and the
* translation table that specifies which state transitions cause which output to be written).
*
* @tparam OutSymbolT The symbol type being output by the finite-state transducer
* @tparam NUM_SYMBOLS The number of symbol groups amongst which to differentiate (one dimension of
* the transition table)
* @tparam TT_NUM_STATES The number of states defined by the DFA (the other dimension of the
* transition table)
*/
template <typename OutSymbolT, int32_t NUM_SYMBOLS, int32_t TT_NUM_STATES>
class Dfa {
template <typename SymbolGroupIdLookupT,
typename TransitionTableT,
typename TranslationTableT,
int32_t NUM_STATES>
class dfa_device_view {
private:
using sgid_lut_init_t = typename SymbolGroupIdLookupT::KernelParameter;
using transition_table_init_t = typename TransitionTableT::KernelParameter;
using translation_table_init_t = typename TranslationTableT::KernelParameter;

public:
// The maximum number of states supported by this DFA instance
// This is a value queried by the DFA simulation algorithm
static constexpr int32_t MAX_NUM_STATES = TT_NUM_STATES;
static constexpr int32_t MAX_NUM_STATES = NUM_STATES;

private:
// Symbol-group id lookup table
using MatcherT = detail::SingleSymbolSmemLUT<char>;
using MatcherInitT = typename MatcherT::KernelParameter;

// Transition table
using TransitionTableT = detail::TransitionTable<NUM_SYMBOLS + 1, TT_NUM_STATES>;
using TransitionTableInitT = typename TransitionTableT::KernelParameter;

// Translation lookup table
using OutSymbolOffsetT = uint32_t;
using TransducerTableT = detail::TransducerLookupTable<OutSymbolT,
OutSymbolOffsetT,
NUM_SYMBOLS + 1,
TT_NUM_STATES,
(NUM_SYMBOLS + 1) * TT_NUM_STATES>;
using TransducerTableInitT = typename TransducerTableT::KernelParameter;

// Private members (passed between host/device)
/// Information to initialize the device-side lookup table that maps symbol -> symbol group id
MatcherInitT symbol_matcher_init;

/// Information to initialize the device-side transition table
TransitionTableInitT tt_init;

/// Information to initialize the device-side translation table
TransducerTableInitT tt_out_init;

public:
//---------------------------------------------------------------------
// DEVICE-SIDE MEMBER FUNCTIONS
//---------------------------------------------------------------------
using SymbolGroupStorageT = typename MatcherT::TempStorage;
using SymbolGroupStorageT = typename SymbolGroupIdLookupT::TempStorage;
using TransitionTableStorageT = typename TransitionTableT::TempStorage;
using TranslationTableStorageT = typename TransducerTableT::TempStorage;
using TranslationTableStorageT = typename TranslationTableT::TempStorage;

__device__ auto InitSymbolGroupLUT(SymbolGroupStorageT& temp_storage)
{
return MatcherT(symbol_matcher_init, temp_storage);
return SymbolGroupIdLookupT(*d_sgid_lut_init, temp_storage);
}

__device__ auto InitTransitionTable(TransitionTableStorageT& temp_storage)
{
return TransitionTableT(tt_init, temp_storage);
return TransitionTableT(*d_transition_table_init, temp_storage);
}

__device__ auto InitTranslationTable(TranslationTableStorageT& temp_storage)
{
return TransducerTableT(tt_out_init, temp_storage);
return TranslationTableT(*d_translation_table_init, temp_storage);
}

//---------------------------------------------------------------------
// HOST-SIDE MEMBER FUNCTIONS
//---------------------------------------------------------------------
template <typename StateIdT, typename SymbolGroupIdItT>
cudaError_t Init(SymbolGroupIdItT const& symbol_vec,
std::vector<std::vector<StateIdT>> const& tt_vec,
std::vector<std::vector<std::vector<OutSymbolT>>> const& out_tt_vec,
cudaStream_t stream = 0)
dfa_device_view(sgid_lut_init_t const* d_sgid_lut_init,
transition_table_init_t const* d_transition_table_init,
translation_table_init_t const* d_translation_table_init)
: d_sgid_lut_init(d_sgid_lut_init),
d_transition_table_init(d_transition_table_init),
d_translation_table_init(d_translation_table_init)
{
cudaError_t error = cudaSuccess;

enum : uint32_t { MEM_SYMBOL_MATCHER = 0, MEM_TT, MEM_OUT_TT, NUM_ALLOCATIONS };
}

size_t allocation_sizes[NUM_ALLOCATIONS] = {0};
void* allocations[NUM_ALLOCATIONS] = {0};
private:
sgid_lut_init_t const* d_sgid_lut_init;
transition_table_init_t const* d_transition_table_init;
translation_table_init_t const* d_translation_table_init;
};

// Memory requirements: lookup table
error = MatcherT::PrepareLUT(
nullptr, allocation_sizes[MEM_SYMBOL_MATCHER], symbol_vec, symbol_matcher_init);
if (error) return error;
/**
* @brief Helper class to facilitate the specification and instantiation of a DFA (i.e., the
* transition table and its number of states, the mapping of symbols to symbol groups, and the
* translation table that specifies which state transitions cause which output to be written).
*
* @tparam OutSymbolT The symbol type being output by the finite-state transducer
* @tparam NUM_SYMBOLS The number of symbol groups amongst which to differentiate (one dimension of
* the transition table)
* @tparam NUM_STATES The number of states defined by the DFA (the other dimension of the
* transition table)
*/
template <typename OutSymbolT, int32_t NUM_SYMBOLS, int32_t NUM_STATES>
class Dfa {
public:
// The maximum number of states supported by this DFA instance
// This is a value queried by the DFA simulation algorithm
static constexpr int32_t MAX_NUM_STATES = NUM_STATES;

// Memory requirements: transition table
error =
TransitionTableT::CreateTransitionTable(nullptr, allocation_sizes[MEM_TT], tt_vec, tt_init);
if (error) return error;
private:
// Symbol-group id lookup table
using SymbolGroupIdLookupT = detail::SingleSymbolSmemLUT<char>;
using SymbolGroupIdInitT = typename SymbolGroupIdLookupT::KernelParameter;

// Memory requirements: transducer table
error = TransducerTableT::CreateTransitionTable(
nullptr, allocation_sizes[MEM_OUT_TT], out_tt_vec, tt_out_init);
if (error) return error;
// Transition table
using TransitionTableT = detail::TransitionTable<NUM_SYMBOLS + 1, NUM_STATES>;
using TransitionTableInitT = typename TransitionTableT::KernelParameter;

// Memory requirements: total memory
size_t temp_storage_bytes = 0;
error = cub::AliasTemporaries(nullptr, temp_storage_bytes, allocations, allocation_sizes);
if (error) return error;
// Translation lookup table
using OutSymbolOffsetT = uint32_t;
using TranslationTableT = detail::TransducerLookupTable<OutSymbolT,
OutSymbolOffsetT,
NUM_SYMBOLS + 1,
NUM_STATES,
(NUM_SYMBOLS + 1) * NUM_STATES>;
using TranslationTableInitT = typename TranslationTableT::KernelParameter;

auto get_device_view()
{
return dfa_device_view<SymbolGroupIdLookupT, TransitionTableT, TranslationTableT, NUM_STATES>{
sgid_init.d_begin(), transition_table_init.d_begin(), translation_table_init.d_begin()};
}

// Allocate memory
void* d_temp_storage = nullptr;
error = cudaMalloc(&d_temp_storage, temp_storage_bytes);
if (error) return error;
public:
template <typename StateIdT, typename SymbolGroupIdItT>
Dfa(SymbolGroupIdItT const& symbol_vec,
std::vector<std::vector<StateIdT>> const& tt_vec,
std::vector<std::vector<std::vector<OutSymbolT>>> const& out_tt_vec,
cudaStream_t stream)
{
constexpr std::size_t single_item = 1;

// Alias memory
error =
cub::AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes);
if (error) return error;
sgid_init = hostdevice_vector<SymbolGroupIdInitT>{single_item, stream};
transition_table_init = hostdevice_vector<TransitionTableInitT>{single_item, stream};
translation_table_init = hostdevice_vector<TranslationTableInitT>{single_item, stream};

// Initialize symbol group lookup table
error = MatcherT::PrepareLUT(allocations[MEM_SYMBOL_MATCHER],
allocation_sizes[MEM_SYMBOL_MATCHER],
symbol_vec,
symbol_matcher_init,
stream);
if (error) return error;
// Initialize symbol group id lookup table
SymbolGroupIdLookupT::InitDeviceSymbolGroupIdLut(sgid_init, symbol_vec, stream);

// Initialize state transition table
error = TransitionTableT::CreateTransitionTable(
allocations[MEM_TT], allocation_sizes[MEM_TT], tt_vec, tt_init, stream);
if (error) return error;
TransitionTableT::InitDeviceTransitionTable(transition_table_init, tt_vec, stream);

// Initialize finite-state transducer lookup table
error = TransducerTableT::CreateTransitionTable(
allocations[MEM_OUT_TT], allocation_sizes[MEM_OUT_TT], out_tt_vec, tt_out_init, stream);
if (error) return error;

return error;
TranslationTableT::InitDeviceTranslationTable(translation_table_init, out_tt_vec, stream);
}

template <typename SymbolT,
Expand All @@ -248,7 +230,7 @@ class Dfa {
{
return DeviceTransduce(d_temp_storage,
temp_storage_bytes,
*this,
this->get_device_view(),
d_chars,
num_chars,
d_out_it,
Expand All @@ -257,8 +239,12 @@ class Dfa {
seed_state,
stream);
}
};

private:
hostdevice_vector<SymbolGroupIdInitT> sgid_init{};
hostdevice_vector<TransitionTableInitT> transition_table_init{};
hostdevice_vector<TranslationTableInitT> translation_table_init{};
};
} // namespace fst
} // namespace io
} // namespace cudf
Loading

0 comments on commit 39a6b65

Please sign in to comment.