forked from NVIDIA/spark-rapids-jni
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add regex rewrite kernel to find literal[a,b]{x,y} in a string (#1)
* wip Signed-off-by: Haoyang Li <[email protected]> * wip Signed-off-by: Haoyang Li <[email protected]> * wip Signed-off-by: Haoyang Li <[email protected]> * wip Signed-off-by: Haoyang Li <[email protected]> * support range filter Signed-off-by: Haoyang Li <[email protected]> * clean up Signed-off-by: Haoyang Li <[email protected]> * wip Signed-off-by: Haoyang Li <[email protected]> * change some names Signed-off-by: Haoyang Li <[email protected]> * clean up Signed-off-by: Haoyang Li <[email protected]> * address comments Signed-off-by: Haoyang Li <[email protected]> --------- Signed-off-by: Haoyang Li <[email protected]>
- Loading branch information
1 parent
2881ab2
commit 5156fe2
Showing
6 changed files
with
316 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "cudf_jni_apis.hpp" | ||
#include "dtype_utils.hpp" | ||
#include "jni_utils.hpp" | ||
#include "regex_rewrite_utils.hpp" | ||
|
||
extern "C" { | ||
|
||
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_RegexRewriteUtils_literalRangePattern( | ||
JNIEnv* env, jclass, jlong input, jlong target, jint d, jint start, jint end) | ||
{ | ||
JNI_NULL_CHECK(env, input, "input column is null", 0); | ||
JNI_NULL_CHECK(env, target, "target is null", 0); | ||
|
||
try { | ||
cudf::jni::auto_set_device(env); | ||
|
||
cudf::column_view* cv = reinterpret_cast<cudf::column_view*>(input); | ||
cudf::strings_column_view scv(*cv); | ||
cudf::string_scalar* ss_scalar = reinterpret_cast<cudf::string_scalar*>(target); | ||
return cudf::jni::release_as_jlong( | ||
spark_rapids_jni::literal_range_pattern(scv, *ss_scalar, d, start, end)); | ||
} | ||
CATCH_STD(env, 0); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cudf/column/column_device_view.cuh> | ||
#include <cudf/column/column_factories.hpp> | ||
#include <cudf/detail/iterator.cuh> | ||
#include <cudf/detail/null_mask.hpp> | ||
#include <cudf/detail/nvtx/ranges.hpp> | ||
#include <cudf/strings/detail/utf8.hpp> | ||
#include <cudf/strings/string_view.cuh> | ||
#include <cudf/strings/strings_column_view.hpp> | ||
#include <cudf/utilities/default_stream.hpp> | ||
|
||
#include <rmm/cuda_stream_view.hpp> | ||
#include <rmm/exec_policy.hpp> | ||
|
||
#include <thrust/iterator/counting_iterator.h> | ||
#include <thrust/transform.h> | ||
|
||
namespace spark_rapids_jni { | ||
|
||
namespace { | ||
|
||
struct literal_range_pattern_fn { | ||
__device__ bool operator()( | ||
cudf::string_view d_string, cudf::string_view d_prefix, int range_len, int start, int end) | ||
{ | ||
int const n = d_string.length(), m = d_prefix.length(); | ||
for (int i = 0; i <= n - m - range_len; i++) { | ||
bool match = true; | ||
for (int j = 0; j < m; j++) { | ||
if (d_string[i + j] != d_prefix[j]) { | ||
match = false; | ||
break; | ||
} | ||
} | ||
if (match) { | ||
for (int j = 0; j < range_len; j++) { | ||
auto code_point = cudf::strings::detail::utf8_to_codepoint(d_string[i + m + j]); | ||
if (code_point < start || code_point > end) { | ||
match = false; | ||
break; | ||
} | ||
} | ||
if (match) { return true; } | ||
} | ||
} | ||
return false; | ||
} | ||
}; | ||
|
||
std::unique_ptr<cudf::column> find_literal_range_pattern(cudf::strings_column_view const& strings, | ||
cudf::string_scalar const& prefix, | ||
int const range_len, | ||
int const start, | ||
int const end, | ||
rmm::cuda_stream_view stream, | ||
rmm::device_async_resource_ref mr) | ||
{ | ||
auto const strings_count = strings.size(); | ||
if (strings_count == 0) return cudf::make_empty_column(cudf::type_id::BOOL8); | ||
|
||
CUDF_EXPECTS(prefix.is_valid(stream), "Parameter prefix must be valid."); | ||
|
||
auto const d_prefix = cudf::string_view(prefix.data(), prefix.size()); | ||
auto const strings_column = cudf::column_device_view::create(strings.parent(), stream); | ||
auto const d_strings = *strings_column; | ||
|
||
auto results = make_numeric_column(cudf::data_type{cudf::type_id::BOOL8}, | ||
strings_count, | ||
cudf::detail::copy_bitmask(strings.parent(), stream, mr), | ||
strings.null_count(), | ||
stream, | ||
mr); | ||
auto const d_results = results->mutable_view().data<bool>(); | ||
// set the bool values by evaluating the passed function | ||
thrust::transform( | ||
rmm::exec_policy(stream), | ||
thrust::make_counting_iterator<cudf::size_type>(0), | ||
thrust::make_counting_iterator<cudf::size_type>(strings_count), | ||
d_results, | ||
[d_strings, d_prefix, range_len, start, end] __device__(cudf::size_type idx) { | ||
if (!d_strings.is_null(idx)) { | ||
return bool{literal_range_pattern_fn{}( | ||
d_strings.element<cudf::string_view>(idx), d_prefix, range_len, start, end)}; | ||
} | ||
return false; | ||
}); | ||
results->set_null_count(strings.null_count()); | ||
return results; | ||
} | ||
|
||
} // namespace | ||
|
||
/** | ||
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means | ||
* a literal string followed by a range of characters in the range of start to end, with at least | ||
* len characters. | ||
* | ||
* @param strings Column of strings to check for literal. | ||
* @param literal UTF-8 encoded string to check in strings column. | ||
* @param len Minimum number of characters to check after the literal. | ||
* @param start Minimum UTF-8 codepoint value to check for in the range. | ||
* @param end Maximum UTF-8 codepoint value to check for in the range. | ||
* @param stream CUDA stream used for device memory operations and kernel launches. | ||
* @param mr Device memory resource used to allocate the returned column's device memory. | ||
*/ | ||
std::unique_ptr<cudf::column> literal_range_pattern(cudf::strings_column_view const& input, | ||
cudf::string_scalar const& prefix, | ||
int const range_len, | ||
int const start, | ||
int const end, | ||
rmm::cuda_stream_view stream, | ||
rmm::device_async_resource_ref mr) | ||
{ | ||
CUDF_FUNC_RANGE(); | ||
return find_literal_range_pattern(input, prefix, range_len, start, end, stream, mr); | ||
} | ||
|
||
} // namespace spark_rapids_jni |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cudf/scalar/scalar_factories.hpp> | ||
#include <cudf/strings/strings_column_view.hpp> | ||
#include <cudf/utilities/default_stream.hpp> | ||
|
||
namespace spark_rapids_jni { | ||
/** | ||
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means | ||
* a literal string followed by a range of characters in the range of start to end, with at least | ||
* len characters. | ||
* | ||
* @param strings Column of strings to check for literal. | ||
* @param literal UTF-8 encoded string to check in strings column. | ||
* @param len Minimum number of characters to check after the literal. | ||
* @param start Minimum UTF-8 codepoint value to check for in the range. | ||
* @param end Maximum UTF-8 codepoint value to check for in the range. | ||
* @param stream CUDA stream used for device memory operations and kernel launches. | ||
* @param mr Device memory resource used to allocate the returned column's device memory. | ||
*/ | ||
std::unique_ptr<cudf::column> literal_range_pattern( | ||
cudf::strings_column_view const& input, | ||
cudf::string_scalar const& literal, | ||
int const len, | ||
int const start, | ||
int const end, | ||
rmm::cuda_stream_view stream = rmm::cuda_stream_default, | ||
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()); | ||
} // namespace spark_rapids_jni |
44 changes: 44 additions & 0 deletions
44
src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.jni; | ||
|
||
import ai.rapids.cudf.*; | ||
|
||
public class RegexRewriteUtils { | ||
static { | ||
NativeDepsLoader.loadNativeDeps(); | ||
} | ||
|
||
/** | ||
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means | ||
* a literal string followed by a range of characters in the range of start to end, with at least | ||
* len characters. | ||
* | ||
* @param strings Column of strings to check for literal. | ||
* @param literal UTF-8 encoded string to check in strings column. | ||
* @param len Minimum number of characters to check after the literal. | ||
* @param start Minimum UTF-8 codepoint value to check for in the range. | ||
* @param end Maximum UTF-8 codepoint value to check for in the range. | ||
* @return ColumnVector of booleans where true indicates the string contains the pattern. | ||
*/ | ||
public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, int start, int end) { | ||
assert(input.getType().equals(DType.STRING)) : "column must be a String"; | ||
return new ColumnVector(literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, start, end)); | ||
} | ||
|
||
private static native long literalRangePattern(long input, long literal, int len, int start, int end); | ||
} |
51 changes: 51 additions & 0 deletions
51
src/test/java/com/nvidia/spark/rapids/jni/RegexRewriteUtilsTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.jni; | ||
|
||
import ai.rapids.cudf.ColumnVector; | ||
import ai.rapids.cudf.Scalar; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; | ||
|
||
public class RegexRewriteUtilsTest { | ||
|
||
@Test | ||
void testLiteralRangePattern() { | ||
int d = 3; | ||
try (ColumnVector inputCv = ColumnVector.fromStrings( | ||
"abc123", "aabc123", "aabc12", "abc1232", "aabc1232"); | ||
Scalar pattern = Scalar.fromString("abc"); | ||
ColumnVector expected = ColumnVector.fromBooleans(true, true, false, true, true); | ||
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 48, 57)) { | ||
assertColumnsAreEqual(expected, actual); | ||
} | ||
} | ||
|
||
@Test | ||
void testLiteralRangePatternChinese() { | ||
int d = 2; | ||
try (ColumnVector inputCv = ColumnVector.fromStrings( | ||
"数据砖块", "火花-急流英伟达", "英伟达Nvidia", "火花-急流"); | ||
Scalar pattern = Scalar.fromString("英"); | ||
ColumnVector expected = ColumnVector.fromBooleans(false, true, true, false); | ||
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 19968, 40869)) { | ||
assertColumnsAreEqual(expected, actual); | ||
} | ||
} | ||
|
||
} |