From 3c7ac1ee2e12d1e687487ce59c867ea3702af8fb Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Fri, 1 Dec 2023 23:22:55 +0000 Subject: [PATCH 01/11] simple quote normalization test --- cpp/src/io/fst/lookup_tables.cuh | 3 +- cpp/tests/CMakeLists.txt | 1 + cpp/tests/io/fst/quote_normalization_test.cu | 244 +++++++++++++++++++ 3 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 cpp/tests/io/fst/quote_normalization_test.cu diff --git a/cpp/src/io/fst/lookup_tables.cuh b/cpp/src/io/fst/lookup_tables.cuh index 42036b79751..32e227613ab 100644 --- a/cpp/src/io/fst/lookup_tables.cuh +++ b/cpp/src/io/fst/lookup_tables.cuh @@ -609,7 +609,8 @@ class TransducerLookupTable { } // Check whether runtime-provided table size exceeds the compile-time given max. table size - CUDF_EXPECTS(out_symbols.size() <= MAX_TABLE_SIZE, "Unsupported translation table"); + // TODO: Support for multicharater translations? + // CUDF_EXPECTS(out_symbols.size() <= MAX_TABLE_SIZE, "Unsupported translation table"); // Prepare host-side data to be copied and passed to the device std::copy( diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b35c72b9e9d..cc766e4aeda 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -304,6 +304,7 @@ ConfigureTest( PERCENT 30 ) target_link_libraries(DATA_CHUNK_SOURCE_TEST PRIVATE ZLIB::ZLIB) +ConfigureTest(QUOTE_NORMALIZATION_TEST io/fst/quote_normalization_test.cu) ConfigureTest(LOGICAL_STACK_TEST io/fst/logical_stack_test.cu) ConfigureTest(FST_TEST io/fst/fst_test.cu) ConfigureTest(TYPE_INFERENCE_TEST io/type_inference_test.cu) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu new file mode 100644 index 00000000000..fefe34d5141 --- /dev/null +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +//------------------------------------------------------------------------------ +// CPU-BASED IMPLEMENTATIONS FOR VERIFICATION +//------------------------------------------------------------------------------ +/** + * @brief CPU-based implementation of a finite-state transducer (FST). + * + * @tparam InputItT Forward input iterator type to symbols fed into the FST + * @tparam StateT Type representing states of the finite-state machine + * @tparam SymbolGroupLutT Sequence container of symbol groups. Each symbol group is a sequence + * container to symbols within that group. + * @tparam TransitionTableT Two-dimensional container type + * @tparam TransducerTableT Two-dimensional container type + * @tparam OutputItT Forward output iterator type + * @tparam IndexOutputItT Forward output iterator type + * @param[in] begin Forward iterator to the beginning of the symbol sequence + * @param[in] end Forward iterator to one past the last element of the symbol sequence + * @param[in] init_state The starting state of the finite-state machine + * @param[in] symbol_group_lut Sequence container of symbol groups. Each symbol group is a sequence + * container to symbols within that group. The index of the symbol group containing a symbol being + * read will be used as symbol_gid of the transition and translation tables. + * @param[in] transition_table The two-dimensional transition table, i.e., + * transition_table[state][symbol_gid] -> new_state + * @param[in] translation_table The two-dimensional transducer table, i.e., + * translation_table[state][symbol_gid] -> range_of_output_symbols + * @param[out] out_tape A forward output iterator to which the transduced input will be written + * @param[out] out_index_tape A forward output iterator to which indexes of the symbols that + * actually caused some output are written to + * @return A pair of iterators to one past the last element of (1) the transduced output symbol + * sequence and (2) the indexes of + */ +template +static std::pair fst_baseline(InputItT begin, + InputItT end, + StateT const& init_state, + SymbolGroupLutT symbol_group_lut, + TransitionTableT transition_table, + TransducerTableT translation_table, + OutputItT out_tape, + IndexOutputItT out_index_tape) +{ + // Initialize "FSM" with starting state + StateT state = init_state; + + // To track the symbol offset within the input that caused the FST to output + std::size_t in_offset = 0; + for (auto it = begin; it < end; it++) { + // The symbol currently being read + auto const& symbol = *it; + + // Iterate over symbol groups and search for the first symbol group containing the current + // symbol, if no match is found we use cend(symbol_group_lut) as the "catch-all" symbol group + auto symbol_group_it = + std::find_if(std::cbegin(symbol_group_lut), std::cend(symbol_group_lut), [symbol](auto& sg) { + return std::find(std::cbegin(sg), std::cend(sg), symbol) != std::cend(sg); + }); + auto symbol_group = std::distance(std::cbegin(symbol_group_lut), symbol_group_it); + + // Output the translated symbols to the output tape + out_tape = std::copy(std::cbegin(translation_table[state][symbol_group]), + std::cend(translation_table[state][symbol_group]), + out_tape); + + auto out_size = std::distance(std::cbegin(translation_table[state][symbol_group]), + std::cend(translation_table[state][symbol_group])); + + out_index_tape = std::fill_n(out_index_tape, out_size, in_offset); + + // Transition the state of the finite-state machine + state = static_cast(transition_table[state][symbol_group]); + + // Continue with next symbol from input tape + in_offset++; + } + return {out_tape, out_index_tape}; +} +} // namespace + +// Base test fixture for tests +struct FstTest : public cudf::test::BaseFixture {}; + +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple) +{ + // Type used to represent the atomic symbol type used within the finite-state machine + using SymbolT = char; + + // Type sufficiently large to index symbols within the input and output (may be unsigned) + using SymbolOffsetT = uint32_t; + + // Prepare cuda stream for data transfers & kernels + rmm::cuda_stream stream{}; + rmm::cuda_stream_view stream_view(stream); + + // Test input + std::string input = R"({"A" : 'TEST"'})"; + auto d_input_scalar = cudf::make_string_scalar(input); + auto& d_input = static_cast&>(*d_input_scalar); + + // Prepare input & output buffers + constexpr std::size_t single_item = 1; + cudf::detail::hostdevice_vector output_gpu(input.size() * 2, stream_view); + cudf::detail::hostdevice_vector output_gpu_size(single_item, stream_view); + cudf::detail::hostdevice_vector out_indexes_gpu(input.size(), stream_view); + + enum class dfa_states : char { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES }; + + enum class dfa_symbol_group_id : uint32_t { + OPENING_BRACE, ///< Opening brace SG: { + OPENING_BRACKET, ///< Opening bracket SG: [ + CLOSING_BRACE, ///< Closing brace SG: } + CLOSING_BRACKET, ///< Closing bracket SG: ] + DOUBLE_QUOTE_CHAR, ///< Quote character SG: " + SINGLE_QUOTE_CHAR, ///< Quote character SG: ' + ESCAPE_CHAR, ///< Escape character SG: '\' + OTHER_SYMBOLS, ///< SG implicitly matching all other characters + NUM_SYMBOL_GROUPS ///< Total number of symbol groups + }; + + // Aliases for readability of the transition table + constexpr auto TT_OOS = dfa_states::TT_OOS; + constexpr auto TT_DQS = dfa_states::TT_DQS; + constexpr auto TT_SQS = dfa_states::TT_SQS; + constexpr auto TT_DEC = dfa_states::TT_DEC; + constexpr auto TT_SEC = dfa_states::TT_SEC; + + constexpr auto TT_NUM_STATES = static_cast(dfa_states::TT_NUM_STATES); + constexpr auto NUM_SYMBOL_GROUPS = static_cast(dfa_symbol_group_id::NUM_SYMBOL_GROUPS); + + // The i-th string representing all the characters of a symbol group + std::array const qna_sgs{"{", "[", "}", "]", "\"", "'", "\\"}; + + // Transition table + // Does not support JSON lines + std::array, TT_NUM_STATES> const qna_state_tt{{ + /* IN_STATE { [ } ] " ' \ OTHER */ + /* TT_OOS */ {{TT_OOS, TT_OOS, TT_OOS, TT_OOS, TT_DQS, TT_SQS, TT_OOS, TT_OOS}}, + /* TT_DQS */ {{TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_OOS, TT_DQS, TT_DEC, TT_DQS}}, + /* TT_SQS */ {{TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_OOS, TT_SEC, TT_SQS}}, + /* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS}}, + /* TT_SEC */ {{TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS}}, + }}; + + // Translation table (i.e., for each transition, what are the symbols that we output) + std::array, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const qna_out_tt{ + {/* IN_STATE { [ } ] " ' \ OTHER */ + /* TT_OOS */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'"'}, {'\\'}, {'x'}}}, + /* TT_DQS */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}}, + /* TT_SQS */ {{{'{'}, {'['}, {'}'}, {']'}, {'\\', '"'}, {'"'}, {'\\'}, {'x'}}}, + /* TT_DEC */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}}, + /* TT_SEC */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}}}}; + + // The DFA's starting state + constexpr char start_state = static_cast(TT_OOS); + + // Run algorithm + auto parser = cudf::io::fst::detail::make_fst( + cudf::io::fst::detail::make_symbol_group_lut(qna_sgs), + cudf::io::fst::detail::make_transition_table(qna_state_tt), + cudf::io::fst::detail::make_translation_table(qna_out_tt), + stream); + + // Allocate device-side temporary storage & run algorithm + parser.Transduce(d_input.data(), + static_cast(d_input.size()), + output_gpu.device_ptr(), + out_indexes_gpu.device_ptr(), + output_gpu_size.device_ptr(), + start_state, + stream.value()); + + // Async copy results from device to host + output_gpu.device_to_host_async(stream.view()); + out_indexes_gpu.device_to_host_async(stream.view()); + output_gpu_size.device_to_host_async(stream.view()); + + // Prepare CPU-side results for verification + std::string output_cpu{}; + std::vector out_index_cpu{}; + output_cpu.reserve(input.size()); + out_index_cpu.reserve(input.size()); + + // Run CPU-side algorithm + fst_baseline(std::begin(input), + std::end(input), + start_state, + qna_sgs, + qna_state_tt, + qna_out_tt, + std::back_inserter(output_cpu), + std::back_inserter(out_index_cpu)); + + // Make sure results have been copied back to host + stream.synchronize(); + + // Verify results + ASSERT_EQ(output_gpu_size[0], output_cpu.size()); + CUDF_TEST_EXPECT_VECTOR_EQUAL(output_gpu, output_cpu, output_cpu.size()); + // TODO: indexing for multicharacter translations + // CUDF_TEST_EXPECT_VECTOR_EQUAL(out_indexes_gpu, out_index_cpu, output_cpu.size()); +} + +CUDF_TEST_PROGRAM_MAIN() From f77bd260cb670f1dcdb969dc3b80207bf320b5bb Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Sat, 2 Dec 2023 00:48:01 +0000 Subject: [PATCH 02/11] added more test cases --- cpp/tests/io/fst/quote_normalization_test.cu | 60 +++++++++++++++----- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index fefe34d5141..275261f392d 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -121,7 +121,7 @@ static std::pair fst_baseline(InputItT begin, // Base test fixture for tests struct FstTest : public cudf::test::BaseFixture {}; -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple) +void run_test(std::string& input) { // Type used to represent the atomic symbol type used within the finite-state machine using SymbolT = char; @@ -133,17 +133,7 @@ TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple) rmm::cuda_stream stream{}; rmm::cuda_stream_view stream_view(stream); - // Test input - std::string input = R"({"A" : 'TEST"'})"; - auto d_input_scalar = cudf::make_string_scalar(input); - auto& d_input = static_cast&>(*d_input_scalar); - - // Prepare input & output buffers - constexpr std::size_t single_item = 1; - cudf::detail::hostdevice_vector output_gpu(input.size() * 2, stream_view); - cudf::detail::hostdevice_vector output_gpu_size(single_item, stream_view); - cudf::detail::hostdevice_vector out_indexes_gpu(input.size(), stream_view); - + // Run algorithm enum class dfa_states : char { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES }; enum class dfa_symbol_group_id : uint32_t { @@ -194,13 +184,21 @@ TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple) // The DFA's starting state constexpr char start_state = static_cast(TT_OOS); - // Run algorithm auto parser = cudf::io::fst::detail::make_fst( cudf::io::fst::detail::make_symbol_group_lut(qna_sgs), cudf::io::fst::detail::make_transition_table(qna_state_tt), cudf::io::fst::detail::make_translation_table(qna_out_tt), stream); + auto d_input_scalar = cudf::make_string_scalar(input); + auto& d_input = static_cast&>(*d_input_scalar); + + // Prepare input & output buffers + constexpr std::size_t single_item = 1; + cudf::detail::hostdevice_vector output_gpu(input.size() * 2, stream_view); + cudf::detail::hostdevice_vector output_gpu_size(single_item, stream_view); + cudf::detail::hostdevice_vector out_indexes_gpu(input.size(), stream_view); + // Allocate device-side temporary storage & run algorithm parser.Transduce(d_input.data(), static_cast(d_input.size()), @@ -236,9 +234,45 @@ TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple) // Verify results ASSERT_EQ(output_gpu_size[0], output_cpu.size()); + std::cout << output_cpu << std::endl; CUDF_TEST_EXPECT_VECTOR_EQUAL(output_gpu, output_cpu, output_cpu.size()); // TODO: indexing for multicharacter translations // CUDF_TEST_EXPECT_VECTOR_EQUAL(out_indexes_gpu, out_index_cpu, output_cpu.size()); } +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple1) +{ + std::string input = R"({"A":'TEST"'})"; + run_test(input); +} +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple2) +{ + std::string input = R"({'A':"TEST'"} ['OTHER STUFF'])"; + run_test(input); +} +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple3) +{ + std::string input = R"(['{"A": "B"}',"{'A': 'B'}"])"; + run_test(input); +} +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple4) +{ + std::string input = R"({"ain't ain't a word and you ain't supposed to say it":'"""""""""""'})"; + run_test(input); +} +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple5) +{ + std::string input = R"({"\"'\"'\"'\"'":'"\'"\'"\'"\'"'})"; + run_test(input); +} +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple6) +{ + std::string input = R"([{"ABC':'CBA":'XYZ":"ZXY'}])"; + run_test(input); +} +TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple7) +{ + std::string input = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; + run_test(input); +} CUDF_TEST_PROGRAM_MAIN() From 48287ae86bb0b25b59e4a8a562aee2beb5d8fa0b Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Thu, 7 Dec 2023 20:05:22 +0000 Subject: [PATCH 03/11] edge case handling; translation table -> translation functor --- cpp/src/io/fst/lookup_tables.cuh | 3 +- cpp/tests/io/fst/quote_normalization_test.cu | 396 +++++++++++-------- 2 files changed, 222 insertions(+), 177 deletions(-) diff --git a/cpp/src/io/fst/lookup_tables.cuh b/cpp/src/io/fst/lookup_tables.cuh index 32e227613ab..42036b79751 100644 --- a/cpp/src/io/fst/lookup_tables.cuh +++ b/cpp/src/io/fst/lookup_tables.cuh @@ -609,8 +609,7 @@ class TransducerLookupTable { } // Check whether runtime-provided table size exceeds the compile-time given max. table size - // TODO: Support for multicharater translations? - // CUDF_EXPECTS(out_symbols.size() <= MAX_TABLE_SIZE, "Unsupported translation table"); + CUDF_EXPECTS(out_symbols.size() <= MAX_TABLE_SIZE, "Unsupported translation table"); // Prepare host-side data to be copied and passed to the device std::copy( diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index 275261f392d..0134a512bd7 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -35,159 +35,169 @@ namespace { -//------------------------------------------------------------------------------ -// CPU-BASED IMPLEMENTATIONS FOR VERIFICATION -//------------------------------------------------------------------------------ -/** - * @brief CPU-based implementation of a finite-state transducer (FST). - * - * @tparam InputItT Forward input iterator type to symbols fed into the FST - * @tparam StateT Type representing states of the finite-state machine - * @tparam SymbolGroupLutT Sequence container of symbol groups. Each symbol group is a sequence - * container to symbols within that group. - * @tparam TransitionTableT Two-dimensional container type - * @tparam TransducerTableT Two-dimensional container type - * @tparam OutputItT Forward output iterator type - * @tparam IndexOutputItT Forward output iterator type - * @param[in] begin Forward iterator to the beginning of the symbol sequence - * @param[in] end Forward iterator to one past the last element of the symbol sequence - * @param[in] init_state The starting state of the finite-state machine - * @param[in] symbol_group_lut Sequence container of symbol groups. Each symbol group is a sequence - * container to symbols within that group. The index of the symbol group containing a symbol being - * read will be used as symbol_gid of the transition and translation tables. - * @param[in] transition_table The two-dimensional transition table, i.e., - * transition_table[state][symbol_gid] -> new_state - * @param[in] translation_table The two-dimensional transducer table, i.e., - * translation_table[state][symbol_gid] -> range_of_output_symbols - * @param[out] out_tape A forward output iterator to which the transduced input will be written - * @param[out] out_index_tape A forward output iterator to which indexes of the symbols that - * actually caused some output are written to - * @return A pair of iterators to one past the last element of (1) the transduced output symbol - * sequence and (2) the indexes of - */ -template -static std::pair fst_baseline(InputItT begin, - InputItT end, - StateT const& init_state, - SymbolGroupLutT symbol_group_lut, - TransitionTableT transition_table, - TransducerTableT translation_table, - OutputItT out_tape, - IndexOutputItT out_index_tape) -{ - // Initialize "FSM" with starting state - StateT state = init_state; - - // To track the symbol offset within the input that caused the FST to output - std::size_t in_offset = 0; - for (auto it = begin; it < end; it++) { - // The symbol currently being read - auto const& symbol = *it; - - // Iterate over symbol groups and search for the first symbol group containing the current - // symbol, if no match is found we use cend(symbol_group_lut) as the "catch-all" symbol group - auto symbol_group_it = - std::find_if(std::cbegin(symbol_group_lut), std::cend(symbol_group_lut), [symbol](auto& sg) { - return std::find(std::cbegin(sg), std::cend(sg), symbol) != std::cend(sg); - }); - auto symbol_group = std::distance(std::cbegin(symbol_group_lut), symbol_group_it); - - // Output the translated symbols to the output tape - out_tape = std::copy(std::cbegin(translation_table[state][symbol_group]), - std::cend(translation_table[state][symbol_group]), - out_tape); - - auto out_size = std::distance(std::cbegin(translation_table[state][symbol_group]), - std::cend(translation_table[state][symbol_group])); - - out_index_tape = std::fill_n(out_index_tape, out_size, in_offset); - - // Transition the state of the finite-state machine - state = static_cast(transition_table[state][symbol_group]); - - // Continue with next symbol from input tape - in_offset++; +// Type used to represent the atomic symbol type used within the finite-state machine +using SymbolT = char; +using StateT = char; +// Type sufficiently large to index symbols within the input and output (may be unsigned) +using SymbolOffsetT = uint32_t; +enum class dfa_states : char { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES }; +enum class dfa_symbol_group_id : uint32_t { + DOUBLE_QUOTE_CHAR, ///< Quote character SG: " + SINGLE_QUOTE_CHAR, ///< Quote character SG: ' + ESCAPE_CHAR, ///< Escape character SG: '\' + NEWLINE_CHAR, ///< Newline character SG: '\n' + OTHER_SYMBOLS, ///< SG implicitly matching all other characters + NUM_SYMBOL_GROUPS ///< Total number of symbol groups +}; +// Aliases for readability of the transition table +constexpr auto TT_OOS = dfa_states::TT_OOS; +constexpr auto TT_DQS = dfa_states::TT_DQS; +constexpr auto TT_SQS = dfa_states::TT_SQS; +constexpr auto TT_DEC = dfa_states::TT_DEC; +constexpr auto TT_SEC = dfa_states::TT_SEC; +constexpr auto TT_NUM_STATES = static_cast(dfa_states::TT_NUM_STATES); +constexpr auto NUM_SYMBOL_GROUPS = static_cast(dfa_symbol_group_id::NUM_SYMBOL_GROUPS); +// The i-th string representing all the characters of a symbol group +// TODO: We should reinstantiate this lookup table approach once https://github.com/rapidsai/cudf/pull/14561 is merged +// std::array, NUM_SYMBOL_GROUPS - 1> const qna_sgs{ +// {{'\"'}, {'\''}, {'\\'}, {'\n'}}}; +// Temporary workaround: +struct SymbolToSymbolGroup { + CUDF_HOST_DEVICE uint32_t operator()(SymbolT symbol) const + { + switch (symbol) { + case '\"': return static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR); + case '\'': return static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR); + case '\\': return static_cast(dfa_symbol_group_id::ESCAPE_CHAR); + case '\n': return static_cast(dfa_symbol_group_id::NEWLINE_CHAR); + default: return static_cast(dfa_symbol_group_id::OTHER_SYMBOLS); + } + }; +}; +// Transition table +std::array, TT_NUM_STATES> const qna_state_tt{{ + /* IN_STATE " ' \ \n OTHER */ + /* TT_OOS */ {{TT_DQS, TT_SQS, TT_OOS, TT_OOS, TT_OOS}}, + /* TT_DQS */ {{TT_OOS, TT_DQS, TT_DEC, TT_OOS, TT_DQS}}, + /* TT_SQS */ {{TT_SQS, TT_OOS, TT_SEC, TT_OOS, TT_SQS}}, + /* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_OOS, TT_DQS}}, + /* TT_SEC */ {{TT_SQS, TT_SQS, TT_SQS, TT_OOS, TT_SQS}}, +}}; +// The DFA's starting state +constexpr char start_state = static_cast(TT_OOS); +struct TransduceToNormalizedQuotes { + /** + * @brief Returns the -th output symbol on the transition (state_id, match_id). + */ + template + constexpr CUDF_HOST_DEVICE SymbolT operator()(StateT const state_id, + SymbolGroupT const match_id, + RelativeOffsetT const relative_offset, + SymbolT const read_symbol) const + { + // -------- TRANSLATION TABLE ------------ + // state | read_symbol -> output_symbols + // DQS | * -> * + // DEC | * -> * + // OOS | ' -> " + // OOS | * -> * + // SQS | " -> \" + // SQS | ' -> " + // SQS | * -> * + // SEC | * -> * + // ---------- SPECIAL CASES: -------------- (anything else translates to input symbol) + // OOS | ' -> " + // SQS | " -> \" + // SQS | ' -> " + // SQS | \\ -> + // SEC | ' -> ' + // SEC | Sigma\{'} -> \* + // Whether this transition translates to the escape sequence: \" + const bool outputs_escape_sequence = + (state_id == static_cast(dfa_states::TT_SQS)) && + (match_id == static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR)); + // Case when a double quote needs to be replaced by the escape sequence: \" + if (outputs_escape_sequence) { + return (relative_offset == 0) ? '\\' : '"'; + } + // Case when a single quote needs to be replaced by a double quote + else if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SQS)) || + (state_id == static_cast(dfa_states::TT_OOS)))) { + return '"'; + } + // Case when the read symbol is an escape character - the actual translation for \ for some symbol is handled by + // transitions from SEC. For now, there is no output for this transition + else if ((match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SQS)))) { + return 0; + } + // Case when an escaped single quote in an input single-quoted string needs to be replaced by an unescaped single quote + else if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SEC)))) { + return '\''; + } + // Case when an escaped symbol that is not a single-quote needs to be replaced with \ + else if (state_id == static_cast(dfa_states::TT_SEC)) { + return (relative_offset == 0) ? '\\' : read_symbol; + } + // In all other cases we simply output the input symbol + /* + else if (!((match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SQS))))) { + return read_symbol; + } + */ + else return read_symbol; } - return {out_tape, out_index_tape}; -} + /** + * @brief Returns the number of output characters for a given transition. During quote + * normalization, we always emit one output character (i.e., either the input character or the + * single quote-input replaced by a double quote), except when we need to escape a double quote + * that was previously inside a single-quoted string. + */ + template + constexpr CUDF_HOST_DEVICE int32_t operator()(StateT const state_id, + SymbolGroupT const match_id, + SymbolT const read_symbol) const + { + // Whether this transition translates to the escape sequence: \" + const bool sqs_outputs_escape_sequence = + (state_id == static_cast(dfa_states::TT_SQS)) && + (match_id == static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR)); + // Number of characters to output on this transition + if(sqs_outputs_escape_sequence) return 2; + // Whether this transition translates to the escape sequence \ or unescaped ' + const bool sec_outputs_escape_sequence = + (state_id == static_cast(dfa_states::TT_SEC)) && + (match_id != static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)); + // Number of characters to output on this transition + if(sec_outputs_escape_sequence) return 2; + // Whether this transition translates to no output + const bool sqs_outputs_nop = + (state_id == static_cast(dfa_states::TT_SQS)) && + (match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)); + // Number of characters to output on this transition + if(sqs_outputs_nop) return 0; + return 1; + } +}; + } // namespace // Base test fixture for tests struct FstTest : public cudf::test::BaseFixture {}; -void run_test(std::string& input) +void run_test(std::string& input, std::string &output) { - // Type used to represent the atomic symbol type used within the finite-state machine - using SymbolT = char; - - // Type sufficiently large to index symbols within the input and output (may be unsigned) - using SymbolOffsetT = uint32_t; - // Prepare cuda stream for data transfers & kernels rmm::cuda_stream stream{}; rmm::cuda_stream_view stream_view(stream); - // Run algorithm - enum class dfa_states : char { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES }; - - enum class dfa_symbol_group_id : uint32_t { - OPENING_BRACE, ///< Opening brace SG: { - OPENING_BRACKET, ///< Opening bracket SG: [ - CLOSING_BRACE, ///< Closing brace SG: } - CLOSING_BRACKET, ///< Closing bracket SG: ] - DOUBLE_QUOTE_CHAR, ///< Quote character SG: " - SINGLE_QUOTE_CHAR, ///< Quote character SG: ' - ESCAPE_CHAR, ///< Escape character SG: '\' - OTHER_SYMBOLS, ///< SG implicitly matching all other characters - NUM_SYMBOL_GROUPS ///< Total number of symbol groups - }; - - // Aliases for readability of the transition table - constexpr auto TT_OOS = dfa_states::TT_OOS; - constexpr auto TT_DQS = dfa_states::TT_DQS; - constexpr auto TT_SQS = dfa_states::TT_SQS; - constexpr auto TT_DEC = dfa_states::TT_DEC; - constexpr auto TT_SEC = dfa_states::TT_SEC; - - constexpr auto TT_NUM_STATES = static_cast(dfa_states::TT_NUM_STATES); - constexpr auto NUM_SYMBOL_GROUPS = static_cast(dfa_symbol_group_id::NUM_SYMBOL_GROUPS); - - // The i-th string representing all the characters of a symbol group - std::array const qna_sgs{"{", "[", "}", "]", "\"", "'", "\\"}; - - // Transition table - // Does not support JSON lines - std::array, TT_NUM_STATES> const qna_state_tt{{ - /* IN_STATE { [ } ] " ' \ OTHER */ - /* TT_OOS */ {{TT_OOS, TT_OOS, TT_OOS, TT_OOS, TT_DQS, TT_SQS, TT_OOS, TT_OOS}}, - /* TT_DQS */ {{TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_OOS, TT_DQS, TT_DEC, TT_DQS}}, - /* TT_SQS */ {{TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_OOS, TT_SEC, TT_SQS}}, - /* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS}}, - /* TT_SEC */ {{TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS}}, - }}; - - // Translation table (i.e., for each transition, what are the symbols that we output) - std::array, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const qna_out_tt{ - {/* IN_STATE { [ } ] " ' \ OTHER */ - /* TT_OOS */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'"'}, {'\\'}, {'x'}}}, - /* TT_DQS */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}}, - /* TT_SQS */ {{{'{'}, {'['}, {'}'}, {']'}, {'\\', '"'}, {'"'}, {'\\'}, {'x'}}}, - /* TT_DEC */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}}, - /* TT_SEC */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}}}}; - - // The DFA's starting state - constexpr char start_state = static_cast(TT_OOS); - auto parser = cudf::io::fst::detail::make_fst( - cudf::io::fst::detail::make_symbol_group_lut(qna_sgs), + cudf::io::fst::detail::make_symbol_group_lookup_op(SymbolToSymbolGroup{}), cudf::io::fst::detail::make_transition_table(qna_state_tt), - cudf::io::fst::detail::make_translation_table(qna_out_tt), + cudf::io::fst::detail::make_translation_functor(TransduceToNormalizedQuotes{}), stream); auto d_input_scalar = cudf::make_string_scalar(input); @@ -213,66 +223,102 @@ void run_test(std::string& input) out_indexes_gpu.device_to_host_async(stream.view()); output_gpu_size.device_to_host_async(stream.view()); - // Prepare CPU-side results for verification - std::string output_cpu{}; - std::vector out_index_cpu{}; - output_cpu.reserve(input.size()); - out_index_cpu.reserve(input.size()); - - // Run CPU-side algorithm - fst_baseline(std::begin(input), - std::end(input), - start_state, - qna_sgs, - qna_state_tt, - qna_out_tt, - std::back_inserter(output_cpu), - std::back_inserter(out_index_cpu)); - // Make sure results have been copied back to host stream.synchronize(); // Verify results - ASSERT_EQ(output_gpu_size[0], output_cpu.size()); - std::cout << output_cpu << std::endl; - CUDF_TEST_EXPECT_VECTOR_EQUAL(output_gpu, output_cpu, output_cpu.size()); - // TODO: indexing for multicharacter translations - // CUDF_TEST_EXPECT_VECTOR_EQUAL(out_indexes_gpu, out_index_cpu, output_cpu.size()); + std::cout << "Expected output: " << output << std::endl << "Computed output: "; + for(size_t i = 0; i < output_gpu_size[0]; i++) + std::cout << output_gpu[i]; + std::cout << std::endl; + ASSERT_EQ(output_gpu_size[0], output.size()); + CUDF_TEST_EXPECT_VECTOR_EQUAL(output_gpu, output, output.size()); } -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple1) +TEST_F(FstTest, GroundTruth_QuoteNormalization1) { std::string input = R"({"A":'TEST"'})"; - run_test(input); + std::string output = R"({"A":"TEST\""})"; + run_test(input, output); + } -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple2) +TEST_F(FstTest, GroundTruth_QuoteNormalization2) { std::string input = R"({'A':"TEST'"} ['OTHER STUFF'])"; - run_test(input); + std::string output = R"({"A":"TEST'"} ["OTHER STUFF"])"; + run_test(input, output); } -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple3) +TEST_F(FstTest, GroundTruth_QuoteNormalization3) { std::string input = R"(['{"A": "B"}',"{'A': 'B'}"])"; - run_test(input); + std::string output = R"(["{\"A\": \"B\"}","{'A': 'B'}"])"; + run_test(input, output); } -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple4) +TEST_F(FstTest, GroundTruth_QuoteNormalization4) { std::string input = R"({"ain't ain't a word and you ain't supposed to say it":'"""""""""""'})"; - run_test(input); + std::string output = R"({"ain't ain't a word and you ain't supposed to say it":"\"\"\"\"\"\"\"\"\"\"\""})"; + run_test(input, output); } -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple5) +TEST_F(FstTest, GroundTruth_QuoteNormalization5) { std::string input = R"({"\"'\"'\"'\"'":'"\'"\'"\'"\'"'})"; - run_test(input); + std::string output = R"({"\"'\"'\"'\"'":"\"'\"'\"'\"'\""})"; + run_test(input, output); } -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple6) +TEST_F(FstTest, GroundTruth_QuoteNormalization6) { std::string input = R"([{"ABC':'CBA":'XYZ":"ZXY'}])"; - run_test(input); + std::string output = R"([{"ABC':'CBA":"XYZ\":\"ZXY"}])"; + run_test(input, output); } -TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple7) +TEST_F(FstTest, GroundTruth_QuoteNormalization7) { std::string input = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; - run_test(input); + std::string output = R"(["\t","\\t","\\","\\'\"\\\\","\n","\b"])"; + run_test(input, output); } +TEST_F(FstTest, GroundTruth_QuoteNormalization8) +{ + std::string input = R"(['\t','\\t','\\','\\\"\'\\\\','\n','\b','\u0012'])"; + std::string output = R"(["\t","\\t","\\","\\\"'\\\\","\n","\b","\u0012"])"; + run_test(input, output); +} +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid1) +{ + std::string input = R"(["THIS IS A TEST'])"; + std::string output = R"(["THIS IS A TEST'])"; + run_test(input, output); +} +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid2) +{ + std::string input = R"(['THIS IS A TEST"])"; + std::string output = R"(["THIS IS A TEST\"])"; + run_test(input, output); +} +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid3) +{ + std::string input = R"({"MORE TEST'N":'RESUL})"; + std::string output = R"({"MORE TEST'N":"RESUL})"; + run_test(input, output); +} +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid4) +{ + std::string input = R"({"NUMBER":100'0,'STRING':'SOMETHING'})"; + std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING'})"; + run_test(input, output); +} +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid5) +{ + std::string input = R"({'NUMBER':100"0,"STRING":"SOMETHING"})"; + std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING"})"; + run_test(input, output); +} +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid6) +{ + std::string input = R"({'a':'\\''})"; + std::string output = R"({"a":"\\""})"; + run_test(input, output); +} + CUDF_TEST_PROGRAM_MAIN() From 0aa5f30dca2ffd3f8f16e64c9d0151ae4f450fc5 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Thu, 7 Dec 2023 20:16:41 +0000 Subject: [PATCH 04/11] forgot to run pre-commit style check --- cpp/tests/io/fst/quote_normalization_test.cu | 80 ++++++++++---------- 1 file changed, 42 insertions(+), 38 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index 0134a512bd7..58bc1b211f1 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -50,16 +50,17 @@ enum class dfa_symbol_group_id : uint32_t { NUM_SYMBOL_GROUPS ///< Total number of symbol groups }; // Aliases for readability of the transition table -constexpr auto TT_OOS = dfa_states::TT_OOS; -constexpr auto TT_DQS = dfa_states::TT_DQS; -constexpr auto TT_SQS = dfa_states::TT_SQS; -constexpr auto TT_DEC = dfa_states::TT_DEC; -constexpr auto TT_SEC = dfa_states::TT_SEC; +constexpr auto TT_OOS = dfa_states::TT_OOS; +constexpr auto TT_DQS = dfa_states::TT_DQS; +constexpr auto TT_SQS = dfa_states::TT_SQS; +constexpr auto TT_DEC = dfa_states::TT_DEC; +constexpr auto TT_SEC = dfa_states::TT_SEC; constexpr auto TT_NUM_STATES = static_cast(dfa_states::TT_NUM_STATES); constexpr auto NUM_SYMBOL_GROUPS = static_cast(dfa_symbol_group_id::NUM_SYMBOL_GROUPS); // The i-th string representing all the characters of a symbol group -// TODO: We should reinstantiate this lookup table approach once https://github.com/rapidsai/cudf/pull/14561 is merged -// std::array, NUM_SYMBOL_GROUPS - 1> const qna_sgs{ +// TODO: We should reinstantiate this lookup table approach once +// https://github.com/rapidsai/cudf/pull/14561 is merged std::array, +// NUM_SYMBOL_GROUPS - 1> const qna_sgs{ // {{'\"'}, {'\''}, {'\\'}, {'\n'}}}; // Temporary workaround: struct SymbolToSymbolGroup { @@ -109,9 +110,9 @@ struct TransduceToNormalizedQuotes { // OOS | ' -> " // SQS | " -> \" // SQS | ' -> " - // SQS | \\ -> - // SEC | ' -> ' - // SEC | Sigma\{'} -> \* + // SQS | \\ -> + // SEC | ' -> ' + // SEC | Sigma\{'} -> \* // Whether this transition translates to the escape sequence: \" const bool outputs_escape_sequence = (state_id == static_cast(dfa_states::TT_SQS)) && @@ -126,18 +127,20 @@ struct TransduceToNormalizedQuotes { (state_id == static_cast(dfa_states::TT_OOS)))) { return '"'; } - // Case when the read symbol is an escape character - the actual translation for \ for some symbol is handled by - // transitions from SEC. For now, there is no output for this transition + // Case when the read symbol is an escape character - the actual translation for \ for some + // symbol is handled by transitions from SEC. For now, there is no output for this + // transition else if ((match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)) && - ((state_id == static_cast(dfa_states::TT_SQS)))) { - return 0; + ((state_id == static_cast(dfa_states::TT_SQS)))) { + return 0; } - // Case when an escaped single quote in an input single-quoted string needs to be replaced by an unescaped single quote + // Case when an escaped single quote in an input single-quoted string needs to be replaced by an + // unescaped single quote else if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && - ((state_id == static_cast(dfa_states::TT_SEC)))) { + ((state_id == static_cast(dfa_states::TT_SEC)))) { return '\''; } - // Case when an escaped symbol that is not a single-quote needs to be replaced with \ + // Case when an escaped symbol that is not a single-quote needs to be replaced with \ else if (state_id == static_cast(dfa_states::TT_SEC)) { return (relative_offset == 0) ? '\\' : read_symbol; } @@ -148,7 +151,8 @@ struct TransduceToNormalizedQuotes { return read_symbol; } */ - else return read_symbol; + else + return read_symbol; } /** * @brief Returns the number of output characters for a given transition. During quote @@ -166,19 +170,19 @@ struct TransduceToNormalizedQuotes { (state_id == static_cast(dfa_states::TT_SQS)) && (match_id == static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR)); // Number of characters to output on this transition - if(sqs_outputs_escape_sequence) return 2; + if (sqs_outputs_escape_sequence) return 2; // Whether this transition translates to the escape sequence \ or unescaped ' const bool sec_outputs_escape_sequence = (state_id == static_cast(dfa_states::TT_SEC)) && (match_id != static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)); // Number of characters to output on this transition - if(sec_outputs_escape_sequence) return 2; + if (sec_outputs_escape_sequence) return 2; // Whether this transition translates to no output const bool sqs_outputs_nop = (state_id == static_cast(dfa_states::TT_SQS)) && (match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)); // Number of characters to output on this transition - if(sqs_outputs_nop) return 0; + if (sqs_outputs_nop) return 0; return 1; } }; @@ -188,7 +192,7 @@ struct TransduceToNormalizedQuotes { // Base test fixture for tests struct FstTest : public cudf::test::BaseFixture {}; -void run_test(std::string& input, std::string &output) +void run_test(std::string& input, std::string& output) { // Prepare cuda stream for data transfers & kernels rmm::cuda_stream stream{}; @@ -228,7 +232,7 @@ void run_test(std::string& input, std::string &output) // Verify results std::cout << "Expected output: " << output << std::endl << "Computed output: "; - for(size_t i = 0; i < output_gpu_size[0]; i++) + for (size_t i = 0; i < output_gpu_size[0]; i++) std::cout << output_gpu[i]; std::cout << std::endl; ASSERT_EQ(output_gpu_size[0], output.size()); @@ -237,86 +241,86 @@ void run_test(std::string& input, std::string &output) TEST_F(FstTest, GroundTruth_QuoteNormalization1) { - std::string input = R"({"A":'TEST"'})"; + std::string input = R"({"A":'TEST"'})"; std::string output = R"({"A":"TEST\""})"; run_test(input, output); - } TEST_F(FstTest, GroundTruth_QuoteNormalization2) { - std::string input = R"({'A':"TEST'"} ['OTHER STUFF'])"; + std::string input = R"({'A':"TEST'"} ['OTHER STUFF'])"; std::string output = R"({"A":"TEST'"} ["OTHER STUFF"])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization3) { - std::string input = R"(['{"A": "B"}',"{'A': 'B'}"])"; + std::string input = R"(['{"A": "B"}',"{'A': 'B'}"])"; std::string output = R"(["{\"A\": \"B\"}","{'A': 'B'}"])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization4) { std::string input = R"({"ain't ain't a word and you ain't supposed to say it":'"""""""""""'})"; - std::string output = R"({"ain't ain't a word and you ain't supposed to say it":"\"\"\"\"\"\"\"\"\"\"\""})"; + std::string output = + R"({"ain't ain't a word and you ain't supposed to say it":"\"\"\"\"\"\"\"\"\"\"\""})"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization5) { - std::string input = R"({"\"'\"'\"'\"'":'"\'"\'"\'"\'"'})"; + std::string input = R"({"\"'\"'\"'\"'":'"\'"\'"\'"\'"'})"; std::string output = R"({"\"'\"'\"'\"'":"\"'\"'\"'\"'\""})"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization6) { - std::string input = R"([{"ABC':'CBA":'XYZ":"ZXY'}])"; + std::string input = R"([{"ABC':'CBA":'XYZ":"ZXY'}])"; std::string output = R"([{"ABC':'CBA":"XYZ\":\"ZXY"}])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization7) { - std::string input = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; + std::string input = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; std::string output = R"(["\t","\\t","\\","\\'\"\\\\","\n","\b"])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization8) { - std::string input = R"(['\t','\\t','\\','\\\"\'\\\\','\n','\b','\u0012'])"; + std::string input = R"(['\t','\\t','\\','\\\"\'\\\\','\n','\b','\u0012'])"; std::string output = R"(["\t","\\t","\\","\\\"'\\\\","\n","\b","\u0012"])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid1) { - std::string input = R"(["THIS IS A TEST'])"; + std::string input = R"(["THIS IS A TEST'])"; std::string output = R"(["THIS IS A TEST'])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid2) { - std::string input = R"(['THIS IS A TEST"])"; + std::string input = R"(['THIS IS A TEST"])"; std::string output = R"(["THIS IS A TEST\"])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid3) { - std::string input = R"({"MORE TEST'N":'RESUL})"; + std::string input = R"({"MORE TEST'N":'RESUL})"; std::string output = R"({"MORE TEST'N":"RESUL})"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid4) { - std::string input = R"({"NUMBER":100'0,'STRING':'SOMETHING'})"; + std::string input = R"({"NUMBER":100'0,'STRING':'SOMETHING'})"; std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING'})"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid5) { - std::string input = R"({'NUMBER':100"0,"STRING":"SOMETHING"})"; + std::string input = R"({'NUMBER':100"0,"STRING":"SOMETHING"})"; std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING"})"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid6) { - std::string input = R"({'a':'\\''})"; + std::string input = R"({'a':'\\''})"; std::string output = R"({"a":"\\""})"; run_test(input, output); } From cf50ee7724789d9cb40d17c6e16bbe6808d3fef7 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Mon, 11 Dec 2023 23:39:40 +0000 Subject: [PATCH 05/11] fixing test cases after discussion --- cpp/tests/io/fst/quote_normalization_test.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index 58bc1b211f1..7dcec74595b 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -279,7 +279,7 @@ TEST_F(FstTest, GroundTruth_QuoteNormalization6) TEST_F(FstTest, GroundTruth_QuoteNormalization7) { std::string input = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; - std::string output = R"(["\t","\\t","\\","\\'\"\\\\","\n","\b"])"; + std::string output = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization8) @@ -309,7 +309,7 @@ TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid3) TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid4) { std::string input = R"({"NUMBER":100'0,'STRING':'SOMETHING'})"; - std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING'})"; + std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING"})"; run_test(input, output); } TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid5) From 5f7798b926f3125cc4f85fbaeca7ef62991dcf46 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Tue, 12 Dec 2023 22:33:17 +0000 Subject: [PATCH 06/11] addressing PR feedback - back to lut approach rather than functor --- cpp/tests/io/fst/quote_normalization_test.cu | 43 +++++++------------- 1 file changed, 15 insertions(+), 28 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index 7dcec74595b..e4a2dd61829 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ #include #include +#include + #include #include #include @@ -57,24 +59,11 @@ constexpr auto TT_DEC = dfa_states::TT_DEC; constexpr auto TT_SEC = dfa_states::TT_SEC; constexpr auto TT_NUM_STATES = static_cast(dfa_states::TT_NUM_STATES); constexpr auto NUM_SYMBOL_GROUPS = static_cast(dfa_symbol_group_id::NUM_SYMBOL_GROUPS); + // The i-th string representing all the characters of a symbol group -// TODO: We should reinstantiate this lookup table approach once -// https://github.com/rapidsai/cudf/pull/14561 is merged std::array, -// NUM_SYMBOL_GROUPS - 1> const qna_sgs{ -// {{'\"'}, {'\''}, {'\\'}, {'\n'}}}; -// Temporary workaround: -struct SymbolToSymbolGroup { - CUDF_HOST_DEVICE uint32_t operator()(SymbolT symbol) const - { - switch (symbol) { - case '\"': return static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR); - case '\'': return static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR); - case '\\': return static_cast(dfa_symbol_group_id::ESCAPE_CHAR); - case '\n': return static_cast(dfa_symbol_group_id::NEWLINE_CHAR); - default: return static_cast(dfa_symbol_group_id::OTHER_SYMBOLS); - } - }; -}; +std::array, NUM_SYMBOL_GROUPS - 1> const qna_sgs{ + {{'\"'}, {'\''}, {'\\'}, {'\n'}}}; + // Transition table std::array, TT_NUM_STATES> const qna_state_tt{{ /* IN_STATE " ' \ \n OTHER */ @@ -145,12 +134,6 @@ struct TransduceToNormalizedQuotes { return (relative_offset == 0) ? '\\' : read_symbol; } // In all other cases we simply output the input symbol - /* - else if (!((match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)) && - ((state_id == static_cast(dfa_states::TT_SQS))))) { - return read_symbol; - } - */ else return read_symbol; } @@ -199,7 +182,7 @@ void run_test(std::string& input, std::string& output) rmm::cuda_stream_view stream_view(stream); auto parser = cudf::io::fst::detail::make_fst( - cudf::io::fst::detail::make_symbol_group_lookup_op(SymbolToSymbolGroup{}), + cudf::io::fst::detail::make_symbol_group_lut(qna_sgs), cudf::io::fst::detail::make_transition_table(qna_state_tt), cudf::io::fst::detail::make_translation_functor(TransduceToNormalizedQuotes{}), stream); @@ -211,20 +194,18 @@ void run_test(std::string& input, std::string& output) constexpr std::size_t single_item = 1; cudf::detail::hostdevice_vector output_gpu(input.size() * 2, stream_view); cudf::detail::hostdevice_vector output_gpu_size(single_item, stream_view); - cudf::detail::hostdevice_vector out_indexes_gpu(input.size(), stream_view); // Allocate device-side temporary storage & run algorithm parser.Transduce(d_input.data(), static_cast(d_input.size()), output_gpu.device_ptr(), - out_indexes_gpu.device_ptr(), + thrust::make_discard_iterator(), output_gpu_size.device_ptr(), start_state, stream.value()); // Async copy results from device to host output_gpu.device_to_host_async(stream.view()); - out_indexes_gpu.device_to_host_async(stream.view()); output_gpu_size.device_to_host_async(stream.view()); // Make sure results have been copied back to host @@ -324,5 +305,11 @@ TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid6) std::string output = R"({"a":"\\""})"; run_test(input, output); } +TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid7) +{ + std::string input = R"(}'a': 'b'{)"; + std::string output = R"(}"a": "b"{)"; + run_test(input, output); +} CUDF_TEST_PROGRAM_MAIN() From 470543ddec8c1f3009eab71fb21dfe21913d7844 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Thu, 14 Dec 2023 20:54:54 +0000 Subject: [PATCH 07/11] using the right stream! --- cpp/tests/io/fst/quote_normalization_test.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index e4a2dd61829..53a0a0e7a8d 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -187,7 +187,7 @@ void run_test(std::string& input, std::string& output) cudf::io::fst::detail::make_translation_functor(TransduceToNormalizedQuotes{}), stream); - auto d_input_scalar = cudf::make_string_scalar(input); + auto d_input_scalar = cudf::make_string_scalar(input, stream_view); auto& d_input = static_cast&>(*d_input_scalar); // Prepare input & output buffers @@ -202,11 +202,11 @@ void run_test(std::string& input, std::string& output) thrust::make_discard_iterator(), output_gpu_size.device_ptr(), start_state, - stream.value()); + stream_view); // Async copy results from device to host - output_gpu.device_to_host_async(stream.view()); - output_gpu_size.device_to_host_async(stream.view()); + output_gpu.device_to_host_async(stream_view); + output_gpu_size.device_to_host_async(stream_view); // Make sure results have been copied back to host stream.synchronize(); From 77c46135eac6f2e47a552524e1341a50b835f753 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Thu, 14 Dec 2023 23:25:00 +0000 Subject: [PATCH 08/11] clarifying the translation table - splitting into special and non-special characters --- cpp/tests/io/fst/quote_normalization_test.cu | 34 +++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index 53a0a0e7a8d..ad0f1b39294 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -86,22 +86,24 @@ struct TransduceToNormalizedQuotes { SymbolT const read_symbol) const { // -------- TRANSLATION TABLE ------------ - // state | read_symbol -> output_symbols - // DQS | * -> * - // DEC | * -> * - // OOS | ' -> " - // OOS | * -> * - // SQS | " -> \" - // SQS | ' -> " - // SQS | * -> * - // SEC | * -> * - // ---------- SPECIAL CASES: -------------- (anything else translates to input symbol) - // OOS | ' -> " - // SQS | " -> \" - // SQS | ' -> " - // SQS | \\ -> - // SEC | ' -> ' - // SEC | Sigma\{'} -> \* + // Let the alphabet set be Sigma + // --------------------------------------- + // ---------- NON-SPECIAL CASES: ---------- + // Output symbol same as input symbol + // state | read_symbol -> output_symbol + // DQS | Sigma -> Sigma + // DEC | Sigma -> Sigma + // OOS | Sigma\{'} -> Sigma\{'} + // SQS | Sigma\{', "} -> Sigma\{', "} + // ---------- SPECIAL CASES: -------------- + // Input symbol translates to output symbol + // OOS | {'} -> {"} + // SQS | {'} -> {"} + // SQS | {"} -> {\"} + // SQS | {\} -> + // SEC | {'} -> {'} + // SEC | Sigma\{'} -> {\*} + // Whether this transition translates to the escape sequence: \" const bool outputs_escape_sequence = (state_id == static_cast(dfa_states::TT_SQS)) && From 124ab533ec6484f0e50e2f80045d6f06a5e20702 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Fri, 15 Dec 2023 20:24:49 +0000 Subject: [PATCH 09/11] addressing PR reviews --- cpp/tests/io/fst/quote_normalization_test.cu | 50 +++++++++++++------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index ad0f1b39294..c9492665396 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -38,8 +38,11 @@ namespace { // Type used to represent the atomic symbol type used within the finite-state machine +// TODO: type aliasing to be declared in a common header for better maintainability and +// pre-empt future bugs using SymbolT = char; using StateT = char; + // Type sufficiently large to index symbols within the input and output (may be unsigned) using SymbolOffsetT = uint32_t; enum class dfa_states : char { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES }; @@ -51,6 +54,7 @@ enum class dfa_symbol_group_id : uint32_t { OTHER_SYMBOLS, ///< SG implicitly matching all other characters NUM_SYMBOL_GROUPS ///< Total number of symbol groups }; + // Aliases for readability of the transition table constexpr auto TT_OOS = dfa_states::TT_OOS; constexpr auto TT_DQS = dfa_states::TT_DQS; @@ -73,8 +77,10 @@ std::array, TT_NUM_STATES> const qna_s /* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_OOS, TT_DQS}}, /* TT_SEC */ {{TT_SQS, TT_SQS, TT_SQS, TT_OOS, TT_SQS}}, }}; + // The DFA's starting state constexpr char start_state = static_cast(TT_OOS); + struct TransduceToNormalizedQuotes { /** * @brief Returns the -th output symbol on the transition (state_id, match_id). @@ -109,36 +115,34 @@ struct TransduceToNormalizedQuotes { (state_id == static_cast(dfa_states::TT_SQS)) && (match_id == static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR)); // Case when a double quote needs to be replaced by the escape sequence: \" - if (outputs_escape_sequence) { - return (relative_offset == 0) ? '\\' : '"'; - } + if (outputs_escape_sequence) { return (relative_offset == 0) ? '\\' : '"'; } // Case when a single quote needs to be replaced by a double quote - else if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && - ((state_id == static_cast(dfa_states::TT_SQS)) || - (state_id == static_cast(dfa_states::TT_OOS)))) { + if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SQS)) || + (state_id == static_cast(dfa_states::TT_OOS)))) { return '"'; } // Case when the read symbol is an escape character - the actual translation for \ for some // symbol is handled by transitions from SEC. For now, there is no output for this // transition - else if ((match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)) && - ((state_id == static_cast(dfa_states::TT_SQS)))) { + if ((match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SQS)))) { return 0; } // Case when an escaped single quote in an input single-quoted string needs to be replaced by an // unescaped single quote - else if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && - ((state_id == static_cast(dfa_states::TT_SEC)))) { + if ((match_id == static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)) && + ((state_id == static_cast(dfa_states::TT_SEC)))) { return '\''; } // Case when an escaped symbol that is not a single-quote needs to be replaced with \ - else if (state_id == static_cast(dfa_states::TT_SEC)) { + if (state_id == static_cast(dfa_states::TT_SEC)) { return (relative_offset == 0) ? '\\' : read_symbol; } // In all other cases we simply output the input symbol - else - return read_symbol; + return read_symbol; } + /** * @brief Returns the number of output characters for a given transition. During quote * normalization, we always emit one output character (i.e., either the input character or the @@ -155,19 +159,19 @@ struct TransduceToNormalizedQuotes { (state_id == static_cast(dfa_states::TT_SQS)) && (match_id == static_cast(dfa_symbol_group_id::DOUBLE_QUOTE_CHAR)); // Number of characters to output on this transition - if (sqs_outputs_escape_sequence) return 2; + if (sqs_outputs_escape_sequence) { return 2; } // Whether this transition translates to the escape sequence \ or unescaped ' const bool sec_outputs_escape_sequence = (state_id == static_cast(dfa_states::TT_SEC)) && (match_id != static_cast(dfa_symbol_group_id::SINGLE_QUOTE_CHAR)); // Number of characters to output on this transition - if (sec_outputs_escape_sequence) return 2; + if (sec_outputs_escape_sequence) { return 2; } // Whether this transition translates to no output const bool sqs_outputs_nop = (state_id == static_cast(dfa_states::TT_SQS)) && (match_id == static_cast(dfa_symbol_group_id::ESCAPE_CHAR)); // Number of characters to output on this transition - if (sqs_outputs_nop) return 0; + if (sqs_outputs_nop) { return 0; } return 1; } }; @@ -228,18 +232,21 @@ TEST_F(FstTest, GroundTruth_QuoteNormalization1) std::string output = R"({"A":"TEST\""})"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization2) { std::string input = R"({'A':"TEST'"} ['OTHER STUFF'])"; std::string output = R"({"A":"TEST'"} ["OTHER STUFF"])"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization3) { std::string input = R"(['{"A": "B"}',"{'A': 'B'}"])"; std::string output = R"(["{\"A\": \"B\"}","{'A': 'B'}"])"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization4) { std::string input = R"({"ain't ain't a word and you ain't supposed to say it":'"""""""""""'})"; @@ -247,66 +254,77 @@ TEST_F(FstTest, GroundTruth_QuoteNormalization4) R"({"ain't ain't a word and you ain't supposed to say it":"\"\"\"\"\"\"\"\"\"\"\""})"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization5) { std::string input = R"({"\"'\"'\"'\"'":'"\'"\'"\'"\'"'})"; std::string output = R"({"\"'\"'\"'\"'":"\"'\"'\"'\"'\""})"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization6) { std::string input = R"([{"ABC':'CBA":'XYZ":"ZXY'}])"; std::string output = R"([{"ABC':'CBA":"XYZ\":\"ZXY"}])"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization7) { std::string input = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; std::string output = R"(["\t","\\t","\\","\\\'\"\\\\","\n","\b"])"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization8) { std::string input = R"(['\t','\\t','\\','\\\"\'\\\\','\n','\b','\u0012'])"; std::string output = R"(["\t","\\t","\\","\\\"'\\\\","\n","\b","\u0012"])"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid1) { std::string input = R"(["THIS IS A TEST'])"; std::string output = R"(["THIS IS A TEST'])"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid2) { std::string input = R"(['THIS IS A TEST"])"; std::string output = R"(["THIS IS A TEST\"])"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid3) { std::string input = R"({"MORE TEST'N":'RESUL})"; std::string output = R"({"MORE TEST'N":"RESUL})"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid4) { std::string input = R"({"NUMBER":100'0,'STRING':'SOMETHING'})"; std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING"})"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid5) { std::string input = R"({'NUMBER':100"0,"STRING":"SOMETHING"})"; std::string output = R"({"NUMBER":100"0,"STRING":"SOMETHING"})"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid6) { std::string input = R"({'a':'\\''})"; std::string output = R"({"a":"\\""})"; run_test(input, output); } + TEST_F(FstTest, GroundTruth_QuoteNormalization_Invalid7) { std::string input = R"(}'a': 'b'{)"; From 83da20343d22768c46e4477f6586cc9bdbed1e85 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Tue, 2 Jan 2024 19:28:46 +0000 Subject: [PATCH 10/11] removed stdout logging --- cpp/tests/io/fst/quote_normalization_test.cu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index c9492665396..07069a6c70e 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -218,10 +218,6 @@ void run_test(std::string& input, std::string& output) stream.synchronize(); // Verify results - std::cout << "Expected output: " << output << std::endl << "Computed output: "; - for (size_t i = 0; i < output_gpu_size[0]; i++) - std::cout << output_gpu[i]; - std::cout << std::endl; ASSERT_EQ(output_gpu_size[0], output.size()); CUDF_TEST_EXPECT_VECTOR_EQUAL(output_gpu, output, output.size()); } From 367473b59363b504ce3a8b3abf36e9ab3758e0b3 Mon Sep 17 00:00:00 2001 From: Shruti Shivakumar Date: Tue, 2 Jan 2024 19:55:40 +0000 Subject: [PATCH 11/11] pre-commit run --- cpp/tests/CMakeLists.txt | 2 +- cpp/tests/io/fst/quote_normalization_test.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 8d478ba1a4f..d0abcc225d1 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at diff --git a/cpp/tests/io/fst/quote_normalization_test.cu b/cpp/tests/io/fst/quote_normalization_test.cu index 07069a6c70e..e2636ab029f 100644 --- a/cpp/tests/io/fst/quote_normalization_test.cu +++ b/cpp/tests/io/fst/quote_normalization_test.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.