Skip to content

Commit

Permalink
Change language function FunctionId to hash of IrRoutine
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Apr 2, 2024
1 parent 0a8ec34 commit bbaa1ee
Show file tree
Hide file tree
Showing 7 changed files with 377 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import com.google.inject.Inject;
import io.trino.Session;
Expand All @@ -26,6 +27,7 @@
import io.trino.security.ViewAccessControl;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.connector.CatalogSchemaName;
import io.trino.spi.connector.ConnectorSession;
Expand All @@ -48,6 +50,7 @@
import io.trino.sql.routine.SqlRoutineAnalysis;
import io.trino.sql.routine.SqlRoutineAnalyzer;
import io.trino.sql.routine.SqlRoutineCompiler;
import io.trino.sql.routine.SqlRoutineHash;
import io.trino.sql.routine.SqlRoutinePlanner;
import io.trino.sql.routine.ir.IrRoutine;
import io.trino.sql.tree.FunctionSpecification;
Expand Down Expand Up @@ -80,16 +83,18 @@ public class LanguageFunctionManager
private final SqlParser parser;
private final TypeManager typeManager;
private final GroupProvider groupProvider;
private final BlockEncodingSerde blockEncodingSerde;
private SqlRoutineAnalyzer analyzer;
private SqlRoutinePlanner planner;
private final Map<QueryId, QueryFunctions> queryFunctions = new ConcurrentHashMap<>();

@Inject
public LanguageFunctionManager(SqlParser parser, TypeManager typeManager, GroupProvider groupProvider)
public LanguageFunctionManager(SqlParser parser, TypeManager typeManager, GroupProvider groupProvider, BlockEncodingSerde blockEncodingSerde)
{
this.parser = requireNonNull(parser, "parser is null");
this.typeManager = requireNonNull(typeManager, "typeManager is null");
this.groupProvider = requireNonNull(groupProvider, "groupProvider is null");
this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null");
}

// There is a circular dependency between LanguageFunctionManager and MetadataManager.
Expand Down Expand Up @@ -211,9 +216,8 @@ public static boolean isTrinoSqlLanguageFunction(FunctionId functionId)

private static FunctionId createSqlLanguageFunctionId(QueryId queryId, String sql)
{
// TODO: The function ID should be a hash of the IrRoutine, not the SQL text
// QueryId is added to the FunctionID to ensures this exact planned IrRoutine is used for the function.
// This breaks caching of the function implementation across queries, but it is necessary to ensure correctness.
// QueryId is added to the FunctionID to ensure it is unique across queries.
// After the function is analyzed and planned, the FunctionId is replaced with a hash of the IrRoutine.
String hash = Hashing.sha256().hashUnencodedChars(sql).toString();
return new FunctionId(SQL_FUNCTION_PREFIX + queryId + "_" + hash);
}
Expand Down Expand Up @@ -265,21 +269,18 @@ public synchronized List<FunctionMetadata> getFunctions(CatalogHandle catalogHan

public FunctionId analyzeAndPlan(FunctionId functionId, AccessControl accessControl)
{
if (usedFunctions.containsKey(functionId)) {
return functionId;
}

LanguageFunctionImplementation function = implementationsById.get(functionId);
checkArgument(function != null, "Unknown function implementation: %s", functionId);

// verify the function and check permissions of nexted function calls
IrRoutine routine = function.analyzeAndPlan(accessControl);
function.analyzeAndPlan(accessControl);

// todo generate a new FunctionId based on a hash of the routine
IrRoutine routine = function.getRoutine();
FunctionId resolvedFunctionId = function.getResolvedFunctionId();

// mark the function as used, so it is serialized for workers
usedFunctions.put(functionId, routine);
return functionId;
usedFunctions.put(resolvedFunctionId, routine);
return resolvedFunctionId;
}

public Optional<ScalarFunctionImplementation> specialize(FunctionId functionId, InvocationConvention invocationConvention, FunctionManager functionManager)
Expand Down Expand Up @@ -381,6 +382,7 @@ private class LanguageFunctionImplementation
private final Optional<String> owner;
private final Optional<RunAsIdentityLoader> identityLoader;
private IrRoutine routine;
private FunctionId resolvedFunctionId;
private boolean analyzing;

private LanguageFunctionImplementation(QueryId queryId, String sql, SqlPath path, Optional<String> owner, Optional<RunAsIdentityLoader> identityLoader)
Expand All @@ -397,31 +399,47 @@ public FunctionMetadata getFunctionMetadata()
return functionMetadata;
}

public void verifyForCreate(FunctionManager functionManager, AccessControl accessControl)
public synchronized void verifyForCreate(FunctionManager functionManager, AccessControl accessControl)
{
checkState(identityLoader.isEmpty(), "create should not enforce security");
IrRoutine routine = analyzeAndPlan(accessControl);
analyzeAndPlan(accessControl);
new SqlRoutineCompiler(functionManager).compile(routine);
}

private synchronized IrRoutine analyzeAndPlan(AccessControl accessControl)
private synchronized void analyzeAndPlan(AccessControl accessControl)
{
if (routine != null) {
return routine;
return;
}
if (analyzing) {
throw new TrinoException(NOT_SUPPORTED, "Recursive language functions are not supported: %s%s".formatted(functionMetadata.getCanonicalName(), functionMetadata.getSignature()));
}

analyzing = true;

FunctionContext context = functionContext(accessControl);
SqlRoutineAnalysis analysis = analyzer.analyze(context.session(), context.accessControl(), functionSpecification);
routine = planner.planSqlFunction(session, functionSpecification, analysis);

Hasher hasher = Hashing.sha256().newHasher();
SqlRoutineHash.hash(routine, hasher, blockEncodingSerde);
resolvedFunctionId = new FunctionId(SQL_FUNCTION_PREFIX + hasher.hash());

analyzing = false;
}

public synchronized IrRoutine getRoutine()
{
checkState(routine != null, "function not yet analyzed");
return routine;
}

public synchronized FunctionId getResolvedFunctionId()
{
checkState(routine != null, "function not yet analyzed");
return resolvedFunctionId;
}

private FunctionContext functionContext(AccessControl accessControl)
{
if (identityLoader.isEmpty() || isRunAsInvoker(functionSpecification)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.trino.spi.ErrorCode;
import io.trino.spi.QueryId;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.AggregationApplicationResult;
import io.trino.spi.connector.Assignment;
Expand Down Expand Up @@ -2823,7 +2824,8 @@ public MetadataManager build()
}

if (languageFunctionManager == null) {
languageFunctionManager = new LanguageFunctionManager(new SqlParser(), typeManager, user -> ImmutableSet.of());
BlockEncodingSerde blockEncodingSerde = new InternalBlockEncodingSerde(new BlockEncodingManager(), typeManager);
languageFunctionManager = new LanguageFunctionManager(new SqlParser(), typeManager, user -> ImmutableSet.of(), blockEncodingSerde);
}

return new MetadataManager(
Expand Down
Loading

0 comments on commit bbaa1ee

Please sign in to comment.