diff --git a/cpp/include/cudf/strings/string_view.cuh b/cpp/include/cudf/strings/string_view.cuh index 29062167f11..2e9fa6513b6 100644 --- a/cpp/include/cudf/strings/string_view.cuh +++ b/cpp/include/cudf/strings/string_view.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,6 +121,13 @@ __device__ inline string_view::const_iterator::const_iterator(const string_view& { } +__device__ inline string_view::const_iterator::const_iterator(string_view const& str, + size_type pos, + size_type offset) + : p{str.data()}, bytes{str.size_bytes()}, char_pos{pos}, byte_pos{offset} +{ +} + __device__ inline string_view::const_iterator& string_view::const_iterator::operator++() { if (byte_pos < bytes) @@ -244,7 +251,7 @@ __device__ inline string_view::const_iterator string_view::begin() const __device__ inline string_view::const_iterator string_view::end() const { - return const_iterator(*this, length()); + return const_iterator(*this, length(), size_bytes()); } // @endcond @@ -338,11 +345,14 @@ __device__ inline size_type string_view::find_impl(const char* str, size_type pos, size_type count) const { - if (!str || pos < 0) return npos; auto const nchars = length(); + if (!str || pos < 0 || pos > nchars) return npos; if (count < 0) count = nchars; - auto const spos = byte_offset(pos); - auto const epos = byte_offset(std::min(pos + count, nchars)); + + // use iterator to help reduce character/byte counting + auto itr = begin() + pos; + auto const spos = itr.byte_offset(); + auto const epos = ((pos + count) < nchars) ? (itr + count).byte_offset() : size_bytes(); auto const find_length = (epos - spos) - bytes + 1; @@ -352,7 +362,9 @@ __device__ inline size_type string_view::find_impl(const char* str, for (size_type jdx = 0; match && (jdx < bytes); ++jdx) { match = (ptr[jdx] == str[jdx]); } - if (match) { return character_offset(forward ? (idx + spos) : (epos - bytes - idx)); } + if (match) { return forward ? pos : character_offset(epos - bytes - idx); } + // use pos to record the current find position + pos += strings::detail::is_begin_utf8_char(*ptr); forward ? ++ptr : --ptr; } return npos; diff --git a/cpp/include/cudf/strings/string_view.hpp b/cpp/include/cudf/strings/string_view.hpp index 28f9d57e9bd..5a709447f0e 100644 --- a/cpp/include/cudf/strings/string_view.hpp +++ b/cpp/include/cudf/strings/string_view.hpp @@ -104,10 +104,12 @@ class string_view { [[nodiscard]] __device__ inline size_type byte_offset() const; private: + friend class string_view; const char* p{}; size_type bytes{}; size_type char_pos{}; size_type byte_pos{}; + __device__ inline const_iterator(string_view const& str, size_type pos, size_type offset); /// @endcond }; diff --git a/cpp/tests/strings/find_tests.cpp b/cpp/tests/strings/find_tests.cpp index ea36f5280e9..bd336540e0c 100644 --- a/cpp/tests/strings/find_tests.cpp +++ b/cpp/tests/strings/find_tests.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -41,6 +42,8 @@ TEST_F(StringsFindTest, Find) {1, 1, 0, 1, 1, 1}); auto results = cudf::strings::find(strings_view, cudf::string_scalar("é")); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + results = cudf::strings::rfind(strings_view, cudf::string_scalar("é")); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); } { cudf::test::fixed_width_column_wrapper expected({3, -1, -1, 0, -1, -1}, @@ -211,6 +214,14 @@ TEST_F(StringsFindTest, EmptyTarget) CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); results = cudf::strings::ends_with(strings_view, cudf::string_scalar("")); CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected); + + cudf::test::fixed_width_column_wrapper expected_find({0, 0, 0, 0, 0, 0}, + {1, 1, 0, 1, 1, 1}); + results = cudf::strings::find(strings_view, cudf::string_scalar("")); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, expected_find); + auto expected_rfind = cudf::strings::count_characters(strings_view); + results = cudf::strings::rfind(strings_view, cudf::string_scalar("")); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(*results, *expected_rfind); } TEST_F(StringsFindTest, AllEmpty)