From d351e5c4197acf7c7ab215ea7555926cb2d1f5b8 Mon Sep 17 00:00:00 2001 From: Elias Stehle <3958403+elstehle@users.noreply.github.com> Date: Thu, 14 Jul 2022 09:17:59 -0700 Subject: [PATCH] addresses style review comments & fixes a todo --- cpp/src/io/fst/lookup_tables.cuh | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/cpp/src/io/fst/lookup_tables.cuh b/cpp/src/io/fst/lookup_tables.cuh index a0a9f81a302..208890d28d3 100644 --- a/cpp/src/io/fst/lookup_tables.cuh +++ b/cpp/src/io/fst/lookup_tables.cuh @@ -41,7 +41,7 @@ class SingleSymbolSmemLUT { // Type used for representing a symbol group id (i.e., what we return for a given symbol) using SymbolGroupIdT = uint8_t; - /// Number of entries for every lookup (e.g., for 8-bit Symbol this is 256) + // Number of entries for every lookup (e.g., for 8-bit Symbol this is 256) static constexpr uint32_t NUM_ENTRIES_PER_LUT = 0x01U << (sizeof(SymbolT) * 8U); struct _TempStorage { @@ -60,9 +60,6 @@ class SingleSymbolSmemLUT { using TempStorage = cub::Uninitialized<_TempStorage>; - //------------------------------------------------------------------------------ - // HELPER METHODS - //------------------------------------------------------------------------------ /** * @brief * @@ -104,21 +101,14 @@ class SingleSymbolSmemLUT { sgid_init.host_ptr()->sym_to_sgid[max_base_match_val + 1] = no_match_id; // Alias memory / return memory requiremenets - // TODO I think this could be +1? - sgid_init.host_ptr()->num_valid_entries = max_base_match_val + 2; + sgid_init.host_ptr()->num_valid_entries = max_base_match_val + 1; sgid_init.host_to_device(stream); } - //------------------------------------------------------------------------------ - // MEMBER VARIABLES - //------------------------------------------------------------------------------ _TempStorage& temp_storage; SymbolGroupIdT num_valid_entries; - //------------------------------------------------------------------------------ - // CONSTRUCTOR - //------------------------------------------------------------------------------ __device__ __forceinline__ _TempStorage& PrivateStorage() { __shared__ _TempStorage private_storage; @@ -170,17 +160,17 @@ class TransitionTable { template ()})>> static void InitDeviceTransitionTable(hostdevice_vector& transition_table_init, - std::vector> const& trans_table, + std::vector> const& translation_table, rmm::cuda_stream_view stream) { - // trans_table[state][symbol] -> new state - for (std::size_t state = 0; state < trans_table.size(); ++state) { - for (std::size_t symbol = 0; symbol < trans_table[state].size(); ++symbol) { + // translation_table[state][symbol] -> new state + for (std::size_t state = 0; state < translation_table.size(); ++state) { + for (std::size_t symbol = 0; symbol < translation_table[state].size(); ++symbol) { CUDF_EXPECTS( - trans_table[state][symbol] <= std::numeric_limits::max(), + translation_table[state][symbol] <= std::numeric_limits::max(), "Target state index value exceeds value representable by the transition table's type"); transition_table_init.host_ptr()->transitions[symbol * MAX_NUM_STATES + state] = - trans_table[state][symbol]; + translation_table[state][symbol]; } } @@ -314,7 +304,7 @@ class TransducerLookupTable { */ static void InitDeviceTranslationTable( hostdevice_vector& translation_table_init, - std::vector>> const& trans_table, + std::vector>> const& translation_table, rmm::cuda_stream_view stream) { std::vector out_symbols; @@ -324,7 +314,7 @@ class TransducerLookupTable { out_symbol_offsets.push_back(0); // Iterate over the states in the transition table - for (auto const& state_trans : trans_table) { + for (auto const& state_trans : translation_table) { uint32_t num_added = 0; // Iterate over the symbols in the transition table for (auto const& symbol_out : state_trans) {