Skip to content

Commit

Permalink
Remove session parameter from RowExpressionFormatter's constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova committed Sep 5, 2019
1 parent 7c28d0b commit bca2bce
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.facebook.presto.metadata.OperatorNotFoundException;
import com.facebook.presto.operator.StageExecutionDescriptor;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.ConnectorTableLayoutHandle;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.function.FunctionHandle;
Expand Down Expand Up @@ -110,6 +111,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -135,7 +137,7 @@ public class PlanPrinter
{
private final PlanRepresentation representation;
private final FunctionManager functionManager;
private final RowExpressionFormatter formatter;
private final Function<RowExpression, String> formatter;

private PlanPrinter(
PlanNode planRoot,
Expand Down Expand Up @@ -163,7 +165,10 @@ private PlanPrinter(
.sum(), MILLISECONDS));

this.representation = new PlanRepresentation(planRoot, types, totalCpuTime, totalScheduledTime);
this.formatter = new RowExpressionFormatter(session.toConnectorSession(), functionManager);

RowExpressionFormatter rowExpressionFormatter = new RowExpressionFormatter(functionManager);
ConnectorSession connectorSession = requireNonNull(session, "session is null").toConnectorSession();
this.formatter = rowExpression -> rowExpressionFormatter.formatRowExpression(connectorSession, rowExpression);

Visitor visitor = new Visitor(stageExecutionStrategy, types, estimatedStatsAndCosts, session, stats);
planRoot.accept(visitor, null);
Expand Down Expand Up @@ -364,7 +369,7 @@ public Void visitJoin(JoinNode node, Void context)
for (JoinNode.EquiJoinClause clause : node.getCriteria()) {
joinExpressions.add(JoinNodeUtils.toExpression(clause).toString());
}
node.getFilter().map(formatter::formatRowExpression).ifPresent(joinExpressions::add);
node.getFilter().map(formatter::apply).ifPresent(joinExpressions::add);

NodeRepresentation nodeOutput;
if (node.isCrossJoin()) {
Expand All @@ -379,7 +384,7 @@ public Void visitJoin(JoinNode node, Void context)

node.getDistributionType().ifPresent(distributionType -> nodeOutput.appendDetails("Distribution: %s", distributionType));
node.getSortExpressionContext(functionManager)
.ifPresent(sortContext -> nodeOutput.appendDetails("SortExpression[%s]", formatter.formatRowExpression(sortContext.getSortExpression())));
.ifPresent(sortContext -> nodeOutput.appendDetails("SortExpression[%s]", formatter.apply(sortContext.getSortExpression())));
node.getLeft().accept(this, context);
node.getRight().accept(this, context);

Expand All @@ -391,7 +396,7 @@ public Void visitSpatialJoin(SpatialJoinNode node, Void context)
{
NodeRepresentation nodeOutput = addNode(node,
node.getType().getJoinLabel(),
format("[%s]", formatter.formatRowExpression(node.getFilter())));
format("[%s]", formatter.apply(node.getFilter())));

nodeOutput.appendDetailsLine("Distribution: %s", node.getDistributionType());
node.getLeft().accept(this, context);
Expand Down Expand Up @@ -507,10 +512,10 @@ private String formatAggregation(AggregationNode.Aggregation aggregation)
builder.append("*");
}
else {
builder.append("(" + Joiner.on(",").join(aggregation.getArguments().stream().map(formatter::formatRowExpression).collect(toImmutableList())) + ")");
builder.append("(" + Joiner.on(",").join(aggregation.getArguments().stream().map(formatter::apply).collect(toImmutableList())) + ")");
}
builder.append(")");
aggregation.getFilter().ifPresent(filter -> builder.append(" WHERE " + formatter.formatRowExpression(filter)));
aggregation.getFilter().ifPresent(filter -> builder.append(" WHERE " + formatter.apply(filter)));
aggregation.getOrderBy().ifPresent(orderingScheme -> builder.append(" ORDER BY " + orderingScheme.toString()));
aggregation.getMask().ifPresent(mask -> builder.append(" (mask = " + mask + ")"));
return builder.toString();
Expand Down Expand Up @@ -596,7 +601,7 @@ public Void visitWindow(WindowNode node, Void context)
"%s := %s(%s) %s",
entry.getKey(),
call.getDisplayName(),
Joiner.on(", ").join(call.getArguments().stream().map(formatter::formatRowExpression).collect(toImmutableList())),
Joiner.on(", ").join(call.getArguments().stream().map(formatter::apply).collect(toImmutableList())),
frameInfo);
}
return processChildren(node, context);
Expand Down Expand Up @@ -669,7 +674,7 @@ public Void visitValues(ValuesNode node, Void context)
{
NodeRepresentation nodeOutput = addNode(node, "Values");
for (List<RowExpression> row : node.getRows()) {
nodeOutput.appendDetailsLine("(" + Joiner.on(", ").join(formatter.formatRowExpressions(row)) + ")");
nodeOutput.appendDetailsLine("(" + row.stream().map(formatter::apply).collect(Collectors.joining(", ")) + ")");
}
return null;
}
Expand Down Expand Up @@ -732,7 +737,7 @@ private Void visitScanFilterAndProjectInfo(
if (filterNode.isPresent()) {
operatorName += "Filter";
formatString += "filterPredicate = %s, ";
arguments.add(formatter.formatRowExpression(filterNode.get().getPredicate()));
arguments.add(formatter.apply(filterNode.get().getPredicate()));
}

if (formatString.length() > 1) {
Expand Down Expand Up @@ -1066,7 +1071,7 @@ private void printAssignments(NodeRepresentation nodeOutput, Assignments assignm
// skip identity assignments
continue;
}
nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), formatter.formatRowExpression(entry.getValue()));
nodeOutput.appendDetailsLine("%s := %s", entry.getKey(), formatter.apply(entry.getValue()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,82 +39,80 @@

public final class RowExpressionFormatter
{
private final ConnectorSession session;
private final FunctionMetadataManager functionMetadataManager;
private final StandardFunctionResolution standardFunctionResolution;

public RowExpressionFormatter(ConnectorSession session, FunctionManager functionManager)
public RowExpressionFormatter(FunctionManager functionManager)
{
this.session = requireNonNull(session, "session is null");
this.functionMetadataManager = requireNonNull(functionManager, "function manager is null");
this.standardFunctionResolution = new FunctionResolution(functionManager);
}

public String formatRowExpression(RowExpression expression)
public String formatRowExpression(ConnectorSession session, RowExpression expression)
{
return expression.accept(new Formatter(), null);
return expression.accept(new Formatter(), requireNonNull(session, "session is null"));
}

public List<String> formatRowExpressions(List<RowExpression> rowExpressions)
private List<String> formatRowExpressions(ConnectorSession session, List<RowExpression> rowExpressions)
{
return rowExpressions.stream().map(this::formatRowExpression).collect(toList());
return rowExpressions.stream().map(rowExpression -> formatRowExpression(session, rowExpression)).collect(toList());
}

public class Formatter
implements RowExpressionVisitor<String, Void>
implements RowExpressionVisitor<String, ConnectorSession>
{
@Override
public String visitCall(CallExpression node, Void context)
public String visitCall(CallExpression node, ConnectorSession session)
{
if (standardFunctionResolution.isArithmeticFunction(node.getFunctionHandle()) || standardFunctionResolution.isComparisonFunction(node.getFunctionHandle())) {
String operation = functionMetadataManager.getFunctionMetadata(node.getFunctionHandle()).getOperatorType().get().getOperator();
return String.join(" " + operation + " ", formatRowExpressions(node.getArguments()).stream().map(e -> "(" + e + ")").collect(toImmutableList()));
return String.join(" " + operation + " ", formatRowExpressions(session, node.getArguments()).stream().map(e -> "(" + e + ")").collect(toImmutableList()));
}
else if (standardFunctionResolution.isCastFunction(node.getFunctionHandle())) {
return String.format("CAST(%s AS %s)", formatRowExpression(node.getArguments().get(0)), node.getType().getDisplayName());
return String.format("CAST(%s AS %s)", formatRowExpression(session, node.getArguments().get(0)), node.getType().getDisplayName());
}
else if (standardFunctionResolution.isNegateFunction(node.getFunctionHandle())) {
return "-(" + formatRowExpression(node.getArguments().get(0)) + ")";
return "-(" + formatRowExpression(session, node.getArguments().get(0)) + ")";
}
else if (standardFunctionResolution.isSubscriptFunction(node.getFunctionHandle())) {
return formatRowExpression(node.getArguments().get(0)) + "[" + formatRowExpression(node.getArguments().get(1)) + "]";
return formatRowExpression(session, node.getArguments().get(0)) + "[" + formatRowExpression(session, node.getArguments().get(1)) + "]";
}
else if (standardFunctionResolution.isBetweenFunction(node.getFunctionHandle())) {
List<String> formattedExpresions = formatRowExpressions(node.getArguments());
List<String> formattedExpresions = formatRowExpressions(session, node.getArguments());
return String.format("%s BETWEEN (%s) AND (%s)", formattedExpresions.get(0), formattedExpresions.get(1), formattedExpresions.get(2));
}
return node.getDisplayName() + "(" + String.join(", ", formatRowExpressions(node.getArguments())) + ")";
return node.getDisplayName() + "(" + String.join(", ", formatRowExpressions(session, node.getArguments())) + ")";
}

@Override
public String visitSpecialForm(SpecialFormExpression node, Void context)
public String visitSpecialForm(SpecialFormExpression node, ConnectorSession session)
{
if (node.getForm().equals(SpecialFormExpression.Form.AND) || node.getForm().equals(SpecialFormExpression.Form.OR)) {
return String.join(" " + node.getForm() + " ", formatRowExpressions(node.getArguments()).stream().map(e -> "(" + e + ")").collect(toImmutableList()));
return String.join(" " + node.getForm() + " ", formatRowExpressions(session, node.getArguments()).stream().map(e -> "(" + e + ")").collect(toImmutableList()));
}
return node.getForm().name() + "(" + String.join(", ", formatRowExpressions(node.getArguments())) + ")";
return node.getForm().name() + "(" + String.join(", ", formatRowExpressions(session, node.getArguments())) + ")";
}

@Override
public String visitInputReference(InputReferenceExpression node, Void context)
public String visitInputReference(InputReferenceExpression node, ConnectorSession session)
{
return node.toString();
}

@Override
public String visitLambda(LambdaDefinitionExpression node, Void context)
public String visitLambda(LambdaDefinitionExpression node, ConnectorSession session)
{
return "(" + String.join(", ", node.getArguments()) + ") -> " + formatRowExpression(node.getBody());
return "(" + String.join(", ", node.getArguments()) + ") -> " + formatRowExpression(session, node.getBody());
}

@Override
public String visitVariableReference(VariableReferenceExpression node, Void context)
public String visitVariableReference(VariableReferenceExpression node, ConnectorSession session)
{
return node.getName();
}

@Override
public String visitConstant(ConstantExpression node, Void context)
public String visitConstant(ConstantExpression node, ConnectorSession session)
{
Object value = LiteralInterpreter.evaluate(session, node);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.Session;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
Expand Down Expand Up @@ -72,12 +73,14 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Maps.immutableEnumMap;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public final class GraphvizPrinter
{
Expand Down Expand Up @@ -216,13 +219,15 @@ private static class NodePrinter
private static final int MAX_NAME_WIDTH = 100;
private final StringBuilder output;
private final PlanNodeIdGenerator idGenerator;
private final RowExpressionFormatter formatter;
private final Function<RowExpression, String> formatter;

public NodePrinter(StringBuilder output, PlanNodeIdGenerator idGenerator, Session session, FunctionManager functionManager)
{
this.output = output;
this.idGenerator = idGenerator;
this.formatter = new RowExpressionFormatter(session.toConnectorSession(), functionManager);
RowExpressionFormatter rowExpressionFormatter = new RowExpressionFormatter(functionManager);
ConnectorSession connectorSession = requireNonNull(session, "session is null").toConnectorSession();
this.formatter = rowExpression -> rowExpressionFormatter.formatRowExpression(connectorSession, rowExpression);
}

@Override
Expand Down Expand Up @@ -392,7 +397,7 @@ public Void visitGroupId(GroupIdNode node, Void context)
@Override
public Void visitFilter(FilterNode node, Void context)
{
String expression = formatter.formatRowExpression(node.getPredicate());
String expression = formatter.apply(node.getPredicate());
printNode(node, "Filter", expression, NODE_COLORS.get(NodeType.FILTER));
return node.getSource().accept(this, context);
}
Expand All @@ -407,7 +412,7 @@ public Void visitProject(ProjectNode node, Void context)
// skip identity assignments
continue;
}
builder.append(format("%s := %s\\n", entry.getKey(), formatter.formatRowExpression(entry.getValue())));
builder.append(format("%s := %s\\n", entry.getKey(), formatter.apply(entry.getValue())));
}

printNode(node, "Project", builder.toString(), NODE_COLORS.get(NodeType.PROJECT));
Expand Down Expand Up @@ -508,7 +513,7 @@ public Void visitSemiJoin(SemiJoinNode node, Void context)
@Override
public Void visitSpatialJoin(SpatialJoinNode node, Void context)
{
printNode(node, node.getType().getJoinLabel(), formatter.formatRowExpression(node.getFilter()), NODE_COLORS.get(NodeType.JOIN));
printNode(node, node.getType().getJoinLabel(), formatter.apply(node.getFilter()), NODE_COLORS.get(NodeType.JOIN));

node.getLeft().accept(this, context);
node.getRight().accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public class TestRowExpressionFormatter
{
private static final TypeManager typeManager = new TypeRegistry();
private static final FunctionManager functionManager = new FunctionManager(typeManager, new BlockEncodingManager(typeManager), new FeaturesConfig());
private static final RowExpressionFormatter FORMATTER = new RowExpressionFormatter(TEST_SESSION.toConnectorSession(), functionManager);
private static final RowExpressionFormatter FORMATTER = new RowExpressionFormatter(functionManager);
private static final VariableReferenceExpression C_BIGINT = new VariableReferenceExpression("c_bigint", BIGINT);
private static final VariableReferenceExpression C_BIGINT_ARRAY = new VariableReferenceExpression("c_bigint_array", new ArrayType(BIGINT));

Expand Down Expand Up @@ -320,6 +320,6 @@ private static CallExpression createCallExpression(OperatorType type)

private static String format(RowExpression expression)
{
return FORMATTER.formatRowExpression(expression);
return FORMATTER.formatRowExpression(TEST_SESSION.toConnectorSession(), expression);
}
}

0 comments on commit bca2bce

Please sign in to comment.