Skip to content

Commit

Permalink
Changes JSON reader's recovery option's behaviour to ignore all chara…
Browse files Browse the repository at this point in the history
…cters after a valid JSON record (#14279)

Closes #14226.

The new behvior of `JSON_LINES_RECOVER` will now ignore excess characters after the first valid JSON record on each JSON line.
```
{ "number": 1 } 
{ "number": 1 } xyz
{ "number": 1 } {}
{ "number": 1 } { "number": 4 }
```

**Implementation details:**
The JSON parser pushdown automaton was changed for `JSON_LINES_RECOVER` format such that when in state `PD_PVL` (`post-value`, "I have just finished parsing a value") and when the stack context is `ROOT` ("I'm not somewhere within a list or struct"), we just treat all characters as "white space" until encountering a newline character. `post-value` in stack context `ROOT` is exactly the condition we are in after having parsed the first valid record of a JSON line. _Thanks to @karthikeyann for suggesting to use `PD_PVL` as the capturing state._ 

As the stack context is generated upfront, we have to fix up and correct the stack context to set the stack context as `ROOT` stack context for all these excess characters. I.e., (`_` means `ROOT` stack context, `{` means within a `STRUCT` stack context):
```
in:    {"a":1}{"this is supposed to be ignored"}
stack: _{{{{{{_{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{{
```
Needs to be fixed up to become:
```
in:    {"a":1}{"this is supposed to be ignored"}
stack: _{{{{{{__________________________________
```

Authors:
  - Elias Stehle (https://github.com/elstehle)
  - Karthikeyan (https://github.com/karthikeyann)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Nghia Truong (https://github.com/ttnghia)
  - Karthikeyan (https://github.com/karthikeyann)

URL: #14279
  • Loading branch information
elstehle authored Oct 20, 2023
1 parent d36904b commit 50e2211
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 24 deletions.
4 changes: 2 additions & 2 deletions cpp/src/io/fst/lookup_tables.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -753,15 +753,15 @@ class TranslationOp {
RelativeOffsetT const relative_offset,
SymbolT const read_symbol) const
{
return translation_op(*this, state_id, match_id, relative_offset, read_symbol);
return translation_op(state_id, match_id, relative_offset, read_symbol);
}

template <typename StateIndexT, typename SymbolIndexT, typename SymbolT>
constexpr CUDF_HOST_DEVICE auto operator()(StateIndexT const state_id,
SymbolIndexT const match_id,
SymbolT const read_symbol) const
{
return translation_op(*this, state_id, match_id, read_symbol);
return translation_op(state_id, match_id, read_symbol);
}
};

Expand Down
165 changes: 146 additions & 19 deletions cpp/src/io/json/nested_json_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,98 @@ void check_input_size(std::size_t input_size)

namespace cudf::io::json {

// FST to help fixing the stack context of characters that follow the first record on each JSON line
namespace fix_stack_of_excess_chars {

// Type used to represent the target state in the transition table
using StateT = char;

// Type used to represent a symbol group id
using SymbolGroupT = uint8_t;

/**
* @brief Definition of the DFA's states.
*/
enum class dfa_states : StateT {
// Before the first record on the JSON line
BEFORE,
// Within the first record on the JSON line
WITHIN,
// Excess data that follows the first record on the JSON line
EXCESS,
// Total number of states
NUM_STATES
};

/**
* @brief Definition of the symbol groups
*/
enum class dfa_symbol_group_id : SymbolGroupT {
ROOT, ///< Symbol for root stack context
DELIMITER, ///< Line delimiter symbol group
OTHER, ///< Symbol group that implicitly matches all other tokens
NUM_SYMBOL_GROUPS ///< Total number of symbol groups
};

constexpr auto TT_NUM_STATES = static_cast<StateT>(dfa_states::NUM_STATES);
constexpr auto NUM_SYMBOL_GROUPS = static_cast<uint32_t>(dfa_symbol_group_id::NUM_SYMBOL_GROUPS);

/**
* @brief Function object to map (input_symbol,stack_context) tuples to a symbol group.
*/
struct SymbolPairToSymbolGroupId {
CUDF_HOST_DEVICE SymbolGroupT operator()(thrust::tuple<SymbolT, StackSymbolT> symbol) const
{
auto const input_symbol = thrust::get<0>(symbol);
auto const stack_symbol = thrust::get<1>(symbol);
return static_cast<SymbolGroupT>(
input_symbol == '\n'
? dfa_symbol_group_id::DELIMITER
: (stack_symbol == '_' ? dfa_symbol_group_id::ROOT : dfa_symbol_group_id::OTHER));
}
};

/**
* @brief Translation function object that fixes the stack context of excess data that follows after
* the first JSON record on each line.
*/
struct TransduceInputOp {
template <typename RelativeOffsetT, typename SymbolT>
constexpr CUDF_HOST_DEVICE StackSymbolT operator()(StateT const state_id,
SymbolGroupT const match_id,
RelativeOffsetT const relative_offset,
SymbolT const read_symbol) const
{
if (state_id == static_cast<StateT>(dfa_states::EXCESS)) { return '_'; }
return thrust::get<1>(read_symbol);
}

template <typename SymbolT>
constexpr CUDF_HOST_DEVICE int32_t operator()(StateT const state_id,
SymbolGroupT const match_id,
SymbolT const read_symbol) const
{
constexpr int32_t single_output_item = 1;
return single_output_item;
}
};

// Aliases for readability of the transition table
constexpr auto TT_BEFORE = dfa_states::BEFORE;
constexpr auto TT_INSIDE = dfa_states::WITHIN;
constexpr auto TT_EXCESS = dfa_states::EXCESS;

// Transition table
std::array<std::array<dfa_states, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> constexpr transition_table{
{/* IN_STATE ROOT NEWLINE OTHER */
/* TT_BEFORE */ {{TT_BEFORE, TT_BEFORE, TT_INSIDE}},
/* TT_INSIDE */ {{TT_EXCESS, TT_BEFORE, TT_INSIDE}},
/* TT_EXCESS */ {{TT_EXCESS, TT_BEFORE, TT_EXCESS}}}};

// The DFA's starting state
constexpr auto start_state = static_cast<StateT>(dfa_states::BEFORE);
} // namespace fix_stack_of_excess_chars

// FST to prune tokens of invalid lines for recovering JSON lines format
namespace token_filter {

Expand Down Expand Up @@ -146,9 +238,8 @@ struct UnwrapTokenFromSymbolOp {
* invalid lines.
*/
struct TransduceToken {
template <typename TransducerTableT, typename RelativeOffsetT, typename SymbolT>
constexpr CUDF_HOST_DEVICE SymbolT operator()(TransducerTableT const&,
StateT const state_id,
template <typename RelativeOffsetT, typename SymbolT>
constexpr CUDF_HOST_DEVICE SymbolT operator()(StateT const state_id,
SymbolGroupT const match_id,
RelativeOffsetT const relative_offset,
SymbolT const read_symbol) const
Expand All @@ -165,9 +256,8 @@ struct TransduceToken {
}
}

template <typename TransducerTableT, typename SymbolT>
constexpr CUDF_HOST_DEVICE int32_t operator()(TransducerTableT const&,
StateT const state_id,
template <typename SymbolT>
constexpr CUDF_HOST_DEVICE int32_t operator()(StateT const state_id,
SymbolGroupT const match_id,
SymbolT const read_symbol) const
{
Expand Down Expand Up @@ -643,6 +733,11 @@ auto get_transition_table(json_format_cfg_t format)
// PD_ANL describes the target state after a new line after encountering error state
auto const PD_ANL = (format == json_format_cfg_t::JSON_LINES_RECOVER) ? PD_BOV : PD_ERR;

// Target state after having parsed the first JSON value on a JSON line
// Spark has the special need to ignore everything that comes after the first JSON object
// on a JSON line instead of marking those as invalid
auto const PD_AFS = (format == json_format_cfg_t::JSON_LINES_RECOVER) ? PD_PVL : PD_ERR;

// First row: empty stack ("root" level of the JSON)
// Second row: '[' on top of stack (we're parsing a list value)
// Third row: '{' on top of stack (we're parsing a struct value)
Expand All @@ -668,7 +763,7 @@ auto get_transition_table(json_format_cfg_t format)
PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_BOV, PD_STR,
PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_STR, PD_BOV, PD_STR};
pda_tt[static_cast<StateT>(pda_state_t::PD_PVL)] = {
PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_ERR, PD_PVL, PD_BOV, PD_ERR,
PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_AFS, PD_PVL, PD_BOV, PD_AFS,
PD_ERR, PD_ERR, PD_ERR, PD_PVL, PD_ERR, PD_ERR, PD_BOV, PD_ERR, PD_PVL, PD_BOV, PD_ERR,
PD_ERR, PD_ERR, PD_PVL, PD_ERR, PD_ERR, PD_ERR, PD_BFN, PD_ERR, PD_PVL, PD_BOV, PD_ERR};
pda_tt[static_cast<StateT>(pda_state_t::PD_BFN)] = {
Expand Down Expand Up @@ -733,6 +828,18 @@ auto get_translation_table(bool recover_from_error)
return regular_tokens;
};

/**
* @brief Helper function that returns `recovering_tokens` if `recover_from_error` is true and
* returns `regular_tokens` otherwise. This is used to ignore excess characters after the first
* value in the case of JSON lines that recover from invalid lines, as Spark ignores any excess
* characters that follow the first record on a JSON line.
*/
auto alt_tokens = [recover_from_error](std::vector<char> regular_tokens,
std::vector<char> recovering_tokens) {
if (recover_from_error) { return recovering_tokens; }
return regular_tokens;
};

std::array<std::array<std::vector<char>, NUM_PDA_SGIDS>, PD_NUM_STATES> pda_tlt;
pda_tlt[static_cast<StateT>(pda_state_t::PD_BOV)] = {{ /*ROOT*/
{StructBegin}, // OPENING_BRACE
Expand Down Expand Up @@ -920,18 +1027,18 @@ auto get_translation_table(bool recover_from_error)
{}}}; // OTHER

pda_tlt[static_cast<StateT>(pda_state_t::PD_PVL)] = {
{ /*ROOT*/
{ErrorBegin}, // OPENING_BRACE
{ErrorBegin}, // OPENING_BRACKET
{ErrorBegin}, // CLOSING_BRACE
{ErrorBegin}, // CLOSING_BRACKET
{ErrorBegin}, // QUOTE
{ErrorBegin}, // ESCAPE
{ErrorBegin}, // COMMA
{ErrorBegin}, // COLON
{}, // WHITE_SPACE
nl_tokens({}, {}), // LINE_BREAK
{ErrorBegin}, // OTHER
{ /*ROOT*/
{alt_tokens({ErrorBegin}, {})}, // OPENING_BRACE
{alt_tokens({ErrorBegin}, {})}, // OPENING_BRACKET
{alt_tokens({ErrorBegin}, {})}, // CLOSING_BRACE
{alt_tokens({ErrorBegin}, {})}, // CLOSING_BRACKET
{alt_tokens({ErrorBegin}, {})}, // QUOTE
{alt_tokens({ErrorBegin}, {})}, // ESCAPE
{alt_tokens({ErrorBegin}, {})}, // COMMA
{alt_tokens({ErrorBegin}, {})}, // COLON
{}, // WHITE_SPACE
nl_tokens({}, {}), // LINE_BREAK
{alt_tokens({ErrorBegin}, {})}, // OTHER
/*LIST*/
{ErrorBegin}, // OPENING_BRACE
{ErrorBegin}, // OPENING_BRACKET
Expand Down Expand Up @@ -1446,6 +1553,26 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> ge
// character.
auto zip_in = thrust::make_zip_iterator(json_in.data(), stack_symbols.data());

// Spark, as the main stakeholder in the `recover_from_error` option, has the specific need to
// ignore any characters that follow the first value on each JSON line. This is an FST that
// fixes the stack context for those excess characters. That is, that all those excess characters
// will be interpreted in the root stack context
if (recover_from_error) {
auto fix_stack_of_excess_chars = fst::detail::make_fst(
fst::detail::make_symbol_group_lookup_op(
fix_stack_of_excess_chars::SymbolPairToSymbolGroupId{}),
fst::detail::make_transition_table(fix_stack_of_excess_chars::transition_table),
fst::detail::make_translation_functor(fix_stack_of_excess_chars::TransduceInputOp{}),
stream);
fix_stack_of_excess_chars.Transduce(zip_in,
static_cast<SymbolOffsetT>(json_in.size()),
stack_symbols.data(),
thrust::make_discard_iterator(),
thrust::make_discard_iterator(),
fix_stack_of_excess_chars::start_state,
stream);
}

constexpr auto max_translation_table_size =
tokenizer_pda::NUM_PDA_SGIDS *
static_cast<tokenizer_pda::StateT>(tokenizer_pda::pda_state_t::PD_NUM_STATES);
Expand Down
71 changes: 69 additions & 2 deletions cpp/tests/io/json_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,11 +1957,11 @@ TEST_F(JsonReaderTest, JSONLinesRecovering)
// 2 -> (invalid)
R"({"b":{"a":[321})"
"\n"
// 3 -> c: [1] (valid)
// 3 -> c: 1.2 (valid)
R"({"c":1.2})"
"\n"
"\n"
// 4 -> a: 123 (valid)
// 4 -> a: 4 (valid)
R"({"a":4})"
"\n"
// 5 -> (invalid)
Expand Down Expand Up @@ -2020,4 +2020,71 @@ TEST_F(JsonReaderTest, JSONLinesRecovering)
c_validity.cbegin()});
}

TEST_F(JsonReaderTest, JSONLinesRecoveringIgnoreExcessChars)
{
/**
* @brief Spark has the specific need to ignore extra characters that come after the first record
* on a JSON line
*/
std::string data =
// 0 -> a: -2 (valid)
R"({"a":-2}{})"
"\n"
// 1 -> (invalid)
R"({"b":{}should_be_invalid})"
"\n"
// 2 -> b (valid)
R"({"b":{"a":3} })"
"\n"
// 3 -> c: (valid)
R"({"c":1.2 } )"
"\n"
"\n"
// 4 -> (valid)
R"({"a":4} 123)"
"\n"
// 5 -> (valid)
R"({"a":5}//Comment after record)"
"\n"
// 6 -> (valid)
R"({"a":6} //Comment after whitespace)"
"\n"
// 7 -> (invalid)
R"({"a":5 //Invalid Comment within record})";

auto filepath = temp_env->get_temp_dir() + "RecoveringLinesExcessChars.json";
{
std::ofstream outfile(filepath, std::ofstream::out);
outfile << data;
}

cudf::io::json_reader_options in_options =
cudf::io::json_reader_options::builder(cudf::io::source_info{filepath})
.lines(true)
.recovery_mode(cudf::io::json_recovery_mode_t::RECOVER_WITH_NULL);

cudf::io::table_with_metadata result = cudf::io::read_json(in_options);

EXPECT_EQ(result.tbl->num_columns(), 3);
EXPECT_EQ(result.tbl->num_rows(), 8);
EXPECT_EQ(result.tbl->get_column(0).type().id(), cudf::type_id::INT64);
EXPECT_EQ(result.tbl->get_column(1).type().id(), cudf::type_id::STRUCT);
EXPECT_EQ(result.tbl->get_column(2).type().id(), cudf::type_id::FLOAT64);

std::vector<bool> a_validity{true, false, false, false, true, true, true, false};
std::vector<bool> b_validity{false, false, true, false, false, false, false, false};
std::vector<bool> c_validity{false, false, false, true, false, false, false, false};

// Child column b->a
auto b_a_col = int64_wrapper({0, 0, 3, 0, 0, 0, 0, 0});

CUDF_TEST_EXPECT_COLUMNS_EQUAL(result.tbl->get_column(0),
int64_wrapper{{-2, 0, 0, 0, 4, 5, 6, 0}, a_validity.cbegin()});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(
result.tbl->get_column(1), cudf::test::structs_column_wrapper({b_a_col}, b_validity.cbegin()));
CUDF_TEST_EXPECT_COLUMNS_EQUAL(
result.tbl->get_column(2),
float64_wrapper{{0.0, 0.0, 0.0, 1.2, 0.0, 0.0, 0.0, 0.0}, c_validity.cbegin()});
}

CUDF_TEST_PROGRAM_MAIN()
2 changes: 1 addition & 1 deletion cpp/tests/io/nested_json_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ TEST_F(JsonTest, RecoveringTokenStream)
{
// Test input. Inline comments used to indicate character indexes
// 012345678 <= line 0
std::string const input = R"({"a":-2},)"
std::string const input = R"({"a":2 {})"
// 9
"\n"
// 01234 <= line 1
Expand Down

0 comments on commit 50e2211

Please sign in to comment.