Skip to content

Commit

Permalink
Add regex rewrite kernel to find literal[a,b]{x,y} in a string (#1)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: Haoyang Li <[email protected]>

* wip

Signed-off-by: Haoyang Li <[email protected]>

* wip

Signed-off-by: Haoyang Li <[email protected]>

* wip

Signed-off-by: Haoyang Li <[email protected]>

* support range filter

Signed-off-by: Haoyang Li <[email protected]>

* clean up

Signed-off-by: Haoyang Li <[email protected]>

* wip

Signed-off-by: Haoyang Li <[email protected]>

* change some names

Signed-off-by: Haoyang Li <[email protected]>

* clean up

Signed-off-by: Haoyang Li <[email protected]>

* address comments

Signed-off-by: Haoyang Li <[email protected]>

---------

Signed-off-by: Haoyang Li <[email protected]>
  • Loading branch information
thirtiseven authored May 21, 2024
1 parent 2881ab2 commit 5156fe2
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ add_library(
src/MapUtilsJni.cpp
src/NativeParquetJni.cpp
src/ParseURIJni.cpp
src/RegexRewriteUtilsJni.cpp
src/RowConversionJni.cpp
src/SparkResourceAdaptorJni.cpp
src/StringDigitsPatternJni.cpp
Expand All @@ -202,6 +203,7 @@ add_library(
src/murmur_hash.cu
src/padding_partition.cu
src/parse_uri.cu
src/regex_rewrite_utils.cu
src/row_conversion.cu
src/string_digits_pattern.cu
src/timezones.cu
Expand Down
41 changes: 41 additions & 0 deletions src/main/cpp/src/RegexRewriteUtilsJni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
#include "jni_utils.hpp"
#include "regex_rewrite_utils.hpp"

extern "C" {

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_RegexRewriteUtils_literalRangePattern(
JNIEnv* env, jclass, jlong input, jlong target, jint d, jint start, jint end)
{
JNI_NULL_CHECK(env, input, "input column is null", 0);
JNI_NULL_CHECK(env, target, "target is null", 0);

try {
cudf::jni::auto_set_device(env);

cudf::column_view* cv = reinterpret_cast<cudf::column_view*>(input);
cudf::strings_column_view scv(*cv);
cudf::string_scalar* ss_scalar = reinterpret_cast<cudf::string_scalar*>(target);
return cudf::jni::release_as_jlong(
spark_rapids_jni::literal_range_pattern(scv, *ss_scalar, d, start, end));
}
CATCH_STD(env, 0);
}
}
133 changes: 133 additions & 0 deletions src/main/cpp/src/regex_rewrite_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/iterator.cuh>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/strings/detail/utf8.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

namespace spark_rapids_jni {

namespace {

struct literal_range_pattern_fn {
__device__ bool operator()(
cudf::string_view d_string, cudf::string_view d_prefix, int range_len, int start, int end)
{
int const n = d_string.length(), m = d_prefix.length();
for (int i = 0; i <= n - m - range_len; i++) {
bool match = true;
for (int j = 0; j < m; j++) {
if (d_string[i + j] != d_prefix[j]) {
match = false;
break;
}
}
if (match) {
for (int j = 0; j < range_len; j++) {
auto code_point = cudf::strings::detail::utf8_to_codepoint(d_string[i + m + j]);
if (code_point < start || code_point > end) {
match = false;
break;
}
}
if (match) { return true; }
}
}
return false;
}
};

std::unique_ptr<cudf::column> find_literal_range_pattern(cudf::strings_column_view const& strings,
cudf::string_scalar const& prefix,
int const range_len,
int const start,
int const end,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto const strings_count = strings.size();
if (strings_count == 0) return cudf::make_empty_column(cudf::type_id::BOOL8);

CUDF_EXPECTS(prefix.is_valid(stream), "Parameter prefix must be valid.");

auto const d_prefix = cudf::string_view(prefix.data(), prefix.size());
auto const strings_column = cudf::column_device_view::create(strings.parent(), stream);
auto const d_strings = *strings_column;

auto results = make_numeric_column(cudf::data_type{cudf::type_id::BOOL8},
strings_count,
cudf::detail::copy_bitmask(strings.parent(), stream, mr),
strings.null_count(),
stream,
mr);
auto const d_results = results->mutable_view().data<bool>();
// set the bool values by evaluating the passed function
thrust::transform(
rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(strings_count),
d_results,
[d_strings, d_prefix, range_len, start, end] __device__(cudf::size_type idx) {
if (!d_strings.is_null(idx)) {
return bool{literal_range_pattern_fn{}(
d_strings.element<cudf::string_view>(idx), d_prefix, range_len, start, end)};
}
return false;
});
results->set_null_count(strings.null_count());
return results;
}

} // namespace

/**
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
* a literal string followed by a range of characters in the range of start to end, with at least
* len characters.
*
* @param strings Column of strings to check for literal.
* @param literal UTF-8 encoded string to check in strings column.
* @param len Minimum number of characters to check after the literal.
* @param start Minimum UTF-8 codepoint value to check for in the range.
* @param end Maximum UTF-8 codepoint value to check for in the range.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr<cudf::column> literal_range_pattern(cudf::strings_column_view const& input,
cudf::string_scalar const& prefix,
int const range_len,
int const start,
int const end,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_FUNC_RANGE();
return find_literal_range_pattern(input, prefix, range_len, start, end, stream, mr);
}

} // namespace spark_rapids_jni
45 changes: 45 additions & 0 deletions src/main/cpp/src/regex_rewrite_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>

namespace spark_rapids_jni {
/**
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
* a literal string followed by a range of characters in the range of start to end, with at least
* len characters.
*
* @param strings Column of strings to check for literal.
* @param literal UTF-8 encoded string to check in strings column.
* @param len Minimum number of characters to check after the literal.
* @param start Minimum UTF-8 codepoint value to check for in the range.
* @param end Maximum UTF-8 codepoint value to check for in the range.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr<cudf::column> literal_range_pattern(
cudf::strings_column_view const& input,
cudf::string_scalar const& literal,
int const len,
int const start,
int const end,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource());
} // namespace spark_rapids_jni
44 changes: 44 additions & 0 deletions src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.*;

public class RegexRewriteUtils {
static {
NativeDepsLoader.loadNativeDeps();
}

/**
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means
* a literal string followed by a range of characters in the range of start to end, with at least
* len characters.
*
* @param strings Column of strings to check for literal.
* @param literal UTF-8 encoded string to check in strings column.
* @param len Minimum number of characters to check after the literal.
* @param start Minimum UTF-8 codepoint value to check for in the range.
* @param end Maximum UTF-8 codepoint value to check for in the range.
* @return ColumnVector of booleans where true indicates the string contains the pattern.
*/
public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, int start, int end) {
assert(input.getType().equals(DType.STRING)) : "column must be a String";
return new ColumnVector(literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, start, end));
}

private static native long literalRangePattern(long input, long literal, int len, int start, int end);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.Scalar;
import org.junit.jupiter.api.Test;

import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual;

public class RegexRewriteUtilsTest {

@Test
void testLiteralRangePattern() {
int d = 3;
try (ColumnVector inputCv = ColumnVector.fromStrings(
"abc123", "aabc123", "aabc12", "abc1232", "aabc1232");
Scalar pattern = Scalar.fromString("abc");
ColumnVector expected = ColumnVector.fromBooleans(true, true, false, true, true);
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 48, 57)) {
assertColumnsAreEqual(expected, actual);
}
}

@Test
void testLiteralRangePatternChinese() {
int d = 2;
try (ColumnVector inputCv = ColumnVector.fromStrings(
"数据砖块", "火花-急流英伟达", "英伟达Nvidia", "火花-急流");
Scalar pattern = Scalar.fromString("英");
ColumnVector expected = ColumnVector.fromBooleans(false, true, true, false);
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 19968, 40869)) {
assertColumnsAreEqual(expected, actual);
}
}

}

0 comments on commit 5156fe2

Please sign in to comment.