Skip to content

Commit

Permalink
fix: enable time unit functions for interpreter (#7709)
Browse files Browse the repository at this point in the history
* fix: enable time unit functions for interpreter

* add parity test
  • Loading branch information
Zara Lim authored Jun 25, 2021
1 parent 7f3a892 commit a26a297
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.confluent.ksql.util.KsqlParserTestUtil;
import io.confluent.ksql.util.MetaStoreFixture;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -148,6 +149,7 @@ public void shouldDereference() throws Exception {
public void shouldDoUdfs() throws Exception {
assertOrders("CONCAT('abc-', 'def')", "abc-def");
assertOrders("SPLIT('a-b-c', '-')", ImmutableList.of("a", "b", "c"));
assertOrders("TIMESTAMPADD(SECONDS, 1, '2020-01-01')", new Timestamp(1577836801000L));
assertOrdersError("SPLIT(123, '2')",
compileTime("Function 'split' does not accept parameters (INTEGER, STRING)"),
compileTime("Function 'split' does not accept parameters (INTEGER, STRING)"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ public Term visitLambdaVariable(

@Override
public Term visitIntervalUnit(final IntervalUnit exp, final Context context) {
return visitUnsupported(exp);
return LiteralTerms.of(exp.getUnit());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.concurrent.TimeUnit;

@SuppressWarnings("checkstyle:ClassDataAbstractionCoupling")
public final class LiteralTerms {
Expand Down Expand Up @@ -54,6 +55,10 @@ public static Term of(final Timestamp value) {
return new TimestampTermImpl(value);
}

public static Term of(final TimeUnit value) {
return new IntervalUnitTermImpl(value);
}

public static NullTerm ofNull() {
return new NullTerm();
}
Expand Down Expand Up @@ -209,4 +214,23 @@ public SqlType getSqlType() {
return SqlTypes.TIMESTAMP;
}
}

public static class IntervalUnitTermImpl implements Term {

private final TimeUnit value;

public IntervalUnitTermImpl(final TimeUnit timeUnit) {
this.value = timeUnit;
}

@Override
public Object getValue(final TermEvaluationContext context) {
return value;
}

@Override
public SqlType getSqlType() {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import io.confluent.ksql.execution.expression.tree.InListExpression;
import io.confluent.ksql.execution.expression.tree.InPredicate;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IntervalUnit;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
import io.confluent.ksql.execution.expression.tree.LambdaFunctionCall;
Expand All @@ -73,6 +74,7 @@
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.IntegerType;
import io.confluent.ksql.function.types.IntervalUnitType;
import io.confluent.ksql.function.types.LambdaType;
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.function.udf.UdfMetadata;
Expand All @@ -88,6 +90,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.kafka.connect.data.Schema;
Expand Down Expand Up @@ -393,6 +396,31 @@ public void shouldHandleFunctionCalls_intParams() {
assertThat(object, is(2));
}

@Test
public void shouldHandleFunctionCalls_intervalParam() {
// Given:
final UdfFactory udfFactory = mock(UdfFactory.class);
final KsqlScalarFunction udf = mock(KsqlScalarFunction.class);
when(udf.newInstance(any())).thenReturn(new toMillisUdf());
givenUdf("FOO", udfFactory, udf);
when(udf.parameters()).thenReturn(ImmutableList.of(IntervalUnitType.INSTANCE, IntegerType.INSTANCE));

// When:
InterpretedExpression interpreter1 = interpreter(
new FunctionCall(
FunctionName.of("FOO"),
ImmutableList.of(
new IntervalUnit(TimeUnit.SECONDS),
new IntegerLiteral(1))
)
);
final Object object = interpreter1.evaluate(ROW);


// Then:
assertThat(object, is(1000));
}

@Test
public void shouldEvaluateIsNullPredicate() {
// Given:
Expand Down Expand Up @@ -1099,4 +1127,14 @@ public Object evaluate(Object... args) {
return result;
}
}

private static class toMillisUdf implements Kudf {

@Override
public Object evaluate(Object... args) {
TimeUnit a = (TimeUnit) args[0];
int b = (Integer) args[1];
return (int) a.toMillis(b);
}
}
}

0 comments on commit a26a297

Please sign in to comment.