Skip to content

Commit

Permalink
simple quote normalization test
Browse files Browse the repository at this point in the history
  • Loading branch information
shrshi committed Dec 1, 2023
1 parent 5e58e71 commit fe9b582
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 1 deletion.
3 changes: 2 additions & 1 deletion cpp/src/io/fst/lookup_tables.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ class TransducerLookupTable {
}

// Check whether runtime-provided table size exceeds the compile-time given max. table size
CUDF_EXPECTS(out_symbols.size() <= MAX_TABLE_SIZE, "Unsupported translation table");
// TODO: Support for multicharater translations?
// CUDF_EXPECTS(out_symbols.size() <= MAX_TABLE_SIZE, "Unsupported translation table");

// Prepare host-side data to be copied and passed to the device
std::copy(
Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ ConfigureTest(
PERCENT 30
)
target_link_libraries(DATA_CHUNK_SOURCE_TEST PRIVATE ZLIB::ZLIB)
ConfigureTest(QUOTE_NORMALIZATION_TEST io/fst/quote_normalization_test.cu)
ConfigureTest(LOGICAL_STACK_TEST io/fst/logical_stack_test.cu)
ConfigureTest(FST_TEST io/fst/fst_test.cu)
ConfigureTest(TYPE_INFERENCE_TEST io/type_inference_test.cu)
Expand Down
244 changes: 244 additions & 0 deletions cpp/tests/io/fst/quote_normalization_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <io/fst/lookup_tables.cuh>
#include <io/utilities/hostdevice_vector.hpp>

#include <cudf_test/base_fixture.hpp>
#include <cudf_test/cudf_gtest.hpp>

#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/repeat_strings.hpp>
#include <cudf/types.hpp>

#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/device_uvector.hpp>

#include <cstdlib>
#include <string>
#include <vector>

namespace {

//------------------------------------------------------------------------------
// CPU-BASED IMPLEMENTATIONS FOR VERIFICATION
//------------------------------------------------------------------------------
/**
* @brief CPU-based implementation of a finite-state transducer (FST).
*
* @tparam InputItT Forward input iterator type to symbols fed into the FST
* @tparam StateT Type representing states of the finite-state machine
* @tparam SymbolGroupLutT Sequence container of symbol groups. Each symbol group is a sequence
* container to symbols within that group.
* @tparam TransitionTableT Two-dimensional container type
* @tparam TransducerTableT Two-dimensional container type
* @tparam OutputItT Forward output iterator type
* @tparam IndexOutputItT Forward output iterator type
* @param[in] begin Forward iterator to the beginning of the symbol sequence
* @param[in] end Forward iterator to one past the last element of the symbol sequence
* @param[in] init_state The starting state of the finite-state machine
* @param[in] symbol_group_lut Sequence container of symbol groups. Each symbol group is a sequence
* container to symbols within that group. The index of the symbol group containing a symbol being
* read will be used as symbol_gid of the transition and translation tables.
* @param[in] transition_table The two-dimensional transition table, i.e.,
* transition_table[state][symbol_gid] -> new_state
* @param[in] translation_table The two-dimensional transducer table, i.e.,
* translation_table[state][symbol_gid] -> range_of_output_symbols
* @param[out] out_tape A forward output iterator to which the transduced input will be written
* @param[out] out_index_tape A forward output iterator to which indexes of the symbols that
* actually caused some output are written to
* @return A pair of iterators to one past the last element of (1) the transduced output symbol
* sequence and (2) the indexes of
*/
template <typename InputItT,
typename StateT,
typename SymbolGroupLutT,
typename TransitionTableT,
typename TransducerTableT,
typename OutputItT,
typename IndexOutputItT>
static std::pair<OutputItT, IndexOutputItT> fst_baseline(InputItT begin,
InputItT end,
StateT const& init_state,
SymbolGroupLutT symbol_group_lut,
TransitionTableT transition_table,
TransducerTableT translation_table,
OutputItT out_tape,
IndexOutputItT out_index_tape)
{
// Initialize "FSM" with starting state
StateT state = init_state;

// To track the symbol offset within the input that caused the FST to output
std::size_t in_offset = 0;
for (auto it = begin; it < end; it++) {
// The symbol currently being read
auto const& symbol = *it;

// Iterate over symbol groups and search for the first symbol group containing the current
// symbol, if no match is found we use cend(symbol_group_lut) as the "catch-all" symbol group
auto symbol_group_it =
std::find_if(std::cbegin(symbol_group_lut), std::cend(symbol_group_lut), [symbol](auto& sg) {
return std::find(std::cbegin(sg), std::cend(sg), symbol) != std::cend(sg);
});
auto symbol_group = std::distance(std::cbegin(symbol_group_lut), symbol_group_it);

// Output the translated symbols to the output tape
out_tape = std::copy(std::cbegin(translation_table[state][symbol_group]),
std::cend(translation_table[state][symbol_group]),
out_tape);

auto out_size = std::distance(std::cbegin(translation_table[state][symbol_group]),
std::cend(translation_table[state][symbol_group]));

out_index_tape = std::fill_n(out_index_tape, out_size, in_offset);

// Transition the state of the finite-state machine
state = static_cast<char>(transition_table[state][symbol_group]);

// Continue with next symbol from input tape
in_offset++;
}
return {out_tape, out_index_tape};
}
} // namespace

// Base test fixture for tests
struct FstTest : public cudf::test::BaseFixture {};

TEST_F(FstTest, GroundTruth_QuoteNormalizationSimple)
{
// Type used to represent the atomic symbol type used within the finite-state machine
using SymbolT = char;

// Type sufficiently large to index symbols within the input and output (may be unsigned)
using SymbolOffsetT = uint32_t;

// Prepare cuda stream for data transfers & kernels
rmm::cuda_stream stream{};
rmm::cuda_stream_view stream_view(stream);

// Test input
std::string input = R"({"A" : 'TEST"'})";
auto d_input_scalar = cudf::make_string_scalar(input);
auto& d_input = static_cast<cudf::scalar_type_t<std::string>&>(*d_input_scalar);

// Prepare input & output buffers
constexpr std::size_t single_item = 1;
cudf::detail::hostdevice_vector<SymbolT> output_gpu(input.size() * 2, stream_view);
cudf::detail::hostdevice_vector<SymbolOffsetT> output_gpu_size(single_item, stream_view);
cudf::detail::hostdevice_vector<SymbolOffsetT> out_indexes_gpu(input.size(), stream_view);

enum class dfa_states : char { TT_OOS = 0U, TT_DQS, TT_SQS, TT_DEC, TT_SEC, TT_NUM_STATES };

enum class dfa_symbol_group_id : uint32_t {
OPENING_BRACE, ///< Opening brace SG: {
OPENING_BRACKET, ///< Opening bracket SG: [
CLOSING_BRACE, ///< Closing brace SG: }
CLOSING_BRACKET, ///< Closing bracket SG: ]
DOUBLE_QUOTE_CHAR, ///< Quote character SG: "
SINGLE_QUOTE_CHAR, ///< Quote character SG: '
ESCAPE_CHAR, ///< Escape character SG: '\'
OTHER_SYMBOLS, ///< SG implicitly matching all other characters
NUM_SYMBOL_GROUPS ///< Total number of symbol groups
};

// Aliases for readability of the transition table
constexpr auto TT_OOS = dfa_states::TT_OOS;
constexpr auto TT_DQS = dfa_states::TT_DQS;
constexpr auto TT_SQS = dfa_states::TT_SQS;
constexpr auto TT_DEC = dfa_states::TT_DEC;
constexpr auto TT_SEC = dfa_states::TT_SEC;

constexpr auto TT_NUM_STATES = static_cast<char>(dfa_states::TT_NUM_STATES);
constexpr auto NUM_SYMBOL_GROUPS = static_cast<uint32_t>(dfa_symbol_group_id::NUM_SYMBOL_GROUPS);

// The i-th string representing all the characters of a symbol group
std::array<std::string, NUM_SYMBOL_GROUPS - 1> const qna_sgs{"{", "[", "}", "]", "\"", "'", "\\"};

// Transition table
// Does not support JSON lines
std::array<std::array<dfa_states, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const qna_state_tt{{
/* IN_STATE { [ } ] " ' \ OTHER */
/* TT_OOS */ {{TT_OOS, TT_OOS, TT_OOS, TT_OOS, TT_DQS, TT_SQS, TT_OOS, TT_OOS}},
/* TT_DQS */ {{TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_OOS, TT_DQS, TT_DEC, TT_DQS}},
/* TT_SQS */ {{TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_OOS, TT_SEC, TT_SQS}},
/* TT_DEC */ {{TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS, TT_DQS}},
/* TT_SEC */ {{TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS, TT_SQS}},
}};

// Translation table (i.e., for each transition, what are the symbols that we output)
std::array<std::array<std::vector<char>, NUM_SYMBOL_GROUPS>, TT_NUM_STATES> const qna_out_tt{
{/* IN_STATE { [ } ] " ' \ OTHER */
/* TT_OOS */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'"'}, {'\\'}, {'x'}}},
/* TT_DQS */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}},
/* TT_SQS */ {{{'{'}, {'['}, {'}'}, {']'}, {'\\', '"'}, {'"'}, {'\\'}, {'x'}}},
/* TT_DEC */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}},
/* TT_SEC */ {{{'{'}, {'['}, {'}'}, {']'}, {'"'}, {'\''}, {'\\'}, {'x'}}}}};

// The DFA's starting state
constexpr char start_state = static_cast<char>(TT_OOS);

// Run algorithm
auto parser = cudf::io::fst::detail::make_fst(
cudf::io::fst::detail::make_symbol_group_lut(qna_sgs),
cudf::io::fst::detail::make_transition_table(qna_state_tt),
cudf::io::fst::detail::make_translation_table<TT_NUM_STATES * NUM_SYMBOL_GROUPS>(qna_out_tt),
stream);

// Allocate device-side temporary storage & run algorithm
parser.Transduce(d_input.data(),
static_cast<SymbolOffsetT>(d_input.size()),
output_gpu.device_ptr(),
out_indexes_gpu.device_ptr(),
output_gpu_size.device_ptr(),
start_state,
stream.value());

// Async copy results from device to host
output_gpu.device_to_host_async(stream.view());
out_indexes_gpu.device_to_host_async(stream.view());
output_gpu_size.device_to_host_async(stream.view());

// Prepare CPU-side results for verification
std::string output_cpu{};
std::vector<SymbolOffsetT> out_index_cpu{};
output_cpu.reserve(input.size());
out_index_cpu.reserve(input.size());

// Run CPU-side algorithm
fst_baseline(std::begin(input),
std::end(input),
start_state,
qna_sgs,
qna_state_tt,
qna_out_tt,
std::back_inserter(output_cpu),
std::back_inserter(out_index_cpu));

// Make sure results have been copied back to host
stream.synchronize();

// Verify results
ASSERT_EQ(output_gpu_size[0], output_cpu.size());
CUDF_TEST_EXPECT_VECTOR_EQUAL(output_gpu, output_cpu, output_cpu.size());
// TODO: indexing for multicharacter translations
// CUDF_TEST_EXPECT_VECTOR_EQUAL(out_indexes_gpu, out_index_cpu, output_cpu.size());
}

CUDF_TEST_PROGRAM_MAIN()

0 comments on commit fe9b582

Please sign in to comment.