Skip to content

Commit

Permalink
Change interface for FST to not need temp storage
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Jul 13, 2022
1 parent 239f138 commit 39cff80
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 45 deletions.
55 changes: 35 additions & 20 deletions cpp/src/io/fst/lookup_tables.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#pragma once

#include <cudf/types.hpp>
#include <io/utilities/hostdevice_vector.hpp>
#include <io/fst/device_dfa.cuh>
#include <io/utilities/hostdevice_vector.hpp>

#include <cub/cub.cuh>

Expand Down Expand Up @@ -485,26 +485,41 @@ class Dfa {
typename TransducedIndexOutItT,
typename TransducedCountOutItT,
typename OffsetT>
cudaError_t Transduce(void* d_temp_storage,
size_t& temp_storage_bytes,
SymbolT const* d_chars,
OffsetT num_chars,
TransducedOutItT d_out_it,
TransducedIndexOutItT d_out_idx_it,
TransducedCountOutItT d_num_transduced_out_it,
const uint32_t seed_state = 0,
cudaStream_t stream = 0)
void Transduce(SymbolT const* d_chars,
OffsetT num_chars,
TransducedOutItT d_out_it,
TransducedIndexOutItT d_out_idx_it,
TransducedCountOutItT d_num_transduced_out_it,
const uint32_t seed_state,
rmm::cuda_stream_view stream)
{
return DeviceTransduce(d_temp_storage,
temp_storage_bytes,
this->get_device_view(),
d_chars,
num_chars,
d_out_it,
d_out_idx_it,
d_num_transduced_out_it,
seed_state,
stream);
std::size_t temp_storage_bytes = 0;
rmm::device_buffer temp_storage{};
DeviceTransduce(nullptr,
temp_storage_bytes,
this->get_device_view(),
d_chars,
num_chars,
d_out_it,
d_out_idx_it,
d_num_transduced_out_it,
seed_state,
stream);

if (temp_storage.size() < temp_storage_bytes) {
temp_storage.resize(temp_storage_bytes, stream);
}

DeviceTransduce(temp_storage.data(),
temp_storage_bytes,
this->get_device_view(),
d_chars,
num_chars,
d_out_it,
d_out_idx_it,
d_num_transduced_out_it,
seed_state,
stream);
}

private:
Expand Down
39 changes: 14 additions & 25 deletions cpp/tests/io/fst/fst_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,10 @@ TEST_F(FstTest, GroundTruth)
input += input;

// Prepare input & output buffers
constexpr std::size_t single_item = 1;
rmm::device_uvector<SymbolT> d_input(input.size(), stream_view);
hostdevice_vector<SymbolT> output_gpu(input.size(), stream_view);
hostdevice_vector<SymbolOffsetT> output_gpu_size(single_item, stream_view);
hostdevice_vector<SymbolOffsetT> out_indexes_gpu(input.size(), stream_view);
ASSERT_CUDA_SUCCEEDED(cudaMemcpyAsync(
d_input.data(), input.data(), input.size() * sizeof(SymbolT), cudaMemcpyHostToDevice, stream));
Expand All @@ -228,32 +230,19 @@ TEST_F(FstTest, GroundTruth)

std::size_t temp_storage_bytes = 0;

// Query temporary storage requirements
ASSERT_CUDA_SUCCEEDED(parser.Transduce(nullptr,
temp_storage_bytes,
d_input.data(),
static_cast<SymbolOffsetT>(d_input.size()),
output_gpu.device_ptr(),
out_indexes_gpu.device_ptr(),
cub::DiscardOutputIterator<int32_t>{},
start_state,
stream));

// Allocate device-side temporary storage & run algorithm
rmm::device_buffer temp_storage{temp_storage_bytes, stream_view};
ASSERT_CUDA_SUCCEEDED(parser.Transduce(temp_storage.data(),
temp_storage_bytes,
d_input.data(),
static_cast<SymbolOffsetT>(d_input.size()),
output_gpu.device_ptr(),
out_indexes_gpu.device_ptr(),
cub::DiscardOutputIterator<int32_t>{},
start_state,
stream));
parser.Transduce(d_input.data(),
static_cast<SymbolOffsetT>(d_input.size()),
output_gpu.device_ptr(),
out_indexes_gpu.device_ptr(),
output_gpu_size.device_ptr(),
start_state,
stream);

// Async copy results from device to host
output_gpu.device_to_host(stream_view);
out_indexes_gpu.device_to_host(stream_view);
output_gpu_size.device_to_host(stream_view);

// Prepare CPU-side results for verification
std::string output_cpu{};
Expand All @@ -275,13 +264,13 @@ TEST_F(FstTest, GroundTruth)
cudaStreamSynchronize(stream);

// Verify results
ASSERT_EQ(output_gpu.size(), output_cpu.size());
ASSERT_EQ(output_gpu_size[0], output_cpu.size());
ASSERT_EQ(out_indexes_gpu.size(), out_index_cpu.size());
for (std::size_t i = 0; i < output_gpu.size(); i++) {
ASSERT_EQ(output_gpu.host_ptr()[i], output_cpu[i]) << "Mismatch at index #" << i;
for (std::size_t i = 0; i < output_cpu.size(); i++) {
ASSERT_EQ(output_gpu[i], output_cpu[i]) << "Mismatch at index #" << i;
}
for (std::size_t i = 0; i < out_indexes_gpu.size(); i++) {
ASSERT_EQ(out_indexes_gpu.host_ptr()[i], out_index_cpu[i]) << "Mismatch at index #" << i;
ASSERT_EQ(out_indexes_gpu[i], out_index_cpu[i]) << "Mismatch at index #" << i;
}
}

Expand Down

0 comments on commit 39cff80

Please sign in to comment.