diff --git a/cpp/benchmarks/io/fst.cu b/cpp/benchmarks/io/fst.cu index ad19bdfdfcb..31f1bf8e70f 100644 --- a/cpp/benchmarks/io/fst.cu +++ b/cpp/benchmarks/io/fst.cu @@ -95,7 +95,9 @@ void BM_FST_JSON(nvbench::state& state) auto parser = cudf::io::fst::detail::make_fst( cudf::io::fst::detail::make_symbol_group_lut(pda_sgs), cudf::io::fst::detail::make_transition_table(pda_state_tt), - cudf::io::fst::detail::make_translation_table(pda_out_tt), + cudf::io::fst::detail::make_translation_table(pda_out_tt), stream); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); @@ -134,7 +136,9 @@ void BM_FST_JSON_no_outidx(nvbench::state& state) auto parser = cudf::io::fst::detail::make_fst( cudf::io::fst::detail::make_symbol_group_lut(pda_sgs), cudf::io::fst::detail::make_transition_table(pda_state_tt), - cudf::io::fst::detail::make_translation_table(pda_out_tt), + cudf::io::fst::detail::make_translation_table(pda_out_tt), stream); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); @@ -171,7 +175,9 @@ void BM_FST_JSON_no_out(nvbench::state& state) auto parser = cudf::io::fst::detail::make_fst( cudf::io::fst::detail::make_symbol_group_lut(pda_sgs), cudf::io::fst::detail::make_transition_table(pda_state_tt), - cudf::io::fst::detail::make_translation_table(pda_out_tt), + cudf::io::fst::detail::make_translation_table(pda_out_tt), stream); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); @@ -209,7 +215,9 @@ void BM_FST_JSON_no_str(nvbench::state& state) auto parser = cudf::io::fst::detail::make_fst( cudf::io::fst::detail::make_symbol_group_lut(pda_sgs), cudf::io::fst::detail::make_transition_table(pda_state_tt), - cudf::io::fst::detail::make_translation_table(pda_out_tt), + cudf::io::fst::detail::make_translation_table(pda_out_tt), stream); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); diff --git a/cpp/src/io/fst/agent_dfa.cuh b/cpp/src/io/fst/agent_dfa.cuh index 2171764decd..bc5b94e2718 100644 --- a/cpp/src/io/fst/agent_dfa.cuh +++ b/cpp/src/io/fst/agent_dfa.cuh @@ -18,7 +18,9 @@ #include "in_reg_array.cuh" #include +#include #include +#include #include namespace cudf::io::fst::detail { @@ -44,9 +46,10 @@ using StateIndexT = uint32_t; template struct VectorCompositeOp { template - __host__ __device__ __forceinline__ VectorT operator()(VectorT const& lhs, VectorT const& rhs) + __device__ __forceinline__ VectorT operator()(VectorT const& lhs, VectorT const& rhs) { VectorT res{}; +#pragma unroll for (int32_t i = 0; i < NUM_ITEMS; ++i) { res.Set(i, rhs.Get(lhs.Get(i))); } @@ -57,61 +60,275 @@ struct VectorCompositeOp { /** * @brief A class whose ReadSymbol member function is invoked for each symbol being read from the * input tape. The wrapper class looks up whether a state transition caused by a symbol is supposed - * to emit any output symbol (the "transduced" output) and, if so, keeps track of how many symbols - * it intends to write out and writing out such symbols to the given output iterators. + * to emit any output symbol (the "transduced" output) and, if so, keeps track of *how many* symbols + * it intends to write out. + */ +template +class DFACountCallbackWrapper { + public: + __device__ __forceinline__ DFACountCallbackWrapper(TransducerTableT transducer_table) + : transducer_table(transducer_table) + { + } + + template + __device__ __forceinline__ void Init(OffsetT const&) + { + out_count = 0; + } + + template + __device__ __forceinline__ void ReadSymbol(CharIndexT const character_index, + StateIndexT const old_state, + StateIndexT const new_state, + SymbolIndexT const symbol_id, + SymbolT const read_symbol) + { + uint32_t const count = transducer_table(old_state, symbol_id, read_symbol); + out_count += count; + } + + __device__ __forceinline__ void TearDown() {} + TransducerTableT const transducer_table; + uint32_t out_count{}; +}; + +/** + * @brief A class whose ReadSymbol member function is invoked for each symbol being read from the + * input tape. The wrapper class looks up whether a state transition caused by a symbol is supposed + * to emit any output symbol (the "transduced" output) and, if so, writes out such symbols to the + * given output iterators. * + * @tparam MaxTranslatedOutChars The maximum number of symbols that are written on a any given state + * transition * @tparam TransducerTableT The type implementing a transducer table that can be used for looking up * the symbols that are supposed to be emitted on a given state transition. - * @tparam TransducedOutItT A Random-access output iterator type to which symbols returned by the + * @tparam TransducedOutItT A random-access output iterator type to which symbols returned by the * transducer table are assignable. - * @tparam TransducedIndexOutItT A Random-access output iterator type to which indexes are written. + * @tparam TransducedIndexOutItT A random-access output iterator type to which indexes are written. */ -template -class DFASimulationCallbackWrapper { +template +class DFAWriteCallbackWrapper { public: - __host__ __device__ __forceinline__ DFASimulationCallbackWrapper( - TransducerTableT transducer_table, TransducedOutItT out_it, TransducedIndexOutItT out_idx_it) - : transducer_table(transducer_table), out_it(out_it), out_idx_it(out_idx_it), write(false) + __device__ __forceinline__ DFAWriteCallbackWrapper(TransducerTableT transducer_table, + TransducedOutItT out_it, + TransducedIndexOutItT out_idx_it, + uint32_t out_offset, + uint32_t /*tile_out_offset*/, + uint32_t /*tile_in_offset*/, + uint32_t /*tile_out_count*/) + : transducer_table(transducer_table), + out_it(out_it), + out_idx_it(out_idx_it), + out_offset(out_offset) { } template - __host__ __device__ __forceinline__ void Init(OffsetT const& offset) + __device__ __forceinline__ void Init(OffsetT const& in_offset) + { + this->in_offset = in_offset; + } + + template + __device__ __forceinline__ + typename ::cuda::std::enable_if<(MaxTranslatedOutChars_ <= 2), void>::type + ReadSymbol(CharIndexT const character_index, + StateIndexT const old_state, + StateIndexT const new_state, + SymbolIndexT const symbol_id, + SymbolT const read_symbol, + cub::Int2Type /*MaxTranslatedOutChars*/) + { + uint32_t const count = transducer_table(old_state, symbol_id, read_symbol); + +#pragma unroll + for (uint32_t out_char = 0; out_char < MaxTranslatedOutChars_; out_char++) { + if (out_char < count) { + out_it[out_offset + out_char] = + transducer_table(old_state, symbol_id, out_char, read_symbol); + out_idx_it[out_offset + out_char] = in_offset + character_index; + } + } + out_offset += count; + } + + template + __device__ __forceinline__ + typename ::cuda::std::enable_if<(MaxTranslatedOutChars_ > 2), void>::type + ReadSymbol(CharIndexT const character_index, + StateIndexT const old_state, + StateIndexT const new_state, + SymbolIndexT const symbol_id, + SymbolT const read_symbol, + cub::Int2Type) { - this->offset = offset; - if (!write) out_count = 0; + uint32_t const count = transducer_table(old_state, symbol_id, read_symbol); + + for (uint32_t out_char = 0; out_char < count; out_char++) { + out_it[out_offset + out_char] = transducer_table(old_state, symbol_id, out_char, read_symbol); + out_idx_it[out_offset + out_char] = in_offset + character_index; + } + out_offset += count; } template - __host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const character_index, - StateIndexT const old_state, - StateIndexT const new_state, - SymbolIndexT const symbol_id, - SymbolT const read_symbol) + __device__ __forceinline__ void ReadSymbol(CharIndexT const character_index, + StateIndexT const old_state, + StateIndexT const new_state, + SymbolIndexT const symbol_id, + SymbolT const read_symbol) + { + ReadSymbol(character_index, + old_state, + new_state, + symbol_id, + read_symbol, + cub::Int2Type{}); + } + + __device__ __forceinline__ void TearDown() {} + + public: + TransducerTableT const transducer_table; + TransducedOutItT out_it; + TransducedIndexOutItT out_idx_it; + uint32_t out_offset; + uint32_t in_offset; +}; + +/** + * @brief A class whose ReadSymbol member function is invoked for each symbol being read from the + * input tape. The wrapper class looks up whether a state transition caused by a symbol is supposed + * to emit any output symbol (the "transduced" output) and, if so, writes out such symbols to the + * given output iterators. This class uses a shared memory-backed write buffer to coalesce writes to + * global memory. + * + * @tparam DiscardIndexOutput Whether to discard the indexes instead of writing them to the given + * output iterator + * @tparam DiscardTranslatedOutput Whether to discard the translated output symbols instead of + * writing them to the given output iterator + * @tparam NumWriteBufferItems The number of items to allocate in shared memory for the write + * buffer. + * @tparam OutputT The type of the translated items + * @tparam TransducerTableT The type implementing a transducer table that can be used for looking up + * the symbols that are supposed to be emitted on a given state transition. + * @tparam TransducedOutItT A random-access output iterator type to which symbols returned by the + * transducer table are assignable. + * @tparam TransducedIndexOutItT A random-access output iterator type to which indexes are written. + */ +template +class WriteCoalescingCallbackWrapper { + struct TempStorage_Offsets { + uint16_t compacted_offset[NumWriteBufferItems]; + }; + struct TempStorage_Symbols { + OutputT compacted_symbols[NumWriteBufferItems]; + }; + using offset_cache_t = + ::cuda::std::conditional_t; + using symbol_cache_t = ::cuda::std:: + conditional_t, TempStorage_Symbols>; + struct TempStorage_ : offset_cache_t, symbol_cache_t {}; + + __device__ __forceinline__ TempStorage_& PrivateStorage() + { + __shared__ TempStorage private_storage; + return private_storage.Alias(); + } + TempStorage_& temp_storage; + + public: + struct TempStorage : cub::Uninitialized {}; + + __device__ __forceinline__ WriteCoalescingCallbackWrapper(TransducerTableT transducer_table, + TransducedOutItT out_it, + TransducedIndexOutItT out_idx_it, + uint32_t thread_out_offset, + uint32_t tile_out_offset, + uint32_t tile_in_offset, + uint32_t tile_out_count) + : temp_storage(PrivateStorage()), + transducer_table(transducer_table), + out_it(out_it), + out_idx_it(out_idx_it), + thread_out_offset(thread_out_offset), + tile_out_offset(tile_out_offset), + tile_in_offset(tile_in_offset), + tile_out_count(tile_out_count) + { + } + + template + __device__ __forceinline__ void Init(OffsetT const& offset) + { + this->in_offset = offset; + } + + template + __device__ __forceinline__ void ReadSymbol(CharIndexT const character_index, + StateIndexT const old_state, + StateIndexT const new_state, + SymbolIndexT const symbol_id, + SymbolT const read_symbol) { uint32_t const count = transducer_table(old_state, symbol_id, read_symbol); - if (write) { -#if defined(__CUDA_ARCH__) -#pragma unroll 1 -#endif - for (uint32_t out_char = 0; out_char < count; out_char++) { - out_it[out_count + out_char] = + for (uint32_t out_char = 0; out_char < count; out_char++) { + if constexpr (!DiscardIndexOutput) { + temp_storage.compacted_offset[thread_out_offset + out_char - tile_out_offset] = + in_offset + character_index - tile_in_offset; + } + if constexpr (!DiscardTranslatedOutput) { + temp_storage.compacted_symbols[thread_out_offset + out_char - tile_out_offset] = transducer_table(old_state, symbol_id, out_char, read_symbol); - out_idx_it[out_count + out_char] = offset + character_index; } } - out_count += count; + thread_out_offset += count; } - __host__ __device__ __forceinline__ void TearDown() {} + __device__ __forceinline__ void TearDown() + { + __syncthreads(); + if constexpr (!DiscardTranslatedOutput) { + for (uint32_t out_char = threadIdx.x; out_char < tile_out_count; out_char += blockDim.x) { + out_it[tile_out_offset + out_char] = temp_storage.compacted_symbols[out_char]; + } + } + if constexpr (!DiscardIndexOutput) { + for (uint32_t out_char = threadIdx.x; out_char < tile_out_count; out_char += blockDim.x) { + out_idx_it[tile_out_offset + out_char] = + temp_storage.compacted_offset[out_char] + tile_in_offset; + } + } + __syncthreads(); + } public: TransducerTableT const transducer_table; TransducedOutItT out_it; TransducedIndexOutItT out_idx_it; - uint32_t out_count; - uint32_t offset; - bool write; + uint32_t thread_out_offset; + uint32_t tile_out_offset; + uint32_t tile_in_offset; + uint32_t in_offset; + uint32_t tile_out_count; }; /** @@ -125,17 +342,18 @@ class DFASimulationCallbackWrapper { template class StateVectorTransitionOp { public: - __host__ __device__ __forceinline__ StateVectorTransitionOp( + __device__ __forceinline__ StateVectorTransitionOp( TransitionTableT const& transition_table, std::array& state_vector) : transition_table(transition_table), state_vector(state_vector) { } template - __host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index, - SymbolIndexT const& read_symbol_id, - SymbolT const& read_symbol) const + __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index, + SymbolIndexT const& read_symbol_id, + SymbolT const& read_symbol) const { +#pragma unroll for (int32_t i = 0; i < NUM_INSTANCES; ++i) { state_vector[i] = transition_table(state_vector[i], read_symbol_id); } @@ -152,17 +370,17 @@ struct StateTransitionOp { TransitionTableT const& transition_table; CallbackOpT& callback_op; - __host__ __device__ __forceinline__ StateTransitionOp(TransitionTableT const& transition_table, - StateIndexT state, - CallbackOpT& callback_op) + __device__ __forceinline__ StateTransitionOp(TransitionTableT const& transition_table, + StateIndexT state, + CallbackOpT& callback_op) : transition_table(transition_table), state(state), callback_op(callback_op) { } template - __host__ __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index, - SymbolIndexT const& read_symbol_id, - SymbolT const& read_symbol) + __device__ __forceinline__ void ReadSymbol(CharIndexT const& character_index, + SymbolIndexT const& read_symbol_id, + SymbolT const& read_symbol) { // Remember what state we were in before we made the transition StateIndexT previous_state = state; @@ -420,7 +638,7 @@ struct AgentDFA { __syncthreads(); // Thread's symbols - CharT* t_chars = &temp_storage.chars[threadIdx.x * SYMBOLS_PER_THREAD]; + CharT const* t_chars = &temp_storage.chars[threadIdx.x * SYMBOLS_PER_THREAD]; // Parse thread's symbols and transition the state-vector if (is_full_block) { @@ -538,6 +756,43 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) CUDF_KERNEL // The state transition vector passed on to the second stage of the algorithm StateVectorT out_state_vector; + using OutSymbolT = typename DfaT::OutSymbolT; + // static constexpr int32_t MIN_TRANSLATED_OUT = DfaT::MIN_TRANSLATED_OUT; + static constexpr int32_t num_max_translated_out = DfaT::MAX_TRANSLATED_OUT; + static constexpr bool discard_out_index = + ::cuda::std::is_same>::value; + static constexpr bool discard_out_it = + ::cuda::std::is_same>::value; + using NonWriteCoalescingT = + DFAWriteCallbackWrapper; + + using WriteCoalescingT = + WriteCoalescingCallbackWrapper; + + static constexpr bool is_translation_pass = (!IS_TRANS_VECTOR_PASS) || IS_SINGLE_PASS; + + // Use write-coalescing only if the worst-case output size per tile fits into shared memory + static constexpr bool can_use_smem_cache = + (sizeof(typename WriteCoalescingT::TempStorage) + sizeof(typename AgentDfaSimT::TempStorage) + + sizeof(typename DfaT::SymbolGroupStorageT) + sizeof(typename DfaT::TransitionTableStorageT) + + sizeof(typename DfaT::TranslationTableStorageT)) < (48 * 1024); + static constexpr bool use_smem_cache = + is_translation_pass and + (sizeof(typename WriteCoalescingT::TempStorage) <= AgentDFAPolicy::SMEM_THRESHOLD) and + can_use_smem_cache; + + using DFASimulationCallbackWrapperT = + typename cub::If::Type; + // Stage 1: Compute the state-transition vector if (IS_TRANS_VECTOR_PASS || IS_SINGLE_PASS) { // Keeping track of the state for each of the state machines @@ -576,7 +831,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) CUDF_KERNEL // -> first block/tile: write out block aggregate as the "tile's" inclusive (i.e., the one that // incorporates all preceding blocks/tiles results) //------------------------------------------------------------------------------ - if (IS_SINGLE_PASS) { + if constexpr (IS_SINGLE_PASS) { uint32_t tile_idx = blockIdx.x; using StateVectorCompositeOpT = VectorCompositeOp; @@ -623,10 +878,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) CUDF_KERNEL } // Perform finite-state machine simulation, computing size of transduced output - DFASimulationCallbackWrapper - callback_wrapper(transducer_table, transduced_out_it, transduced_out_idx_it); + DFACountCallbackWrapper count_chars_callback_op{transducer_table}; StateIndexT t_start_state = state; agent_dfa.GetThreadStateTransitions(symbol_matcher, @@ -635,7 +887,7 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) CUDF_KERNEL blockIdx.x * SYMBOLS_PER_BLOCK, num_chars, state, - callback_wrapper, + count_chars_callback_op, cub::Int2Type()); __syncthreads(); @@ -650,15 +902,18 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) CUDF_KERNEL __shared__ typename OffsetPrefixScanCallbackOpT_::TempStorage prefix_callback_temp_storage; uint32_t tile_idx = blockIdx.x; + uint32_t tile_out_offset{}; + uint32_t tile_out_count{}; + uint32_t thread_out_offset{}; if (tile_idx == 0) { OffsetT block_aggregate = 0; OutOffsetBlockScan(scan_temp_storage) - .ExclusiveScan(callback_wrapper.out_count, - callback_wrapper.out_count, + .ExclusiveScan(count_chars_callback_op.out_count, + thread_out_offset, static_cast(0), cub::Sum{}, block_aggregate); - + tile_out_count = block_aggregate; if (threadIdx.x == 0 /*and not IS_LAST_TILE*/) { offset_tile_state.SetInclusive(0, block_aggregate); } @@ -671,22 +926,28 @@ __launch_bounds__(int32_t(AgentDFAPolicy::BLOCK_THREADS)) CUDF_KERNEL offset_tile_state, prefix_callback_temp_storage, cub::Sum{}, tile_idx); OutOffsetBlockScan(scan_temp_storage) - .ExclusiveScan( - callback_wrapper.out_count, callback_wrapper.out_count, cub::Sum{}, prefix_op); - + .ExclusiveScan(count_chars_callback_op.out_count, thread_out_offset, cub::Sum{}, prefix_op); + tile_out_offset = prefix_op.GetExclusivePrefix(); + tile_out_count = prefix_op.GetBlockAggregate(); if (tile_idx == gridDim.x - 1 && threadIdx.x == 0) { *d_num_transduced_out_it = prefix_op.GetInclusivePrefix(); } } - callback_wrapper.write = true; + DFASimulationCallbackWrapperT write_translated_callback_op{transducer_table, + transduced_out_it, + transduced_out_idx_it, + thread_out_offset, + tile_out_offset, + blockIdx.x * SYMBOLS_PER_BLOCK, + tile_out_count}; agent_dfa.GetThreadStateTransitions(symbol_matcher, transition_table, d_chars, blockIdx.x * SYMBOLS_PER_BLOCK, num_chars, t_start_state, - callback_wrapper, + write_translated_callback_op, cub::Int2Type()); } } diff --git a/cpp/src/io/fst/dispatch_dfa.cuh b/cpp/src/io/fst/dispatch_dfa.cuh index be63ec6539f..ef5e9c8a78f 100644 --- a/cpp/src/io/fst/dispatch_dfa.cuh +++ b/cpp/src/io/fst/dispatch_dfa.cuh @@ -37,6 +37,11 @@ struct AgentDFAPolicy { // The number of symbols processed by each thread static constexpr int32_t ITEMS_PER_THREAD = _ITEMS_PER_THREAD; + + // If the shared memory-backed write buffer exceeds this threshold, the FST will skip buffering + // the output in a write buffer and instead immediately write out to global memory, potentially + // resulting in non-coalesced writes + static constexpr std::size_t SMEM_THRESHOLD = 24 * 1024; }; /** @@ -49,7 +54,7 @@ struct DeviceFSMPolicy { struct Policy900 : cub::ChainedPolicy<900, Policy900, Policy900> { enum { BLOCK_THREADS = 128, - ITEMS_PER_THREAD = 32, + ITEMS_PER_THREAD = 16, }; using AgentDFAPolicy = AgentDFAPolicy; diff --git a/cpp/src/io/fst/lookup_tables.cuh b/cpp/src/io/fst/lookup_tables.cuh index 5532a7f994b..ae1f81fd541 100644 --- a/cpp/src/io/fst/lookup_tables.cuh +++ b/cpp/src/io/fst/lookup_tables.cuh @@ -367,18 +367,18 @@ class TransitionTable { template static KernelParameter InitDeviceTransitionTable( - std::array, MAX_NUM_STATES> const& translation_table) + std::array, MAX_NUM_STATES> const& transition_table) { KernelParameter init_data{}; - // translation_table[state][symbol] -> new state - for (std::size_t state = 0; state < translation_table.size(); ++state) { - for (std::size_t symbol = 0; symbol < translation_table[state].size(); ++symbol) { + // transition_table[state][symbol] -> new state + for (std::size_t state = 0; state < transition_table.size(); ++state) { + for (std::size_t symbol = 0; symbol < transition_table[state].size(); ++symbol) { CUDF_EXPECTS( - static_cast(translation_table[state][symbol]) <= + static_cast(transition_table[state][symbol]) <= std::numeric_limits::max(), "Target state index value exceeds value representable by the transition table's type"); init_data.transitions[symbol * MAX_NUM_STATES + state] = - static_cast(translation_table[state][symbol]); + static_cast(transition_table[state][symbol]); } } @@ -494,6 +494,10 @@ class dfa_device_view { // This is a value queried by the DFA simulation algorithm static constexpr int32_t MAX_NUM_STATES = NUM_STATES; + using OutSymbolT = typename TranslationTableT::OutSymbolT; + static constexpr int32_t MIN_TRANSLATED_OUT = TranslationTableT::MIN_TRANSLATED_OUT; + static constexpr int32_t MAX_TRANSLATED_OUT = TranslationTableT::MAX_TRANSLATED_OUT; + using SymbolGroupStorageT = std::conditional_t::value, typename SymbolGroupIdLookupT::TempStorage, typename cub::NullType>; @@ -542,24 +546,33 @@ class dfa_device_view { * @tparam OutSymbolT The symbol type being output * @tparam OutSymbolOffsetT Type sufficiently large to index into the lookup table of output * symbols - * @tparam MAX_NUM_SYMBOLS The maximum number of symbols being output by a single state transition + * @tparam MAX_NUM_SYMBOLS The maximum number of symbol groups supported by this lookup table * @tparam MAX_NUM_STATES The maximum number of states that this lookup table shall support + * @tparam MIN_TRANSLATED_OUT_ The minimum number of symbols being output by a single state + * transition + * @tparam MAX_TRANSLATED_OUT_ The maximum number of symbols being output by a single state + * transition * @tparam MAX_TABLE_SIZE The maximum number of items in the lookup table of output symbols - * be used. */ -template class TransducerLookupTable { private: struct _TempStorage { OutSymbolOffsetT out_offset[MAX_NUM_STATES * MAX_NUM_SYMBOLS + 1]; - OutSymbolT out_symbols[MAX_TABLE_SIZE]; + OutSymbolT_ out_symbols[MAX_TABLE_SIZE]; }; public: + using OutSymbolT = OutSymbolT_; + static constexpr int32_t MIN_TRANSLATED_OUT = MIN_TRANSLATED_OUT_; + static constexpr int32_t MAX_TRANSLATED_OUT = MAX_TRANSLATED_OUT_; + using TempStorage = cub::Uninitialized<_TempStorage>; struct KernelParameter { @@ -567,6 +580,8 @@ class TransducerLookupTable { OutSymbolOffsetT, MAX_NUM_SYMBOLS, MAX_NUM_STATES, + MIN_TRANSLATED_OUT, + MAX_TRANSLATED_OUT, MAX_TABLE_SIZE>; OutSymbolOffsetT d_out_offsets[MAX_NUM_STATES * MAX_NUM_SYMBOLS + 1]; @@ -686,14 +701,19 @@ class TransducerLookupTable { * sequence of symbols that the finite-state transducer is supposed to output for each transition. * * @tparam MAX_TABLE_SIZE The maximum number of items in the lookup table of output symbols - * be used + * @tparam MIN_TRANSLATED_OUT The minimum number of symbols being output by a single state + * transition + * @tparam MAX_TRANSLATED_OUT The maximum number of symbols being output by a single state + * transition * @tparam OutSymbolT The symbol type being output - * @tparam MAX_NUM_SYMBOLS The maximum number of symbols being output by a single state transition + * @tparam MAX_NUM_SYMBOLS The maximum number of symbol groups supported by this lookup table * @tparam MAX_NUM_STATES The maximum number of states that this lookup table shall support * @param translation_table The translation table * @return A translation table of type `TransducerLookupTable`. */ template @@ -705,20 +725,30 @@ auto make_translation_table(std::array, MAX_N OutSymbolOffsetT, MAX_NUM_SYMBOLS, MAX_NUM_STATES, + MIN_TRANSLATED_OUT, + MAX_TRANSLATED_OUT, MAX_TABLE_SIZE>; return translation_table_t::InitDeviceTranslationTable(translation_table); } -template +template class TranslationOp { private: struct _TempStorage {}; public: + using OutSymbolT = OutSymbolT_; + static constexpr int32_t MIN_TRANSLATED_OUT = MIN_TRANSLATED_OUT_; + static constexpr int32_t MAX_TRANSLATED_OUT = MAX_TRANSLATED_OUT_; + using TempStorage = cub::Uninitialized<_TempStorage>; struct KernelParameter { - using LookupTableT = TranslationOp; + using LookupTableT = + TranslationOp; TranslationOpT translation_op; }; @@ -772,6 +802,10 @@ class TranslationOp { * * @tparam FunctorT A function object type that must implement two signatures: (1) with `(state_id, * match_id, read_symbol)` and (2) with `(state_id, match_id, relative_offset, read_symbol)` + * @tparam MIN_TRANSLATED_SYMBOLS The minimum number of translated output symbols for any given + * input symbol + * @tparam MAX_TRANSLATED_SYMBOLS The maximum number of translated output symbols for any given + * input symbol * @param map_op A function object that must implement two signatures: (1) with `(state_id, * match_id, read_symbol)` and (2) with `(state_id, match_id, relative_offset, read_symbol)`. * Invocations of the first signature, (1), must return the number of symbols that are emitted for @@ -779,10 +813,14 @@ class TranslationOp { * that transition, where `i` corresponds to `relative_offse` * @return A translation table of type `TranslationO` */ -template +template auto make_translation_functor(FunctorT map_op) { - return TranslationOp::InitDeviceTranslationTable(map_op); + return TranslationOp:: + InitDeviceTranslationTable(map_op); } /** diff --git a/cpp/src/io/json/json_normalization.cu b/cpp/src/io/json/json_normalization.cu index ca56a12eb36..760b2214365 100644 --- a/cpp/src/io/json/json_normalization.cu +++ b/cpp/src/io/json/json_normalization.cu @@ -302,11 +302,14 @@ void normalize_single_quotes(datasource::owning_buffer( + normalize_quotes::TransduceToNormalizedQuotes{}), + stream); rmm::device_uvector outbuf(indata.size() * 2, stream, mr); rmm::device_scalar outbuf_size(stream, mr); @@ -327,11 +330,14 @@ void normalize_whitespace(datasource::owning_buffer rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto parser = fst::detail::make_fst( - fst::detail::make_symbol_group_lut(normalize_whitespace::wna_sgs), - fst::detail::make_transition_table(normalize_whitespace::wna_state_tt), - fst::detail::make_translation_functor(normalize_whitespace::TransduceToNormalizedWS{}), - stream); + static constexpr std::int32_t min_out = 0; + static constexpr std::int32_t max_out = 2; + auto parser = + fst::detail::make_fst(fst::detail::make_symbol_group_lut(normalize_whitespace::wna_sgs), + fst::detail::make_transition_table(normalize_whitespace::wna_state_tt), + fst::detail::make_translation_functor( + normalize_whitespace::TransduceToNormalizedWS{}), + stream); rmm::device_uvector outbuf(indata.size(), stream, mr); rmm::device_scalar outbuf_size(stream, mr); diff --git a/cpp/src/io/json/nested_json_gpu.cu b/cpp/src/io/json/nested_json_gpu.cu index a007754ef4f..8decaf034f3 100644 --- a/cpp/src/io/json/nested_json_gpu.cu +++ b/cpp/src/io/json/nested_json_gpu.cu @@ -1455,11 +1455,14 @@ void get_stack_context(device_span json_in, constexpr auto max_translation_table_size = to_stack_op::NUM_SYMBOL_GROUPS * to_stack_op::TT_NUM_STATES; - auto json_to_stack_ops_fst = fst::detail::make_fst( + static constexpr auto min_translated_out = 0; + static constexpr auto max_translated_out = 1; + auto json_to_stack_ops_fst = fst::detail::make_fst( fst::detail::make_symbol_group_lut(to_stack_op::get_sgid_lut(delimiter)), fst::detail::make_transition_table(to_stack_op::get_transition_table(stack_behavior)), - fst::detail::make_translation_table( - to_stack_op::get_translation_table(stack_behavior)), + fst::detail:: + make_translation_table( + to_stack_op::get_translation_table(stack_behavior)), stream); // "Search" for relevant occurrence of brackets and braces that indicate the beginning/end @@ -1507,11 +1510,12 @@ std::pair, rmm::device_uvector> pr // Instantiate FST for post-processing the token stream to remove all tokens that belong to an // invalid JSON line token_filter::UnwrapTokenFromSymbolOp sgid_op{}; - auto filter_fst = - fst::detail::make_fst(fst::detail::make_symbol_group_lut(token_filter::symbol_groups, sgid_op), - fst::detail::make_transition_table(token_filter::transition_table), - fst::detail::make_translation_functor(token_filter::TransduceToken{}), - stream); + using symbol_t = thrust::tuple; + auto filter_fst = fst::detail::make_fst( + fst::detail::make_symbol_group_lut(token_filter::symbol_groups, sgid_op), + fst::detail::make_transition_table(token_filter::transition_table), + fst::detail::make_translation_functor(token_filter::TransduceToken{}), + stream); auto const mr = rmm::mr::get_current_device_resource(); rmm::device_scalar d_num_selected_tokens(stream, mr); @@ -1598,7 +1602,8 @@ std::pair, rmm::device_uvector> ge fst::detail::make_symbol_group_lookup_op( fix_stack_of_excess_chars::SymbolPairToSymbolGroupId{delimiter}), fst::detail::make_transition_table(fix_stack_of_excess_chars::transition_table), - fst::detail::make_translation_functor(fix_stack_of_excess_chars::TransduceInputOp{}), + fst::detail::make_translation_functor( + fix_stack_of_excess_chars::TransduceInputOp{}), stream); fix_stack_of_excess_chars.Transduce(zip_in, static_cast(json_in.size()), @@ -1619,7 +1624,7 @@ std::pair, rmm::device_uvector> ge auto json_to_tokens_fst = fst::detail::make_fst( fst::detail::make_symbol_group_lookup_op(tokenizer_pda::PdaSymbolToSymbolGroupId{delimiter}), fst::detail::make_transition_table(tokenizer_pda::get_transition_table(format)), - fst::detail::make_translation_table( + fst::detail::make_translation_table( tokenizer_pda::get_translation_table(recover_from_error)), stream); diff --git a/cpp/tests/io/fst/common.hpp b/cpp/tests/io/fst/common.hpp index 382d21fabb8..0177300eda9 100644 --- a/cpp/tests/io/fst/common.hpp +++ b/cpp/tests/io/fst/common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -69,6 +69,8 @@ std::array, TT_NUM_STATES> const pda_s /* TT_ESC */ {{TT_STR, TT_STR, TT_STR, TT_STR, TT_STR, TT_STR, TT_STR}}}}; // Translation table (i.e., for each transition, what are the symbols that we output) +static constexpr auto min_translated_out = 1; +static constexpr auto max_translated_out = 1; std::array, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const pda_out_tt{ {/* IN_STATE { [ } ] " \ OTHER */ /* TT_OOS */ {{{'{'}, {'['}, {'}'}, {']'}, {'x'}, {'x'}, {'x'}}}, diff --git a/cpp/tests/io/fst/fst_test.cu b/cpp/tests/io/fst/fst_test.cu index 4df0d3ae04d..8a8d3d39e0f 100644 --- a/cpp/tests/io/fst/fst_test.cu +++ b/cpp/tests/io/fst/fst_test.cu @@ -169,7 +169,9 @@ TEST_F(FstTest, GroundTruth) auto parser = cudf::io::fst::detail::make_fst( cudf::io::fst::detail::make_symbol_group_lut(pda_sgs), cudf::io::fst::detail::make_transition_table(pda_state_tt), - cudf::io::fst::detail::make_translation_table(pda_out_tt), + cudf::io::fst::detail::make_translation_table(pda_out_tt), stream); // Allocate device-side temporary storage & run algorithm