forked from rapidsai/cudf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
531 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cub/cub.cuh> | ||
|
||
#include <algorithm> | ||
#include <cstdint> | ||
#include <vector> | ||
|
||
namespace cudf { | ||
namespace io { | ||
namespace fst { | ||
namespace detail { | ||
/** | ||
* @brief Class template that can be plugged into the finite-state machine to look up the symbol | ||
* group index for a given symbol. Class template does not support multi-symbol lookups (i.e., no | ||
* look-ahead). | ||
* | ||
* @tparam SymbolT The symbol type being passed in to lookup the corresponding symbol group id | ||
*/ | ||
template <typename SymbolT> | ||
struct SingleSymbolSmemLUT { | ||
//------------------------------------------------------------------------------ | ||
// DEFAULT TYPEDEFS | ||
//------------------------------------------------------------------------------ | ||
// Type used for representing a symbol group id (i.e., what we return for a given symbol) | ||
using SymbolGroupIdT = uint8_t; | ||
|
||
//------------------------------------------------------------------------------ | ||
// DERIVED CONFIGURATIONS | ||
//------------------------------------------------------------------------------ | ||
/// Number of entries for every lookup (e.g., for 8-bit Symbol this is 256) | ||
static constexpr uint32_t NUM_ENTRIES_PER_LUT = 0x01U << (sizeof(SymbolT) * 8U); | ||
|
||
//------------------------------------------------------------------------------ | ||
// TYPEDEFS | ||
//------------------------------------------------------------------------------ | ||
|
||
struct _TempStorage { | ||
// d_match_meta_data[symbol] -> symbol group index | ||
SymbolGroupIdT match_meta_data[NUM_ENTRIES_PER_LUT]; | ||
}; | ||
|
||
struct KernelParameter { | ||
// d_match_meta_data[min(symbol,num_valid_entries)] -> symbol group index | ||
SymbolGroupIdT num_valid_entries; | ||
|
||
// d_match_meta_data[symbol] -> symbol group index | ||
SymbolGroupIdT* d_match_meta_data; | ||
}; | ||
|
||
struct TempStorage : cub::Uninitialized<_TempStorage> { | ||
}; | ||
|
||
//------------------------------------------------------------------------------ | ||
// HELPER METHODS | ||
//------------------------------------------------------------------------------ | ||
/** | ||
* @brief | ||
* | ||
* @param[in] d_temp_storage Device-side temporary storage that can be used to store the lookup | ||
* table. If no storage is provided it will return the temporary storage requirements in \p | ||
* d_temp_storage_bytes. | ||
* @param[in,out] d_temp_storage_bytes Amount of device-side temporary storage that can be used in | ||
* the number of bytes | ||
* @param[in] symbol_strings Array of strings, where the i-th string holds all symbols | ||
* (characters!) that correspond to the i-th symbol group index | ||
* @param[out] kernel_param The kernel parameter object to be initialized with the given mapping | ||
* of symbols to symbol group ids. | ||
* @param[in] stream The stream that shall be used to cudaMemcpyAsync the lookup table | ||
* @return | ||
*/ | ||
template <typename SymbolGroupItT> | ||
__host__ __forceinline__ static cudaError_t PrepareLUT(void* d_temp_storage, | ||
size_t& d_temp_storage_bytes, | ||
SymbolGroupItT const& symbol_strings, | ||
KernelParameter& kernel_param, | ||
cudaStream_t stream = 0) | ||
{ | ||
// The symbol group index to be returned if none of the given symbols match | ||
SymbolGroupIdT no_match_id = symbol_strings.size(); | ||
|
||
std::vector<SymbolGroupIdT> lut(NUM_ENTRIES_PER_LUT); | ||
SymbolGroupIdT max_base_match_val = 0; | ||
|
||
// Initialize all entries: by default we return the no-match-id | ||
for (uint32_t i = 0; i < NUM_ENTRIES_PER_LUT; ++i) { | ||
lut[i] = no_match_id; | ||
} | ||
|
||
// Set up lookup table | ||
uint32_t sg_id = 0; | ||
for (auto const& sg_symbols : symbol_strings) { | ||
for (auto const& sg_symbol : sg_symbols) { | ||
max_base_match_val = std::max(max_base_match_val, static_cast<SymbolGroupIdT>(sg_symbol)); | ||
lut[sg_symbol] = sg_id; | ||
} | ||
sg_id++; | ||
} | ||
|
||
// Initialize the out-of-bounds lookup: d_match_meta_data[max_base_match_val+1] -> no_match_id | ||
lut[max_base_match_val + 1] = no_match_id; | ||
|
||
// Alias memory / return memory requiremenets | ||
kernel_param.num_valid_entries = max_base_match_val + 2; | ||
if (d_temp_storage) { | ||
cudaError_t error = cudaMemcpyAsync(d_temp_storage, | ||
lut.data(), | ||
kernel_param.num_valid_entries * sizeof(SymbolGroupIdT), | ||
cudaMemcpyHostToDevice, | ||
stream); | ||
|
||
kernel_param.d_match_meta_data = reinterpret_cast<SymbolGroupIdT*>(d_temp_storage); | ||
return error; | ||
} else { | ||
d_temp_storage_bytes = kernel_param.num_valid_entries * sizeof(SymbolGroupIdT); | ||
return cudaSuccess; | ||
} | ||
|
||
return cudaSuccess; | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// MEMBER VARIABLES | ||
//------------------------------------------------------------------------------ | ||
_TempStorage& temp_storage; | ||
SymbolGroupIdT num_valid_entries; | ||
|
||
//------------------------------------------------------------------------------ | ||
// CONSTRUCTOR | ||
//------------------------------------------------------------------------------ | ||
__device__ __forceinline__ _TempStorage& PrivateStorage() | ||
{ | ||
__shared__ _TempStorage private_storage; | ||
return private_storage; | ||
} | ||
|
||
__host__ __device__ __forceinline__ SingleSymbolSmemLUT(KernelParameter const& kernel_param, | ||
TempStorage& temp_storage) | ||
: temp_storage(temp_storage.Alias()), num_valid_entries(kernel_param.num_valid_entries) | ||
{ | ||
// GPU-side init | ||
#if CUB_PTX_ARCH > 0 | ||
for (int32_t i = threadIdx.x; i < kernel_param.num_valid_entries; i += blockDim.x) { | ||
this->temp_storage.match_meta_data[i] = kernel_param.d_match_meta_data[i]; | ||
} | ||
__syncthreads(); | ||
|
||
#else | ||
// CPU-side init | ||
for (std::size_t i = 0; i < kernel_param.num_luts; i++) { | ||
this->temp_storage.match_meta_data[i] = kernel_param.d_match_meta_data[i]; | ||
} | ||
#endif | ||
} | ||
|
||
__host__ __device__ __forceinline__ int32_t operator()(SymbolT const symbol) const | ||
{ | ||
// Look up the symbol group for given symbol | ||
return temp_storage.match_meta_data[min(symbol, num_valid_entries - 1)]; | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
} // namespace fst | ||
} // namespace io | ||
} // namespace cudf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cub/cub.cuh> | ||
|
||
#include <cstdint> | ||
|
||
namespace cudf { | ||
namespace io { | ||
namespace fst { | ||
namespace detail { | ||
|
||
template <int MAX_NUM_SYMBOLS, int MAX_NUM_STATES> | ||
struct TransitionTable { | ||
//------------------------------------------------------------------------------ | ||
// DEFAULT TYPEDEFS | ||
//------------------------------------------------------------------------------ | ||
using ItemT = char; | ||
|
||
struct TransitionVectorWrapper { | ||
const ItemT* data; | ||
|
||
__host__ __device__ TransitionVectorWrapper(const ItemT* data) : data(data) {} | ||
|
||
__host__ __device__ __forceinline__ uint32_t Get(int32_t index) const { return data[index]; } | ||
}; | ||
|
||
//------------------------------------------------------------------------------ | ||
// TYPEDEFS | ||
//------------------------------------------------------------------------------ | ||
using TransitionVectorT = TransitionVectorWrapper; | ||
|
||
struct _TempStorage { | ||
// | ||
ItemT transitions[MAX_NUM_STATES * MAX_NUM_SYMBOLS]; | ||
}; | ||
|
||
struct TempStorage : cub::Uninitialized<_TempStorage> { | ||
}; | ||
|
||
struct KernelParameter { | ||
ItemT* transitions; | ||
}; | ||
|
||
using LoadAliasT = std::uint32_t; | ||
|
||
static constexpr std::size_t NUM_AUX_MEM_BYTES = | ||
CUB_QUOTIENT_CEILING(MAX_NUM_STATES * MAX_NUM_SYMBOLS * sizeof(ItemT), sizeof(LoadAliasT)) * | ||
sizeof(LoadAliasT); | ||
|
||
//------------------------------------------------------------------------------ | ||
// HELPER METHODS | ||
//------------------------------------------------------------------------------ | ||
__host__ static cudaError_t CreateTransitionTable( | ||
void* d_temp_storage, | ||
size_t& temp_storage_bytes, | ||
const std::vector<std::vector<int>>& trans_table, | ||
KernelParameter& kernel_param, | ||
cudaStream_t stream = 0) | ||
{ | ||
if (!d_temp_storage) { | ||
temp_storage_bytes = NUM_AUX_MEM_BYTES; | ||
return cudaSuccess; | ||
} | ||
|
||
// trans_vectors[symbol][state] -> new_state | ||
ItemT trans_vectors[MAX_NUM_STATES * MAX_NUM_SYMBOLS]; | ||
|
||
// trans_table[state][symbol] -> new state | ||
for (std::size_t state = 0; state < trans_table.size(); ++state) { | ||
for (std::size_t symbol = 0; symbol < trans_table[state].size(); ++symbol) { | ||
trans_vectors[symbol * MAX_NUM_STATES + state] = trans_table[state][symbol]; | ||
} | ||
} | ||
|
||
kernel_param.transitions = static_cast<ItemT*>(d_temp_storage); | ||
|
||
// Copy transition table to device | ||
return cudaMemcpyAsync( | ||
d_temp_storage, trans_vectors, NUM_AUX_MEM_BYTES, cudaMemcpyHostToDevice, stream); | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// MEMBER VARIABLES | ||
//------------------------------------------------------------------------------ | ||
_TempStorage& temp_storage; | ||
|
||
__device__ __forceinline__ _TempStorage& PrivateStorage() | ||
{ | ||
__shared__ _TempStorage private_storage; | ||
return private_storage; | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
// CONSTRUCTOR | ||
//------------------------------------------------------------------------------ | ||
__host__ __device__ __forceinline__ TransitionTable(const KernelParameter& kernel_param, | ||
TempStorage& temp_storage) | ||
: temp_storage(temp_storage.Alias()) | ||
{ | ||
#if CUB_PTX_ARCH > 0 | ||
for (int i = threadIdx.x; i < CUB_QUOTIENT_CEILING(NUM_AUX_MEM_BYTES, sizeof(LoadAliasT)); | ||
i += blockDim.x) { | ||
reinterpret_cast<LoadAliasT*>(this->temp_storage.transitions)[i] = | ||
reinterpret_cast<LoadAliasT*>(kernel_param.transitions)[i]; | ||
} | ||
__syncthreads(); | ||
#else | ||
for (int i = 0; i < kernel_param.num_luts; i++) { | ||
this->temp_storage.transitions[i] = kernel_param.transitions[i]; | ||
} | ||
#endif | ||
} | ||
|
||
/** | ||
* @brief Returns a random-access iterator to lookup all the state transitions for one specific | ||
* symbol from an arbitrary old_state, i.e., it[old_state] -> new_state. | ||
* | ||
* @param state_id The DFA's current state index from which we'll transition | ||
* @param match_id The symbol group id of the symbol that we just read in | ||
* @return | ||
*/ | ||
template <typename StateIndexT, typename SymbolIndexT> | ||
__host__ __device__ __forceinline__ int32_t operator()(StateIndexT state_id, | ||
SymbolIndexT match_id) const | ||
{ | ||
return temp_storage.transitions[match_id * MAX_NUM_STATES + state_id]; | ||
} | ||
}; | ||
|
||
} // namespace detail | ||
} // namespace fst | ||
} // namespace io | ||
} // namespace cudf |
Oops, something went wrong.