From 1638d2ebde1e7a20c13bfaf1f84990b5a4abfb97 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Thu, 14 Sep 2023 23:13:45 -0700 Subject: [PATCH] Add FlatHashStrategyCompiler --- .../main/java/io/trino/operator/FlatHash.java | 4 +- .../io/trino/operator/FlatHashStrategy.java | 316 +------------- .../operator/FlatHashStrategyCompiler.java | 400 ++++++++++++++++++ .../java/io/trino/sql/gen/JoinCompiler.java | 5 +- 4 files changed, 414 insertions(+), 311 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHash.java b/core/trino-main/src/main/java/io/trino/operator/FlatHash.java index 05e8ae38d0c3..51f25addc2f0 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHash.java @@ -164,13 +164,11 @@ public void appendTo(int groupId, BlockBuilder[] blockBuilders) int recordOffset = getRecordOffset(index); byte[] variableWidthChunk = EMPTY_CHUNK; - int variableWidthOffset = 0; if (variableWidthData != null) { variableWidthChunk = variableWidthData.getChunk(records, recordOffset); - variableWidthOffset = VariableWidthData.getChunkOffset(records, recordOffset); } - flatHashStrategy.readFlat(records, recordOffset + recordValueOffset, variableWidthChunk, variableWidthOffset, blockBuilders); + flatHashStrategy.readFlat(records, recordOffset + recordValueOffset, variableWidthChunk, blockBuilders); if (hasPrecomputedHash) { BIGINT.writeLong(blockBuilders[blockBuilders.length - 1], (long) LONG_HANDLE.get(records, recordOffset + recordHashOffset)); } diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java index 3d6b66616fa2..c4ddeee500af 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategy.java @@ -13,325 +13,29 @@ */ package io.trino.operator; -import com.google.common.collect.ImmutableList; -import io.trino.operator.scalar.CombineHashFunction; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeOperators; -import java.lang.invoke.MethodHandle; -import java.lang.invoke.MethodHandles; -import java.util.List; - -import static com.google.common.base.Throwables.throwIfUnchecked; -import static io.trino.operator.VariableWidthData.EMPTY_CHUNK; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; -import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; -import static io.trino.spi.function.InvocationConvention.simpleConvention; -import static io.trino.spi.type.TypeUtils.NULL_HASH_CODE; -import static io.trino.sql.planner.optimizations.HashGenerationOptimizer.INITIAL_HASH_VALUE; -import static java.lang.Math.toIntExact; -import static java.lang.invoke.MethodHandles.arrayElementGetter; -import static java.lang.invoke.MethodHandles.collectArguments; -import static java.lang.invoke.MethodHandles.constant; -import static java.lang.invoke.MethodHandles.dropArguments; -import static java.lang.invoke.MethodHandles.guardWithTest; -import static java.lang.invoke.MethodHandles.insertArguments; -import static java.lang.invoke.MethodHandles.lookup; -import static java.lang.invoke.MethodHandles.permuteArguments; -import static java.lang.invoke.MethodType.methodType; -import static java.util.Objects.requireNonNull; - -public class FlatHashStrategy +public interface FlatHashStrategy { - private static final MethodHandle READ_FLAT_FIELD_IS_NULL; - private static final MethodHandle READ_FLAT_NULL_FIELD; - private static final MethodHandle WRITE_FLAT_NULL_FIELD; - private static final MethodHandle FLAT_IS_NULL; - private static final MethodHandle BLOCK_IS_NULL; - private static final MethodHandle LOGICAL_OR; - private static final MethodHandle BOOLEAN_NOT_EQUALS; - private static final MethodHandle INTEGER_ADD; - - static { - try { - MethodHandles.Lookup lookup = lookup(); - READ_FLAT_FIELD_IS_NULL = lookup.findStatic(FlatHashStrategy.class, "readFlatFieldIsNull", methodType(boolean.class, int.class, byte[].class, int.class)); - READ_FLAT_NULL_FIELD = lookup.findVirtual(BlockBuilder.class, "appendNull", methodType(BlockBuilder.class)).asType(methodType(void.class, BlockBuilder.class)); - WRITE_FLAT_NULL_FIELD = dropArguments( - dropArguments( - lookup.findStatic(FlatHashStrategy.class, "writeFieldNull", methodType(void.class, int.class, byte[].class, int.class)), - 3, - byte[].class, - int.class), - 1, - Block.class, - int.class); - FLAT_IS_NULL = lookup.findStatic(FlatHashStrategy.class, "flatIsNull", methodType(boolean.class, int.class, byte[].class, int.class)); - BLOCK_IS_NULL = lookup.findVirtual(Block.class, "isNull", methodType(boolean.class, int.class)); - LOGICAL_OR = lookup.findStatic(Boolean.class, "logicalOr", methodType(boolean.class, boolean.class, boolean.class)); - BOOLEAN_NOT_EQUALS = lookup.findStatic(FlatHashStrategy.class, "booleanNotEquals", methodType(boolean.class, boolean.class, boolean.class)); - INTEGER_ADD = lookup.findStatic(FlatHashStrategy.class, "integerAdd", methodType(int.class, int.class, int.class)); - } - catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } - } - - private final List types; - private final boolean anyVariableWidth; - private final int totalFlatFixedLength; - private final List readFlatMethods; - private final List writeFlatMethods; - private final List hashFlatMethods; - private final List hashBlockMethods; - private final List distinctFlatBlockMethods; - - public FlatHashStrategy(List types, TypeOperators typeOperators) - { - this.types = ImmutableList.copyOf(requireNonNull(types, "types is null")); - ImmutableList.Builder readFlatMethods = ImmutableList.builder(); - ImmutableList.Builder writeFlatMethods = ImmutableList.builder(); - ImmutableList.Builder distinctFlatBlockMethods = ImmutableList.builder(); - ImmutableList.Builder hashFlatMethods = ImmutableList.builder(); - ImmutableList.Builder hashBlockMethods = ImmutableList.builder(); - - try { - MethodHandle readFlatNullField = dropArguments(READ_FLAT_NULL_FIELD, 0, byte[].class, int.class, byte[].class); - - int[] fieldIsNullOffsets = new int[types.size()]; - int[] fieldFixedOffsets = new int[types.size()]; - - int variableWidthCount = (int) types.stream().filter(Type::isFlatVariableWidth).count(); - - int fixedOffset = 0; - for (int i = 0; i < types.size(); i++) { - Type type = types.get(i); - fieldIsNullOffsets[i] = fixedOffset; - fixedOffset++; - fieldFixedOffsets[i] = fixedOffset; - fixedOffset += type.getFlatFixedSize(); - } - totalFlatFixedLength = fixedOffset; - anyVariableWidth = variableWidthCount > 0; - - for (int i = 0; i < types.size(); i++) { - Type type = types.get(i); - - MethodHandle readFlat = typeOperators.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)); - readFlat = toAbsoluteFlatArgument(type, readFlat, 0, fieldFixedOffsets[i]); - readFlat = guardWithTest( - insertArguments(READ_FLAT_FIELD_IS_NULL, 0, fieldIsNullOffsets[i]), - readFlatNullField, - readFlat); - readFlat = collectArguments(readFlat, 3, insertArguments(arrayElementGetter(BlockBuilder[].class), 1, i)); - readFlatMethods.add(readFlat); - - MethodHandle writeFlat = typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)); - // add the field fixed offset to the base fixed offset - writeFlat = collectArguments(writeFlat, 3, insertArguments(INTEGER_ADD, 1, fieldFixedOffsets[i])); - writeFlat = guardWithTest( - BLOCK_IS_NULL, - insertArguments(WRITE_FLAT_NULL_FIELD, 0, fieldIsNullOffsets[i]), - writeFlat); - writeFlat = collectArguments(writeFlat, 0, insertArguments(arrayElementGetter(Block[].class), 1, i)); - writeFlatMethods.add(writeFlat); - - MethodHandle distinctFlatBlock = typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)); - distinctFlatBlock = toAbsoluteFlatArgument(type, distinctFlatBlock, 0, fieldFixedOffsets[i]); - distinctFlatBlock = guardWithTest( - LOGICAL_OR, - dropArguments(BOOLEAN_NOT_EQUALS, 2, byte[].class, int.class, byte[].class, Block.class, int.class), - dropArguments(distinctFlatBlock, 0, boolean.class, boolean.class)); - distinctFlatBlock = collectArguments(distinctFlatBlock, 1, BLOCK_IS_NULL); - distinctFlatBlock = collectArguments(distinctFlatBlock, 0, insertArguments(FLAT_IS_NULL, 0, fieldIsNullOffsets[i])); - distinctFlatBlock = permuteArguments( - distinctFlatBlock, - methodType(boolean.class, byte[].class, int.class, byte[].class, Block.class, int.class), - 0, 1, 3, 4, 0, 1, 2, 3, 4); - distinctFlatBlock = collectArguments(distinctFlatBlock, 3, insertArguments(arrayElementGetter(Block[].class), 1, i)); - distinctFlatBlockMethods.add(distinctFlatBlock); + boolean isAnyVariableWidth(); - MethodHandle hashFlat = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)); - hashFlat = toAbsoluteFlatArgument(type, hashFlat, 0, fieldFixedOffsets[i]); - hashFlat = guardWithTest( - insertArguments(FLAT_IS_NULL, 0, fieldIsNullOffsets[i]), - dropArguments(constant(long.class, NULL_HASH_CODE), 0, byte[].class, int.class, byte[].class), - hashFlat); - hashFlatMethods.add(hashFlat); + int getTotalFlatFixedLength(); - MethodHandle hashBlock = typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)); - hashBlock = guardWithTest( - BLOCK_IS_NULL, - dropArguments(constant(long.class, NULL_HASH_CODE), 0, Block.class, int.class), - hashBlock); - hashBlock = collectArguments(hashBlock, 0, insertArguments(arrayElementGetter(Block[].class), 1, i)); - hashBlockMethods.add(hashBlock); - } - } - catch (ReflectiveOperationException e) { - throw new RuntimeException(e); - } + int getTotalVariableWidth(Block[] blocks, int position); - this.readFlatMethods = readFlatMethods.build(); - this.writeFlatMethods = writeFlatMethods.build(); - this.distinctFlatBlockMethods = distinctFlatBlockMethods.build(); - this.hashFlatMethods = hashFlatMethods.build(); - this.hashBlockMethods = hashBlockMethods.build(); - } + void readFlat(byte[] fixedChunk, int fixedOffset, byte[] variableChunk, BlockBuilder[] blockBuilders); - public boolean isAnyVariableWidth() - { - return anyVariableWidth; - } + void writeFlat(Block[] blocks, int position, byte[] fixedChunk, int fixedOffset, byte[] variableChunk, int variableOffset); - public int getTotalFlatFixedLength() - { - return totalFlatFixedLength; - } - - public int getTotalVariableWidth(Block[] blocks, int position) - { - if (!anyVariableWidth) { - return 0; - } - - long variableWidth = 0; - for (int i = 0; i < types.size(); i++) { - Type type = types.get(i); - Block block = blocks[i]; - - if (type.isFlatVariableWidth()) { - variableWidth += type.getFlatVariableWidthSize(block, position); - } - } - return toIntExact(variableWidth); - } - - public void readFlat(byte[] fixedChunk, int fixedOffset, byte[] variableChunk, int variableOffset, BlockBuilder[] blockBuilders) - { - try { - for (MethodHandle readFlatMethod : readFlatMethods) { - readFlatMethod.invokeExact(fixedChunk, fixedOffset, variableChunk, blockBuilders); - } - } - catch (Throwable throwable) { - throwIfUnchecked(throwable); - throw new RuntimeException(throwable); - } - } - - public void writeFlat(Block[] blocks, int position, byte[] fixedChunk, int fixedOffset, byte[] variableChunk, int variableOffset) - { - try { - for (int i = 0, writeFlatMethodsSize = writeFlatMethods.size(); i < writeFlatMethodsSize; i++) { - writeFlatMethods.get(i).invokeExact(blocks, position, fixedChunk, fixedOffset, variableChunk, variableOffset); - Type type = types.get(i); - if (type.isFlatVariableWidth() && !blocks[i].isNull(position)) { - variableOffset += type.getFlatVariableWidthSize(blocks[i], position); - } - } - } - catch (Throwable throwable) { - throwIfUnchecked(throwable); - throw new RuntimeException(throwable); - } - } - - public boolean valueNotDistinctFrom( + boolean valueNotDistinctFrom( byte[] leftFixedChunk, int leftFixedOffset, byte[] leftVariableChunk, Block[] rightBlocks, - int rightPosition) - { - try { - for (MethodHandle distinctFlatBlockMethod : distinctFlatBlockMethods) { - if ((boolean) distinctFlatBlockMethod.invokeExact(leftFixedChunk, leftFixedOffset, leftVariableChunk, rightBlocks, rightPosition)) { - return false; - } - } - return true; - } - catch (Throwable throwable) { - throwIfUnchecked(throwable); - throw new RuntimeException(throwable); - } - } - - public long hash(Block[] blocks, int position) - { - try { - long result = INITIAL_HASH_VALUE; - for (MethodHandle hashBlockMethod : hashBlockMethods) { - result = CombineHashFunction.getHash(result, (long) hashBlockMethod.invokeExact(blocks, position)); - } - return result; - } - catch (Throwable throwable) { - throwIfUnchecked(throwable); - throw new RuntimeException(throwable); - } - } - - public long hash(byte[] fixedChunk, int fixedOffset, byte[] variableChunk) - { - try { - long result = INITIAL_HASH_VALUE; - for (MethodHandle hashFlat : hashFlatMethods) { - result = CombineHashFunction.getHash(result, (long) hashFlat.invokeExact(fixedChunk, fixedOffset, variableChunk)); - } - return result; - } - catch (Throwable throwable) { - throwIfUnchecked(throwable); - throw new RuntimeException(throwable); - } - } - - private static MethodHandle toAbsoluteFlatArgument(Type type, MethodHandle methodHandle, int argument, int fixedPosition) - throws ReflectiveOperationException - { - // offset the fixed position by the field offset - methodHandle = collectArguments(methodHandle, argument + 1, insertArguments(INTEGER_ADD, 1, fixedPosition)); - - // for fixed types, hard code a reference to the empty slice - if (!type.isFlatVariableWidth()) { - methodHandle = insertArguments(methodHandle, argument + 2, (Object) EMPTY_CHUNK); - methodHandle = dropArguments(methodHandle, argument + 2, byte[].class); - } - return methodHandle; - } - - private static boolean readFlatFieldIsNull(int fieldNullOffset, byte[] fixedChunk, int fixedOffset) - { - return fixedChunk[fixedOffset + fieldNullOffset] != 0; - } - - private static void writeFieldNull(int fieldNullOffset, byte[] fixedChunk, int fixedOffset) - { - fixedChunk[fixedOffset + fieldNullOffset] = 1; - } - - private static boolean flatIsNull( - int fieldNullOffset, - byte[] fixedChunk, - int fixedOffset) - { - return fixedChunk[fixedOffset + fieldNullOffset] != 0; - } + int rightPosition); - private static boolean booleanNotEquals(boolean left, boolean right) - { - return left != right; - } + long hash(Block[] blocks, int position); - private static int integerAdd(int left, int right) - { - return left + right; - } + long hash(byte[] fixedChunk, int fixedOffset, byte[] variableChunk); } diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java new file mode 100644 index 000000000000..6097ea2add20 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java @@ -0,0 +1,400 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import io.airlift.bytecode.BytecodeBlock; +import io.airlift.bytecode.ClassDefinition; +import io.airlift.bytecode.FieldDefinition; +import io.airlift.bytecode.MethodDefinition; +import io.airlift.bytecode.Parameter; +import io.airlift.bytecode.Scope; +import io.airlift.bytecode.Variable; +import io.airlift.bytecode.control.IfStatement; +import io.trino.operator.scalar.CombineHashFunction; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; +import io.trino.sql.gen.CallSiteBinder; + +import java.lang.invoke.MethodHandle; +import java.util.ArrayList; +import java.util.List; + +import static io.airlift.bytecode.Access.FINAL; +import static io.airlift.bytecode.Access.PRIVATE; +import static io.airlift.bytecode.Access.PUBLIC; +import static io.airlift.bytecode.Access.STATIC; +import static io.airlift.bytecode.Access.a; +import static io.airlift.bytecode.Parameter.arg; +import static io.airlift.bytecode.ParameterizedType.type; +import static io.airlift.bytecode.expression.BytecodeExpressions.add; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantBoolean; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantFalse; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; +import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; +import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; +import static io.airlift.bytecode.expression.BytecodeExpressions.not; +import static io.airlift.bytecode.expression.BytecodeExpressions.notEqual; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.BLOCK_BUILDER; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; +import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.spi.type.TypeUtils.NULL_HASH_CODE; +import static io.trino.sql.gen.Bootstrap.BOOTSTRAP_METHOD; +import static io.trino.sql.gen.BytecodeUtils.loadConstant; +import static io.trino.sql.gen.SqlTypeBytecodeExpression.constantType; +import static io.trino.sql.planner.optimizations.HashGenerationOptimizer.INITIAL_HASH_VALUE; +import static io.trino.util.CompilerUtils.defineClass; +import static io.trino.util.CompilerUtils.makeClassName; + +public final class FlatHashStrategyCompiler +{ + private FlatHashStrategyCompiler() {} + + public static FlatHashStrategy compileFlatHashStrategy(List types, TypeOperators typeOperators) + { + boolean anyVariableWidth = (int) types.stream().filter(Type::isFlatVariableWidth).count() > 0; + + List keyFields = new ArrayList<>(); + int fixedOffset = 0; + for (int i = 0; i < types.size(); i++) { + Type type = types.get(i); + keyFields.add(new KeyField( + i, + type, + fixedOffset, + fixedOffset + 1, + typeOperators.getReadValueOperator(type, simpleConvention(BLOCK_BUILDER, FLAT)), + typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)), + typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)), + typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)), + typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL)))); + fixedOffset += 1 + type.getFlatFixedSize(); + } + + CallSiteBinder callSiteBinder = new CallSiteBinder(); + ClassDefinition definition = new ClassDefinition( + a(PUBLIC, FINAL), + makeClassName("FlatHashStrategy"), + type(Object.class), + type(FlatHashStrategy.class)); + + // the 'types' field is not used, but it makes debugging easier + // this is an instance field because a static field doesn't seem to show up in the IntelliJ debugger + FieldDefinition typesField = definition.declareField(a(PRIVATE, FINAL), "types", type(List.class, Type.class)); + MethodDefinition constructor = definition.declareConstructor(a(PUBLIC)); + constructor + .getBody() + .append(constructor.getThis()) + .invokeConstructor(Object.class) + .append(constructor.getThis().setField(typesField, loadConstant(callSiteBinder, ImmutableList.copyOf(types), List.class))) + .ret(); + + definition.declareMethod(a(PUBLIC), "isAnyVariableWidth", type(boolean.class)).getBody() + .append(constantBoolean(anyVariableWidth).ret()); + + definition.declareMethod(a(PUBLIC), "getTotalFlatFixedLength", type(int.class)).getBody() + .append(constantInt(fixedOffset).ret()); + + generateGetTotalVariableWidth(definition, keyFields, callSiteBinder); + + generateReadFlat(definition, keyFields, callSiteBinder); + generateWriteFlat(definition, keyFields, callSiteBinder); + generateNotDistinctFromMethod(definition, keyFields, callSiteBinder); + generateHashBlock(definition, keyFields, callSiteBinder); + generateHashFlat(definition, keyFields, callSiteBinder); + + try { + return defineClass(definition, FlatHashStrategy.class, callSiteBinder.getBindings(), FlatHashStrategyCompiler.class.getClassLoader()) + .getConstructor() + .newInstance(); + } + catch (ReflectiveOperationException e) { + throw new RuntimeException(e); + } + } + + private static void generateGetTotalVariableWidth(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter blocks = arg("blocks", type(Block[].class)); + Parameter position = arg("position", type(int.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "getTotalVariableWidth", + type(int.class), + blocks, + position); + BytecodeBlock body = methodDefinition.getBody(); + + Scope scope = methodDefinition.getScope(); + Variable variableWidth = scope.declareVariable("variableWidth", body, constantLong(0)); + + for (KeyField keyField : keyFields) { + Type type = keyField.type(); + if (type.isFlatVariableWidth()) { + body.append(new IfStatement() + .condition(not(blocks.getElement(keyField.index()).invoke("isNull", boolean.class, position))) + .ifTrue(variableWidth.set(add( + variableWidth, + constantType(callSiteBinder, type).invoke("getFlatVariableWidthSize", int.class, blocks.getElement(keyField.index()), position).cast(long.class))))); + } + } + body.append(invokeStatic(Math.class, "toIntExact", int.class, variableWidth).ret()); + } + + private static void generateReadFlat(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter fixedChunk = arg("fixedChunk", type(byte[].class)); + Parameter fixedOffset = arg("fixedOffset", type(int.class)); + Parameter variableChunk = arg("variableChunk", type(byte[].class)); + Parameter blockBuilders = arg("blockBuilders", type(BlockBuilder[].class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "readFlat", + type(void.class), + fixedChunk, + fixedOffset, + variableChunk, + blockBuilders); + BytecodeBlock body = methodDefinition.getBody(); + + for (KeyField keyField : keyFields) { + body.append(new IfStatement() + .condition(notEqual(fixedChunk.getElement(add(fixedOffset, constantInt(keyField.fieldIsNullOffset()))).cast(int.class), constantInt(0))) + .ifTrue(blockBuilders.getElement(keyField.index()).invoke("appendNull", BlockBuilder.class).pop()) + .ifFalse(new BytecodeBlock() + .append(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(keyField.readFlatMethod()).getBindingId()), + "readFlat", + void.class, + fixedChunk, + add(fixedOffset, constantInt(keyField.fieldFixedOffset())), + variableChunk, + blockBuilders.getElement(keyField.index()))))); + } + body.ret(); + } + + private static void generateWriteFlat(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter blocks = arg("blocks", type(Block[].class)); + Parameter position = arg("position", type(int.class)); + Parameter fixedChunk = arg("fixedChunk", type(byte[].class)); + Parameter fixedOffset = arg("fixedOffset", type(int.class)); + Parameter variableChunk = arg("variableChunk", type(byte[].class)); + Parameter variableOffset = arg("variableOffset", type(int.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "writeFlat", + type(void.class), + blocks, + position, + fixedChunk, + fixedOffset, + variableChunk, + variableOffset); + BytecodeBlock body = methodDefinition.getBody(); + for (KeyField keyField : keyFields) { + BytecodeBlock writeNonNullFlat = new BytecodeBlock() + .append(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(keyField.writeFlatMethod()).getBindingId()), + "writeFlat", + void.class, + blocks.getElement(keyField.index()), + position, + fixedChunk, + add(fixedOffset, constantInt(keyField.fieldFixedOffset())), + variableChunk, + variableOffset)); + if (keyField.type().isFlatVariableWidth()) { + // variableOffset += type.getFlatVariableWidthSize(blocks[i], position); + writeNonNullFlat.append(variableOffset.set(add(variableOffset, constantType(callSiteBinder, keyField.type()).invoke( + "getFlatVariableWidthSize", + int.class, + blocks.getElement(keyField.index()), + position)))); + } + body.append(new IfStatement() + .condition(blocks.getElement(keyField.index()).invoke("isNull", boolean.class, position)) + .ifTrue(fixedChunk.setElement(add(fixedOffset, constantInt(keyField.fieldIsNullOffset())), constantInt(1).cast(byte.class))) + .ifFalse(writeNonNullFlat)); + } + body.ret(); + } + + private static void generateNotDistinctFromMethod(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter leftFixedChunk = arg("leftFixedChunk", type(byte[].class)); + Parameter leftFixedOffset = arg("leftFixedOffset", type(int.class)); + Parameter leftVariableChunk = arg("leftVariableChunk", type(byte[].class)); + Parameter rightBlocks = arg("rightBlocks", type(Block[].class)); + Parameter rightPosition = arg("rightPosition", type(int.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "valueNotDistinctFrom", + type(boolean.class), + leftFixedChunk, + leftFixedOffset, + leftVariableChunk, + rightBlocks, + rightPosition); + BytecodeBlock body = methodDefinition.getBody(); + + for (KeyField keyField : keyFields) { + MethodDefinition distinctFromMethod = generateDistinctFromMethod(definition, keyField, callSiteBinder); + body.append(new IfStatement() + .condition(invokeStatic(distinctFromMethod, leftFixedChunk, leftFixedOffset, leftVariableChunk, rightBlocks.getElement(keyField.index()), rightPosition)) + .ifTrue(constantFalse().ret())); + } + body.append(constantTrue().ret()); + } + + private static MethodDefinition generateDistinctFromMethod(ClassDefinition definition, KeyField keyField, CallSiteBinder callSiteBinder) + { + Parameter leftFixedChunk = arg("leftFixedChunk", type(byte[].class)); + Parameter leftFixedOffset = arg("leftFixedOffset", type(int.class)); + Parameter leftVariableChunk = arg("leftVariableChunk", type(byte[].class)); + Parameter rightBlock = arg("rightBlock", type(Block.class)); + Parameter rightPosition = arg("rightPosition", type(int.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC, STATIC), + "valueDistinctFrom" + keyField.index(), + type(boolean.class), + leftFixedChunk, + leftFixedOffset, + leftVariableChunk, + rightBlock, + rightPosition); + BytecodeBlock body = methodDefinition.getBody(); + Scope scope = methodDefinition.getScope(); + + Variable leftIsNull = scope.declareVariable("leftIsNull", body, notEqual(leftFixedChunk.getElement(add(leftFixedOffset, constantInt(keyField.fieldIsNullOffset()))).cast(int.class), constantInt(0))); + Variable rightIsNull = scope.declareVariable("rightIsNull", body, rightBlock.invoke("isNull", boolean.class, rightPosition)); + + // if (leftIsNull) { + // return !rightIsNull; + // } + body.append(new IfStatement() + .condition(leftIsNull) + .ifTrue(not(rightIsNull).ret())); + + // if (rightIsNull) { + // return true; + // } + body.append(new IfStatement() + .condition(rightIsNull) + .ifTrue(constantTrue().ret())); + + body.append(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(keyField.distinctFlatBlockMethod()).getBindingId()), + "distinctFrom", + boolean.class, + leftFixedChunk, + add(leftFixedOffset, constantInt(keyField.fieldFixedOffset())), + leftVariableChunk, + rightBlock, + rightPosition) + .ret()); + return methodDefinition; + } + + private static void generateHashBlock(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter blocks = arg("blocks", type(Block[].class)); + Parameter position = arg("position", type(int.class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "hash", + type(long.class), + blocks, + position); + BytecodeBlock body = methodDefinition.getBody(); + + Scope scope = methodDefinition.getScope(); + Variable result = scope.declareVariable("result", body, constantLong(INITIAL_HASH_VALUE)); + Variable hash = scope.declareVariable(long.class, "hash"); + Variable block = scope.declareVariable(Block.class, "block"); + + for (KeyField keyField : keyFields) { + body.append(block.set(blocks.getElement(keyField.index()))); + body.append(new IfStatement() + .condition(block.invoke("isNull", boolean.class, position)) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(keyField.hashBlockMethod()).getBindingId()), + "hash", + long.class, + block, + position)))); + body.append(result.set(invokeStatic(CombineHashFunction.class, "getHash", long.class, result, hash))); + } + body.append(result.ret()); + } + + private static void generateHashFlat(ClassDefinition definition, List keyFields, CallSiteBinder callSiteBinder) + { + Parameter fixedChunk = arg("fixedChunk", type(byte[].class)); + Parameter fixedOffset = arg("fixedOffset", type(int.class)); + Parameter variableChunk = arg("variableChunk", type(byte[].class)); + MethodDefinition methodDefinition = definition.declareMethod( + a(PUBLIC), + "hash", + type(long.class), + fixedChunk, + fixedOffset, + variableChunk); + BytecodeBlock body = methodDefinition.getBody(); + + Scope scope = methodDefinition.getScope(); + Variable result = scope.declareVariable("result", body, constantLong(INITIAL_HASH_VALUE)); + Variable hash = scope.declareVariable(long.class, "hash"); + + for (KeyField keyField : keyFields) { + body.append(new IfStatement() + .condition(notEqual(fixedChunk.getElement(add(fixedOffset, constantInt(keyField.fieldIsNullOffset()))).cast(int.class), constantInt(0))) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(keyField.hashFlatMethod()).getBindingId()), + "hash", + long.class, + fixedChunk, + add(fixedOffset, constantInt(keyField.fieldFixedOffset())), + variableChunk)))); + body.append(result.set(invokeStatic(CombineHashFunction.class, "getHash", long.class, result, hash))); + } + body.append(result.ret()); + } + + private record KeyField( + int index, + Type type, + int fieldIsNullOffset, + int fieldFixedOffset, + MethodHandle readFlatMethod, + MethodHandle writeFlatMethod, + MethodHandle distinctFlatBlockMethod, + MethodHandle hashFlatMethod, + MethodHandle hashBlockMethod) {} +} diff --git a/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java b/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java index 0bd3ce3a3061..12030ab0de59 100644 --- a/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java +++ b/core/trino-main/src/main/java/io/trino/sql/gen/JoinCompiler.java @@ -89,6 +89,7 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.notEqual; import static io.airlift.bytecode.expression.BytecodeExpressions.setStatic; import static io.trino.cache.SafeCaches.buildNonEvictableCache; +import static io.trino.operator.FlatHashStrategyCompiler.compileFlatHashStrategy; import static io.trino.operator.join.JoinUtils.getSingleBigintJoinChannel; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; @@ -138,7 +139,7 @@ public JoinCompiler(TypeOperators typeOperators, boolean enableSingleChannelBigi CacheBuilder.newBuilder() .recordStats() .maximumSize(1000), - CacheLoader.from(key -> new FlatHashStrategy(key, typeOperators))); + CacheLoader.from(key -> compileFlatHashStrategy(key, typeOperators))); } @Managed @@ -158,7 +159,7 @@ public CacheStatsMBean getHashStrategiesStats() // This should be in a separate cache, but it is convenient during the transition to keep this in the join compiler public FlatHashStrategy getFlatHashStrategy(List types) { - return flatHashStrategies.getUnchecked(types); + return flatHashStrategies.getUnchecked(ImmutableList.copyOf(types)); } public LookupSourceSupplierFactory compileLookupSourceFactory(List types, List joinChannels, Optional sortChannel, Optional> outputChannels)