Skip to content

Commit

Permalink
Remove createExpressionAnalyzer from StatementAnalyzerFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Dec 20, 2021
1 parent 8e40a44 commit c13f952
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,32 @@ public class ExpressionAnalyzer
private final Function<Node, ResolvedWindow> getResolvedWindow;
private final List<Field> sourceFields = new ArrayList<>();

private ExpressionAnalyzer(
PlannerContext plannerContext,
AccessControl accessControl,
StatementAnalyzerFactory statementAnalyzerFactory,
Analysis analysis,
Session session,
TypeProvider types,
WarningCollector warningCollector)
{
this(
plannerContext,
accessControl,
(node, correlationSupport) -> statementAnalyzerFactory.createStatementAnalyzer(
analysis,
session,
warningCollector,
correlationSupport),
session,
types,
analysis.getParameters(),
warningCollector,
analysis.isDescribe(),
analysis::getType,
analysis::getWindow);
}

ExpressionAnalyzer(
PlannerContext plannerContext,
AccessControl accessControl,
Expand Down Expand Up @@ -2664,6 +2690,7 @@ public static boolean isPatternRecognitionFunction(FunctionCall node)

public static ExpressionAnalysis analyzePatternRecognitionExpression(
Session session,
PlannerContext plannerContext,
StatementAnalyzerFactory statementAnalyzerFactory,
AccessControl accessControl,
Scope scope,
Expand All @@ -2672,7 +2699,7 @@ public static ExpressionAnalysis analyzePatternRecognitionExpression(
WarningCollector warningCollector,
Set<String> labels)
{
ExpressionAnalyzer analyzer = statementAnalyzerFactory.createExpressionAnalyzer(analysis, session, TypeProvider.empty(), warningCollector);
ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, TypeProvider.empty(), warningCollector);
analyzer.analyze(expression, scope, labels);

updateAnalysis(analysis, analyzer, session, accessControl);
Expand All @@ -2691,15 +2718,17 @@ public static ExpressionAnalysis analyzePatternRecognitionExpression(

public static ExpressionAnalysis analyzeExpressions(
Session session,
PlannerContext plannerContext,
StatementAnalyzerFactory statementAnalyzerFactory,
AccessControl accessControl,
TypeProvider types,
Iterable<Expression> expressions,
Map<NodeRef<Parameter>, Expression> parameters,
WarningCollector warningCollector,
QueryType queryType)
{
Analysis analysis = new Analysis(null, parameters, queryType);
ExpressionAnalyzer analyzer = statementAnalyzerFactory.createExpressionAnalyzer(analysis, session, types, warningCollector);
ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, types, warningCollector);
for (Expression expression : expressions) {
analyzer.analyze(
expression,
Expand All @@ -2722,6 +2751,7 @@ public static ExpressionAnalysis analyzeExpressions(

public static ExpressionAnalysis analyzeExpression(
Session session,
PlannerContext plannerContext,
StatementAnalyzerFactory statementAnalyzerFactory,
AccessControl accessControl,
Scope scope,
Expand All @@ -2730,7 +2760,7 @@ public static ExpressionAnalysis analyzeExpression(
WarningCollector warningCollector,
CorrelationSupport correlationSupport)
{
ExpressionAnalyzer analyzer = statementAnalyzerFactory.createExpressionAnalyzer(analysis, session, TypeProvider.empty(), warningCollector);
ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, TypeProvider.empty(), warningCollector);
analyzer.analyze(expression, scope, correlationSupport);

updateAnalysis(analysis, analyzer, session, accessControl);
Expand All @@ -2750,6 +2780,7 @@ public static ExpressionAnalysis analyzeExpression(

public static ExpressionAnalysis analyzeWindow(
Session session,
PlannerContext plannerContext,
StatementAnalyzerFactory statementAnalyzerFactory,
AccessControl accessControl,
Scope scope,
Expand All @@ -2759,7 +2790,7 @@ public static ExpressionAnalysis analyzeWindow(
ResolvedWindow window,
Node originalNode)
{
ExpressionAnalyzer analyzer = statementAnalyzerFactory.createExpressionAnalyzer(analysis, session, TypeProvider.empty(), warningCollector);
ExpressionAnalyzer analyzer = new ExpressionAnalyzer(plannerContext, accessControl, statementAnalyzerFactory, analysis, session, TypeProvider.empty(), warningCollector);
analyzer.analyzeWindow(window, scope, originalNode, correlationSupport);

updateAnalysis(analysis, analyzer, session, accessControl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.cost.CostCalculator;
import io.trino.cost.StatsCalculator;
import io.trino.execution.warnings.WarningCollector;
import io.trino.security.AccessControl;
import io.trino.spi.TrinoException;
import io.trino.sql.PlannerContext;
import io.trino.sql.SqlFormatter;
Expand Down Expand Up @@ -58,6 +59,7 @@ public class QueryExplainer
private final PlannerContext plannerContext;
private final AnalyzerFactory analyzerFactory;
private final StatementAnalyzerFactory statementAnalyzerFactory;
private final AccessControl accessControl;
private final StatsCalculator statsCalculator;
private final CostCalculator costCalculator;

Expand All @@ -67,6 +69,7 @@ public class QueryExplainer
PlannerContext plannerContext,
AnalyzerFactory analyzerFactory,
StatementAnalyzerFactory statementAnalyzerFactory,
AccessControl accessControl,
StatsCalculator statsCalculator,
CostCalculator costCalculator)
{
Expand All @@ -75,6 +78,7 @@ public class QueryExplainer
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.analyzerFactory = requireNonNull(analyzerFactory, "analyzerFactory is null");
this.statementAnalyzerFactory = requireNonNull(statementAnalyzerFactory, "statementAnalyzerFactory is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
}
Expand Down Expand Up @@ -163,7 +167,7 @@ public Plan getLogicalPlan(Session session, Statement statement, List<Expression
planOptimizers,
idAllocator,
plannerContext,
new TypeAnalyzer(statementAnalyzerFactory),
new TypeAnalyzer(plannerContext, statementAnalyzerFactory, accessControl),
statsCalculator,
costCalculator,
warningCollector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import io.trino.cost.CostCalculator;
import io.trino.cost.StatsCalculator;
import io.trino.security.AccessControl;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.PlanFragmenter;
import io.trino.sql.planner.PlanOptimizersFactory;
Expand All @@ -29,6 +30,7 @@ public class QueryExplainerFactory
private final PlanFragmenter planFragmenter;
private final PlannerContext plannerContext;
private final StatementAnalyzerFactory statementAnalyzerFactory;
private final AccessControl accessControl;
private final StatsCalculator statsCalculator;
private final CostCalculator costCalculator;

Expand All @@ -38,13 +40,15 @@ public QueryExplainerFactory(
PlanFragmenter planFragmenter,
PlannerContext plannerContext,
StatementAnalyzerFactory statementAnalyzerFactory,
AccessControl accessControl,
StatsCalculator statsCalculator,
CostCalculator costCalculator)
{
this.planOptimizersFactory = requireNonNull(planOptimizersFactory, "planOptimizersFactory is null");
this.planFragmenter = requireNonNull(planFragmenter, "planFragmenter is null");
this.plannerContext = requireNonNull(plannerContext, "metadata is null");
this.statementAnalyzerFactory = requireNonNull(statementAnalyzerFactory, "statementAnalyzerFactory is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
}
Expand All @@ -57,6 +61,7 @@ public QueryExplainer createQueryExplainer(AnalyzerFactory analyzerFactory)
plannerContext,
analyzerFactory,
statementAnalyzerFactory,
accessControl,
statsCalculator,
costCalculator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,7 @@ private ExpressionAnalysis analyzePatternRecognitionExpression(Expression expres

return ExpressionAnalyzer.analyzePatternRecognitionExpression(
session,
plannerContext,
statementAnalyzerFactory,
accessControl,
scope,
Expand Down Expand Up @@ -2070,7 +2071,9 @@ protected Scope visitSampledRelation(SampledRelation relation, Optional<Scope> s

Map<NodeRef<Expression>, Type> expressionTypes = ExpressionAnalyzer.analyzeExpressions(
session,
plannerContext,
statementAnalyzerFactory,
accessControl,
TypeProvider.empty(),
ImmutableList.of(samplePercentage),
analysis.getParameters(),
Expand Down Expand Up @@ -2757,6 +2760,7 @@ private void analyzeWindow(QuerySpecification querySpecification, ResolvedWindow
{
ExpressionAnalysis expressionAnalysis = ExpressionAnalyzer.analyzeWindow(
session,
plannerContext,
statementAnalyzerFactory,
accessControl,
scope,
Expand Down Expand Up @@ -3541,6 +3545,7 @@ private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope)
{
return ExpressionAnalyzer.analyzeExpression(
session,
plannerContext,
statementAnalyzerFactory,
accessControl,
scope,
Expand All @@ -3554,6 +3559,7 @@ private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope,
{
return ExpressionAnalyzer.analyzeExpression(
session,
plannerContext,
statementAnalyzerFactory,
accessControl,
scope,
Expand Down Expand Up @@ -3585,6 +3591,7 @@ private void analyzeRowFilter(String currentIdentity, Table table, QualifiedObje
try {
expressionAnalysis = ExpressionAnalyzer.analyzeExpression(
createViewSession(filter.getCatalog(), filter.getSchema(), Identity.forUser(filter.getIdentity()).build(), session.getPath()), // TODO: path should be included in row filter
plannerContext,
statementAnalyzerFactory,
accessControl,
scope,
Expand Down Expand Up @@ -3639,6 +3646,7 @@ private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObj
try {
expressionAnalysis = ExpressionAnalyzer.analyzeExpression(
createViewSession(mask.getCatalog(), mask.getSchema(), Identity.forUser(mask.getIdentity()).build(), session.getPath()), // TODO: path should be included in row filter
plannerContext,
statementAnalyzerFactory,
accessControl,
scope,
Expand Down Expand Up @@ -4061,7 +4069,9 @@ private List<Expression> analyzeOrderBy(Node node, List<SortItem> sortItems, Sco
expression = new FieldReference(toIntExact(ordinal - 1));
}

ExpressionAnalysis expressionAnalysis = ExpressionAnalyzer.analyzeExpression(session,
ExpressionAnalysis expressionAnalysis = ExpressionAnalyzer.analyzeExpression(
session,
plannerContext,
statementAnalyzerFactory,
accessControl,
orderByScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import io.trino.spi.security.GroupProvider;
import io.trino.sql.PlannerContext;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.planner.TypeProvider;

import javax.inject.Inject;

Expand Down Expand Up @@ -103,30 +102,6 @@ public StatementAnalyzer createStatementAnalyzer(
correlationSupport);
}

// this is only for the static factory methods on ExpressionAnalyzer, and should not be used for any other purpose
ExpressionAnalyzer createExpressionAnalyzer(
Analysis analysis,
Session session,
TypeProvider types,
WarningCollector warningCollector)
{
return new ExpressionAnalyzer(
plannerContext,
accessControl,
(node, correlationSupport) -> createStatementAnalyzer(
analysis,
session,
warningCollector,
correlationSupport),
session,
types,
analysis.getParameters(),
warningCollector,
analysis.isDescribe(),
analysis::getType,
analysis::getWindow);
}

public static StatementAnalyzerFactory createTestingStatementAnalyzerFactory(
PlannerContext plannerContext,
AccessControl accessControl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,16 +304,19 @@ private static boolean isBetween(Range range)
public static ExtractionResult getExtractionResult(PlannerContext plannerContext, Session session, Expression predicate, TypeProvider types)
{
// This is a limited type analyzer for the simple expressions used in this method
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(new StatementAnalyzerFactory(
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(
plannerContext,
new SqlParser(),
new AllowAllAccessControl(),
user -> ImmutableSet.of(),
new TableProceduresRegistry(),
new SessionPropertyManager(),
new TablePropertyManager(),
new AnalyzePropertyManager(),
new TableProceduresPropertyManager()));
new StatementAnalyzerFactory(
plannerContext,
new SqlParser(),
new AllowAllAccessControl(),
user -> ImmutableSet.of(),
new TableProceduresRegistry(),
new SessionPropertyManager(),
new TablePropertyManager(),
new AnalyzePropertyManager(),
new TableProceduresPropertyManager()),
new AllowAllAccessControl());
return new Visitor(plannerContext, session, types, typeAnalyzer).process(predicate, false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.AnalyzePropertyManager;
import io.trino.metadata.TablePropertyManager;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
Expand All @@ -42,18 +43,25 @@
*/
public class TypeAnalyzer
{
private final PlannerContext plannerContext;
private final StatementAnalyzerFactory statementAnalyzerFactory;
private final AccessControl accessControl;

@Inject
public TypeAnalyzer(StatementAnalyzerFactory statementAnalyzerFactory)
public TypeAnalyzer(PlannerContext plannerContext, StatementAnalyzerFactory statementAnalyzerFactory, AccessControl accessControl)
{
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
this.statementAnalyzerFactory = requireNonNull(statementAnalyzerFactory, "statementAnalyzerFactory is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
}

public Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, Iterable<Expression> expressions)
{
return analyzeExpressions(session,
return analyzeExpressions(
session,
plannerContext,
statementAnalyzerFactory,
accessControl,
inputTypes,
expressions,
ImmutableMap.of(),
Expand All @@ -74,10 +82,13 @@ public Type getType(Session session, TypeProvider inputTypes, Expression express

public static TypeAnalyzer createTestingTypeAnalyzer(PlannerContext plannerContext)
{
return new TypeAnalyzer(createTestingStatementAnalyzerFactory(
return new TypeAnalyzer(
plannerContext,
new AllowAllAccessControl(),
new TablePropertyManager(),
new AnalyzePropertyManager()));
createTestingStatementAnalyzerFactory(
plannerContext,
new AllowAllAccessControl(),
new TablePropertyManager(),
new AnalyzePropertyManager()),
new AllowAllAccessControl());
}
}
Loading

0 comments on commit c13f952

Please sign in to comment.