Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regex rewrite kernel to find literal[a,b]{x,y} in a string #2041

Merged
merged 18 commits into from
May 21, 2024
Merged
2 changes: 2 additions & 0 deletions src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ add_library(
src/MapUtilsJni.cpp
src/NativeParquetJni.cpp
src/ParseURIJni.cpp
src/RegexRewriteUtilsJni.cpp
src/RowConversionJni.cpp
src/SparkResourceAdaptorJni.cpp
src/ZOrderJni.cpp
Expand All @@ -178,6 +179,7 @@ add_library(
src/map_utils.cu
src/murmur_hash.cu
src/parse_uri.cu
src/regex_rewrite_utils.cu
src/row_conversion.cu
src/timezones.cu
src/utilities.cu
Expand Down
41 changes: 41 additions & 0 deletions src/main/cpp/src/RegexRewriteUtilsJni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
#include "jni_utils.hpp"
#include "regex_rewrite_utils.hpp"

extern "C" {

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_RegexRewriteUtils_literalRangePattern(
JNIEnv* env, jclass, jlong input, jlong target, jint d, jint start, jint end)
{
JNI_NULL_CHECK(env, input, "input column is null", 0);
JNI_NULL_CHECK(env, target, "target is null", 0);

try {
cudf::jni::auto_set_device(env);

cudf::column_view* cv = reinterpret_cast<cudf::column_view*>(input);
cudf::strings_column_view scv(*cv);
cudf::string_scalar* ss_scalar = reinterpret_cast<cudf::string_scalar*>(target);
return cudf::jni::release_as_jlong(
spark_rapids_jni::literal_range_pattern(scv, *ss_scalar, d, start, end));
}
CATCH_STD(env, 0);
}
}
133 changes: 133 additions & 0 deletions src/main/cpp/src/regex_rewrite_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/strings/detail/utf8.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

namespace spark_rapids_jni {

namespace {

struct literal_range_pattern_fn {
__device__ bool operator()(
thirtiseven marked this conversation as resolved.
Show resolved Hide resolved
cudf::string_view d_string, cudf::string_view d_prefix, int range_len, int start, int end) const
{
int const n = d_string.length(), m = d_prefix.length();
for (int i = 0; i <= n - m - range_len; i++) {
bool match = true;
for (int j = 0; j < m; j++) {
if (d_string[i + j] != d_prefix[j]) {
match = false;
break;
}
}
if (match) {
for (int j = 0; j < range_len; j++) {
auto code_point = cudf::strings::detail::utf8_to_codepoint(d_string[i + m + j]);
if (code_point < start || code_point > end) {
match = false;
break;
}
}
if (match) { return true; }
}
}
return false;
}
};

std::unique_ptr<cudf::column> find_literal_range_pattern(cudf::strings_column_view const& strings,
cudf::string_scalar const& prefix,
int const range_len,
int const start,
int const end,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const strings_count = strings.size();
if (strings_count == 0) { return cudf::make_empty_column(cudf::type_id::BOOL8); }

CUDF_EXPECTS(prefix.is_valid(stream), "Parameter prefix must be valid.");

auto const d_prefix = cudf::string_view(prefix.data(), prefix.size());
auto const strings_column = cudf::column_device_view::create(strings.parent(), stream);
auto const d_strings = *strings_column;

auto results = make_numeric_column(cudf::data_type{cudf::type_id::BOOL8},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
stream,
mr);
Comment on lines +82 to +87
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that the file was not properly formated (using clang-format).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passed clang-format in the pre-commit hook

auto const d_results = results->mutable_view().data<bool>();
// set the bool values by evaluating the passed function
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(strings_count),
d_results,
[d_strings, d_prefix, range_len, start, end, check_fn = literal_range_pattern_fn{}] __device__(
cudf::size_type idx) {
if (!d_strings.is_null(idx)) {
return check_fn(d_strings.element<cudf::string_view>(idx), d_prefix, range_len, start, end);
}
return false;
});
results->set_null_count(strings.null_count());
return results;
}

} // namespace

/**
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
* a literal string followed by a range of characters in the range of start to end, with at least
* len characters.
*
* @param strings Column of strings to check for literal.
* @param literal UTF-8 encoded string to check in strings column.
* @param len Minimum number of characters to check after the literal.
* @param start Minimum UTF-8 codepoint value to check for in the range.
* @param end Maximum UTF-8 codepoint value to check for in the range.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr<cudf::column> literal_range_pattern(cudf::strings_column_view const& input,
cudf::string_scalar const& prefix,
int const range_len,
int const start,
int const end,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
return find_literal_range_pattern(input, prefix, range_len, start, end, stream, mr);
}

} // namespace spark_rapids_jni
45 changes: 45 additions & 0 deletions src/main/cpp/src/regex_rewrite_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

namespace spark_rapids_jni {
/**
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
* a literal string followed by a range of characters in the range of start to end, with at least
* len characters.
*
* @param strings Column of strings to check for literal.
* @param literal UTF-8 encoded string to check in strings column.
* @param len Minimum number of characters to check after the literal.
* @param start Minimum UTF-8 codepoint value to check for in the range.
* @param end Maximum UTF-8 codepoint value to check for in the range.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr<cudf::column> literal_range_pattern(
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
cudf::strings_column_view const& input,
cudf::string_scalar const& literal,
int const len,
int const start,
int const end,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
} // namespace spark_rapids_jni
44 changes: 44 additions & 0 deletions src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.*;

public class RegexRewriteUtils {
static {
NativeDepsLoader.loadNativeDeps();
}

/**
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
* a literal string followed by a range of characters in the range of start to end, with at least
* len characters.
*
* @param strings Column of strings to check for literal.
* @param literal UTF-8 encoded string to check in strings column.
* @param len Minimum number of characters to check after the literal.
* @param start Minimum UTF-8 codepoint value to check for in the range.
* @param end Maximum UTF-8 codepoint value to check for in the range.
* @return ColumnVector of booleans where true indicates the string contains the pattern.
*/
public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, int start, int end) {
assert(input.getType().equals(DType.STRING)) : "column must be a String";
return new ColumnVector(literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, start, end));
}

private static native long literalRangePattern(long input, long literal, int len, int start, int end);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Scalar;
import org.junit.jupiter.api.Test;

import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;

public class RegexRewriteUtilsTest {

@Test
void testLiteralRangePattern() {
int d = 3;
try (ColumnVector inputCv = ColumnVector.fromStrings(
"abc123", "aabc123", "aabc12", "abc1232", "aabc1232");
Scalar pattern = Scalar.fromString("abc");
ColumnVector expected = ColumnVector.fromBooleans(true, true, false, true, true);
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 48, 57)) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testLiteralRangePatternChinese() {
int d = 2;
try (ColumnVector inputCv = ColumnVector.fromStrings(
"数据砖块", "火花-急流英伟达", "英伟达Nvidia", "火花-急流");
Scalar pattern = Scalar.fromString("英");
ColumnVector expected = ColumnVector.fromBooleans(false, true, true, false);
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 19968, 40869)) {
assertColumnsAreEqual(expected, actual);
}
}

}
Loading