From af65d52c7d4ca41606482926bdcc001644b7d108 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Tue, 2 Jan 2024 17:37:12 -0800 Subject: [PATCH] JSON quote normalization (#14545) The goal of this PR is to address [PR 10004](https://github.com/rapidsai/cudf/issues/10004) by supporting parsing of JSON files containing single quotes for field/value strings. Authors: - Shruti Shivakumar (https://github.com/shrshi) - Nghia Truong (https://github.com/ttnghia) Approvers: - Nghia Truong (https://github.com/ttnghia) - Mike Wilson (https://github.com/hyperbolic2346) - Elias Stehle (https://github.com/elstehle) URL: https://github.com/rapidsai/cudf/pull/14545 --- cpp/tests/CMakeLists.txt | 3 +- cpp/tests/io/fst/quote_normalization_test.cu | 331 +++++++++++++++++++ 2 files changed, 333 insertions(+), 1 deletion(-) create mode 100644 cpp/tests/io/fst/quote_normalization_test.cu diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 40d745338f4..d0abcc225d1 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -306,6 +306,7 @@ ConfigureTest( PERCENT 30 ) target_link_libraries(DATA_CHUNK_SOURCE_TEST PRIVATE ZLIB::ZLIB) +ConfigureTest(QUOTE_NORMALIZATION_TEST io/fst/quote_normalization_test.cu) ConfigureTest(LOGICAL_STACK_TEST io/fst/logical_stack_test.cu) ConfigureTest(FST_TEST io/fst/fst_test.cu) ConfigureTest(TYPE_INFERENCE_TEST io/type_inference_test.cu) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu new file mode 100644 index 00000000000..e2636ab029f --- /dev/null +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace { + +// Type used to represent the atomic symbol type used within the finite-state machine +// TODO: type aliasing to be declared in a common header for better maintainability and +// pre-empt future bugs +using SymbolT = char; +using StateT = char; + +// Type sufficiently large to index symbols within the input and output (may be unsigned) +using SymbolOffsetT = uint32_t; +enum class dfa_states : char { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES }; +enum class dfa_symbol_group_id : uint32_t { + DOUBLE_QUOTE_CHAR, ///< Quote character SG: " + SINGLE_QUOTE_CHAR, ///< Quote character SG: ' + ESCAPE_CHAR, ///< Escape character SG: '\' + NEWLINE_CHAR, ///< Newline character SG: '\n' + OTHER_SYMBOLS, ///< SG implicitly matching all other characters + NUM_SYMBOL_GROUPS ///< Total number of symbol groups +}; + +// Aliases for readability of the transition table +constexpr auto TT_OOS = dfa_states::TT_OOS; +constexpr auto TT_DQS = dfa_states::TT_DQS; +constexpr auto TT_SQS = dfa_states::TT_SQS; +constexpr auto TT_DEC = dfa_states::TT_DEC; +constexpr auto TT_SEC = dfa_states::TT_SEC; +constexpr auto TT_NUM_STATES = static_cast(dfa_states::TT_NUM_STATES); +constexpr auto NUM_SYMBOL_GROUPS = static_cast(dfa_symbol_group_id::NUM_SYMBOL_GROUPS); + +// The i-th string representing all the characters of a symbol group +std::array, NUM_SYMBOL_GROUPS - 1> const qna_sgs{ + {{'\"'}, {'\''}, {'\\'}, {'\n'}}}; + +// Transition table +std::array, TT_NUM_STATES> const qna_state_tt{{ + /* IN_STATE " ' \ \n OTHER */ + /* TT_OOS */ {{TT_DQS, TT_SQS, TT_OOS, TT_OOS, TT_OOS}}, + /* TT_DQS */ {{TT_OOS, TT_DQS, TT_DEC, TT_OOS, TT_DQS}}, + /* TT_SQS */ {{TT_SQS, TT_OOS, TT_SEC, TT_OOS, TT_SQS}}, + /* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_OOS, TT_DQS}}, + /* TT_SEC */ {{TT_SQS, TT_SQS, TT_SQS, TT_OOS, TT_SQS}}, +}}; + +// The DFA's starting state +constexpr char start_state = static_cast(TT_OOS); + +struct TransduceToNormalizedQuotes { + /** + * @brief Returns the -th output symbol on the transition (state_id, match_id). + */ + template + constexpr CUDF_HOST_DEVICE SymbolT operator()(StateT const state_id, + SymbolGroupT const match_id, + RelativeOffsetT const relative_offset, + SymbolT const read_symbol) const + { + // -------- TRANSLATION TABLE ------------ + // Let the alphabet set be Sigma + // --------------------------------------- + // ---------- NON-SPECIAL CASES: ---------- + // Output symbol same as input symbol + // state | read_symbol -> output_symbol + // DQS | Sigma -> Sigma + // DEC | Sigma -> Sigma + // OOS | Sigma\{'} -> Sigma\{'} + // SQS | Sigma\{', "} -> Sigma\{', "} + // ---------- SPECIAL CASES: -------------- + // Input symbol translates to output symbol + // OOS | {'} -> {"} + // SQS | {'} -> {"} + // SQS | {"} -> {\"} + // SQS | {\} -> + // SEC | {'} -> {'} + // SEC | Sigma\{'} -> {\*} + + // Whether this transition translates to the escape sequence: \" + const bool outputs_escape_sequence = + (state_id == static_cast(dfa_states::TT_SQS)) && + (match_id == static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR)); + // Case when a double quote needs to be replaced by the escape sequence: \" + if (outputs_escape_sequence) { return (relative_offset == 0) ? '\\' : '"'; } + // Case when a single quote needs to be replaced by a double quote + if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SQS)) || + (state_id == static_cast(dfa_states::TT_OOS)))) { + return '"'; + } + // Case when the read symbol is an escape character - the actual translation for \ for some + // symbol is handled by transitions from SEC. For now, there is no output for this + // transition + if ((match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SQS)))) { + return 0; + } + // Case when an escaped single quote in an input single-quoted string needs to be replaced by an + // unescaped single quote + if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SEC)))) { + return '\''; + } + // Case when an escaped symbol that is not a single-quote needs to be replaced with \ + if (state_id == static_cast(dfa_states::TT_SEC)) { + return (relative_offset == 0) ? '\\' : read_symbol; + } + // In all other cases we simply output the input symbol + return read_symbol; + } + + /** + * @brief Returns the number of output characters for a given transition. During quote + * normalization, we always emit one output character (i.e., either the input character or the + * single quote-input replaced by a double quote), except when we need to escape a double quote + * that was previously inside a single-quoted string. + */ + template + constexpr CUDF_HOST_DEVICE int32_t operator()(StateT const state_id, + SymbolGroupT const match_id, + SymbolT const read_symbol) const + { + // Whether this transition translates to the escape sequence: \" + const bool sqs_outputs_escape_sequence = + (state_id == static_cast(dfa_states::TT_SQS)) && + (match_id == static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR)); + // Number of characters to output on this transition + if (sqs_outputs_escape_sequence) { return 2; } + // Whether this transition translates to the escape sequence \ or unescaped ' + const bool sec_outputs_escape_sequence = + (state_id == static_cast(dfa_states::TT_SEC)) && + (match_id != static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)); + // Number of characters to output on this transition + if (sec_outputs_escape_sequence) { return 2; } + // Whether this transition translates to no output + const bool sqs_outputs_nop = + (state_id == static_cast(dfa_states::TT_SQS)) && + (match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)); + // Number of characters to output on this transition + if (sqs_outputs_nop) { return 0; } + return 1; + } +}; + +} // namespace + +// Base test fixture for tests +struct FstTest : public cudf::test::BaseFixture {}; + +void run_test(std::string& input, std::string& output) +{ + // Prepare cuda stream for data transfers & kernels + rmm::cuda_stream stream{}; + rmm::cuda_stream_view stream_view(stream); + + auto parser = cudf::io::fst::detail::make_fst( + cudf::io::fst::detail::make_symbol_group_lut(qna_sgs), + cudf::io::fst::detail::make_transition_table(qna_state_tt), + cudf::io::fst::detail::make_translation_functor(TransduceToNormalizedQuotes{}), + stream); + + auto d_input_scalar = cudf::make_string_scalar(input, stream_view); + auto& d_input = static_cast&>(*d_input_scalar); + + // Prepare input & output buffers + constexpr std::size_t single_item = 1; + cudf::detail::hostdevice_vector output_gpu(input.size() * 2, stream_view); + cudf::detail::hostdevice_vector output_gpu_size(single_item, stream_view); + + // Allocate device-side temporary storage & run algorithm + parser.Transduce(d_input.data(), + static_cast(d_input.size()), + output_gpu.device_ptr(), + thrust::make_discard_iterator(), + output_gpu_size.device_ptr(), + start_state, + stream_view); + + // Async copy results from device to host + output_gpu.device_to_host_async(stream_view); + output_gpu_size.device_to_host_async(stream_view); + + // Make sure results have been copied back to host + stream.synchronize(); + + // Verify results + ASSERT_EQ(output_gpu_size[0], output.size()); + CUDF_TEST_EXPECT_VECTOR_EQUAL(output_gpu, output, output.size()); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization1) +{ + std::string input = R"({"A":'TEST"'})"; + std::string output = R"({"A":"TEST\""})"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization2) +{ + std::string input = R"({'A':"TEST'"} ['OTHER STUFF'])"; + std::string output = R"({"A":"TEST'"} ["OTHER STUFF"])"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization3) +{ + std::string input = R"(['{"A": "B"}',"{'A': 'B'}"])"; + std::string output = R"(["{\"A\": \"B\"}","{'A': 'B'}"])"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization4) +{ + std::string input = R"({"ain't ain't a word and you ain't supposed to say it":'"""""""""""'})"; + std::string output = + R"({"ain't ain't a word and you ain't supposed to say it":"\"\"\"\"\"\"\"\"\"\"\""})"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization5) +{ + std::string input = R"({"\"'\"'\"'\"'":'"\'"\'"\'"\'"'})"; + std::string output = R"({"\"'\"'\"'\"'":"\"'\"'\"'\"'\""})"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization6) +{ + std::string input = R"([{"ABC':'CBA":'XYZ":"ZXY'}])"; + std::string output = R"([{"ABC':'CBA":"XYZ\":\"ZXY"}])"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization7) +{ + std::string input = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; + std::string output = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization8) +{ + std::string input = R"(['\t','\\t','\\','\\\"\'\\\\','\n','\b','\u0012'])"; + std::string output = R"(["\t","\\t","\\","\\\"'\\\\","\n","\b","\u0012"])"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid1) +{ + std::string input = R"(["THIS IS A TEST'])"; + std::string output = R"(["THIS IS A TEST'])"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid2) +{ + std::string input = R"(['THIS IS A TEST"])"; + std::string output = R"(["THIS IS A TEST\"])"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid3) +{ + std::string input = R"({"MORE TEST'N":'RESUL})"; + std::string output = R"({"MORE TEST'N":"RESUL})"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid4) +{ + std::string input = R"({"NUMBER":100'0,'STRING':'SOMETHING'})"; + std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING"})"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid5) +{ + std::string input = R"({'NUMBER':100"0,"STRING":"SOMETHING"})"; + std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING"})"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid6) +{ + std::string input = R"({'a':'\\''})"; + std::string output = R"({"a":"\\""})"; + run_test(input, output); +} + +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid7) +{ + std::string input = R"(}'a': 'b'{)"; + std::string output = R"(}"a": "b"{)"; + run_test(input, output); +} + +CUDF_TEST_PROGRAM_MAIN()