diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index d30abc0b8f..169067bfdd 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -162,6 +162,7 @@ add_library( src/MapUtilsJni.cpp src/NativeParquetJni.cpp src/ParseURIJni.cpp + src/RegexRewriteUtilsJni.cpp src/RowConversionJni.cpp src/SparkResourceAdaptorJni.cpp src/ZOrderJni.cpp @@ -178,6 +179,7 @@ add_library( src/map_utils.cu src/murmur_hash.cu src/parse_uri.cu + src/regex_rewrite_utils.cu src/row_conversion.cu src/timezones.cu src/utilities.cu diff --git a/src/main/cpp/src/RegexRewriteUtilsJni.cpp b/src/main/cpp/src/RegexRewriteUtilsJni.cpp new file mode 100644 index 0000000000..28f346582c --- /dev/null +++ b/src/main/cpp/src/RegexRewriteUtilsJni.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cudf_jni_apis.hpp" +#include "dtype_utils.hpp" +#include "jni_utils.hpp" +#include "regex_rewrite_utils.hpp" + +extern "C" { + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_RegexRewriteUtils_literalRangePattern( + JNIEnv* env, jclass, jlong input, jlong target, jint d, jint start, jint end) +{ + JNI_NULL_CHECK(env, input, "input column is null", 0); + JNI_NULL_CHECK(env, target, "target is null", 0); + + try { + cudf::jni::auto_set_device(env); + + cudf::column_view* cv = reinterpret_cast(input); + cudf::strings_column_view scv(*cv); + cudf::string_scalar* ss_scalar = reinterpret_cast(target); + return cudf::jni::release_as_jlong( + spark_rapids_jni::literal_range_pattern(scv, *ss_scalar, d, start, end)); + } + CATCH_STD(env, 0); +} +} diff --git a/src/main/cpp/src/regex_rewrite_utils.cu b/src/main/cpp/src/regex_rewrite_utils.cu new file mode 100644 index 0000000000..2735b134f9 --- /dev/null +++ b/src/main/cpp/src/regex_rewrite_utils.cu @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace spark_rapids_jni { + +namespace { + +struct literal_range_pattern_fn { + __device__ bool operator()( + cudf::string_view d_string, cudf::string_view d_prefix, int range_len, int start, int end) const + { + int const n = d_string.length(), m = d_prefix.length(); + for (int i = 0; i <= n - m - range_len; i++) { + bool match = true; + for (int j = 0; j < m; j++) { + if (d_string[i + j] != d_prefix[j]) { + match = false; + break; + } + } + if (match) { + for (int j = 0; j < range_len; j++) { + auto code_point = cudf::strings::detail::utf8_to_codepoint(d_string[i + m + j]); + if (code_point < start || code_point > end) { + match = false; + break; + } + } + if (match) { return true; } + } + } + return false; + } +}; + +std::unique_ptr find_literal_range_pattern(cudf::strings_column_view const& strings, + cudf::string_scalar const& prefix, + int const range_len, + int const start, + int const end, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const strings_count = strings.size(); + if (strings_count == 0) { return cudf::make_empty_column(cudf::type_id::BOOL8); } + + CUDF_EXPECTS(prefix.is_valid(stream), "Parameter prefix must be valid."); + + auto const d_prefix = cudf::string_view(prefix.data(), prefix.size()); + auto const strings_column = cudf::column_device_view::create(strings.parent(), stream); + auto const d_strings = *strings_column; + + auto results = make_numeric_column(cudf::data_type{cudf::type_id::BOOL8}, + strings_count, + cudf::detail::copy_bitmask(strings.parent(), stream, mr), + strings.null_count(), + stream, + mr); + auto const d_results = results->mutable_view().data(); + // set the bool values by evaluating the passed function + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(strings_count), + d_results, + [d_strings, d_prefix, range_len, start, end, check_fn = literal_range_pattern_fn{}] __device__( + cudf::size_type idx) { + if (!d_strings.is_null(idx)) { + return check_fn(d_strings.element(idx), d_prefix, range_len, start, end); + } + return false; + }); + results->set_null_count(strings.null_count()); + return results; +} + +} // namespace + +/** + * @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means + * a literal string followed by a range of characters in the range of start to end, with at least + * len characters. + * + * @param strings Column of strings to check for literal. + * @param literal UTF-8 encoded string to check in strings column. + * @param len Minimum number of characters to check after the literal. + * @param start Minimum UTF-8 codepoint value to check for in the range. + * @param end Maximum UTF-8 codepoint value to check for in the range. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + */ +std::unique_ptr literal_range_pattern(cudf::strings_column_view const& input, + cudf::string_scalar const& prefix, + int const range_len, + int const start, + int const end, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + CUDF_FUNC_RANGE(); + return find_literal_range_pattern(input, prefix, range_len, start, end, stream, mr); +} + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/regex_rewrite_utils.hpp b/src/main/cpp/src/regex_rewrite_utils.hpp new file mode 100644 index 0000000000..e5e500b180 --- /dev/null +++ b/src/main/cpp/src/regex_rewrite_utils.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace spark_rapids_jni { +/** + * @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means + * a literal string followed by a range of characters in the range of start to end, with at least + * len characters. + * + * @param strings Column of strings to check for literal. + * @param literal UTF-8 encoded string to check in strings column. + * @param len Minimum number of characters to check after the literal. + * @param start Minimum UTF-8 codepoint value to check for in the range. + * @param end Maximum UTF-8 codepoint value to check for in the range. + * @param stream CUDA stream used for device memory operations and kernel launches. + * @param mr Device memory resource used to allocate the returned column's device memory. + */ +std::unique_ptr literal_range_pattern( + cudf::strings_column_view const& input, + cudf::string_scalar const& literal, + int const len, + int const start, + int const end, + rmm::cuda_stream_view stream = rmm::cuda_stream_default, + rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()); +} // namespace spark_rapids_jni diff --git a/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java new file mode 100644 index 0000000000..9277c3e0f9 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.*; + +public class RegexRewriteUtils { + static { + NativeDepsLoader.loadNativeDeps(); + } + +/** + * @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means + * a literal string followed by a range of characters in the range of start to end, with at least + * len characters. + * + * @param strings Column of strings to check for literal. + * @param literal UTF-8 encoded string to check in strings column. + * @param len Minimum number of characters to check after the literal. + * @param start Minimum UTF-8 codepoint value to check for in the range. + * @param end Maximum UTF-8 codepoint value to check for in the range. + * @return ColumnVector of booleans where true indicates the string contains the pattern. + */ + public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, int start, int end) { + assert(input.getType().equals(DType.STRING)) : "column must be a String"; + return new ColumnVector(literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, start, end)); + } + + private static native long literalRangePattern(long input, long literal, int len, int start, int end); +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/RegexRewriteUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/RegexRewriteUtilsTest.java new file mode 100644 index 0000000000..243967055a --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/RegexRewriteUtilsTest.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.Scalar; +import org.junit.jupiter.api.Test; + +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; + +public class RegexRewriteUtilsTest { + + @Test + void testLiteralRangePattern() { + int d = 3; + try (ColumnVector inputCv = ColumnVector.fromStrings( + "abc123", "aabc123", "aabc12", "abc1232", "aabc1232"); + Scalar pattern = Scalar.fromString("abc"); + ColumnVector expected = ColumnVector.fromBooleans(true, true, false, true, true); + ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 48, 57)) { + assertColumnsAreEqual(expected, actual); + } + } + + @Test + void testLiteralRangePatternChinese() { + int d = 2; + try (ColumnVector inputCv = ColumnVector.fromStrings( + "数据砖块", "火花-急流英伟达", "英伟达Nvidia", "火花-急流"); + Scalar pattern = Scalar.fromString("英"); + ColumnVector expected = ColumnVector.fromBooleans(false, true, true, false); + ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 19968, 40869)) { + assertColumnsAreEqual(expected, actual); + } + } + +}