-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add regex rewrite kernel to find literal[a,b]{x,y}
in a string
#2041
Merged
Merged
Changes from 17 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
f54ea85
wip
thirtiseven 6fffca8
wip
thirtiseven 64d3784
wip
thirtiseven 5e8fe37
Merge branch 'branch-24.06' into str_dig
thirtiseven 7b8d375
wip
thirtiseven a92beea
Merge branch 'branch-24.06' into str_dig
thirtiseven 21c19be
Merge branch 'branch-24.06' into str_dig
thirtiseven 28f0ad5
Merge branch 'branch-24.06' into str_dig
thirtiseven d30e025
support range filter
thirtiseven 8293324
clean up
thirtiseven 03ba4fa
Merge branch 'branch-24.06' into str_dig
thirtiseven ac6575c
wip
thirtiseven c335a8d
change some names
thirtiseven cf9fde7
clean up
thirtiseven 9e7b06a
address comments
thirtiseven ee96c44
Apply suggestions from code review
thirtiseven 0be1c23
format
thirtiseven 7e46ab2
fix build
thirtiseven File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "cudf_jni_apis.hpp" | ||
#include "dtype_utils.hpp" | ||
#include "jni_utils.hpp" | ||
#include "regex_rewrite_utils.hpp" | ||
|
||
extern "C" { | ||
|
||
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_RegexRewriteUtils_literalRangePattern( | ||
JNIEnv* env, jclass, jlong input, jlong target, jint d, jint start, jint end) | ||
{ | ||
JNI_NULL_CHECK(env, input, "input column is null", 0); | ||
JNI_NULL_CHECK(env, target, "target is null", 0); | ||
|
||
try { | ||
cudf::jni::auto_set_device(env); | ||
|
||
cudf::column_view* cv = reinterpret_cast<cudf::column_view*>(input); | ||
cudf::strings_column_view scv(*cv); | ||
cudf::string_scalar* ss_scalar = reinterpret_cast<cudf::string_scalar*>(target); | ||
return cudf::jni::release_as_jlong( | ||
spark_rapids_jni::literal_range_pattern(scv, *ss_scalar, d, start, end)); | ||
} | ||
CATCH_STD(env, 0); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cudf/column/column_device_view.cuh> | ||
#include <cudf/column/column_factories.hpp> | ||
#include <cudf/detail/iterator.cuh> | ||
#include <cudf/detail/null_mask.hpp> | ||
#include <cudf/detail/nvtx/ranges.hpp> | ||
#include <cudf/strings/detail/utf8.hpp> | ||
#include <cudf/strings/string_view.cuh> | ||
#include <cudf/strings/strings_column_view.hpp> | ||
#include <cudf/utilities/default_stream.hpp> | ||
|
||
#include <rmm/cuda_stream_view.hpp> | ||
#include <rmm/exec_policy.hpp> | ||
|
||
#include <thrust/iterator/counting_iterator.h> | ||
#include <thrust/transform.h> | ||
|
||
namespace spark_rapids_jni { | ||
|
||
namespace { | ||
|
||
struct literal_range_pattern_fn { | ||
__device__ bool operator() | ||
const(cudf::string_view d_string, cudf::string_view d_prefix, int range_len, int start, int end) | ||
{ | ||
int const n = d_string.length(), m = d_prefix.length(); | ||
for (int i = 0; i <= n - m - range_len; i++) { | ||
bool match = true; | ||
for (int j = 0; j < m; j++) { | ||
if (d_string[i + j] != d_prefix[j]) { | ||
match = false; | ||
break; | ||
} | ||
} | ||
if (match) { | ||
for (int j = 0; j < range_len; j++) { | ||
auto code_point = cudf::strings::detail::utf8_to_codepoint(d_string[i + m + j]); | ||
if (code_point < start || code_point > end) { | ||
match = false; | ||
break; | ||
} | ||
} | ||
if (match) { return true; } | ||
} | ||
} | ||
return false; | ||
} | ||
}; | ||
|
||
std::unique_ptr<cudf::column> find_literal_range_pattern(cudf::strings_column_view const& strings, | ||
cudf::string_scalar const& prefix, | ||
int const range_len, | ||
int const start, | ||
int const end, | ||
rmm::cuda_stream_view stream, | ||
rmm::device_async_resource_ref mr) | ||
{ | ||
auto const strings_count = strings.size(); | ||
if (strings_count == 0) { return cudf::make_empty_column(cudf::type_id::BOOL8); } | ||
|
||
CUDF_EXPECTS(prefix.is_valid(stream), "Parameter prefix must be valid."); | ||
|
||
auto const d_prefix = cudf::string_view(prefix.data(), prefix.size()); | ||
auto const strings_column = cudf::column_device_view::create(strings.parent(), stream); | ||
auto const d_strings = *strings_column; | ||
|
||
auto results = make_numeric_column(cudf::data_type{cudf::type_id::BOOL8}, | ||
strings_count, | ||
cudf::detail::copy_bitmask(strings.parent(), stream, mr), | ||
strings.null_count(), | ||
stream, | ||
mr); | ||
auto const d_results = results->mutable_view().data<bool>(); | ||
// set the bool values by evaluating the passed function | ||
thrust::transform( | ||
rmm::exec_policy(stream), | ||
thrust::make_counting_iterator<cudf::size_type>(0), | ||
thrust::make_counting_iterator<cudf::size_type>(strings_count), | ||
d_results, | ||
[d_strings, d_prefix, range_len, start, end, check_fn = literal_range_pattern_fn{}] __device__( | ||
cudf::size_type idx) { | ||
if (!d_strings.is_null(idx)) { | ||
return check_fn(d_strings.element<cudf::string_view>(idx), d_prefix, range_len, start, end); | ||
} | ||
return false; | ||
}); | ||
results->set_null_count(strings.null_count()); | ||
return results; | ||
} | ||
|
||
} // namespace | ||
|
||
/** | ||
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means | ||
* a literal string followed by a range of characters in the range of start to end, with at least | ||
* len characters. | ||
* | ||
* @param strings Column of strings to check for literal. | ||
* @param literal UTF-8 encoded string to check in strings column. | ||
* @param len Minimum number of characters to check after the literal. | ||
* @param start Minimum UTF-8 codepoint value to check for in the range. | ||
* @param end Maximum UTF-8 codepoint value to check for in the range. | ||
* @param stream CUDA stream used for device memory operations and kernel launches. | ||
* @param mr Device memory resource used to allocate the returned column's device memory. | ||
*/ | ||
std::unique_ptr<cudf::column> literal_range_pattern(cudf::strings_column_view const& input, | ||
cudf::string_scalar const& prefix, | ||
int const range_len, | ||
int const start, | ||
int const end, | ||
rmm::cuda_stream_view stream, | ||
rmm::device_async_resource_ref mr) | ||
{ | ||
CUDF_FUNC_RANGE(); | ||
return find_literal_range_pattern(input, prefix, range_len, start, end, stream, mr); | ||
} | ||
|
||
} // namespace spark_rapids_jni |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cudf/scalar/scalar_factories.hpp> | ||
#include <cudf/strings/strings_column_view.hpp> | ||
#include <cudf/utilities/default_stream.hpp> | ||
|
||
namespace spark_rapids_jni { | ||
/** | ||
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means | ||
* a literal string followed by a range of characters in the range of start to end, with at least | ||
* len characters. | ||
* | ||
* @param strings Column of strings to check for literal. | ||
* @param literal UTF-8 encoded string to check in strings column. | ||
* @param len Minimum number of characters to check after the literal. | ||
* @param start Minimum UTF-8 codepoint value to check for in the range. | ||
* @param end Maximum UTF-8 codepoint value to check for in the range. | ||
* @param stream CUDA stream used for device memory operations and kernel launches. | ||
* @param mr Device memory resource used to allocate the returned column's device memory. | ||
*/ | ||
std::unique_ptr<cudf::column> literal_range_pattern( | ||
ttnghia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cudf::strings_column_view const& input, | ||
cudf::string_scalar const& literal, | ||
int const len, | ||
int const start, | ||
int const end, | ||
rmm::cuda_stream_view stream = rmm::cuda_stream_default, | ||
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()); | ||
} // namespace spark_rapids_jni |
44 changes: 44 additions & 0 deletions
44
src/main/java/com/nvidia/spark/rapids/jni/RegexRewriteUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.jni; | ||
|
||
import ai.rapids.cudf.*; | ||
|
||
public class RegexRewriteUtils { | ||
static { | ||
NativeDepsLoader.loadNativeDeps(); | ||
} | ||
|
||
/** | ||
* @brief Check if input string contains regex pattern `literal[start-end]{len,}`, which means | ||
* a literal string followed by a range of characters in the range of start to end, with at least | ||
* len characters. | ||
* | ||
* @param strings Column of strings to check for literal. | ||
* @param literal UTF-8 encoded string to check in strings column. | ||
* @param len Minimum number of characters to check after the literal. | ||
* @param start Minimum UTF-8 codepoint value to check for in the range. | ||
* @param end Maximum UTF-8 codepoint value to check for in the range. | ||
* @return ColumnVector of booleans where true indicates the string contains the pattern. | ||
*/ | ||
public static ColumnVector literalRangePattern(ColumnVector input, Scalar literal, int len, int start, int end) { | ||
assert(input.getType().equals(DType.STRING)) : "column must be a String"; | ||
return new ColumnVector(literalRangePattern(input.getNativeView(), CudfAccessor.getScalarHandle(literal), len, start, end)); | ||
} | ||
|
||
private static native long literalRangePattern(long input, long literal, int len, int start, int end); | ||
} |
51 changes: 51 additions & 0 deletions
51
src/test/java/com/nvidia/spark/rapids/jni/RegexRewriteUtilsTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.jni; | ||
|
||
import ai.rapids.cudf.ColumnVector; | ||
import ai.rapids.cudf.Scalar; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; | ||
|
||
public class RegexRewriteUtilsTest { | ||
|
||
@Test | ||
void testLiteralRangePattern() { | ||
int d = 3; | ||
try (ColumnVector inputCv = ColumnVector.fromStrings( | ||
"abc123", "aabc123", "aabc12", "abc1232", "aabc1232"); | ||
Scalar pattern = Scalar.fromString("abc"); | ||
ColumnVector expected = ColumnVector.fromBooleans(true, true, false, true, true); | ||
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 48, 57)) { | ||
assertColumnsAreEqual(expected, actual); | ||
} | ||
} | ||
|
||
@Test | ||
void testLiteralRangePatternChinese() { | ||
int d = 2; | ||
try (ColumnVector inputCv = ColumnVector.fromStrings( | ||
"数据砖块", "火花-急流英伟达", "英伟达Nvidia", "火花-急流"); | ||
Scalar pattern = Scalar.fromString("英"); | ||
ColumnVector expected = ColumnVector.fromBooleans(false, true, true, false); | ||
ColumnVector actual = RegexRewriteUtils.literalRangePattern(inputCv, pattern, d, 19968, 40869)) { | ||
assertColumnsAreEqual(expected, actual); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect that the file was not properly formated (using
clang-format
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It passed clang-format in the pre-commit hook