Skip to content

Commit

Permalink
Enable concat() string function to support multiple string arguments (
Browse files Browse the repository at this point in the history
#1279)

* Enable `concat()` string function to support multiple string arguments (#200)

Signed-off-by: Margarit Hakobyan <[email protected]>
  • Loading branch information
margarit-h authored Jan 27, 2023
1 parent a4f8066 commit 45fc371
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ private FunctionBuilder getFunctionBuilder(
List<ExprType> sourceTypes = functionSignature.getParamTypeList();
List<ExprType> targetTypes = resolvedSignature.getKey().getParamTypeList();
FunctionBuilder funcBuilder = resolvedSignature.getValue();
if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) {
if (isCastFunction(functionName)
|| FunctionSignature.isVarArgFunction(targetTypes)
|| sourceTypes.equals(targetTypes)) {
return funcBuilder;
}
return castArguments(sourceTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unreso
functionSignature));
}
Map.Entry<Integer, FunctionSignature> bestMatchEntry = functionMatchQueue.peek();
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) {
if (FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())
&& (unresolvedSignature.getParamTypeList().isEmpty()
|| unresolvedSignature.getParamTypeList().size() > 9)) {
throw new ExpressionEvaluationException(
String.format("%s function expected 1-9 arguments, but got %d",
functionName, unresolvedSignature.getParamTypeList().size()));
}
if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())
&& !FunctionSignature.isVarArgFunction(bestMatchEntry.getValue().getParamTypeList())) {
throw new ExpressionEvaluationException(
String.format("%s function expected %s, but get %s", functionName,
formatFunctions(functionBundle.keySet()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;

import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
Expand Down Expand Up @@ -39,6 +41,10 @@ public int match(FunctionSignature functionSignature) {
|| paramTypeList.size() != functionTypeList.size()) {
return NOT_MATCH;
}
// TODO: improve to support regular and array type mixed, ex. func(int,string,array)
if (isVarArgFunction(functionTypeList)) {
return EXACTLY_MATCH;
}

int matchDegree = EXACTLY_MATCH;
for (int i = 0; i < paramTypeList.size(); i++) {
Expand All @@ -62,4 +68,11 @@ public String formatTypes() {
.map(ExprType::typeName)
.collect(Collectors.joining(",", "[", "]"));
}

/**
* util function - returns true if function has variable arguments.
*/
protected static boolean isVarArgFunction(List<ExprType> argTypes) {
return argTypes.size() == 1 && argTypes.get(0) == ARRAY;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,34 @@

package org.opensearch.sql.expression.text;

import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.expression.function.FunctionDSL.define;
import static org.opensearch.sql.expression.function.FunctionDSL.impl;
import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.env.Environment;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.DefaultFunctionResolver;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.FunctionSignature;
import org.opensearch.sql.expression.function.SerializableBiFunction;
import org.opensearch.sql.expression.function.SerializableTriFunction;


/**
* The definition of text functions.
* 1) have the clear interface for function define.
Expand Down Expand Up @@ -141,16 +151,37 @@ private DefaultFunctionResolver upper() {
}

/**
* TODO: https://github.com/opendistro-for-elasticsearch/sql/issues/710
* Extend to accept variable argument amounts.
* Concatenates a list of Strings.
* Supports following signatures:
* (STRING, STRING) -> STRING
* (STRING, STRING, ...., STRING) -> STRING
*/
private DefaultFunctionResolver concat() {
return define(BuiltinFunctionName.CONCAT.getName(),
impl(nullMissingHandling((str1, str2) ->
new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING));
FunctionName concatFuncName = BuiltinFunctionName.CONCAT.getName();
return define(concatFuncName, funcName ->
Pair.of(
new FunctionSignature(concatFuncName, Collections.singletonList(ARRAY)),
(funcProp, args) -> new FunctionExpression(funcName, args) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
List<ExprValue> exprValues = args.stream()
.map(arg -> arg.valueOf(valueEnv)).collect(Collectors.toList());
if (exprValues.stream().anyMatch(ExprValue::isMissing)) {
return ExprValueUtils.missingValue();
}
if (exprValues.stream().anyMatch(ExprValue::isNull)) {
return ExprValueUtils.nullValue();
}
return new ExprStringValue(exprValues.stream()
.map(ExprValue::stringValue)
.collect(Collectors.joining()));
}

@Override
public ExprType type() {
return STRING;
}
}
));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collections;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -76,4 +80,53 @@ void resolve_function_not_match() {
assertEquals("add function expected {[INTEGER,INTEGER]}, but get [BOOLEAN,BOOLEAN]",
exception.getMessage());
}

@Test
void resolve_varargs_function_signature_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList.of(STRING));
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue());
}

@Test
void resolve_varargs_no_args_function_signature_not_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
// Concat function with no arguments
when(functionSignature.getParamTypeList()).thenReturn(Collections.emptyList());

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
() -> resolver.resolve(functionSignature));
assertEquals("concat function expected 1-9 arguments, but got 0",
exception.getMessage());
}

@Test
void resolve_varargs_too_many_args_function_signature_not_match() {
functionName = FunctionName.of("concat");
when(functionSignature.match(bestMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL);
when(bestMatchFS.getParamTypeList()).thenReturn(ImmutableList.of(ARRAY));
// Concat function with more than 9 arguments
when(functionSignature.getParamTypeList()).thenReturn(ImmutableList
.of(STRING, STRING, STRING, STRING, STRING,
STRING, STRING, STRING, STRING, STRING));

DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName,
ImmutableMap.of(bestMatchFS, bestMatchBuilder));

ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class,
() -> resolver.resolve(functionSignature));
assertEquals("concat function expected 1-9 arguments, but got 10",
exception.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public class TextFunctionTest extends ExpressionTestBase {
private static List<List<String>> CONCAT_STRING_LISTS = ImmutableList.of(
ImmutableList.of("hello", "world"),
ImmutableList.of("123", "5325"));
private static List<List<String>> CONCAT_STRING_LISTS_WITH_MANY_STRINGS = ImmutableList.of(
ImmutableList.of("he", "llo", "wo", "rld", "!"),
ImmutableList.of("0", "123", "53", "25", "7"));

interface SubstrSubstring {
FunctionExpression getFunction(SubstringInfo strInfo);
Expand Down Expand Up @@ -228,11 +231,13 @@ public void upper() {
@Test
void concat() {
CONCAT_STRING_LISTS.forEach(this::testConcatString);
CONCAT_STRING_LISTS_WITH_MANY_STRINGS.forEach(this::testConcatMultipleString);

when(nullRef.type()).thenReturn(STRING);
when(missingRef.type()).thenReturn(STRING);
assertEquals(missingValue(), eval(
DSL.concat(missingRef, DSL.literal("1"))));
// If any of the expressions is a NULL value, it returns NULL.
assertEquals(nullValue(), eval(
DSL.concat(nullRef, DSL.literal("1"))));
assertEquals(missingValue(), eval(
Expand Down Expand Up @@ -446,6 +451,22 @@ void testConcatString(List<String> strings, String delim) {
assertEquals(expected, eval(expression).stringValue());
}

void testConcatMultipleString(List<String> strings) {
String expected = null;
if (strings.stream().noneMatch(Objects::isNull)) {
expected = String.join("", strings);
}

FunctionExpression expression = DSL.concat(
DSL.literal(strings.get(0)),
DSL.literal(strings.get(1)),
DSL.literal(strings.get(2)),
DSL.literal(strings.get(3)),
DSL.literal(strings.get(4)));
assertEquals(STRING, expression.type());
assertEquals(expected, eval(expression).stringValue());
}

void testLengthString(String str) {
FunctionExpression expression = DSL.length(DSL.literal(new ExprStringValue(str)));
assertEquals(INTEGER, expression.type());
Expand Down
16 changes: 8 additions & 8 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2614,21 +2614,21 @@ CONCAT
Description
>>>>>>>>>>>

Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together. If any of the expressions is a NULL value, it returns NULL.

Argument type: STRING, STRING
Argument type: STRING, STRING, ...., STRING

Return type: STRING

Example::

os> SELECT CONCAT('hello', 'world')
os> SELECT CONCAT('hello ', 'whole ', 'world', '!'), CONCAT('hello', 'world'), CONCAT('hello', null)
fetched rows / total rows = 1/1
+----------------------------+
| CONCAT('hello', 'world') |
|----------------------------|
| helloworld |
+----------------------------+
+--------------------------------------------+----------------------------+-------------------------+
| CONCAT('hello ', 'whole ', 'world', '!') | CONCAT('hello', 'world') | CONCAT('hello', null) |
|--------------------------------------------+----------------------------+-------------------------|
| hello whole world! | helloworld | null |
+--------------------------------------------+----------------------------+-------------------------+


CONCAT_WS
Expand Down
16 changes: 8 additions & 8 deletions docs/user/ppl/functions/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ CONCAT
Description
>>>>>>>>>>>

Usage: CONCAT(str1, str2) returns str1 and str strings concatenated together.
Usage: CONCAT(str1, str2, ...., str_9) adds up to 9 strings together.

Argument type: STRING, STRING
Argument type: STRING, STRING, ...., STRING

Return type: STRING

Example::

os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world') | fields `CONCAT('hello', 'world')`
os> source=people | eval `CONCAT('hello', 'world')` = CONCAT('hello', 'world'), `CONCAT('hello ', 'whole ', 'world', '!')` = CONCAT('hello ', 'whole ', 'world', '!') | fields `CONCAT('hello', 'world')`, `CONCAT('hello ', 'whole ', 'world', '!')`
fetched rows / total rows = 1/1
+----------------------------+
| CONCAT('hello', 'world') |
|----------------------------|
| helloworld |
+----------------------------+
+----------------------------+--------------------------------------------+
| CONCAT('hello', 'world') | CONCAT('hello ', 'whole ', 'world', '!') |
|----------------------------+--------------------------------------------|
| helloworld | hello whole world! |
+----------------------------+--------------------------------------------+


CONCAT_WS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public void testLtrim() throws IOException {

@Test
public void testConcat() throws IOException {
verifyQuery("concat", "", ", 'there'",
"hellothere", "worldthere", "helloworldthere");
verifyQuery("concat", "", ", 'there', 'all', '!'",
"hellothereall!", "worldthereall!", "helloworldthereall!");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public void testLtrim() throws IOException {

@Test
public void testConcat() throws IOException {
verifyQuery("concat('hello', 'whole', 'world', '!', '!')", "keyword", "hellowholeworld!!");
verifyQuery("concat('hello', 'world')", "keyword", "helloworld");
verifyQuery("concat('', 'hello')", "keyword", "hello");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ LOCATE('world', 'helloworld') as column
LOCATE('world', 'hello') as column
LOCATE('world', 'helloworld', 7) as column
REPLACE('helloworld', 'world', 'opensearch') as column
REPLACE('hello', 'world', 'opensearch') as column
REPLACE('hello', 'world', 'opensearch') as column
CONCAT('hello', 'world') as column
CONCAT('hello ', 'whole ', 'world', '!') as column

0 comments on commit 45fc371

Please sign in to comment.