Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split construction of large rows into separate methods #23721

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
*/
package io.trino.sql.gen;

import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.trino.metadata.FunctionManager;
Expand All @@ -40,26 +43,33 @@ public class BytecodeGeneratorContext
private final CachedInstanceBinder cachedInstanceBinder;
private final FunctionManager functionManager;
private final Variable wasNull;
private final ClassDefinition classDefinition;
private final List<Parameter> contextArguments; // arguments that need to be propagated to generated methods to be able to resolve underlying references, session, etc.

public BytecodeGeneratorContext(
RowExpressionCompiler rowExpressionCompiler,
Scope scope,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
FunctionManager functionManager)
FunctionManager functionManager,
ClassDefinition classDefinition,
List<Parameter> contextArguments)
{
requireNonNull(rowExpressionCompiler, "rowExpressionCompiler is null");
requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null");
requireNonNull(scope, "scope is null");
requireNonNull(callSiteBinder, "callSiteBinder is null");
requireNonNull(functionManager, "functionManager is null");
requireNonNull(classDefinition, "classDefinition is null");

this.rowExpressionCompiler = rowExpressionCompiler;
this.scope = scope;
this.callSiteBinder = callSiteBinder;
this.cachedInstanceBinder = cachedInstanceBinder;
this.functionManager = functionManager;
this.wasNull = scope.getVariable("wasNull");
this.classDefinition = classDefinition;
this.contextArguments = ImmutableList.copyOf(contextArguments);
}

public Scope getScope()
Expand Down Expand Up @@ -110,4 +120,29 @@ public Variable wasNull()
{
return wasNull;
}

public ClassDefinition getClassDefinition()
{
return classDefinition;
}

public RowExpressionCompiler getRowExpressionCompiler()
{
return rowExpressionCompiler;
}

public CachedInstanceBinder getCachedInstanceBinder()
{
return cachedInstanceBinder;
}

public FunctionManager getFunctionManager()
{
return functionManager;
}

public List<Parameter> getContextArguments()
{
return contextArguments;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.sql.gen;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.bytecode.BytecodeBlock;
Expand Down Expand Up @@ -236,11 +237,13 @@ private void generateFilterMethod(
Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull");

RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(cursor),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(session, cursor));

LabelNode end = new LabelNode("end");
method.getBody()
Expand Down Expand Up @@ -276,11 +279,13 @@ private void generateProjectMethod(
Variable wasNullVariable = scope.declareVariable(type(boolean.class), "wasNull");

RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(cursor),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(session, cursor, output));

method.getBody()
.comment("boolean wasNull = false;")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,13 @@ private void generateFilterMethod(
scope.declareVariable("session", body, method.getThis().getField(sessionField));

RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(callSiteBinder, leftPosition, leftPage, rightPosition, rightPage, leftBlocksSize),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(leftPage, leftPosition, rightPage, rightPosition));

BytecodeNode visitorBody = compiler.compile(filter, scope);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,13 @@ public static CompiledLambda preGenerateLambdaExpression(
}

RowExpressionCompiler innerExpressionCompiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
variableReferenceCompiler(parameterMapBuilder.buildOrThrow()),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
parameters.build());

return defineLambdaMethod(
innerExpressionCompiler,
Expand Down Expand Up @@ -266,18 +268,30 @@ public static Class<? extends Supplier<Object>> compileLambdaProvider(LambdaDefi
scope.declareVariable("session", body, method.getThis().getField(sessionField));

RowExpressionCompiler rowExpressionCompiler = new RowExpressionCompiler(
lambdaProviderClassDefinition,
callSiteBinder,
cachedInstanceBinder,
variableReferenceCompiler(ImmutableMap.of()),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of());

List<Parameter> parameters = new ArrayList<>();
parameters.add(arg("session", ConnectorSession.class));
for (int i = 0; i < lambdaExpression.arguments().size(); i++) {
Symbol argument = lambdaExpression.arguments().get(i);
Class<?> type = Primitives.wrap(argument.type().getJavaType());
parameters.add(arg("lambda_" + i + "_" + BytecodeUtils.sanitizeName(argument.name()), type));
}

BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext(
rowExpressionCompiler,
scope,
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
lambdaProviderClassDefinition,
parameters);

body.append(
generateLambda(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,13 @@ private MethodDefinition generateEvaluateMethod(

Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompilerProjection(callSiteBinder),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(session, position));

body.append(thisVariable.getField(blockBuilder))
.append(compiler.compile(projection, scope))
Expand Down Expand Up @@ -543,11 +545,13 @@ private MethodDefinition generateFilterMethod(

Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
RowExpressionCompiler compiler = new RowExpressionCompiler(
classDefinition,
callSiteBinder,
cachedInstanceBinder,
fieldReferenceCompiler(callSiteBinder),
functionManager,
compiledLambdaMap);
compiledLambdaMap,
ImmutableList.of(page, position));

Variable result = scope.declareVariable(boolean.class, "result");
body.append(compiler.compile(filter, scope))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.IfStatement;
Expand All @@ -28,8 +31,12 @@
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;

import java.util.ArrayList;
import java.util.List;

import static io.airlift.bytecode.Access.PUBLIC;
import static io.airlift.bytecode.Access.a;
import static io.airlift.bytecode.Parameter.arg;
import static io.airlift.bytecode.ParameterizedType.type;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse;
import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt;
Expand All @@ -49,6 +56,9 @@ public class RowConstructorCodeGenerator
// Arbitrary value chosen to balance the code size vs performance trade off. Not perf tested.
private static final int MEGAMORPHIC_FIELD_COUNT = 64;

// number of fields to initialize in a single method for large rows
private static final int LARGE_ROW_BATCH_SIZE = 100;

public RowConstructorCodeGenerator(SpecialForm specialForm)
{
requireNonNull(specialForm, "specialForm is null");
Expand Down Expand Up @@ -109,18 +119,58 @@ private BytecodeNode generateExpressionForLargeRows(BytecodeGeneratorContext con
BytecodeBlock block = new BytecodeBlock().setDescription("Constructor for " + rowType);
CallSiteBinder binder = context.getCallSiteBinder();
Scope scope = context.getScope();
List<Type> types = rowType.getTypeParameters();

Variable fieldBuilders = scope.getOrCreateTempVariable(BlockBuilder[].class);
block.append(fieldBuilders.set(invokeStatic(RowConstructorCodeGenerator.class, "createFieldBlockBuildersForSingleRow", BlockBuilder[].class, constantType(binder, rowType))));

Variable blockBuilder = scope.getOrCreateTempVariable(BlockBuilder.class);
for (int i = 0; i < arguments.size(); ++i) {
for (int i = 0; i < arguments.size(); i += LARGE_ROW_BATCH_SIZE) {
MethodDefinition partialRowConstructor = generatePartialRowConstructor(i, Math.min(i + LARGE_ROW_BATCH_SIZE, arguments.size()), context);
block.getVariable(scope.getThis());
for (Parameter argument : context.getContextArguments()) {
block.getVariable(argument);
}
block.getVariable(fieldBuilders);
block.invokeVirtual(partialRowConstructor);
}
scope.releaseTempVariableForReuse(blockBuilder);

block.append(invokeStatic(RowConstructorCodeGenerator.class, "createSqlRowFromFieldBuildersForSingleRow", SqlRow.class, fieldBuilders));
scope.releaseTempVariableForReuse(fieldBuilders);
block.append(context.wasNull().set(constantFalse()));
return block;
}

private MethodDefinition generatePartialRowConstructor(int start, int end, BytecodeGeneratorContext parentContext)
{
ClassDefinition classDefinition = parentContext.getClassDefinition();
CallSiteBinder binder = parentContext.getCallSiteBinder();

Parameter fieldBuilders = arg("fieldBuilders", BlockBuilder[].class);

List<Parameter> parameters = new ArrayList<>(parentContext.getContextArguments());
parameters.add(fieldBuilders);

MethodDefinition methodDefinition = classDefinition.declareMethod(
a(PUBLIC),
"partialRowConstructor" + System.identityHashCode(this) + "_" + start,
type(void.class),
parameters);

Scope scope = methodDefinition.getScope();
BytecodeBlock block = methodDefinition.getBody();
scope.declareVariable("wasNull", block, constantFalse());

BytecodeGeneratorContext context = new BytecodeGeneratorContext(parentContext.getRowExpressionCompiler(), scope, binder, parentContext.getCachedInstanceBinder(), parentContext.getFunctionManager(), classDefinition, parentContext.getContextArguments());
Variable blockBuilder = scope.getOrCreateTempVariable(BlockBuilder.class);
List<Type> types = rowType.getTypeParameters();
for (int i = start; i < end; i++) {
Type fieldType = types.get(i);

block.append(blockBuilder.set(fieldBuilders.getElement(constantInt(i))));

block.comment("Clean wasNull and Generate + " + i + "-th field of row");

block.append(context.wasNull().set(constantFalse()));
block.append(context.generate(arguments.get(i)));
Variable field = scope.getOrCreateTempVariable(fieldType.getJavaType());
Expand All @@ -131,12 +181,9 @@ private BytecodeNode generateExpressionForLargeRows(BytecodeGeneratorContext con
.ifFalse(constantType(binder, fieldType).writeValue(blockBuilder, field).pop()));
scope.releaseTempVariableForReuse(field);
}
scope.releaseTempVariableForReuse(blockBuilder);

block.append(invokeStatic(RowConstructorCodeGenerator.class, "createSqlRowFromFieldBuildersForSingleRow", SqlRow.class, fieldBuilders));
scope.releaseTempVariableForReuse(fieldBuilders);
block.append(context.wasNull().set(constantFalse()));
return block;
block.ret();
return methodDefinition;
}

@UsedByGeneratedCode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.trino.metadata.FunctionManager;
Expand All @@ -31,6 +33,7 @@
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;

import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand All @@ -45,24 +48,30 @@

public class RowExpressionCompiler
{
private final ClassDefinition classDefinition;
private final CallSiteBinder callSiteBinder;
private final CachedInstanceBinder cachedInstanceBinder;
private final RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler;
private final FunctionManager functionManager;
private final Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap;
private final List<Parameter> contextArguments; // arguments that need to be propagates to generated methods

public RowExpressionCompiler(
ClassDefinition classDefinition,
CallSiteBinder callSiteBinder,
CachedInstanceBinder cachedInstanceBinder,
RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler,
FunctionManager functionManager,
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap)
Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap,
List<Parameter> contextArguments)
{
this.classDefinition = classDefinition;
this.callSiteBinder = callSiteBinder;
this.cachedInstanceBinder = cachedInstanceBinder;
this.fieldReferenceCompiler = fieldReferenceCompiler;
this.functionManager = functionManager;
this.compiledLambdaMap = compiledLambdaMap;
this.contextArguments = ImmutableList.copyOf(contextArguments);
}

public BytecodeNode compile(RowExpression rowExpression, Scope scope)
Expand All @@ -86,7 +95,9 @@ public BytecodeNode visitCall(CallExpression call, Context context)
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
classDefinition,
contextArguments);

return generatorContext.generateFullCall(call.resolvedFunction(), call.arguments());
}
Expand Down Expand Up @@ -116,7 +127,9 @@ public BytecodeNode visitSpecialForm(SpecialForm specialForm, Context context)
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
classDefinition,
contextArguments);

return generator.generateExpression(generatorContext);
}
Expand Down Expand Up @@ -178,7 +191,9 @@ public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context conte
context.getScope(),
callSiteBinder,
cachedInstanceBinder,
functionManager);
functionManager,
classDefinition,
contextArguments);

return generateLambda(
generatorContext,
Expand Down
Loading