From b387c8e10d3cfc1d4ae7893f5d54d735718c710a Mon Sep 17 00:00:00 2001 From: Yury-Fridlyand Date: Tue, 1 Nov 2022 13:46:34 -0700 Subject: [PATCH] Fix `FLOAT` -> `DOUBLE` cast. * Update tests. * Update numeric -> `BOOLEAN` cast in `LuceneQuery` to comply with `TypeCastOperator`. Signed-off-by: Yury-Fridlyand --- .../operator/convert/TypeCastOperator.java | 2 +- .../convert/TypeCastOperatorTest.java | 4 +- .../script/filter/lucene/LuceneQuery.java | 2 +- .../script/filter/FilterQueryBuilderTest.java | 80 +++++++++++++------ 4 files changed, 60 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java index 23508406ac..8f904bfbf7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java @@ -125,7 +125,7 @@ private static DefaultFunctionResolver castToFloat() { impl(nullMissingHandling( (v) -> new ExprFloatValue(Float.valueOf(v.stringValue()))), FLOAT, STRING), impl(nullMissingHandling( - (v) -> new ExprFloatValue(v.longValue())), FLOAT, DOUBLE), + (v) -> new ExprFloatValue(v.floatValue())), FLOAT, DOUBLE), impl(nullMissingHandling( (v) -> new ExprFloatValue(v.booleanValue() ? 1f : 0f)), FLOAT, BOOLEAN) ); diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java index c2ca793e39..f791b7d86a 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java @@ -48,8 +48,8 @@ class TypeCastOperatorTest { private static Stream numberData() { return Stream.of(new ExprByteValue(3), new ExprShortValue(3), - new ExprIntegerValue(3), new ExprLongValue(3L), new ExprFloatValue(3f), - new ExprDoubleValue(3D)); + new ExprIntegerValue(3), new ExprLongValue(3L), new ExprFloatValue(3.14f), + new ExprDoubleValue(3.1415D)); } private static Stream stringData() { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java index 289772d6b6..aa27fffcbc 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java @@ -172,7 +172,7 @@ private ExprValue cast(FunctionExpression castFunction) { }) .put(BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), expr -> { if (ExprCoreType.numberTypes().contains(expr.type())) { - return expr.valueOf(null).doubleValue() == 1 + return expr.valueOf(null).doubleValue() != 0 ? ExprBooleanValue.of(true) : ExprBooleanValue.of(false); } else if (expr.type().equals(ExprCoreType.STRING)) { return ExprBooleanValue.of(Boolean.valueOf(expr.valueOf(null).stringValue())); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index 75ddd1dd93..c0656c20d8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -48,7 +48,6 @@ import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -64,18 +63,18 @@ class FilterQueryBuilderTest { private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); private static Stream numericCastSource() { - return Stream.of(literal((byte) 1), literal((short) 1), literal( - 1), literal(1L), literal(1F), literal(1D), literal(true), literal("1")); + return Stream.of(literal((byte) 1), literal((short) -1), literal( + 1), literal(21L), literal(3.14F), literal(3.1415D), literal(true), literal("1")); } private static Stream booleanTrueCastSource() { - return Stream.of(literal((byte) 1), literal((short) 1), literal( - 1), literal(1L), literal(1F), literal(1D), literal(true), literal("true")); + return Stream.of(literal((byte) 1), literal((short) -1), literal( + 1), literal(42L), literal(3.14F), literal(3.1415D), literal(true), literal("true")); } private static Stream booleanFalseCastSource() { return Stream.of(literal((byte) 0), literal((short) 0), literal( - 0), literal(0L), literal(0F), literal(0D), literal(false), literal("false")); + 0), literal(0L), literal(0.0F), literal(0.0D), literal(false), literal("false")); } @Mock @@ -937,93 +936,126 @@ void cast_to_string_in_filter() { dsl.equal(ref("string_value", STRING), dsl.castString(literal("1"))))); } + private Float castToFloat(Object o) { + if (o instanceof Number) { + return ((Number)o).floatValue(); + } + if (o instanceof String) { + return Float.parseFloat((String) o); + } + if (o instanceof Boolean) { + return ((Boolean)o) ? 1F : 0F; + } + // unreachable code + throw new IllegalArgumentException(); + } + + private Integer castToInteger(Object o) { + if (o instanceof Number) { + return ((Number)o).intValue(); + } + if (o instanceof String) { + return Integer.parseInt((String) o); + } + if (o instanceof Boolean) { + return ((Boolean)o) ? 1 : 0; + } + // unreachable code + throw new IllegalArgumentException(); + } + @ParameterizedTest(name = "castByte({0})") @MethodSource({"numericCastSource"}) void cast_to_byte_in_filter(LiteralExpression expr) { - assertJsonEquals( + assertJsonEquals(String.format( "{\n" + " \"term\" : {\n" + " \"byte_value\" : {\n" - + " \"value\" : 1,\n" + + " \"value\" : %d,\n" + " \"boost\" : 1.0\n" + " }\n" + " }\n" - + "}", + + "}", castToInteger(expr.valueOf(null).value())), buildQuery(dsl.equal(ref("byte_value", BYTE), dsl.castByte(expr)))); } @ParameterizedTest(name = "castShort({0})") @MethodSource({"numericCastSource"}) void cast_to_short_in_filter(LiteralExpression expr) { - assertJsonEquals( + assertJsonEquals(String.format( "{\n" + " \"term\" : {\n" + " \"short_value\" : {\n" - + " \"value\" : 1,\n" + + " \"value\" : %d,\n" + " \"boost\" : 1.0\n" + " }\n" + " }\n" - + "}", + + "}", castToInteger(expr.valueOf(null).value())), buildQuery(dsl.equal(ref("short_value", SHORT), dsl.castShort(expr)))); } @ParameterizedTest(name = "castInt({0})") @MethodSource({"numericCastSource"}) void cast_to_int_in_filter(LiteralExpression expr) { - assertJsonEquals( + assertJsonEquals(String.format( "{\n" + " \"term\" : {\n" + " \"integer_value\" : {\n" - + " \"value\" : 1,\n" + + " \"value\" : %d,\n" + " \"boost\" : 1.0\n" + " }\n" + " }\n" - + "}", + + "}", castToInteger(expr.valueOf(null).value())), buildQuery(dsl.equal(ref("integer_value", INTEGER), dsl.castInt(expr)))); } @ParameterizedTest(name = "castLong({0})") @MethodSource({"numericCastSource"}) void cast_to_long_in_filter(LiteralExpression expr) { - assertJsonEquals( + assertJsonEquals(String.format( "{\n" + " \"term\" : {\n" + " \"long_value\" : {\n" - + " \"value\" : 1,\n" + + " \"value\" : %d,\n" + " \"boost\" : 1.0\n" + " }\n" + " }\n" - + "}", + + "}", castToInteger(expr.valueOf(null).value())), buildQuery(dsl.equal(ref("long_value", LONG), dsl.castLong(expr)))); } @ParameterizedTest(name = "castFloat({0})") @MethodSource({"numericCastSource"}) void cast_to_float_in_filter(LiteralExpression expr) { - assertJsonEquals( + assertJsonEquals(String.format( "{\n" + " \"term\" : {\n" + " \"float_value\" : {\n" - + " \"value\" : 1.0,\n" + + " \"value\" : %f,\n" + " \"boost\" : 1.0\n" + " }\n" + " }\n" - + "}", + + "}", castToFloat(expr.valueOf(null).value())), buildQuery(dsl.equal(ref("float_value", FLOAT), dsl.castFloat(expr)))); } @ParameterizedTest(name = "castDouble({0})") @MethodSource({"numericCastSource"}) void cast_to_double_in_filter(LiteralExpression expr) { - assertJsonEquals( + // double values affected by floating point imprecision, so we can't compare them in json + // (Double)(Float)3.14 -> 3.14000010490417 + assertEquals(castToFloat(expr.valueOf(null).value()), + dsl.castDouble(expr).valueOf(null).doubleValue(), 0.00001); + + assertJsonEquals(String.format( "{\n" + " \"term\" : {\n" + " \"double_value\" : {\n" - + " \"value\" : 1.0,\n" + + " \"value\" : %2.20f,\n" + " \"boost\" : 1.0\n" + " }\n" + " }\n" - + "}", + + "}", dsl.castDouble(expr).valueOf(null).doubleValue()), buildQuery(dsl.equal(ref("double_value", DOUBLE), dsl.castDouble(expr)))); }