Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support implicit type conversion from string to boolean #166

Merged
26 changes: 25 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/expression/Cast.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
package org.opensearch.sql.ast.expression;

import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_BOOLEAN;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_BYTE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_DATE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_DOUBLE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_FLOAT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_INT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_LONG;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_SHORT;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_STRING;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_TIME;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.CAST_TO_TIMESTAMP;
Expand All @@ -49,6 +51,7 @@
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.function.FunctionName;

/**
Expand All @@ -60,9 +63,11 @@
@ToString
public class Cast extends UnresolvedExpression {

private static Map<String, FunctionName> CONVERTED_TYPE_FUNCTION_NAME_MAP =
private static final Map<String, FunctionName> CONVERTED_TYPE_FUNCTION_NAME_MAP =
new ImmutableMap.Builder<String, FunctionName>()
.put("string", CAST_TO_STRING.getName())
.put("byte", CAST_TO_BYTE.getName())
.put("short", CAST_TO_SHORT.getName())
.put("int", CAST_TO_INT.getName())
.put("integer", CAST_TO_INT.getName())
.put("long", CAST_TO_LONG.getName())
Expand All @@ -84,6 +89,25 @@ public class Cast extends UnresolvedExpression {
*/
private final UnresolvedExpression convertedType;

/**
* Check if the given function name is a cast function or not.
* @param name function name
* @return true if cast function, otherwise false.
*/
public static boolean isCastFunction(FunctionName name) {
return CONVERTED_TYPE_FUNCTION_NAME_MAP.containsValue(name);
}

/**
* Get the cast function name for a given target data type.
* @param targetType target data type
* @return cast function name corresponding
*/
public static FunctionName getCastFunctionName(ExprType targetType) {
String type = targetType.typeName().toLowerCase(Locale.ROOT);
return CONVERTED_TYPE_FUNCTION_NAME_MAP.get(type);
}

/**
* Get the converted type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@

package org.opensearch.sql.data.type;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -62,16 +63,15 @@ public enum ExprCoreType implements ExprType {
FLOAT(LONG),
DOUBLE(FLOAT),

/**
* Boolean.
*/
BOOLEAN(UNDEFINED),

/**
* String.
*/
STRING(UNDEFINED),

/**
* Boolean.
*/
BOOLEAN(STRING),

/**
* Date.
Expand Down Expand Up @@ -108,6 +108,16 @@ public enum ExprCoreType implements ExprType {
.put(STRING, "keyword")
.build();

private static final Set<ExprType> NUMBER_TYPES =
new ImmutableSet.Builder<ExprType>()
.add(BYTE)
.add(SHORT)
.add(INTEGER)
.add(LONG)
.add(FLOAT)
.add(DOUBLE)
.build();

ExprCoreType(ExprCoreType... compatibleTypes) {
for (ExprCoreType subType : compatibleTypes) {
subType.parents.add(this);
Expand Down Expand Up @@ -139,7 +149,7 @@ public static List<ExprCoreType> coreTypes() {
.collect(Collectors.toList());
}

public static List<ExprType> numberTypes() {
return ImmutableList.of(INTEGER, LONG, FLOAT, DOUBLE);
public static Set<ExprType> numberTypes() {
return NUMBER_TYPES;
}
}
10 changes: 10 additions & 0 deletions core/src/main/java/org/opensearch/sql/data/type/ExprType.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ default boolean isCompatible(ExprType other) {
}
}

/**
* Should cast this type to other type or not. By default, cast is always required
* if the given type is different from this type.
* @param other other data type
* @return true if cast is required, otherwise false
*/
default boolean shouldCast(ExprType other) {
return !this.equals(other);
}

/**
* Get the parent type.
*/
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,16 @@ public FunctionExpression castString(Expression value) {
.compile(BuiltinFunctionName.CAST_TO_STRING.getName(), Arrays.asList(value));
}

public FunctionExpression castByte(Expression value) {
return (FunctionExpression) repository
.compile(BuiltinFunctionName.CAST_TO_BYTE.getName(), Arrays.asList(value));
}

public FunctionExpression castShort(Expression value) {
return (FunctionExpression) repository
.compile(BuiltinFunctionName.CAST_TO_SHORT.getName(), Arrays.asList(value));
}

public FunctionExpression castInt(Expression value) {
return (FunctionExpression) repository
.compile(BuiltinFunctionName.CAST_TO_INT.getName(), Arrays.asList(value));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ public enum BuiltinFunctionName {
* Data Type Convert Function.
*/
CAST_TO_STRING(FunctionName.of("cast_to_string")),
CAST_TO_BYTE(FunctionName.of("cast_to_byte")),
CAST_TO_SHORT(FunctionName.of("cast_to_short")),
CAST_TO_INT(FunctionName.of("cast_to_int")),
CAST_TO_LONG(FunctionName.of("cast_to_long")),
CAST_TO_FLOAT(FunctionName.of("cast_to_float")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.ast.expression.Cast.getCastFunctionName;
import static org.opensearch.sql.ast.expression.Cast.isCastFunction;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.expression.Expression;

Expand Down Expand Up @@ -47,15 +56,70 @@ public FunctionImplementation compile(FunctionName functionName, List<Expression
* Resolve the {@link FunctionBuilder} in Builtin Function Repository.
*
* @param functionSignature {@link FunctionSignature}
* @return {@link FunctionBuilder}
* @return Original function builder if it's a cast function or all arguments have expected types.
* Otherwise wrap its arguments by cast function as needed.
*/
public FunctionBuilder resolve(FunctionSignature functionSignature) {
FunctionName functionName = functionSignature.getFunctionName();
if (functionResolverMap.containsKey(functionName)) {
return functionResolverMap.get(functionName).resolve(functionSignature);
Pair<FunctionSignature, FunctionBuilder> resolvedSignature =
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
functionResolverMap.get(functionName).resolve(functionSignature);

List<ExprType> sourceTypes = functionSignature.getParamTypeList();
List<ExprType> targetTypes = resolvedSignature.getKey().getParamTypeList();
FunctionBuilder funcBuilder = resolvedSignature.getValue();
if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) {
return funcBuilder;
}
return castArguments(sourceTypes, targetTypes, funcBuilder);
} else {
throw new ExpressionEvaluationException(
String.format("unsupported function name: %s", functionName.getFunctionName()));
}
}

/**
* Wrap resolved function builder's arguments by cast function to cast input expression value
* to value of target type at runtime. For example, suppose unresolved signature is
* equal(BOOL,STRING) and its resolved function builder is F with signature equal(BOOL,BOOL).
* In this case, wrap F and return equal(BOOL, cast_to_bool(STRING)).
*/
private FunctionBuilder castArguments(List<ExprType> sourceTypes,
List<ExprType> targetTypes,
FunctionBuilder funcBuilder) {
return arguments -> {
List<Expression> argsCasted = new ArrayList<>();
for (int i = 0; i < arguments.size(); i++) {
Expression arg = arguments.get(i);
ExprType sourceType = sourceTypes.get(i);
ExprType targetType = targetTypes.get(i);

if (isCastRequired(sourceType, targetType)) {
argsCasted.add(cast(arg, targetType));
} else {
argsCasted.add(arg);
}
}
return funcBuilder.apply(argsCasted);
};
}

private boolean isCastRequired(ExprType sourceType, ExprType targetType) {
// TODO: Remove this special case after fixing all failed UTs
if (ExprCoreType.numberTypes().contains(sourceType)
&& ExprCoreType.numberTypes().contains(targetType)) {
return false;
}
return sourceType.shouldCast(targetType);
}

private Expression cast(Expression arg, ExprType targetType) {
FunctionName castFunctionName = getCastFunctionName(targetType);
if (castFunctionName == null) {
throw new ExpressionEvaluationException(StringUtils.format(
"Type conversion to type %s is not supported", targetType));
}
return (Expression) compile(castFunctionName, ImmutableList.of(arg));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Singular;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.exception.ExpressionEvaluationException;

/**
Expand All @@ -41,8 +42,10 @@ public class FunctionResolver {
* If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it.
* If applying the widening rule, found the most match one, return it.
* If nothing found, throw {@link ExpressionEvaluationException}
*
* @return function signature and its builder
*/
public FunctionBuilder resolve(FunctionSignature unresolvedSignature) {
public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unresolvedSignature) {
PriorityQueue<Map.Entry<Integer, FunctionSignature>> functionMatchQueue = new PriorityQueue<>(
Map.Entry.comparingByKey());

Expand All @@ -59,7 +62,8 @@ public FunctionBuilder resolve(FunctionSignature unresolvedSignature) {
unresolvedSignature.formatTypes()
));
} else {
return functionBundle.get(bestMatchEntry.getValue());
FunctionSignature resolvedSignature = bestMatchEntry.getValue();
return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@
import java.util.stream.Stream;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.data.model.ExprBooleanValue;
import org.opensearch.sql.data.model.ExprByteValue;
import org.opensearch.sql.data.model.ExprDateValue;
import org.opensearch.sql.data.model.ExprDoubleValue;
import org.opensearch.sql.data.model.ExprFloatValue;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprLongValue;
import org.opensearch.sql.data.model.ExprShortValue;
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprTimeValue;
import org.opensearch.sql.data.model.ExprTimestampValue;
Expand All @@ -68,6 +70,8 @@ public class TypeCastOperator {
*/
public static void register(BuiltinFunctionRepository repository) {
repository.register(castToString());
repository.register(castToByte());
repository.register(castToShort());
repository.register(castToInt());
repository.register(castToLong());
repository.register(castToFloat());
Expand All @@ -92,6 +96,28 @@ private static FunctionResolver castToString() {
);
}

private static FunctionResolver castToByte() {
return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BYTE.getName(),
impl(nullMissingHandling(
(v) -> new ExprByteValue(Short.valueOf(v.stringValue()))), BYTE, STRING),
impl(nullMissingHandling(
(v) -> new ExprByteValue(v.shortValue())), BYTE, DOUBLE),
impl(nullMissingHandling(
(v) -> new ExprByteValue(v.booleanValue() ? 1 : 0)), BYTE, BOOLEAN)
);
}

private static FunctionResolver castToShort() {
return FunctionDSL.define(BuiltinFunctionName.CAST_TO_SHORT.getName(),
impl(nullMissingHandling(
(v) -> new ExprShortValue(Short.valueOf(v.stringValue()))), SHORT, STRING),
impl(nullMissingHandling(
(v) -> new ExprShortValue(v.shortValue())), SHORT, DOUBLE),
impl(nullMissingHandling(
(v) -> new ExprShortValue(v.booleanValue() ? 1 : 0)), SHORT, BOOLEAN)
);
}

private static FunctionResolver castToInt() {
return FunctionDSL.define(BuiltinFunctionName.CAST_TO_INT.getName(),
impl(nullMissingHandling(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE;
import static org.opensearch.sql.data.type.ExprCoreType.FLOAT;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
Expand All @@ -58,6 +59,11 @@ public void isCompatible() {
assertTrue(FLOAT.isCompatible(LONG));
assertTrue(FLOAT.isCompatible(INTEGER));
assertTrue(FLOAT.isCompatible(SHORT));
assertTrue(BOOLEAN.isCompatible(STRING));
}

@Test
public void isNotCompatible() {
assertFalse(INTEGER.isCompatible(DOUBLE));
assertFalse(STRING.isCompatible(DOUBLE));
assertFalse(INTEGER.isCompatible(UNKNOWN));
Expand All @@ -69,6 +75,13 @@ public void isCompatibleWithUndefined() {
ExprCoreType.coreTypes().forEach(type -> assertFalse(UNDEFINED.isCompatible(type)));
}

@Test
public void shouldCast() {
assertTrue(UNDEFINED.shouldCast(STRING));
assertTrue(STRING.shouldCast(BOOLEAN));
assertFalse(STRING.shouldCast(STRING));
}

@Test
public void getParent() {
assertThat(((ExprType) () -> "test").getParent(), Matchers.contains(UNKNOWN));
Expand Down
Loading