diff --git a/src/main/cpp/src/parse_uri.cu b/src/main/cpp/src/parse_uri.cu index 13c4050404..0e6ea2690d 100644 --- a/src/main/cpp/src/parse_uri.cu +++ b/src/main/cpp/src/parse_uri.cu @@ -494,48 +494,42 @@ bool __device__ validate_fragment(string_view fragment) __device__ std::pair find_query_part(string_view haystack, string_view needle) { - auto const n_bytes = needle.size_bytes(); - auto const find_length = haystack.size_bytes() - n_bytes + 1; - - auto h = haystack.data(); - auto const end_h = haystack.data() + find_length; - auto n = needle.data(); - bool match = false; - while (h < end_h) { - match = false; // initialize to false to prevent empty query key - for (size_type jdx = 0; (jdx == 0 || match) && (jdx < n_bytes); ++jdx) { - match = (h[jdx] == n[jdx]); + auto const n_bytes = needle.size_bytes(); + auto h = haystack.data(); + auto const h_end = h + haystack.size_bytes(); + auto n = needle.data(); + + // stop matching early after it can no longer contain the string we are searching for + while (h + n_bytes < h_end) { + bool match_needle = true; + for (size_type jdx = 0; jdx < n_bytes; ++jdx) { + match_needle = (h[jdx] == n[jdx]); + if (!match_needle) { break; } } - if (match) { match = n_bytes < haystack.size_bytes() && h[n_bytes] == '='; } - if (match) { - // we don't care about the matched part, we want the string data after that. - h += n_bytes; - break; - } else { - // skip to the next param, which is after a &. - while (h < end_h && *h != '&') { + + if (match_needle && h[n_bytes] == '=') { + // we don't care about the matched part, we want the string data after '=' + h += n_bytes + 1; + + // rest of string until end or until '&' is query match + int match_len = 0; + auto start = h; + while (h < h_end && *h != '&') { + match_len++; h++; } - } - h++; - } - - // if not match or no value, return nothing - if (!match || *h != '=') { return {{}, false}; } - // skip over the = - h++; + return {{start, match_len}, true}; + } - // rest of string until end or until '&' is query match - auto const bytes_left = haystack.size_bytes() - (h - haystack.data()); - int match_len = 0; - auto start = h; - while (*h != '&' && match_len < bytes_left) { - ++match_len; - ++h; + // not match, skip to the next param if possible, which is after a &. + while (h + n_bytes < h_end && *h != '&') { + h++; + } + h++; // skip over the & if has, or point to h_end +1 } - return {{start, match_len}, true}; + return {{}, false}; } uri_parts __device__ validate_uri(const char* str, diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java index 8f9fcfd903..ffe7e9e946 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java @@ -143,7 +143,7 @@ void testQuery(String[] testData, String[] params) { String[] pairs = query.split("&"); for (String pair : pairs) { int idx = pair.indexOf("="); - if (idx > 0 && pair.substring(0, idx).equals(params[i])) { + if (idx >= 0 && pair.substring(0, idx).equals(params[i])) { subquery = pair.substring(idx + 1); break; } @@ -218,6 +218,7 @@ void parseURISparkTest() { "https://www.nvidia.com/?cat=12", "www.nvidia.com/vote.php?pid=50", "https://www.nvidia.com/vote.php?=50", + "https://www.nvidia.com/vote.php?query=50" }; String[] queries = { @@ -276,6 +277,7 @@ void parseURISparkTest() { "f", "query", "query", + "", "" };