Skip to content

Commit

Permalink
Support locate function
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Nov 8, 2024
1 parent f2e137a commit 82dad84
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 49 deletions.
20 changes: 20 additions & 0 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,26 @@ String Functions
SELECT levenshtein('kitten', 'sitting', 10); -- 3
SELECT levenshtein('kitten', 'sitting', 2); -- -1

.. spark:function:: locate(substring, string[, start]) -> integer
Returns the position of the first occurrence of ``substring`` in given ``string``
after position ``start``. The given ``start`` and return value are 1-based.
Returns 0 if ``start`` is NULL. Returns NULL if ``substring`` or ``string`` is NULL.
Returns 0 if ``start`` is less than 1 or greater than the size of ``string``.
Returns 0 if ``substring`` is not found in ``string``. ::

SELECT locate('aa', 'aaads'); -- 1
SELECT locate('aa', 'aaads', -1); -- 0
SELECT locate('aa', 'aaads', 2); -- 2
SELECT locate('aa', 'aaads', 6); -- 0
SELECT locate('aa', 'aaads', NULL) -- 0
SELECT locate('', 'aaads', 1); -- 1
SELECT locate('', 'aaads', 9); -- 1
SELECT locate('', ''); -- 1
SELECT locate('aa', '') -- 0
SELECT locate(NULL, NULL, NULL) -- 0
SELECT locate(NULL, NULL, 1) -- NULL

.. spark:function:: lower(string) -> string
Returns string with all characters changed to lowercase. ::
Expand Down
47 changes: 36 additions & 11 deletions velox/functions/lib/string/StringImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,27 +199,52 @@ std::vector<int32_t> stringToCodePoints(const T& inputString) {

/// Returns the starting position in characters of the Nth instance(counting
/// from the left if lpos==true and from the end otherwise) of the substring in
/// string. Positions start with 1. If not found, 0 is returned. If subString is
/// empty result is 1.
/// string after position 'start'. Positions start with 1. If not
/// found, 0 is returned. If subString is empty result is 1.
/// @param instance The instance of the substring to find in string.
/// @param start The start position in characters to search for the substring in
/// string.
template <bool isAscii, bool lpos = true, typename T>
FOLLY_ALWAYS_INLINE int64_t
stringPosition(const T& string, const T& subString, int64_t instance = 0) {
FOLLY_ALWAYS_INLINE int64_t stringPosition(
const T& string,
const T& subString,
int64_t instance = 0,
int64_t start = 1) {
VELOX_USER_CHECK_GT(instance, 0, "'instance' must be a positive number");
if (subString.size() == 0) {
return 1;
}
if (start < 1 || start > string.size()) {
return 0;
}

int64_t startByteIndex;
if constexpr (isAscii) {
startByteIndex = start - 1;
} else {
// Calculate the byte index of the start character.
const char* pos = string.data();
int64_t numCharacters = 0;
while (numCharacters < start - 1) {
if (!utf_cont(*pos)) {
numCharacters++;
}
pos++;
}
startByteIndex = pos - string.data();
}

// The string to search for substring.
auto view = std::string_view(
string.data() + startByteIndex, string.size() - startByteIndex);

int64_t byteIndex = -1;
if constexpr (lpos) {
byteIndex = findNthInstanceByteIndexFromStart(
std::string_view(string.data(), string.size()),
std::string_view(subString.data(), subString.size()),
instance);
view, std::string_view(subString.data(), subString.size()), instance);
} else {
byteIndex = findNthInstanceByteIndexFromEnd(
std::string_view(string.data(), string.size()),
std::string_view(subString.data(), subString.size()),
instance);
view, std::string_view(subString.data(), subString.size()), instance);
}

if (byteIndex == -1) {
Expand All @@ -228,7 +253,7 @@ stringPosition(const T& string, const T& subString, int64_t instance = 0) {

// Return the number of characters from the beginning of the string to
// byteIndex.
return length<isAscii>(std::string_view(string.data(), byteIndex)) + 1;
return length<isAscii>(std::string_view(view.data(), byteIndex)) + start;
}

/// Replace replaced with replacement in inputString and write results to
Expand Down
90 changes: 54 additions & 36 deletions velox/functions/lib/string/tests/StringImplTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,99 +399,117 @@ TEST_F(StringImplTest, overlappedStringPosition) {
auto testValidInputAsciiLpos = [](const std::string& string,
const std::string& substr,
const int64_t instance,
const int64_t start,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ true, true>(
StringView(string), StringView(substr), instance);
StringView(string), StringView(substr), instance, start);
ASSERT_EQ(result, expectedPosition);
};
auto testValidInputAsciiRpos = [](const std::string& string,
const std::string& substr,
const int64_t instance,
const int64_t start,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ true, false>(
StringView(string), StringView(substr), instance);
StringView(string), StringView(substr), instance, start);
ASSERT_EQ(result, expectedPosition);
};

auto testValidInputUnicodeLpos = [](const std::string& string,
const std::string& substr,
const int64_t instance,
const int64_t start,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ false, true>(
StringView(string), StringView(substr), instance);
StringView(string), StringView(substr), instance, start);
ASSERT_EQ(result, expectedPosition);
};

auto testValidInputUnicodeRpos = [](const std::string& string,
const std::string& substr,
const int64_t instance,
const int64_t start,
const int64_t expectedPosition) {
auto result = stringPosition</*isAscii*/ false, false>(
StringView(string), StringView(substr), instance);
StringView(string), StringView(substr), instance, start);
ASSERT_EQ(result, expectedPosition);
};

testValidInputAsciiLpos("aaa", "aa", 2, 2L);
testValidInputAsciiRpos("aaa", "aa", 2, 1L);

testValidInputAsciiLpos("|||", "||", 2, 2L);
testValidInputAsciiRpos("|||", "||", 2, 1L);

testValidInputUnicodeLpos("😋😋😋", "😋😋", 2, 2L);
testValidInputUnicodeRpos("😋😋😋", "😋😋", 2, 1L);

testValidInputUnicodeLpos("你你你", "你你", 2, 2L);
testValidInputUnicodeRpos("你你你", "你你", 2, 1L);
testValidInputAsciiLpos("aaa", "aa", 2, 1, 2L);
testValidInputAsciiLpos("aaaaaaa", "aa", 2, 3, 4L);
testValidInputAsciiRpos("aaa", "aa", 2, 1, 1L);
testValidInputAsciiRpos("aaaaaaa", "aa", 2, 2, 5L);

testValidInputAsciiLpos("|||", "||", 2, 1, 2L);
testValidInputAsciiLpos("|||||||", "||", 2, 4, 5L);
testValidInputAsciiRpos("|||", "||", 2, 1, 1L);
testValidInputAsciiRpos("|||||||", "||", 2, 4, 5L);

testValidInputUnicodeLpos("😋😋😋", "😋😋", 2, 1, 2L);
testValidInputUnicodeLpos("😋😋😋😋😋", "😋😋", 2, 4, 0L);
testValidInputUnicodeRpos("😋😋😋", "😋😋", 2, 2, 0L);
testValidInputUnicodeRpos("😋😋😋😋😋", "😋😋", 2, 4, 0L);

testValidInputUnicodeLpos("你你你", "你你", 2, 1, 2L);
testValidInputUnicodeLpos("你你你你你你", "你你", 2, 4, 5L);
testValidInputUnicodeRpos("你你你", "你你", 2, 1, 1L);
testValidInputUnicodeRpos("你你你你你你", "你你", 2, 4, 4L);
}

TEST_F(StringImplTest, stringPosition) {
auto testValidInputAscii = [](const std::string& string,
const std::string& substr,
const int64_t instance,
const int64_t start,
const int64_t expectedPosition) {
ASSERT_EQ(
stringPosition</*isAscii*/ true>(
StringView(string), StringView(substr), instance),
StringView(string), StringView(substr), instance, start),
expectedPosition);
ASSERT_EQ(
stringPosition</*isAscii*/ false>(
StringView(string), StringView(substr), instance),
StringView(string), StringView(substr), instance, start),
expectedPosition);
};

auto testValidInputUnicode = [](const std::string& string,
const std::string& substr,
const int64_t instance,
const int64_t start,
const int64_t expectedPosition) {
ASSERT_EQ(
stringPosition</*isAscii*/ false>(
StringView(string), StringView(substr), instance),
StringView(string), StringView(substr), instance, start),
expectedPosition);
ASSERT_EQ(
stringPosition</*isAscii*/ false>(
StringView(string), StringView(substr), instance),
StringView(string), StringView(substr), instance, start),
expectedPosition);
};

testValidInputAscii("high", "ig", 1, 2L);
testValidInputAscii("high", "igx", 1, 0L);
testValidInputAscii("Quadratically", "a", 1, 3L);
testValidInputAscii("foobar", "foobar", 1, 1L);
testValidInputAscii("foobar", "obar", 1, 3L);
testValidInputAscii("zoo!", "!", 1, 4L);
testValidInputAscii("x", "", 1, 1L);
testValidInputAscii("", "", 1, 1L);
testValidInputAscii("abc/xyz/foo/bar", "/", 3, 12L);

testValidInputUnicode("\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "\u7231", 1, 4L);
testValidInputAscii("high", "ig", 1, 1, 2L);
testValidInputAscii("high", "igx", 1, 1, 0L);
testValidInputAscii("Quadratically", "a", 1, 1, 3L);
testValidInputAscii("foobar", "foobar", 1, 1, 1L);
testValidInputAscii("foobar", "obar", 1, 1, 3L);
testValidInputAscii("zoo!", "!", 1, 1, 4L);
testValidInputAscii("x", "", 1, 1, 1L);
testValidInputAscii("", "", 1, 1, 1L);
testValidInputAscii("abc/xyz/foo/bar", "/", 3, 1, 12L);
testValidInputAscii("abc/xyz/foo/bar", "/", 2, 5, 12L);

testValidInputUnicode("\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "\u7231", 1, 1, 4L);
testValidInputUnicode(
"\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "\u5E0C\u671B", 1, 1, 6L);
testValidInputUnicode("\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "nice", 1, 1, 0L);
testValidInputUnicode(
"\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "\u5E0C\u671B", 1, 6L);
testValidInputUnicode("\u4FE1\u5FF5,\u7231,\u5E0C\u671B", "nice", 1, 0L);
"\u4FE1\u5FF5,\u7231,\u5E0C\u671B,\u7231", "\u7231", 1, 6, 9L);

testValidInputUnicode("abc/xyz/foo/bar", "/", 1, 4L);
testValidInputUnicode("abc/xyz/foo/bar", "/", 2, 8L);
testValidInputUnicode("abc/xyz/foo/bar", "/", 3, 12L);
testValidInputUnicode("abc/xyz/foo/bar", "/", 4, 0L);
testValidInputUnicode("abc/xyz/foo/bar", "/", 1, 1, 4L);
testValidInputUnicode("abc/xyz/foo/bar", "/", 2, 1, 8L);
testValidInputUnicode("abc/xyz/foo/bar", "/", 3, 1, 12L);
testValidInputUnicode("abc/xyz/foo/bar", "/", 4, 1, 0L);
testValidInputAscii("abc/xyz/foo/bar", "/", 1, 13, 0L);

EXPECT_THROW(
stringPosition</*isAscii*/ false>(
Expand Down
6 changes: 5 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,17 @@ void registerFunctions(const std::string& prefix) {
registerCompareFunctions(prefix);
registerBitwiseFunctions(prefix);

// String sreach function
// String search function
registerFunction<StartsWithFunction, bool, Varchar, Varchar>(
{prefix + "startswith"});
registerFunction<EndsWithFunction, bool, Varchar, Varchar>(
{prefix + "endswith"});
registerFunction<ContainsFunction, bool, Varchar, Varchar>(
{prefix + "contains"});
registerFunction<LocateFunction, int32_t, Varchar, Varchar>(
{prefix + "locate"});
registerFunction<LocateFunction, int32_t, Varchar, Varchar, int32_t>(
{prefix + "locate"});

registerFunction<TrimSpaceFunction, Varchar, Varchar>({prefix + "trim"});
registerFunction<TrimFunction, Varchar, Varchar, Varchar>({prefix + "trim"});
Expand Down
57 changes: 56 additions & 1 deletion velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ struct StartsWithFunction {
result = false;
} else {
result = str1.substr(0, str2.length()) == str2;
;
}
return true;
}
Expand Down Expand Up @@ -293,6 +292,62 @@ struct EndsWithFunction {
}
};

/// locate function
/// locate(string, string) -> integer
/// locate(string, string, integer) -> integer
///
/// Returns the position of the first occurrence of the first string in the
/// second string after the give position.
template <typename T>
struct LocateFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
out_type<int32_t>& result,
const arg_type<Varchar>& subString,
const arg_type<Varchar>& string) {
result = stringImpl::stringPosition<false /*isAscii*/>(
string, subString, 1 /*instance*/, 1 /*start*/);
return true;
}

FOLLY_ALWAYS_INLINE bool callAscii(
out_type<int32_t>& result,
const arg_type<Varchar>& subString,
const arg_type<Varchar>& string) {
result = stringImpl::stringPosition<true /*isAscii*/>(
string, subString, 1 /*instance*/, 1 /*start*/);
return true;
}

FOLLY_ALWAYS_INLINE bool callAscii(
out_type<int32_t>& result,
const arg_type<Varchar>& subString,
const arg_type<Varchar>& string,
const arg_type<int32_t>& start) {
result = stringImpl::stringPosition<true /*isAscii*/>(
string, subString, 1 /*instance*/, start);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<int32_t>& result,
const arg_type<Varchar>* subString,
const arg_type<Varchar>* string,
const arg_type<int32_t>* start) {
if (start == nullptr) {
result = 0;
return true;
}
if (subString == nullptr || string == nullptr) {
return false;
}
result = stringImpl::stringPosition<false /*isAscii*/>(
*string, *subString, 1 /*instance*/, *start);
return true;
}
};

/// Returns the substring from str before count occurrences of the delimiter
/// delim. If count is positive, everything to the left of the final delimiter
/// (counting from the left) is returned. If count is negative, everything to
Expand Down
39 changes: 39 additions & 0 deletions velox/functions/sparksql/tests/StringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,45 @@ TEST_F(StringTest, substring) {
EXPECT_EQ(substringWithLength("da\u6570\u636Eta", -3, 2), "\u636Et");
}

TEST_F(StringTest, locate) {
const auto locate = [&](const std::optional<std::string>& substr,
const std::optional<std::string>& str,
const std::optional<int32_t>& start = std::nullopt,
bool withStart = false) {
if (!start.has_value() && !withStart) {
return evaluateOnce<int32_t>("locate(c0, c1)", substr, str);
}
return evaluateOnce<int32_t>("locate(c0, c1, c2)", substr, str, start);
};

EXPECT_EQ(locate("aa", "aaads"), 1);
EXPECT_EQ(locate("aa", "aaads", 0), 0);
EXPECT_EQ(locate("aa", "aaads", 2), 2);
EXPECT_EQ(locate("aa", "aaads", 3), 0);
EXPECT_EQ(locate("aa", "aaads", -3), 0);
EXPECT_EQ(locate("de", "aaads"), 0);
EXPECT_EQ(locate("de", "aaads", 2), 0);
EXPECT_EQ(locate("abc", "abcdddabcabc", 6), 7);
EXPECT_EQ(locate("", ""), 1);
EXPECT_EQ(locate("", "", 3), 1);
EXPECT_EQ(locate("", "aaads"), 1);
EXPECT_EQ(locate("", "aaads", 9), 1);
EXPECT_EQ(locate("aa", ""), 0);
EXPECT_EQ(locate("aa", "", 2), 0);
EXPECT_EQ(locate("zz", "aaads", std::nullopt, true), 0);
EXPECT_EQ(locate("aa", std::nullopt), std::nullopt);
EXPECT_EQ(locate(std::nullopt, "aaads"), std::nullopt);
EXPECT_EQ(locate(std::nullopt, std::nullopt, -1), std::nullopt);
EXPECT_EQ(locate(std::nullopt, std::nullopt, std::nullopt, true), 0);

EXPECT_EQ(locate("\u7231", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B"), 4);
EXPECT_EQ(locate("\u7231", "\u4FE1\u5FF5,\u7231,\u5E0C\u671B", 0), 0);
EXPECT_EQ(
locate("\u4FE1", "\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B", 2), 4);
EXPECT_EQ(
locate("\u4FE1", "\u4FE1\u5FF5,\u4FE1\u7231,\u4FE1\u5E0C\u671B", 8), 0);
}

TEST_F(StringTest, substringIndex) {
const auto substringIndex =
[&](const std::string& str, const std::string& delim, int32_t count) {
Expand Down

0 comments on commit 82dad84

Please sign in to comment.