Skip to content

Commit

Permalink
Add benchmark for array filter object
Browse files Browse the repository at this point in the history
  • Loading branch information
hackeryang authored and raunaqmorarka committed Mar 29, 2023
1 parent daa80f6 commit de458d6
Showing 1 changed file with 148 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.Signature;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.sql.gen.ExpressionCompiler;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;
import io.trino.sql.tree.QualifiedName;
import io.trino.type.FunctionType;
Expand All @@ -58,19 +60,24 @@

import static com.google.common.base.Throwables.throwIfUnchecked;
import static com.google.common.base.Verify.verify;
import static io.trino.block.BlockAssertions.createRandomBlockForType;
import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static io.trino.operator.scalar.BenchmarkArrayFilter.ExactArrayFilterFunction.EXACT_ARRAY_FILTER_FUNCTION;
import static io.trino.operator.scalar.BenchmarkArrayFilter.ExactArrayFilterObjectFunction.EXACT_ARRAY_FILTER_OBJECT_FUNCTION;
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.OperatorType.LESS_THAN;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.TypeSignature.arrayType;
import static io.trino.spi.type.TypeSignature.functionType;
import static io.trino.spi.type.TypeUtils.readNativeValue;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.relational.Expressions.constant;
import static io.trino.sql.relational.Expressions.field;
import static io.trino.sql.relational.SpecialForm.Form.DEREFERENCE;
import static io.trino.testing.TestingConnectorSession.SESSION;
import static io.trino.util.Reflection.methodHandle;
import static java.lang.Boolean.TRUE;
Expand All @@ -88,9 +95,11 @@ public class BenchmarkArrayFilter
private static final int ARRAY_SIZE = 4;
private static final int NUM_TYPES = 1;
private static final List<Type> TYPES = ImmutableList.of(BIGINT);
private static final List<Type> ROW_TYPES = ImmutableList.of(RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE)));

static {
verify(NUM_TYPES == TYPES.size());
verify(NUM_TYPES == ROW_TYPES.size());
}

@Benchmark
Expand All @@ -105,6 +114,18 @@ public List<Optional<Page>> benchmark(BenchmarkData data)
data.getPage()));
}

@Benchmark
@OperationsPerInvocation(POSITIONS * ARRAY_SIZE * NUM_TYPES)
public List<Optional<Page>> benchmarkObject(RowBenchmarkData data)
{
return ImmutableList.copyOf(
data.getPageProcessor().process(
SESSION,
new DriverYieldSignal(),
newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()),
data.getPage()));
}

@SuppressWarnings("FieldMayBeFinal")
@State(Scope.Thread)
public static class BenchmarkData
Expand Down Expand Up @@ -172,15 +193,85 @@ public Page getPage()
}
}

@SuppressWarnings("FieldMayBeFinal")
@State(Scope.Thread)
public static class RowBenchmarkData
{
@Param({"filter", "exact_filter"})
private String name = "filter";

private Page page;
private PageProcessor pageProcessor;

@Setup
public void setup()
{
TestingFunctionResolution functionResolution = new TestingFunctionResolution(InternalFunctionBundle.builder().function(EXACT_ARRAY_FILTER_OBJECT_FUNCTION).build());
ExpressionCompiler compiler = functionResolution.getExpressionCompiler();
ImmutableList.Builder<RowExpression> projectionsBuilder = ImmutableList.builder();
Block[] blocks = new Block[ROW_TYPES.size()];
for (int i = 0; i < ROW_TYPES.size(); i++) {
Type elementType = ROW_TYPES.get(i);
ArrayType arrayType = new ArrayType(elementType);
ResolvedFunction resolvedFunction = functionResolution.resolveFunction(
QualifiedName.of(name),
fromTypes(arrayType, new FunctionType(ROW_TYPES, BOOLEAN)));
ResolvedFunction lessThan = functionResolution.resolveOperator(LESS_THAN, ImmutableList.of(BIGINT, BIGINT));

projectionsBuilder.add(new CallExpression(resolvedFunction, ImmutableList.of(
field(0, arrayType),
new LambdaDefinitionExpression(
ImmutableList.of(elementType),
ImmutableList.of("x"),
new CallExpression(
lessThan,
ImmutableList.of(
constant(0L, BIGINT),
new SpecialForm(
DEREFERENCE,
BIGINT,
new VariableReferenceExpression("x", elementType),
constant(0, INTEGER))))))));
blocks[i] = createChannel(POSITIONS, arrayType);
}

ImmutableList<RowExpression> projections = projectionsBuilder.build();
pageProcessor = compiler.compilePageProcessor(Optional.empty(), projections).get();
page = new Page(blocks);
}

private static Block createChannel(int positionCount, ArrayType arrayType)
{
return createRandomBlockForType(arrayType, positionCount, 0.2F);
}

public PageProcessor getPageProcessor()
{
return pageProcessor;
}

public Page getPage()
{
return page;
}
}

public static void main(String[] args)
throws Exception
{
// assure the benchmarks are valid before running
BenchmarkData data = new BenchmarkData();
data.setup();
new BenchmarkArrayFilter().benchmark(data);
BenchmarkArrayFilter benchmarkArrayFilter = new BenchmarkArrayFilter();
benchmarkArrayFilter.benchmark(data);

Benchmarks.benchmark(BenchmarkArrayFilter.class).run();
RowBenchmarkData rowData = new RowBenchmarkData();
rowData.setup();
benchmarkArrayFilter.benchmarkObject(rowData);

Benchmarks.benchmark(BenchmarkArrayFilter.class)
.withOptions(optionsBuilder -> optionsBuilder.jvmArgs("-Xmx4g"))
.run();
}

public static final class ExactArrayFilterFunction
Expand Down Expand Up @@ -237,4 +328,59 @@ public static Block filter(Type type, Block block, MethodHandle function)
return resultBuilder.build();
}
}

public static final class ExactArrayFilterObjectFunction
extends SqlScalarFunction
{
public static final ExactArrayFilterObjectFunction EXACT_ARRAY_FILTER_OBJECT_FUNCTION = new ExactArrayFilterObjectFunction();

private static final MethodHandle METHOD_HANDLE = methodHandle(ExactArrayFilterObjectFunction.class, "filterObject", Type.class, Block.class, MethodHandle.class);

private ExactArrayFilterObjectFunction()
{
super(FunctionMetadata.scalarBuilder()
.signature(Signature.builder()
.name("exact_filter")
.typeVariable("T")
.returnType(arrayType(new TypeSignature("T")))
.argumentType(arrayType(new TypeSignature("T")))
.argumentType(functionType(new TypeSignature("T"), BOOLEAN.getTypeSignature()))
.build())
.nondeterministic()
.description("return array containing elements that match the given predicate")
.build());
}

@Override
protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature)
{
Type type = ((ArrayType) boundSignature.getReturnType()).getElementType();
return new ChoicesSpecializedSqlScalarFunction(
boundSignature,
FAIL_ON_NULL,
ImmutableList.of(NEVER_NULL, NEVER_NULL),
METHOD_HANDLE.bindTo(type));
}

public static Block filterObject(Type type, Block block, MethodHandle function)
{
int positionCount = block.getPositionCount();
BlockBuilder resultBuilder = type.createBlockBuilder(null, positionCount);
for (int position = 0; position < positionCount; position++) {
Object input = type.getObject(block, position);
Boolean keep;
try {
keep = (Boolean) function.invokeExact(input);
}
catch (Throwable t) {
throwIfUnchecked(t);
throw new RuntimeException(t);
}
if (TRUE.equals(keep)) {
type.appendTo(block, position, resultBuilder);
}
}
return resultBuilder.build();
}
}
}

0 comments on commit de458d6

Please sign in to comment.