Skip to content

Commit

Permalink
Refactors JSON reader's pushdown automaton (#13716)
Browse files Browse the repository at this point in the history
This PR simplifies and cleans up the JSON reader's pushdown automaton. 

The pushdown automaton takes as input two arrays:
1. The JSON's input characters
2. The stack context for each character (`{` - `JSON object`, `[` - `JSON array`, `_` - `Root of JSON`)

Previously, we were fusing the two arrays and materializing them straight to the symbol group id for each combination. A symbol group id serves as the column of the transition table. The symbol group ids array was then used as input to the finite state transducer (FST).

After the [recent refactor of the FST](#13344) lookup tables, the FST has become more flexible. It now supports arbitrary iterators and the symbol group id lookup table (that maps a symbol to a symbol group id) can now be implemented by a simple function object. 

This PR takes advantage of the FST's ability to take fancy iterators. We now zip the `json_input` and `stack_context` symbols and pass that `zip_iterator` to the FST.

Authors:
  - Elias Stehle (https://github.com/elstehle)
  - Vukasin Milovanovic (https://github.com/vuule)
  - Karthikeyan (https://github.com/karthikeyann)

Approvers:
  - Karthikeyan (https://github.com/karthikeyann)
  - Vukasin Milovanovic (https://github.com/vuule)

URL: #13716
  • Loading branch information
elstehle authored Aug 9, 2023
1 parent 378acc5 commit e8df037
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 28 deletions.
70 changes: 69 additions & 1 deletion cpp/src/io/fst/lookup_tables.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,74 @@ class SingleSymbolSmemLUT {
}
};

/**
* @brief A simple symbol group lookup wrapper that uses a simple function object to
* retrieve the symbol group id for a symbol.
*
* @tparam SymbolGroupLookupOpT The function object type to return the symbol group for a given
* symbol
*/
template <typename SymbolGroupLookupOpT>
class SymbolGroupLookupOp {
private:
struct _TempStorage {};

public:
using TempStorage = cub::Uninitialized<_TempStorage>;

struct KernelParameter {
// Declare the member type that the DFA is going to instantiate
using LookupTableT = SymbolGroupLookupOp<SymbolGroupLookupOpT>;
SymbolGroupLookupOpT sgid_lookup_op;
};

static KernelParameter InitDeviceSymbolGroupIdLut(SymbolGroupLookupOpT sgid_lookup_op)
{
return KernelParameter{sgid_lookup_op};
}

private:
_TempStorage& temp_storage;
SymbolGroupLookupOpT sgid_lookup_op;

__device__ __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}

public:
CUDF_HOST_DEVICE SymbolGroupLookupOp(KernelParameter const& kernel_param,
TempStorage& temp_storage)
: temp_storage(temp_storage.Alias()), sgid_lookup_op(kernel_param.sgid_lookup_op)
{
}

template <typename SymbolT_>
constexpr CUDF_HOST_DEVICE int32_t operator()(SymbolT_ const symbol) const
{
// Look up the symbol group for given symbol
return sgid_lookup_op(symbol);
}
};

/**
* @brief Prepares a simple symbol group lookup wrapper that uses a simple function object to
* retrieve the symbol group id for a symbol.
*
* @tparam FunctorT A function object type that must implement the signature `int32_t
* operator()(symbol)`, where `symbol` is a symbol from the input type.
* @param sgid_lookup_op A function object that must implement the signature `int32_t
* operator()(symbol)`, where `symbol` is a symbol from the input type.
* @return The kernel parameter of type SymbolGroupLookupOp::KernelParameter that is used to
* initialize a simple symbol group id lookup wrapper
*/
template <typename FunctorT>
auto make_symbol_group_lookup_op(FunctorT sgid_lookup_op)
{
return SymbolGroupLookupOp<FunctorT>::InitDeviceSymbolGroupIdLut(sgid_lookup_op);
}

/**
* @brief Creates a symbol group lookup table of type `SingleSymbolSmemLUT` that uses a two-staged
* lookup approach. @p pre_map_op is a function object invoked with `(lut, symbol)` that must return
Expand Down Expand Up @@ -830,7 +898,7 @@ class Dfa {
};

/**
* @brief Creates a determninistic finite automaton (DFA) as specified by the triple of (symbol
* @brief Creates a deterministic finite automaton (DFA) as specified by the triple of (symbol
* group, transition, translation)-lookup tables to be used with the finite-state transducer
* algorithm.
*
Expand Down
43 changes: 16 additions & 27 deletions cpp/src/io/json/nested_json_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ static __constant__ PdaSymbolGroupIdT tos_sg_to_pda_sgid[] = {
struct PdaSymbolToSymbolGroupId {
template <typename SymbolT, typename StackSymbolT>
__device__ __forceinline__ PdaSymbolGroupIdT
operator()(thrust::tuple<SymbolT, StackSymbolT> symbol_pair)
operator()(thrust::tuple<SymbolT, StackSymbolT> symbol_pair) const
{
// The symbol read from the input
auto symbol = thrust::get<0>(symbol_pair);
Expand Down Expand Up @@ -1420,36 +1420,25 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> ge

// Prepare for PDA transducer pass, merging input symbols with stack symbols
auto const recover_from_error = (format == tokenizer_pda::json_format_cfg_t::JSON_LINES_RECOVER);
rmm::device_uvector<PdaSymbolGroupIdT> pda_sgids = [json_in, stream, recover_from_error]() {
// Memory holding the top-of-stack stack context for the input
rmm::device_uvector<StackSymbolT> stack_op_indices{json_in.size(), stream};

// Identify what is the stack context for each input character (JSON-root, struct, or list)
auto const stack_behavior = recover_from_error ? stack_behavior_t::ResetOnDelimiter
: stack_behavior_t::PushPopWithoutReset;
get_stack_context(json_in, stack_op_indices.data(), stack_behavior, stream);

rmm::device_uvector<PdaSymbolGroupIdT> pda_sgids{json_in.size(), stream};
auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_op_indices.data());
thrust::transform(rmm::exec_policy(stream),
zip_in,
zip_in + json_in.size(),
pda_sgids.data(),
tokenizer_pda::PdaSymbolToSymbolGroupId{});
return pda_sgids;
}();

// Instantiating PDA transducer
std::array<std::vector<char>, tokenizer_pda::NUM_PDA_SGIDS> pda_sgid_identity{};
std::generate(std::begin(pda_sgid_identity),
std::end(pda_sgid_identity),
[i = char{0}]() mutable { return std::vector<char>{i++}; });

// Memory holding the top-of-stack stack context for the input
rmm::device_uvector<StackSymbolT> stack_symbols{json_in.size(), stream};

// Identify what is the stack context for each input character (JSON-root, struct, or list)
auto const stack_behavior =
recover_from_error ? stack_behavior_t::ResetOnDelimiter : stack_behavior_t::PushPopWithoutReset;
get_stack_context(json_in, stack_symbols.data(), stack_behavior, stream);

// Input to the full pushdown automaton finite-state transducer, where a input symbol comprises
// the combination of a character from the JSON input together with the stack context for that
// character.
auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_symbols.data());

constexpr auto max_translation_table_size =
tokenizer_pda::NUM_PDA_SGIDS *
static_cast<tokenizer_pda::StateT>(tokenizer_pda::pda_state_t::PD_NUM_STATES);
auto json_to_tokens_fst = fst::detail::make_fst(
fst::detail::make_symbol_group_lut(pda_sgid_identity),
fst::detail::make_symbol_group_lookup_op(tokenizer_pda::PdaSymbolToSymbolGroupId{}),
fst::detail::make_transition_table(tokenizer_pda::get_transition_table(format)),
fst::detail::make_translation_table<max_translation_table_size>(
tokenizer_pda::get_translation_table(recover_from_error)),
Expand All @@ -1473,7 +1462,7 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> ge
rmm::device_uvector<SymbolOffsetT> tokens_indices{
max_token_out_count + delimiter_offset, stream, mr};

json_to_tokens_fst.Transduce(pda_sgids.begin(),
json_to_tokens_fst.Transduce(zip_in,
static_cast<SymbolOffsetT>(json_in.size()),
tokens.data() + delimiter_offset,
tokens_indices.data() + delimiter_offset,
Expand Down

0 comments on commit e8df037

Please sign in to comment.