Skip to content

Commit

Permalink
Simplify WindowFunctionSupplier interface
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Dec 17, 2021
1 parent 59c1c30 commit 9c6a7e2
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 126 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,16 @@ private void resetAccumulator()
public static WindowFunctionSupplier supplier(Signature signature, InternalAggregationFunction function)
{
requireNonNull(function, "function is null");
return new AbstractWindowFunctionSupplier(signature, null, function.getLambdaInterfaces())
return new WindowFunctionSupplier()
{
@Override
protected WindowFunction newWindowFunction(List<Integer> inputs, boolean ignoreNulls, List<LambdaProvider> lambdaProviders)
public List<Class<?>> getLambdaInterfaces()
{
return function.getLambdaInterfaces();
}

@Override
public WindowFunction createWindowFunction(List<Integer> inputs, boolean ignoreNulls, List<LambdaProvider> lambdaProviders)
{
return new AggregateWindowFunction(function, inputs, lambdaProviders);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,18 @@

import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import io.trino.metadata.Signature;
import io.trino.operator.aggregation.LambdaProvider;
import io.trino.spi.function.Description;
import io.trino.spi.function.ValueWindowFunction;
import io.trino.spi.function.WindowFunction;
import io.trino.spi.type.Type;

import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Constructor;
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class ReflectionWindowFunctionSupplier<T extends WindowFunction>
extends AbstractWindowFunctionSupplier
public class ReflectionWindowFunctionSupplier
implements WindowFunctionSupplier
{
private enum ConstructorType
{
Expand All @@ -39,22 +35,18 @@ private enum ConstructorType
INPUTS_IGNORE_NULLS
}

private final Constructor<T> constructor;
private final int argumentCount;
private final Constructor<? extends WindowFunction> constructor;
private final ConstructorType constructorType;

public ReflectionWindowFunctionSupplier(String name, Type returnType, List<? extends Type> argumentTypes, Class<T> type)
public ReflectionWindowFunctionSupplier(int argumentCount, Class<? extends WindowFunction> type)
{
this(new Signature(name, returnType.getTypeSignature(), Lists.transform(argumentTypes, Type::getTypeSignature)), type);
}

public ReflectionWindowFunctionSupplier(Signature signature, Class<T> type)
{
super(signature, getDescription(requireNonNull(type, "type is null")), ImmutableList.of());
this.argumentCount = argumentCount;
try {
Constructor<T> constructor;
Constructor<? extends WindowFunction> constructor;
ConstructorType constructorType;

if (signature.getArgumentTypes().isEmpty()) {
if (argumentCount == 0) {
constructor = type.getConstructor();
constructorType = ConstructorType.NO_INPUTS;
}
Expand Down Expand Up @@ -83,27 +75,30 @@ else if (ValueWindowFunction.class.isAssignableFrom(type)) {
}

@Override
protected T newWindowFunction(List<Integer> inputs, boolean ignoreNulls, List<LambdaProvider> lambdaProviders)
public List<Class<?>> getLambdaInterfaces()
{
return ImmutableList.of();
}

@Override
public WindowFunction createWindowFunction(List<Integer> argumentChannels, boolean ignoreNulls, List<LambdaProvider> lambdaProviders)
{
requireNonNull(argumentChannels, "inputs is null");
checkArgument(argumentChannels.size() == argumentCount, "Expected %s arguments, but got %s", argumentCount, argumentChannels.size());

try {
switch (constructorType) {
case NO_INPUTS:
return constructor.newInstance();
case INPUTS:
return constructor.newInstance(inputs);
return constructor.newInstance(argumentChannels);
case INPUTS_IGNORE_NULLS:
return constructor.newInstance(inputs, ignoreNulls);
return constructor.newInstance(argumentChannels, ignoreNulls);
}
throw new VerifyException("Unhandled constructor type: " + constructorType);
}
catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}

private static String getDescription(AnnotatedElement annotatedElement)
{
Description description = annotatedElement.getAnnotation(Description.class);
return (description == null) ? null : description.value();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import io.trino.metadata.Signature;
import io.trino.metadata.SqlFunction;

import static com.google.common.base.Strings.nullToEmpty;
import java.util.Optional;

import static io.trino.metadata.FunctionKind.WINDOW;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
Expand All @@ -31,17 +32,16 @@ public class SqlWindowFunction
private final WindowFunctionSupplier supplier;
private final FunctionMetadata functionMetadata;

public SqlWindowFunction(WindowFunctionSupplier supplier, boolean deprecated)
public SqlWindowFunction(Signature signature, Optional<String> description, boolean deprecated, WindowFunctionSupplier supplier)
{
this.supplier = requireNonNull(supplier, "supplier is null");
Signature signature = supplier.getSignature();
functionMetadata = new FunctionMetadata(
signature,
signature.getName(),
new FunctionNullability(true, nCopies(signature.getArgumentTypes().size(), true)),
false,
true,
nullToEmpty(supplier.getDescription()),
description.orElse(""),
WINDOW,
deprecated);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import com.google.common.collect.ImmutableSet;
import io.trino.metadata.Signature;
import io.trino.metadata.TypeVariableConstraint;
import io.trino.spi.function.Description;
import io.trino.spi.function.WindowFunction;
import io.trino.spi.function.WindowFunctionSignature;
import io.trino.spi.type.TypeSignature;

import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -61,7 +63,10 @@ private static SqlWindowFunction parse(Class<? extends WindowFunction> clazz, Wi
argumentTypes,
false);

Optional<String> description = Optional.ofNullable(clazz.getAnnotation(Description.class)).map(Description::value);

boolean deprecated = clazz.getAnnotationsByType(Deprecated.class).length > 0;
return new SqlWindowFunction(new ReflectionWindowFunctionSupplier<>(signature, clazz), deprecated);

return new SqlWindowFunction(signature, description, deprecated, new ReflectionWindowFunctionSupplier(window.argumentTypes().length, clazz));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,13 @@
*/
package io.trino.operator.window;

import io.trino.metadata.Signature;
import io.trino.operator.aggregation.LambdaProvider;
import io.trino.spi.function.WindowFunction;

import java.util.List;

public interface WindowFunctionSupplier
{
Signature getSignature();

String getDescription();

WindowFunction createWindowFunction(List<Integer> argumentChannels, boolean ignoreNulls, List<LambdaProvider> lambdaProviders);

List<Class<?>> getLambdaInterfaces();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,25 @@ public class TestWindowOperator
private static final FrameInfo UNBOUNDED_FRAME = new FrameInfo(RANGE, UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());

public static final List<WindowFunctionDefinition> ROW_NUMBER = ImmutableList.of(
window(new ReflectionWindowFunctionSupplier<>("row_number", BIGINT, ImmutableList.of(), RowNumberFunction.class), BIGINT, UNBOUNDED_FRAME, false, ImmutableList.of()));
window(new ReflectionWindowFunctionSupplier(0, RowNumberFunction.class), BIGINT, UNBOUNDED_FRAME, false, ImmutableList.of()));

public static final List<WindowFunctionDefinition> RANK = ImmutableList.of(
window(new ReflectionWindowFunctionSupplier<>("rank", BIGINT, ImmutableList.of(), RankFunction.class), BIGINT, UNBOUNDED_FRAME, false, ImmutableList.of()));
window(new ReflectionWindowFunctionSupplier(0, RankFunction.class), BIGINT, UNBOUNDED_FRAME, false, ImmutableList.of()));

private static final List<WindowFunctionDefinition> FIRST_VALUE = ImmutableList.of(
window(new ReflectionWindowFunctionSupplier<>("first_value", VARCHAR, ImmutableList.<Type>of(VARCHAR), FirstValueFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1));
window(new ReflectionWindowFunctionSupplier(1, FirstValueFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1));

private static final List<WindowFunctionDefinition> LAST_VALUE = ImmutableList.of(
window(new ReflectionWindowFunctionSupplier<>("last_value", VARCHAR, ImmutableList.<Type>of(VARCHAR), LastValueFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1));
window(new ReflectionWindowFunctionSupplier(1, LastValueFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1));

private static final List<WindowFunctionDefinition> NTH_VALUE = ImmutableList.of(
window(new ReflectionWindowFunctionSupplier<>("nth_value", VARCHAR, ImmutableList.of(VARCHAR, BIGINT), NthValueFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1, 3));
window(new ReflectionWindowFunctionSupplier(2, NthValueFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1, 3));

private static final List<WindowFunctionDefinition> LAG = ImmutableList.of(
window(new ReflectionWindowFunctionSupplier<>("lag", VARCHAR, ImmutableList.of(VARCHAR, BIGINT, VARCHAR), LagFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1, 3, 4));
window(new ReflectionWindowFunctionSupplier(3, LagFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1, 3, 4));

private static final List<WindowFunctionDefinition> LEAD = ImmutableList.of(
window(new ReflectionWindowFunctionSupplier<>("lead", VARCHAR, ImmutableList.of(VARCHAR, BIGINT, VARCHAR), LeadFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1, 3, 4));
window(new ReflectionWindowFunctionSupplier(3, LeadFunction.class), VARCHAR, UNBOUNDED_FRAME, false, ImmutableList.of(), 1, 3, 4));

private ExecutorService executor;
private ScheduledExecutorService scheduledExecutor;
Expand Down

0 comments on commit 9c6a7e2

Please sign in to comment.