From 6055c7e13ac8f763537f10bcfec0fe4c80b39146 Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Mon, 21 Aug 2023 15:28:03 -0700 Subject: [PATCH] Fix `ASCII` function and groom UT for text functions. (#301) (#1895) * Fix `ASCII` function and groom UT for text functions. * Code cleanup. --------- Signed-off-by: Yury-Fridlyand --- .../sql/expression/text/TextFunction.java | 3 +- .../sql/expression/text/TextFunctionTest.java | 304 ++++++------------ 2 files changed, 95 insertions(+), 212 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 1cf7f64867..d670843551 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -416,7 +416,8 @@ private static ExprValue exprLeft(ExprValue expr, ExprValue length) { } private static ExprValue exprAscii(ExprValue expr) { - return new ExprIntegerValue((int) expr.stringValue().charAt(0)); + return new ExprIntegerValue( + expr.stringValue().length() == 0 ? 0 : (int) expr.stringValue().charAt(0)); } private static ExprValue exprLocate(ExprValue subStr, ExprValue str) { diff --git a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java index 84ae0b844f..b58f3031b7 100644 --- a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java @@ -6,6 +6,8 @@ package org.opensearch.sql.expression.text; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.missingValue; import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; @@ -17,13 +19,12 @@ import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +import java.util.stream.Stream; import lombok.AllArgsConstructor; import lombok.Getter; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; @@ -31,48 +32,52 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionTestBase; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.env.Environment; -@ExtendWith(MockitoExtension.class) public class TextFunctionTest extends ExpressionTestBase { - @Mock Environment env; - - @Mock Expression nullRef; - - @Mock Expression missingRef; - - private static List SUBSTRING_STRINGS = - ImmutableList.of( - new SubstringInfo("", 1, 1, ""), - new SubstringInfo("Quadratically", 5, null, "ratically"), - new SubstringInfo("foobarbar", 4, null, "barbar"), - new SubstringInfo("Quadratically", 5, 6, "ratica"), - new SubstringInfo("Quadratically", 5, 600, "ratically"), - new SubstringInfo("Quadratically", 500, 1, ""), - new SubstringInfo("Quadratically", 500, null, ""), - new SubstringInfo("Sakila", -3, null, "ila"), - new SubstringInfo("Sakila", -5, 3, "aki"), - new SubstringInfo("Sakila", -4, 2, "ki"), - new SubstringInfo("Quadratically", 0, null, ""), - new SubstringInfo("Sakila", 0, 2, ""), - new SubstringInfo("Sakila", 2, 0, ""), - new SubstringInfo("Sakila", 0, 0, "")); - private static List UPPER_LOWER_STRINGS = - ImmutableList.of( - "test", " test", "test ", " test ", "TesT", "TEST", " TEST", "TEST ", " TEST ", " ", ""); - private static List STRING_PATTERN_PAIRS = - ImmutableList.of( - new StringPatternPair("Michael!", "Michael!"), - new StringPatternPair("hello", "world"), - new StringPatternPair("world", "hello")); - private static List TRIM_STRINGS = - ImmutableList.of(" test", " test", "test ", "test", " test ", "", " "); - private static List> CONCAT_STRING_LISTS = - ImmutableList.of(ImmutableList.of("hello", "world"), ImmutableList.of("123", "5325")); - private static List> CONCAT_STRING_LISTS_WITH_MANY_STRINGS = - ImmutableList.of( - ImmutableList.of("he", "llo", "wo", "rld", "!"), - ImmutableList.of("0", "123", "53", "25", "7")); + + private static Stream getStringsForSubstr() { + return Stream.of( + new SubstringInfo("", 1, 1, ""), + new SubstringInfo("Quadratically", 5, null, "ratically"), + new SubstringInfo("foobarbar", 4, null, "barbar"), + new SubstringInfo("Quadratically", 5, 6, "ratica"), + new SubstringInfo("Quadratically", 5, 600, "ratically"), + new SubstringInfo("Quadratically", 500, 1, ""), + new SubstringInfo("Quadratically", 500, null, ""), + new SubstringInfo("Sakila", -3, null, "ila"), + new SubstringInfo("Sakila", -5, 3, "aki"), + new SubstringInfo("Sakila", -4, 2, "ki"), + new SubstringInfo("Quadratically", 0, null, ""), + new SubstringInfo("Sakila", 0, 2, ""), + new SubstringInfo("Sakila", 2, 0, ""), + new SubstringInfo("Sakila", 0, 0, "")); + } + + private static Stream getStringsForUpperAndLower() { + return Stream.of( + "test", " test", "test ", " test ", "TesT", "TEST", " TEST", "TEST ", " TEST ", " ", ""); + } + + private static Stream getStringsForComparison() { + return Stream.of( + new StringPatternPair("Michael!", "Michael!"), + new StringPatternPair("hello", "world"), + new StringPatternPair("world", "hello")); + } + + private static Stream getStringsForTrim() { + return Stream.of(" test", " test", "test ", "test", " test ", "", " "); + } + + private static Stream> getStringsForConcat() { + return Stream.of(ImmutableList.of("hello", "world"), ImmutableList.of("123", "5325")); + } + + private static Stream> getMultipleStringsForConcat() { + return Stream.of( + ImmutableList.of("he", "llo", "wo", "rld", "!"), + ImmutableList.of("0", "123", "53", "25", "7")); + } interface SubstrSubstring { FunctionExpression getFunction(SubstringInfo strInfo); @@ -130,30 +135,11 @@ static class SubstringInfo { String res; } - @BeforeEach - public void setup() { - when(nullRef.valueOf(env)).thenReturn(nullValue()); - when(missingRef.valueOf(env)).thenReturn(missingValue()); - } - - @Test - public void substrSubstring() { - SUBSTRING_STRINGS.forEach(s -> substrSubstringTest(s, new Substr())); - SUBSTRING_STRINGS.forEach(s -> substrSubstringTest(s, new Substring())); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.substr(missingRef, DSL.literal(1)))); - assertEquals(nullValue(), eval(DSL.substr(nullRef, DSL.literal(1)))); - assertEquals(missingValue(), eval(DSL.substring(missingRef, DSL.literal(1)))); - assertEquals(nullValue(), eval(DSL.substring(nullRef, DSL.literal(1)))); - - when(nullRef.type()).thenReturn(INTEGER); - when(missingRef.type()).thenReturn(INTEGER); - assertEquals(missingValue(), eval(DSL.substr(DSL.literal("hello"), missingRef))); - assertEquals(nullValue(), eval(DSL.substr(DSL.literal("hello"), nullRef))); - assertEquals(missingValue(), eval(DSL.substring(DSL.literal("hello"), missingRef))); - assertEquals(nullValue(), eval(DSL.substring(DSL.literal("hello"), nullRef))); + @ParameterizedTest + @MethodSource("getStringsForSubstr") + void substrSubstring(SubstringInfo s) { + substrSubstringTest(s, new Substr()); + substrSubstringTest(s, new Substring()); } void substrSubstringTest(SubstringInfo strInfo, SubstrSubstring substrSubstring) { @@ -162,79 +148,41 @@ void substrSubstringTest(SubstringInfo strInfo, SubstrSubstring substrSubstring) assertEquals(strInfo.getRes(), eval(expr).stringValue()); } - @Test - public void ltrim() { - TRIM_STRINGS.forEach(this::ltrimString); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.ltrim(missingRef))); - assertEquals(nullValue(), eval(DSL.ltrim(nullRef))); - } - - @Test - public void rtrim() { - TRIM_STRINGS.forEach(this::rtrimString); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.ltrim(missingRef))); - assertEquals(nullValue(), eval(DSL.ltrim(nullRef))); - } - - @Test - public void trim() { - TRIM_STRINGS.forEach(this::trimString); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.ltrim(missingRef))); - assertEquals(nullValue(), eval(DSL.ltrim(nullRef))); - } - - void ltrimString(String str) { + @ParameterizedTest + @MethodSource("getStringsForTrim") + void ltrim(String str) { FunctionExpression expression = DSL.ltrim(DSL.literal(str)); assertEquals(STRING, expression.type()); assertEquals(str.stripLeading(), eval(expression).stringValue()); } - void rtrimString(String str) { + @ParameterizedTest + @MethodSource("getStringsForTrim") + void rtrim(String str) { FunctionExpression expression = DSL.rtrim(DSL.literal(str)); assertEquals(STRING, expression.type()); assertEquals(str.stripTrailing(), eval(expression).stringValue()); } - void trimString(String str) { + @ParameterizedTest + @MethodSource("getStringsForTrim") + void trim(String str) { FunctionExpression expression = DSL.trim(DSL.literal(str)); assertEquals(STRING, expression.type()); assertEquals(str.trim(), eval(expression).stringValue()); } - @Test - public void lower() { - UPPER_LOWER_STRINGS.forEach(this::testLowerString); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.lower(missingRef))); - assertEquals(nullValue(), eval(DSL.lower(nullRef))); - } - - @Test - public void upper() { - UPPER_LOWER_STRINGS.forEach(this::testUpperString); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.upper(missingRef))); - assertEquals(nullValue(), eval(DSL.upper(nullRef))); - } - - @Test - void concat() { - CONCAT_STRING_LISTS.forEach(this::testConcatString); - CONCAT_STRING_LISTS_WITH_MANY_STRINGS.forEach(this::testConcatMultipleString); + @ParameterizedTest + @MethodSource("getStringsForConcat") + void concat(List strings) { + testConcatString(strings); + // Since `concat` isn't wrapped with `nullMissingHandling` (which has its own tests), + // we have to test there case with NULL and MISSING values + Expression nullRef = mock(Expression.class); + Expression missingRef = mock(Expression.class); + when(nullRef.valueOf(any())).thenReturn(nullValue()); + when(missingRef.valueOf(any())).thenReturn(missingValue()); when(nullRef.type()).thenReturn(STRING); when(missingRef.type()).thenReturn(STRING); assertEquals(missingValue(), eval(DSL.concat(missingRef, DSL.literal("1")))); @@ -244,43 +192,10 @@ void concat() { assertEquals(nullValue(), eval(DSL.concat(DSL.literal("1"), nullRef))); } - @Test - void concat_ws() { - CONCAT_STRING_LISTS.forEach(s -> testConcatString(s, ",")); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals( - missingValue(), eval(DSL.concat_ws(missingRef, DSL.literal("1"), DSL.literal("1")))); - assertEquals(nullValue(), eval(DSL.concat_ws(nullRef, DSL.literal("1"), DSL.literal("1")))); - assertEquals( - missingValue(), eval(DSL.concat_ws(DSL.literal("1"), missingRef, DSL.literal("1")))); - assertEquals(nullValue(), eval(DSL.concat_ws(DSL.literal("1"), nullRef, DSL.literal("1")))); - assertEquals( - missingValue(), eval(DSL.concat_ws(DSL.literal("1"), DSL.literal("1"), missingRef))); - assertEquals(nullValue(), eval(DSL.concat_ws(DSL.literal("1"), DSL.literal("1"), nullRef))); - } - - @Test - void length() { - UPPER_LOWER_STRINGS.forEach(this::testLengthString); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.length(missingRef))); - assertEquals(nullValue(), eval(DSL.length(nullRef))); - } - - @Test - void strcmp() { - STRING_PATTERN_PAIRS.forEach(this::testStcmpString); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.strcmp(missingRef, missingRef))); - assertEquals(nullValue(), eval(DSL.strcmp(nullRef, nullRef))); - assertEquals(missingValue(), eval(DSL.strcmp(nullRef, missingRef))); - assertEquals(missingValue(), eval(DSL.strcmp(missingRef, nullRef))); + @ParameterizedTest + @MethodSource("getStringsForConcat") + void concat_ws(List strings) { + testConcatString(strings, ","); } @Test @@ -302,14 +217,6 @@ void right() { expression = DSL.right(DSL.literal(""), DSL.literal(10)); assertEquals(STRING, expression.type()); assertEquals("", eval(expression).value()); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(INTEGER); - assertEquals(missingValue(), eval(DSL.right(nullRef, missingRef))); - assertEquals(nullValue(), eval(DSL.right(nullRef, DSL.literal(new ExprIntegerValue(1))))); - - when(nullRef.type()).thenReturn(INTEGER); - assertEquals(nullValue(), eval(DSL.right(DSL.literal(new ExprStringValue("value")), nullRef))); } @Test @@ -331,14 +238,6 @@ void left() { expression = DSL.left(DSL.literal(""), DSL.literal(10)); assertEquals(STRING, expression.type()); assertEquals("", eval(expression).value()); - - when(nullRef.type()).thenReturn(STRING); - when(missingRef.type()).thenReturn(INTEGER); - assertEquals(missingValue(), eval(DSL.left(nullRef, missingRef))); - assertEquals(nullValue(), eval(DSL.left(nullRef, DSL.literal(new ExprIntegerValue(1))))); - - when(nullRef.type()).thenReturn(INTEGER); - assertEquals(nullValue(), eval(DSL.left(DSL.literal(new ExprStringValue("value")), nullRef))); } @Test @@ -346,11 +245,7 @@ void ascii() { FunctionExpression expression = DSL.ascii(DSL.literal(new ExprStringValue("hello"))); assertEquals(INTEGER, expression.type()); assertEquals(104, eval(expression).integerValue()); - - when(nullRef.type()).thenReturn(STRING); - assertEquals(nullValue(), eval(DSL.ascii(nullRef))); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.ascii(missingRef))); + assertEquals(0, DSL.ascii(DSL.literal("")).valueOf().integerValue()); } @Test @@ -362,14 +257,6 @@ void locate() { expression = DSL.locate(DSL.literal("world"), DSL.literal("helloworldworld"), DSL.literal(7)); assertEquals(INTEGER, expression.type()); assertEquals(11, eval(expression).integerValue()); - - when(nullRef.type()).thenReturn(STRING); - assertEquals(nullValue(), eval(DSL.locate(nullRef, DSL.literal("hello")))); - assertEquals(nullValue(), eval(DSL.locate(nullRef, DSL.literal("hello"), DSL.literal(1)))); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.locate(missingRef, DSL.literal("hello")))); - assertEquals( - missingValue(), eval(DSL.locate(missingRef, DSL.literal("hello"), DSL.literal(1)))); } @Test @@ -382,11 +269,6 @@ void position() { expression = DSL.position(DSL.literal("abc"), DSL.literal("hello world")); assertEquals(INTEGER, expression.type()); assertEquals(0, eval(expression).integerValue()); - - when(nullRef.type()).thenReturn(STRING); - assertEquals(nullValue(), eval(DSL.position(nullRef, DSL.literal("hello")))); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.position(missingRef, DSL.literal("hello")))); } @Test @@ -395,11 +277,6 @@ void replace() { DSL.replace(DSL.literal("helloworld"), DSL.literal("world"), DSL.literal("opensearch")); assertEquals(STRING, expression.type()); assertEquals("helloopensearch", eval(expression).stringValue()); - - when(nullRef.type()).thenReturn(STRING); - assertEquals(nullValue(), eval(DSL.replace(nullRef, DSL.literal("a"), DSL.literal("b")))); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.replace(missingRef, DSL.literal("a"), DSL.literal("b")))); } @Test @@ -407,11 +284,6 @@ void reverse() { FunctionExpression expression = DSL.reverse(DSL.literal("abcde")); assertEquals(STRING, expression.type()); assertEquals("edcba", eval(expression).stringValue()); - - when(nullRef.type()).thenReturn(STRING); - assertEquals(nullValue(), eval(DSL.reverse(nullRef))); - when(missingRef.type()).thenReturn(STRING); - assertEquals(missingValue(), eval(DSL.reverse(missingRef))); } void testConcatString(List strings) { @@ -435,6 +307,8 @@ void testConcatString(List strings, String delim) { assertEquals(expected, eval(expression).stringValue()); } + @ParameterizedTest + @MethodSource("getMultipleStringsForConcat") void testConcatMultipleString(List strings) { String expected = null; if (strings.stream().noneMatch(Objects::isNull)) { @@ -452,13 +326,17 @@ void testConcatMultipleString(List strings) { assertEquals(expected, eval(expression).stringValue()); } - void testLengthString(String str) { + @ParameterizedTest + @MethodSource("getStringsForUpperAndLower") + void length(String str) { FunctionExpression expression = DSL.length(DSL.literal(new ExprStringValue(str))); assertEquals(INTEGER, expression.type()); assertEquals(str.getBytes().length, eval(expression).integerValue()); } - void testStcmpString(StringPatternPair stringPatternPair) { + @ParameterizedTest + @MethodSource("getStringsForComparison") + void strcmp(StringPatternPair stringPatternPair) { FunctionExpression expression = DSL.strcmp( DSL.literal(new ExprStringValue(stringPatternPair.getStr())), @@ -467,19 +345,23 @@ void testStcmpString(StringPatternPair stringPatternPair) { assertEquals(stringPatternPair.strCmpTest(), eval(expression).integerValue()); } - void testLowerString(String str) { + @ParameterizedTest + @MethodSource("getStringsForUpperAndLower") + void lower(String str) { FunctionExpression expression = DSL.lower(DSL.literal(new ExprStringValue(str))); assertEquals(STRING, expression.type()); assertEquals(stringValue(str.toLowerCase()), eval(expression)); } - void testUpperString(String str) { + @ParameterizedTest + @MethodSource("getStringsForUpperAndLower") + void upper(String str) { FunctionExpression expression = DSL.upper(DSL.literal(new ExprStringValue(str))); assertEquals(STRING, expression.type()); assertEquals(stringValue(str.toUpperCase()), eval(expression)); } private ExprValue eval(Expression expression) { - return expression.valueOf(env); + return expression.valueOf(); } }