Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connector SPI for scale writers options #18561

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions core/trino-main/src/main/java/io/trino/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import io.trino.spi.connector.TableFunctionApplicationResult;
import io.trino.spi.connector.TableScanRedirectApplicationResult;
import io.trino.spi.connector.TopNApplicationResult;
import io.trino.spi.connector.WriterScalingOptions;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.function.FunctionMetadata;
Expand Down Expand Up @@ -751,4 +752,14 @@ default boolean isMaterializedView(Session session, QualifiedObjectName viewName
* Note: It is ignored when retry policy is set to TASK
*/
OptionalInt getMaxWriterTasks(Session session, String catalogName);

/**
* Returns writer scaling options for the specified table. This method is called when table handle is not available during CTAS.
*/
WriterScalingOptions getNewTableWriterScalingOptions(Session session, QualifiedObjectName tableName, Map<String, Object> tableProperties);

/**
* Returns writer scaling options for the specified table.
*/
WriterScalingOptions getInsertWriterScalingOptions(Session session, TableHandle tableHandle);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import io.trino.spi.connector.TableFunctionApplicationResult;
import io.trino.spi.connector.TableScanRedirectApplicationResult;
import io.trino.spi.connector.TopNApplicationResult;
import io.trino.spi.connector.WriterScalingOptions;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.function.AggregationFunctionMetadata;
Expand Down Expand Up @@ -2671,6 +2672,22 @@ public OptionalInt getMaxWriterTasks(Session session, String catalogName)
return catalogMetadata.getMetadata(session).getMaxWriterTasks(session.toConnectorSession(catalogHandle));
}

@Override
public WriterScalingOptions getNewTableWriterScalingOptions(Session session, QualifiedObjectName tableName, Map<String, Object> tableProperties)
{
CatalogMetadata catalogMetadata = getCatalogMetadataForWrite(session, tableName.getCatalogName());
CatalogHandle catalogHandle = catalogMetadata.getCatalogHandle(session, tableName);
ConnectorMetadata metadata = catalogMetadata.getMetadataFor(session, catalogHandle);
return metadata.getNewTableWriterScalingOptions(session.toConnectorSession(catalogHandle), tableName.asSchemaTableName(), tableProperties);
}

@Override
public WriterScalingOptions getInsertWriterScalingOptions(Session session, TableHandle tableHandle)
{
ConnectorMetadata metadata = getMetadataForWrite(session, tableHandle.getCatalogHandle());
return metadata.getInsertWriterScalingOptions(session.toConnectorSession(tableHandle.getCatalogHandle()), tableHandle.getConnectorHandle());
}

private Optional<ConnectorTableVersion> toConnectorVersion(Optional<TableVersion> version)
{
Optional<ConnectorTableVersion> connectorVersion = Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
import io.trino.spi.connector.DynamicFilter;
import io.trino.spi.connector.RecordSet;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.connector.WriterScalingOptions;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionId;
Expand Down Expand Up @@ -379,6 +380,7 @@
import static io.trino.util.SpatialJoinUtils.ST_WITHIN;
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialComparisons;
import static io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -3280,7 +3282,11 @@ public PhysicalOperation visitRefreshMaterializedView(RefreshMaterializedViewNod
public PhysicalOperation visitTableWriter(TableWriterNode node, LocalExecutionPlanContext context)
{
// Set table writer count
int maxWriterCount = getWriterCount(session, node.getPartitioningScheme(), node.getSource());
int maxWriterCount = getWriterCount(
session,
node.getTarget().getWriterScalingOptions(metadata, session),
node.getPartitioningScheme(),
node.getSource());
context.setDriverInstanceCount(maxWriterCount);
context.taskContext.setMaxWriterCount(maxWriterCount);

Expand Down Expand Up @@ -3438,7 +3444,11 @@ public PhysicalOperation visitSimpleTableExecuteNode(SimpleTableExecuteNode node
public PhysicalOperation visitTableExecute(TableExecuteNode node, LocalExecutionPlanContext context)
{
// Set table writer count
int maxWriterCount = getWriterCount(session, node.getPartitioningScheme(), node.getSource());
int maxWriterCount = getWriterCount(
session,
node.getTarget().getWriterScalingOptions(metadata, session),
node.getPartitioningScheme(),
node.getSource());
context.setDriverInstanceCount(maxWriterCount);
context.taskContext.setMaxWriterCount(maxWriterCount);

Expand All @@ -3465,7 +3475,7 @@ public PhysicalOperation visitTableExecute(TableExecuteNode node, LocalExecution
return new PhysicalOperation(operatorFactory, outputMapping.buildOrThrow(), context, source);
}

private int getWriterCount(Session session, Optional<PartitioningScheme> partitioningScheme, PlanNode source)
private int getWriterCount(Session session, WriterScalingOptions connectorScalingOptions, Optional<PartitioningScheme> partitioningScheme, PlanNode source)
{
// This check is required because we don't know which writer count to use when exchange is
// single distribution. It could be possible that when scaling is enabled, a single distribution is
Expand All @@ -3475,12 +3485,24 @@ private int getWriterCount(Session session, Optional<PartitioningScheme> partiti
return 1;
}

// The default value of partitioned writer count is 32 which is high enough to use it
// for both cases when scaling is enabled or not. Additionally, it doesn't lead to too many
// small files since when scaling is disabled only single writer will handle a single partition.
return partitioningScheme
.map(scheme -> getTaskPartitionedWriterCount(session))
.orElseGet(() -> isLocalScaledWriterExchange(source) ? getTaskScaleWritersMaxWriterCount(session) : getTaskWriterCount(session));
if (partitioningScheme.isPresent()) {
// The default value of partitioned writer count is 32 which is high enough to use it
// for both cases when scaling is enabled or not. Additionally, it doesn't lead to too many
// small files since when scaling is disabled only single writer will handle a single partition.
if (isLocalScaledWriterExchange(source)) {
return connectorScalingOptions.perTaskMaxScaledWriterCount()
.map(writerCount -> min(writerCount, getTaskPartitionedWriterCount(session)))
.orElse(getTaskPartitionedWriterCount(session));
}
return getTaskPartitionedWriterCount(session);
}

if (isLocalScaledWriterExchange(source)) {
return connectorScalingOptions.perTaskMaxScaledWriterCount()
.map(writerCount -> min(writerCount, getTaskScaleWritersMaxWriterCount(session)))
.orElse(getTaskScaleWritersMaxWriterCount(session));
}
return getTaskWriterCount(session);
}

private boolean isSingleGatheringExchange(PlanNode node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,11 @@ private RelationPlan createTableExecutePlan(Analysis analysis, TableExecute stat
.map(ColumnMetadata::getName)
.collect(toImmutableList());

TableWriterNode.TableExecuteTarget tableExecuteTarget = new TableWriterNode.TableExecuteTarget(executeHandle, Optional.empty(), tableName.asSchemaTableName());
TableWriterNode.TableExecuteTarget tableExecuteTarget = new TableWriterNode.TableExecuteTarget(
executeHandle,
Optional.empty(),
tableName.asSchemaTableName(),
metadata.getInsertWriterScalingOptions(session, tableHandle));

Optional<TableLayout> layout = metadata.getLayoutForTableExecute(session, executeHandle);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.connector.GroupingProperty;
import io.trino.spi.connector.LocalProperty;
import io.trino.spi.connector.WriterScalingOptions;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.Partitioning;
Expand Down Expand Up @@ -760,20 +761,22 @@ public PlanWithProperties visitMergeWriter(MergeWriterNode node, PreferredProper

private PlanWithProperties getWriterPlanWithProperties(Optional<PartitioningScheme> partitioningScheme, PlanWithProperties newSource, TableWriterNode.WriterTarget writerTarget)
{
WriterScalingOptions scalingOptions = writerTarget.getWriterScalingOptions(plannerContext.getMetadata(), session);
if (partitioningScheme.isEmpty()) {
// use maxWritersTasks to set PartitioningScheme.partitionCount field to limit number of tasks that will take part in executing writing stage
int maxWriterTasks = writerTarget.getMaxWriterTasks(plannerContext.getMetadata(), session).orElse(getMaxWriterTaskCount(session));
Optional<Integer> maxWritersNodesCount = getRetryPolicy(session) != RetryPolicy.TASK
? Optional.of(Math.min(maxWriterTasks, getMaxWriterTaskCount(session)))
: Optional.empty();
if (scaleWriters) {
if (scaleWriters && scalingOptions.isWriterTasksScalingEnabled()) {
partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols(), Optional.empty(), false, Optional.empty(), maxWritersNodesCount));
}
else if (redistributeWrites) {
partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), newSource.getNode().getOutputSymbols(), Optional.empty(), false, Optional.empty(), maxWritersNodesCount));
}
}
else if (scaleWriters
&& scalingOptions.isWriterTasksScalingEnabled()
&& writerTarget.supportsMultipleWritersPerPartition(plannerContext.getMetadata(), session)
// do not insert an exchange if partitioning is compatible
&& !newSource.getProperties().isCompatibleTablePartitioningWith(partitioningScheme.get().getPartitioning(), false, plannerContext.getMetadata(), session)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.trino.spi.connector.ConstantProperty;
import io.trino.spi.connector.GroupingProperty;
import io.trino.spi.connector.LocalProperty;
import io.trino.spi.connector.WriterScalingOptions;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningHandle;
Expand Down Expand Up @@ -686,13 +687,14 @@ public PlanWithProperties visitSimpleTableExecuteNode(SimpleTableExecuteNode nod
@Override
public PlanWithProperties visitTableWriter(TableWriterNode node, StreamPreferredProperties parentPreferences)
{
WriterScalingOptions scalingOptions = node.getTarget().getWriterScalingOptions(plannerContext.getMetadata(), session);
return visitTableWriter(
node,
node.getPartitioningScheme(),
node.getSource(),
parentPreferences,
node.getTarget(),
isTaskScaleWritersEnabled(session));
isTaskScaleWritersEnabled(session) && scalingOptions.isPerTaskWriterScalingEnabled());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ public WriterTarget getWriterTarget(PlanNode node)
return new TableExecuteTarget(
target.getExecuteHandle(),
findTableScanHandleForTableExecute(((TableExecuteNode) node).getSource()),
target.getSchemaTableName());
target.getSchemaTableName(),
target.getWriterScalingOptions());
}

if (node instanceof MergeWriterNode mergeWriterNode) {
Expand Down Expand Up @@ -244,14 +245,16 @@ private WriterTarget createWriterTarget(WriterTarget target)
metadata.beginCreateTable(session, create.getCatalog(), create.getTableMetadata(), create.getLayout()),
create.getTableMetadata().getTable(),
target.supportsMultipleWritersPerPartition(metadata, session),
target.getMaxWriterTasks(metadata, session));
target.getMaxWriterTasks(metadata, session),
target.getWriterScalingOptions(metadata, session));
}
if (target instanceof InsertReference insert) {
return new InsertTarget(
metadata.beginInsert(session, insert.getHandle(), insert.getColumns()),
metadata.getTableName(session, insert.getHandle()).getSchemaTableName(),
target.supportsMultipleWritersPerPartition(metadata, session),
target.getMaxWriterTasks(metadata, session));
target.getMaxWriterTasks(metadata, session),
target.getWriterScalingOptions(metadata, session));
}
if (target instanceof MergeTarget merge) {
MergeHandle mergeHandle = metadata.beginMerge(session, merge.getHandle());
Expand All @@ -266,11 +269,12 @@ private WriterTarget createWriterTarget(WriterTarget target)
refreshMV.getStorageTableHandle(),
metadata.beginRefreshMaterializedView(session, refreshMV.getStorageTableHandle(), refreshMV.getSourceTableHandles()),
metadata.getTableName(session, refreshMV.getStorageTableHandle()).getSchemaTableName(),
refreshMV.getSourceTableHandles());
refreshMV.getSourceTableHandles(),
refreshMV.getWriterScalingOptions(metadata, session));
}
if (target instanceof TableExecuteTarget tableExecute) {
BeginTableExecuteResult<TableExecuteHandle, TableHandle> result = metadata.beginTableExecute(session, tableExecute.getExecuteHandle(), tableExecute.getMandatorySourceHandle());
return new TableExecuteTarget(result.getTableExecuteHandle(), Optional.of(result.getSourceHandle()), tableExecute.getSchemaTableName());
return new TableExecuteTarget(result.getTableExecuteHandle(), Optional.of(result.getSourceHandle()), tableExecute.getSchemaTableName(), tableExecute.getWriterScalingOptions());
}
throw new IllegalArgumentException("Unhandled target type: " + target.getClass().getSimpleName());
}
Expand Down
Loading