Skip to content

Commit

Permalink
some west-const remainders & unifies StateIndexT
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jul 13, 2022
1 parent 5f1c4b5 commit e6f8def
Showing 1 changed file with 53 additions and 54 deletions.
107 changes: 53 additions & 54 deletions cpp/src/io/fst/agent_dfa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

namespace cudf::io::fst::detail {

/// Type used to enumerate (and index) into the states defined by a DFA
using StateIndexT = uint32_t;

//-----------------------------------------------------------------------------
// DFA-SIMULATION STATE COMPOSITION FUNCTORS
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -105,7 +108,7 @@ template <int32_t NUM_INSTANCES, typename TransitionTableT>
class StateVectorTransitionOp {
public:
__host__ __device__ __forceinline__ StateVectorTransitionOp(
TransitionTableT const& transition_table, std::array<int32_t, NUM_INSTANCES>& state_vector)
TransitionTableT const& transition_table, std::array<StateIndexT, NUM_INSTANCES>& state_vector)
: transition_table(transition_table), state_vector(state_vector)
{
}
Expand All @@ -120,18 +123,18 @@ class StateVectorTransitionOp {
}

public:
std::array<int32_t, NUM_INSTANCES>& state_vector;
const TransitionTableT& transition_table;
std::array<StateIndexT, NUM_INSTANCES>& state_vector;
TransitionTableT const& transition_table;
};

template <typename CallbackOpT, typename TransitionTableT>
struct StateTransitionOp {
int32_t state;
const TransitionTableT& transition_table;
StateIndexT state;
TransitionTableT const& transition_table;
CallbackOpT& callback_op;

__host__ __device__ __forceinline__ StateTransitionOp(TransitionTableT const& transition_table,
int32_t state,
StateIndexT state,
CallbackOpT& callback_op)
: transition_table(transition_table), state(state), callback_op(callback_op)
{
Expand All @@ -142,7 +145,7 @@ struct StateTransitionOp {
SymbolIndexT const& read_symbol_id)
{
// Remember what state we were in before we made the transition
int32_t previous_state = state;
StateIndexT previous_state = state;

state = transition_table(state, read_symbol_id);
callback_op.ReadSymbol(character_index, previous_state, state, read_symbol_id);
Expand All @@ -152,7 +155,6 @@ struct StateTransitionOp {
template <typename AgentDFAPolicy, typename SymbolItT, typename OffsetT>
struct AgentDFA {
using SymbolIndexT = uint32_t;
using StateIndexT = uint32_t;
using AliasedLoadT = uint32_t;
using CharT = typename std::iterator_traits<SymbolItT>::value_type;

Expand Down Expand Up @@ -200,22 +202,21 @@ struct AgentDFA {
{
}

template <int32_t NUM_SYMBOLS, // The net (excluding overlap) number of characters to be parsed
typename SymbolMatcherT, // The symbol matcher returning the matched symbol and its
// length
typename CallbackOpT, // Callback operator
template <int32_t NUM_SYMBOLS,
typename SymbolMatcherT,
typename CallbackOpT,
int32_t IS_FULL_BLOCK>
__device__ __forceinline__ static void ThreadParse(const SymbolMatcherT& symbol_matcher,
const CharT* chars,
const SymbolIndexT& max_num_chars,
__device__ __forceinline__ static void ThreadParse(SymbolMatcherT const& symbol_matcher,
CharT const* chars,
SymbolIndexT const& max_num_chars,
CallbackOpT callback_op,
cub::Int2Type<IS_FULL_BLOCK> /*IS_FULL_BLOCK*/)
{
// Iterate over symbols
#pragma unroll
for (int32_t i = 0; i < NUM_SYMBOLS; ++i) {
if (IS_FULL_BLOCK || threadIdx.x * SYMBOLS_PER_THREAD + i < max_num_chars) {
uint32_t matched_id = symbol_matcher(chars[i]);
auto matched_id = symbol_matcher(chars[i]);
callback_op.ReadSymbol(i, matched_id);
}
}
Expand All @@ -226,9 +227,9 @@ struct AgentDFA {
typename StateTransitionOpT,
int32_t IS_FULL_BLOCK>
__device__ __forceinline__ void GetThreadStateTransitions(
const SymbolMatcherT& symbol_matcher,
const CharT* chars,
const SymbolIndexT& max_num_chars,
SymbolMatcherT const& symbol_matcher,
CharT const* chars,
SymbolIndexT const& max_num_chars,
StateTransitionOpT& state_transition_op,
cub::Int2Type<IS_FULL_BLOCK> /*IS_FULL_BLOCK*/)
{
Expand All @@ -239,15 +240,15 @@ struct AgentDFA {
//---------------------------------------------------------------------
// LOADING FULL BLOCK OF CHARACTERS, NON-ALIASED
//---------------------------------------------------------------------
__device__ __forceinline__ void LoadBlock(const CharT* d_chars,
const OffsetT block_offset,
const OffsetT num_total_symbols,
__device__ __forceinline__ void LoadBlock(CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
cub::Int2Type<true> /*IS_FULL_BLOCK*/,
cub::Int2Type<1> /*ALIGNMENT*/)
{
CharT thread_chars[SYMBOLS_PER_THREAD];

const CharT* d_block_symbols = d_chars + block_offset;
CharT const* d_block_symbols = d_chars + block_offset;
cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_block_symbols, thread_chars);

#pragma unroll
Expand All @@ -259,9 +260,9 @@ struct AgentDFA {
//---------------------------------------------------------------------
// LOADING PARTIAL BLOCK OF CHARACTERS, NON-ALIASED
//---------------------------------------------------------------------
__device__ __forceinline__ void LoadBlock(const CharT* d_chars,
const OffsetT block_offset,
const OffsetT num_total_symbols,
__device__ __forceinline__ void LoadBlock(CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
cub::Int2Type<false> /*IS_FULL_BLOCK*/,
cub::Int2Type<1> /*ALIGNMENT*/)
{
Expand All @@ -272,7 +273,7 @@ struct AgentDFA {
// Last unit to be loaded is IDIV_CEIL(#SYM, SYMBOLS_PER_UNIT)
OffsetT num_total_chars = num_total_symbols - block_offset;

const CharT* d_block_symbols = d_chars + block_offset;
CharT const* d_block_symbols = d_chars + block_offset;
cub::LoadDirectStriped<BLOCK_THREADS>(
threadIdx.x, d_block_symbols, thread_chars, num_total_chars);

Expand All @@ -285,16 +286,16 @@ struct AgentDFA {
//---------------------------------------------------------------------
// LOADING FULL BLOCK OF CHARACTERS, ALIASED
//---------------------------------------------------------------------
__device__ __forceinline__ void LoadBlock(const CharT* d_chars,
const OffsetT block_offset,
const OffsetT num_total_symbols,
__device__ __forceinline__ void LoadBlock(CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
cub::Int2Type<true> /*IS_FULL_BLOCK*/,
cub::Int2Type<sizeof(AliasedLoadT)> /*ALIGNMENT*/)
{
AliasedLoadT thread_units[UINTS_PER_THREAD];

const AliasedLoadT* d_block_symbols =
reinterpret_cast<const AliasedLoadT*>(d_chars + block_offset);
AliasedLoadT const* d_block_symbols =
reinterpret_cast<AliasedLoadT const*>(d_chars + block_offset);
cub::LoadDirectStriped<BLOCK_THREADS>(threadIdx.x, d_block_symbols, thread_units);

#pragma unroll
Expand All @@ -306,9 +307,9 @@ struct AgentDFA {
//---------------------------------------------------------------------
// LOADING PARTIAL BLOCK OF CHARACTERS, ALIASED
//---------------------------------------------------------------------
__device__ __forceinline__ void LoadBlock(const CharT* d_chars,
const OffsetT block_offset,
const OffsetT num_total_symbols,
__device__ __forceinline__ void LoadBlock(CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
cub::Int2Type<false> /*IS_FULL_BLOCK*/,
cub::Int2Type<sizeof(AliasedLoadT)> /*ALIGNMENT*/)
{
Expand All @@ -320,8 +321,8 @@ struct AgentDFA {
OffsetT num_total_units =
CUB_QUOTIENT_CEILING(num_total_symbols - block_offset, sizeof(AliasedLoadT));

const AliasedLoadT* d_block_symbols =
reinterpret_cast<const AliasedLoadT*>(d_chars + block_offset);
AliasedLoadT const* d_block_symbols =
reinterpret_cast<AliasedLoadT const*>(d_chars + block_offset);
cub::LoadDirectStriped<BLOCK_THREADS>(
threadIdx.x, d_block_symbols, thread_units, num_total_units);

Expand All @@ -334,9 +335,9 @@ struct AgentDFA {
//---------------------------------------------------------------------
// LOADING BLOCK OF CHARACTERS: DISPATCHER
//---------------------------------------------------------------------
__device__ __forceinline__ void LoadBlock(const CharT* d_chars,
const OffsetT block_offset,
const OffsetT num_total_symbols)
__device__ __forceinline__ void LoadBlock(CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols)
{
// Check if pointer is aligned to four bytes
if (((uintptr_t)(const void*)(d_chars + block_offset) % 4) == 0) {
Expand All @@ -360,12 +361,12 @@ struct AgentDFA {

template <int32_t NUM_STATES, typename SymbolMatcherT, typename TransitionTableT>
__device__ __forceinline__ void GetThreadStateTransitionVector(
const SymbolMatcherT& symbol_matcher,
const TransitionTableT& transition_table,
const CharT* d_chars,
const OffsetT block_offset,
const OffsetT num_total_symbols,
std::array<int32_t, NUM_STATES>& state_vector)
SymbolMatcherT const& symbol_matcher,
TransitionTableT const& transition_table,
CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
std::array<StateIndexT, NUM_STATES>& state_vector)
{
using StateVectorTransitionOpT = StateVectorTransitionOp<NUM_STATES, TransitionTableT>;

Expand Down Expand Up @@ -405,7 +406,7 @@ struct AgentDFA {
CharT const* d_chars,
OffsetT const block_offset,
OffsetT const num_total_symbols,
int32_t& state,
StateIndexT& state,
CallbackOpT& callback_op,
cub::Int2Type<BYPASS_LOAD> /**/)
{
Expand Down Expand Up @@ -459,19 +460,17 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
void SimulateDFAKernel(DfaT dfa,
SymbolItT d_chars,
OffsetT const num_chars,
uint32_t seed_state,
StateIndexT seed_state,
StateVectorT* __restrict__ d_thread_state_transition,
TileStateT tile_state,
OutOffsetScanTileState offset_tile_state,
TransducedOutItT transduced_out_it,
TransducedIndexOutItT transduced_out_idx_it,
TransducedCountOutItT d_num_transduced_out_it)
{
using StateIndexT = uint32_t;

using AgentDfaSimT = AgentDFA<AgentDFAPolicy, SymbolItT, OffsetT>;

static constexpr uint32_t NUM_STATES = DfaT::MAX_NUM_STATES;
static constexpr int32_t NUM_STATES = DfaT::MAX_NUM_STATES;

enum {
BLOCK_THREADS = AgentDFAPolicy::BLOCK_THREADS,
Expand Down Expand Up @@ -509,7 +508,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
// Stage 1: Compute the state-transition vector
if (IS_TRANS_VECTOR_PASS || IS_SINGLE_PASS) {
// Keeping track of the state for each of the <NUM_STATES> state machines
std::array<int32_t, NUM_STATES> state_vector;
std::array<StateIndexT, NUM_STATES> state_vector;

// Initialize the seed state transition vector with the identity vector
#pragma unroll
Expand Down Expand Up @@ -539,7 +538,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__

// Stage 2: Perform FSM simulation
if ((!IS_TRANS_VECTOR_PASS) || IS_SINGLE_PASS) {
int32_t state = 0;
StateIndexT state = 0;

//------------------------------------------------------------------------------
// SINGLE-PASS:
Expand Down Expand Up @@ -599,7 +598,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) __global__
TransducedIndexOutItT>
callback_wrapper(transducer_table, transduced_out_it, transduced_out_idx_it);

int32_t t_start_state = state;
StateIndexT t_start_state = state;
agent_dfa.GetThreadStateTransitions(symbol_matcher,
transition_table,
d_chars,
Expand Down

0 comments on commit e6f8def

Please sign in to comment.