Skip to content

Commit

Permalink
[BEAM-13099] Replace call to RelNode.metadata with BeamRelMetadataQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNeuralBit committed Oct 28, 2021
1 parent 4cdbe8d commit 0e5451f
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.stream.Collectors;
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters.Kind;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.RelMdNodeStats;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamLogicalConvention;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
Expand Down Expand Up @@ -202,6 +203,8 @@ public BeamRelNode convertToBeamRel(String sqlStatement, QueryParameters queryPa
NonCumulativeCostImpl.SOURCE,
RelMdNodeStats.SOURCE,
root.rel.getCluster().getMetadataProvider())));

root.rel.getCluster().setMetadataQuerySupplier(BeamRelMetadataQuery::instance);
RelMetadataQuery.THREAD_PROVIDERS.set(
JaninoRelMetadataProvider.of(root.rel.getCluster().getMetadataProvider()));
root.rel.getCluster().invalidateMetadataQuery();
Expand Down Expand Up @@ -233,26 +236,29 @@ public MetadataDef<BuiltInMetadata.NonCumulativeCost> getDef() {

@SuppressWarnings("UnusedDeclaration")
public RelOptCost getNonCumulativeCost(RelNode rel, RelMetadataQuery mq) {
assert mq instanceof BeamRelMetadataQuery;
BeamRelMetadataQuery bmq = (BeamRelMetadataQuery) mq;

// This is called by a generated code in calcite MetadataQuery.
// If the rel is Calcite rel or we are in JDBC path and cost factory is not set yet we should
// use calcite cost estimation
if (!(rel instanceof BeamRelNode)) {
return rel.computeSelfCost(rel.getCluster().getPlanner(), mq);
return rel.computeSelfCost(rel.getCluster().getPlanner(), bmq);
}

// Currently we do nothing in this case, however, we can plug our own cost estimation method
// here and based on the design we also need to remove the cached values

// We need to first remove the cached values.
List<Table.Cell<RelNode, List, Object>> costKeys =
mq.map.cellSet().stream()
bmq.map.cellSet().stream()
.filter(entry -> entry.getValue() instanceof BeamCostModel)
.filter(entry -> ((BeamCostModel) entry.getValue()).isInfinite())
.collect(Collectors.toList());

costKeys.forEach(cell -> mq.map.remove(cell.getRowKey(), cell.getColumnKey()));
costKeys.forEach(cell -> bmq.map.remove(cell.getRowKey(), cell.getColumnKey()));

return ((BeamRelNode) rel).beamComputeSelfCost(rel.getCluster().getPlanner(), mq);
return ((BeamRelNode) rel).beamComputeSelfCost(rel.getCluster().getPlanner(), bmq);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl.planner;

import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.JaninoRelMetadataProvider;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;

public class BeamRelMetadataQuery extends RelMetadataQuery {
private NodeStatsMetadata.Handler nodeStatsMetadataHandler;

private BeamRelMetadataQuery() {
nodeStatsMetadataHandler = initialHandler(NodeStatsMetadata.Handler.class);
}

public static BeamRelMetadataQuery instance() {
return new BeamRelMetadataQuery();
}

public NodeStats getNodeStats(RelNode relNode) {
// Note this infinite loop was duplicated from logic in Calcite's RelMetadataQuery.get* methods
for (; ; ) {
try {
NodeStats result = nodeStatsMetadataHandler.getNodeStats(relNode, this);
return result;
} catch (JaninoRelMetadataProvider.NoHandler e) {
nodeStatsMetadataHandler = revise(e.relClass, NodeStatsMetadata.DEF);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ public MetadataDef<NodeStatsMetadata> getDef() {

@SuppressWarnings("UnusedDeclaration")
public NodeStats getNodeStats(RelNode rel, RelMetadataQuery mq) {
assert mq instanceof BeamRelMetadataQuery;
BeamRelMetadataQuery bmq = (BeamRelMetadataQuery) mq;

if (rel instanceof BeamRelNode) {
return this.getBeamNodeStats((BeamRelNode) rel, mq);
return this.getBeamNodeStats((BeamRelNode) rel, bmq);
}

// We can later define custom methods for all different RelNodes to prevent hitting this point.
Expand All @@ -61,7 +63,7 @@ public NodeStats getNodeStats(RelNode rel, RelMetadataQuery mq) {
return NodeStats.UNKNOWN;
}

private NodeStats getBeamNodeStats(BeamRelNode rel, RelMetadataQuery mq) {
private NodeStats getBeamNodeStats(BeamRelNode rel, BeamRelMetadataQuery mq) {

// Removing the unknown results.
// Calcite caches previous results in mq.map. This is done to prevent cyclic calls of this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStatsMetadata;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.plan.volcano.RelSubset;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.metadata.RelMetadataQuery;

/** Utilities for {@code BeamRelNode}. */
@SuppressWarnings({
Expand Down Expand Up @@ -103,8 +102,8 @@ public static RelNode getInput(RelNode input) {
return result;
}

public static NodeStats getNodeStats(RelNode input, RelMetadataQuery mq) {
public static NodeStats getNodeStats(RelNode input, BeamRelMetadataQuery mq) {
input = getInput(input);
return input.metadata(NodeStatsMetadata.class, mq).getNodeStats();
return mq.getNodeStats(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ public void testSubsetHavingBest() {
// tests if we are actually testing what we want.
Assert.assertTrue(root instanceof RelSubset);

NodeStats estimates = BeamSqlRelUtils.getNodeStats(root, root.getCluster().getMetadataQuery());
NodeStats estimates =
BeamSqlRelUtils.getNodeStats(
root, ((BeamRelMetadataQuery) root.getCluster().getMetadataQuery()));
Assert.assertFalse(estimates.isUnknown());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner;
import org.apache.beam.sdk.extensions.sql.impl.SqlConversionException;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRuleSets;
import org.apache.beam.sdk.extensions.sql.impl.planner.RelMdNodeStats;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamLogicalConvention;
Expand Down Expand Up @@ -210,6 +211,9 @@ private BeamRelNode convertToBeamRelInternal(String sql, QueryParameters queryPa
NonCumulativeCostImpl.SOURCE,
RelMdNodeStats.SOURCE,
root.rel.getCluster().getMetadataProvider())));

root.rel.getCluster().setMetadataQuerySupplier(BeamRelMetadataQuery::instance);

RelMetadataQuery.THREAD_PROVIDERS.set(
JaninoRelMetadataProvider.of(root.rel.getCluster().getMetadataProvider()));
root.rel.getCluster().invalidateMetadataQuery();
Expand Down

0 comments on commit 0e5451f

Please sign in to comment.