From e8df03754021e3decfc6640b58bd7a0770b0c230 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Wed, 9 Aug 2023 21:17:17 +0200 Subject: [PATCH] Refactors JSON reader's pushdown automaton (#13716) This PR simplifies and cleans up the JSON reader's pushdown automaton. The pushdown automaton takes as input two arrays: 1. The JSON's input characters 2. The stack context for each character (`{` - `JSON object`, `[` - `JSON array`, `_` - `Root of JSON`) Previously, we were fusing the two arrays and materializing them straight to the symbol group id for each combination. A symbol group id serves as the column of the transition table. The symbol group ids array was then used as input to the finite state transducer (FST). After the [recent refactor of the FST](https://github.com/rapidsai/cudf/pull/13344) lookup tables, the FST has become more flexible. It now supports arbitrary iterators and the symbol group id lookup table (that maps a symbol to a symbol group id) can now be implemented by a simple function object. This PR takes advantage of the FST's ability to take fancy iterators. We now zip the `json_input` and `stack_context` symbols and pass that `zip_iterator` to the FST. Authors: - Elias Stehle (https://github.com/elstehle) - Vukasin Milovanovic (https://github.com/vuule) - Karthikeyan (https://github.com/karthikeyann) Approvers: - Karthikeyan (https://github.com/karthikeyann) - Vukasin Milovanovic (https://github.com/vuule) URL: https://github.com/rapidsai/cudf/pull/13716 --- cpp/src/io/fst/lookup_tables.cuh | 70 +++++++++++++++++++++++++++++- cpp/src/io/json/nested_json_gpu.cu | 43 +++++++----------- 2 files changed, 85 insertions(+), 28 deletions(-) diff --git a/cpp/src/io/fst/lookup_tables.cuh b/cpp/src/io/fst/lookup_tables.cuh index c4176d5673f..37c99453361 100644 --- a/cpp/src/io/fst/lookup_tables.cuh +++ b/cpp/src/io/fst/lookup_tables.cuh @@ -179,6 +179,74 @@ class SingleSymbolSmemLUT { } }; +/** + * @brief A simple symbol group lookup wrapper that uses a simple function object to + * retrieve the symbol group id for a symbol. + * + * @tparam SymbolGroupLookupOpT The function object type to return the symbol group for a given + * symbol + */ +template +class SymbolGroupLookupOp { + private: + struct _TempStorage {}; + + public: + using TempStorage = cub::Uninitialized<_TempStorage>; + + struct KernelParameter { + // Declare the member type that the DFA is going to instantiate + using LookupTableT = SymbolGroupLookupOp; + SymbolGroupLookupOpT sgid_lookup_op; + }; + + static KernelParameter InitDeviceSymbolGroupIdLut(SymbolGroupLookupOpT sgid_lookup_op) + { + return KernelParameter{sgid_lookup_op}; + } + + private: + _TempStorage& temp_storage; + SymbolGroupLookupOpT sgid_lookup_op; + + __device__ __forceinline__ _TempStorage& PrivateStorage() + { + __shared__ _TempStorage private_storage; + return private_storage; + } + + public: + CUDF_HOST_DEVICE SymbolGroupLookupOp(KernelParameter const& kernel_param, + TempStorage& temp_storage) + : temp_storage(temp_storage.Alias()), sgid_lookup_op(kernel_param.sgid_lookup_op) + { + } + + template + constexpr CUDF_HOST_DEVICE int32_t operator()(SymbolT_ const symbol) const + { + // Look up the symbol group for given symbol + return sgid_lookup_op(symbol); + } +}; + +/** + * @brief Prepares a simple symbol group lookup wrapper that uses a simple function object to + * retrieve the symbol group id for a symbol. + * + * @tparam FunctorT A function object type that must implement the signature `int32_t + * operator()(symbol)`, where `symbol` is a symbol from the input type. + * @param sgid_lookup_op A function object that must implement the signature `int32_t + * operator()(symbol)`, where `symbol` is a symbol from the input type. + * @return The kernel parameter of type SymbolGroupLookupOp::KernelParameter that is used to + * initialize a simple symbol group id lookup wrapper + */ +template +auto make_symbol_group_lookup_op(FunctorT sgid_lookup_op) +{ + return SymbolGroupLookupOp::InitDeviceSymbolGroupIdLut(sgid_lookup_op); +} + /** * @brief Creates a symbol group lookup table of type `SingleSymbolSmemLUT` that uses a two-staged * lookup approach. @p pre_map_op is a function object invoked with `(lut, symbol)` that must return @@ -830,7 +898,7 @@ class Dfa { }; /** - * @brief Creates a determninistic finite automaton (DFA) as specified by the triple of (symbol + * @brief Creates a deterministic finite automaton (DFA) as specified by the triple of (symbol * group, transition, translation)-lookup tables to be used with the finite-state transducer * algorithm. * diff --git a/cpp/src/io/json/nested_json_gpu.cu b/cpp/src/io/json/nested_json_gpu.cu index 0629ceb95c6..8552db9a719 100644 --- a/cpp/src/io/json/nested_json_gpu.cu +++ b/cpp/src/io/json/nested_json_gpu.cu @@ -477,7 +477,7 @@ static __constant__ PdaSymbolGroupIdT tos_sg_to_pda_sgid[] = { struct PdaSymbolToSymbolGroupId { template __device__ __forceinline__ PdaSymbolGroupIdT - operator()(thrust::tuple symbol_pair) + operator()(thrust::tuple symbol_pair) const { // The symbol read from the input auto symbol = thrust::get<0>(symbol_pair); @@ -1420,36 +1420,25 @@ std::pair, rmm::device_uvector> ge // Prepare for PDA transducer pass, merging input symbols with stack symbols auto const recover_from_error = (format == tokenizer_pda::json_format_cfg_t::JSON_LINES_RECOVER); - rmm::device_uvector pda_sgids = [json_in, stream, recover_from_error]() { - // Memory holding the top-of-stack stack context for the input - rmm::device_uvector stack_op_indices{json_in.size(), stream}; - - // Identify what is the stack context for each input character (JSON-root, struct, or list) - auto const stack_behavior = recover_from_error ? stack_behavior_t::ResetOnDelimiter - : stack_behavior_t::PushPopWithoutReset; - get_stack_context(json_in, stack_op_indices.data(), stack_behavior, stream); - - rmm::device_uvector pda_sgids{json_in.size(), stream}; - auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_op_indices.data()); - thrust::transform(rmm::exec_policy(stream), - zip_in, - zip_in + json_in.size(), - pda_sgids.data(), - tokenizer_pda::PdaSymbolToSymbolGroupId{}); - return pda_sgids; - }(); - - // Instantiating PDA transducer - std::array, tokenizer_pda::NUM_PDA_SGIDS> pda_sgid_identity{}; - std::generate(std::begin(pda_sgid_identity), - std::end(pda_sgid_identity), - [i = char{0}]() mutable { return std::vector{i++}; }); + + // Memory holding the top-of-stack stack context for the input + rmm::device_uvector stack_symbols{json_in.size(), stream}; + + // Identify what is the stack context for each input character (JSON-root, struct, or list) + auto const stack_behavior = + recover_from_error ? stack_behavior_t::ResetOnDelimiter : stack_behavior_t::PushPopWithoutReset; + get_stack_context(json_in, stack_symbols.data(), stack_behavior, stream); + + // Input to the full pushdown automaton finite-state transducer, where a input symbol comprises + // the combination of a character from the JSON input together with the stack context for that + // character. + auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_symbols.data()); constexpr auto max_translation_table_size = tokenizer_pda::NUM_PDA_SGIDS * static_cast(tokenizer_pda::pda_state_t::PD_NUM_STATES); auto json_to_tokens_fst = fst::detail::make_fst( - fst::detail::make_symbol_group_lut(pda_sgid_identity), + fst::detail::make_symbol_group_lookup_op(tokenizer_pda::PdaSymbolToSymbolGroupId{}), fst::detail::make_transition_table(tokenizer_pda::get_transition_table(format)), fst::detail::make_translation_table( tokenizer_pda::get_translation_table(recover_from_error)), @@ -1473,7 +1462,7 @@ std::pair, rmm::device_uvector> ge rmm::device_uvector tokens_indices{ max_token_out_count + delimiter_offset, stream, mr}; - json_to_tokens_fst.Transduce(pda_sgids.begin(), + json_to_tokens_fst.Transduce(zip_in, static_cast(json_in.size()), tokens.data() + delimiter_offset, tokens_indices.data() + delimiter_offset,