From a26a297af55ad46c3f46efae78f470a33d0df251 Mon Sep 17 00:00:00 2001 From: Zara Lim Date: Fri, 25 Jun 2021 12:37:50 -0700 Subject: [PATCH] fix: enable time unit functions for interpreter (#7709) * fix: enable time unit functions for interpreter * add parity test --- .../ExpressionEvaluatorParityTest.java | 2 + .../execution/interpreter/TermCompiler.java | 2 +- .../interpreter/terms/LiteralTerms.java | 24 ++++++++++++ .../InterpretedExpressionTest.java | 38 +++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java index 62483acedef4..efa5b5d3a1e7 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/execution/ExpressionEvaluatorParityTest.java @@ -37,6 +37,7 @@ import io.confluent.ksql.util.KsqlParserTestUtil; import io.confluent.ksql.util.MetaStoreFixture; import java.math.BigDecimal; +import java.sql.Timestamp; import java.util.Collections; import java.util.List; import java.util.Map; @@ -148,6 +149,7 @@ public void shouldDereference() throws Exception { public void shouldDoUdfs() throws Exception { assertOrders("CONCAT('abc-', 'def')", "abc-def"); assertOrders("SPLIT('a-b-c', '-')", ImmutableList.of("a", "b", "c")); + assertOrders("TIMESTAMPADD(SECONDS, 1, '2020-01-01')", new Timestamp(1577836801000L)); assertOrdersError("SPLIT(123, '2')", compileTime("Function 'split' does not accept parameters (INTEGER, STRING)"), compileTime("Function 'split' does not accept parameters (INTEGER, STRING)")); diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/TermCompiler.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/TermCompiler.java index d56024025fb8..f58ff797fd98 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/TermCompiler.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/TermCompiler.java @@ -303,7 +303,7 @@ public Term visitLambdaVariable( @Override public Term visitIntervalUnit(final IntervalUnit exp, final Context context) { - return visitUnsupported(exp); + return LiteralTerms.of(exp.getUnit()); } @Override diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/terms/LiteralTerms.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/terms/LiteralTerms.java index 9bfe5ea6f3b0..b27de35d3f8d 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/terms/LiteralTerms.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/interpreter/terms/LiteralTerms.java @@ -20,6 +20,7 @@ import io.confluent.ksql.schema.ksql.types.SqlTypes; import java.math.BigDecimal; import java.sql.Timestamp; +import java.util.concurrent.TimeUnit; @SuppressWarnings("checkstyle:ClassDataAbstractionCoupling") public final class LiteralTerms { @@ -54,6 +55,10 @@ public static Term of(final Timestamp value) { return new TimestampTermImpl(value); } + public static Term of(final TimeUnit value) { + return new IntervalUnitTermImpl(value); + } + public static NullTerm ofNull() { return new NullTerm(); } @@ -209,4 +214,23 @@ public SqlType getSqlType() { return SqlTypes.TIMESTAMP; } } + + public static class IntervalUnitTermImpl implements Term { + + private final TimeUnit value; + + public IntervalUnitTermImpl(final TimeUnit timeUnit) { + this.value = timeUnit; + } + + @Override + public Object getValue(final TermEvaluationContext context) { + return value; + } + + @Override + public SqlType getSqlType() { + return null; + } + } } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/interpreter/InterpretedExpressionTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/interpreter/InterpretedExpressionTest.java index 6350f6b04c2e..9f95012ccacd 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/interpreter/InterpretedExpressionTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/interpreter/InterpretedExpressionTest.java @@ -52,6 +52,7 @@ import io.confluent.ksql.execution.expression.tree.InListExpression; import io.confluent.ksql.execution.expression.tree.InPredicate; import io.confluent.ksql.execution.expression.tree.IntegerLiteral; +import io.confluent.ksql.execution.expression.tree.IntervalUnit; import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate; import io.confluent.ksql.execution.expression.tree.IsNullPredicate; import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall; @@ -73,6 +74,7 @@ import io.confluent.ksql.function.types.ArrayType; import io.confluent.ksql.function.types.GenericType; import io.confluent.ksql.function.types.IntegerType; +import io.confluent.ksql.function.types.IntervalUnitType; import io.confluent.ksql.function.types.LambdaType; import io.confluent.ksql.function.udf.Kudf; import io.confluent.ksql.function.udf.UdfMetadata; @@ -88,6 +90,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.function.Function; import org.apache.kafka.connect.data.Schema; @@ -393,6 +396,31 @@ public void shouldHandleFunctionCalls_intParams() { assertThat(object, is(2)); } + @Test + public void shouldHandleFunctionCalls_intervalParam() { + // Given: + final UdfFactory udfFactory = mock(UdfFactory.class); + final KsqlScalarFunction udf = mock(KsqlScalarFunction.class); + when(udf.newInstance(any())).thenReturn(new toMillisUdf()); + givenUdf("FOO", udfFactory, udf); + when(udf.parameters()).thenReturn(ImmutableList.of(IntervalUnitType.INSTANCE, IntegerType.INSTANCE)); + + // When: + InterpretedExpression interpreter1 = interpreter( + new FunctionCall( + FunctionName.of("FOO"), + ImmutableList.of( + new IntervalUnit(TimeUnit.SECONDS), + new IntegerLiteral(1)) + ) + ); + final Object object = interpreter1.evaluate(ROW); + + + // Then: + assertThat(object, is(1000)); + } + @Test public void shouldEvaluateIsNullPredicate() { // Given: @@ -1099,4 +1127,14 @@ public Object evaluate(Object... args) { return result; } } + + private static class toMillisUdf implements Kudf { + + @Override + public Object evaluate(Object... args) { + TimeUnit a = (TimeUnit) args[0]; + int b = (Integer) args[1]; + return (int) a.toMillis(b); + } + } }