Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix table function execution without partitioning (v2) #21558

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public final class WorkProcessorUtils
{
private WorkProcessorUtils() {}

static <T> Iterator<T> iteratorFrom(WorkProcessor<T> processor)
public static <T> Iterator<T> iteratorFrom(WorkProcessor<T> processor)
{
requireNonNull(processor, "processor is null");
return new AbstractIterator<>()
Expand All @@ -58,7 +58,7 @@ protected T computeNext()
};
}

static <T> Iterator<Optional<T>> yieldingIteratorFrom(WorkProcessor<T> processor)
public static <T> Iterator<Optional<T>> yieldingIteratorFrom(WorkProcessor<T> processor)
{
return new YieldingIterator<>(processor);
}
Expand Down Expand Up @@ -95,7 +95,7 @@ protected Optional<T> computeNext()
}
}

static <T> WorkProcessor<T> fromIterator(Iterator<T> iterator)
public static <T> WorkProcessor<T> fromIterator(Iterator<T> iterator)
{
requireNonNull(iterator, "iterator is null");
return create(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
* a table function having KEEP WHEN EMPTY property must have single distribution.
*/
public class EmptyTableFunctionPartition
implements TableFunctionPartition
implements TableFunctionInput
{
private final TableFunctionDataProcessor tableFunction;
private final int properChannelsCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
import static java.util.Objects.requireNonNull;

public class RegularTableFunctionPartition
implements TableFunctionPartition
implements TableFunctionInput
{
private final PagesIndex pagesIndex;
private final int partitionStart;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.function;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.operator.PagesIndex;
import io.trino.operator.WorkProcessor;
import io.trino.operator.WorkProcessor.ProcessState;
import io.trino.spi.Page;
import io.trino.spi.function.table.TableFunctionDataProcessor;
import io.trino.spi.function.table.TableFunctionProcessorState;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.concurrent.MoreFutures.toListenableFuture;
import static io.trino.operator.WorkProcessorUtils.yieldingIteratorFrom;
import static java.util.Objects.requireNonNull;

public class StreamTableFunctionInput
implements WorkProcessor.Process<WorkProcessor<Page>>
{
private final TableFunctionDataProcessor tableFunction;
private final int properChannelsCount;
private final int passThroughSourcesCount;
private final List<List<Integer>> requiredChannels;
private final Optional<Map<Integer, Integer>> markerChannels;
private final List<RegularTableFunctionPartition.PassThroughColumnSpecification> passThroughSpecifications;

private final Iterator<Optional<Page>> inputPages;
private final PagesIndex pagesIndex;
private boolean finished;

public StreamTableFunctionInput(
TableFunctionDataProcessor tableFunction,
int properChannelsCount,
int passThroughSourcesCount,
List<List<Integer>> requiredChannels,
Optional<Map<Integer, Integer>> markerChannels,
List<RegularTableFunctionPartition.PassThroughColumnSpecification> passThroughSpecifications,
WorkProcessor<Page> inputPages,
PagesIndex pagesIndex)
{
this.tableFunction = requireNonNull(tableFunction, "tableFunction is null");
this.properChannelsCount = properChannelsCount;
this.passThroughSourcesCount = passThroughSourcesCount;
this.requiredChannels = requiredChannels.stream()
.map(ImmutableList::copyOf)
.collect(toImmutableList());
this.markerChannels = markerChannels.map(ImmutableMap::copyOf);
this.passThroughSpecifications = ImmutableList.copyOf(requireNonNull(passThroughSpecifications, "passThroughSpecifications is null"));
this.inputPages = yieldingIteratorFrom(requireNonNull(inputPages, "inputPages is null"));
this.pagesIndex = requireNonNull(pagesIndex, "pagesIndex is null");
}

@Override
public ProcessState<WorkProcessor<Page>> process()
{
if (finished) {
return ProcessState.finished();
}

if (inputPages.hasNext()) {
Optional<Page> next = inputPages.next();
if (next.isEmpty()) {
return ProcessState.yielded();
}
Page currentInputPage = next.get();
pagesIndex.clear();
pagesIndex.addPage(currentInputPage);
return ProcessState.ofResult(new RegularTableFunctionPartition(
pagesIndex,
0,
pagesIndex.getPositionCount(),
new TableFunctionDataProcessor()
Copy link
Member

@tbaeg tbaeg May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to wrap in a new anonymous class?

Locally, I directly passed the tableFunction which also resolved the TestJsonTable failures for me.

{
@Override
public TableFunctionProcessorState process(@org.jetbrains.annotations.Nullable List<Optional<Page>> input)
{
if (input == null) {
// end of page
return TableFunctionProcessorState.Finished.FINISHED;
}
TableFunctionProcessorState process = tableFunction.process(input);
if (process instanceof TableFunctionProcessorState.Finished) {
finished = true;
}
return process;
}
},
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
markerChannels,
passThroughSpecifications)
.toOutputPages());
}

// finish
finished = true;
return ProcessState.ofResult(WorkProcessor.create(() -> {
TableFunctionProcessorState state = tableFunction.process(null);
return switch (state) {
case TableFunctionProcessorState.Processed processed -> ProcessState.ofResult(processed.getResult());
case TableFunctionProcessorState.Blocked blocked -> ProcessState.blocked(toListenableFuture(blocked.getFuture()));
case TableFunctionProcessorState.Finished __ -> ProcessState.finished();
};
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import io.trino.operator.WorkProcessor;
import io.trino.spi.Page;

public interface TableFunctionPartition
public interface TableFunctionInput
{
WorkProcessor<Page> toOutputPages();
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public class TableFunctionOperator
implements Operator
Expand Down Expand Up @@ -93,7 +94,7 @@ public static class TableFunctionOperatorFactory
private final boolean pruneWhenEmpty;

// partitioning channels from all sources
private final List<Integer> partitionChannels;
private final Optional<List<Integer>> partitionChannels;

// subset of partition channels that are already grouped
private final List<Integer> prePartitionedChannels;
Expand Down Expand Up @@ -125,7 +126,7 @@ public TableFunctionOperatorFactory(
Optional<Map<Integer, Integer>> markerChannels,
List<PassThroughColumnSpecification> passThroughSpecifications,
boolean pruneWhenEmpty,
List<Integer> partitionChannels,
Optional<List<Integer>> partitionChannels,
List<Integer> prePartitionedChannels,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
Expand All @@ -142,12 +143,20 @@ public TableFunctionOperatorFactory(
requireNonNull(passThroughSpecifications, "passThroughSpecifications is null");
requireNonNull(partitionChannels, "partitionChannels is null");
requireNonNull(prePartitionedChannels, "prePartitionedChannels is null");
checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
requireNonNull(sortChannels, "sortChannels is null");
requireNonNull(sortOrders, "sortOrders is null");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
if (partitionChannels.isPresent()) {
checkArgument(partitionChannels.get().containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels.get())), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
}
else {
checkArgument(prePartitionedChannels.isEmpty(), "prePartitionedChannels must be empty when partitionChannels is absent");
checkArgument(sortChannels.isEmpty(), "sortChannels must be empty when partitionChannels is absent");
checkArgument(sortOrders.isEmpty(), "sortOrders must be empty when partitionChannels is absent");
checkArgument(preSortedPrefix == 0, "preSortedPrefix must be zero when partitionChannels is absent");
}
requireNonNull(sourceTypes, "sourceTypes is null");
requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");

Expand All @@ -164,7 +173,7 @@ public TableFunctionOperatorFactory(
this.markerChannels = markerChannels.map(ImmutableMap::copyOf);
this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications);
this.pruneWhenEmpty = pruneWhenEmpty;
this.partitionChannels = ImmutableList.copyOf(partitionChannels);
this.partitionChannels = partitionChannels.map(ImmutableList::copyOf);
this.prePartitionedChannels = ImmutableList.copyOf(prePartitionedChannels);
this.sortChannels = ImmutableList.copyOf(sortChannels);
this.sortOrders = ImmutableList.copyOf(sortOrders);
Expand Down Expand Up @@ -250,7 +259,7 @@ public TableFunctionOperator(
Optional<Map<Integer, Integer>> markerChannels,
List<PassThroughColumnSpecification> passThroughSpecifications,
boolean pruneWhenEmpty,
List<Integer> partitionChannels,
Optional<List<Integer>> partitionChannels,
List<Integer> prePartitionedChannels,
List<Integer> sortChannels,
List<SortOrder> sortOrders,
Expand All @@ -268,12 +277,20 @@ public TableFunctionOperator(
requireNonNull(passThroughSpecifications, "passThroughSpecifications is null");
requireNonNull(partitionChannels, "partitionChannels is null");
requireNonNull(prePartitionedChannels, "prePartitionedChannels is null");
checkArgument(partitionChannels.containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
requireNonNull(sortChannels, "sortChannels is null");
requireNonNull(sortOrders, "sortOrders is null");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels)), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
if (partitionChannels.isPresent()) {
checkArgument(partitionChannels.get().containsAll(prePartitionedChannels), "prePartitionedChannels must be a subset of partitionChannels");
checkArgument(sortChannels.size() == sortOrders.size(), "The number of sort channels must be equal to the number of sort orders");
checkArgument(preSortedPrefix <= sortChannels.size(), "The number of pre-sorted channels must be lower or equal to the number of sort channels");
checkArgument(preSortedPrefix == 0 || ImmutableSet.copyOf(prePartitionedChannels).equals(ImmutableSet.copyOf(partitionChannels.get())), "preSortedPrefix can only be greater than zero if all partition channels are pre-grouped");
}
else {
checkArgument(prePartitionedChannels.isEmpty(), "prePartitionedChannels must be empty when partitionChannels is absent");
checkArgument(sortChannels.isEmpty(), "sortChannels must be empty when partitionChannels is absent");
checkArgument(sortOrders.isEmpty(), "sortOrders must be empty when partitionChannels is absent");
checkArgument(preSortedPrefix == 0, "preSortedPrefix must be zero when partitionChannels is absent");
}
requireNonNull(sourceTypes, "sourceTypes is null");
requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");

Expand All @@ -283,23 +300,36 @@ public TableFunctionOperator(
this.processEmptyInput = !pruneWhenEmpty;

PagesIndex pagesIndex = pagesIndexFactory.newPagesIndex(sourceTypes, expectedPositions);
HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels, prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix);

this.outputPages = pageBuffer.pages()
.transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput))
.flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions(
groupPagesIndex,
hashStrategies,
tableFunctionProvider,
session,
functionHandle,
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
markerChannels,
passThroughSpecifications,
processEmptyInput))
.flatMap(TableFunctionPartition::toOutputPages);
if (partitionChannels.isEmpty()) {
this.outputPages = WorkProcessor.create(new StreamTableFunctionInput(
tableFunctionProvider.getDataProcessor(session, functionHandle),
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
markerChannels,
passThroughSpecifications,
pageBuffer.pages(),
pagesIndex))
.flatMap(identity());
}
else {
HashStrategies hashStrategies = new HashStrategies(pagesIndex, partitionChannels.get(), prePartitionedChannels, sortChannels, sortOrders, preSortedPrefix);
this.outputPages = pageBuffer.pages()
.transform(new PartitionAndSort(pagesIndex, hashStrategies, processEmptyInput))
.flatMap(groupPagesIndex -> pagesIndexToTableFunctionPartitions(
groupPagesIndex,
hashStrategies,
tableFunctionProvider,
session,
functionHandle,
properChannelsCount,
passThroughSourcesCount,
requiredChannels,
markerChannels,
passThroughSpecifications,
processEmptyInput))
.flatMap(TableFunctionInput::toOutputPages);
}
}

@Override
Expand Down Expand Up @@ -530,7 +560,7 @@ private static int findGroupEnd(PagesIndex pagesIndex, PagesHashStrategy pagesHa
return findEndPosition(startPosition, pagesIndex.getPositionCount(), (firstPosition, secondPosition) -> pagesIndex.positionIdenticalToPosition(pagesHashStrategy, firstPosition, secondPosition));
}

private WorkProcessor<TableFunctionPartition> pagesIndexToTableFunctionPartitions(
private WorkProcessor<TableFunctionInput> pagesIndexToTableFunctionPartitions(
PagesIndex pagesIndex,
HashStrategies hashStrategies,
TableFunctionProcessorProvider tableFunctionProvider,
Expand All @@ -553,7 +583,7 @@ private WorkProcessor<TableFunctionPartition> pagesIndexToTableFunctionPartition
private boolean processEmpty = processEmptyInput;

@Override
public WorkProcessor.ProcessState<TableFunctionPartition> process()
public WorkProcessor.ProcessState<TableFunctionInput> process()
{
if (partitionStart == pagesIndex.getPositionCount()) {
if (processEmpty && pagesIndex.getPositionCount() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1660,10 +1660,9 @@ public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode
}
}

List<Integer> partitionChannels = node.getSpecification()
Optional<List<Integer>> partitionChannels = node.getSpecification()
.map(DataOrganizationSpecification::partitionBy)
.map(list -> getChannelsForSymbols(list, source.getLayout()))
.orElse(ImmutableList.of());
.map(list -> getChannelsForSymbols(list, source.getLayout()));

List<Integer> sortChannels = ImmutableList.of();
List<SortOrder> sortOrders = ImmutableList.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,13 @@
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.testing.TestingSession.testSessionBuilder;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;

@TestInstance(PER_CLASS)
public class TestFilterStatsRule
extends BaseStatsCalculatorTest
{
Expand Down
Loading
Loading