Skip to content

Commit

Permalink
WIP for cast_strings_to_dates_legacy
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Oct 21, 2024
1 parent 1a7d192 commit dcb463e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 10 deletions.
104 changes: 94 additions & 10 deletions src/main/cpp/src/json_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ std::tuple<std::unique_ptr<cudf::column>, std::unique_ptr<rmm::device_buffer>, c
thrust::make_counting_iterator(0L),
thrust::make_counting_iterator(input.size() * static_cast<int64_t>(cudf::detail::warp_size)),
[input = *d_input_ptr,
output = thrust::make_zip_iterator(thrust::make_tuple(
is_valid_input.begin(), is_null_or_empty.begin()))] __device__(int64_t tidx) {
output = thrust::make_zip_iterator(is_valid_input.begin(),
is_null_or_empty.begin())] __device__(int64_t tidx) {
// Execute one warp per row to minimize thread divergence.
if ((tidx % cudf::detail::warp_size) != 0) { return; }
auto const idx = tidx / cudf::detail::warp_size;
Expand Down Expand Up @@ -333,8 +333,8 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> cast_strings
cudf::detail::offsetalator_factory::make_input_iterator(input_sv.offsets());
auto const d_input_ptr = cudf::column_device_view::create(input, stream);
auto const is_valid_it = cudf::detail::make_validity_iterator<true>(*d_input_ptr);
auto const output_it = thrust::make_zip_iterator(
thrust::make_tuple(output->mutable_view().begin<bool>(), validity.begin()));
auto const output_it =
thrust::make_zip_iterator(output->mutable_view().begin<bool>(), validity.begin());
thrust::tabulate(
rmm::exec_policy_nosync(stream),
output_it,
Expand Down Expand Up @@ -445,32 +445,36 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> cast_strings
}

// TODO: mr
auto const removed_quotes = remove_quotes(input, false, stream, mr);
auto const removed_quotes = remove_quotes(input, false, stream, mr);
auto const removed_quotes_cv = removed_quotes->view();
auto const input_sv = cudf::strings_column_view{removed_quotes_cv};
auto const d_input_ptr = cudf::column_device_view::create(removed_quotes_cv, stream);
auto const is_valid_it = cudf::detail::make_validity_iterator<true>(*d_input_ptr);

auto const input_sv = cudf::strings_column_view{removed_quotes->view()};
auto const regex_prog = cudf::strings::regex_program::create(
date_regex, cudf::strings::regex_flags::DEFAULT, cudf::strings::capture_groups::NON_CAPTURE);
auto const is_matched = cudf::strings::matches_re(input_sv, *regex_prog, stream);
auto const is_timestamp = cudf::strings::is_timestamp(input_sv, date_format, stream);
auto const d_is_matched = is_matched->view().begin<bool>();
auto const d_is_timestamp = is_timestamp->view().begin<bool>();

auto const d_input_ptr = cudf::column_device_view::create(removed_quotes->view(), stream);
auto const is_valid_it = cudf::detail::make_validity_iterator<true>(*d_input_ptr);
auto const invalid_count = thrust::count_if(
rmm::exec_policy(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(string_count),
[is_valid = is_valid_it, is_matched = d_is_matched, is_timestamp = d_is_timestamp] __device__(
auto idx) { return is_valid[idx] && (!is_matched[idx] || !is_timestamp[idx]); });
auto idx) {
// The row is invalid if it is valid (non-null) but failed at least one check.
return is_valid[idx] && (!is_matched[idx] || !is_timestamp[idx]);
});

if (invalid_count == 0) {
auto output = cudf::strings::to_timestamps(
input_sv, cudf::data_type{cudf::type_id::TIMESTAMP_DAYS}, date_format, stream, mr);
return {std::move(output), rmm::device_uvector<bool>(0, stream)};
}

// From here we have invalid_count > 0.

if (error_if_invalid) { return {nullptr, rmm::device_uvector<bool>(0, stream)}; }

auto const input_offsets_it =
Expand Down Expand Up @@ -525,6 +529,67 @@ std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> cast_strings
return {std::move(output), std::move(validity)};
}

std::pair<std::unique_ptr<cudf::column>, rmm::device_uvector<bool>> cast_strings_to_dates_legacy(
cudf::column_view const& input,
std::vector<std::pair<std::string, int64_t>> const& special_dates,
bool error_if_invalid,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const string_count = input.size();
if (string_count == 0) {
return {cudf::make_empty_column(cudf::data_type{cudf::type_id::TIMESTAMP_DAYS}),
rmm::device_uvector<bool>(0, stream)};
}

// TODO: mr
auto const removed_quotes = remove_quotes(input, false, stream, mr);
auto const removed_quotes_cv = removed_quotes->view();
auto const input_sv = cudf::strings_column_view{removed_quotes_cv};
auto const d_input_ptr = cudf::column_device_view::create(removed_quotes_cv, stream);
auto const is_valid_it = cudf::detail::make_validity_iterator<true>(*d_input_ptr);

auto const check_input = [&](std::string const& date_regex, std::string const& date_format) {
auto const regex_prog = cudf::strings::regex_program::create(
date_regex, cudf::strings::regex_flags::DEFAULT, cudf::strings::capture_groups::NON_CAPTURE);
return {cudf::strings::matches_re(input_sv, *regex_prog, stream),
cudf::strings::is_timestamp(input_sv, date_format, stream)};
};

auto const [is_matched_ymd, is_timestamp_ymd] = check_input();
auto const [is_matched_ym, is_timestamp_ym] = check_input();
auto const [is_matched_y, is_timestamp_y] = check_input();

auto const is_valid_format_it = thrust::make_zip_iterator(is_matched_ymd->view().begin<bool>(),
is_timestamp_ymd->view().begin<bool>(),
is_matched_ym->view().begin<bool>(),
is_timestamp_ym->view().begin<bool>(),
is_matched_y->view().begin<bool>(),
is_timestamp_y->view().begin<bool>());
auto const invalid_count = thrust::count_if(
rmm::exec_policy(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(string_count),
[is_valid = is_valid_it, is_valid_format = is_valid_format_it] __device__(auto idx) {
if (!is_valid[idx]) { return 0; }
auto const valid_format = is_valid_format[idx];
// The row is invalid only if it non-null and failed to check for all 3 formats.
return (!thrust::get<0>(valid_format) || !thrust::get<1>(valid_format)) &&
(!thrust::get<2>(valid_format) || !thrust::get<3>(valid_format)) &&
(!thrust::get<4>(valid_format) || !thrust::get<5>(valid_format));
});

if (invalid_count == 0) {
// TODO
auto output = cudf::strings::to_timestamps(
input_sv, cudf::data_type{cudf::type_id::TIMESTAMP_DAYS}, date_format, stream, mr);
return {std::move(output), rmm::device_uvector<bool>(0, stream)};
}
// From here we have invalid_count > 0.

if (error_if_invalid) { return {nullptr, rmm::device_uvector<bool>(0, stream)}; }
}

// TODO there is a bug here around 0 https://github.com/NVIDIA/spark-rapids/issues/10898
std::unique_ptr<cudf::column> cast_strings_to_decimals(cudf::column_view const& input,
int precision,
Expand Down Expand Up @@ -939,6 +1004,25 @@ std::unique_ptr<cudf::column> cast_strings_to_dates(cudf::column_view const& inp
return std::move(output);
}

std::unique_ptr<cudf::column> cast_strings_to_dates_legacy(
cudf::column_view const& input,
std::vector<std::pair<std::string, int64_t>> const& special_dates,
bool error_if_invalid,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();

auto [output, validity] =
detail::cast_strings_to_dates_legacy(input, special_dates, error_if_invalid, stream, mr);

if (output == nullptr) { return nullptr; }
auto [null_mask, null_count] =
cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity{}, stream, mr);
if (null_count > 0) { output->set_null_mask(std::move(null_mask), null_count); }
return std::move(output);
}

std::unique_ptr<cudf::column> cast_strings_to_decimals(cudf::column_view const& input,
int precision,
int scale,
Expand Down
7 changes: 7 additions & 0 deletions src/main/cpp/src/json_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ std::unique_ptr<cudf::column> cast_strings_to_dates(
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

std::unique_ptr<cudf::column> cast_strings_to_dates_legacy(
cudf::column_view const& input,
std::vector<std::pair<std::string, int64_t>> const& special_dates,
bool error_if_invalid,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());

std::unique_ptr<cudf::column> remove_quotes(
cudf::column_view const& input,
bool nullify_if_not_quoted,
Expand Down

0 comments on commit dcb463e

Please sign in to comment.