Skip to content

Commit

Permalink
Never spool non-select queries
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Oct 28, 2024
1 parent 3b6a173 commit cca708b
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 5 deletions.
32 changes: 32 additions & 0 deletions core/trino-main/src/main/java/io/trino/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,38 @@ public Session withExchangeEncryption(Slice encryptionKey)
queryDataEncoding);
}

public Session withoutSpooling()
{
return new Session(
queryId,
querySpan,
transactionId,
clientTransactionSupport,
identity,
originalIdentity,
source,
catalog,
schema,
path,
traceToken,
timeZoneKey,
locale,
remoteUserAddress,
userAgent,
clientInfo,
clientTags,
clientCapabilities,
resourceEstimates,
start,
systemProperties,
catalogProperties,
sessionPropertyManager,
preparedStatements,
protocolHeaders,
exchangeEncryptionKey,
Optional.empty());
}

public ConnectorSession toConnectorSession()
{
return new FullConnectorSession(this, identity.toConnectorIdentity());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
import static io.trino.server.DynamicFilterService.DynamicFiltersStats;
import static io.trino.spi.StandardErrorCode.NOT_FOUND;
import static io.trino.spi.StandardErrorCode.USER_CANCELED;
import static io.trino.spi.resourcegroups.QueryType.SELECT;
import static io.trino.util.Ciphers.createRandomAesEncryptionKey;
import static io.trino.util.Ciphers.serializeAesEncryptionKey;
import static io.trino.util.Failures.toFailure;
Expand Down Expand Up @@ -308,6 +309,10 @@ static QueryStateMachine beginWithTicker(
session = session.withExchangeEncryption(serializeAesEncryptionKey(createRandomAesEncryptionKey()));
}

if (!queryType.map(SELECT::equals).orElse(false)) {
session = session.withoutSpooling();
}

Span querySpan = session.getQuerySpan();

querySpan.setAttribute(TrinoAttributes.QUERY_TYPE, queryType.map(Enum::name).orElse("UNKNOWN"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public class MaterializedResult
private final List<MaterializedRow> rows;
private final List<Type> types;
private final List<String> columnNames;
private final Optional<String> queryDataEncoding;
private final Map<String, String> setSessionProperties;
private final Set<String> resetSessionProperties;
private final Optional<String> updateType;
Expand All @@ -76,18 +77,19 @@ public class MaterializedResult

public MaterializedResult(List<MaterializedRow> rows, List<? extends Type> types)
{
this(rows, types, Optional.empty());
this(rows, types, Optional.empty(), Optional.empty());
}

public MaterializedResult(List<MaterializedRow> rows, List<? extends Type> types, Optional<List<String>> columnNames)
public MaterializedResult(List<MaterializedRow> rows, List<? extends Type> types, Optional<List<String>> columnNames, Optional<String> queryDataEncoding)
{
this(rows, types, columnNames.orElse(ImmutableList.of()), ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty());
this(rows, types, columnNames.orElse(ImmutableList.of()), queryDataEncoding, ImmutableMap.of(), ImmutableSet.of(), Optional.empty(), OptionalLong.empty(), ImmutableList.of(), Optional.empty());
}

public MaterializedResult(
List<MaterializedRow> rows,
List<? extends Type> types,
List<String> columnNames,
Optional<String> queryDataEncoding,
Map<String, String> setSessionProperties,
Set<String> resetSessionProperties,
Optional<String> updateType,
Expand All @@ -98,6 +100,7 @@ public MaterializedResult(
this.rows = ImmutableList.copyOf(requireNonNull(rows, "rows is null"));
this.types = ImmutableList.copyOf(requireNonNull(types, "types is null"));
this.columnNames = ImmutableList.copyOf(requireNonNull(columnNames, "columnNames is null"));
this.queryDataEncoding = requireNonNull(queryDataEncoding, "queryDataEncoding is null");
this.setSessionProperties = ImmutableMap.copyOf(requireNonNull(setSessionProperties, "setSessionProperties is null"));
this.resetSessionProperties = ImmutableSet.copyOf(requireNonNull(resetSessionProperties, "resetSessionProperties is null"));
this.updateType = requireNonNull(updateType, "updateType is null");
Expand Down Expand Up @@ -133,6 +136,11 @@ public List<String> getColumnNames()
return columnNames;
}

public Optional<String> getQueryDataEncoding()
{
return queryDataEncoding;
}

public Map<String, String> getSetSessionProperties()
{
return setSessionProperties;
Expand Down Expand Up @@ -280,6 +288,7 @@ public MaterializedResult toTestTypes()
.collect(toImmutableList()),
types,
columnNames,
queryDataEncoding,
setSessionProperties,
resetSessionProperties,
updateType,
Expand Down Expand Up @@ -360,6 +369,7 @@ public static class Builder
private final ConnectorSession session;
private final List<Type> types;
private final ImmutableList.Builder<MaterializedRow> rows = ImmutableList.builder();
private Optional<String> queryDataEncoding = Optional.empty();
private Optional<List<String>> columnNames = Optional.empty();

Builder(ConnectorSession session, List<Type> types)
Expand Down Expand Up @@ -422,9 +432,15 @@ public synchronized Builder columnNames(List<String> columnNames)
return this;
}

public synchronized Builder queryDataEncoding(String encoding)
{
this.queryDataEncoding = Optional.of(requireNonNull(encoding, "encoding is null"));
return this;
}

public synchronized MaterializedResult build()
{
return new MaterializedResult(rows.build(), types, columnNames);
return new MaterializedResult(rows.build(), types, columnNames, queryDataEncoding);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ private static MaterializedResult toMaterializedRows(DispatchQuery dispatchQuery
ImmutableList.of(),
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
queryInfo.getSetSessionProperties(),
queryInfo.getResetSessionProperties(),
Optional.ofNullable(queryInfo.getUpdateType()),
Expand All @@ -106,6 +107,7 @@ private static MaterializedResult toMaterializedRows(DispatchQuery dispatchQuery
materializedRows,
columnTypes,
columnNames,
Optional.empty(),
queryInfo.getSetSessionProperties(),
queryInfo.getResetSessionProperties(),
Optional.ofNullable(queryInfo.getUpdateType()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.client.Column;
import io.trino.client.QueryStatusInfo;
import io.trino.client.StatementClient;
import io.trino.client.spooling.EncodedQueryData;
import io.trino.metadata.MetadataUtil;
import io.trino.metadata.QualifiedObjectName;
import io.trino.metadata.QualifiedTablePrefix;
Expand Down Expand Up @@ -107,6 +108,9 @@ public ResultWithQueryId<T> execute(Session session, @Language("SQL") String sql
try (StatementClient client = statementClientFactory.create(httpClient, session, clientSession, sql)) {
while (client.isRunning()) {
resultsSession.addResults(client.currentStatusInfo(), client.currentRows());
if (client.currentData() instanceof EncodedQueryData encodedQueryData) {
resultsSession.setQueryDataEncoding(encodedQueryData.getEncoding());
}
client.advance();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ default void setStatementStats(StatementStats statementStats) {}

void addResults(QueryStatusInfo statusInfo, ResultRows access);

default void setQueryDataEncoding(String encoding)
{
}

T build(Map<String, String> setSessionProperties, Set<String> resetSessionProperties);
}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private class MaterializedResultSession

private final AtomicReference<List<Type>> types = new AtomicReference<>();
private final AtomicReference<List<String>> columnNames = new AtomicReference<>();

private final AtomicReference<String> queryDataEncoding = new AtomicReference<>();
private final AtomicReference<Optional<String>> updateType = new AtomicReference<>(Optional.empty());
private final AtomicReference<OptionalLong> updateCount = new AtomicReference<>(OptionalLong.empty());
private final AtomicReference<List<Warning>> warnings = new AtomicReference<>(ImmutableList.of());
Expand Down Expand Up @@ -175,6 +175,12 @@ public void addResults(QueryStatusInfo statusInfo, ResultRows data)
rows.addAll(mappedCopy(data, dataToRow(types.get())));
}

@Override
public void setQueryDataEncoding(String encoding)
{
queryDataEncoding.set(encoding);
}

@Override
public MaterializedResult build(Map<String, String> setSessionProperties, Set<String> resetSessionProperties)
{
Expand All @@ -183,6 +189,7 @@ public MaterializedResult build(Map<String, String> setSessionProperties, Set<St
rows.build(),
types.get(),
columnNames.get(),
Optional.ofNullable(queryDataEncoding.get()),
setSessionProperties,
resetSessionProperties,
updateType.get(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.testing.TestingTrinoClient;
import io.trino.tpch.TpchTable;
import okhttp3.OkHttpClient;
import org.junit.jupiter.api.Test;
import org.testcontainers.containers.localstack.LocalStackContainer;
import org.testcontainers.containers.localstack.LocalStackContainer.Service;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
Expand All @@ -45,6 +46,7 @@
import static io.trino.client.StatementClientFactory.newStatementClient;
import static io.trino.util.Ciphers.createRandomAesEncryptionKey;
import static java.util.Base64.getEncoder;
import static org.assertj.core.api.Assertions.assertThat;

public abstract class AbstractSpooledQueryDataDistributedQueries
extends AbstractTestEngineOnlyQueries
Expand Down Expand Up @@ -110,6 +112,30 @@ protected QueryRunner createQueryRunner()
return queryRunner;
}

@Test
public void testSpoolingDisabledForNonSelectQueries()
{
// Ensure that spooling is enabled for SELECT queries
assertThat(computeActual("SELECT * FROM nation").getQueryDataEncoding())
.hasValue(encoding());

// The rest of the cases are not meant to cover all possible query shapes
assertThat(computeActual("EXPLAIN SELECT * FROM nation").getQueryDataEncoding())
.isEmpty();

assertThat(computeActual("CREATE TABLE spooling_test (col INT)").getQueryDataEncoding())
.isEmpty();

assertThat(computeActual("INSERT INTO spooling_test (col) VALUES (2137)").getQueryDataEncoding())
.isEmpty();

assertThat(computeActual("SHOW SESSION").getQueryDataEncoding())
.isEmpty();

assertThat(computeActual("DROP TABLE spooling_test").getQueryDataEncoding())
.isEmpty();
}

private static TestingTrinoClient createClient(TestingTrinoServer testingTrinoServer, Session session, String encoding)
{
return new TestingTrinoClient(testingTrinoServer, new TestingStatementClientFactory() {
Expand Down

0 comments on commit cca708b

Please sign in to comment.