Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors JSON reader's pushdown automaton #13716

Merged
merged 9 commits into from
Aug 9, 2023
66 changes: 66 additions & 0 deletions cpp/src/io/fst/lookup_tables.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,72 @@ class SingleSymbolSmemLUT {
}
};

/**
* @brief A simple symbol group lookup wrapper that uses a simple function object to
* retrieve the symbol group id for a symbol.
*
* @tparam SymbolGroupLookupOpT The function object type to return the symbol group for a given
* symbol
*/
template <typename SymbolGroupLookupOpT>
class SymbolGroupLookupOp {
private:
struct _TempStorage {};
elstehle marked this conversation as resolved.
Show resolved Hide resolved

public:
using TempStorage = cub::Uninitialized<_TempStorage>;

struct KernelParameter {
using LookupTableT = SymbolGroupLookupOp<SymbolGroupLookupOpT>;
elstehle marked this conversation as resolved.
Show resolved Hide resolved
SymbolGroupLookupOpT sgid_lookup_op;
};

static KernelParameter InitDeviceSymbolGroupIdLut(SymbolGroupLookupOpT sgid_lookup_op)
{
return KernelParameter{sgid_lookup_op};
}

private:
_TempStorage& temp_storage;
SymbolGroupLookupOpT sgid_lookup_op;

__device__ __forceinline__ _TempStorage& PrivateStorage()
{
__shared__ _TempStorage private_storage;
return private_storage;
}

public:
CUDF_HOST_DEVICE SymbolGroupLookupOp(KernelParameter const& kernel_param,
TempStorage& temp_storage)
: temp_storage(temp_storage.Alias()), sgid_lookup_op(kernel_param.sgid_lookup_op)
{
}

template <typename SymbolT_>
constexpr CUDF_HOST_DEVICE int32_t operator()(SymbolT_ const symbol) const
{
// Look up the symbol group for given symbol
return sgid_lookup_op(symbol);
}
};

/**
* @brief Creates a simple symbol group lookup wrapper that uses a simple function object to
* retrieve the symbol group id for a symbol.
*
* @tparam FunctorT A function object type that must implement the signature `int32_t
* operator()(symbol)`, where `symbol` is a symbol from the input type.
* @param sgid_lookup_op A function object that must implement the signature `int32_t
* operator()(symbol)`, where `symbol` is a symbol from the input type.
* @return A symbol group lookup table of type SymbolGroupLookupOp
*/
template <typename FunctorT>
auto make_symbol_group_lookup_op(FunctorT sgid_lookup_op)
{
return SymbolGroupLookupOp<FunctorT>::InitDeviceSymbolGroupIdLut(sgid_lookup_op);
elstehle marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* @brief Creates a symbol group lookup table of type `SingleSymbolSmemLUT` that uses a two-staged
* lookup approach. @p pre_map_op is a function object invoked with `(lut, symbol)` that must return
Expand Down
43 changes: 16 additions & 27 deletions cpp/src/io/json/nested_json_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ static __constant__ PdaSymbolGroupIdT tos_sg_to_pda_sgid[] = {
struct PdaSymbolToSymbolGroupId {
template <typename SymbolT, typename StackSymbolT>
__device__ __forceinline__ PdaSymbolGroupIdT
operator()(thrust::tuple<SymbolT, StackSymbolT> symbol_pair)
operator()(thrust::tuple<SymbolT, StackSymbolT> symbol_pair) const
{
// The symbol read from the input
auto symbol = thrust::get<0>(symbol_pair);
Expand Down Expand Up @@ -1420,36 +1420,25 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> ge

// Prepare for PDA transducer pass, merging input symbols with stack symbols
auto const recover_from_error = (format == tokenizer_pda::json_format_cfg_t::JSON_LINES_RECOVER);
rmm::device_uvector<PdaSymbolGroupIdT> pda_sgids = [json_in, stream, recover_from_error]() {
// Memory holding the top-of-stack stack context for the input
rmm::device_uvector<StackSymbolT> stack_op_indices{json_in.size(), stream};

// Identify what is the stack context for each input character (JSON-root, struct, or list)
auto const stack_behavior = recover_from_error ? stack_behavior_t::ResetOnDelimiter
: stack_behavior_t::PushPopWithoutReset;
get_stack_context(json_in, stack_op_indices.data(), stack_behavior, stream);

rmm::device_uvector<PdaSymbolGroupIdT> pda_sgids{json_in.size(), stream};
auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_op_indices.data());
thrust::transform(rmm::exec_policy(stream),
zip_in,
zip_in + json_in.size(),
pda_sgids.data(),
tokenizer_pda::PdaSymbolToSymbolGroupId{});
return pda_sgids;
}();

// Instantiating PDA transducer
std::array<std::vector<char>, tokenizer_pda::NUM_PDA_SGIDS> pda_sgid_identity{};
std::generate(std::begin(pda_sgid_identity),
std::end(pda_sgid_identity),
[i = char{0}]() mutable { return std::vector<char>{i++}; });

// Memory holding the top-of-stack stack context for the input
rmm::device_uvector<StackSymbolT> stack_op_indices{json_in.size(), stream};

// Identify what is the stack context for each input character (JSON-root, struct, or list)
auto const stack_behavior =
recover_from_error ? stack_behavior_t::ResetOnDelimiter : stack_behavior_t::PushPopWithoutReset;
get_stack_context(json_in, stack_op_indices.data(), stack_behavior, stream);

// Input to the full pushdown automaton finite-state transducer, where a input symbol comprises
// the combination of a character from the JSON input together with the stack context for that
// character.
auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_op_indices.data());

constexpr auto max_translation_table_size =
tokenizer_pda::NUM_PDA_SGIDS *
static_cast<tokenizer_pda::StateT>(tokenizer_pda::pda_state_t::PD_NUM_STATES);
auto json_to_tokens_fst = fst::detail::make_fst(
fst::detail::make_symbol_group_lut(pda_sgid_identity),
fst::detail::make_symbol_group_lookup_op(tokenizer_pda::PdaSymbolToSymbolGroupId{}),
fst::detail::make_transition_table(tokenizer_pda::get_transition_table(format)),
fst::detail::make_translation_table<max_translation_table_size>(
tokenizer_pda::get_translation_table(recover_from_error)),
Expand All @@ -1473,7 +1462,7 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> ge
rmm::device_uvector<SymbolOffsetT> tokens_indices{
max_token_out_count + delimiter_offset, stream, mr};

json_to_tokens_fst.Transduce(pda_sgids.begin(),
json_to_tokens_fst.Transduce(zip_in,
static_cast<SymbolOffsetT>(json_in.size()),
tokens.data() + delimiter_offset,
tokens_indices.data() + delimiter_offset,
Expand Down