Skip to content

Commit

Permalink
[BEAM-13099] Use RelMetadataQuery subclass rather than calling RelNod…
Browse files Browse the repository at this point in the history
…e.metadata (#15820)

[BEAM-13099] Use RelMetadataQuery subclass rather than calling RelNode.metadata
  • Loading branch information
TheNeuralBit authored Oct 28, 2021
2 parents 4cdbe8d + 60eaa82 commit cd4b7f3
Show file tree
Hide file tree
Showing 37 changed files with 208 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.stream.Collectors;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters.Kind;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.RelMdNodeStats;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamLogicalConvention;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
Expand Down Expand Up @@ -202,6 +203,8 @@ public BeamRelNode convertToBeamRel(String sqlStatement, QueryParameters queryPa
NonCumulativeCostImpl.SOURCE,
RelMdNodeStats.SOURCE,
root.rel.getCluster().getMetadataProvider())));

root.rel.getCluster().setMetadataQuerySupplier(BeamRelMetadataQuery::instance);
RelMetadataQuery.THREAD_PROVIDERS.set(
JaninoRelMetadataProvider.of(root.rel.getCluster().getMetadataProvider()));
root.rel.getCluster().invalidateMetadataQuery();
Expand Down Expand Up @@ -233,26 +236,29 @@ public MetadataDef<BuiltInMetadata.NonCumulativeCost> getDef() {

@SuppressWarnings("UnusedDeclaration")
public RelOptCost getNonCumulativeCost(RelNode rel, RelMetadataQuery mq) {
assert mq instanceof BeamRelMetadataQuery;
BeamRelMetadataQuery bmq = (BeamRelMetadataQuery) mq;

// This is called by a generated code in calcite MetadataQuery.
// If the rel is Calcite rel or we are in JDBC path and cost factory is not set yet we should
// use calcite cost estimation
if (!(rel instanceof BeamRelNode)) {
return rel.computeSelfCost(rel.getCluster().getPlanner(), mq);
return rel.computeSelfCost(rel.getCluster().getPlanner(), bmq);
}

// Currently we do nothing in this case, however, we can plug our own cost estimation method
// here and based on the design we also need to remove the cached values

// We need to first remove the cached values.
List<Table.Cell<RelNode, List, Object>> costKeys =
mq.map.cellSet().stream()
bmq.map.cellSet().stream()
.filter(entry -> entry.getValue() instanceof BeamCostModel)
.filter(entry -> ((BeamCostModel) entry.getValue()).isInfinite())
.collect(Collectors.toList());

costKeys.forEach(cell -> mq.map.remove(cell.getRowKey(), cell.getColumnKey()));
costKeys.forEach(cell -> bmq.map.remove(cell.getRowKey(), cell.getColumnKey()));

return ((BeamRelNode) rel).beamComputeSelfCost(rel.getCluster().getPlanner(), mq);
return ((BeamRelNode) rel).beamComputeSelfCost(rel.getCluster().getPlanner(), bmq);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.planner;

import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.JaninoRelMetadataProvider;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;

public class BeamRelMetadataQuery extends RelMetadataQuery {
private NodeStatsMetadata.Handler nodeStatsMetadataHandler;

private BeamRelMetadataQuery() {
nodeStatsMetadataHandler = initialHandler(NodeStatsMetadata.Handler.class);
}

public static BeamRelMetadataQuery instance() {
return new BeamRelMetadataQuery();
}

public NodeStats getNodeStats(RelNode relNode) {
// Note this infinite loop was duplicated from logic in Calcite's RelMetadataQuery.get* methods
for (; ; ) {
try {
NodeStats result = nodeStatsMetadataHandler.getNodeStats(relNode, this);
return result;
} catch (JaninoRelMetadataProvider.NoHandler e) {
nodeStatsMetadataHandler = revise(e.relClass, NodeStatsMetadata.DEF);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ public MetadataDef<NodeStatsMetadata> getDef() {

@SuppressWarnings("UnusedDeclaration")
public NodeStats getNodeStats(RelNode rel, RelMetadataQuery mq) {
assert mq instanceof BeamRelMetadataQuery;
BeamRelMetadataQuery bmq = (BeamRelMetadataQuery) mq;

if (rel instanceof BeamRelNode) {
return this.getBeamNodeStats((BeamRelNode) rel, mq);
return this.getBeamNodeStats((BeamRelNode) rel, bmq);
}

// We can later define custom methods for all different RelNodes to prevent hitting this point.
Expand All @@ -61,7 +63,7 @@ public NodeStats getNodeStats(RelNode rel, RelMetadataQuery mq) {
return NodeStats.UNKNOWN;
}

private NodeStats getBeamNodeStats(BeamRelNode rel, RelMetadataQuery mq) {
private NodeStats getBeamNodeStats(BeamRelNode rel, BeamRelMetadataQuery mq) {

// Removing the unknown results.
// Calcite caches previous results in mq.map. This is done to prevent cyclic calls of this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.plan.RelOptCluster;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.plan.RelOptPlanner;
Expand Down Expand Up @@ -55,7 +56,7 @@ public int getLimitCountOfSortRel() {
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
NodeStats inputStat = BeamSqlRelUtils.getNodeStats(input, mq);
double selectivity = estimateFilterSelectivity(getInput(), program, mq);

Expand All @@ -78,7 +79,7 @@ private static double estimateFilterSelectivity(
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
NodeStats inputStat = BeamSqlRelUtils.getNodeStats(this.input, mq);
return BeamCostModel.FACTORY.makeCost(inputStat.getRowCount(), inputStat.getRate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import java.util.List;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlPipelineOptions;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.transform.agg.AggregationCombineFnAdapter;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.transforms.Group;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
Expand All @@ -58,7 +58,6 @@
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelWriter;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.Aggregate;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand Down Expand Up @@ -91,7 +90,7 @@ public BeamAggregationRel(
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {

NodeStats inputStat = BeamSqlRelUtils.getNodeStats(this.input, mq);
inputStat = computeWindowingCostEffect(inputStat);
Expand All @@ -111,7 +110,7 @@ public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {

NodeStats inputEstimate = BeamSqlRelUtils.getNodeStats(this.input, mq);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.rule.BeamIOSinkRule;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
Expand All @@ -39,7 +40,6 @@
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.prepare.Prepare;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.TableModify;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.sql2rel.RelStructuredTypeFlattener;

Expand Down Expand Up @@ -77,12 +77,12 @@ public BeamIOSinkRel(
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
return BeamSqlRelUtils.getNodeStats(this.input, mq);
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
NodeStats inputEstimates = BeamSqlRelUtils.getNodeStats(this.input, mq);
return BeamCostModel.FACTORY.makeCost(inputEstimates.getRowCount(), inputEstimates.getRate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.beam.sdk.extensions.sql.impl.BeamCalciteTable;
import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTableFilter;
Expand Down Expand Up @@ -88,7 +89,7 @@ public double estimateRowCount(RelMetadataQuery mq) {
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
BeamTableStatistics rowCountStatistics = calciteTable.getStatistic();
double window =
(beamTable.isBounded() == PCollection.IsBounded.BOUNDED)
Expand Down Expand Up @@ -130,7 +131,7 @@ public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
NodeStats estimates = BeamSqlRelUtils.getNodeStats(this, mq);
return BeamCostModel.FACTORY.makeCost(estimates.getRowCount(), estimates.getRate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.List;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
Expand All @@ -30,7 +31,6 @@
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.Intersect;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.SetOp;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;

/**
* {@code BeamRelNode} to replace a {@code Intersect} node.
Expand All @@ -55,7 +55,7 @@ public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
// This takes the minimum of the inputs for all the estimate factors.
double minimumRows = Double.POSITIVE_INFINITY;
double minimumWindowSize = Double.POSITIVE_INFINITY;
Expand All @@ -72,7 +72,7 @@ public NodeStats estimateNodeStats(RelMetadataQuery mq) {
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {

NodeStats inputsStatSummation =
inputs.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Set;
import org.apache.beam.sdk.extensions.sql.BeamSqlSeekableTable;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
import org.apache.beam.sdk.values.PCollection;
Expand All @@ -35,7 +36,6 @@
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.CorrelationId;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.Join;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.JoinRelType;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexFieldAccess;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexInputRef;
Expand Down Expand Up @@ -111,7 +111,7 @@ public static boolean seekable(BeamRelNode relNode) {
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
NodeStats leftEstimates = BeamSqlRelUtils.getNodeStats(this.left, mq);
NodeStats rightEstimates = BeamSqlRelUtils.getNodeStats(this.right, mq);
NodeStats selfEstimates = BeamSqlRelUtils.getNodeStats(this, mq);
Expand All @@ -120,7 +120,7 @@ public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
double selectivity = mq.getSelectivity(this, getCondition());
NodeStats leftEstimates = BeamSqlRelUtils.getNodeStats(this.left, mq);
NodeStats rightEstimates = BeamSqlRelUtils.getNodeStats(this.right, mq);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.beam.sdk.extensions.sql.impl.cep.OrderKey;
import org.apache.beam.sdk.extensions.sql.impl.nfa.NFA;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
Expand All @@ -54,7 +55,6 @@
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelCollation;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.Match;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rex.RexNode;
Expand Down Expand Up @@ -110,12 +110,12 @@ public BeamMatchRel(
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
return BeamCostModel.FACTORY.makeTinyCost(); // return constant costModel for now
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
// a simple way of getting some estimate data
// to be examined further
NodeStats inputEstimate = BeamSqlRelUtils.getNodeStats(input, mq);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.util.List;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
Expand All @@ -30,7 +31,6 @@
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.Minus;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.core.SetOp;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;

/**
* {@code BeamRelNode} to replace a {@code Minus} node.
Expand All @@ -45,7 +45,7 @@ public BeamMinusRel(
}

@Override
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
public BeamCostModel beamComputeSelfCost(RelOptPlanner planner, BeamRelMetadataQuery mq) {
NodeStats inputsEstimatesSummation =
inputs.stream()
.map(input -> BeamSqlRelUtils.getNodeStats(input, mq))
Expand All @@ -66,7 +66,7 @@ public PTransform<PCollectionList<Row>, PCollection<Row>> buildPTransform() {
}

@Override
public NodeStats estimateNodeStats(RelMetadataQuery mq) {
public NodeStats estimateNodeStats(BeamRelMetadataQuery mq) {
NodeStats firstInputEstimates = BeamSqlRelUtils.getNodeStats(inputs.get(0), mq);
// The first input minus half of the others. (We are assuming half of them have intersection)
for (int i = 1; i < inputs.size(); i++) {
Expand Down
Loading

0 comments on commit cd4b7f3

Please sign in to comment.