Skip to content

Commit

Permalink
feat: implement correct logic for nested lambdas and more complex lam…
Browse files Browse the repository at this point in the history
…bda expressions
  • Loading branch information
stevenpyzhang committed Feb 25, 2021
1 parent 6bc9dff commit e881860
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,23 +183,33 @@ public Void visitLikePredicate(final LikePredicate node, final TypeContext conte

@Override
public Void visitFunctionCall(final FunctionCall node, final TypeContext context) {
final List<SqlArgument> argumentTypes = new ArrayList<>();
final FunctionName functionName = node.getName();

final TypeContext contextCopy = context.getCopy();
final List<SqlArgument> argumentTypes = new ArrayList<>();
final List<TypeContext> typeContextsForChildren = new ArrayList<>();
final boolean hasLambda = node.hasLambdaFunctionCallArguments();
for (final Expression argExpr : node.getArguments()) {
final TypeContext childContext = context.getCopy();
final TypeContext childContext;
if (argExpr instanceof LambdaFunctionCall) {
childContext = contextCopy.getCopy();
} else {
childContext = context.getCopy();
}

typeContextsForChildren.add(childContext);
final SqlType resolvedArgType =
expressionTypeManager.getExpressionSqlType(argExpr, childContext);
expressionTypeManager.getExpressionSqlType(argExpr, childContext.getCopy());

if (argExpr instanceof LambdaFunctionCall) {
argumentTypes.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
SqlLambda.of(contextCopy.getLambdaInputTypes(), resolvedArgType)));
} else {
argumentTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
context.visitType(resolvedArgType);
contextCopy.visitType(resolvedArgType);
}
}
}
Expand All @@ -211,8 +221,9 @@ public Void visitFunctionCall(final FunctionCall node, final TypeContext context
function.newInstance(ksqlConfig)
);

for (final Expression argExpr : node.getArguments()) {
process(argExpr, context.getCopy());
final List<Expression> arguments = node.getArguments();
for (int i = 0; i < arguments.size(); i++) {
process(arguments.get(i), typeContextsForChildren.get(i));
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,21 +452,32 @@ public Pair<String, SqlType> visitFunctionCall(
final String instanceName = funNameToCodeName.apply(functionName);

final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());

final TypeContext contextCopy = context.getCopy();
final List<SqlArgument> argumentSchemas = new ArrayList<>();
final List<TypeContext> typeContextsForChildren = new ArrayList<>();
final boolean hasLambda = node.hasLambdaFunctionCallArguments();

for (final Expression argExpr : node.getArguments()) {
final TypeContext childContext = context.getCopy();
final TypeContext childContext;
if (argExpr instanceof LambdaFunctionCall) {
childContext = contextCopy.getCopy();
} else {
childContext = context.getCopy();
}

typeContextsForChildren.add(childContext.getCopy());
final SqlType resolvedArgType =
expressionTypeManager.getExpressionSqlType(argExpr, childContext);
if (argExpr instanceof LambdaFunctionCall) {
argumentSchemas.add(
SqlArgument.of(
SqlLambda.of(context.getLambdaInputTypes(), childContext.getSqlType())));
SqlLambda.of(contextCopy.getLambdaInputTypes(), resolvedArgType)));
} else {
argumentSchemas.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
context.visitType(resolvedArgType);
contextCopy.visitType(resolvedArgType);
}
}
}
Expand Down Expand Up @@ -494,7 +505,7 @@ public Pair<String, SqlType> visitFunctionCall(
}

joiner.add(
process(convertArgument(arg, sqlType, paramType), context.getCopy())
process(convertArgument(arg, sqlType, paramType),typeContextsForChildren.get(i))
.getLeft());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,22 +474,30 @@ public Void visitFunctionCall(
final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());

final List<SqlArgument> argTypes = new ArrayList<>();
final TypeContext contextCopy = expressionTypeContext.getCopy();

final boolean hasLambda = node.hasLambdaFunctionCallArguments();
for (final Expression expression : node.getArguments()) {
final TypeContext childContext = expressionTypeContext.getCopy();
final TypeContext childContext;
if (expression instanceof LambdaFunctionCall) {
childContext = contextCopy.getCopy();
} else {
childContext = expressionTypeContext.getCopy();
}
process(expression, childContext);
final SqlType resolvedArgType = childContext.getSqlType();

if (expression instanceof LambdaFunctionCall) {
argTypes.add(
SqlArgument.of(
SqlLambda.of(expressionTypeContext.getLambdaInputTypes(),
childContext.getSqlType())));
SqlLambda.of(
contextCopy.getLambdaInputTypes(),
resolvedArgType)));
} else {
argTypes.add(SqlArgument.of(resolvedArgType));
// for lambdas - we save the type information to resolve the lambda generics
if (hasLambda) {
expressionTypeContext.visitType(resolvedArgType);
contextCopy.visitType(resolvedArgType);
}
}
}
Expand Down
Loading

0 comments on commit e881860

Please sign in to comment.