Skip to content

Commit

Permalink
Resolve test functions using standard mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
martint committed Mar 17, 2024
1 parent 161aa29 commit 5385cec
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.trino.Session;
import io.trino.operator.aggregation.TestingAggregationFunction;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.Plugin;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.OperatorType;
Expand All @@ -27,6 +28,7 @@
import io.trino.sql.gen.PageFunctionCompiler;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.testing.QueryRunner;
import io.trino.transaction.TransactionManager;

Expand All @@ -37,6 +39,7 @@
import java.util.stream.Collectors;

import static io.trino.SessionTestUtils.TEST_SESSION;
import static io.trino.metadata.InternalFunctionBundle.extractFunctions;
import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder;
import static io.trino.testing.TransactionBuilder.transaction;
import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager;
Expand All @@ -63,6 +66,21 @@ public TestingFunctionResolution(FunctionBundle functions)
metadata = plannerContext.getMetadata();
}

public TestingFunctionResolution(Plugin plugin)
{
transactionManager = createTestTransactionManager();

TestingPlannerContext.Builder builder = plannerContextBuilder()
.withTransactionManager(transactionManager)
.addFunctions(extractFunctions(plugin.getFunctions()));

plugin.getTypes().forEach(builder::addType);
plugin.getParametricTypes().forEach(builder::addParametricType);

plannerContext = builder.build();
metadata = plannerContext.getMetadata();
}

public TestingFunctionResolution(QueryRunner queryRunner)
{
this(queryRunner.getTransactionManager(), queryRunner.getPlannerContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,8 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.function.FunctionNullability;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.SymbolReference;
Expand All @@ -35,27 +30,18 @@

import java.util.Optional;

import static io.trino.metadata.GlobalFunctionCatalog.builtinFunctionName;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.plugin.geospatial.GeometryType.GEOMETRY;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.planner.assertions.PlanMatchPattern.spatialJoin;
import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.util.SpatialJoinUtils.ST_DISTANCE;

public class TestPruneSpatialJoinChildrenColumns
extends BaseRuleTest
{
// normally a test can just resolve the function from metadata, but the geo functions are in a plugin that is not visible to this module
public static final ResolvedFunction TEST_ST_DISTANCE_FUNCTION = new ResolvedFunction(
new BoundSignature(builtinFunctionName(ST_DISTANCE), BIGINT, ImmutableList.of(BIGINT, BIGINT)),
GlobalSystemConnector.CATALOG_HANDLE,
new FunctionId("st_distance"),
FunctionKind.SCALAR,
true,
new FunctionNullability(false, ImmutableList.of(false, false)),
ImmutableMap.of(),
ImmutableSet.of());
private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(new GeoPlugin());
private static final ResolvedFunction TEST_ST_DISTANCE_FUNCTION = FUNCTIONS.resolveFunction("st_distance", fromTypes(GEOMETRY, GEOMETRY));

@Test
public void testPruneOneChild()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.SymbolReference;
Expand All @@ -29,7 +31,8 @@

import java.util.Optional;

import static io.trino.plugin.geospatial.TestPruneSpatialJoinChildrenColumns.TEST_ST_DISTANCE_FUNCTION;
import static io.trino.plugin.geospatial.GeometryType.GEOMETRY;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.ir.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
import static io.trino.sql.planner.assertions.PlanMatchPattern.spatialJoin;
import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject;
Expand All @@ -38,6 +41,9 @@
public class TestPruneSpatialJoinColumns
extends BaseRuleTest
{
private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(new GeoPlugin());
private static final ResolvedFunction TEST_ST_DISTANCE_FUNCTION = FUNCTIONS.resolveFunction("st_distance", fromTypes(GEOMETRY, GEOMETRY));

@Test
public void notAllOutputsReferenced()
{
Expand Down

0 comments on commit 5385cec

Please sign in to comment.