From 36c08fde6eb0a35461d38c1760f33998192ce0f1 Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Mon, 12 Feb 2024 13:12:39 +0100 Subject: [PATCH] Spooling protocol extension https://github.com/trinodb/trino/issues/22662 --- client/trino-cli/pom.xml | 12 + .../java/io/trino/cli/TestQueryRunner.java | 8 +- client/trino-client/pom.xml | 10 + .../src/main/java/io/trino/client/Column.java | 23 + .../io/trino/client/FixJsonDataUtils.java | 37 +- .../main/java/io/trino/client/JsonCodec.java | 16 + .../java/io/trino/client/ProtocolHeaders.java | 7 + .../main/java/io/trino/client/QueryData.java | 3 + .../client/QueryDataClientJacksonModule.java | 66 +++ .../io/trino/client/QueryDataDecoder.java | 36 ++ .../java/io/trino/client/QueryResults.java | 84 ++-- .../java/io/trino/client/RawQueryData.java | 53 +++ .../trino/client/StatementClientFactory.java | 5 + .../io/trino/client/StatementClientV1.java | 62 ++- .../trino/client/spooling/DataAttribute.java | 109 +++++ .../trino/client/spooling/DataAttributes.java | 120 +++++ .../client/spooling/EncodedQueryData.java | 164 +++++++ .../trino/client/spooling/InlineSegment.java | 51 ++ .../io/trino/client/spooling/Segment.java | 93 ++++ .../trino/client/spooling/SegmentLoader.java | 162 +++++++ .../trino/client/spooling/SpooledSegment.java | 53 +++ .../client/spooling/encoding/CipherUtils.java | 34 ++ .../encoding/CompressedQueryDataDecoder.java | 56 +++ .../encoding/DecryptingQueryDataDecoder.java | 106 +++++ .../encoding/JsonQueryDataDecoder.java | 129 ++++++ .../encoding/Lz4QueryDataDecoder.java | 50 ++ .../spooling/encoding/QueryDataDecoders.java | 55 +++ .../encoding/SnappyQueryDataDecoder.java | 50 ++ .../encoding/ZstdQueryDataDecoder.java | 40 ++ .../io/trino/client/TestQueryResults.java | 2 +- .../test/java/io/trino/client/TestRetry.java | 9 +- client/trino-jdbc/pom.xml | 6 + .../io/trino/jdbc/TestProgressMonitor.java | 11 +- core/trino-main/pom.xml | 13 +- .../src/main/java/io/trino/Session.java | 54 ++- .../java/io/trino/SessionRepresentation.java | 24 +- .../io/trino/cost/PlanNodeStatsEstimate.java | 2 + .../dispatcher/QueuedStatementResource.java | 3 +- .../io/trino/execution/SqlQueryExecution.java | 9 + .../java/io/trino/operator/OperatorInfo.java | 2 + .../operator/OutputSpoolingController.java | 184 ++++++++ .../OutputSpoolingOperatorFactory.java | 435 ++++++++++++++++++ .../HttpRequestSessionContextFactory.java | 14 +- .../java/io/trino/server/PluginManager.java | 12 +- .../io/trino/server/QuerySessionSupplier.java | 15 +- .../src/main/java/io/trino/server/Server.java | 2 + .../io/trino/server/ServerMainModule.java | 2 + .../java/io/trino/server/SessionContext.java | 13 +- .../protocol/ExecutingStatementResource.java | 11 + .../protocol/JsonArrayResultsIterator.java | 231 ++++++++++ .../trino/server/protocol/OutputColumn.java | 31 ++ .../java/io/trino/server/protocol/Query.java | 16 +- .../server/protocol/QueryResultRows.java | 292 ++---------- .../spooling/DataAttributesSerialization.java | 41 ++ .../PreferredQueryDataEncoderSelector.java | 62 +++ .../protocol/spooling/QueryDataEncoder.java | 46 ++ .../spooling/QueryDataEncoderSelector.java | 30 ++ .../spooling/QueryDataJacksonModule.java | 114 +++++ .../protocol/spooling/QueryDataProducer.java | 113 +++++ .../protocol/spooling/SegmentResource.java | 116 +++++ .../protocol/spooling/SpooledBlock.java | 79 ++++ .../protocol/spooling/SpoolingConfig.java | 156 +++++++ .../spooling/SpoolingManagerBridge.java | 163 +++++++ .../spooling/SpoolingManagerRegistry.java | 136 ++++++ .../spooling/SpoolingServerModule.java | 61 +++ .../encoding/CompressedQueryDataEncoder.java | 71 +++ .../encoding/EncryptingQueryDataEncoder.java | 121 +++++ .../encoding/JsonQueryDataEncoder.java | 174 +++++++ .../encoding/Lz4QueryDataEncoder.java | 48 ++ .../encoding/QueryDataEncodingModule.java | 38 ++ .../encoding/SnappyQueryDataEncoder.java | 48 ++ .../encoding/ZstdQueryDataEncoder.java | 48 ++ .../server/testing/TestingTrinoServer.java | 14 + .../io/trino/sql/analyzer/QueryExplainer.java | 5 + .../sql/planner/LocalExecutionPlanner.java | 27 +- .../io/trino/sql/planner/LogicalPlanner.java | 12 +- .../planner/planprinter/GraphvizPrinter.java | 4 +- .../sql/planner/planprinter/PlanPrinter.java | 16 +- .../sql/planner/planprinter/TextRenderer.java | 2 + .../sanity/ValidateDependenciesChecker.java | 4 +- .../java/io/trino/testing/PlanTester.java | 25 +- .../java/io/trino/testing/QueryRunner.java | 2 + .../trino/testing/StandaloneQueryRunner.java | 6 + .../io/trino/execution/TaskTestUtils.java | 10 + .../TestOutputSpoolingController.java | 183 ++++++++ .../operator/spooling/TestSpooledBlock.java | 88 ++++ .../TestHttpRequestSessionContextFactory.java | 3 +- .../io/trino/server/TestQueryResource.java | 47 +- .../server/TestQuerySessionSupplier.java | 7 +- .../server/TestQueryStateInfoResource.java | 28 +- .../protocol/TestQueryDataSerialization.java | 319 +++++++++++++ .../server/protocol/TestQueryResultRows.java | 52 +-- .../TestQueryResultsSerialization.java | 188 ++++++++ .../server/security/TestResourceSecurity.java | 3 +- .../java/io/trino/server/ui/TestWebUi.java | 3 +- core/trino-server/src/main/provisio/trino.xml | 6 + .../src/main/java/io/trino/spi/Plugin.java | 6 + .../spi/protocol/SpooledSegmentHandle.java | 18 + .../trino/spi/protocol/SpoolingContext.java | 30 ++ .../trino/spi/protocol/SpoolingManager.java | 53 +++ .../spi/protocol/SpoolingManagerContext.java | 30 ++ .../spi/protocol/SpoolingManagerFactory.java | 23 + plugin/trino-spooling-filesystem/pom.xml | 145 ++++++ .../FileSystemSpooledSegmentHandle.java | 43 ++ .../filesystem/FileSystemSpoolingConfig.java | 97 ++++ .../filesystem/FileSystemSpoolingManager.java | 136 ++++++ .../FileSystemSpoolingManagerFactory.java | 62 +++ .../filesystem/FilesystemSpoolingModule.java | 70 +++ .../filesystem/FilesystemSpoolingPlugin.java | 29 ++ .../filesystem/SwitchingFileSystem.java | 152 ++++++ .../TestFileSystemSpoolingManager.java | 98 ++++ pom.xml | 14 + testing/trino-plugin-reader/pom.xml | 14 + .../io/trino/testing/containers/Minio.java | 7 +- .../io/trino/testing/minio/MinioClient.java | 6 + .../AbstractTestEngineOnlyQueries.java | 4 +- .../trino/testing/DistributedQueryRunner.java | 8 + testing/trino-tests/pom.xml | 31 ++ ...actSpooledQueryDataDistributedQueries.java | 108 +++++ .../TestJsonLz4SpooledDistributedQueries.java | 24 + ...stJsonSnappySpooledDistributedQueries.java | 24 + .../TestJsonSpooledDistributedQueries.java | 24 + ...TestJsonZstdSpooledDistributedQueries.java | 24 + .../test/java/io/trino/tests/TestServer.java | 34 +- 124 files changed, 6703 insertions(+), 417 deletions(-) create mode 100644 client/trino-client/src/main/java/io/trino/client/QueryDataClientJacksonModule.java create mode 100644 client/trino-client/src/main/java/io/trino/client/QueryDataDecoder.java create mode 100644 client/trino-client/src/main/java/io/trino/client/RawQueryData.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/DataAttribute.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/DataAttributes.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/EncodedQueryData.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/InlineSegment.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/Segment.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/SegmentLoader.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/SpooledSegment.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/CipherUtils.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/CompressedQueryDataDecoder.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/DecryptingQueryDataDecoder.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/JsonQueryDataDecoder.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/Lz4QueryDataDecoder.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/QueryDataDecoders.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/SnappyQueryDataDecoder.java create mode 100644 client/trino-client/src/main/java/io/trino/client/spooling/encoding/ZstdQueryDataDecoder.java create mode 100644 core/trino-main/src/main/java/io/trino/operator/OutputSpoolingController.java create mode 100644 core/trino-main/src/main/java/io/trino/operator/OutputSpoolingOperatorFactory.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/JsonArrayResultsIterator.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/OutputColumn.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/DataAttributesSerialization.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/PreferredQueryDataEncoderSelector.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoder.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoderSelector.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataJacksonModule.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataProducer.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/SegmentResource.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpooledBlock.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingConfig.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerBridge.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerRegistry.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingServerModule.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/CompressedQueryDataEncoder.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/EncryptingQueryDataEncoder.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/JsonQueryDataEncoder.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/Lz4QueryDataEncoder.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/QueryDataEncodingModule.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/SnappyQueryDataEncoder.java create mode 100644 core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/ZstdQueryDataEncoder.java create mode 100644 core/trino-main/src/test/java/io/trino/operator/TestOutputSpoolingController.java create mode 100644 core/trino-main/src/test/java/io/trino/operator/spooling/TestSpooledBlock.java create mode 100644 core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java create mode 100644 core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/protocol/SpooledSegmentHandle.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingContext.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManager.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerContext.java create mode 100644 core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerFactory.java create mode 100644 plugin/trino-spooling-filesystem/pom.xml create mode 100644 plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java create mode 100644 plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingConfig.java create mode 100644 plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java create mode 100644 plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManagerFactory.java create mode 100644 plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingModule.java create mode 100644 plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingPlugin.java create mode 100644 plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/SwitchingFileSystem.java create mode 100644 plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManager.java create mode 100644 testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java create mode 100644 testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonLz4SpooledDistributedQueries.java create mode 100644 testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSnappySpooledDistributedQueries.java create mode 100644 testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSpooledDistributedQueries.java create mode 100644 testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonZstdSpooledDistributedQueries.java diff --git a/client/trino-cli/pom.xml b/client/trino-cli/pom.xml index 166551d4649d..6ebd77e34af7 100644 --- a/client/trino-cli/pom.xml +++ b/client/trino-cli/pom.xml @@ -109,6 +109,18 @@ ${dep.jline.version} + + com.fasterxml.jackson.core + jackson-databind + runtime + + + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + runtime + + org.jline jline-terminal-ffm diff --git a/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java b/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java index d492f8073b54..31764c244a6c 100644 --- a/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java +++ b/client/trino-cli/src/test/java/io/trino/cli/TestQueryRunner.java @@ -14,11 +14,12 @@ package io.trino.cli; import com.google.common.collect.ImmutableList; -import io.airlift.json.JsonCodec; import io.airlift.units.Duration; import io.trino.client.ClientSession; import io.trino.client.ClientTypeSignature; import io.trino.client.Column; +import io.trino.client.JsonCodec; +import io.trino.client.QueryDataClientJacksonModule; import io.trino.client.QueryResults; import io.trino.client.StatementStats; import io.trino.client.uri.PropertyName; @@ -42,10 +43,10 @@ import static com.google.common.net.HttpHeaders.CONTENT_TYPE; import static com.google.common.net.HttpHeaders.LOCATION; import static com.google.common.net.HttpHeaders.SET_COOKIE; -import static io.airlift.json.JsonCodec.jsonCodec; import static io.trino.cli.ClientOptions.OutputFormat.CSV; import static io.trino.cli.TerminalUtils.getTerminal; import static io.trino.client.ClientStandardTypes.BIGINT; +import static io.trino.client.JsonCodec.jsonCodec; import static io.trino.client.auth.external.ExternalRedirectStrategy.PRINT; import static java.util.concurrent.TimeUnit.MINUTES; import static org.assertj.core.api.Assertions.assertThat; @@ -54,8 +55,7 @@ @TestInstance(PER_METHOD) public class TestQueryRunner { - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); - + private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class, new QueryDataClientJacksonModule()); private MockWebServer server; @BeforeEach diff --git a/client/trino-client/pom.xml b/client/trino-client/pom.xml index 3edef6c888cd..ea546d244572 100644 --- a/client/trino-client/pom.xml +++ b/client/trino-client/pom.xml @@ -63,11 +63,21 @@ okhttp-urlconnection + + com.squareup.okio + okio-jvm + + dev.failsafe failsafe + + io.airlift + aircompressor + + io.airlift units diff --git a/client/trino-client/src/main/java/io/trino/client/Column.java b/client/trino-client/src/main/java/io/trino/client/Column.java index 2aabb8033cce..e6eb539cd2cf 100644 --- a/client/trino-client/src/main/java/io/trino/client/Column.java +++ b/client/trino-client/src/main/java/io/trino/client/Column.java @@ -17,6 +17,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.errorprone.annotations.Immutable; +import java.util.Objects; + import static java.util.Objects.requireNonNull; @Immutable @@ -54,4 +56,25 @@ public ClientTypeSignature getTypeSignature() { return typeSignature; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Column column = (Column) o; + return Objects.equals(name, column.name) + && Objects.equals(type, column.type) + && Objects.equals(typeSignature, column.typeSignature); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type, typeSignature); + } } diff --git a/client/trino-client/src/main/java/io/trino/client/FixJsonDataUtils.java b/client/trino-client/src/main/java/io/trino/client/FixJsonDataUtils.java index 88710aa45c7d..97440db497a9 100644 --- a/client/trino-client/src/main/java/io/trino/client/FixJsonDataUtils.java +++ b/client/trino-client/src/main/java/io/trino/client/FixJsonDataUtils.java @@ -14,6 +14,7 @@ package io.trino.client; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import io.trino.client.ClientTypeSignatureParameter.ParameterKind; @@ -53,33 +54,35 @@ import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; -final class FixJsonDataUtils +public final class FixJsonDataUtils { private FixJsonDataUtils() {} - public static Iterable> fixData(List columns, List> data) + public static Iterable> fixData(List columns, Iterable> data) { if (data == null) { return null; } + ColumnTypeHandler[] typeHandlers = createTypeHandlers(columns); - ImmutableList.Builder> rows = ImmutableList.builderWithExpectedSize(data.size()); - for (List row : data) { - if (row.size() != typeHandlers.length) { - throw new IllegalArgumentException("row/column size mismatch"); - } - ArrayList newRow = new ArrayList<>(typeHandlers.length); - int column = 0; - for (Object value : row) { - if (value != null) { - value = typeHandlers[column].fixValue(value); - } - newRow.add(value); - column++; + return Iterables.transform(data, row -> fixRow(typeHandlers, row)); + } + + private static List fixRow(ColumnTypeHandler[] typeHandlers, List row) + { + if (row.size() != typeHandlers.length) { + throw new IllegalArgumentException("row/column size mismatch"); + } + ArrayList newRow = new ArrayList<>(typeHandlers.length); + int column = 0; + for (Object value : row) { + if (value != null) { + value = typeHandlers[column].fixValue(value); } - rows.add(unmodifiableList(newRow)); // allow nulls in list + newRow.add(value); + column++; } - return rows.build(); + return unmodifiableList(newRow); // allow nulls in list } private static ColumnTypeHandler[] createTypeHandlers(List columns) diff --git a/client/trino-client/src/main/java/io/trino/client/JsonCodec.java b/client/trino-client/src/main/java/io/trino/client/JsonCodec.java index 5bde9542209c..de3c062169bc 100644 --- a/client/trino-client/src/main/java/io/trino/client/JsonCodec.java +++ b/client/trino-client/src/main/java/io/trino/client/JsonCodec.java @@ -20,6 +20,7 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.MapperFeature; +import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; @@ -65,6 +66,11 @@ public static JsonCodec jsonCodec(Class type) return new JsonCodec<>(OBJECT_MAPPER_SUPPLIER.get(), type); } + public static JsonCodec jsonCodec(Class type, Module... extraModules) + { + return new JsonCodec<>(OBJECT_MAPPER_SUPPLIER.get().registerModules(extraModules), type); + } + private final ObjectMapper mapper; private final Type type; private final JavaType javaType; @@ -106,4 +112,14 @@ public T fromJson(InputStream inputStream) return value; } } + + public String toJson(T instance) + { + try { + return mapper.writerFor(javaType).writeValueAsString(instance); + } + catch (IOException exception) { + throw new IllegalArgumentException(String.format("%s could not be converted to JSON", instance.getClass().getName()), exception); + } + } } diff --git a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java index e09555d84755..2bbcca4ca8b9 100644 --- a/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java +++ b/client/trino-client/src/main/java/io/trino/client/ProtocolHeaders.java @@ -43,6 +43,7 @@ public final class ProtocolHeaders private final String requestClientCapabilities; private final String requestResourceEstimate; private final String requestExtraCredential; + private final String requestQueryDataEncoding; private final String responseSetCatalog; private final String responseSetSchema; private final String responseSetPath; @@ -89,6 +90,7 @@ private ProtocolHeaders(String name) requestClientCapabilities = prefix + "Client-Capabilities"; requestResourceEstimate = prefix + "Resource-Estimate"; requestExtraCredential = prefix + "Extra-Credential"; + requestQueryDataEncoding = prefix + "Query-Data-Encoding"; responseSetCatalog = prefix + "Set-Catalog"; responseSetSchema = prefix + "Set-Schema"; responseSetPath = prefix + "Set-Path"; @@ -198,6 +200,11 @@ public String requestExtraCredential() return requestExtraCredential; } + public String requestQueryDataEncoding() + { + return requestQueryDataEncoding; + } + public String responseSetCatalog() { return responseSetCatalog; diff --git a/client/trino-client/src/main/java/io/trino/client/QueryData.java b/client/trino-client/src/main/java/io/trino/client/QueryData.java index 9e16157b5182..72aae39584c8 100644 --- a/client/trino-client/src/main/java/io/trino/client/QueryData.java +++ b/client/trino-client/src/main/java/io/trino/client/QueryData.java @@ -13,9 +13,12 @@ */ package io.trino.client; +import jakarta.annotation.Nullable; + import java.util.List; public interface QueryData { + @Nullable Iterable> getData(); } diff --git a/client/trino-client/src/main/java/io/trino/client/QueryDataClientJacksonModule.java b/client/trino-client/src/main/java/io/trino/client/QueryDataClientJacksonModule.java new file mode 100644 index 000000000000..6724288e42b0 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/QueryDataClientJacksonModule.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.core.Version; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.StdDeserializer; +import com.fasterxml.jackson.databind.module.SimpleModule; +import io.trino.client.spooling.EncodedQueryData; + +import java.io.IOException; +import java.util.List; + +/** + * Decodes the direct and encoded protocols. + * + * If the "data" fields starts with an array - this is the direct protocol which requires reading values and wrapping them with a class. + * + * Otherwise, this is an encoded protocol. + */ +public class QueryDataClientJacksonModule + extends SimpleModule +{ + private static final TypeReference>> DIRECT_FORMAT = new TypeReference>>(){}; + private static final TypeReference ENCODED_FORMAT = new TypeReference(){}; + + public QueryDataClientJacksonModule() + { + super(QueryDataClientJacksonModule.class.getSimpleName(), Version.unknownVersion()); + addDeserializer(QueryData.class, new Deserializer()); + } + + private static class Deserializer + extends StdDeserializer + { + public Deserializer() + { + super(QueryData.class); + } + + @Override + public QueryData deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) + throws IOException + { + // If this is not JSON_ARRAY we are dealing with direct data encoding + if (jsonParser.currentToken().equals(JsonToken.START_ARRAY)) { + return RawQueryData.of(jsonParser.readValueAs(DIRECT_FORMAT)); + } + return jsonParser.readValueAs(ENCODED_FORMAT); + } + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/QueryDataDecoder.java b/client/trino-client/src/main/java/io/trino/client/QueryDataDecoder.java new file mode 100644 index 000000000000..30b1ebb14a8c --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/QueryDataDecoder.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client; + +import io.trino.client.spooling.DataAttributes; +import jakarta.annotation.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; + +public interface QueryDataDecoder +{ + interface Factory + { + QueryDataDecoder create(List columns, DataAttributes segmentAttributes); + + String encodingId(); + } + + @Nullable Iterable> decode(@Nullable InputStream input, DataAttributes queryAttributes) + throws IOException; + + String encodingId(); +} diff --git a/client/trino-client/src/main/java/io/trino/client/QueryResults.java b/client/trino-client/src/main/java/io/trino/client/QueryResults.java index 741d20b710fd..e7715cd8a3b5 100644 --- a/client/trino-client/src/main/java/io/trino/client/QueryResults.java +++ b/client/trino-client/src/main/java/io/trino/client/QueryResults.java @@ -14,6 +14,8 @@ package io.trino.client; import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; @@ -25,20 +27,18 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.unmodifiableIterable; -import static io.trino.client.FixJsonDataUtils.fixData; import static java.util.Objects.requireNonNull; @Immutable public class QueryResults - implements QueryStatusInfo, QueryData + implements QueryStatusInfo { private final String id; private final URI infoUri; private final URI partialCancelUri; private final URI nextUri; private final List columns; - private final Iterable> data; + private final QueryData data; private final StatementStats stats; private final QueryError error; private final List warnings; @@ -52,25 +52,25 @@ public QueryResults( @JsonProperty("partialCancelUri") URI partialCancelUri, @JsonProperty("nextUri") URI nextUri, @JsonProperty("columns") List columns, - @JsonProperty("data") List> data, + @JsonProperty("data") QueryData data, @JsonProperty("stats") StatementStats stats, @JsonProperty("error") QueryError error, @JsonProperty("warnings") List warnings, @JsonProperty("updateType") String updateType, @JsonProperty("updateCount") Long updateCount) { - this( - id, - infoUri, - partialCancelUri, - nextUri, - columns, - fixData(columns, data), - stats, - error, - firstNonNull(warnings, ImmutableList.of()), - updateType, - updateCount); + this.id = requireNonNull(id, "id is null"); + this.infoUri = requireNonNull(infoUri, "infoUri is null"); + this.partialCancelUri = partialCancelUri; + this.nextUri = nextUri; + this.columns = (columns != null) ? ImmutableList.copyOf(columns) : null; + this.data = data; + checkArgument(!hasData(data) || columns != null, "data present without columns"); + this.stats = requireNonNull(stats, "stats is null"); + this.error = error; + this.warnings = ImmutableList.copyOf(firstNonNull(warnings, ImmutableList.of())); + this.updateType = updateType; + this.updateCount = updateCount; } public QueryResults( @@ -86,18 +86,18 @@ public QueryResults( String updateType, Long updateCount) { - this.id = requireNonNull(id, "id is null"); - this.infoUri = requireNonNull(infoUri, "infoUri is null"); - this.partialCancelUri = partialCancelUri; - this.nextUri = nextUri; - this.columns = (columns != null) ? ImmutableList.copyOf(columns) : null; - this.data = (data != null) ? unmodifiableIterable(data) : null; - checkArgument(data == null || columns != null, "data present without columns"); - this.stats = requireNonNull(stats, "stats is null"); - this.error = error; - this.warnings = ImmutableList.copyOf(requireNonNull(warnings, "warnings is null")); - this.updateType = updateType; - this.updateCount = updateCount; + this( + id, + infoUri, + partialCancelUri, + nextUri, + columns, + RawQueryData.of(data), + stats, + error, + firstNonNull(warnings, ImmutableList.of()), + updateType, + updateCount); } @JsonProperty @@ -138,10 +138,15 @@ public List getColumns() return columns; } - @Nullable - @JsonProperty - @Override - public Iterable> getData() + @JsonIgnore + public QueryData getData() + { + return data; + } + + @JsonProperty("data") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public QueryData getRawData() { return data; } @@ -193,11 +198,22 @@ public String toString() .add("partialCancelUri", partialCancelUri) .add("nextUri", nextUri) .add("columns", columns) - .add("hasData", data != null) + .add("hasData", hasData(data)) .add("stats", stats) .add("error", error) .add("updateType", updateType) .add("updateCount", updateCount) .toString(); } + + private static boolean hasData(QueryData data) + { + if (data == null) { + return false; + } + if (data instanceof RawQueryData) { + return data.getData() != null; + } + return true; + } } diff --git a/client/trino-client/src/main/java/io/trino/client/RawQueryData.java b/client/trino-client/src/main/java/io/trino/client/RawQueryData.java new file mode 100644 index 000000000000..412d95ffb3c1 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/RawQueryData.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client; + +import jakarta.annotation.Nullable; + +import java.util.List; + +import static com.google.common.collect.Iterables.unmodifiableIterable; +import static io.trino.client.FixJsonDataUtils.fixData; + +/** + * Class represents QueryData serialized to JSON array of arrays of objects. + * It has custom handling and representation in the QueryDataJacksonModule. + */ +public class RawQueryData + implements QueryData +{ + private final Iterable> iterable; + + private RawQueryData(Iterable> values) + { + this.iterable = values == null ? null : unmodifiableIterable(values); + } + + @Override + public Iterable> getData() + { + return iterable; + } + + public static QueryData of(@Nullable Iterable> values) + { + return new RawQueryData(values); + } + + // JSON encoded looses type information. In order for it to be usable, we need to fix types + public QueryData fixTypes(List columns) + { + return RawQueryData.of(fixData(columns, iterable)); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java index cde74b30fa8c..c5342ff62337 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java @@ -32,4 +32,9 @@ public static StatementClient newStatementClient(OkHttpClient httpClient, Client { return new StatementClientV1((Call.Factory) httpClient, session, query, clientCapabilities); } + + public static StatementClient newStatementClient(OkHttpClient httpClient, QueryDataDecoder.Factory decoderFactory, ClientSession session, String query, Optional> clientCapabilities) + { + return new StatementClientV1((Call.Factory) httpClient, Optional.of(decoderFactory), session, query, clientCapabilities); + } } diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index 99bb1a7b41cc..d46f564f9071 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -21,6 +21,9 @@ import com.google.common.collect.Sets; import com.google.errorprone.annotations.ThreadSafe; import io.airlift.units.Duration; +import io.trino.client.spooling.DataAttributes; +import io.trino.client.spooling.EncodedQueryData; +import io.trino.client.spooling.SegmentLoader; import jakarta.annotation.Nullable; import okhttp3.Call; import okhttp3.Headers; @@ -53,6 +56,7 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Throwables.getCausalChain; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.net.HttpHeaders.ACCEPT_ENCODING; import static com.google.common.net.HttpHeaders.USER_AGENT; @@ -71,7 +75,6 @@ class StatementClientV1 implements StatementClient { private static final MediaType MEDIA_TYPE_TEXT = MediaType.parse("text/plain; charset=utf-8"); - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); private static final Splitter COLLECTION_HEADER_SPLITTER = Splitter.on('=').limit(2).trimResults(); private static final String USER_AGENT_VALUE = StatementClientV1.class.getSimpleName() + @@ -100,10 +103,21 @@ class StatementClientV1 private final Optional originalUser; private final String clientCapabilities; private final boolean compressionDisabled; + private final JsonCodec jsonCodec; private final AtomicReference state = new AtomicReference<>(State.RUNNING); + // Encoded data + private final Optional queryDataDecoderFactory; + private final SegmentLoader segmentDownloader; + private final AtomicReference decoder = new AtomicReference<>(); + public StatementClientV1(Call.Factory httpCallFactory, ClientSession session, String query, Optional> clientCapabilities) + { + this(httpCallFactory, Optional.empty(), session, query, clientCapabilities); + } + + public StatementClientV1(Call.Factory httpCallFactory, Optional queryDataDecoder, ClientSession session, String query, Optional> clientCapabilities) { requireNonNull(httpCallFactory, "httpCallFactory is null"); requireNonNull(session, "session is null"); @@ -125,15 +139,17 @@ public StatementClientV1(Call.Factory httpCallFactory, ClientSession session, St .map(Enum::name) .collect(toImmutableSet()))); this.compressionDisabled = session.isCompressionDisabled(); + this.jsonCodec = jsonCodec(QueryResults.class, new QueryDataClientJacksonModule()); + this.queryDataDecoderFactory = requireNonNull(queryDataDecoder, "queryDataDecoder is null"); + this.segmentDownloader = new SegmentLoader(); - Request request = buildQueryRequest(session, query); - + Request request = buildQueryRequest(session, query, queryDataDecoder.map(QueryDataDecoder.Factory::encodingId)); // Pass empty as materializedJsonSizeLimit to always materialize the first response // to avoid losing the response body if the initial response parsing fails executeRequest(request, "starting query", OptionalLong.empty(), this::isTransient); } - private Request buildQueryRequest(ClientSession session, String query) + private Request buildQueryRequest(ClientSession session, String query, Optional requestedEncoding) { HttpUrl url = HttpUrl.get(session.getServer()); if (url == null) { @@ -195,6 +211,8 @@ private Request buildQueryRequest(ClientSession session, String query) builder.addHeader(TRINO_HEADERS.requestClientCapabilities(), clientCapabilities); + requestedEncoding.ifPresent(encoding -> builder.addHeader(TRINO_HEADERS.requestQueryDataEncoding(), encoding)); + return builder.build(); } @@ -250,7 +268,20 @@ public QueryStatusInfo currentStatusInfo() public QueryData currentData() { checkState(isRunning(), "current position is not valid (cursor past end)"); - return currentResults.get(); + QueryResults queryResults = currentResults.get(); + + if (queryResults == null || queryResults.getData() == null) { + return RawQueryData.of(null); + } + + if (queryResults.getData() instanceof RawQueryData) { + // We need to reinterpret JSON values to have correct types + return ((RawQueryData) queryResults.getData()) + .fixTypes(queryResults.getColumns()); + } + + EncodedQueryData queryData = (EncodedQueryData) queryResults.getData(); + return queryData.toRawData(decoder.get(), segmentDownloader); } @Override @@ -399,7 +430,7 @@ private boolean executeRequest(Request request, String taskName, OptionalLong ma JsonResponse response; try { - response = JsonResponse.execute(QUERY_RESULTS_CODEC, httpCallFactory, request, materializedJsonSizeLimit); + response = JsonResponse.execute(jsonCodec, httpCallFactory, request, materializedJsonSizeLimit); } catch (RuntimeException e) { if (!isRetryable.apply(e)) { @@ -485,6 +516,20 @@ private void processResponse(Headers headers, QueryResults results) clearTransactionId.set(true); } + // Make sure that decoder and dataAttributes are set before currentResults + if (results.getData() instanceof EncodedQueryData) { + EncodedQueryData encodedData = (EncodedQueryData) results.getData(); + DataAttributes queryAttributed = encodedData.getMetadata(); + if (decoder.get() == null) { + QueryDataDecoder queryDataDecoder = queryDataDecoderFactory + .orElseThrow(() -> new IllegalStateException("Received encoded data format but there is no decoder")) + .create(results.getColumns(), queryAttributed); + decoder.set(queryDataDecoder); + } + + verify(decoder.get().encodingId().equals(encodedData.getEncodingId()), "Decoder has wrong encoding id, expected %s, got %s", encodedData.getEncodingId(), decoder.get().encodingId()); + } + currentResults.set(results); } @@ -532,6 +577,11 @@ public void close() if (uri != null) { httpDelete(uri); } + try { + segmentDownloader.close(); + } + catch (Exception ignored) { + } } } diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/DataAttribute.java b/client/trino-client/src/main/java/io/trino/client/spooling/DataAttribute.java new file mode 100644 index 000000000000..383ef77f341e --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/DataAttribute.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling; + +import static com.google.common.base.Verify.verify; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public enum DataAttribute +{ + // Offset of the segment in relation to the whole result set + ROW_OFFSET("rowOffset", Long.class), + // Number of rows in the segment + ROWS_COUNT("rowsCount", Long.class), + // Actual size of the segment in bytes (uncompressed or compressed) + BYTE_SIZE("byteSize", Integer.class), + // Size of the segment in bytes after decompression, added only to compressed segments + UNCOMPRESSED_SIZE("uncompressedSize", Integer.class), + // Symmetric encryption key used to decrypt the segment, added only to encrypted segments + ENCRYPTION_KEY("encryptionKey", String.class), + // Encryption cipher name, added only to encrypted segments + ENCRYPTION_CIPHER_NAME("cipherName", String.class); + + private final String name; + private final Class javaClass; + + DataAttribute(String name, Class javaClass) + { + this.name = requireNonNull(name, "name is null"); + this.javaClass = requireNonNull(javaClass, "javaClass is null"); + } + + public String attributeName() + { + return name; + } + + public Class javaClass() + { + return javaClass; + } + + public static DataAttribute getByName(String name) + { + for (DataAttribute attributeName : DataAttribute.values()) { + if (attributeName.attributeName().equals(name)) { + return attributeName; + } + } + throw new IllegalArgumentException("Unknown attribute name: " + name); + } + + public T decode(Class clazz, Object value) + { + verify(clazz == javaClass, "Expected %s, but got %s", javaClass, clazz); + if (clazz == Long.class) { + if (value instanceof Long) { + return clazz.cast(value); + } + if (value instanceof Integer) { + return clazz.cast(Integer.class.cast(value).longValue()); + } + if (value instanceof String) { + return clazz.cast(Long.parseLong(String.class.cast(value))); + } + } + + if (clazz == Integer.class) { + if (value instanceof Long) { + return clazz.cast(toIntExact(Long.class.cast(value))); + } + if (value instanceof Integer) { + return clazz.cast(value); + } + if (value instanceof String) { + return clazz.cast(Integer.parseInt(String.class.cast(value))); + } + } + + if (clazz == String.class) { + if (value instanceof String) { + return clazz.cast(value); + } + } + + if (clazz == Boolean.class) { + if (value instanceof Boolean) { + return clazz.cast(value); + } + + if (value instanceof String) { + return clazz.cast(Boolean.parseBoolean(String.class.cast(value))); + } + } + + throw new IllegalArgumentException("Unsupported class: " + clazz); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/DataAttributes.java b/client/trino-client/src/main/java/io/trino/client/spooling/DataAttributes.java new file mode 100644 index 000000000000..2ce3f6bf79cd --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/DataAttributes.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling; + +import com.google.common.collect.ImmutableMap; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Verify.verify; +import static java.lang.String.format; + +public final class DataAttributes +{ + final Map attributes; + + DataAttributes(Map attributes) + { + this.attributes = ImmutableMap.copyOf(firstNonNull(attributes, ImmutableMap.of())); + } + + public static DataAttributes empty() + { + return new DataAttributes(ImmutableMap.of()); + } + + public T get(DataAttribute attribute, Class clazz) + { + return getOptional(attribute, clazz) + .orElseThrow(() -> new IllegalArgumentException(format("Required data attribute '%s' does not exist", attribute.name()))); + } + + public Optional getOptional(DataAttribute attribute, Class clazz) + { + return Optional.ofNullable(attributes.get(attribute.attributeName())) + .map(value -> attribute.decode(clazz, value)); + } + + public Map toMap() + { + return attributes; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DataAttributes that = (DataAttributes) o; + return Objects.equals(attributes, that.attributes); + } + + @Override + public int hashCode() + { + return Objects.hashCode(attributes); + } + + public Builder toBuilder() + { + return new Builder(this); + } + + public static Builder builder() + { + return new Builder(); + } + + public static Builder builder(DataAttributes dataAttributes) + { + return new Builder(dataAttributes); + } + + public static class Builder + { + private final ImmutableMap.Builder builder = ImmutableMap.builder(); + + private Builder() {} + + private Builder(DataAttributes attributes) + { + builder.putAll(attributes.attributes); + } + + public Builder set(DataAttribute attribute, T value) + { + verify(attribute.javaClass().isInstance(value), "Invalid value type: %s for attribute: %s", value.getClass(), attribute.attributeName()); + builder.put(attribute.attributeName(), value); + return this; + } + + public Builder set(String key, Object value) + { + DataAttribute attribute = DataAttribute.getByName(key); + return set(attribute, attribute.decode(attribute.javaClass(), value)); + } + + public DataAttributes build() + { + return new DataAttributes(builder.buildKeepingLast()); + } + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/EncodedQueryData.java b/client/trino-client/src/main/java/io/trino/client/spooling/EncodedQueryData.java new file mode 100644 index 000000000000..002f7b36abc6 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/EncodedQueryData.java @@ -0,0 +1,164 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import io.trino.client.QueryData; +import io.trino.client.QueryDataDecoder; +import io.trino.client.RawQueryData; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.collect.Iterables.concat; +import static com.google.common.collect.Iterables.transform; +import static com.google.common.collect.Iterables.unmodifiableIterable; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class EncodedQueryData + implements QueryData +{ + private final String encodingId; + private final DataAttributes metadata; + private final List segments; + + @JsonCreator + public EncodedQueryData(@JsonProperty("encodingId") String encodingId, @JsonProperty("metadata") Map metadata, @JsonProperty("segments") List segments) + { + this(encodingId, new DataAttributes(metadata), segments); + } + + public EncodedQueryData(String encodingId, DataAttributes metadata, List segments) + { + this.encodingId = requireNonNull(encodingId, "encodingId is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.segments = ImmutableList.copyOf(requireNonNull(segments, "segments is null")); + } + + @JsonProperty("segments") + public List getSegments() + { + return segments; + } + + @JsonProperty("encodingId") + public String getEncodingId() + { + return encodingId; + } + + @JsonInclude(JsonInclude.Include.NON_EMPTY) + @JsonProperty("metadata") + public Map getJsonMetadata() + { + return metadata.attributes; + } + + @JsonIgnore + public DataAttributes getMetadata() + { + return metadata; + } + + @Override + public Iterable> getData() + { + throw new UnsupportedOperationException("EncodedQueryData required decoding via matching QueryDataDecoder"); + } + + public QueryData toRawData(QueryDataDecoder decoder, SegmentLoader segmentLoader) + { + if (!decoder.encodingId().equals(encodingId)) { + throw new IllegalArgumentException(format("Invalid decoder supplied, expected %s, got %s", encodingId, decoder.encodingId())); + } + + return RawQueryData.of(unmodifiableIterable(concat(transform(segments, segment -> { + if (segment instanceof InlineSegment) { + InlineSegment inline = (InlineSegment) segment; + try { + return decoder.decode(new ByteArrayInputStream(inline.getData()), inline.getMetadata()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + if (segment instanceof SpooledSegment) { + SpooledSegment spooled = (SpooledSegment) segment; + try (InputStream stream = segmentLoader.load(spooled)) { + return decoder.decode(stream, segment.getMetadata()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + throw new IllegalArgumentException("Unexpected segment type: " + segment.getClass().getSimpleName()); + })))); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("encodingId", encodingId) + .add("segments", segments) + .add("metadata", metadata.attributes.keySet()) + .toString(); + } + + public static Builder builder(String format) + { + return new Builder(format); + } + + public static class Builder + { + private final String encodingId; + private final ImmutableList.Builder segments = ImmutableList.builder(); + private DataAttributes metadata = DataAttributes.empty(); + + private Builder(String encodingId) + { + this.encodingId = requireNonNull(encodingId, "encodingId is null"); + } + + public Builder withSegment(Segment segment) + { + this.segments.add(segment); + return this; + } + + public Builder withAttributes(DataAttributes attributes) + { + this.metadata = requireNonNull(attributes, "attributes is null"); + return this; + } + + public EncodedQueryData build() + { + return new EncodedQueryData(encodingId, metadata, segments.build()); + } + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/InlineSegment.java b/client/trino-client/src/main/java/io/trino/client/spooling/InlineSegment.java new file mode 100644 index 000000000000..0997953ebbd8 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/InlineSegment.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Map; + +import static java.lang.String.format; + +public class InlineSegment + extends Segment +{ + private final byte[] data; + + @JsonCreator + public InlineSegment(@JsonProperty("data") byte[] data, @JsonProperty("metadata") Map metadata) + { + this(data, new DataAttributes(metadata)); + } + + InlineSegment(byte[] data, DataAttributes metadata) + { + super(metadata); + this.data = data; + } + + @JsonProperty("data") + public byte[] getData() + { + return data; + } + + @Override + public String toString() + { + return format("InlineSegment{offset=%d, rows=%d, size=%d}", getOffset(), getRowsCount(), getDataSizeBytes()); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/Segment.java b/client/trino-client/src/main/java/io/trino/client/spooling/Segment.java new file mode 100644 index 000000000000..5b0fbf4d8840 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/Segment.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +import java.net.URI; +import java.util.Map; +import java.util.Optional; + +import static io.trino.client.spooling.DataAttribute.BYTE_SIZE; +import static io.trino.client.spooling.DataAttribute.ROWS_COUNT; +import static io.trino.client.spooling.DataAttribute.ROW_OFFSET; +import static java.util.Objects.requireNonNull; + +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = InlineSegment.class, name = "inline"), + @JsonSubTypes.Type(value = SpooledSegment.class, name = "spooled")}) +public abstract class Segment +{ + private final DataAttributes metadata; + + public Segment(DataAttributes metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @JsonProperty("metadata") + public Map getJsonMetadata() + { + return metadata.attributes; + } + + @JsonIgnore + public DataAttributes getMetadata() + { + return metadata; + } + + @JsonIgnore + public long getOffset() + { + return getRequiredAttribute(ROW_OFFSET, Long.class); + } + + @JsonIgnore + public long getRowsCount() + { + return getRequiredAttribute(ROWS_COUNT, Long.class); + } + + @JsonIgnore + public int getDataSizeBytes() + { + return getRequiredAttribute(BYTE_SIZE, Integer.class); + } + + public Optional getAttribute(DataAttribute name, Class clazz) + { + return Optional.ofNullable(metadata.get(name, clazz)); + } + + public T getRequiredAttribute(DataAttribute name, Class clazz) + { + return getAttribute(name, clazz) + .orElseThrow(() -> new IllegalArgumentException("Missing required attribute: " + name.attributeName())); + } + + public static Segment inlined(byte[] data, DataAttributes attributes) + { + return new InlineSegment(data, attributes); + } + + public static Segment spooled(URI segmentUri, DataAttributes attributes) + { + return new SpooledSegment(segmentUri, attributes); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/SegmentLoader.java b/client/trino-client/src/main/java/io/trino/client/spooling/SegmentLoader.java new file mode 100644 index 000000000000..20e375978d62 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/SegmentLoader.java @@ -0,0 +1,162 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling; + +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static java.util.Objects.requireNonNull; + +public class SegmentLoader + implements AutoCloseable +{ + private static final Logger logger = Logger.getLogger(SegmentLoader.class.getPackage().getName()); + private final OkHttpClient client; + + public SegmentLoader() + { + this.client = new OkHttpClient(); + } + + public InputStream load(SpooledSegment segment) + throws IOException + { + return loadFromURI(segment.getDataUri()); + } + + public InputStream loadFromURI(URI segmentUri) + throws IOException + { + Request request = new Request.Builder() + .url(segmentUri.toString()) + .build(); + + Response response = client.newCall(request).execute(); + ResponseBody body = response.body(); + + if (response.isSuccessful()) { + return delegatingInputStream(response, requireNonNull(body, "response body is null").source().inputStream(), segmentUri); + } + + throw new IOException("Could not open segment for streaming " + response.code() + " " + response.message()); + } + + private void delete(URI segmentUri) + { + Request deleteRequest = new Request.Builder() + .delete() + .url(segmentUri.toString()) + .build(); + + client.newCall(deleteRequest).enqueue(new Callback() + { + @Override + public void onFailure(Call call, IOException cause) + { + logger.log(Level.WARNING, "Could not acknowledge spooled segment", cause); + } + + @Override + public void onResponse(Call call, Response response) + { + } + }); + } + + private InputStream delegatingInputStream(Response response, InputStream delegate, URI segmentUri) + { + return new InputStream() + { + @Override + public int read(byte[] b) + throws IOException + { + return delegate.read(b); + } + + @Override + public int read(byte[] b, int off, int len) + throws IOException + { + return delegate.read(b, off, len); + } + + @Override + public long skip(long n) + throws IOException + { + return delegate.skip(n); + } + + @Override + public int available() + throws IOException + { + return delegate.available(); + } + + @Override + public void close() + throws IOException + { + response.close(); + delegate.close(); + delete(segmentUri); + } + + @Override + public void mark(int readlimit) + { + delegate.mark(readlimit); + } + + @Override + public void reset() + throws IOException + { + delegate.reset(); + } + + @Override + public boolean markSupported() + { + return delegate.markSupported(); + } + + @Override + public int read() + throws IOException + { + return delegate.read(); + } + }; + } + + @Override + public void close() + throws Exception + { + client.dispatcher().executorService().shutdown(); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/SpooledSegment.java b/client/trino-client/src/main/java/io/trino/client/spooling/SpooledSegment.java new file mode 100644 index 000000000000..06ef9ee04231 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/SpooledSegment.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.net.URI; +import java.util.Map; + +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class SpooledSegment + extends Segment +{ + private final URI dataUri; + + @JsonCreator + public SpooledSegment(@JsonProperty("dataUri") URI dataUri, @JsonProperty("metadata") Map metadata) + { + this(dataUri, new DataAttributes(metadata)); + } + + SpooledSegment(URI dataUri, DataAttributes metadata) + { + super(metadata); + this.dataUri = requireNonNull(dataUri, "dataUri is null"); + } + + @JsonProperty("dataUri") + public URI getDataUri() + { + return dataUri; + } + + @Override + public String toString() + { + return format("SpooledSegment{offset=%d, rows=%d, size=%d}", getOffset(), getRowsCount(), getDataSizeBytes()); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/CipherUtils.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/CipherUtils.java new file mode 100644 index 000000000000..ad4441c52277 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/CipherUtils.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import javax.crypto.spec.SecretKeySpec; + +import static java.util.Base64.getMimeDecoder; +import static java.util.Base64.getMimeEncoder; + +public class CipherUtils +{ + private CipherUtils() {} + + public static String serializeSecretKey(SecretKeySpec key) + { + return getMimeEncoder().encodeToString(key.getEncoded()); + } + + public static SecretKeySpec deserializeSecretKey(String serializedKey, String cipherName) + { + return new SecretKeySpec(getMimeDecoder().decode(serializedKey), cipherName); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/CompressedQueryDataDecoder.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/CompressedQueryDataDecoder.java new file mode 100644 index 000000000000..8522c7257d53 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/CompressedQueryDataDecoder.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import io.trino.client.QueryDataDecoder; +import io.trino.client.spooling.DataAttribute; +import io.trino.client.spooling.DataAttributes; +import jakarta.annotation.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public abstract class CompressedQueryDataDecoder + implements QueryDataDecoder +{ + protected final QueryDataDecoder delegate; + + public CompressedQueryDataDecoder(QueryDataDecoder delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + abstract InputStream decompress(InputStream inputStream, int uncompressedSize) + throws IOException; + + @Override + public @Nullable Iterable> decode(@Nullable InputStream stream, DataAttributes metadata) + throws IOException + { + if (stream == null) { + return null; + } + + Optional uncompressedSize = metadata.getOptional(DataAttribute.UNCOMPRESSED_SIZE, Integer.class); + if (uncompressedSize.isPresent()) { + return delegate.decode(decompress(stream, uncompressedSize.get()), metadata); + } + // Data not compressed - below threshold + return delegate.decode(stream, metadata); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/DecryptingQueryDataDecoder.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/DecryptingQueryDataDecoder.java new file mode 100644 index 000000000000..8c8967b30848 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/DecryptingQueryDataDecoder.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import io.trino.client.Column; +import io.trino.client.QueryDataDecoder; +import io.trino.client.spooling.DataAttributes; +import jakarta.annotation.Nullable; + +import javax.crypto.Cipher; +import javax.crypto.CipherInputStream; +import javax.crypto.spec.SecretKeySpec; + +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Optional; + +import static io.trino.client.spooling.DataAttribute.ENCRYPTION_CIPHER_NAME; +import static io.trino.client.spooling.DataAttribute.ENCRYPTION_KEY; +import static io.trino.client.spooling.encoding.CipherUtils.deserializeSecretKey; +import static java.util.Objects.requireNonNull; +import static javax.crypto.Cipher.DECRYPT_MODE; + +public class DecryptingQueryDataDecoder + implements QueryDataDecoder +{ + private final QueryDataDecoder delegate; + private final SecretKeySpec encryptionKey; + + public DecryptingQueryDataDecoder(QueryDataDecoder delegate, SecretKeySpec encryptionKey) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey is null"); + } + + @Override + public @Nullable Iterable> decode(@Nullable InputStream input, DataAttributes attributes) + throws IOException + { + try (CipherInputStream encryptedInput = new CipherInputStream(input, createCipher(encryptionKey))) { + return delegate.decode(encryptedInput, attributes); + } + } + + @Override + public String encodingId() + { + return delegate.encodingId(); + } + + private static Cipher createCipher(SecretKeySpec key) + { + try { + Cipher cipher = Cipher.getInstance("AES"); + cipher.init(DECRYPT_MODE, key); + return cipher; + } + catch (Exception e) { + throw new RuntimeException("Failed to initialize cipher", e); + } + } + + public static class Factory + implements QueryDataDecoder.Factory + { + private final QueryDataDecoder.Factory delegateFactory; + + public Factory(QueryDataDecoder.Factory delegateFactory) + { + this.delegateFactory = requireNonNull(delegateFactory, "delegateFactory is null"); + } + + @Override + public QueryDataDecoder create(List columns, DataAttributes queryAttributes) + { + QueryDataDecoder delegate = delegateFactory.create(columns, queryAttributes); + + Optional encryptionKey = queryAttributes.getOptional(ENCRYPTION_KEY, String.class) + .map(key -> deserializeSecretKey(key, queryAttributes.get(ENCRYPTION_CIPHER_NAME, String.class))); + + if (encryptionKey.isPresent()) { + return new DecryptingQueryDataDecoder(delegate, encryptionKey.get()); + } + + return delegate; + } + + @Override + public String encodingId() + { + return delegateFactory.encodingId(); + } + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/JsonQueryDataDecoder.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/JsonQueryDataDecoder.java new file mode 100644 index 000000000000..c0cf6c8687b4 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/JsonQueryDataDecoder.java @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.trino.client.Column; +import io.trino.client.QueryDataDecoder; +import io.trino.client.spooling.DataAttributes; +import jakarta.annotation.Nullable; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.util.List; + +import static io.trino.client.FixJsonDataUtils.fixData; +import static java.util.Objects.requireNonNull; + +public class JsonQueryDataDecoder + implements QueryDataDecoder +{ + public static final String ENCODING_ID = "json-ext"; + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final TypeReference>> TYPE = new TypeReference>>() {}; + private final List columns; + + public JsonQueryDataDecoder(List columns) + { + this.columns = requireNonNull(columns, "columns is null"); + } + + @Override + public @Nullable Iterable> decode(InputStream stream, DataAttributes attributes) + { + if (stream == null) { + return null; + } + + try { + return fixData(columns, OBJECT_MAPPER.readValue(stream, TYPE)); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public String encodingId() + { + return ENCODING_ID; + } + + public static class Factory + implements QueryDataDecoder.Factory + { + @Override + public QueryDataDecoder create(List columns, DataAttributes queryAttributes) + { + return new JsonQueryDataDecoder(columns); + } + + @Override + public String encodingId() + { + return ENCODING_ID; + } + } + + public static class ZstdFactory + extends Factory + { + @Override + public QueryDataDecoder create(List columns, DataAttributes queryAttributes) + { + return new ZstdQueryDataDecoder(super.create(columns, queryAttributes)); + } + + @Override + public String encodingId() + { + return super.encodingId() + "+zstd"; + } + } + + public static class SnappyFactory + extends Factory + { + @Override + public QueryDataDecoder create(List columns, DataAttributes queryAttributes) + { + return new SnappyQueryDataDecoder(super.create(columns, queryAttributes)); + } + + @Override + public String encodingId() + { + return super.encodingId() + "+snappy"; + } + } + + public static class Lz4Factory + extends Factory + { + @Override + public QueryDataDecoder create(List columns, DataAttributes queryAttributes) + { + return new Lz4QueryDataDecoder(super.create(columns, queryAttributes)); + } + + @Override + public String encodingId() + { + return super.encodingId() + "+lz4"; + } + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/Lz4QueryDataDecoder.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/Lz4QueryDataDecoder.java new file mode 100644 index 000000000000..3eec73667d62 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/Lz4QueryDataDecoder.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import com.google.common.io.ByteStreams; +import io.airlift.compress.lz4.Lz4Decompressor; +import io.trino.client.QueryDataDecoder; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +import static com.google.common.base.Verify.verify; + +public class Lz4QueryDataDecoder + extends CompressedQueryDataDecoder +{ + public Lz4QueryDataDecoder(QueryDataDecoder delegate) + { + super(delegate); + } + + @Override + InputStream decompress(InputStream stream, int uncompressedSize) + throws IOException + { + Lz4Decompressor decompressor = new Lz4Decompressor(); + byte[] bytes = ByteStreams.toByteArray(stream); + byte[] output = new byte[uncompressedSize]; + verify(decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length) == uncompressedSize, "Decompressed size does not match expected size"); + return new ByteArrayInputStream(output); + } + + @Override + public String encodingId() + { + return delegate.encodingId() + "+lz4"; + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/QueryDataDecoders.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/QueryDataDecoders.java new file mode 100644 index 000000000000..fa58bea95547 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/QueryDataDecoders.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import io.trino.client.QueryDataDecoder; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Verify.verify; + +public class QueryDataDecoders +{ + private static Map factories = new HashMap<>(); + + private static void register(QueryDataDecoder.Factory factory) + { + if (factories.containsKey(factory.encodingId())) { + throw new IllegalStateException("Encoding " + factory.encodingId() + " already registered."); + } + factories.put(factory.encodingId(), factory); + } + + static + { + register(new JsonQueryDataDecoder.Factory()); + register(new JsonQueryDataDecoder.ZstdFactory()); + register(new JsonQueryDataDecoder.SnappyFactory()); + register(new JsonQueryDataDecoder.Lz4Factory()); + } + + private QueryDataDecoders() {} + + public static QueryDataDecoder.Factory get(String encodingId) + { + if (!factories.containsKey(encodingId)) { + throw new IllegalArgumentException("Unknown encoding id: " + encodingId); + } + + QueryDataDecoder.Factory factory = factories.get(encodingId); + verify(factory.encodingId().equals(encodingId), "Factory has wrong encoding id, expected %s, got %s", encodingId, factory.encodingId()); + return new DecryptingQueryDataDecoder.Factory(factory); + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/SnappyQueryDataDecoder.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/SnappyQueryDataDecoder.java new file mode 100644 index 000000000000..5a16aeef8ec0 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/SnappyQueryDataDecoder.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import com.google.common.io.ByteStreams; +import io.airlift.compress.snappy.SnappyDecompressor; +import io.trino.client.QueryDataDecoder; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; + +import static com.google.common.base.Verify.verify; + +public class SnappyQueryDataDecoder + extends CompressedQueryDataDecoder +{ + public SnappyQueryDataDecoder(QueryDataDecoder delegate) + { + super(delegate); + } + + @Override + InputStream decompress(InputStream stream, int uncompressedSize) + throws IOException + { + SnappyDecompressor decompressor = new SnappyDecompressor(); + byte[] bytes = ByteStreams.toByteArray(stream); + byte[] output = new byte[uncompressedSize]; + verify(decompressor.decompress(bytes, 0, bytes.length, output, 0, output.length) == uncompressedSize, "Decompressed size does not match expected size"); + return new ByteArrayInputStream(output); + } + + @Override + public String encodingId() + { + return delegate.encodingId() + "+snappy"; + } +} diff --git a/client/trino-client/src/main/java/io/trino/client/spooling/encoding/ZstdQueryDataDecoder.java b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/ZstdQueryDataDecoder.java new file mode 100644 index 000000000000..922251f62387 --- /dev/null +++ b/client/trino-client/src/main/java/io/trino/client/spooling/encoding/ZstdQueryDataDecoder.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.client.spooling.encoding; + +import io.airlift.compress.zstd.ZstdInputStream; +import io.trino.client.QueryDataDecoder; + +import java.io.InputStream; + +public class ZstdQueryDataDecoder + extends CompressedQueryDataDecoder +{ + public ZstdQueryDataDecoder(QueryDataDecoder delegate) + { + super(delegate); + } + + @Override + InputStream decompress(InputStream inputStream, int uncompressedSize) + { + return new ZstdInputStream(inputStream); + } + + @Override + public String encodingId() + { + return delegate.encodingId() + "+zstd"; + } +} diff --git a/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java b/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java index 3dd478efcf4b..47a851ebe927 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java +++ b/client/trino-client/src/test/java/io/trino/client/TestQueryResults.java @@ -24,7 +24,7 @@ public class TestQueryResults { - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); + private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class, new QueryDataClientJacksonModule()); private static final String GOLDEN_VALUE = "{\n" + " \"id\" : \"20160128_214710_00012_rk68b\",\n" + diff --git a/client/trino-client/src/test/java/io/trino/client/TestRetry.java b/client/trino-client/src/test/java/io/trino/client/TestRetry.java index 40fde2e890e0..6ff2bd320382 100644 --- a/client/trino-client/src/test/java/io/trino/client/TestRetry.java +++ b/client/trino-client/src/test/java/io/trino/client/TestRetry.java @@ -14,7 +14,6 @@ package io.trino.client; import com.google.common.collect.ImmutableList; -import io.airlift.json.JsonCodec; import io.airlift.units.Duration; import okhttp3.OkHttpClient; import okhttp3.mockwebserver.MockResponse; @@ -35,7 +34,7 @@ import static com.google.common.net.HttpHeaders.CONTENT_TYPE; import static com.google.common.net.MediaType.JSON_UTF_8; -import static io.airlift.json.JsonCodec.jsonCodec; +import static io.trino.client.JsonCodec.jsonCodec; import static io.trino.client.StatementClientFactory.newStatementClient; import static io.trino.spi.type.StandardTypes.INTEGER; import static io.trino.spi.type.StandardTypes.VARCHAR; @@ -49,7 +48,7 @@ public class TestRetry { private MockWebServer server; - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); + private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class, new QueryDataClientJacksonModule()); @BeforeEach public void setup() @@ -141,9 +140,9 @@ private String newQueryResults(String state) Stream.of(new Column("id", INTEGER, new ClientTypeSignature("integer")), new Column("name", VARCHAR, new ClientTypeSignature("varchar"))) .collect(toList()), - IntStream.range(0, numRecords) + RawQueryData.of(IntStream.range(0, numRecords) .mapToObj(index -> Stream.of((Object) index, "a").collect(toList())) - .collect(toList()), + .collect(toList())), new StatementStats(state, state.equals("QUEUED"), true, OptionalDouble.of(0), OptionalDouble.of(0), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, null), null, ImmutableList.of(), diff --git a/client/trino-jdbc/pom.xml b/client/trino-jdbc/pom.xml index ce409c43ac86..b27b5e88cda2 100644 --- a/client/trino-jdbc/pom.xml +++ b/client/trino-jdbc/pom.xml @@ -91,6 +91,12 @@ provided + + com.fasterxml.jackson.core + jackson-databind + runtime + + com.google.inject guice diff --git a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestProgressMonitor.java b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestProgressMonitor.java index 056728cb2dbe..588fe8f8fc56 100644 --- a/client/trino-jdbc/src/test/java/io/trino/jdbc/TestProgressMonitor.java +++ b/client/trino-jdbc/src/test/java/io/trino/jdbc/TestProgressMonitor.java @@ -14,11 +14,13 @@ package io.trino.jdbc; import com.google.common.collect.ImmutableList; -import io.airlift.json.JsonCodec; import io.trino.client.ClientTypeSignature; import io.trino.client.Column; +import io.trino.client.JsonCodec; import io.trino.client.QueryResults; +import io.trino.client.RawQueryData; import io.trino.client.StatementStats; +import io.trino.server.protocol.spooling.QueryDataJacksonModule; import io.trino.spi.type.StandardTypes; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -41,8 +43,8 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.net.HttpHeaders.CONTENT_TYPE; -import static io.airlift.json.JsonCodec.jsonCodec; import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; +import static io.trino.client.JsonCodec.jsonCodec; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD; @@ -52,7 +54,7 @@ @Execution(SAME_THREAD) public class TestProgressMonitor { - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); + private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class, new QueryDataJacksonModule()); private MockWebServer server; @@ -87,14 +89,13 @@ private List createResults() private String newQueryResults(Integer partialCancelId, Integer nextUriId, List responseColumns, List> data, String state) { String queryId = "20160128_214710_00012_rk68b"; - QueryResults queryResults = new QueryResults( queryId, server.url("/query.html?" + queryId).uri(), partialCancelId == null ? null : server.url(format("/v1/statement/partialCancel/%s.%s", queryId, partialCancelId)).uri(), nextUriId == null ? null : server.url(format("/v1/statement/%s/%s", queryId, nextUriId)).uri(), responseColumns, - data, + RawQueryData.of(data), new StatementStats(state, state.equals("QUEUED"), true, OptionalDouble.of(0), OptionalDouble.of(0), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, null), null, ImmutableList.of(), diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml index c11bc3de9581..d880daf08d66 100644 --- a/core/trino-main/pom.xml +++ b/core/trino-main/pom.xml @@ -18,6 +18,7 @@ + com.clearspring.analytics stream @@ -398,6 +399,12 @@ runtime + + com.squareup.okhttp3 + okhttp-urlconnection + runtime + + net.java.dev.jna jna-platform @@ -410,12 +417,6 @@ test - - com.squareup.okhttp3 - okhttp-urlconnection - test - - io.airlift jaxrs-testing diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index a40afafc0db1..e9d1246f8803 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -91,6 +91,8 @@ public final class Session private final Map preparedStatements; private final ProtocolHeaders protocolHeaders; private final Optional exchangeEncryptionKey; + private final Optional queryDataEncodingId; + private final Optional queryDataEncryptionKey; public Session( QueryId queryId, @@ -118,7 +120,9 @@ public Session( SessionPropertyManager sessionPropertyManager, Map preparedStatements, ProtocolHeaders protocolHeaders, - Optional exchangeEncryptionKey) + Optional exchangeEncryptionKey, + Optional queryDataEncodingId, + Optional queryDataEncryptionKey) { this.queryId = requireNonNull(queryId, "queryId is null"); this.querySpan = requireNonNull(querySpan, "querySpan is null"); @@ -145,6 +149,8 @@ public Session( this.preparedStatements = requireNonNull(preparedStatements, "preparedStatements is null"); this.protocolHeaders = requireNonNull(protocolHeaders, "protocolHeaders is null"); this.exchangeEncryptionKey = requireNonNull(exchangeEncryptionKey, "exchangeEncryptionKey is null"); + this.queryDataEncodingId = requireNonNull(queryDataEncodingId, "queryDataEncodingId is null"); + this.queryDataEncryptionKey = requireNonNull(queryDataEncryptionKey, "queryDataEncryptionKey is null"); requireNonNull(catalogProperties, "catalogProperties is null"); ImmutableMap.Builder> catalogPropertiesBuilder = ImmutableMap.builder(); @@ -384,7 +390,9 @@ public Session beginTransactionId(TransactionId transactionId, TransactionManage sessionPropertyManager, preparedStatements, protocolHeaders, - exchangeEncryptionKey); + exchangeEncryptionKey, + queryDataEncodingId, + queryDataEncryptionKey); } public Session withDefaultProperties(Map systemPropertyDefaults, Map> catalogPropertyDefaults, AccessControl accessControl) @@ -433,7 +441,9 @@ public Session withDefaultProperties(Map systemPropertyDefaults, sessionPropertyManager, preparedStatements, protocolHeaders, - exchangeEncryptionKey); + exchangeEncryptionKey, + queryDataEncodingId, + queryDataEncryptionKey); } public Session withExchangeEncryption(Slice encryptionKey) @@ -465,7 +475,9 @@ public Session withExchangeEncryption(Slice encryptionKey) sessionPropertyManager, preparedStatements, protocolHeaders, - Optional.of(encryptionKey)); + Optional.of(encryptionKey), + queryDataEncodingId, + queryDataEncryptionKey); } public ConnectorSession toConnectorSession() @@ -518,7 +530,9 @@ public SessionRepresentation toSessionRepresentation() catalogProperties, identity.getCatalogRoles(), preparedStatements, - protocolHeaders.getProtocolName()); + protocolHeaders.getProtocolName(), + queryDataEncodingId, + queryDataEncryptionKey); } @Override @@ -628,6 +642,16 @@ public SecurityContext toSecurityContext() return new SecurityContext(getRequiredTransactionId(), getIdentity(), queryId, start); } + public Optional getQueryDataEncodingId() + { + return queryDataEncodingId; + } + + public Optional getQueryDataEncryptionKey() + { + return queryDataEncryptionKey; + } + public static class SessionBuilder { private QueryId queryId; @@ -648,6 +672,8 @@ public static class SessionBuilder private String clientInfo; private Set clientTags = ImmutableSet.of(); private Set clientCapabilities = ImmutableSet.of(); + private Optional queryDataEncoding = Optional.empty(); + private Optional queryDataEncryptionKey = Optional.empty(); private ResourceEstimates resourceEstimates; private Instant start = Instant.now(); private final Map systemProperties = new HashMap<>(); @@ -682,6 +708,8 @@ private SessionBuilder(Session session) this.userAgent = session.userAgent.orElse(null); this.clientInfo = session.clientInfo.orElse(null); this.clientCapabilities = ImmutableSet.copyOf(session.clientCapabilities); + this.queryDataEncoding = session.queryDataEncodingId; + this.queryDataEncryptionKey = session.queryDataEncryptionKey; this.clientTags = ImmutableSet.copyOf(session.clientTags); this.start = session.start; this.systemProperties.putAll(session.systemProperties); @@ -930,6 +958,18 @@ public SessionBuilder setProtocolHeaders(ProtocolHeaders protocolHeaders) return this; } + public SessionBuilder setQueryDataEncoding(Optional value) + { + this.queryDataEncoding = value; + return this; + } + + public SessionBuilder setQueryDataEncryptionKey(Slice value) + { + this.queryDataEncryptionKey = Optional.ofNullable(value); + return this; + } + public Session build() { return new Session( @@ -958,7 +998,9 @@ public Session build() sessionPropertyManager, preparedStatements, protocolHeaders, - Optional.empty()); + Optional.empty(), + queryDataEncoding, + queryDataEncryptionKey); } } diff --git a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java index 649b3cdbb3e8..eaec96d6ad57 100644 --- a/core/trino-main/src/main/java/io/trino/SessionRepresentation.java +++ b/core/trino-main/src/main/java/io/trino/SessionRepresentation.java @@ -71,6 +71,8 @@ public final class SessionRepresentation private final Map catalogRoles; private final Map preparedStatements; private final String protocolName; + private final Optional queryDataEncoding; + private final Optional queryDataEncryptionKey; @JsonCreator public SessionRepresentation( @@ -102,7 +104,9 @@ public SessionRepresentation( @JsonProperty("catalogProperties") Map> catalogProperties, @JsonProperty("catalogRoles") Map catalogRoles, @JsonProperty("preparedStatements") Map preparedStatements, - @JsonProperty("protocolName") String protocolName) + @JsonProperty("protocolName") String protocolName, + @JsonProperty("queryDataEncoding") Optional queryDataEncoding, + @JsonProperty("queryDataEncryptionKey") Optional queryDataEncryptionKey) { this.queryId = requireNonNull(queryId, "queryId is null"); this.querySpan = requireNonNull(querySpan, "querySpan is null"); @@ -132,6 +136,8 @@ public SessionRepresentation( this.catalogRoles = ImmutableMap.copyOf(catalogRoles); this.preparedStatements = ImmutableMap.copyOf(preparedStatements); this.protocolName = requireNonNull(protocolName, "protocolName is null"); + this.queryDataEncoding = requireNonNull(queryDataEncoding, "queryDataEncoding is null"); + this.queryDataEncryptionKey = requireNonNull(queryDataEncryptionKey, "queryDataEncryptionKey is null"); ImmutableMap.Builder> catalogPropertiesBuilder = ImmutableMap.builder(); for (Entry> entry : catalogProperties.entrySet()) { @@ -320,6 +326,18 @@ public String getTimeZone() return timeZoneKey.getId(); } + @JsonProperty + public Optional getQueryDataEncoding() + { + return queryDataEncoding; + } + + @JsonProperty + public Optional getQueryDataEncryptionKey() + { + return queryDataEncryptionKey; + } + public Identity toIdentity() { return toIdentity(emptyMap()); @@ -378,6 +396,8 @@ public Session toSession(SessionPropertyManager sessionPropertyManager, Map outputSymbols) requireNonNull(outputSymbols, "outputSymbols is null"); return outputSymbols.stream() + .filter(symbol -> !symbol.type().equals(SPOOLING_METADATA_TYPE)) // We don't extra metadata block stats to be accounted for .mapToDouble(symbol -> getOutputSizeForSymbol(getSymbolStatistics(symbol), symbol.type())) .sum(); } diff --git a/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java b/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java index 5e17f913eb5a..5b2e20bcd438 100644 --- a/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/dispatcher/QueuedStatementResource.java @@ -27,6 +27,7 @@ import io.opentelemetry.api.trace.Tracer; import io.trino.client.QueryError; import io.trino.client.QueryResults; +import io.trino.client.RawQueryData; import io.trino.client.StatementStats; import io.trino.execution.ExecutionFailureInfo; import io.trino.execution.QueryManagerConfig; @@ -279,7 +280,7 @@ private static QueryResults createQueryResults( null, nextUri, null, - null, + RawQueryData.of(null), StatementStats.builder() .setState(state.toString()) .setQueued(state == QUEUED) diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 6182e08adef1..7caed294ecd0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -54,6 +54,7 @@ import io.trino.server.DynamicFilterService; import io.trino.server.ResultQueryInfo; import io.trino.server.protocol.Slug; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.sql.PlannerContext; @@ -117,6 +118,7 @@ public class SqlQueryExecution private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final SpoolingManagerRegistry spoolingManagerRegistry; private final NodeAllocatorService nodeAllocatorService; private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory; private final OutputStatsEstimatorFactory outputStatsEstimatorFactory; @@ -157,6 +159,7 @@ private SqlQueryExecution( SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + SpoolingManagerRegistry spoolingManagerRegistry, NodeAllocatorService nodeAllocatorService, PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory, OutputStatsEstimatorFactory outputStatsEstimatorFactory, @@ -191,6 +194,7 @@ private SqlQueryExecution( this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.spoolingManagerRegistry = requireNonNull(spoolingManagerRegistry, "spoolingManagerRegistry is null"); this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.partitionMemoryEstimatorFactory = requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null"); this.outputStatsEstimatorFactory = requireNonNull(outputStatsEstimatorFactory, "outputDataSizeEstimatorFactory is null"); @@ -490,6 +494,7 @@ private PlanRoot doPlanQuery(CachingTableStatsProvider tableStatsProvider) planOptimizers, idAllocator, plannerContext, + spoolingManagerRegistry, statsCalculator, costCalculator, stateMachine.getWarningCollector(), @@ -781,6 +786,7 @@ public static class SqlQueryExecutionFactory private final SplitSourceFactory splitSourceFactory; private final NodePartitioningManager nodePartitioningManager; private final NodeScheduler nodeScheduler; + private final SpoolingManagerRegistry spoolingManagerRegistry; private final NodeAllocatorService nodeAllocatorService; private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory; private final OutputStatsEstimatorFactory outputStatsEstimatorFactory; @@ -813,6 +819,7 @@ public static class SqlQueryExecutionFactory SplitSourceFactory splitSourceFactory, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, + SpoolingManagerRegistry spoolingManagerRegistry, NodeAllocatorService nodeAllocatorService, PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory, OutputStatsEstimatorFactory outputStatsEstimatorFactory, @@ -844,6 +851,7 @@ public static class SqlQueryExecutionFactory this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.spoolingManagerRegistry = requireNonNull(spoolingManagerRegistry, "spoolingManagerRegistry is null"); this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); this.partitionMemoryEstimatorFactory = requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null"); this.outputStatsEstimatorFactory = requireNonNull(outputStatsEstimatorFactory, "outputDataSizeEstimatorFactory is null"); @@ -891,6 +899,7 @@ public QueryExecution createQueryExecution( splitSourceFactory, nodePartitioningManager, nodeScheduler, + spoolingManagerRegistry, nodeAllocatorService, partitionMemoryEstimatorFactory, outputStatsEstimatorFactory, diff --git a/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java b/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java index 55dcc8ef6549..c312b34ec07e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java +++ b/core/trino-main/src/main/java/io/trino/operator/OperatorInfo.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import io.trino.operator.OutputSpoolingOperatorFactory.OutputSpoolingInfo; import io.trino.operator.TableWriterOperator.TableWriterInfo; import io.trino.operator.exchange.LocalExchangeBufferInfo; import io.trino.operator.join.JoinOperatorInfo; @@ -27,6 +28,7 @@ @JsonSubTypes.Type(value = DirectExchangeClientStatus.class, name = "directExchangeClientStatus"), @JsonSubTypes.Type(value = JoinOperatorInfo.class, name = "joinOperatorInfo"), @JsonSubTypes.Type(value = LocalExchangeBufferInfo.class, name = "localExchangeBuffer"), + @JsonSubTypes.Type(value = OutputSpoolingInfo.class, name = "outputSpooling"), @JsonSubTypes.Type(value = PartitionedOutputInfo.class, name = "partitionedOutput"), @JsonSubTypes.Type(value = SplitOperatorInfo.class, name = "splitOperator"), @JsonSubTypes.Type(value = TableFinishInfo.class, name = "tableFinish"), diff --git a/core/trino-main/src/main/java/io/trino/operator/OutputSpoolingController.java b/core/trino-main/src/main/java/io/trino/operator/OutputSpoolingController.java new file mode 100644 index 000000000000..6bd948b14758 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/OutputSpoolingController.java @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import io.trino.spi.Page; + +import static com.google.common.base.Verify.verify; +import static java.lang.Math.clamp; + +public class OutputSpoolingController +{ + public enum Mode + { + INLINE, + BUFFER, + SPOOL + } + + private long currentSpooledSegmentTarget; + private final long maximumSpooledSegmentTarget; + + private final long maximumInlinedPositions; + private final long maximumInlinedSize; + + private long spooledPositions; + private long spooledPages; + private long spooledRawBytes; + private long spooledEncodedBytes; + + private long inlinedPositions; + private long inlinedPages; + private long inlinedRawBytes; + private long bufferedRawSize; + private long bufferedPositions; + + private Mode mode; + + public OutputSpoolingController(boolean inlineFirstRows, long maximumInlinedPositions, long maximumInlinedSize, long initialSpooledSegmentTarget, long maximumSpooledSegmentTarget) + { + this.currentSpooledSegmentTarget = initialSpooledSegmentTarget; + this.maximumSpooledSegmentTarget = maximumSpooledSegmentTarget; + this.maximumInlinedPositions = maximumInlinedPositions; + this.maximumInlinedSize = maximumInlinedSize; + + mode = inlineFirstRows ? Mode.INLINE : Mode.SPOOL; + } + + public Mode getNextMode(Page page) + { + return getNextMode(page.getPositionCount(), page.getSizeInBytes()); + } + + public Mode getNextMode(int positionCount, long sizeInBytes) + { + return switch (mode) { + case INLINE -> { + // If we still didn't inline maximum number of positions + if (inlinedPositions + positionCount >= maximumInlinedPositions) { + mode = Mode.SPOOL; // switch to spooling mode + yield getNextMode(positionCount, sizeInBytes); // and now decide whether to buffer or spool this page + } + + // We don't want to many inlined segments + if (inlinedPages > 3) { // or better bound + mode = Mode.SPOOL; // switch to spooling mode + yield getNextMode(positionCount, sizeInBytes); // and now decide whether to buffer or spool this page + } + + // If we still didn't inline maximum number of bytes + if (inlinedRawBytes + sizeInBytes >= maximumInlinedSize) { + mode = Mode.SPOOL; // switch to spooling mode + yield getNextMode(positionCount, sizeInBytes); // and now decide whether to buffer or spool this page + } + + verify(bufferedRawSize == 0, "There should be no buffered pages when streaming"); + recordInlined(positionCount, sizeInBytes); + yield Mode.INLINE; // we are ok to stream this page + } + case SPOOL -> { + if (bufferedRawSize + sizeInBytes >= currentSpooledSegmentTarget) { + recordSpooled(bufferedPositions + positionCount, bufferedRawSize + sizeInBytes); + yield Mode.SPOOL; + } + + recordBuffered(positionCount, sizeInBytes); + yield Mode.BUFFER; + } + + case BUFFER -> throw new IllegalStateException("Current mode can be either STREAM or SPOOL"); + }; + } + + public void recordSpooled(long rows, long size) + { + bufferedRawSize = 0; + bufferedPositions = 0; // Buffer cleared when spooled + + spooledPositions += rows; + spooledRawBytes += size; + spooledPages++; + + // Double spool target size until we reach maximum + currentSpooledSegmentTarget = clamp(currentSpooledSegmentTarget * 2, currentSpooledSegmentTarget, maximumSpooledSegmentTarget); + } + + public void recordEncoded(long encodedSize) + { + spooledEncodedBytes += encodedSize; + } + + public void recordInlined(int positionCount, long sizeInBytes) + { + inlinedPositions += positionCount; + inlinedPages++; + inlinedRawBytes += sizeInBytes; + } + + public void recordBuffered(int positionCount, long sizeInBytes) + { + bufferedPositions += positionCount; + bufferedRawSize += sizeInBytes; + } + + public long getSpooledPositions() + { + return spooledPositions; + } + + public long getSpooledPages() + { + return spooledPages; + } + + public long getSpooledRawBytes() + { + return spooledRawBytes; + } + + public long getSpooledEncodedBytes() + { + return spooledEncodedBytes; + } + + public long getInlinedPositions() + { + return inlinedPositions; + } + + public long getInlinedPages() + { + return inlinedPages; + } + + public long getInlinedRawBytes() + { + return inlinedRawBytes; + } + + public long getBufferedRawSize() + { + return bufferedRawSize; + } + + public long getBufferedPositions() + { + return bufferedPositions; + } + + public long getCurrentSpooledSegmentTarget() + { + return currentSpooledSegmentTarget; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/OutputSpoolingOperatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/OutputSpoolingOperatorFactory.java new file mode 100644 index 000000000000..152e1483bef8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/OutputSpoolingOperatorFactory.java @@ -0,0 +1,435 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; +import io.trino.client.spooling.DataAttributes; +import io.trino.memory.context.LocalMemoryContext; +import io.trino.operator.OperationTimer.OperationTiming; +import io.trino.server.protocol.OutputColumn; +import io.trino.server.protocol.spooling.QueryDataEncoder; +import io.trino.server.protocol.spooling.SpooledBlock; +import io.trino.server.protocol.spooling.SpoolingManagerBridge; +import io.trino.spi.Mergeable; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.protocol.SpooledSegmentHandle; +import io.trino.spi.protocol.SpoolingContext; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.PlanNodeId; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStream; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Supplier; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; +import static io.trino.client.spooling.DataAttribute.BYTE_SIZE; +import static io.trino.client.spooling.DataAttribute.ROWS_COUNT; +import static io.trino.operator.OutputSpoolingOperatorFactory.OutputSpoolingOperator.State.FINISHED; +import static io.trino.operator.OutputSpoolingOperatorFactory.OutputSpoolingOperator.State.HAS_LAST_OUTPUT; +import static io.trino.operator.OutputSpoolingOperatorFactory.OutputSpoolingOperator.State.HAS_OUTPUT; +import static io.trino.operator.OutputSpoolingOperatorFactory.OutputSpoolingOperator.State.NEEDS_INPUT; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_SYMBOL; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_TYPE; +import static io.trino.server.protocol.spooling.SpooledBlock.createNonSpooledPage; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +public class OutputSpoolingOperatorFactory + implements OperatorFactory +{ + private final int operatorId; + private final PlanNodeId planNodeId; + private final Map operatorLayout; + private final SpoolingManagerBridge spoolingManager; + private final QueryDataEncoder queryDataEncoder; + private boolean closed; + + public OutputSpoolingOperatorFactory(int operatorId, PlanNodeId planNodeId, Map operatorLayout, QueryDataEncoder queryDataEncoder, SpoolingManagerBridge spoolingManager) + { + this.operatorId = operatorId; + this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); + this.operatorLayout = ImmutableMap.copyOf(requireNonNull(operatorLayout, "layout is null")); + this.queryDataEncoder = requireNonNull(queryDataEncoder, "queryDataEncoder is null"); + this.spoolingManager = requireNonNull(spoolingManager, "spoolingManager is null"); + } + + public static List spooledOutputLayout(OutputNode outputNode, Map layout) + { + List columnNames = outputNode.getColumnNames(); + List outputSymbols = outputNode.getOutputSymbols(); + + ImmutableList.Builder outputColumnBuilder = ImmutableList.builderWithExpectedSize(outputNode.getColumnNames().size()); + for (int i = 0; i < columnNames.size(); i++) { + if (outputSymbols.get(i).type().equals(SPOOLING_METADATA_TYPE)) { + continue; + } + outputColumnBuilder.add(new OutputColumn(layout.get(outputSymbols.get(i)), columnNames.get(i), outputSymbols.get(i).type())); + } + return outputColumnBuilder.build(); + } + + public static Map spooledLayout(Map layout) + { + int maxChannelId = layout.values().stream().max(Integer::compareTo).orElseThrow(); + verify(maxChannelId + 1 == layout.size(), "Max channel id %s is not equal to layout size: %s", maxChannelId, layout.size()); + return ImmutableMap.builderWithExpectedSize(layout.size() + 1) + .putAll(layout) + .put(SPOOLING_METADATA_SYMBOL, maxChannelId + 1) + .buildOrThrow(); + } + + @Override + public Operator createOperator(DriverContext driverContext) + { + checkState(!closed, "Factory is already closed"); + OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, OutputSpoolingOperator.class.getSimpleName()); + return new OutputSpoolingOperator(operatorContext, queryDataEncoder, spoolingManager, operatorLayout); + } + + @Override + public void noMoreOperators() + { + closed = true; + } + + @Override + public OperatorFactory duplicate() + { + return new OutputSpoolingOperatorFactory(operatorId, planNodeId, operatorLayout, queryDataEncoder, spoolingManager); + } + + static class OutputSpoolingOperator + implements Operator + { + private final OutputSpoolingController controller; + + enum State + { + NEEDS_INPUT, // output is not ready + HAS_OUTPUT, // output page ready + HAS_LAST_OUTPUT, // last output page ready + FINISHED // no more pages will be ever produced + } + + private OutputSpoolingOperator.State state = NEEDS_INPUT; + private final OperatorContext operatorContext; + private final LocalMemoryContext userMemoryContext; + private final QueryDataEncoder queryDataEncoder; + private final SpoolingManagerBridge spoolingManager; + private final Map layout; + private final PageBuffer buffer = PageBuffer.create(); + private final Block[] emptyBlocks; + + private final OperationTiming encodingTiming = new OperationTiming(); + private final OperationTiming spoolingTiming = new OperationTiming(); + private Page outputPage; + + public OutputSpoolingOperator(OperatorContext operatorContext, QueryDataEncoder queryDataEncoder, SpoolingManagerBridge spoolingManager, Map layout) + { + this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); + this.controller = new OutputSpoolingController( + spoolingManager.useInlineSegments(), + 20, + 1024, + spoolingManager.initialSegmentSize(), + spoolingManager.maximumSegmentSize()); + this.userMemoryContext = operatorContext.localUserMemoryContext(); + this.queryDataEncoder = requireNonNull(queryDataEncoder, "queryDataEncoder is null"); + this.spoolingManager = requireNonNull(spoolingManager, "spoolingManager is null"); + this.layout = requireNonNull(layout, "layout is null"); + this.emptyBlocks = emptyBlocks(layout); + + operatorContext.setInfoSupplier(new OutputSpoolingInfoSupplier(encodingTiming, spoolingTiming, controller)); + } + + @Override + public OperatorContext getOperatorContext() + { + return operatorContext; + } + + @Override + public boolean needsInput() + { + return state == NEEDS_INPUT; + } + + @Override + public void addInput(Page page) + { + checkState(needsInput(), "Operator is already finishing"); + requireNonNull(page, "page is null"); + + outputPage = switch (controller.getNextMode(page)) { + case SPOOL -> { + buffer.add(page); + yield outputBuffer(false); + } + case BUFFER -> { + buffer.add(page); + yield null; + } + case INLINE -> createNonSpooledPage(page); + }; + + if (outputPage != null) { + state = HAS_OUTPUT; + } + } + + @Override + public Page getOutput() + { + if (state != HAS_OUTPUT && state != HAS_LAST_OUTPUT) { + return null; + } + + Page toReturn = outputPage; + outputPage = null; + state = state == HAS_LAST_OUTPUT ? FINISHED : NEEDS_INPUT; + return toReturn; + } + + @Override + public void finish() + { + if (state == NEEDS_INPUT) { + outputPage = outputBuffer(true); + if (outputPage != null) { + state = HAS_LAST_OUTPUT; + } + else { + state = FINISHED; + } + } + } + + @Override + public boolean isFinished() + { + return state == FINISHED; + } + + private Page outputBuffer(boolean finished) + { + if (buffer.isEmpty()) { + return null; + } + + synchronized (buffer) { + return spool(buffer.removeAll(), finished); + } + } + + private Page spool(List pages, boolean finished) + { + long rows = reduce(pages, page -> (long) page.getPositionCount()); + if (finished) { + long size = reduce(pages, Page::getSizeInBytes); + controller.recordSpooled(rows, size); // final buffer + } + + SpoolingContext spoolingContext = new SpoolingContext(operatorContext.getDriverContext().getSession().getQueryId(), rows); + SpooledSegmentHandle segmentHandle = spoolingManager.create(spoolingContext); + + try (OutputStream output = spoolingManager.createOutputStream(segmentHandle); ByteArrayOutputStream bufferedOutput = new ByteArrayOutputStream(toIntExact(controller.getCurrentSpooledSegmentTarget()))) { + OperationTimer encodingTimer = new OperationTimer(true); + DataAttributes attributes = queryDataEncoder.encodeTo(bufferedOutput, pages) + .toBuilder() + .set(ROWS_COUNT, rows) + .build(); + encodingTimer.end(encodingTiming); + OperationTimer spoolingTimer = new OperationTimer(true); + output.write(bufferedOutput.toByteArray()); + spoolingTimer.end(spoolingTiming); + controller.recordEncoded(attributes.get(BYTE_SIZE, Integer.class)); + return emptySingleRowPage(layout, new SpooledBlock(spoolingManager.handleToUriIdentifier(segmentHandle), attributes).serialize()); + } + catch (Exception e) { + throw new RuntimeException(e); + } + finally { + pages = null; + userMemoryContext.setBytes(0); + } + } + + private Page emptySingleRowPage(Map layout, Block block) + { + Block[] blocks = emptyBlocks; + blocks[layout.get(SPOOLING_METADATA_SYMBOL)] = block; + return new Page(blocks); + } + + static long reduce(List page, Function reduce) + { + return page.stream().map(reduce).reduce(0L, Long::sum); + } + + private static Block[] emptyBlocks(Map layout) + { + Block[] blocks = new Block[layout.size()]; + for (Map.Entry entry : layout.entrySet()) { + if (!entry.getKey().type().equals(SPOOLING_METADATA_TYPE)) { + blocks[entry.getValue()] = entry.getKey().type().createBlockBuilder(null, 1).appendNull().build(); + } + } + + return blocks; + } + } + + private record PageBuffer(LinkedList buffer) + { + private PageBuffer + { + requireNonNull(buffer, "buffer is null"); + } + + public static PageBuffer create() + { + return new PageBuffer(new LinkedList<>()); + } + + public void add(Page page) + { + buffer.add(page); + } + + public boolean isEmpty() + { + return buffer.isEmpty(); + } + + public synchronized List removeAll() + { + List pages = ImmutableList.copyOf(buffer()); + buffer.clear(); + return pages; + } + } + + private record OutputSpoolingInfoSupplier( + OperationTiming encodingTiming, + OperationTiming spoolingTiming, + OutputSpoolingController controller) + implements Supplier + { + private OutputSpoolingInfoSupplier + { + requireNonNull(encodingTiming, "encodingTiming is null"); + requireNonNull(spoolingTiming, "spoolingTiming is null"); + requireNonNull(controller, "controller is null"); + } + + @Override + public OutputSpoolingInfo get() + { + return new OutputSpoolingInfo( + new Duration(encodingTiming.getWallNanos(), NANOSECONDS).convertToMostSuccinctTimeUnit(), + new Duration(encodingTiming.getCpuNanos(), NANOSECONDS).convertToMostSuccinctTimeUnit(), + new Duration(spoolingTiming.getWallNanos(), NANOSECONDS).convertToMostSuccinctTimeUnit(), + new Duration(spoolingTiming.getCpuNanos(), NANOSECONDS).convertToMostSuccinctTimeUnit(), + controller.getInlinedPages(), + controller.getInlinedPositions(), + controller.getInlinedRawBytes(), + controller.getSpooledPages(), + controller.getSpooledPositions(), + controller.getSpooledRawBytes(), + controller.getSpooledEncodedBytes()); + } + } + + public record OutputSpoolingInfo( + Duration encodingWallTime, + Duration encodingCpuTime, + Duration spoolingWallTime, + Duration spoolingCpuTime, + long inlinedPages, + long inlinedPositions, + long inlinedRawBytes, + long spooledPages, + long spooledPositions, + long spooledRawBytes, + long spooledEncodedBytes) + implements Mergeable, OperatorInfo + { + public OutputSpoolingInfo + { + requireNonNull(encodingWallTime, "encodingWallTime is null"); + requireNonNull(encodingCpuTime, "encodingCpuTime is null"); + } + + @Override + public OutputSpoolingInfo mergeWith(OutputSpoolingInfo other) + { + return new OutputSpoolingInfo( + new Duration(encodingWallTime.toMillis() + other.encodingWallTime().toMillis(), MILLISECONDS).convertToMostSuccinctTimeUnit(), + new Duration(encodingCpuTime.toMillis() + other.encodingCpuTime().toMillis(), MILLISECONDS).convertToMostSuccinctTimeUnit(), + new Duration(spoolingWallTime.toMillis() + other.spoolingWallTime().toMillis(), MILLISECONDS).convertToMostSuccinctTimeUnit(), + new Duration(spoolingCpuTime.toMillis() + other.spoolingCpuTime().toMillis(), MILLISECONDS).convertToMostSuccinctTimeUnit(), + inlinedPages + other.inlinedPages(), + inlinedPositions + other.inlinedPositions, + inlinedRawBytes + other.inlinedRawBytes, + spooledPages + other.spooledPages, + spooledPositions + other.spooledPositions, + spooledRawBytes + other.spooledRawBytes, + spooledEncodedBytes + other.spooledEncodedBytes); + } + + @JsonProperty + public double getEncodedToRawBytesRatio() + { + return 1.0 * spooledEncodedBytes / spooledRawBytes; + } + + @Override + public boolean isFinal() + { + return true; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("encodingWallTime", encodingWallTime) + .add("encodingCpuTime", encodingCpuTime) + .add("spoolingWallTime", spoolingWallTime) + .add("spoolingCpuTime", spoolingCpuTime) + .add("inlinedPages", inlinedPages) + .add("inlinedPositions", inlinedPositions) + .add("inlinedRawBytes", inlinedRawBytes) + .add("spooledPages", spooledPages) + .add("spooledPositions", spooledPositions) + .add("spooledRawBytes", spooledRawBytes) + .add("spooledEncodedBytes", spooledEncodedBytes) + .add("encodedToRawBytesRatio", getEncodedToRawBytesRatio()) + .toString(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java index f422f9bdb587..90b168cc90e4 100644 --- a/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java +++ b/core/trino-main/src/main/java/io/trino/server/HttpRequestSessionContextFactory.java @@ -27,6 +27,8 @@ import io.trino.metadata.Metadata; import io.trino.security.AccessControl; import io.trino.server.protocol.PreparedStatementEncoder; +import io.trino.server.protocol.spooling.QueryDataEncoder; +import io.trino.server.protocol.spooling.QueryDataEncoderSelector; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.GroupProvider; import io.trino.spi.security.Identity; @@ -71,6 +73,7 @@ public class HttpRequestSessionContextFactory private final GroupProvider groupProvider; private final AccessControl accessControl; private final Optional alternateHeaderName; + private final QueryDataEncoderSelector queryDataEncoderSelector; @Inject public HttpRequestSessionContextFactory( @@ -78,13 +81,15 @@ public HttpRequestSessionContextFactory( Metadata metadata, GroupProvider groupProvider, AccessControl accessControl, - ProtocolConfig protocolConfig) + ProtocolConfig protocolConfig, + QueryDataEncoderSelector queryDataEncoderSelector) { this.alternateHeaderName = protocolConfig.getAlternateHeaderName(); this.preparedStatementEncoder = requireNonNull(preparedStatementEncoder, "preparedStatementEncoder is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.groupProvider = requireNonNull(groupProvider, "groupProvider is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); + this.queryDataEncoderSelector = requireNonNull(queryDataEncoderSelector, "queryDataEncoderSelector is null"); } public SessionContext createSessionContext( @@ -116,6 +121,10 @@ public SessionContext createSessionContext( Optional timeZoneId = Optional.ofNullable(headers.getFirst(protocolHeaders.requestTimeZone())); Optional language = Optional.ofNullable(headers.getFirst(protocolHeaders.requestLanguage())); Optional clientInfo = Optional.ofNullable(headers.getFirst(protocolHeaders.requestClientInfo())); + Optional queryDataEncodingId = Optional.ofNullable(headers.getFirst(protocolHeaders.requestQueryDataEncoding())) + .flatMap(queryDataEncoderSelector::select) + .map(QueryDataEncoder.Factory::encodingId); + Set clientTags = parseClientTags(protocolHeaders, headers); Set clientCapabilities = parseClientCapabilities(protocolHeaders, headers); ResourceEstimates resourceEstimates = parseResourceEstimate(protocolHeaders, headers); @@ -176,7 +185,8 @@ case ParsedSessionPropertyName(Optional catalogName, String propertyName preparedStatements, transactionId, clientTransactionSupport, - clientInfo); + clientInfo, + queryDataEncodingId); } public Identity extractAuthorizedIdentity(HttpServletRequest servletRequest, HttpHeaders httpHeaders) diff --git a/core/trino-main/src/main/java/io/trino/server/PluginManager.java b/core/trino-main/src/main/java/io/trino/server/PluginManager.java index 06e364b813e3..20bd955f4868 100644 --- a/core/trino-main/src/main/java/io/trino/server/PluginManager.java +++ b/core/trino-main/src/main/java/io/trino/server/PluginManager.java @@ -30,6 +30,7 @@ import io.trino.metadata.TypeRegistry; import io.trino.security.AccessControlManager; import io.trino.security.GroupProviderManager; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.server.security.CertificateAuthenticatorManager; import io.trino.server.security.HeaderAuthenticatorManager; import io.trino.server.security.PasswordAuthenticatorManager; @@ -40,6 +41,7 @@ import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.eventlistener.EventListenerFactory; import io.trino.spi.exchange.ExchangeManagerFactory; +import io.trino.spi.protocol.SpoolingManagerFactory; import io.trino.spi.resourcegroups.ResourceGroupConfigurationManagerFactory; import io.trino.spi.security.CertificateAuthenticatorFactory; import io.trino.spi.security.GroupProviderFactory; @@ -89,6 +91,7 @@ public class PluginManager private final EventListenerManager eventListenerManager; private final GroupProviderManager groupProviderManager; private final ExchangeManagerRegistry exchangeManagerRegistry; + private final SpoolingManagerRegistry spoolingManagerRegistry; private final SessionPropertyDefaults sessionPropertyDefaults; private final TypeRegistry typeRegistry; private final BlockEncodingManager blockEncodingManager; @@ -112,7 +115,8 @@ public PluginManager( TypeRegistry typeRegistry, BlockEncodingManager blockEncodingManager, HandleResolver handleResolver, - ExchangeManagerRegistry exchangeManagerRegistry) + ExchangeManagerRegistry exchangeManagerRegistry, + SpoolingManagerRegistry spoolingManagerRegistry) { this.pluginsProvider = requireNonNull(pluginsProvider, "pluginsProvider is null"); this.catalogStoreManager = requireNonNull(catalogStoreManager, "catalogStoreManager is null"); @@ -130,6 +134,7 @@ public PluginManager( this.blockEncodingManager = requireNonNull(blockEncodingManager, "blockEncodingManager is null"); this.handleResolver = requireNonNull(handleResolver, "handleResolver is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); + this.spoolingManagerRegistry = requireNonNull(spoolingManagerRegistry, "spoolingManagerRegistry is null"); } @Override @@ -267,6 +272,11 @@ private void installPluginInternal(Plugin plugin) log.info("Registering exchange manager %s", exchangeManagerFactory.getName()); exchangeManagerRegistry.addExchangeManagerFactory(exchangeManagerFactory); } + + for (SpoolingManagerFactory spoolingManagerFactory : plugin.getSpoolingManagerFactories()) { + log.info("Registering spooling manager %s", spoolingManagerFactory.getName()); + spoolingManagerRegistry.addSpoolingManagerFactory(spoolingManagerFactory); + } } public static PluginClassLoader createClassLoader(String pluginName, List urls) diff --git a/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java b/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java index f62ef94b76e7..894149927db7 100644 --- a/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java +++ b/core/trino-main/src/main/java/io/trino/server/QuerySessionSupplier.java @@ -20,6 +20,7 @@ import io.trino.metadata.Metadata; import io.trino.metadata.SessionPropertyManager; import io.trino.security.AccessControl; +import io.trino.server.protocol.spooling.SpoolingConfig; import io.trino.spi.QueryId; import io.trino.spi.security.Identity; import io.trino.spi.type.TimeZoneKey; @@ -35,6 +36,8 @@ import static io.trino.SystemSessionProperties.TIME_ZONE_ID; import static io.trino.server.HttpRequestSessionContextFactory.addEnabledRoles; import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; +import static io.trino.util.Ciphers.createRandomAesEncryptionKey; +import static io.trino.util.Ciphers.serializeAesEncryptionKey; import static java.util.Map.Entry; import static java.util.Objects.requireNonNull; @@ -49,13 +52,15 @@ public class QuerySessionSupplier private final Optional forcedSessionTimeZone; private final Optional defaultCatalog; private final Optional defaultSchema; + private final boolean spoolingEncryptionEnabled; @Inject public QuerySessionSupplier( Metadata metadata, AccessControl accessControl, SessionPropertyManager sessionPropertyManager, - SqlEnvironmentConfig config) + SqlEnvironmentConfig config, + SpoolingConfig spoolingConfig) { this.metadata = requireNonNull(metadata, "metadata is null"); this.accessControl = requireNonNull(accessControl, "accessControl is null"); @@ -64,6 +69,7 @@ public QuerySessionSupplier( this.forcedSessionTimeZone = requireNonNull(config.getForcedSessionTimeZone(), "forcedSessionTimeZone is null"); this.defaultCatalog = requireNonNull(config.getDefaultCatalog(), "defaultCatalog is null"); this.defaultSchema = requireNonNull(config.getDefaultSchema(), "defaultSchema is null"); + this.spoolingEncryptionEnabled = requireNonNull(spoolingConfig, "spoolingConfig is null").isEncryptionEnabled(); checkArgument(defaultCatalog.isPresent() || defaultSchema.isEmpty(), "Default schema cannot be set if catalog is not set"); } @@ -112,7 +118,12 @@ public Session createSession(QueryId queryId, Span querySpan, SessionContext con .setClientCapabilities(context.getClientCapabilities()) .setTraceToken(context.getTraceToken()) .setResourceEstimates(context.getResourceEstimates()) - .setProtocolHeaders(context.getProtocolHeaders()); + .setProtocolHeaders(context.getProtocolHeaders()) + .setQueryDataEncoding(context.getQueryDataEncodingId()); + + if (context.getQueryDataEncodingId().isPresent() && spoolingEncryptionEnabled) { + sessionBuilder.setQueryDataEncryptionKey(serializeAesEncryptionKey(createRandomAesEncryptionKey())); + } if (context.getCatalog().isPresent()) { sessionBuilder.setCatalog(context.getCatalog()); diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java index 73527d731bf4..91eb0fc6ab8a 100644 --- a/core/trino-main/src/main/java/io/trino/server/Server.java +++ b/core/trino-main/src/main/java/io/trino/server/Server.java @@ -57,6 +57,7 @@ import io.trino.security.AccessControlManager; import io.trino.security.AccessControlModule; import io.trino.security.GroupProviderManager; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.server.security.CertificateAuthenticatorManager; import io.trino.server.security.HeaderAuthenticatorManager; import io.trino.server.security.PasswordAuthenticatorManager; @@ -176,6 +177,7 @@ private void doStart(String trinoVersion) .ifPresent(PasswordAuthenticatorManager::loadPasswordAuthenticator); injector.getInstance(GroupProviderManager.class).loadConfiguredGroupProvider(); injector.getInstance(ExchangeManagerRegistry.class).loadExchangeManager(); + injector.getInstance(SpoolingManagerRegistry.class).loadSpoolingManager(); injector.getInstance(CertificateAuthenticatorManager.class).loadCertificateAuthenticator(); injector.getInstance(Key.get(new TypeLiteral>() {})) .ifPresent(HeaderAuthenticatorManager::loadHeaderAuthenticator); diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index b89c5f3f1792..d359c24f0d30 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -101,6 +101,7 @@ import io.trino.server.SliceSerialization.SliceDeserializer; import io.trino.server.SliceSerialization.SliceSerializer; import io.trino.server.protocol.PreparedStatementEncoder; +import io.trino.server.protocol.spooling.SpoolingServerModule; import io.trino.server.remotetask.HttpLocationFactory; import io.trino.spi.PageIndexerFactory; import io.trino.spi.PageSorter; @@ -213,6 +214,7 @@ protected void setup(Binder binder) binder.bind(PreparedStatementEncoder.class).in(Scopes.SINGLETON); binder.bind(HttpRequestSessionContextFactory.class).in(Scopes.SINGLETON); install(new InternalCommunicationModule()); + install(new SpoolingServerModule()); QueryManagerConfig queryManagerConfig = buildConfigObject(QueryManagerConfig.class); RetryPolicy retryPolicy = queryManagerConfig.getRetryPolicy(); diff --git a/core/trino-main/src/main/java/io/trino/server/SessionContext.java b/core/trino-main/src/main/java/io/trino/server/SessionContext.java index 4106a2adfe97..074ac41605e4 100644 --- a/core/trino-main/src/main/java/io/trino/server/SessionContext.java +++ b/core/trino-main/src/main/java/io/trino/server/SessionContext.java @@ -62,6 +62,7 @@ public class SessionContext private final Optional transactionId; private final boolean clientTransactionSupport; private final Optional clientInfo; + private final Optional queryDataEncodingId; public SessionContext( ProtocolHeaders protocolHeaders, @@ -86,7 +87,8 @@ public SessionContext( Map preparedStatements, Optional transactionId, boolean clientTransactionSupport, - Optional clientInfo) + Optional clientInfo, + Optional queryDataEncodingId) { this.protocolHeaders = requireNonNull(protocolHeaders, "protocolHeaders is null"); this.catalog = requireNonNull(catalog, "catalog is null"); @@ -113,6 +115,7 @@ public SessionContext( this.transactionId = requireNonNull(transactionId, "transactionId is null"); this.clientTransactionSupport = clientTransactionSupport; this.clientInfo = requireNonNull(clientInfo, "clientInfo is null"); + this.queryDataEncodingId = requireNonNull(queryDataEncodingId, "queryDataEncodingId is null"); } public ProtocolHeaders getProtocolHeaders() @@ -230,6 +233,11 @@ public Optional getTraceToken() return traceToken; } + public Optional getQueryDataEncodingId() + { + return queryDataEncodingId; + } + @VisibleForTesting public static SessionContext fromSession(Session session) { @@ -270,6 +278,7 @@ else if (enabledRoles.size() == 1) { session.getPreparedStatements(), session.getTransactionId(), session.isClientTransactionSupport(), - session.getClientInfo()); + session.getClientInfo(), + session.getQueryDataEncodingId()); } } diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java index df2da65ba037..da2288040d79 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/ExecutingStatementResource.java @@ -30,6 +30,9 @@ import io.trino.server.ExternalUriInfo; import io.trino.server.ForStatementResource; import io.trino.server.ServerConfig; +import io.trino.server.protocol.spooling.QueryDataEncoder; +import io.trino.server.protocol.spooling.QueryDataEncoderSelector; +import io.trino.server.protocol.spooling.QueryDataProducer; import io.trino.server.security.ResourceSecurity; import io.trino.spi.QueryId; import io.trino.spi.block.BlockEncodingSerde; @@ -50,6 +53,7 @@ import java.net.URLEncoder; import java.util.Map.Entry; import java.util.NoSuchElementException; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; @@ -77,6 +81,7 @@ public class ExecutingStatementResource private static final DataSize MAX_TARGET_RESULT_SIZE = DataSize.of(128, MEGABYTE); private final QueryManager queryManager; + private final QueryDataEncoderSelector queryDataEncoderSelector; private final DirectExchangeClientSupplier directExchangeClientSupplier; private final ExchangeManagerRegistry exchangeManagerRegistry; private final BlockEncodingSerde blockEncodingSerde; @@ -92,6 +97,7 @@ public class ExecutingStatementResource @Inject public ExecutingStatementResource( QueryManager queryManager, + QueryDataEncoderSelector queryDataEncoderSelector, DirectExchangeClientSupplier directExchangeClientSupplier, ExchangeManagerRegistry exchangeManagerRegistry, BlockEncodingSerde blockEncodingSerde, @@ -102,6 +108,7 @@ public ExecutingStatementResource( ServerConfig serverConfig) { this.queryManager = requireNonNull(queryManager, "queryManager is null"); + this.queryDataEncoderSelector = requireNonNull(queryDataEncoderSelector, "queryDataEncoderSelector is null"); this.directExchangeClientSupplier = requireNonNull(directExchangeClientSupplier, "directExchangeClientSupplier is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); @@ -189,10 +196,14 @@ protected Query getQuery(QueryId queryId, String slug, long token) throw new NotFoundException("Query not found"); } + Optional encoderFactory = session.getQueryDataEncodingId() + .flatMap(queryDataEncoderSelector::select); + query = queries.computeIfAbsent(queryId, id -> Query.create( session, querySlug, queryManager, + new QueryDataProducer(encoderFactory), queryInfoUrlFactory.getQueryInfoUrl(queryId), directExchangeClientSupplier, exchangeManagerRegistry, diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/JsonArrayResultsIterator.java b/core/trino-main/src/main/java/io/trino/server/protocol/JsonArrayResultsIterator.java new file mode 100644 index 000000000000..5934ed8cff7f --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/JsonArrayResultsIterator.java @@ -0,0 +1,231 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import io.trino.Session; +import io.trino.client.ClientCapabilities; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SqlTime; +import io.trino.spi.type.SqlTimeWithTimeZone; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.SqlTimestampWithTimeZone; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimeWithTimeZoneType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; +import io.trino.spi.type.Type; +import jakarta.annotation.Nullable; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import static io.trino.spi.StandardErrorCode.SERIALIZATION_ERROR; +import static java.lang.String.format; +import static java.util.Collections.emptyList; +import static java.util.Collections.unmodifiableList; +import static java.util.Collections.unmodifiableMap; +import static java.util.Objects.requireNonNull; + +public class JsonArrayResultsIterator + extends AbstractIterator> + implements Iterable> +{ + private final Deque queue; + private final Session session; + private final ImmutableList pages; + private final List columns; + private final boolean supportsParametricDateTime; + private final Consumer exceptionConsumer; + + private Page currentPage; + private int rowPosition = -1; + private int inPageIndex = -1; + + public JsonArrayResultsIterator(Session session, List pages, List columns, Consumer exceptionConsumer) + { + this.pages = ImmutableList.copyOf(pages); + this.queue = new ArrayDeque<>(pages); + this.session = requireNonNull(session, "session is null"); + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + this.supportsParametricDateTime = session.getClientCapabilities().contains(ClientCapabilities.PARAMETRIC_DATETIME.toString()); + this.exceptionConsumer = requireNonNull(exceptionConsumer, "exceptionConsumer is null"); + this.currentPage = queue.pollFirst(); + } + + @Override + protected List computeNext() + { + while (true) { + if (currentPage == null) { + return endOfData(); + } + + inPageIndex++; + + if (inPageIndex >= currentPage.getPositionCount()) { + currentPage = queue.pollFirst(); + + if (currentPage == null) { + return endOfData(); + } + + inPageIndex = 0; + } + + rowPosition++; + + List row = getRowValues(); + if (row != null) { + // row is not skipped, return it + return row; + } + } + } + + @Nullable + private List getRowValues() + { + // types are present if data is present + List row = new ArrayList<>(columns.size()); + for (OutputColumn outputColumn : columns) { + Type type = outputColumn.type(); + Block block = currentPage.getBlock(outputColumn.sourcePageChannel()); + + try { + Object value = type.getObjectValue(session.toConnectorSession(), block, inPageIndex); + if (!supportsParametricDateTime) { + value = getLegacyValue(value, type); + } + row.add(value); + } + catch (Throwable throwable) { + propagateException(rowPosition, outputColumn.sourcePageChannel(), outputColumn.columnName(), outputColumn.type(), throwable); + // skip row as it contains non-serializable value + return null; + } + } + return unmodifiableList(row); + } + + private Object getLegacyValue(Object value, Type type) + { + if (value == null) { + return null; + } + + if (!supportsParametricDateTime) { + // for legacy clients we need to round timestamp and timestamp with timezone to default precision (3) + + if (type instanceof TimestampType) { + return ((SqlTimestamp) value).roundTo(3); + } + + if (type instanceof TimestampWithTimeZoneType) { + return ((SqlTimestampWithTimeZone) value).roundTo(3); + } + + if (type instanceof TimeType) { + return ((SqlTime) value).roundTo(3); + } + + if (type instanceof TimeWithTimeZoneType) { + return ((SqlTimeWithTimeZone) value).roundTo(3); + } + } + + if (type instanceof ArrayType) { + Type elementType = ((ArrayType) type).getElementType(); + + if (!(elementType instanceof TimestampType || elementType instanceof TimestampWithTimeZoneType)) { + return value; + } + + List listValue = (List) value; + List legacyValues = new ArrayList<>(listValue.size()); + for (Object element : listValue) { + legacyValues.add(getLegacyValue(element, elementType)); + } + + return unmodifiableList(legacyValues); + } + + if (type instanceof MapType) { + Type keyType = ((MapType) type).getKeyType(); + Type valueType = ((MapType) type).getValueType(); + + Map mapValue = (Map) value; + Map result = Maps.newHashMapWithExpectedSize(mapValue.size()); + mapValue.forEach((key, val) -> result.put(getLegacyValue(key, keyType), getLegacyValue(val, valueType))); + return unmodifiableMap(result); + } + + if (type instanceof RowType) { + List fields = ((RowType) type).getFields(); + List values = (List) value; + + List result = new ArrayList<>(values.size()); + for (int i = 0; i < values.size(); i++) { + result.add(getLegacyValue(values.get(i), fields.get(i).getType())); + } + return unmodifiableList(result); + } + + return value; + } + + private void propagateException(int row, int channel, String name, Type type, Throwable cause) + { + // columns and rows are 0-indexed + String message = format("Could not serialize column '%s' of type '%s' at position %d:%d", + name, + type, + row + 1, + channel + 1); + + exceptionConsumer.accept(new TrinoException(SERIALIZATION_ERROR, message, cause)); + } + + @Override + public Iterator> iterator() + { + return new JsonArrayResultsIterator(session, pages, columns, exceptionConsumer); + } + + public static Iterable> toIterableList(Session session, QueryResultRows rows, Consumer serializationExceptionHandler) + { + if (rows.getOutputColumns().isEmpty()) { + return emptyList(); + } + + List columnAndTypes = rows.getOutputColumns().orElseThrow(); + return new JsonArrayResultsIterator( + session, + rows.getPages(), + columnAndTypes, + serializationExceptionHandler); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/OutputColumn.java b/core/trino-main/src/main/java/io/trino/server/protocol/OutputColumn.java new file mode 100644 index 000000000000..b9b95110bcb4 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/OutputColumn.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +import io.trino.spi.type.Type; + +import static java.util.Objects.requireNonNull; + +public record OutputColumn(int sourcePageChannel, String columnName, Type type) +{ + public OutputColumn + { + requireNonNull(columnName, "columnName is null"); + requireNonNull(type, "type is null"); + + if (sourcePageChannel < 0) { + throw new IllegalArgumentException("sourcePageChannel is negative"); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java index eb9063c04612..c8f245488d43 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/Query.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/Query.java @@ -47,6 +47,7 @@ import io.trino.server.ExternalUriInfo; import io.trino.server.GoneException; import io.trino.server.ResultQueryInfo; +import io.trino.server.protocol.spooling.QueryDataProducer; import io.trino.spi.ErrorCode; import io.trino.spi.Page; import io.trino.spi.QueryId; @@ -103,6 +104,10 @@ class Query @GuardedBy("this") private final ExchangeDataSource exchangeDataSource; + + @GuardedBy("this") + private final QueryDataProducer queryDataProducer; + @GuardedBy("this") private ListenableFuture exchangeDataSourceBlocked; @@ -176,6 +181,7 @@ public static Query create( Session session, Slug slug, QueryManager queryManager, + QueryDataProducer queryDataProducer, Optional queryInfoUrl, DirectExchangeClientSupplier directExchangeClientSupplier, ExchangeManagerRegistry exchangeManagerRegistry, @@ -193,7 +199,7 @@ public static Query create( getRetryPolicy(session), exchangeManagerRegistry); - Query result = new Query(session, slug, queryManager, queryInfoUrl, exchangeDataSource, dataProcessorExecutor, timeoutExecutor, blockEncodingSerde); + Query result = new Query(session, slug, queryManager, queryDataProducer, queryInfoUrl, exchangeDataSource, dataProcessorExecutor, timeoutExecutor, blockEncodingSerde); result.queryManager.setOutputInfoListener(result.getQueryId(), result::setQueryOutputInfo); @@ -213,6 +219,7 @@ private Query( Session session, Slug slug, QueryManager queryManager, + QueryDataProducer queryDataProducer, Optional queryInfoUrl, ExchangeDataSource exchangeDataSource, Executor resultsProcessorExecutor, @@ -222,6 +229,7 @@ private Query( requireNonNull(session, "session is null"); requireNonNull(slug, "slug is null"); requireNonNull(queryManager, "queryManager is null"); + requireNonNull(queryDataProducer, "queryDataProducer is null"); requireNonNull(queryInfoUrl, "queryInfoUrl is null"); requireNonNull(exchangeDataSource, "exchangeDataSource is null"); requireNonNull(resultsProcessorExecutor, "resultsProcessorExecutor is null"); @@ -229,6 +237,7 @@ private Query( requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.queryManager = queryManager; + this.queryDataProducer = queryDataProducer; this.queryId = session.getQueryId(); this.session = session; this.slug = slug; @@ -496,7 +505,7 @@ private synchronized QueryResultsResponse getNextResult(long token, ExternalUriI partialCancelUri, nextResultsUri, resultRows.getColumns().orElse(null), - resultRows.isEmpty() ? null : resultRows, // client excepts null that indicates "no data" + queryDataProducer.produce(externalUriInfo, session, resultRows, this::handleSerializationException), toStatementStats(queryInfo), toQueryError(queryInfo, typeSerializationException), mappedCopy(queryInfo.warnings(), ProtocolUtil::toClientWarning), @@ -542,9 +551,6 @@ private synchronized QueryResultRows removePagesFromExchange(ResultQueryInfo que // last page is removed. If another thread observes this state before the response is cached // the pages will be lost. QueryResultRows.Builder resultBuilder = queryResultRowsBuilder(session) - // Intercept serialization exceptions and fail query if it's still possible. - // Put serialization exception aside to return failed query result. - .withExceptionConsumer(this::handleSerializationException) .withColumnsAndTypes(columns, types); try { diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultRows.java b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultRows.java index 2fcaab646aa2..c9ada5766ad1 100644 --- a/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultRows.java +++ b/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultRows.java @@ -14,88 +14,75 @@ package io.trino.server.protocol; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.AbstractIterator; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Maps; import io.trino.Session; -import io.trino.client.ClientCapabilities; import io.trino.client.Column; import io.trino.spi.Page; -import io.trino.spi.TrinoException; -import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.connector.ConnectorSession; -import io.trino.spi.type.ArrayType; -import io.trino.spi.type.MapType; -import io.trino.spi.type.RowType; -import io.trino.spi.type.SqlTime; -import io.trino.spi.type.SqlTimeWithTimeZone; -import io.trino.spi.type.SqlTimestamp; -import io.trino.spi.type.SqlTimestampWithTimeZone; -import io.trino.spi.type.TimeType; -import io.trino.spi.type.TimeWithTimeZoneType; -import io.trino.spi.type.TimestampType; -import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import jakarta.annotation.Nullable; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Deque; -import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.function.Consumer; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; -import static io.trino.spi.StandardErrorCode.SERIALIZATION_ERROR; +import static io.trino.server.protocol.ProtocolUtil.createColumn; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_TYPE; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; -import static java.lang.String.format; -import static java.util.Collections.unmodifiableList; -import static java.util.Collections.unmodifiableMap; import static java.util.Objects.requireNonNull; public class QueryResultRows - implements Iterable> { - private final ConnectorSession session; - private final Optional> columns; + private final Session session; + private final Optional> columns; private final List pages; - private final Optional> exceptionConsumer; private final long totalRows; - private final boolean supportsParametricDateTime; - private QueryResultRows(Session session, Optional> columns, List pages, Consumer exceptionConsumer) + private QueryResultRows(Session session, Optional> columns, List pages) { - this.session = session.toConnectorSession(); - this.columns = requireNonNull(columns, "columns is null"); + this.session = requireNonNull(session, "session is null"); + this.columns = requireNonNull(columns, "columns is null").map(values -> values.stream() + .filter(column -> !isSpooledMetadataColumn(column)) + .collect(toImmutableList())); this.pages = ImmutableList.copyOf(pages); - this.exceptionConsumer = Optional.ofNullable(exceptionConsumer); this.totalRows = countRows(pages); - this.supportsParametricDateTime = session.getClientCapabilities().contains(ClientCapabilities.PARAMETRIC_DATETIME.toString()); verify(totalRows == 0 || (totalRows > 0 && columns.isPresent()), "data present without columns and types"); } + private boolean isSpooledMetadataColumn(OutputColumn column) + { + return column.type().equals(SPOOLING_METADATA_TYPE); + } + public boolean isEmpty() { return totalRows == 0; } + public Optional> getOutputColumns() + { + return this.columns; + } + public Optional> getColumns() { - return columns.map(columns -> columns.stream() - .map(ColumnAndType::getColumn) + return columns + .map(columns -> columns.stream() + .map(value -> createColumn(value.columnName(), value.type(), true)) .collect(toImmutableList())); } + public List getPages() + { + return this.pages; + } + /** * Returns expected row count (we don't know yet if every row is serializable). */ @@ -112,24 +99,18 @@ public Optional getUpdateCount() return Optional.empty(); } - List columns = this.columns.get(); + List columns = this.columns.get(); - if (columns.size() != 1 || !columns.get(0).getType().equals(BIGINT)) { + if (columns.size() != 1 || !columns.get(0).type().equals(BIGINT)) { return Optional.empty(); } checkState(!pages.isEmpty(), "no data pages available"); - Number value = (Number) columns.get(0).getType().getObjectValue(session, pages.get(0).getBlock(0), 0); + Number value = (Number) columns.get(0).type().getObjectValue(session.toConnectorSession(), pages.getFirst().getBlock(0), 0); return Optional.ofNullable(value).map(Number::longValue); } - @Override - public Iterator> iterator() - { - return new ResultsIterator(this); - } - private static long countRows(List pages) { long rows = 0; @@ -151,7 +132,7 @@ public String toString() public static QueryResultRows empty(Session session) { - return new QueryResultRows(session, Optional.empty(), ImmutableList.of(), null); + return new QueryResultRows(session, Optional.empty(), ImmutableList.of()); } public static Builder queryResultRowsBuilder(Session session) @@ -163,8 +144,7 @@ public static class Builder { private final Session session; private ImmutableList.Builder pages = ImmutableList.builder(); - private Optional> columns = Optional.empty(); - private Consumer exceptionConsumer; + private Optional> columns = Optional.empty(); public Builder(Session session) { @@ -202,226 +182,26 @@ public Builder withSingleBooleanValue(Column column, boolean value) return this; } - public Builder withExceptionConsumer(Consumer exceptionConsumer) - { - this.exceptionConsumer = exceptionConsumer; - return this; - } - public QueryResultRows build() { return new QueryResultRows( session, columns, - pages.build(), - exceptionConsumer); + pages.build()); } - private static List combine(@Nullable List columns, @Nullable List types) + private static List combine(@Nullable List columns, @Nullable List types) { checkArgument(columns != null && types != null, "columns and types must be present at the same time"); checkArgument(columns.size() == types.size(), "columns and types size mismatch"); - ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(columns.size()); + ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(columns.size()); for (int i = 0; i < columns.size(); i++) { - builder.add(new ColumnAndType(i, columns.get(i), types.get(i))); + builder.add(new OutputColumn(i, columns.get(i).getName(), types.get(i))); } return builder.build(); } } - - private static class ColumnAndType - { - private final int position; - private final Column column; - private final Type type; - - private ColumnAndType(int position, Column column, Type type) - { - this.position = position; - this.column = column; - this.type = type; - } - - public Column getColumn() - { - return column; - } - - public Type getType() - { - return type; - } - - public int getPosition() - { - return position; - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("column", column) - .add("type", type) - .add("position", position) - .toString(); - } - } - - private static class ResultsIterator - extends AbstractIterator> - { - private final Deque queue; - private final QueryResultRows results; - private Page currentPage; - private int rowPosition = -1; - private int inPageIndex = -1; - - public ResultsIterator(QueryResultRows results) - { - this.queue = new ArrayDeque<>(results.pages); - this.results = results; - this.currentPage = queue.pollFirst(); - } - - @Override - protected List computeNext() - { - while (true) { - if (currentPage == null) { - return endOfData(); - } - - inPageIndex++; - - if (inPageIndex >= currentPage.getPositionCount()) { - currentPage = queue.pollFirst(); - - if (currentPage == null) { - return endOfData(); - } - - inPageIndex = 0; - } - - rowPosition++; - - List row = getRowValues(); - if (row != null) { - // row is not skipped, return it - return row; - } - } - } - - @Nullable - private List getRowValues() - { - // types are present if data is present - List columns = results.columns.orElseThrow(); - Object[] row = new Object[currentPage.getChannelCount()]; - - for (int channel = 0; channel < currentPage.getChannelCount(); channel++) { - ColumnAndType column = columns.get(channel); - Type type = column.getType(); - Block block = currentPage.getBlock(channel); - - try { - Object value = type.getObjectValue(results.session, block, inPageIndex); - if (!results.supportsParametricDateTime) { - value = getLegacyValue(value, type); - } - row[channel] = value; - } - catch (Throwable throwable) { - propagateException(rowPosition, column, throwable); - // skip row as it contains non-serializable value - return null; - } - } - - return unmodifiableList(Arrays.asList(row)); - } - - private Object getLegacyValue(Object value, Type type) - { - if (value == null) { - return null; - } - - if (!results.supportsParametricDateTime) { - // for legacy clients we need to round timestamp and timestamp with timezone to default precision (3) - - if (type instanceof TimestampType) { - return ((SqlTimestamp) value).roundTo(3); - } - - if (type instanceof TimestampWithTimeZoneType) { - return ((SqlTimestampWithTimeZone) value).roundTo(3); - } - - if (type instanceof TimeType) { - return ((SqlTime) value).roundTo(3); - } - - if (type instanceof TimeWithTimeZoneType) { - return ((SqlTimeWithTimeZone) value).roundTo(3); - } - } - - if (type instanceof ArrayType) { - Type elementType = ((ArrayType) type).getElementType(); - - if (!(elementType instanceof TimestampType || elementType instanceof TimestampWithTimeZoneType)) { - return value; - } - - List listValue = (List) value; - List legacyValues = new ArrayList<>(listValue.size()); - for (Object element : listValue) { - legacyValues.add(getLegacyValue(element, elementType)); - } - - return unmodifiableList(legacyValues); - } - - if (type instanceof MapType) { - Type keyType = ((MapType) type).getKeyType(); - Type valueType = ((MapType) type).getValueType(); - - Map mapValue = (Map) value; - Map result = Maps.newHashMapWithExpectedSize(mapValue.size()); - mapValue.forEach((key, val) -> result.put(getLegacyValue(key, keyType), getLegacyValue(val, valueType))); - return unmodifiableMap(result); - } - - if (type instanceof RowType) { - List fields = ((RowType) type).getFields(); - List values = (List) value; - - List result = new ArrayList<>(values.size()); - for (int i = 0; i < values.size(); i++) { - result.add(getLegacyValue(values.get(i), fields.get(i).getType())); - } - return unmodifiableList(result); - } - - return value; - } - - private void propagateException(int row, ColumnAndType column, Throwable cause) - { - // columns and rows are 0-indexed - String message = format("Could not serialize column '%s' of type '%s' at position %d:%d", - column.getColumn().getName(), - column.getType(), - row + 1, - column.getPosition() + 1); - - results.exceptionConsumer.ifPresent(consumer -> consumer.accept(new TrinoException(SERIALIZATION_ERROR, message, cause))); - } - } } diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/DataAttributesSerialization.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/DataAttributesSerialization.java new file mode 100644 index 000000000000..a969f6850446 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/DataAttributesSerialization.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import com.google.common.base.Joiner; +import com.google.common.base.Splitter; +import io.trino.client.spooling.DataAttributes; + +public class DataAttributesSerialization +{ + private static final Joiner.MapJoiner JOINER = Joiner.on(",").withKeyValueSeparator(":"); + private static final Splitter.MapSplitter SPLITTER = Splitter.on(",") + .trimResults() + .omitEmptyStrings() + .withKeyValueSeparator(":"); + + private DataAttributesSerialization() {} + + public static String serialize(DataAttributes attributes) + { + return JOINER.join(attributes.toMap()); + } + + public static DataAttributes deserialize(String values) + { + DataAttributes.Builder builder = DataAttributes.builder(); + SPLITTER.split(values).forEach(builder::set); + return builder.build(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/PreferredQueryDataEncoderSelector.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/PreferredQueryDataEncoderSelector.java new file mode 100644 index 000000000000..15fa7a492231 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/PreferredQueryDataEncoderSelector.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import com.google.inject.Inject; +import io.airlift.log.Logger; +import io.trino.server.protocol.spooling.encoding.EncryptingQueryDataEncoder; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class PreferredQueryDataEncoderSelector + implements QueryDataEncoderSelector +{ + private final Logger log = Logger.get(PreferredQueryDataEncoderSelector.class); + + private final Map encoder; + private final SpoolingManagerRegistry spoolingManagerRegistry; + + @Inject + public PreferredQueryDataEncoderSelector(Set factories, SpoolingManagerRegistry spoolingManagerRegistry) + { + this.encoder = requireNonNull(factories, "factories is null").stream() + .map(EncryptingQueryDataEncoder.Factory::new) + .collect(toImmutableMap(QueryDataEncoder.Factory::encodingId, identity())); + this.spoolingManagerRegistry = requireNonNull(spoolingManagerRegistry, "spoolingManagerRegistry is null"); + } + + @Override + public Optional select(String encodingHeader) + { + if (spoolingManagerRegistry.getSpoolingManager().isEmpty()) { + log.debug("Client requested spooled encoding '%s' but spooling is disabled", encodingHeader); + return Optional.empty(); + } + + for (String encodingId : encodingHeader.split(",")) { + QueryDataEncoder.Factory factory = encoder.get(encodingId); + if (factory != null) { + return Optional.of(factory); + } + } + log.debug("None of the preferred spooled encodings `%s` are known and supported by the server", encodingHeader); + return Optional.empty(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoder.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoder.java new file mode 100644 index 000000000000..2fa73872ed57 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoder.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import io.trino.Session; +import io.trino.client.spooling.DataAttributes; +import io.trino.server.protocol.OutputColumn; +import io.trino.spi.Page; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; + +public interface QueryDataEncoder +{ + interface Factory + { + QueryDataEncoder create(Session session, List columns); + + String encodingId(); + } + + DataAttributes encodeTo(OutputStream output, List pages) + throws IOException; + + String encodingId(); + + /** + * Returns additional attributes that are passed to the QueryDataDecoder.Factory.create method. + */ + default DataAttributes attributes() + { + return DataAttributes.empty(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoderSelector.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoderSelector.java new file mode 100644 index 000000000000..9b541885bcf1 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataEncoderSelector.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import java.util.Optional; + +/** + * Responsible for choosing a QueryDataEncoder implementation based on a query Session. + * + * Important: needs to be stable across runs as it will be executed on multiple nodes participated in query execution + * and output generation. + * + * Returning Optional.empty() fallbacks to the direct protocol. + */ +@FunctionalInterface +public interface QueryDataEncoderSelector +{ + Optional select(String encodingId); +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataJacksonModule.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataJacksonModule.java new file mode 100644 index 000000000000..fd3b384c12ac --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataJacksonModule.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.Version; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.jsontype.TypeSerializer; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.fasterxml.jackson.databind.ser.BeanSerializerFactory; +import com.fasterxml.jackson.databind.ser.std.StdSerializer; +import io.trino.client.QueryData; +import io.trino.client.RawQueryData; +import io.trino.client.spooling.EncodedQueryData; +import io.trino.client.spooling.InlineSegment; +import io.trino.client.spooling.Segment; +import io.trino.client.spooling.SpooledSegment; + +import java.io.IOException; + +/** + * Encodes the QueryData for existing raw and encoded protocols. + *

+ * + * If the passed QueryData is raw - serialize its' data as a materialized array of array of objects. + * Otherwise, this is a protocol extension and serialize it directly as an object. + */ +public class QueryDataJacksonModule + extends SimpleModule +{ + public QueryDataJacksonModule() + { + super(QueryDataJacksonModule.class.getSimpleName(), Version.unknownVersion()); + addSerializer(QueryData.class, new Serializer()); + addSerializer(Segment.class, new SegmentSerializer()); + registerSubtypes(InlineSegment.class); + registerSubtypes(SpooledSegment.class); + } + + private static class Serializer + extends StdSerializer + { + public Serializer() + { + super(QueryData.class); + } + + @Override + public void serialize(QueryData value, JsonGenerator generator, SerializerProvider provider) + throws IOException + { + switch (value) { + case null -> provider.defaultSerializeNull(generator); + case RawQueryData ignored -> provider.defaultSerializeValue(value.getData(), generator); + case EncodedQueryData encoded -> createSerializer(provider, provider.constructType(EncodedQueryData.class)).serialize(encoded, generator, provider); + default -> throw new IllegalArgumentException("Unsupported QueryData implementation: " + value.getClass().getSimpleName()); + } + } + + @Override + public boolean isEmpty(SerializerProvider provider, QueryData value) + { + // Important for compatibility with some clients that assume absent data field if data is null + return value == null || (value instanceof RawQueryData && value.getData() == null); + } + } + + private static class SegmentSerializer + extends StdSerializer + { + protected SegmentSerializer() + { + super(Segment.class); + } + + @Override + public void serialize(Segment value, JsonGenerator gen, SerializerProvider provider) + throws IOException + { + createSerializer(provider, provider.constructSpecializedType(provider.constructType(Segment.class), value.getClass())).serializeWithType(value, gen, provider, segmentSerializer(provider)); + } + + private static TypeSerializer segmentSerializer(SerializerProvider provider) + { + try { + return provider.findTypeSerializer(provider.constructType(Segment.class)); + } + catch (JsonMappingException e) { + throw new RuntimeException(e); + } + } + } + + @SuppressWarnings("unchecked") + private static JsonSerializer createSerializer(SerializerProvider provider, JavaType javaType) + throws JsonMappingException + { + return (JsonSerializer) BeanSerializerFactory.instance.createSerializer(provider, javaType); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataProducer.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataProducer.java new file mode 100644 index 000000000000..b2adafb0bf0d --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/QueryDataProducer.java @@ -0,0 +1,113 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import io.trino.Session; +import io.trino.client.QueryData; +import io.trino.client.RawQueryData; +import io.trino.client.spooling.DataAttributes; +import io.trino.client.spooling.EncodedQueryData; +import io.trino.client.spooling.Segment; +import io.trino.server.ExternalUriInfo; +import io.trino.server.protocol.OutputColumn; +import io.trino.server.protocol.QueryResultRows; +import io.trino.spi.Page; +import jakarta.ws.rs.core.UriBuilder; + +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +import static io.trino.client.spooling.DataAttribute.ROWS_COUNT; +import static io.trino.client.spooling.DataAttribute.ROW_OFFSET; +import static io.trino.server.protocol.JsonArrayResultsIterator.toIterableList; +import static io.trino.server.protocol.spooling.SegmentResource.spooledSegmentUriBuilder; +import static java.util.Objects.requireNonNull; + +public class QueryDataProducer +{ + private final Optional encoderFactory; + private long currentOffset; + + private AtomicBoolean metadataWritten = new AtomicBoolean(false); + + public QueryDataProducer(Optional encoderFactory) + { + this.encoderFactory = requireNonNull(encoderFactory, "encoderFactory is null"); + } + + public QueryData produce(ExternalUriInfo uriInfo, Session session, QueryResultRows rows, Consumer throwableConsumer) + { + if (rows.isEmpty()) { + return null; + } + + if (encoderFactory.isEmpty()) { + return RawQueryData.of(toIterableList(session, rows, throwableConsumer)); + } + + UriBuilder uriBuilder = spooledSegmentUriBuilder(uriInfo); + QueryDataEncoder encoder = encoderFactory.get().create(session, rows.getOutputColumns().orElseThrow()); + EncodedQueryData.Builder builder = EncodedQueryData.builder(encoder.encodingId()); + List outputColumns = rows.getOutputColumns().orElseThrow(); + + if (metadataWritten.compareAndSet(false, true)) { + // Attributes are emitted only once for the first segment + builder.withAttributes(encoder.attributes()); + } + + try { + for (Page page : rows.getPages()) { + if (hasSpoolingMetadata(page, outputColumns)) { + SpooledBlock metadata = SpooledBlock.deserialize(page); + DataAttributes attributes = metadata.attributes().toBuilder() + .set(ROW_OFFSET, currentOffset) + .build(); + + builder.withSegment(Segment.spooled(buildSegmentURI(uriBuilder, metadata.segmentHandle()), attributes)); + currentOffset += attributes.get(ROWS_COUNT, Long.class); + } + else { + try (ByteArrayOutputStream output = new ByteArrayOutputStream()) { + DataAttributes attributes = encoder.encodeTo(output, List.of(page)) + .toBuilder() + .set(ROW_OFFSET, currentOffset) + .build(); + builder.withSegment(Segment.inlined(output.toByteArray(), attributes)); + } + currentOffset += page.getPositionCount(); + } + } + } + catch (Exception e) { + throwableConsumer.accept(e); + return null; + } + + return builder.build(); + } + + private URI buildSegmentURI(UriBuilder builder, String segmentHandle) + { + return builder.clone().build(segmentHandle); + } + + private boolean hasSpoolingMetadata(Page page, List outputColumns) + { + return page.getChannelCount() == outputColumns.size() + 1 && page.getPositionCount() == 1 && !page.getBlock(outputColumns.size()).isNull(0); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SegmentResource.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SegmentResource.java new file mode 100644 index 000000000000..25c5249df315 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SegmentResource.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; +import io.trino.metadata.InternalNode; +import io.trino.metadata.InternalNodeManager; +import io.trino.server.ExternalUriInfo; +import io.trino.server.security.ResourceSecurity; +import io.trino.spi.HostAddress; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriBuilder; +import jakarta.ws.rs.core.UriInfo; + +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; + +import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +@Path("/v1/spooled/segments/{segmentHandle}") +@ResourceSecurity(PUBLIC) +public class SegmentResource +{ + private final SpoolingManagerBridge spoolingManager; + private final boolean useWorkers; + private final InternalNodeManager nodeManager; + private final AtomicLong nextWorkerIndex = new AtomicLong(); + private final boolean directStorageAccess; + + @Inject + public SegmentResource(SpoolingManagerBridge spoolingManager, SpoolingConfig config, InternalNodeManager nodeManager) + { + this.spoolingManager = requireNonNull(spoolingManager, "spoolingManager is null"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.useWorkers = config.isUseWorkers() && nodeManager.getCurrentNode().isCoordinator(); + this.directStorageAccess = config.isDirectStorageAccess(); + } + + @GET + @Produces(MediaType.APPLICATION_OCTET_STREAM) + @ResourceSecurity(PUBLIC) + @Path("") + public Response download(@Context UriInfo uriInfo, @PathParam("segmentHandle") String segmentHandle) + throws IOException + { + if (directStorageAccess) { + Optional location = spoolingManager.directLocation(segmentHandle); + if (location.isPresent()) { + return Response.seeOther(location.get()).build(); + } + } + + if (useWorkers) { + HostAddress hostAddress = nextActiveNode(); + return Response.seeOther(uriInfo + .getRequestUriBuilder() + .host(hostAddress.getHostText()) + .port(hostAddress.getPort()) + .build()) + .build(); + } + return Response.ok(spoolingManager.openInputStream(segmentHandle)).build(); + } + + @DELETE + @ResourceSecurity(PUBLIC) + @Path("") + public Response acknowledge(@PathParam("segmentHandle") String segmentHandle) + { + try { + spoolingManager.drop(segmentHandle); + return Response.ok().build(); + } + catch (IOException e) { + return Response.serverError().build(); + } + } + + public static UriBuilder spooledSegmentUriBuilder(ExternalUriInfo info) + { + return UriBuilder.fromUri(info.baseUriBuilder().build()) + .path(SegmentResource.class) + .path(SegmentResource.class, "download"); + } + + public HostAddress nextActiveNode() + { + List internalNodes = ImmutableList.copyOf(nodeManager.getActiveNodesSnapshot().getAllNodes()); + return internalNodes.get(toIntExact(nextWorkerIndex.incrementAndGet() % internalNodes.size())) + .getHostAndPort(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpooledBlock.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpooledBlock.java new file mode 100644 index 000000000000..e488a7ea49b9 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpooledBlock.java @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import io.trino.client.spooling.DataAttributes; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RowBlockBuilder; +import io.trino.spi.block.SqlRow; +import io.trino.spi.type.RowType; +import io.trino.sql.planner.Symbol; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.type.VarcharType.VARCHAR; + +public record SpooledBlock(String segmentHandle, DataAttributes attributes) +{ + public static final RowType SPOOLING_METADATA_TYPE = RowType.from(List.of( + new RowType.Field(Optional.of("identifier"), VARCHAR), + new RowType.Field(Optional.of("metadata"), VARCHAR))); + + public static final String SPOOLING_METADATA_COLUMN_NAME = "$spooling:metadata$"; + public static final Symbol SPOOLING_METADATA_SYMBOL = new Symbol(SPOOLING_METADATA_TYPE, SPOOLING_METADATA_COLUMN_NAME); + + public static SpooledBlock deserialize(Page page) + { + verify(page.getPositionCount() == 1, "Spooling metadata block must have a single position"); + verify(hasMetadataBlock(page), "Spooling metadata block must have all but last channels null"); + SqlRow row = SPOOLING_METADATA_TYPE.getObject(page.getBlock(page.getChannelCount() - 1), 0); + + return new SpooledBlock( + VARCHAR.getSlice(row.getRawFieldBlock(0), 0).toStringUtf8(), + DataAttributesSerialization.deserialize(VARCHAR.getSlice(row.getRawFieldBlock(1), 0).toStringUtf8())); + } + + public Block serialize() + { + RowBlockBuilder rowBlockBuilder = SPOOLING_METADATA_TYPE.createBlockBuilder(null, 1); + rowBlockBuilder.buildEntry(rowEntryBuilder -> { + VARCHAR.writeSlice(rowEntryBuilder.get(0), utf8Slice(segmentHandle)); + VARCHAR.writeSlice(rowEntryBuilder.get(1), utf8Slice(DataAttributesSerialization.serialize(attributes))); + }); + return rowBlockBuilder.build(); + } + + public static Page createNonSpooledPage(Page page) + { + RowBlockBuilder rowBlockBuilder = SPOOLING_METADATA_TYPE.createBlockBuilder(null, page.getPositionCount()); + for (int i = 0; i < page.getPositionCount(); i++) { + rowBlockBuilder.appendNull(); + } + return page.appendColumn(rowBlockBuilder.build()); + } + + private static boolean hasMetadataBlock(Page page) + { + for (int channel = 0; channel < page.getChannelCount() - 1; channel++) { + if (!page.getBlock(channel).isNull(0)) { + return false; + } + } + return true; + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingConfig.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingConfig.java new file mode 100644 index 000000000000..c305433abef2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingConfig.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.configuration.ConfigSecuritySensitive; +import io.airlift.units.DataSize; +import jakarta.validation.constraints.AssertTrue; + +import javax.crypto.spec.SecretKeySpec; + +import static io.airlift.units.DataSize.Unit.MEGABYTE; +import static io.trino.util.Ciphers.is256BitSecretKeySpec; +import static java.util.Base64.getDecoder; + +public class SpoolingConfig +{ + private boolean enabled; + private boolean useWorkers; + private boolean directStorageAccess; + + private boolean inlineSegments = true; + private boolean encryptionEnabled = true; + + private DataSize initialSegmentSize = DataSize.of(8, MEGABYTE); + private DataSize maximumSegmentSize = DataSize.of(16, MEGABYTE); + + private SecretKeySpec encryptionKey; + + public boolean isEnabled() + { + return enabled; + } + + @Config("protocol.spooling.enabled") + @ConfigDescription("Enable spooling client protocol server-side support") + public SpoolingConfig setEnabled(boolean enabled) + { + this.enabled = enabled; + return this; + } + + public boolean isUseWorkers() + { + return useWorkers; + } + + @Config("protocol.spooling.worker-access") + @ConfigDescription("Use worker nodes to retrieve data from spooling location") + public SpoolingConfig setUseWorkers(boolean useWorkers) + { + this.useWorkers = useWorkers; + return this; + } + + public boolean isDirectStorageAccess() + { + return directStorageAccess; + } + + @Config("protocol.spooling.direct-storage-access") + @ConfigDescription("Allow clients to directly access spooled segments (if supported by spooling manager)") + public SpoolingConfig setDirectStorageAccess(boolean directStorageAccess) + { + this.directStorageAccess = directStorageAccess; + return this; + } + + public DataSize getInitialSegmentSize() + { + return initialSegmentSize; + } + + @Config("protocol.spooling.initial-segment-size") + @ConfigDescription("Initial size of the spooled segments in bytes") + public void setInitialSegmentSize(DataSize initialSegmentSize) + { + this.initialSegmentSize = initialSegmentSize; + } + + public DataSize getMaximumSegmentSize() + { + return maximumSegmentSize; + } + + @Config("protocol.spooling.maximum-segment-size") + @ConfigDescription("Maximum size of the spooled segments in bytes") + public void setMaximumSegmentSize(DataSize maximumSegmentSize) + { + this.maximumSegmentSize = maximumSegmentSize; + } + + public boolean isInlineSegments() + { + return inlineSegments; + } + + @ConfigDescription("Allow protocol to inline data") + @Config("protocol.spooling.use-inline-segments") + public void setInlineSegments(boolean inlineSegments) + { + this.inlineSegments = inlineSegments; + } + + public boolean isEncryptionEnabled() + { + return encryptionEnabled; + } + + @ConfigDescription("Encrypt spooled segments using random, ephemeral keys generated for the duration of the query") + @Config("protocol.spooling.encryption") + public void setEncryptionEnabled(boolean encryptionEnabled) + { + this.encryptionEnabled = encryptionEnabled; + } + + public SecretKeySpec getEncryptionKey() + { + return encryptionKey; + } + + @ConfigDescription("256 bit, base64-encoded secret key used to secure segment identifiers") + @Config("protocol.spooling.encryption-key") + @ConfigSecuritySensitive + public void setEncryptionKey(String encryptionKey) + { + this.encryptionKey = new SecretKeySpec(getDecoder().decode(encryptionKey), "AES"); + } + + @AssertTrue(message = "protocol.spooling.encryption-key must be 256 bits long") + public boolean isEncryptionKeyAes256() + { + return encryptionKey == null || is256BitSecretKeySpec(encryptionKey); + } + + @AssertTrue(message = "protocol.spooling.encryption-key must be set if spooling is enabled") + public boolean isEncryptionKeySet() + { + if (!enabled) { + return true; + } + return encryptionKey != null; + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerBridge.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerBridge.java new file mode 100644 index 000000000000..d171fb5cc934 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerBridge.java @@ -0,0 +1,163 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import com.google.inject.Inject; +import io.airlift.slice.Slice; +import io.airlift.units.DataSize; +import io.trino.spi.protocol.SpooledSegmentHandle; +import io.trino.spi.protocol.SpoolingContext; +import io.trino.spi.protocol.SpoolingManager; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.spec.SecretKeySpec; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Optional; + +import static io.airlift.slice.Slices.wrappedBuffer; +import static java.util.Base64.getUrlDecoder; +import static java.util.Base64.getUrlEncoder; +import static java.util.Objects.requireNonNull; +import static javax.crypto.Cipher.DECRYPT_MODE; +import static javax.crypto.Cipher.ENCRYPT_MODE; + +public class SpoolingManagerBridge +{ + private final SpoolingManagerRegistry registry; + private final DataSize initialSegmentSize; + private final DataSize maximumSegmentSize; + private final boolean inlineSegments; + private final SecretKeySpec secretKey; + + @Inject + public SpoolingManagerBridge(SpoolingConfig spoolingConfig, SpoolingManagerRegistry registry) + { + this.registry = requireNonNull(registry, "registry is null"); + requireNonNull(spoolingConfig, "spoolingConfig is null"); + this.initialSegmentSize = spoolingConfig.getInitialSegmentSize(); + this.maximumSegmentSize = spoolingConfig.getMaximumSegmentSize(); + this.inlineSegments = spoolingConfig.isInlineSegments(); + this.secretKey = spoolingConfig.getEncryptionKey(); + } + + public boolean isLoaded() + { + return registry + .getSpoolingManager() + .isPresent(); + } + + public long maximumSegmentSize() + { + return maximumSegmentSize.toBytes(); + } + + public long initialSegmentSize() + { + return initialSegmentSize.toBytes(); + } + + public boolean useInlineSegments() + { + return inlineSegments; + } + + public SpooledSegmentHandle create(SpoolingContext context) + { + return delegate().create(context); + } + + public OutputStream createOutputStream(Object handle) + throws Exception + { + return delegate().createOutputStream(decodeHandle(handle)); + } + + public Optional directLocation(Object handle) + { + return delegate().directLocation(decodeHandle(handle)); + } + + public InputStream openInputStream(Object handle) + throws IOException + { + return delegate().openInputStream(decodeHandle(handle)); + } + + public void drop(Object segmentId) + throws IOException + { + delegate().acknowledge(decodeHandle(segmentId)); + } + + private SpooledSegmentHandle decodeHandle(Object handle) + { + if (handle instanceof SpooledSegmentHandle spooledHandle) { + return spooledHandle; + } + + if (handle instanceof String stringValue) { + return delegate().deserialize(decrypt(wrappedBuffer(getUrlDecoder().decode(stringValue)))); + } + + throw new IllegalArgumentException("Unsupported segment id format: " + handle.getClass().getSimpleName()); + } + + public String handleToUriIdentifier(SpooledSegmentHandle handle) + { + return getUrlEncoder().encodeToString(encrypt(delegate().serialize(handle)).getBytes()); + } + + private SpoolingManager delegate() + { + return registry + .getSpoolingManager() + .orElseThrow(() -> new IllegalStateException("Spooling manager is not loaded")); + } + + private Slice encrypt(Slice input) + { + try { + Cipher cipher = Cipher.getInstance("AES"); + cipher.init(ENCRYPT_MODE, secretKey); + return wrappedBuffer(cipher.doFinal(input.getBytes())); + } + catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | IllegalBlockSizeException | + BadPaddingException e) { + throw new RuntimeException("Could not encrypt segment handle", e); + } + } + + private Slice decrypt(Slice input) + { + try { + Cipher cipher = Cipher.getInstance("AES"); + cipher.init(DECRYPT_MODE, secretKey); + return wrappedBuffer(cipher.doFinal(input.getBytes())); + } + catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | IllegalBlockSizeException | + BadPaddingException e) { + throw new RuntimeException("Could not decrypt segment handle", e); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerRegistry.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerRegistry.java new file mode 100644 index 000000000000..b64f545fd7b8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingManagerRegistry.java @@ -0,0 +1,136 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import com.google.inject.Inject; +import io.airlift.log.Logger; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Tracer; +import io.trino.spi.classloader.ThreadContextClassLoader; +import io.trino.spi.protocol.SpoolingManager; +import io.trino.spi.protocol.SpoolingManagerContext; +import io.trino.spi.protocol.SpoolingManagerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static io.airlift.configuration.ConfigurationLoader.loadPropertiesFrom; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class SpoolingManagerRegistry +{ + private final Map spoolingManagerFactories = new ConcurrentHashMap<>(); + + private static final Logger log = Logger.get(SpoolingManagerRegistry.class); + + static final File CONFIG_FILE = new File("etc/spooling-manager.properties"); + private static final String SPOOLING_MANAGER_NAME_PROPERTY = "spooling-manager.name"; + + private final boolean enabled; + private final OpenTelemetry openTelemetry; + private final Tracer tracer; + private volatile SpoolingManager spoolingManager; + + @Inject + public SpoolingManagerRegistry(SpoolingConfig config, OpenTelemetry openTelemetry, Tracer tracer) + { + this.enabled = config.isEnabled(); + this.openTelemetry = requireNonNull(openTelemetry, "openTelemetry is null"); + this.tracer = requireNonNull(tracer, "tracer is null"); + } + + public void addSpoolingManagerFactory(SpoolingManagerFactory factory) + { + requireNonNull(factory, "factory is null"); + if (spoolingManagerFactories.putIfAbsent(factory.getName(), factory) != null) { + throw new IllegalArgumentException(format("Spooling manager factory '%s' is already registered", factory.getName())); + } + } + + public void loadSpoolingManager() + { + if (!enabled) { + // don't load SpoolingManager when spooling is not enabled + return; + } + + if (!CONFIG_FILE.exists()) { + return; + } + + Map properties = loadProperties(); + String name = properties.remove(SPOOLING_MANAGER_NAME_PROPERTY); + checkArgument(!isNullOrEmpty(name), "Spooling manager configuration %s does not contain %s", CONFIG_FILE, SPOOLING_MANAGER_NAME_PROPERTY); + loadSpoolingManager(name, properties); + } + + public synchronized void loadSpoolingManager(String name, Map properties) + { + SpoolingManagerFactory factory = spoolingManagerFactories.get(name); + checkArgument(factory != null, "Spooling manager factory '%s' is not registered. Available factories: %s", name, spoolingManagerFactories.keySet()); + loadSpoolingManager(factory, properties); + } + + public synchronized void loadSpoolingManager(SpoolingManagerFactory factory, Map properties) + { + requireNonNull(factory, "factory is null"); + log.info("-- Loading spooling manager %s --", factory.getName()); + checkState(spoolingManager == null, "spoolingManager is already loaded"); + SpoolingManagerContext context = new SpoolingManagerContext() + { + @Override + public OpenTelemetry getOpenTelemetry() + { + return openTelemetry; + } + + @Override + public Tracer getTracer() + { + return tracer; + } + }; + + SpoolingManager spoolingManager; + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(factory.getClass().getClassLoader())) { + spoolingManager = factory.create(properties, context); + } + this.spoolingManager = spoolingManager; + log.info("-- Loaded spooling manager %s --", factory.getName()); + } + + public Optional getSpoolingManager() + { + return Optional.ofNullable(spoolingManager); + } + + private static Map loadProperties() + { + try { + return new HashMap<>(loadPropertiesFrom(CONFIG_FILE.getPath())); + } + catch (IOException e) { + throw new UncheckedIOException("Failed to read spooling manager configuration file: " + CONFIG_FILE, e); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingServerModule.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingServerModule.java new file mode 100644 index 000000000000..5b27a7a2bc9c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/SpoolingServerModule.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling; + +import com.google.inject.Binder; +import com.google.inject.Scopes; +import com.google.inject.Singleton; +import com.google.inject.multibindings.ProvidesIntoSet; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.server.ServerConfig; +import io.trino.server.protocol.spooling.encoding.QueryDataEncodingModule; + +import java.util.Optional; + +import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder; + +public class SpoolingServerModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + SpoolingConfig spoolingConfig = buildConfigObject(SpoolingConfig.class); + boolean isCoordinator = buildConfigObject(ServerConfig.class).isCoordinator(); + if (spoolingConfig.isEnabled()) { + configBinder(binder).bindConfig(SpoolingConfig.class); + binder.bind(QueryDataEncoderSelector.class).to(PreferredQueryDataEncoderSelector.class).in(Scopes.SINGLETON); + + install(new QueryDataEncodingModule()); + if (spoolingConfig.isUseWorkers() || isCoordinator) { + jaxrsBinder(binder).bind(SegmentResource.class); + } + } + else { + binder.bind(QueryDataEncoderSelector.class).toInstance(_ -> Optional.empty()); + } + + binder.bind(SpoolingManagerRegistry.class).in(Scopes.SINGLETON); + binder.bind(SpoolingManagerBridge.class).in(Scopes.SINGLETON); + } + + @ProvidesIntoSet + @Singleton + // Fully qualified so not to confuse with Guice's Module + public static com.fasterxml.jackson.databind.Module queryDataJacksonModule() + { + return new QueryDataJacksonModule(); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/CompressedQueryDataEncoder.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/CompressedQueryDataEncoder.java new file mode 100644 index 000000000000..15ead9360f25 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/CompressedQueryDataEncoder.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling.encoding; + +import io.trino.client.spooling.DataAttributes; +import io.trino.server.protocol.spooling.QueryDataEncoder; +import io.trino.spi.Page; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; + +import static com.google.common.primitives.Ints.saturatedCast; +import static io.trino.client.spooling.DataAttribute.BYTE_SIZE; +import static io.trino.client.spooling.DataAttribute.UNCOMPRESSED_SIZE; + +public abstract class CompressedQueryDataEncoder + implements QueryDataEncoder +{ + protected final QueryDataEncoder delegate; + private final int compressionThreshold; + + protected CompressedQueryDataEncoder(QueryDataEncoder delegate, int compressionThreshold) + { + this.delegate = delegate; + this.compressionThreshold = compressionThreshold; + } + + @Override + public DataAttributes encodeTo(OutputStream output, List pages) + throws IOException + { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(pagesSize(pages)); + DataAttributes attributes = delegate.encodeTo(buffer, pages); + int uncompressedSize = attributes.get(BYTE_SIZE, Integer.class); + + // Do not compress data if below threshold + if (uncompressedSize < compressionThreshold) { + buffer.writeTo(output); + return attributes; + } + + return attributes + .toBuilder() + .set(BYTE_SIZE, compress(buffer.toByteArray(), uncompressedSize, output)) // actual size of compressed data + .set(UNCOMPRESSED_SIZE, uncompressedSize) // expected by the decoder if the data is compressed + .build(); + } + + protected abstract int compress(byte[] buffer, int uncompressedSize, OutputStream output) + throws IOException; + + protected static int pagesSize(List pages) + { + return saturatedCast(pages.stream() + .map(Page::getSizeInBytes) + .reduce(0L, Long::sum)); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/EncryptingQueryDataEncoder.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/EncryptingQueryDataEncoder.java new file mode 100644 index 000000000000..f96651db04c3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/EncryptingQueryDataEncoder.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling.encoding; + +import io.airlift.slice.Slice; +import io.trino.Session; +import io.trino.client.spooling.DataAttributes; +import io.trino.server.protocol.OutputColumn; +import io.trino.server.protocol.spooling.QueryDataEncoder; +import io.trino.spi.Page; + +import javax.crypto.Cipher; +import javax.crypto.CipherOutputStream; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.spec.SecretKeySpec; + +import java.io.IOException; +import java.io.OutputStream; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.client.spooling.DataAttribute.ENCRYPTION_CIPHER_NAME; +import static io.trino.client.spooling.DataAttribute.ENCRYPTION_KEY; +import static io.trino.client.spooling.encoding.CipherUtils.serializeSecretKey; +import static io.trino.util.Ciphers.deserializeAesEncryptionKey; +import static io.trino.util.Ciphers.is256BitSecretKeySpec; +import static java.util.Objects.requireNonNull; +import static javax.crypto.Cipher.ENCRYPT_MODE; + +public class EncryptingQueryDataEncoder + implements QueryDataEncoder +{ + private static final String CIPHER_NAME = "AES"; + private final QueryDataEncoder delegate; + private final SecretKeySpec key; + + public EncryptingQueryDataEncoder(QueryDataEncoder delegate, Slice encryptionKey) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + this.key = deserializeAesEncryptionKey(requireNonNull(encryptionKey, "encryptionKey is null")); + checkArgument(is256BitSecretKeySpec(key), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit AES key"); + } + + @Override + public DataAttributes encodeTo(OutputStream output, List pages) + throws IOException + { + try (CipherOutputStream encryptedOutput = new CipherOutputStream(output, createCipher(key))) { + return delegate.encodeTo(encryptedOutput, pages); + } + } + + @Override + public DataAttributes attributes() + { + return delegate + .attributes() + .toBuilder() + .set(ENCRYPTION_KEY, serializeSecretKey(key)) + .set(ENCRYPTION_CIPHER_NAME, CIPHER_NAME) + .build(); + } + + @Override + public String encodingId() + { + return delegate.encodingId(); + } + + private static Cipher createCipher(SecretKeySpec privateKey) + { + try { + Cipher cipher = Cipher.getInstance(CIPHER_NAME); + cipher.init(ENCRYPT_MODE, privateKey); + return cipher; + } + catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) { + throw new RuntimeException(e); + } + } + + public static class Factory + implements QueryDataEncoder.Factory + { + private final QueryDataEncoder.Factory delegate; + + public Factory(QueryDataEncoder.Factory delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + @Override + public QueryDataEncoder create(Session session, List columns) + { + QueryDataEncoder encoder = delegate.create(session, columns); + if (session.getQueryDataEncryptionKey().isEmpty()) { + return encoder; + } + return new EncryptingQueryDataEncoder(encoder, session.getQueryDataEncryptionKey().orElseThrow()); + } + + @Override + public String encodingId() + { + return delegate.encodingId(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/JsonQueryDataEncoder.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/JsonQueryDataEncoder.java new file mode 100644 index 000000000000..c1f6cb6c7634 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/JsonQueryDataEncoder.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling.encoding; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.io.CountingOutputStream; +import com.google.inject.Inject; +import io.trino.Session; +import io.trino.client.spooling.DataAttributes; +import io.trino.server.protocol.JsonArrayResultsIterator; +import io.trino.server.protocol.OutputColumn; +import io.trino.server.protocol.spooling.QueryDataEncoder; +import io.trino.spi.Page; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; + +import static io.trino.client.spooling.DataAttribute.BYTE_SIZE; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + +public class JsonQueryDataEncoder + implements QueryDataEncoder +{ + private static final String ENCODING_ID = "json-ext"; + private final ObjectMapper mapper; + private final Session session; + private final List columns; + + public JsonQueryDataEncoder(ObjectMapper mapper, Session session, List columns) + { + this.mapper = requireNonNull(mapper, "mapper is null"); + this.session = requireNonNull(session, "session is null"); + this.columns = requireNonNull(columns, "columns is null"); + } + + @Override + public DataAttributes encodeTo(OutputStream output, List pages) + throws IOException + { + ImmutableList.Builder serializationExceptions = ImmutableList.builder(); + JsonArrayResultsIterator values = new JsonArrayResultsIterator( + session, + pages, + columns, + serializationExceptions::add); + + try { + CountingOutputStream wrapper = new CountingOutputStream(output); + mapper.writeValue(wrapper, values); + List exceptions = serializationExceptions.build(); + if (!exceptions.isEmpty()) { + throw new RuntimeException("Could not serialize to JSON", exceptions.getFirst()); + } + + return DataAttributes.builder() + .set(BYTE_SIZE, toIntExact(wrapper.getCount())) + .build(); + } + catch (JsonProcessingException e) { + throw new IOException("Could not serialize to JSON", e); + } + } + + @Override + public String encodingId() + { + return ENCODING_ID; + } + + public static class Factory + implements QueryDataEncoder.Factory + { + protected final ObjectMapper mapper; + + @Inject + public Factory(ObjectMapper mapper) + { + this.mapper = requireNonNull(mapper, "mapper is null"); + } + + @Override + public QueryDataEncoder create(Session session, List columns) + { + return new JsonQueryDataEncoder(mapper, session, columns); + } + + @Override + public String encodingId() + { + return ENCODING_ID; + } + } + + public static class ZstdFactory + extends Factory + { + @Inject + public ZstdFactory(ObjectMapper mapper) + { + super(mapper); + } + + @Override + public QueryDataEncoder create(Session session, List columns) + { + return new ZstdQueryDataEncoder(super.create(session, columns)); + } + + @Override + public String encodingId() + { + return super.encodingId() + "+zstd"; + } + } + + public static class SnappyFactory + extends Factory + { + @Inject + public SnappyFactory(ObjectMapper mapper) + { + super(mapper); + } + + @Override + public QueryDataEncoder create(Session session, List columns) + { + return new SnappyQueryDataEncoder(super.create(session, columns)); + } + + @Override + public String encodingId() + { + return super.encodingId() + "+snappy"; + } + } + + public static class Lz4Factory + extends Factory + { + @Inject + public Lz4Factory(ObjectMapper mapper) + { + super(mapper); + } + + @Override + public QueryDataEncoder create(Session session, List columns) + { + return new Lz4QueryDataEncoder(super.create(session, columns)); + } + + @Override + public String encodingId() + { + return super.encodingId() + "+lz4"; + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/Lz4QueryDataEncoder.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/Lz4QueryDataEncoder.java new file mode 100644 index 000000000000..4dfa99e423fa --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/Lz4QueryDataEncoder.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling.encoding; + +import io.airlift.compress.lz4.Lz4Compressor; +import io.trino.server.protocol.spooling.QueryDataEncoder; + +import java.io.IOException; +import java.io.OutputStream; + +public class Lz4QueryDataEncoder + extends CompressedQueryDataEncoder +{ + private static final int COMPRESSION_THRESHOLD = 2_048; + + public Lz4QueryDataEncoder(QueryDataEncoder delegate) + { + super(delegate, COMPRESSION_THRESHOLD); + } + + @Override + protected int compress(byte[] buffer, int uncompressedSize, OutputStream output) + throws IOException + { + Lz4Compressor compressor = new Lz4Compressor(); + byte[] compressed = new byte[compressor.maxCompressedLength(uncompressedSize)]; + int compressedSize = compressor.compress(buffer, 0, uncompressedSize, compressed, 0, compressed.length); + output.write(compressed, 0, compressedSize); + return compressedSize; + } + + @Override + public String encodingId() + { + return delegate.encodingId() + "+lz4"; + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/QueryDataEncodingModule.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/QueryDataEncodingModule.java new file mode 100644 index 000000000000..8daa5b70be3c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/QueryDataEncodingModule.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling.encoding; + +import com.google.inject.Binder; +import com.google.inject.Scopes; +import com.google.inject.multibindings.Multibinder; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.server.protocol.spooling.QueryDataEncoder; + +import static com.google.inject.multibindings.Multibinder.newSetBinder; + +public class QueryDataEncodingModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + Multibinder encoderFactories = newSetBinder(binder, QueryDataEncoder.Factory.class); + + // json + compressed variants + encoderFactories.addBinding().to(JsonQueryDataEncoder.Factory.class).in(Scopes.SINGLETON); + encoderFactories.addBinding().to(JsonQueryDataEncoder.ZstdFactory.class).in(Scopes.SINGLETON); + encoderFactories.addBinding().to(JsonQueryDataEncoder.SnappyFactory.class).in(Scopes.SINGLETON); + encoderFactories.addBinding().to(JsonQueryDataEncoder.Lz4Factory.class).in(Scopes.SINGLETON); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/SnappyQueryDataEncoder.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/SnappyQueryDataEncoder.java new file mode 100644 index 000000000000..de68ee634839 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/SnappyQueryDataEncoder.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling.encoding; + +import io.airlift.compress.snappy.SnappyCompressor; +import io.trino.server.protocol.spooling.QueryDataEncoder; + +import java.io.IOException; +import java.io.OutputStream; + +public class SnappyQueryDataEncoder + extends CompressedQueryDataEncoder +{ + private static final int COMPRESSION_THRESHOLD = 2_048; + + public SnappyQueryDataEncoder(QueryDataEncoder delegate) + { + super(delegate, COMPRESSION_THRESHOLD); + } + + @Override + protected int compress(byte[] buffer, int uncompressedSize, OutputStream output) + throws IOException + { + SnappyCompressor compressor = new SnappyCompressor(); + byte[] compressed = new byte[compressor.maxCompressedLength(uncompressedSize)]; + int compressedSize = compressor.compress(buffer, 0, uncompressedSize, compressed, 0, compressed.length); + output.write(compressed, 0, compressedSize); + return compressedSize; + } + + @Override + public String encodingId() + { + return delegate.encodingId() + "+snappy"; + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/ZstdQueryDataEncoder.java b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/ZstdQueryDataEncoder.java new file mode 100644 index 000000000000..787da15523c5 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/protocol/spooling/encoding/ZstdQueryDataEncoder.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol.spooling.encoding; + +import io.airlift.compress.zstd.ZstdCompressor; +import io.trino.server.protocol.spooling.QueryDataEncoder; + +import java.io.IOException; +import java.io.OutputStream; + +public class ZstdQueryDataEncoder + extends CompressedQueryDataEncoder +{ + private static final int COMPRESSION_THRESHOLD = 2_048; + + public ZstdQueryDataEncoder(QueryDataEncoder delegate) + { + super(delegate, COMPRESSION_THRESHOLD); + } + + @Override + protected int compress(byte[] buffer, int uncompressedSize, OutputStream output) + throws IOException + { + ZstdCompressor compressor = new ZstdCompressor(); + byte[] compressed = new byte[compressor.maxCompressedLength(uncompressedSize)]; + int compressedSize = compressor.compress(buffer, 0, uncompressedSize, compressed, 0, compressed.length); + output.write(compressed, 0, compressedSize); + return compressedSize; + } + + @Override + public String encodingId() + { + return delegate.encodingId() + "+zstd"; + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index 5f87c6825aae..d58216ce8c5a 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -85,6 +85,7 @@ import io.trino.server.SessionSupplier; import io.trino.server.ShutdownAction; import io.trino.server.StartupStatus; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.server.security.CertificateAuthenticatorManager; import io.trino.server.security.ServerSecurityModule; import io.trino.spi.ErrorType; @@ -211,6 +212,7 @@ public static Builder builder() private final boolean coordinator; private final FailureInjector failureInjector; private final ExchangeManagerRegistry exchangeManagerRegistry; + private final SpoolingManagerRegistry spoolingManagerRegistry; public static class TestShutdownAction implements ShutdownAction @@ -248,6 +250,7 @@ private TestingTrinoServer( Optional baseDataDir, Optional spanProcessor, Optional systemAccessControlConfiguration, + Optional spoolingConfiguration, Optional> systemAccessControls, List eventListeners, Consumer additionalConfiguration, @@ -419,6 +422,7 @@ private TestingTrinoServer( mBeanServer = injector.getInstance(MBeanServer.class); failureInjector = injector.getInstance(FailureInjector.class); exchangeManagerRegistry = injector.getInstance(ExchangeManagerRegistry.class); + spoolingManagerRegistry = injector.getInstance(SpoolingManagerRegistry.class); systemAccessControlConfiguration.ifPresentOrElse( configuration -> { @@ -427,6 +431,9 @@ private TestingTrinoServer( }, () -> accessControl.setSystemAccessControls(systemAccessControls.orElseThrow())); + spoolingConfiguration.ifPresent(config -> + spoolingManagerRegistry.loadSpoolingManager(config.factoryName(), config.configuration())); + EventListenerManager eventListenerManager = injector.getInstance(EventListenerManager.class); eventListeners.forEach(eventListenerManager::addEventListener); @@ -507,6 +514,11 @@ public void loadExchangeManager(String name, Map properties) exchangeManagerRegistry.loadExchangeManager(name, properties); } + public void loadSpoolingManager(String name, Map properties) + { + spoolingManagerRegistry.loadSpoolingManager(name, properties); + } + /** * Add the event listeners from connectors. Connector event listeners are * only supported for statically loaded catalogs, and this doesn't match up @@ -724,6 +736,7 @@ public static class Builder private Optional baseDataDir = Optional.empty(); private Optional spanProcessor = Optional.empty(); private Optional systemAccessControlConfiguration = Optional.empty(); + private Optional spoolingConfiguration = Optional.empty(); private Optional> systemAccessControls = Optional.of(ImmutableList.of()); private List eventListeners = ImmutableList.of(); private Consumer additionalConfiguration = _ -> {}; @@ -826,6 +839,7 @@ public TestingTrinoServer build() baseDataDir, spanProcessor, systemAccessControlConfiguration, + spoolingConfiguration, systemAccessControls, eventListeners, additionalConfiguration, diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java index 3eb58e731582..4fb2d3073f70 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/QueryExplainer.java @@ -20,6 +20,8 @@ import io.trino.cost.StatsCalculator; import io.trino.execution.querystats.PlanOptimizersStatsCollector; import io.trino.execution.warnings.WarningCollector; +import io.trino.server.protocol.spooling.SpoolingConfig; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.spi.TrinoException; import io.trino.sql.PlannerContext; import io.trino.sql.SqlFormatter; @@ -46,6 +48,8 @@ import java.util.List; import java.util.Optional; +import static io.airlift.tracing.Tracing.noopTracer; +import static io.opentelemetry.api.OpenTelemetry.noop; import static io.trino.execution.ParameterExtractor.bindParameters; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.sql.analyzer.QueryType.EXPLAIN; @@ -168,6 +172,7 @@ public Plan getLogicalPlan(Session session, Statement statement, List encoderFactory = session + .getQueryDataEncodingId() + .flatMap(encoderSelector::select); + PhysicalOperation operation = node.getSource().accept(this, context); + + if (encoderFactory.isEmpty() || !spoolingManager.isLoaded()) { + // There is no spooling manager or no supported spooled encoding factory + return operation; + } + + Map spooledLayout = OutputSpoolingOperatorFactory.spooledLayout(operation.layout); + List outputLayout = spooledOutputLayout(node, spooledLayout); + return new PhysicalOperation(new OutputSpoolingOperatorFactory(context.getNextOperatorId(), node.getId(), spooledLayout, encoderFactory.orElseThrow().create(session, outputLayout), spoolingManager), spooledLayout, operation); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index efce7fb8c1e8..6a5916612af6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -43,6 +43,7 @@ import io.trino.metadata.TableLayout; import io.trino.metadata.TableMetadata; import io.trino.operator.RetryPolicy; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.spi.ErrorCodeSupplier; import io.trino.spi.RefreshType; import io.trino.spi.TrinoException; @@ -133,6 +134,7 @@ import static io.trino.SystemSessionProperties.isCollectPlanStatisticsForAllQueries; import static io.trino.SystemSessionProperties.isUsePreferredWritePartitioning; import static io.trino.metadata.MetadataUtil.createQualifiedObjectName; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_SYMBOL; import static io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND; import static io.trino.spi.StandardErrorCode.CONSTRAINT_VIOLATION; import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; @@ -181,6 +183,7 @@ public enum Stage private final SymbolAllocator symbolAllocator = new SymbolAllocator(); private final Metadata metadata; private final PlannerContext plannerContext; + private final SpoolingManagerRegistry spoolingManagerRegistry; private final StatisticsAggregationPlanner statisticsAggregationPlanner; private final StatsCalculator statsCalculator; private final CostCalculator costCalculator; @@ -193,13 +196,14 @@ public LogicalPlanner( List planOptimizers, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, + SpoolingManagerRegistry spoolingManagerRegistry, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, CachingTableStatsProvider tableStatsProvider) { - this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, plannerContext, statsCalculator, costCalculator, warningCollector, planOptimizersStatsCollector, tableStatsProvider); + this(session, planOptimizers, DISTRIBUTED_PLAN_SANITY_CHECKER, idAllocator, plannerContext, spoolingManagerRegistry, statsCalculator, costCalculator, warningCollector, planOptimizersStatsCollector, tableStatsProvider); } public LogicalPlanner( @@ -208,6 +212,7 @@ public LogicalPlanner( PlanSanityChecker planSanityChecker, PlanNodeIdAllocator idAllocator, PlannerContext plannerContext, + SpoolingManagerRegistry spoolingManagerRegistry, StatsCalculator statsCalculator, CostCalculator costCalculator, WarningCollector warningCollector, @@ -219,6 +224,7 @@ public LogicalPlanner( this.planSanityChecker = requireNonNull(planSanityChecker, "planSanityChecker is null"); this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.plannerContext = requireNonNull(plannerContext, "plannerContext is null"); + this.spoolingManagerRegistry = requireNonNull(spoolingManagerRegistry, "spoolingManagerRegistry is null"); this.metadata = plannerContext.getMetadata(); this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(symbolAllocator, plannerContext, session); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); @@ -903,6 +909,10 @@ private PlanNode createOutputPlan(RelationPlan plan, Analysis analysis) columnNumber++; } + if (session.getQueryDataEncodingId().isPresent() && spoolingManagerRegistry.getSpoolingManager().isPresent()) { + names.add(SPOOLING_METADATA_SYMBOL.name()); + outputs.add(SPOOLING_METADATA_SYMBOL); + } return new OutputNode(idAllocator.getNextId(), plan.getRoot(), names.build(), outputs.build()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java index 9a5baf405c1d..a6c12d5326ec 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/GraphvizPrinter.java @@ -16,6 +16,7 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import io.trino.server.protocol.spooling.SpooledBlock; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Expression; import io.trino.sql.ir.Reference; @@ -70,6 +71,7 @@ import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Collections2.filter; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Maps.immutableEnumMap; import static io.trino.sql.ir.Booleans.TRUE; @@ -618,7 +620,7 @@ private void printNode(PlanNode node, String label, String details, String color private static String getColumns(OutputNode node) { - Iterator columnNames = node.getColumnNames().iterator(); + Iterator columnNames = filter(node.getColumnNames(), value -> !value.equals(SpooledBlock.SPOOLING_METADATA_COLUMN_NAME)).iterator(); String columns = ""; int nameWidth = 0; while (columnNames.hasNext()) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index d086c2648906..7e905fb679b6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -16,6 +16,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.CaseFormat; import com.google.common.base.Joiner; +import com.google.common.collect.Collections2; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -159,6 +160,8 @@ import static io.trino.metadata.GlobalFunctionCatalog.isBuiltinFunctionName; import static io.trino.metadata.LanguageFunctionManager.isInlineFunction; import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_COLUMN_NAME; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_TYPE; import static io.trino.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.ir.Booleans.TRUE; @@ -533,6 +536,7 @@ private static String formatFragment( PartitioningScheme partitioningScheme = fragment.getOutputPartitioningScheme(); List layout = partitioningScheme.getOutputLayout().stream() .map(anonymizer::anonymize) + .filter(value -> !value.equals(SPOOLING_METADATA_COLUMN_NAME)) .collect(toImmutableList()); builder.append(indentString(1)) .append(format("Output layout: [%s]\n", @@ -1388,11 +1392,16 @@ public Void visitOutput(OutputNode node, Context context) NodeRepresentation nodeOutput = addNode( node, "Output", - ImmutableMap.of("columnNames", formatCollection(node.getColumnNames(), anonymizer::anonymizeColumn)), + ImmutableMap.of("columnNames", formatCollection(Collections2.filter(node.getColumnNames(), this::isNonSpooledColumn), anonymizer::anonymizeColumn)), context); for (int i = 0; i < node.getColumnNames().size(); i++) { String name = node.getColumnNames().get(i); Symbol symbol = node.getOutputSymbols().get(i); + + if (symbol.type().equals(SPOOLING_METADATA_TYPE)) { + continue; + } + if (!name.equals(symbol.name())) { nodeOutput.appendDetails("%s := %s", anonymizer.anonymizeColumn(name), anonymizer.anonymize(symbol)); } @@ -1400,6 +1409,11 @@ public Void visitOutput(OutputNode node, Context context) return processChildren(node, new Context(context.isInitialPlan())); } + private boolean isNonSpooledColumn(String columnName) + { + return !columnName.equals(SPOOLING_METADATA_COLUMN_NAME); + } + @Override public Void visitTopN(TopNNode node, Context context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java index d06afa9d8fa7..8e72b9060dfc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/TextRenderer.java @@ -34,6 +34,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_COLUMN_NAME; import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.POSITIVE_INFINITY; import static java.lang.Double.isFinite; @@ -75,6 +76,7 @@ private String writeTextOutput(StringBuilder output, PlanRepresentation plan, In .append("\n"); String columns = node.getOutputs().stream() + .filter(s -> !s.name().equals(SPOOLING_METADATA_COLUMN_NAME)) .map(s -> s.name() + ":" + s.type()) .collect(joining(", ")); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 720ca7fe12ab..e78e9b88611f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -88,7 +88,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.Collections2.filter; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.server.protocol.spooling.SpooledBlock.SPOOLING_METADATA_SYMBOL; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.optimizations.IndexJoinOptimizer.IndexKeyTracer; @@ -497,7 +499,7 @@ public Void visitOutput(OutputNode node, Set boundSymbols) PlanNode source = node.getSource(); source.accept(this, boundSymbols); // visit child - checkDependencies(source.getOutputSymbols(), node.getOutputSymbols(), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols()); + checkDependencies(source.getOutputSymbols(), filter(node.getOutputSymbols(), symbol -> !symbol.equals(SPOOLING_METADATA_SYMBOL)), "Invalid node. Output column dependencies (%s) not in source plan output (%s)", node.getOutputSymbols(), source.getOutputSymbols()); return null; } diff --git a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java index 11bdd2b8ec42..cdc0332fe784 100644 --- a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java +++ b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java @@ -20,7 +20,6 @@ import io.airlift.configuration.secrets.SecretsResolver; import io.airlift.node.NodeInfo; import io.airlift.units.Duration; -import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.Tracer; import io.trino.FeaturesConfig; @@ -126,6 +125,10 @@ import io.trino.security.GroupProviderManager; import io.trino.server.PluginManager; import io.trino.server.SessionPropertyDefaults; +import io.trino.server.protocol.spooling.PreferredQueryDataEncoderSelector; +import io.trino.server.protocol.spooling.SpoolingConfig; +import io.trino.server.protocol.spooling.SpoolingManagerBridge; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.server.security.CertificateAuthenticatorManager; import io.trino.server.security.HeaderAuthenticatorConfig; import io.trino.server.security.HeaderAuthenticatorManager; @@ -207,6 +210,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -218,6 +222,7 @@ import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.tracing.Tracing.noopTracer; +import static io.opentelemetry.api.OpenTelemetry.noop; import static io.trino.connector.CatalogServiceProviderModule.createAccessControlProvider; import static io.trino.connector.CatalogServiceProviderModule.createAnalyzePropertyManager; import static io.trino.connector.CatalogServiceProviderModule.createColumnPropertyManager; @@ -296,7 +301,7 @@ public class PlanTester private final CoordinatorDynamicCatalogManager catalogManager; private final PluginManager pluginManager; private final ExchangeManagerRegistry exchangeManagerRegistry; - + private final SpoolingManagerRegistry spoolingManagerRegistry; private final TaskManagerConfig taskManagerConfig; private final OptimizerConfig optimizerConfig; private final StatementAnalyzerFactory statementAnalyzerFactory; @@ -375,7 +380,7 @@ private PlanTester(Session defaultSession, int nodeCountForStats) pageIndexerFactory, nodeInfo, testingVersionEmbedder(), - OpenTelemetry.noop(), + noop(), transactionManager, typeManager, nodeSchedulerConfig, @@ -447,7 +452,8 @@ private PlanTester(Session defaultSession, int nodeCountForStats) ImmutableSet.of(), ImmutableSet.of(new ExcludeColumnsFunction())); - exchangeManagerRegistry = new ExchangeManagerRegistry(OpenTelemetry.noop(), noopTracer(), secretsResolver); + exchangeManagerRegistry = new ExchangeManagerRegistry(noop(), noopTracer(), secretsResolver); + spoolingManagerRegistry = new SpoolingManagerRegistry(new SpoolingConfig(), noop(), noopTracer()); this.pluginManager = new PluginManager( (loader, createClassLoader) -> {}, Optional.empty(), @@ -464,7 +470,8 @@ private PlanTester(Session defaultSession, int nodeCountForStats) typeRegistry, blockEncodingManager, new HandleResolver(), - exchangeManagerRegistry); + exchangeManagerRegistry, + spoolingManagerRegistry); catalogManager.registerGlobalSystemConnector(globalSystemConnector); languageFunctionManager.setPlannerContext(plannerContext); @@ -496,7 +503,9 @@ private PlanTester(Session defaultSession, int nodeCountForStats) sessionPropertyManager, defaultSession.getPreparedStatements(), defaultSession.getProtocolHeaders(), - defaultSession.getExchangeEncryptionKey()); + defaultSession.getExchangeEncryptionKey(), + defaultSession.getQueryDataEncodingId(), + defaultSession.getQueryDataEncryptionKey()); } private static SessionPropertyManager createSessionPropertyManager( @@ -716,6 +725,7 @@ private List createDrivers(Session session, @Language("SQL") String sql) throw new AssertionError("Expected sub-plan to have no children"); } + SpoolingManagerRegistry spoolingManagerRegistry = new SpoolingManagerRegistry(new SpoolingConfig(), noop(), noopTracer()); TaskContext taskContext = createTaskContext(notificationExecutor, yieldExecutor, session); TableExecuteContextManager tableExecuteContextManager = new TableExecuteContextManager(); tableExecuteContextManager.registerTableExecuteContextForQuery(taskContext.getQueryContext().getQueryId()); @@ -733,6 +743,8 @@ private List createDrivers(Session session, @Language("SQL") String sql) new IndexJoinLookupStats(), this.taskManagerConfig, new GenericSpillerFactory(unsupportedSingleStreamSpillerFactory()), + new PreferredQueryDataEncoderSelector(Set.of(), spoolingManagerRegistry), + new SpoolingManagerBridge(new SpoolingConfig(), spoolingManagerRegistry), unsupportedSingleStreamSpillerFactory(), unsupportedPartitioningSpillerFactory(), new PagesIndex.TestingFactory(false), @@ -873,6 +885,7 @@ public Plan createPlan(Session session, @Language("SQL") String sql, List properties); + void loadSpoolingManager(String name, Map properties); + record MaterializedResultWithPlan(QueryId queryId, Optional queryPlan, MaterializedResult result) {}} diff --git a/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java index 3bbc43651515..d62041920b7f 100644 --- a/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/StandaloneQueryRunner.java @@ -316,4 +316,10 @@ public void loadExchangeManager(String name, Map properties) { server.loadExchangeManager(name, properties); } + + @Override + public void loadSpoolingManager(String name, Map properties) + { + server.loadSpoolingManager(name, properties); + } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index ada597f53fc6..0ae564e950f6 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -39,6 +39,10 @@ import io.trino.operator.PagesIndex; import io.trino.operator.index.IndexJoinLookupStats; import io.trino.operator.index.IndexManager; +import io.trino.server.protocol.spooling.PreferredQueryDataEncoderSelector; +import io.trino.server.protocol.spooling.SpoolingConfig; +import io.trino.server.protocol.spooling.SpoolingManagerBridge; +import io.trino.server.protocol.spooling.SpoolingManagerRegistry; import io.trino.spi.connector.CatalogHandle; import io.trino.spiller.GenericSpillerFactory; import io.trino.split.PageSinkManager; @@ -68,7 +72,10 @@ import java.util.List; import java.util.Optional; +import java.util.Set; +import static io.airlift.tracing.Tracing.noopTracer; +import static io.opentelemetry.api.OpenTelemetry.noop; import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -154,6 +161,7 @@ public static LocalExecutionPlanner createTestingPlanner() PageFunctionCompiler pageFunctionCompiler = new PageFunctionCompiler(PLANNER_CONTEXT.getFunctionManager(), 0); ColumnarFilterCompiler columnarFilterCompiler = new ColumnarFilterCompiler(PLANNER_CONTEXT.getFunctionManager(), 0); + SpoolingManagerRegistry spoolingManagerRegistry = new SpoolingManagerRegistry(new SpoolingConfig(), noop(), noopTracer()); return new LocalExecutionPlanner( PLANNER_CONTEXT, Optional.empty(), @@ -170,6 +178,8 @@ public static LocalExecutionPlanner createTestingPlanner() new GenericSpillerFactory((types, spillContext, memoryContext) -> { throw new UnsupportedOperationException(); }), + new PreferredQueryDataEncoderSelector(Set.of(), spoolingManagerRegistry), + new SpoolingManagerBridge(new SpoolingConfig(), spoolingManagerRegistry), (types, spillContext, memoryContext) -> { throw new UnsupportedOperationException(); }, diff --git a/core/trino-main/src/test/java/io/trino/operator/TestOutputSpoolingController.java b/core/trino-main/src/test/java/io/trino/operator/TestOutputSpoolingController.java new file mode 100644 index 000000000000..772cad871a6e --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/TestOutputSpoolingController.java @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import org.junit.jupiter.api.Test; + +import static io.trino.operator.OutputSpoolingController.Mode.BUFFER; +import static io.trino.operator.OutputSpoolingController.Mode.INLINE; +import static io.trino.operator.OutputSpoolingController.Mode.SPOOL; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +class TestOutputSpoolingController +{ + @Test + public void testInlineFirstRowsUntilThresholdThenSpooling() + { + var assertion = new OutputSpoolingControllerAssertions( + new OutputSpoolingController(true, 100, 1000, 900, 16000)); + + assertion + .verifyNextMode(10, 100, INLINE) + .verifyInlined(1, 10, 100) + .verifyNextMode(10, 100, INLINE) + .verifyInlined(2, 20, 200) + .verifyNextMode(50, 400, INLINE) + .verifyInlined(3, 70, 600) + .verifyNextMode(50, 400, BUFFER) + .verifyBuffered(50, 400) + .verifyNextMode(50, 400, BUFFER) + .verifyBuffered(100, 800) + .verifyNextMode(50, 400, SPOOL) + .verifySpooled(1, 150, 1200) + .verifyEmptyBuffer() + .verifyNextMode(39, 399, BUFFER) + .verifyBuffered(39, 399); + } + + @Test + public void testSpoolingTargetSize() + { + var assertion = new OutputSpoolingControllerAssertions( + new OutputSpoolingController(false, 0, 0, 512, 2048)); + + assertion + .verifyNextMode(100, 511, BUFFER) // still under the initial segment target + .verifySpooledSegmentTarget(512) + .verifyBuffered(100, 511) + .verifyNextMode(100, 1, SPOOL) + .verifySpooled(1, 200, 512) + .verifySpooledSegmentTarget(1024) // target doubles + .verifyEmptyBuffer() + .verifyNextMode(1, 333, BUFFER) + .verifyNextMode(1, 333, BUFFER) + .verifyNextMode(1, 333, BUFFER) + .verifyNextMode(1, 333, SPOOL) + .verifySpooled(2, 204, 512 + 333 * 4) + .verifyEmptyBuffer() + .verifySpooledSegmentTarget(2048) // target doubled again + .verifyNextMode(100, 2047, BUFFER) + .verifyNextMode(100, 2047, SPOOL) + .verifyEmptyBuffer() + .verifySpooledSegmentTarget(2048) // target clamped at max + .verifySpooled(3, 204 + 200, 512 + 333 * 4 + 2047 * 2); + } + + @Test + public void testSpoolingEncoderEfficiency() + { + var assertion = new OutputSpoolingControllerAssertions( + new OutputSpoolingController(false, 0, 0, 32, 100)); + + assertion + .verifyNextMode(1000, 31, BUFFER) + .verifyBuffered(1000, 31) + .verifyNextMode(1000, 31, SPOOL) + .verifySpooled(1, 2000, 62) + .recordEncodedSize(31) + .verifyEmptyBuffer() + .verifySpooledSegmentTarget(64) + .verifyNextMode(100, 80, SPOOL) + .verifyNextMode(100, 47, BUFFER) // over segment size + .verifySpooled(2, 2100, 142) + .verifyBuffered(100, 47) + .verifyNextMode(54, 1, BUFFER) + .recordEncodedSize(121) + .verifySpooledSegmentTarget(100) + .verifyNextMode(100, 80, SPOOL) + .verifyNextMode(100, 43, BUFFER) + .verifyNextMode(1, 1, BUFFER) + .verifyNextMode(100, 1, BUFFER) + .verifyNextMode(100, 80, SPOOL) + .verifyEmptyBuffer() + .verifySpooled(4, 2655, 395); + } + + private record OutputSpoolingControllerAssertions(OutputSpoolingController controller) + { + public OutputSpoolingControllerAssertions verifyNextMode(int positionCount, int rawSizeInBytes, OutputSpoolingController.Mode expected) + { + assertThat(controller.getNextMode(positionCount, rawSizeInBytes)) + .isEqualTo(expected); + + return this; + } + + private OutputSpoolingControllerAssertions verifyInlined(int inlinedPages, int inlinedPositions, int inlinedRawBytes) + { + assertThat(controller.getInlinedPages()) + .describedAs("Inlined pages") + .isEqualTo(inlinedPages); + + assertThat(controller.getInlinedPositions()) + .describedAs("Inlined positions") + .isEqualTo(inlinedPositions); + + assertThat(controller.getInlinedRawBytes()) + .describedAs("Inlined raw bytes") + .isEqualTo(inlinedRawBytes); + + return this; + } + + private OutputSpoolingControllerAssertions verifySpooled(int spooledPages, int spooledPositions, int spooledRawBytes) + { + assertThat(controller.getSpooledPages()) + .describedAs("Spooled pages") + .isEqualTo(spooledPages); + + assertThat(controller.getSpooledPositions()) + .describedAs("Spooled spooledPositions") + .isEqualTo(spooledPositions); + + assertThat(controller.getSpooledRawBytes()) + .describedAs("Spooled raw bytes") + .isEqualTo(spooledRawBytes); + + return this; + } + + private OutputSpoolingControllerAssertions verifyBuffered(int bufferedPositions, int bufferSize) + { + assertThat(controller.getBufferedPositions()) + .describedAs("Buffered positions") + .isEqualTo(bufferedPositions); + assertThat(controller.getBufferedRawSize()) + .describedAs("Buffered size") + .isEqualTo(bufferSize); + + return this; + } + + private OutputSpoolingControllerAssertions verifySpooledSegmentTarget(long size) + { + assertThat(controller.getCurrentSpooledSegmentTarget()) + .describedAs("Spooled segment target") + .isEqualTo(size); + + return this; + } + + private OutputSpoolingControllerAssertions recordEncodedSize(long encodedSize) + { + controller.recordEncoded(encodedSize); + return this; + } + + private OutputSpoolingControllerAssertions verifyEmptyBuffer() + { + return this.verifyBuffered(0, 0); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/spooling/TestSpooledBlock.java b/core/trino-main/src/test/java/io/trino/operator/spooling/TestSpooledBlock.java new file mode 100644 index 000000000000..83137664469a --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/spooling/TestSpooledBlock.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.spooling; + +import io.trino.client.spooling.DataAttributes; +import io.trino.server.protocol.spooling.SpooledBlock; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import org.junit.jupiter.api.Test; + +import static io.trino.client.spooling.DataAttribute.BYTE_SIZE; +import static io.trino.client.spooling.DataAttribute.ROWS_COUNT; +import static io.trino.spi.type.BigintType.BIGINT; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; + +class TestSpooledBlock +{ + @Test + public void testSerializationRoundTrip() + { + SpooledBlock metadata = new SpooledBlock("segmentId", createDataAttributes(10, 1200)); + Page page = new Page(metadata.serialize()); + SpooledBlock retrieved = SpooledBlock.deserialize(page); + assertThat(metadata).isEqualTo(retrieved); + } + + @Test + public void testSerializationRoundTripWithNonEmptyPage() + { + SpooledBlock metadata = new SpooledBlock("segmentId", createDataAttributes(10, 1100)); + Page page = new Page(blockWithPositions(1, true), metadata.serialize()); + SpooledBlock retrieved = SpooledBlock.deserialize(page); + assertThat(metadata).isEqualTo(retrieved); + } + + @Test + public void testThrowsErrorOnNonNullPositions() + { + SpooledBlock metadata = new SpooledBlock("segmentId", createDataAttributes(20, 1200)); + + assertThatThrownBy(() -> SpooledBlock.deserialize(new Page(blockWithPositions(1, false), metadata.serialize()))) + .hasMessage("Spooling metadata block must have all but last channels null"); + } + + @Test + public void testThrowsErrorOnMultiplePositions() + { + SpooledBlock metadata = new SpooledBlock("segmentId", createDataAttributes(30, 1300)); + + assertThatThrownBy(() -> SpooledBlock.deserialize(new Page(blockWithPositions(2, false), metadata.serialize()))) + .hasMessage("Spooling metadata block must have a single position"); + } + + public static Block blockWithPositions(int count, boolean isNull) + { + BlockBuilder blockBuilder = BIGINT.createBlockBuilder(null, count); + for (int i = 0; i < count; i++) { + if (isNull) { + blockBuilder.appendNull(); + } + else { + BIGINT.writeLong(blockBuilder, 0); + } + } + return blockBuilder.build(); + } + + private static DataAttributes createDataAttributes(long rows, int dataSizeBytes) + { + return DataAttributes.builder() + .set(ROWS_COUNT, rows) + .set(BYTE_SIZE, dataSizeBytes) + .build(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java index b50c783df8f5..a5acd02c14f0 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java +++ b/core/trino-main/src/test/java/io/trino/server/TestHttpRequestSessionContextFactory.java @@ -181,6 +181,7 @@ private static HttpRequestSessionContextFactory sessionContextFactory(ProtocolHe ImmutableSet::of, new AllowAllAccessControl(), new ProtocolConfig() - .setAlternateHeaderName(headers.getProtocolName())); + .setAlternateHeaderName(headers.getProtocolName()), + _ -> Optional.empty()); } } diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java b/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java index a5ac11504e01..5e173583cdbf 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryResource.java @@ -13,6 +13,7 @@ */ package io.trino.server; +import com.google.common.collect.ImmutableList; import com.google.inject.Key; import io.airlift.http.client.HttpClient; import io.airlift.http.client.HttpUriBuilder; @@ -24,6 +25,8 @@ import io.airlift.json.ObjectMapperProvider; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.trace.Span; +import io.trino.client.QueryData; +import io.trino.client.QueryDataClientJacksonModule; import io.trino.client.QueryResults; import io.trino.execution.QueryInfo; import io.trino.plugin.tpch.TpchPlugin; @@ -37,6 +40,7 @@ import java.net.URI; import java.util.List; import java.util.Map; +import java.util.Set; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler; @@ -46,7 +50,6 @@ import static io.airlift.http.client.Request.Builder.preparePut; import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; -import static io.airlift.json.JsonCodec.jsonCodec; import static io.airlift.testing.Closeables.closeAll; import static io.airlift.tracing.SpanSerialization.SpanDeserializer; import static io.airlift.tracing.SpanSerialization.SpanSerializer; @@ -72,10 +75,16 @@ public class TestQueryResource { static final JsonCodec> BASIC_QUERY_INFO_CODEC = new JsonCodecFactory( new ObjectMapperProvider() + .withModules(Set.of(new QueryDataClientJacksonModule())) .withJsonSerializers(Map.of(Span.class, new SpanSerializer(OpenTelemetry.noop()))) .withJsonDeserializers(Map.of(Span.class, new SpanDeserializer(OpenTelemetry.noop())))) .listJsonCodec(BasicQueryInfo.class); + static final JsonCodec QUERY_RESULTS_JSON_CODEC = new JsonCodecFactory( + new ObjectMapperProvider() + .withModules(Set.of(new QueryDataClientJacksonModule()))) + .jsonCodec(QueryResults.class); + private HttpClient client; private TestingTrinoServer server; @@ -108,7 +117,7 @@ public void testIdempotentResults() .setBodyGenerator(createStaticBodyGenerator(sql, UTF_8)) .build(); - QueryResults queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + QueryResults queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); URI uri = queryResults.getNextUri(); while (uri != null) { QueryResults attempt1 = client.execute( @@ -116,17 +125,16 @@ public void testIdempotentResults() .setHeader(TRINO_HEADERS.requestUser(), "user") .setUri(uri) .build(), - createJsonResponseHandler(jsonCodec(QueryResults.class))); + createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); QueryResults attempt2 = client.execute( prepareGet() .setHeader(TRINO_HEADERS.requestUser(), "user") .setUri(uri) .build(), - createJsonResponseHandler(jsonCodec(QueryResults.class))); - - assertThat(attempt2.getData()).isEqualTo(attempt1.getData()); + createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); + assertDataEquals(attempt2.getData(), attempt1.getData()); uri = attempt1.getNextUri(); } } @@ -249,6 +257,25 @@ public void testPreempted() testKilled("preempted"); } + private void assertDataEquals(QueryData left, QueryData right) + { + if (left == null) { + assertThat(right).isNull(); + return; + } + + if (left.getData() == null) { + assertThat(right.getData()).isNull(); + return; + } + + if (right.getData() == null) { + throw new AssertionError("Expected right data to be non-null"); + } + + assertThat(ImmutableList.copyOf(left.getData())).isEqualTo(ImmutableList.copyOf(right.getData())); + } + private void testKilled(String killType) { String queryId = startQuery("SELECT * FROM tpch.sf100.lineitem"); @@ -281,13 +308,13 @@ private String runToCompletion(String sql) .setUri(uri) .setBodyGenerator(createStaticBodyGenerator(sql, UTF_8)) .build(); - QueryResults queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + QueryResults queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); while (queryResults.getNextUri() != null) { request = prepareGet() .setHeader(TRINO_HEADERS.requestUser(), "user") .setUri(queryResults.getNextUri()) .build(); - queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); } return queryResults.getId(); } @@ -300,13 +327,13 @@ private String startQuery(String sql) .setBodyGenerator(createStaticBodyGenerator(sql, UTF_8)) .setHeader(TRINO_HEADERS.requestUser(), "user") .build(); - QueryResults queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + QueryResults queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); while (queryResults.getNextUri() != null && !queryResults.getStats().getState().equals(RUNNING.toString())) { request = prepareGet() .setHeader(TRINO_HEADERS.requestUser(), "user") .setUri(queryResults.getNextUri()) .build(); - queryResults = client.execute(request, createJsonResponseHandler(jsonCodec(QueryResults.class))); + queryResults = client.execute(request, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); } return queryResults.getId(); } diff --git a/core/trino-main/src/test/java/io/trino/server/TestQuerySessionSupplier.java b/core/trino-main/src/test/java/io/trino/server/TestQuerySessionSupplier.java index f75523215f74..1b70cf09358e 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQuerySessionSupplier.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQuerySessionSupplier.java @@ -26,6 +26,7 @@ import io.trino.metadata.SessionPropertyManager; import io.trino.security.AllowAllAccessControl; import io.trino.server.protocol.PreparedStatementEncoder; +import io.trino.server.protocol.spooling.SpoolingConfig; import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.spi.connector.CatalogSchemaName; @@ -72,7 +73,8 @@ public class TestQuerySessionSupplier createTestMetadataManager(), ImmutableSet::of, new AllowAllAccessControl(), - new ProtocolConfig()); + new ProtocolConfig(), + _ -> Optional.empty()); @Test public void testCreateSession() @@ -246,6 +248,7 @@ private static QuerySessionSupplier createSessionSupplier(SqlEnvironmentConfig c metadata, new AllowAllAccessControl(), new SessionPropertyManager(), - config); + config, + new SpoolingConfig()); } } diff --git a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java index 41bb35e62103..9520deede9c0 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java +++ b/core/trino-main/src/test/java/io/trino/server/TestQueryStateInfoResource.java @@ -14,12 +14,19 @@ package io.trino.server; import com.google.common.io.Closer; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.TypeLiteral; +import io.airlift.bootstrap.Bootstrap; import io.airlift.http.client.HttpClient; import io.airlift.http.client.Request; import io.airlift.http.client.UnexpectedResponseException; import io.airlift.http.client.jetty.JettyHttpClient; import io.airlift.json.JsonCodec; +import io.airlift.json.JsonModule; import io.airlift.units.Duration; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Tracer; import io.trino.client.QueryResults; import io.trino.plugin.tpch.TpchPlugin; import io.trino.server.testing.TestingTrinoServer; @@ -41,6 +48,9 @@ import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; import static io.airlift.json.JsonCodec.jsonCodec; import static io.airlift.json.JsonCodec.listJsonCodec; +import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static io.airlift.tracing.Tracing.noopTracer; +import static io.opentelemetry.api.OpenTelemetry.noop; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; import static io.trino.execution.QueryState.FAILED; import static io.trino.execution.QueryState.RUNNING; @@ -60,7 +70,7 @@ public class TestQueryStateInfoResource { private static final String LONG_LASTING_QUERY = "SELECT * FROM tpch.sf1.lineitem"; - private static final JsonCodec QUERY_RESULTS_JSON_CODEC = jsonCodec(QueryResults.class); + private static final JsonCodec QUERY_RESULTS_JSON_CODEC = queryResultsCodec(); private TestingTrinoServer server; private HttpClient client; @@ -87,7 +97,7 @@ public void setUp() .setBodyGenerator(createStaticBodyGenerator(LONG_LASTING_QUERY, UTF_8)) .setHeader(TRINO_HEADERS.requestUser(), "user2") .build(); - QueryResults queryResults2 = client.execute(request2, createJsonResponseHandler(jsonCodec(QueryResults.class))); + QueryResults queryResults2 = client.execute(request2, createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); client.execute(prepareGet().setUri(queryResults2.getNextUri()).build(), createJsonResponseHandler(QUERY_RESULTS_JSON_CODEC)); // queries are started in the background, so they may not all be immediately visible @@ -243,4 +253,18 @@ public void testGetQueryStateInfoNo() .isInstanceOf(UnexpectedResponseException.class) .hasMessageMatching("Expected response code .*, but was 404"); } + + public static JsonCodec queryResultsCodec() + { + Injector injector = new Bootstrap( + new JsonModule(), + binder -> { + jsonCodecBinder(binder).bindJsonCodec(QueryResults.class); + binder.bind(OpenTelemetry.class).toInstance(noop()); + binder.bind(Tracer.class).toInstance(noopTracer()); + } + ).initialize(); + + return injector.getInstance(Key.get(new TypeLiteral<>(){})); + } } diff --git a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java new file mode 100644 index 000000000000..ded0ed50a777 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryDataSerialization.java @@ -0,0 +1,319 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.client.ClientTypeSignature; +import io.trino.client.Column; +import io.trino.client.JsonCodec; +import io.trino.client.QueryData; +import io.trino.client.QueryDataClientJacksonModule; +import io.trino.client.QueryResults; +import io.trino.client.RawQueryData; +import io.trino.client.StatementStats; +import io.trino.client.spooling.DataAttributes; +import io.trino.client.spooling.EncodedQueryData; +import io.trino.client.spooling.InlineSegment; +import io.trino.client.spooling.Segment; +import io.trino.client.spooling.encoding.JsonQueryDataDecoder; +import io.trino.server.protocol.spooling.QueryDataJacksonModule; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.util.List; +import java.util.OptionalDouble; + +import static io.trino.client.ClientStandardTypes.BIGINT; +import static io.trino.client.FixJsonDataUtils.fixData; +import static io.trino.client.JsonCodec.jsonCodec; +import static io.trino.client.spooling.DataAttribute.BYTE_SIZE; +import static io.trino.client.spooling.DataAttribute.ENCRYPTION_KEY; +import static io.trino.client.spooling.DataAttribute.ROWS_COUNT; +import static io.trino.client.spooling.DataAttribute.ROW_OFFSET; +import static io.trino.client.spooling.DataAttributes.empty; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestQueryDataSerialization +{ + private static final List COLUMNS_LIST = ImmutableList.of(new Column("_col0", "bigint", new ClientTypeSignature("bigint"))); + private static final JsonCodec CODEC = jsonCodec(QueryResults.class, new QueryDataClientJacksonModule(), new QueryDataJacksonModule()); + + @Test + public void testNullDataSerialization() + { + assertThat(serialize(null)).doesNotContain("data"); + assertThat(serialize(RawQueryData.of(null))).doesNotContain("data"); + } + + @Test + public void testEmptyArraySerialization() + { + testRoundTrip(COLUMNS_LIST, RawQueryData.of(ImmutableList.of()), "[]"); + + assertThatThrownBy(() -> testRoundTrip(COLUMNS_LIST, RawQueryData.of(ImmutableList.of(ImmutableList.of())), "[[]]")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("row/column size mismatch"); + } + + @Test + public void testQueryDataSerialization() + { + Iterable> values = ImmutableList.of(ImmutableList.of(1L), ImmutableList.of(5L)); + testRoundTrip(COLUMNS_LIST, RawQueryData.of(values), "[[1],[5]]"); + } + + @Test + public void testEncodedQueryDataSerialization() + { + EncodedQueryData queryData = new EncodedQueryData("json-ext", ImmutableMap.of(), ImmutableList.of(Segment.inlined("[[10], [20]]".getBytes(UTF_8), dataAttributes(10, 2, 12)))); + testRoundTrip(COLUMNS_LIST, queryData, """ + { + "encodingId": "json-ext", + "segments": [ + { + "type": "inline", + "data": "W1sxMF0sIFsyMF1d", + "metadata": { + "rowOffset": 10, + "rowsCount": 2, + "byteSize": 12 + } + } + ] + }"""); + } + + @Test + public void testEncodedQueryDataSerializationWithExtraMetadata() + { + EncodedQueryData queryData = new EncodedQueryData("json-ext", ImmutableMap.of("decryptionKey", "secret"), ImmutableList.of(Segment.inlined("[[10], [20]]".getBytes(UTF_8), dataAttributes(10, 2, 12)))); + testRoundTrip(COLUMNS_LIST, queryData, """ + { + "encodingId": "json-ext", + "metadata": { + "decryptionKey": "secret" + }, + "segments": [ + { + "type": "inline", + "data": "W1sxMF0sIFsyMF1d", + "metadata": { + "rowOffset": 10, + "rowsCount": 2, + "byteSize": 12 + } + } + ] + }"""); + } + + @Test + public void testSpooledQueryDataSerialization() + { + DataAttributes attributes = DataAttributes.builder() + .set(ENCRYPTION_KEY, "super secret key") + .build(); + + EncodedQueryData queryData = EncodedQueryData.builder("json-ext") + .withAttributes(attributes) + .withSegment(Segment.inlined("super".getBytes(UTF_8), dataAttributes(0, 100, 5))) + .withSegment(Segment.spooled(URI.create("http://localhost:8080/v1/download/20160128_214710_00012_rk68b/segments/1"), dataAttributes(100, 100, 1024))) + .withSegment(Segment.spooled(URI.create("http://localhost:8080/v1/download/20160128_214710_00012_rk68b/segments/2"), dataAttributes(200, 100, 1024))) + .build(); + testSerializationRoundTrip(queryData, """ + { + "encodingId": "json-ext", + "metadata": { + "encryptionKey": "super secret key" + }, + "segments": [ + { + "type": "inline", + "data": "c3VwZXI=", + "metadata": { + "rowOffset": 0, + "rowsCount": 100, + "byteSize": 5 + } + }, + { + "type": "spooled", + "dataUri": "http://localhost:8080/v1/download/20160128_214710_00012_rk68b/segments/1", + "metadata": { + "rowOffset": 100, + "rowsCount": 100, + "byteSize": 1024 + } + }, + { + "type": "spooled", + "dataUri": "http://localhost:8080/v1/download/20160128_214710_00012_rk68b/segments/2", + "metadata": { + "rowOffset": 200, + "rowsCount": 100, + "byteSize": 1024 + } + } + ] + }"""); + } + + @Test + public void testEncodedQueryDataToString() + { + EncodedQueryData inlineQueryData = new EncodedQueryData("json-ext", ImmutableMap.of("decryption_key", "secret"), ImmutableList.of(Segment.inlined("[[10], [20]]".getBytes(UTF_8), dataAttributes(10, 2, 12)))); + assertThat(inlineQueryData.toString()).isEqualTo("EncodedQueryData{encodingId=json-ext, segments=[InlineSegment{offset=10, rows=2, size=12}], metadata=[decryption_key]}"); + + EncodedQueryData spooledQueryData = new EncodedQueryData("json-ext+zstd", ImmutableMap.of("decryption_key", "secret"), ImmutableList.of(Segment.spooled(URI.create("http://coordinator:8080/v1/segments/uuid"), dataAttributes(10, 2, 1256)))); + assertThat(spooledQueryData.toString()).isEqualTo("EncodedQueryData{encodingId=json-ext+zstd, segments=[SpooledSegment{offset=10, rows=2, size=1256}], metadata=[decryption_key]}"); + } + + private void testRoundTrip(List columns, QueryData queryData, String expectedDataRepresentation) + { + testSerializationRoundTrip(queryData, expectedDataRepresentation); + assertEquals(columns, deserialize(serialize(queryData)), queryData); + } + + private void testSerializationRoundTrip(QueryData queryData, String expectedDataRepresentation) + { + assertThat(serialize(queryData)) + .isEqualToIgnoringWhitespace(queryResultsJson(expectedDataRepresentation)); + } + + private String queryResultsJson(String expectedDataField) + { + return format(""" + { + "id": "20160128_214710_00012_rk68b", + "infoUri": "http://coordinator/query.html?20160128_214710_00012_rk68b", + "partialCancelUri": null, + "nextUri": null, + "columns": [ + { + "name": "_col0", + "type": "bigint", + "typeSignature": { + "rawType": "bigint", + "arguments": [] + } + } + ], + "data": %s, + "stats": { + "state": "FINISHED", + "queued": false, + "scheduled": false, + "progressPercentage": null, + "runningPercentage": null, + "nodes": 0, + "totalSplits": 0, + "queuedSplits": 0, + "runningSplits": 0, + "completedSplits": 0, + "cpuTimeMillis": 0, + "wallTimeMillis": 0, + "queuedTimeMillis": 0, + "elapsedTimeMillis": 0, + "processedRows": 0, + "processedBytes": 0, + "physicalInputBytes": 0, + "physicalWrittenBytes": 0, + "peakMemoryBytes": 0, + "spilledBytes": 0, + "rootStage": null + }, + "error": null, + "warnings": [], + "updateType": null, + "updateCount": null + }""", expectedDataField); + } + + private static void assertEquals(List columns, QueryData left, QueryData right) + { + Iterable> leftValues = decodeData(left, columns); + Iterable> rightValues = decodeData(right, columns); + + if (leftValues == null) { + assertThat(rightValues).isNull(); + return; + } + + assertThat(leftValues).hasSameElementsAs(rightValues); + } + + private static Iterable> decodeData(QueryData data, List columns) + { + if (data instanceof RawQueryData) { + return fixData(columns, data.getData()); + } + else if (data instanceof EncodedQueryData queryDataV2 && queryDataV2.getSegments().getFirst() instanceof InlineSegment inlineSegment) { + try { + return new JsonQueryDataDecoder.Factory().create(columns, empty()).decode(new ByteArrayInputStream(inlineSegment.getData()), inlineSegment.getMetadata()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + throw new AssertionError("Unexpected data type: " + data.getClass().getSimpleName()); + } + + private static QueryData deserialize(String serialized) + { + try { + return CODEC.fromJson(serialized).getData(); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private static String serialize(QueryData data) + { + return CODEC.toJson(new QueryResults( + "20160128_214710_00012_rk68b", + URI.create("http://coordinator/query.html?20160128_214710_00012_rk68b"), + null, + null, + ImmutableList.of(new Column("_col0", BIGINT, new ClientTypeSignature(BIGINT))), + data, + StatementStats.builder() + .setState("FINISHED") + .setProgressPercentage(OptionalDouble.empty()) + .setRunningPercentage(OptionalDouble.empty()) + .build(), + null, + ImmutableList.of(), + null, + null)); + } + + private DataAttributes dataAttributes(long currentOffset, long rowCount, int byteSize) + { + return DataAttributes.builder() + .set(ROW_OFFSET, currentOffset) + .set(ROWS_COUNT, rowCount) + .set(BYTE_SIZE, byteSize) + .build(); + } +} diff --git a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultRows.java b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultRows.java index 7580e1293419..358b5be5f3f3 100644 --- a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultRows.java +++ b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultRows.java @@ -49,6 +49,8 @@ import static io.trino.client.ClientStandardTypes.ROW; import static io.trino.client.ClientStandardTypes.TIMESTAMP; import static io.trino.client.ClientStandardTypes.TIMESTAMP_WITH_TIME_ZONE; +import static io.trino.server.protocol.JsonArrayResultsIterator.toIterableList; +import static io.trino.server.protocol.ProtocolUtil.createColumn; import static io.trino.server.protocol.QueryResultRows.queryResultRowsBuilder; import static io.trino.spi.type.TypeSignature.mapType; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; @@ -68,10 +70,10 @@ public void shouldNotReturnValues() { QueryResultRows rows = QueryResultRows.empty(getSession()); - assertThat((Iterable>) rows).as("rows").isEmpty(); - assertThat(getAllValues(rows)).hasSize(0); + assertThat((Iterable>) toIterableList(getSession(), rows, _ -> {})).as("rows").isEmpty(); + assertThat(getAllValues(rows, ignoredException -> {})).hasSize(0); assertThat(rows.getColumns()).isEmpty(); - assertThat(rows.iterator().hasNext()).isFalse(); + assertThat(toIterableList(getSession(), rows, _ -> {}).iterator().hasNext()).isFalse(); } @Test @@ -83,8 +85,8 @@ public void shouldReturnSingleValue() .withSingleBooleanValue(column, true) .build(); - assertThat((Iterable>) rows).as("rows").isNotEmpty(); - assertThat(getAllValues(rows)).hasSize(1).containsOnly(ImmutableList.of(true)); + assertThat((Iterable>) toIterableList(getSession(), rows, _ -> {})).as("rows").isNotEmpty(); + assertThat(getAllValues(rows, ignoredException -> {})).hasSize(1).containsOnly(ImmutableList.of(true)); assertThat(rows.getColumns().orElseThrow()).containsOnly(column); } @@ -99,11 +101,11 @@ public void shouldReturnUpdateCount() .addPages(rowPagesBuilder(BigintType.BIGINT).row(value).build()) .build(); - assertThat((Iterable>) rows).as("rows").isNotEmpty(); + assertThat((Iterable>) toIterableList(getSession(), rows, _ -> {})).as("rows").isNotEmpty(); assertThat(rows.getUpdateCount()).isPresent(); assertThat(rows.getUpdateCount().get()).isEqualTo(value); - assertThat(getAllValues(rows)).containsExactly(ImmutableList.of(value)); + assertThat(getAllValues(rows, ignoredException -> {})).containsExactly(ImmutableList.of(value)); assertThat(rows.getColumns().orElseThrow()).containsOnly(column); } @@ -116,9 +118,9 @@ public void shouldNotHaveUpdateCount() .withSingleBooleanValue(column, false) .build(); - assertThat((Iterable>) rows).as("rows").isNotEmpty(); + assertThat((Iterable>) toIterableList(getSession(), rows, _ -> {})).as("rows").isNotEmpty(); assertThat(rows.getUpdateCount()).isEmpty(); - assertThat(rows.iterator()).hasNext(); + assertThat(toIterableList(getSession(), rows, _ -> {}).iterator().hasNext()).isTrue(); } @Test @@ -145,15 +147,14 @@ public void shouldReadAllValuesFromMultiplePages() QueryResultRows rows = queryResultRowsBuilder(getSession()) .withColumnsAndTypes(columns, types) .addPages(pages) - .withExceptionConsumer(exceptionConsumer) .build(); - assertThat((Iterable>) rows).as("rows").isNotEmpty(); + assertThat((Iterable>) toIterableList(getSession(), rows, _ -> {})).as("rows").isNotEmpty(); assertThat(rows.getTotalRowsCount()).isEqualTo(10); assertThat(rows.getColumns()).isEqualTo(Optional.of(columns)); assertThat(rows.getUpdateCount()).isEmpty(); - assertThat(getAllValues(rows)).containsExactly( + assertThat(getAllValues(rows, exceptionConsumer)).containsExactly( ImmutableList.of(0, 10L), ImmutableList.of(1, 11L), ImmutableList.of(2, 12L), @@ -171,7 +172,9 @@ public void shouldReadAllValuesFromMultiplePages() @Test public void shouldOmitBadRows() { - List columns = ImmutableList.of(BOOLEAN_COLUMN.apply("_col0"), BOOLEAN_COLUMN.apply("_col1")); + List columns = ImmutableList.of( + createColumn("_col0", BogusType.BOGUS, true), + createColumn("_col1", BogusType.BOGUS, true)); List types = ImmutableList.of(BogusType.BOGUS, BogusType.BOGUS); List pages = rowPagesBuilder(types) @@ -185,7 +188,6 @@ public void shouldOmitBadRows() TestExceptionConsumer exceptionConsumer = new TestExceptionConsumer(); QueryResultRows rows = queryResultRowsBuilder(getSession()) .withColumnsAndTypes(columns, types) - .withExceptionConsumer(exceptionConsumer) .addPages(pages) .build(); @@ -196,7 +198,7 @@ public void shouldOmitBadRows() assertThat(rows.getColumns()).isEqualTo(Optional.of(columns)); assertThat(rows.getUpdateCount().isEmpty()).isTrue(); - assertThat(getAllValues(rows)) + assertThat(getAllValues(rows, exceptionConsumer)) .containsExactly(ImmutableList.of(0, 0)); List exceptions = exceptionConsumer.getExceptions(); @@ -245,7 +247,6 @@ public void shouldHandleNullValues() TestExceptionConsumer exceptionConsumer = new TestExceptionConsumer(); QueryResultRows rows = queryResultRowsBuilder(getSession()) .withColumnsAndTypes(columns, types) - .withExceptionConsumer(exceptionConsumer) .addPages(pages) .build(); @@ -254,7 +255,7 @@ public void shouldHandleNullValues() .isFalse(); assertThat(rows.getTotalRowsCount()).isEqualTo(3); - assertThat(getAllValues(rows)) + assertThat(getAllValues(rows, exceptionConsumer)) .hasSize(3) .containsExactly(newArrayList(0, null), newArrayList(1, null), newArrayList(2, true)); } @@ -274,7 +275,6 @@ public void shouldHandleNullTimestamps() TestExceptionConsumer exceptionConsumer = new TestExceptionConsumer(); QueryResultRows rows = queryResultRowsBuilder(getSession()) .withColumnsAndTypes(columns, types) - .withExceptionConsumer(exceptionConsumer) .addPages(pages) .build(); @@ -284,7 +284,7 @@ public void shouldHandleNullTimestamps() .isFalse(); assertThat(rows.getTotalRowsCount()).isEqualTo(1); - assertThat(getAllValues(rows)) + assertThat(getAllValues(rows, exceptionConsumer)) .hasSize(1) .containsExactly(newArrayList(null, null)); } @@ -302,7 +302,6 @@ public void shouldHandleNullValuesInArray() TestExceptionConsumer exceptionConsumer = new TestExceptionConsumer(); QueryResultRows rows = queryResultRowsBuilder(getSession()) .withColumnsAndTypes(columns, types) - .withExceptionConsumer(exceptionConsumer) .addPages(pages) .build(); @@ -312,7 +311,7 @@ public void shouldHandleNullValuesInArray() .isFalse(); assertThat(rows.getTotalRowsCount()).isEqualTo(1); - assertThat(getAllValues(rows)) + assertThat(getAllValues(rows, exceptionConsumer)) .hasSize(1) .containsOnly(singletonList(singletonList(null))); @@ -332,7 +331,6 @@ public void shouldHandleNullValuesInMap() TestExceptionConsumer exceptionConsumer = new TestExceptionConsumer(); QueryResultRows rows = queryResultRowsBuilder(getSession()) .withColumnsAndTypes(columns, types) - .withExceptionConsumer(exceptionConsumer) .addPages(pages) .build(); @@ -342,7 +340,7 @@ public void shouldHandleNullValuesInMap() .isFalse(); assertThat(rows.getTotalRowsCount()).isEqualTo(1); - assertThat(getAllValues(rows)) + assertThat(getAllValues(rows, exceptionConsumer)) .hasSize(1) .containsOnly(singletonList(singletonMap(10L, null))); @@ -366,7 +364,6 @@ public void shouldHandleNullValuesInRow() TestExceptionConsumer exceptionConsumer = new TestExceptionConsumer(); QueryResultRows rows = queryResultRowsBuilder(getSession()) .withColumnsAndTypes(columns, types) - .withExceptionConsumer(exceptionConsumer) .addPages(pages) .build(); @@ -376,7 +373,7 @@ public void shouldHandleNullValuesInRow() .isFalse(); assertThat(rows.getTotalRowsCount()).isEqualTo(1); - List> allValues = getAllValues(rows); + List> allValues = getAllValues(rows, exceptionConsumer); assertThat(allValues) .hasSize(1) @@ -455,11 +452,10 @@ public void shouldThrowWhenDataIsPresentWithoutColumns() .hasMessage("data present without columns and types"); } - private static List> getAllValues(QueryResultRows rows) + private static List> getAllValues(QueryResultRows rows, Consumer throwableConsumer) { ImmutableList.Builder> builder = ImmutableList.builder(); - - for (List values : rows) { + for (List values : toIterableList(getSession(), rows, throwableConsumer)) { builder.add(values); } diff --git a/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java new file mode 100644 index 000000000000..09a02d7f2fbe --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/server/protocol/TestQueryResultsSerialization.java @@ -0,0 +1,188 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import io.airlift.json.ObjectMapperProvider; +import io.trino.client.ClientTypeSignature; +import io.trino.client.Column; +import io.trino.client.QueryData; +import io.trino.client.QueryDataClientJacksonModule; +import io.trino.client.QueryResults; +import io.trino.client.RawQueryData; +import io.trino.client.StatementStats; +import io.trino.server.protocol.spooling.QueryDataJacksonModule; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.List; +import java.util.OptionalDouble; + +import static io.trino.client.ClientStandardTypes.BIGINT; +import static io.trino.client.FixJsonDataUtils.fixData; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestQueryResultsSerialization +{ + private static final List COLUMNS = ImmutableList.of(new Column("_col0", BIGINT, new ClientTypeSignature("bigint"))); + + // As close as possible to the server mapper (client mapper differs) + private static final ObjectMapper MAPPER = new ObjectMapperProvider().get() + .registerModules(new QueryDataJacksonModule(), new QueryDataClientJacksonModule()); + + @Test + public void testNullDataSerialization() + { + // data field should not be serialized + assertThat(serialize(null)).isEqualToIgnoringWhitespace(""" + { + "id" : "20160128_214710_00012_rk68b", + "infoUri" : "http://coordinator/query.html?20160128_214710_00012_rk68b", + "columns" : [ { + "name" : "_col0", + "type" : "bigint", + "typeSignature" : { + "rawType" : "bigint", + "arguments" : [ ] + } + } ], + "stats" : { + "state" : "FINISHED", + "queued" : false, + "scheduled" : false, + "nodes" : 0, + "totalSplits" : 0, + "queuedSplits" : 0, + "runningSplits" : 0, + "completedSplits" : 0, + "cpuTimeMillis" : 0, + "wallTimeMillis" : 0, + "queuedTimeMillis" : 0, + "elapsedTimeMillis" : 0, + "processedRows" : 0, + "processedBytes" : 0, + "physicalInputBytes" : 0, + "physicalWrittenBytes" : 0, + "peakMemoryBytes" : 0, + "spilledBytes" : 0 + }, + "warnings" : [ ] + } + """); + } + + @Test + public void testEmptyArraySerialization() + { + testRoundTrip(RawQueryData.of(ImmutableList.of()), "[]"); + + assertThatThrownBy(() -> testRoundTrip(RawQueryData.of(ImmutableList.of(ImmutableList.of())), "[[]]")) + .isInstanceOf(RuntimeException.class) + .hasMessage("row/column size mismatch"); + } + + @Test + public void testSerialization() + { + QueryData values = RawQueryData.of(ImmutableList.of(ImmutableList.of(1L), ImmutableList.of(5L))); + testRoundTrip(values, "[[1],[5]]"); + } + + private void testRoundTrip(QueryData results, String expectedDataRepresentation) + { + assertThat(serialize(results)) + .isEqualToIgnoringWhitespace(queryResultsJson(expectedDataRepresentation)); + + String serialized = removeNewLines(serialize(results)); + try { + assertThat(fixData(COLUMNS, MAPPER.readValue(serialized, QueryResults.class).getData().getData())).hasSameElementsAs(results.getData()); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private String queryResultsJson(String expectedDataField) + { + return removeNewLines(format(""" + { + "id" : "20160128_214710_00012_rk68b", + "infoUri" : "http://coordinator/query.html?20160128_214710_00012_rk68b", + "columns" : [ { + "name" : "_col0", + "type" : "bigint", + "typeSignature" : { + "rawType" : "bigint", + "arguments" : [ ] + } + } ], + "data" : %s, + "stats" : { + "state" : "FINISHED", + "queued" : false, + "scheduled" : false, + "nodes" : 0, + "totalSplits" : 0, + "queuedSplits" : 0, + "runningSplits" : 0, + "completedSplits" : 0, + "cpuTimeMillis" : 0, + "wallTimeMillis" : 0, + "queuedTimeMillis" : 0, + "elapsedTimeMillis" : 0, + "processedRows" : 0, + "processedBytes" : 0, + "physicalInputBytes" : 0, + "physicalWrittenBytes" : 0, + "peakMemoryBytes" : 0, + "spilledBytes" : 0 + }, + "warnings" : [ ] + }""", expectedDataField)); + } + + private static String serialize(QueryData data) + { + try { + return removeNewLines(MAPPER.writeValueAsString(new QueryResults( + "20160128_214710_00012_rk68b", + URI.create("http://coordinator/query.html?20160128_214710_00012_rk68b"), + null, + null, + ImmutableList.of(new Column("_col0", BIGINT, new ClientTypeSignature(BIGINT))), + data, + StatementStats.builder() + .setState("FINISHED") + .setProgressPercentage(OptionalDouble.empty()) + .setRunningPercentage(OptionalDouble.empty()) + .build(), + null, + ImmutableList.of(), + null, + null))); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private static String removeNewLines(String value) + { + return value.replaceAll("\n", "").replaceAll("\r", "").replaceAll("\\s+", ""); + } +} diff --git a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java index 17796985d129..e858ae2ae329 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java +++ b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java @@ -1211,7 +1211,8 @@ public TestResource(AccessControl accessControl) createTestMetadataManager(), user -> ImmutableSet.of(), accessControl, - new ProtocolConfig()); + new ProtocolConfig(), + _ -> Optional.empty()); } @ResourceSecurity(AUTHENTICATED_USER) diff --git a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java index 52eed2b0e1de..ad498b2cb058 100644 --- a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java +++ b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java @@ -424,7 +424,8 @@ public TestResource(AccessControl accessControl) createTestMetadataManager(), ImmutableSet::of, accessControl, - new ProtocolConfig()); + new ProtocolConfig(), + _ -> Optional.empty()); } @ResourceSecurity(WEB_UI) diff --git a/core/trino-server/src/main/provisio/trino.xml b/core/trino-server/src/main/provisio/trino.xml index 6e86a8fb0d77..0a1dcce9e94f 100644 --- a/core/trino-server/src/main/provisio/trino.xml +++ b/core/trino-server/src/main/provisio/trino.xml @@ -327,6 +327,12 @@ + + + + + + diff --git a/core/trino-spi/src/main/java/io/trino/spi/Plugin.java b/core/trino-spi/src/main/java/io/trino/spi/Plugin.java index ec3791bb1937..f812f21671f6 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/Plugin.java +++ b/core/trino-spi/src/main/java/io/trino/spi/Plugin.java @@ -18,6 +18,7 @@ import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.eventlistener.EventListenerFactory; import io.trino.spi.exchange.ExchangeManagerFactory; +import io.trino.spi.protocol.SpoolingManagerFactory; import io.trino.spi.resourcegroups.ResourceGroupConfigurationManagerFactory; import io.trino.spi.security.CertificateAuthenticatorFactory; import io.trino.spi.security.GroupProviderFactory; @@ -109,4 +110,9 @@ default Iterable getExchangeManagerFactories() { return emptyList(); } + + default Iterable getSpoolingManagerFactories() + { + return emptyList(); + } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/protocol/SpooledSegmentHandle.java b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpooledSegmentHandle.java new file mode 100644 index 000000000000..0872be4dcf75 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpooledSegmentHandle.java @@ -0,0 +1,18 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.protocol; + +public interface SpooledSegmentHandle +{ +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingContext.java b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingContext.java new file mode 100644 index 000000000000..2c5850b7e35b --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingContext.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.protocol; + +import io.trino.spi.QueryId; + +public record SpoolingContext(QueryId queryId, long rowCount) +{ + public SpoolingContext + { + if (queryId == null) { + throw new IllegalArgumentException("queryId is null"); + } + + if (rowCount < 0) { + throw new IllegalArgumentException("rowCount is negative"); + } + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManager.java b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManager.java new file mode 100644 index 000000000000..8fb9b5aca612 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManager.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.protocol; + +import io.airlift.slice.Slice; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.URI; +import java.util.Optional; + +public interface SpoolingManager +{ + SpooledSegmentHandle create(SpoolingContext context); + + OutputStream createOutputStream(SpooledSegmentHandle handle) + throws Exception; + + InputStream openInputStream(SpooledSegmentHandle handle) + throws IOException; + + default void acknowledge(SpooledSegmentHandle handle) + throws IOException + { + } + + default Optional directLocation(SpooledSegmentHandle handle) + { + return Optional.empty(); + } + + default Slice serialize(SpooledSegmentHandle handle) + { + throw new UnsupportedOperationException(); + } + + default SpooledSegmentHandle deserialize(Slice slice) + { + throw new UnsupportedOperationException(); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerContext.java b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerContext.java new file mode 100644 index 000000000000..ae98d2aa1392 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerContext.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.protocol; + +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Tracer; + +public interface SpoolingManagerContext +{ + default OpenTelemetry getOpenTelemetry() + { + throw new UnsupportedOperationException(); + } + + default Tracer getTracer() + { + throw new UnsupportedOperationException(); + } +} diff --git a/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerFactory.java b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerFactory.java new file mode 100644 index 000000000000..0c151015f107 --- /dev/null +++ b/core/trino-spi/src/main/java/io/trino/spi/protocol/SpoolingManagerFactory.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spi.protocol; + +import java.util.Map; + +public interface SpoolingManagerFactory +{ + String getName(); + + SpoolingManager create(Map config, SpoolingManagerContext context); +} diff --git a/plugin/trino-spooling-filesystem/pom.xml b/plugin/trino-spooling-filesystem/pom.xml new file mode 100644 index 000000000000..cbeb8c751d9d --- /dev/null +++ b/plugin/trino-spooling-filesystem/pom.xml @@ -0,0 +1,145 @@ + + + 4.0.0 + + io.trino + trino-root + 454-SNAPSHOT + ../../pom.xml + + + trino-spooling-filesystem + trino-plugin + Trino - Spooling filesystem + + + ${project.parent.basedir} + + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + io.airlift + bootstrap + + + + io.airlift + configuration + + + + io.airlift + log + + + + io.airlift + units + + + + io.trino + trino-filesystem + + + + io.trino + trino-filesystem-azure + + + + io.trino + trino-filesystem-gcs + + + + io.trino + trino-filesystem-s3 + + + + io.trino + trino-plugin-toolkit + + + + jakarta.validation + jakarta.validation-api + + + + org.weakref + jmxutils + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + io.airlift + slice + provided + + + + io.opentelemetry + opentelemetry-api + provided + + + + io.opentelemetry + opentelemetry-context + provided + + + + io.trino + trino-spi + provided + + + + org.openjdk.jol + jol-core + provided + + + + io.airlift + junit-extensions + test + + + + io.trino + trino-testing-containers + test + + + + org.assertj + assertj-core + test + + + + org.junit.jupiter + junit-jupiter-api + test + + + diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java new file mode 100644 index 000000000000..a5979639fb22 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import io.airlift.units.Duration; +import io.trino.spi.protocol.SpooledSegmentHandle; +import io.trino.spi.protocol.SpoolingContext; + +import java.util.Date; + +import static java.util.Objects.requireNonNull; +import static java.util.UUID.randomUUID; + +public record FileSystemSpooledSegmentHandle(String name, Date validUntil) + implements SpooledSegmentHandle +{ + public FileSystemSpooledSegmentHandle + { + requireNonNull(name, "name is null"); + requireNonNull(validUntil, "validUntil is null"); + } + + public static FileSystemSpooledSegmentHandle random(SpoolingContext context, Duration ttl) + { + return new FileSystemSpooledSegmentHandle(context.queryId().getId() + "/" + randomObjectName(), new Date(System.currentTimeMillis() + ttl.toMillis())); + } + + private static String randomObjectName() + { + return randomUUID().toString().replace("-", ""); + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingConfig.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingConfig.java new file mode 100644 index 000000000000..08c111aba5f1 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingConfig.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import io.airlift.units.Duration; +import jakarta.validation.constraints.AssertTrue; + +import static java.util.concurrent.TimeUnit.HOURS; + +public class FileSystemSpoolingConfig +{ + private boolean nativeAzureEnabled; + private boolean nativeS3Enabled; + private boolean nativeGcsEnabled; + private String location; + + private Duration ttl = new Duration(2, HOURS); + + public boolean isNativeAzureEnabled() + { + return nativeAzureEnabled; + } + + @Config("fs.native-azure.enabled") + public FileSystemSpoolingConfig setNativeAzureEnabled(boolean nativeAzureEnabled) + { + this.nativeAzureEnabled = nativeAzureEnabled; + return this; + } + + public boolean isNativeS3Enabled() + { + return nativeS3Enabled; + } + + @Config("fs.native-s3.enabled") + public FileSystemSpoolingConfig setNativeS3Enabled(boolean nativeS3Enabled) + { + this.nativeS3Enabled = nativeS3Enabled; + return this; + } + + public boolean isNativeGcsEnabled() + { + return nativeGcsEnabled; + } + + @Config("fs.native-gcs.enabled") + public FileSystemSpoolingConfig setNativeGcsEnabled(boolean nativeGcsEnabled) + { + this.nativeGcsEnabled = nativeGcsEnabled; + return this; + } + + public String getLocation() + { + return location; + } + + @Config("location") + public FileSystemSpoolingConfig setLocation(String location) + { + this.location = location; + return this; + } + + public Duration getTtl() + { + return ttl; + } + + @ConfigDescription("Maximum duration for the client to retrieve spooled segment") + @Config("ttl") + public void setTtl(Duration ttl) + { + this.ttl = ttl; + } + + @AssertTrue(message = "At least one native file system must be enabled") + public boolean isEitherNativeFileSystemEnabled() + { + return nativeAzureEnabled || nativeS3Enabled || nativeGcsEnabled; + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java new file mode 100644 index 000000000000..9f332df09abb --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java @@ -0,0 +1,136 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import com.google.inject.Inject; +import io.airlift.log.Logger; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.units.Duration; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoFileSystemFactory; +import io.trino.spi.protocol.SpooledSegmentHandle; +import io.trino.spi.protocol.SpoolingContext; +import io.trino.spi.protocol.SpoolingManager; +import io.trino.spi.security.ConnectorIdentity; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UncheckedIOException; +import java.util.Date; + +import static io.airlift.slice.Slices.utf8Slice; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; + +public class FileSystemSpoolingManager + implements SpoolingManager +{ + private static final Logger log = Logger.get(FileSystemSpoolingManager.class); + + private final String location; + private final TrinoFileSystem fileSystem; + private final Duration ttl; + + @Inject + public FileSystemSpoolingManager(FileSystemSpoolingConfig config, TrinoFileSystemFactory fileSystemFactory) + { + requireNonNull(config, "config is null"); + this.location = config.getLocation(); + this.fileSystem = requireNonNull(fileSystemFactory, "fileSystemFactory is null") + .create(ConnectorIdentity.ofUser("ignored")); + this.ttl = config.getTtl(); + } + + @Override + public OutputStream createOutputStream(SpooledSegmentHandle handle) + throws Exception + { + return fileSystem.newOutputFile(segmentLocation((FileSystemSpooledSegmentHandle) handle)).create(); + } + + @Override + public FileSystemSpooledSegmentHandle create(SpoolingContext context) + { + return FileSystemSpooledSegmentHandle.random(context, ttl); + } + + @Override + public InputStream openInputStream(SpooledSegmentHandle handle) + throws IOException + { + FileSystemSpooledSegmentHandle segmentHandle = (FileSystemSpooledSegmentHandle) handle; + checkExpiration(segmentHandle); + return fileSystem.newInputFile(segmentLocation((FileSystemSpooledSegmentHandle) handle)).newStream(); + } + + @Override + public void acknowledge(SpooledSegmentHandle handle) + { + try { + fileSystem.deleteFile(segmentLocation((FileSystemSpooledSegmentHandle) handle)); + } + catch (IOException e) { + log.warn(e, "Failed to delete segment"); + } + } + + private static String safeString(String value) + { + return value.replaceAll("[^a-zA-Z0-9-_/]", "-"); + } + + private Location segmentLocation(FileSystemSpooledSegmentHandle handle) + { + checkExpiration(handle); + return Location.of(location + "/" + safeString(handle.name())); + } + + private void checkExpiration(FileSystemSpooledSegmentHandle handle) + { + if (handle.validUntil().before(new Date())) { + throw new RuntimeException("Segment has expired"); + } + } + + @Override + public Slice serialize(SpooledSegmentHandle handle) + { + FileSystemSpooledSegmentHandle fileHandle = (FileSystemSpooledSegmentHandle) handle; + try (DynamicSliceOutput output = new DynamicSliceOutput(64)) { + output.writeLong(fileHandle.validUntil().getTime()); + output.writeInt(fileHandle.name().length()); + output.writeBytes(utf8Slice(fileHandle.name())); + return output.slice(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public SpooledSegmentHandle deserialize(Slice slice) + { + try (SliceInput input = slice.getInput()) { + Date validUntil = new Date(input.readLong()); + int nameLength = input.readInt(); + byte[] name = new byte[nameLength]; + input.readBytes(name); + return new FileSystemSpooledSegmentHandle(new String(name, UTF_8), validUntil); + } + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManagerFactory.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManagerFactory.java new file mode 100644 index 000000000000..408d0f7defbf --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManagerFactory.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import com.google.inject.Injector; +import io.airlift.bootstrap.Bootstrap; +import io.opentelemetry.api.OpenTelemetry; +import io.trino.plugin.base.jmx.MBeanServerModule; +import io.trino.spi.protocol.SpoolingManager; +import io.trino.spi.protocol.SpoolingManagerContext; +import io.trino.spi.protocol.SpoolingManagerFactory; +import org.weakref.jmx.guice.MBeanModule; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class FileSystemSpoolingManagerFactory + implements SpoolingManagerFactory +{ + public static final String NAME = "filesystem"; + + @Override + public String getName() + { + return NAME; + } + + @Override + public SpoolingManager create(Map config, SpoolingManagerContext context) + { + requireNonNull(config, "requiredConfig is null"); + + // A plugin is not required to use Guice; it is just very convenient + Bootstrap app = new Bootstrap( + new FilesystemSpoolingModule(), + new MBeanModule(), + new MBeanServerModule(), + binder -> { + binder.bind(SpoolingManagerContext.class).toInstance(context); + binder.bind(OpenTelemetry.class).toInstance(context.getOpenTelemetry()); + }); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + + return injector.getInstance(SpoolingManager.class); + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingModule.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingModule.java new file mode 100644 index 000000000000..ef82e65ec508 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingModule.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import com.google.inject.Scopes; +import com.google.inject.Singleton; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.filesystem.TrinoFileSystemFactory; +import io.trino.filesystem.azure.AzureFileSystemFactory; +import io.trino.filesystem.azure.AzureFileSystemModule; +import io.trino.filesystem.gcs.GcsFileSystemFactory; +import io.trino.filesystem.gcs.GcsFileSystemModule; +import io.trino.filesystem.s3.S3FileSystemFactory; +import io.trino.filesystem.s3.S3FileSystemModule; +import io.trino.filesystem.tracing.TracingFileSystemFactory; +import io.trino.spi.protocol.SpoolingManager; +import io.trino.spi.protocol.SpoolingManagerContext; + +import java.util.Map; + +import static com.google.inject.multibindings.MapBinder.newMapBinder; + +public class FilesystemSpoolingModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + FileSystemSpoolingConfig config = buildConfigObject(FileSystemSpoolingConfig.class); + var factories = newMapBinder(binder, String.class, TrinoFileSystemFactory.class); + if (config.isNativeAzureEnabled()) { + install(new AzureFileSystemModule()); + factories.addBinding("abfs").to(AzureFileSystemFactory.class); + factories.addBinding("abfss").to(AzureFileSystemFactory.class); + } + if (config.isNativeS3Enabled()) { + install(new S3FileSystemModule()); + factories.addBinding("s3").to(S3FileSystemFactory.class); + factories.addBinding("s3a").to(S3FileSystemFactory.class); + factories.addBinding("s3n").to(S3FileSystemFactory.class); + } + if (config.isNativeGcsEnabled()) { + install(new GcsFileSystemModule()); + factories.addBinding("gs").to(GcsFileSystemFactory.class); + } + binder.bind(SpoolingManager.class).to(FileSystemSpoolingManager.class).in(Scopes.SINGLETON); + } + + @Provides + @Singleton + public TrinoFileSystemFactory createFileSystemFactory( + Map factories, + SpoolingManagerContext context) + { + return new TracingFileSystemFactory(context.getTracer(), _ -> new SwitchingFileSystem(factories)); + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingPlugin.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingPlugin.java new file mode 100644 index 000000000000..3eb77ac39095 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FilesystemSpoolingPlugin.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import io.trino.spi.Plugin; +import io.trino.spi.protocol.SpoolingManagerFactory; + +import java.util.List; + +public class FilesystemSpoolingPlugin + implements Plugin +{ + @Override + public Iterable getSpoolingManagerFactories() + { + return List.of(new FileSystemSpoolingManagerFactory()); + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/SwitchingFileSystem.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/SwitchingFileSystem.java new file mode 100644 index 000000000000..93167f122fa6 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/SwitchingFileSystem.java @@ -0,0 +1,152 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import com.google.common.collect.ImmutableMap; +import io.trino.filesystem.FileIterator; +import io.trino.filesystem.Location; +import io.trino.filesystem.TrinoFileSystem; +import io.trino.filesystem.TrinoFileSystemFactory; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoOutputFile; +import io.trino.spi.security.ConnectorIdentity; + +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.groupingBy; + +final class SwitchingFileSystem + implements TrinoFileSystem +{ + private final Map factories; + + public SwitchingFileSystem( + Map factories) + { + this.factories = ImmutableMap.copyOf(requireNonNull(factories, "factories is null")); + } + + @Override + public TrinoInputFile newInputFile(Location location) + { + return fileSystem(location).newInputFile(location); + } + + @Override + public TrinoInputFile newInputFile(Location location, long length) + { + return fileSystem(location).newInputFile(location, length); + } + + @Override + public TrinoOutputFile newOutputFile(Location location) + { + return fileSystem(location).newOutputFile(location); + } + + @Override + public void deleteFile(Location location) + throws IOException + { + fileSystem(location).deleteFile(location); + } + + @Override + public void deleteFiles(Collection locations) + throws IOException + { + var groups = locations.stream().collect(groupingBy(this::determineFactory)); + for (var entry : groups.entrySet()) { + createFileSystem(entry.getKey()).deleteFiles(entry.getValue()); + } + } + + @Override + public void deleteDirectory(Location location) + throws IOException + { + fileSystem(location).deleteDirectory(location); + } + + @Override + public void renameFile(Location source, Location target) + throws IOException + { + fileSystem(source).renameFile(source, target); + } + + @Override + public FileIterator listFiles(Location location) + throws IOException + { + return fileSystem(location).listFiles(location); + } + + @Override + public Optional directoryExists(Location location) + throws IOException + { + return fileSystem(location).directoryExists(location); + } + + @Override + public void createDirectory(Location location) + throws IOException + { + fileSystem(location).createDirectory(location); + } + + @Override + public void renameDirectory(Location source, Location target) + throws IOException + { + fileSystem(source).renameDirectory(source, target); + } + + @Override + public Set listDirectories(Location location) + throws IOException + { + return fileSystem(location).listDirectories(location); + } + + @Override + public Optional createTemporaryDirectory(Location targetPath, String temporaryPrefix, String relativePrefix) + throws IOException + { + return fileSystem(targetPath).createTemporaryDirectory(targetPath, temporaryPrefix, relativePrefix); + } + + private TrinoFileSystem fileSystem(Location location) + { + return createFileSystem(determineFactory(location)); + } + + private TrinoFileSystemFactory determineFactory(Location location) + { + return location.scheme() + .map(factories::get) + .orElseThrow(() -> new IllegalArgumentException("No factory for location: " + location)); + } + + private TrinoFileSystem createFileSystem(TrinoFileSystemFactory factory) + { + return factory.create(ConnectorIdentity.ofUser("system")); + } +} diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManager.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManager.java new file mode 100644 index 000000000000..9dead8381fb4 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManager.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import io.airlift.units.DataSize; +import io.trino.filesystem.s3.S3FileSystemConfig; +import io.trino.filesystem.s3.S3FileSystemFactory; +import io.trino.spi.QueryId; +import io.trino.spi.protocol.SpooledSegmentHandle; +import io.trino.spi.protocol.SpoolingContext; +import io.trino.spi.protocol.SpoolingManager; +import io.trino.testing.containers.Minio; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.io.InputStream; +import java.io.OutputStream; +import java.time.Instant; +import java.util.Date; + +import static io.opentelemetry.api.OpenTelemetry.noop; +import static io.trino.testing.containers.Minio.MINIO_REGION; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@TestInstance(PER_CLASS) +public class TestFileSystemSpoolingManager +{ + private Minio minio; + + @BeforeAll + public void setup() + { + minio = Minio.builder().build(); + minio.start(); + minio.createBucket("test"); + } + + @AfterAll + public void teardown() + { + minio.stop(); + } + + @Test + public void testManager() + throws Exception + { + SpoolingManager manager = getSpoolingManager(); + SpoolingContext context = new SpoolingContext(QueryId.valueOf("a"), 0); + SpooledSegmentHandle spooledSegmentHandle = manager.create(context); + try (OutputStream segment = manager.createOutputStream(spooledSegmentHandle)) { + segment.write("data".getBytes(UTF_8)); + } + + try (InputStream output = manager.openInputStream(spooledSegmentHandle)) { + byte[] buffer = new byte[4]; + assertThat(output.read(buffer)).isEqualTo(buffer.length); + assertThat(buffer).isEqualTo("data".getBytes(UTF_8)); + } + } + + @Test + public void testHandleRoundTrip() + { + FileSystemSpooledSegmentHandle handle = new FileSystemSpooledSegmentHandle("test", Date.from(Instant.now())); + assertThat(getSpoolingManager().deserialize(getSpoolingManager().serialize(handle))).isEqualTo(handle); + } + + private SpoolingManager getSpoolingManager() + { + FileSystemSpoolingConfig spoolingConfig = new FileSystemSpoolingConfig(); + spoolingConfig.setNativeS3Enabled(true); + spoolingConfig.setLocation("s3://test"); + S3FileSystemConfig filesystemConfig = new S3FileSystemConfig() + .setEndpoint(minio.getMinioAddress()) + .setRegion(MINIO_REGION) + .setPathStyleAccess(true) + .setAwsAccessKey(Minio.MINIO_ACCESS_KEY) + .setAwsSecretKey(Minio.MINIO_SECRET_KEY) + .setStreamingPartSize(DataSize.valueOf("5.5MB")); + return new FileSystemSpoolingManager(spoolingConfig, new S3FileSystemFactory(noop(), filesystemConfig)); + } +} diff --git a/pom.xml b/pom.xml index 19c4c42807df..fcc9c1ce49a0 100644 --- a/pom.xml +++ b/pom.xml @@ -109,6 +109,7 @@ plugin/trino-session-property-managers plugin/trino-singlestore plugin/trino-snowflake + plugin/trino-spooling-filesystem plugin/trino-sqlserver plugin/trino-teradata-functions plugin/trino-thrift @@ -1191,6 +1192,13 @@ ${project.version} + + io.trino + trino-jdbc + ${project.version} + nonshaded + + io.trino trino-jmx @@ -1442,6 +1450,12 @@ test-jar + + io.trino + trino-spooling-filesystem + ${project.version} + + io.trino trino-sqlserver diff --git a/testing/trino-plugin-reader/pom.xml b/testing/trino-plugin-reader/pom.xml index 34b89eed7d96..840b210a8e5b 100644 --- a/testing/trino-plugin-reader/pom.xml +++ b/testing/trino-plugin-reader/pom.xml @@ -35,6 +35,20 @@ io.trino trino-main + + + com.squareup.okhttp3 + okhttp + + + com.squareup.okhttp3 + okhttp-urlconnection + + + org.bouncycastle + bcprov-jdk18on + + diff --git a/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/Minio.java b/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/Minio.java index d87c327b9ec7..a4dadd89dffc 100644 --- a/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/Minio.java +++ b/testing/trino-testing-containers/src/main/java/io/trino/testing/containers/Minio.java @@ -111,6 +111,11 @@ public HostAndPort getMinioConsoleEndpoint() } public void createBucket(String bucketName) + { + createBucket(bucketName, false); + } + + public void createBucket(String bucketName, boolean objectLock) { try (MinioClient minioClient = createMinioClient()) { // use retry loop for minioClient.makeBucket as minio container tends to return "Server not initialized, please try again" error @@ -120,7 +125,7 @@ public void createBucket(String bucketName) .withMaxAttempts(Integer.MAX_VALUE) // limited by MaxDuration .withDelay(Duration.of(10, SECONDS)) .build(); - Failsafe.with(retryPolicy).run(() -> minioClient.makeBucket(bucketName)); + Failsafe.with(retryPolicy).run(() -> minioClient.makeBucket(bucketName, objectLock)); } } diff --git a/testing/trino-testing-containers/src/main/java/io/trino/testing/minio/MinioClient.java b/testing/trino-testing-containers/src/main/java/io/trino/testing/minio/MinioClient.java index 706a174e08c0..d12dfc1e9eab 100644 --- a/testing/trino-testing-containers/src/main/java/io/trino/testing/minio/MinioClient.java +++ b/testing/trino-testing-containers/src/main/java/io/trino/testing/minio/MinioClient.java @@ -170,6 +170,11 @@ public List listObjects(String bucket, String path) } public void makeBucket(String bucketName) + { + makeBucket(bucketName, false); + } + + public void makeBucket(String bucketName, boolean objectLock) { if (!createdBuckets.add(bucketName)) { // Forbid to create a bucket with given name more than once per class loader. @@ -182,6 +187,7 @@ public void makeBucket(String bucketName) client.makeBucket( MakeBucketArgs.builder() .bucket(bucketName) + .objectLock(objectLock) .build()); } catch (Exception e) { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java index 08b92e6e46df..2e6e809e91fd 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestEngineOnlyQueries.java @@ -5374,7 +5374,9 @@ public void testShowSession() getQueryRunner().getSessionPropertyManager(), getSession().getPreparedStatements(), getSession().getProtocolHeaders(), - getSession().getExchangeEncryptionKey()); + getSession().getExchangeEncryptionKey(), + getSession().getQueryDataEncodingId(), + getSession().getQueryDataEncryptionKey()); MaterializedResult result = computeActual(session, "SHOW SESSION"); ImmutableMap properties = Maps.uniqueIndex(result.getMaterializedRows(), input -> { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index 06def8f45607..dc365700b279 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -630,6 +630,14 @@ public void loadExchangeManager(String name, Map properties) } } + @Override + public void loadSpoolingManager(String name, Map properties) + { + for (TestingTrinoServer server : servers) { + server.loadSpoolingManager(name, properties); + } + } + @Override public final void close() { diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 9852685c7e45..8a675170d8cb 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -19,6 +19,12 @@ provided + + com.fasterxml.jackson.core + jackson-databind + runtime + + com.google.errorprone error_prone_annotations @@ -127,6 +133,12 @@ test + + io.airlift + tracing + test + + io.opentelemetry opentelemetry-sdk-trace @@ -209,12 +221,31 @@ test + + io.trino + trino-spi + test-jar + test + + + + io.trino + trino-spooling-filesystem + test + + io.trino trino-testing test + + io.trino + trino-testing-containers + test + + io.trino trino-testing-services diff --git a/testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java b/testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java new file mode 100644 index 000000000000..f58bd40fbab3 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +import io.trino.Session; +import io.trino.client.ClientSession; +import io.trino.client.StatementClient; +import io.trino.client.spooling.encoding.QueryDataDecoders; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorPlugin; +import io.trino.plugin.memory.MemoryQueryRunner; +import io.trino.server.testing.TestingTrinoServer; +import io.trino.spooling.filesystem.FilesystemSpoolingPlugin; +import io.trino.testing.AbstractTestEngineOnlyQueries; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingStatementClientFactory; +import io.trino.testing.TestingTrinoClient; +import io.trino.testing.containers.Minio; +import io.trino.tpch.TpchTable; +import okhttp3.OkHttpClient; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; + +import static io.airlift.testing.Closeables.closeAllSuppress; +import static io.trino.client.StatementClientFactory.newStatementClient; +import static io.trino.testing.containers.Minio.MINIO_ACCESS_KEY; +import static io.trino.testing.containers.Minio.MINIO_REGION; +import static io.trino.testing.containers.Minio.MINIO_SECRET_KEY; +import static io.trino.util.Ciphers.createRandomAesEncryptionKey; +import static java.util.Base64.getEncoder; + +public abstract class AbstractSpooledQueryDataDistributedQueries + extends AbstractTestEngineOnlyQueries +{ + private Minio minio; + + protected abstract String encodingId(); + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + minio = closeAfterClass(Minio.builder().build()); + minio.start(); + + String bucketName = "segments" + UUID.randomUUID(); + minio.createBucket(bucketName, true); + + DistributedQueryRunner queryRunner = MemoryQueryRunner.builder() + .setInitialTables(TpchTable.getTables()) + .setTestingTrinoClientFactory((trinoServer, session) -> createClient(trinoServer, session, encodingId())) + .addExtraProperty("protocol.spooling.enabled", "true") + .addExtraProperty("protocol.spooling.encryption-key", randomAES256Key()) + .setAdditionalSetup(runner -> { + runner.installPlugin(new FilesystemSpoolingPlugin()); + runner.loadSpoolingManager("filesystem", Map.of( + "fs.native-s3.enabled", "true", + "location", "s3://" + bucketName, + "s3.endpoint", minio.getMinioAddress(), + "s3.region", MINIO_REGION, + "s3.aws-access-key", MINIO_ACCESS_KEY, + "s3.aws-secret-key", MINIO_SECRET_KEY, + "s3.path-style-access", "true")); + }) + .build(); + queryRunner.getCoordinator().getSessionPropertyManager().addSystemSessionProperties(TEST_SYSTEM_PROPERTIES); + try { + queryRunner.installPlugin(new MockConnectorPlugin(MockConnectorFactory.builder() + .withSessionProperties(TEST_CATALOG_PROPERTIES) + .build())); + queryRunner.createCatalog(TESTING_CATALOG, "mock"); + } + catch (RuntimeException e) { + throw closeAllSuppress(e, queryRunner); + } + return queryRunner; + } + + private static TestingTrinoClient createClient(TestingTrinoServer testingTrinoServer, Session session, String encodingId) + { + return new TestingTrinoClient(testingTrinoServer, new TestingStatementClientFactory() { + @Override + public StatementClient create(OkHttpClient httpClient, Session session, ClientSession clientSession, String query) + { + return newStatementClient(httpClient, QueryDataDecoders.get(encodingId), clientSession, query, Optional.empty()); + } + }, session, new OkHttpClient()); + } + + private static String randomAES256Key() + { + return getEncoder().encodeToString(createRandomAesEncryptionKey().getEncoded()); + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonLz4SpooledDistributedQueries.java b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonLz4SpooledDistributedQueries.java new file mode 100644 index 000000000000..a1604468b411 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonLz4SpooledDistributedQueries.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +public class TestJsonLz4SpooledDistributedQueries + extends AbstractSpooledQueryDataDistributedQueries +{ + @Override + protected String encodingId() + { + return "json-ext"; + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSnappySpooledDistributedQueries.java b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSnappySpooledDistributedQueries.java new file mode 100644 index 000000000000..68b4596d027e --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSnappySpooledDistributedQueries.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +public class TestJsonSnappySpooledDistributedQueries + extends AbstractSpooledQueryDataDistributedQueries +{ + @Override + protected String encodingId() + { + return "json-ext+snappy"; + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSpooledDistributedQueries.java b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSpooledDistributedQueries.java new file mode 100644 index 000000000000..161a1cb0d3f5 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonSpooledDistributedQueries.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +public class TestJsonSpooledDistributedQueries + extends AbstractSpooledQueryDataDistributedQueries +{ + @Override + protected String encodingId() + { + return "json-ext"; + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonZstdSpooledDistributedQueries.java b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonZstdSpooledDistributedQueries.java new file mode 100644 index 000000000000..c3540ba00ed5 --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/server/protocol/TestJsonZstdSpooledDistributedQueries.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server.protocol; + +public class TestJsonZstdSpooledDistributedQueries + extends AbstractSpooledQueryDataDistributedQueries +{ + @Override + protected String encodingId() + { + return "json-ext+zstd"; + } +} diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java index 644b26def621..53254e8b4ddc 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java @@ -13,11 +13,16 @@ */ package io.trino.tests; +import com.fasterxml.jackson.databind.Module; import com.google.common.base.Splitter; import com.google.common.collect.AbstractSequentialIterator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Streams; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.TypeLiteral; +import io.airlift.bootstrap.Bootstrap; import io.airlift.http.client.FullJsonResponseHandler.JsonResponse; import io.airlift.http.client.HttpClient; import io.airlift.http.client.HttpUriBuilder; @@ -25,6 +30,10 @@ import io.airlift.http.client.StatusResponseHandler.StatusResponse; import io.airlift.http.client.jetty.JettyHttpClient; import io.airlift.json.JsonCodec; +import io.airlift.json.JsonModule; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Tracer; +import io.trino.client.QueryDataClientJacksonModule; import io.trino.client.QueryError; import io.trino.client.QueryResults; import io.trino.plugin.memory.MemoryPlugin; @@ -59,14 +68,16 @@ import static com.google.common.net.HttpHeaders.X_FORWARDED_HOST; import static com.google.common.net.HttpHeaders.X_FORWARDED_PORT; import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO; +import static com.google.inject.multibindings.Multibinder.newSetBinder; import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; import static io.airlift.http.client.Request.Builder.prepareGet; import static io.airlift.http.client.Request.Builder.prepareHead; import static io.airlift.http.client.Request.Builder.preparePost; import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; -import static io.airlift.json.JsonCodec.jsonCodec; +import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; import static io.airlift.testing.Closeables.closeAll; +import static io.airlift.tracing.Tracing.noopTracer; import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static io.trino.SystemSessionProperties.MAX_HASH_PARTITION_COUNT; import static io.trino.SystemSessionProperties.QUERY_MAX_MEMORY; @@ -92,7 +103,7 @@ @Execution(CONCURRENT) public class TestServer { - private static final JsonCodec QUERY_RESULTS_CODEC = jsonCodec(QueryResults.class); + private static final JsonCodec QUERY_RESULTS_CODEC = queryResultsCodec(); private TestingTrinoServer server; private HttpClient client; @@ -169,7 +180,7 @@ public void testFirstResponseColumns() assertThat(data).isPresent(); QueryResults results = data.orElseThrow(); - assertThat(results.getData()).containsOnly(ImmutableList.of("memory"), ImmutableList.of("system")); + assertThat(results.getData().getData()).containsOnly(ImmutableList.of("memory"), ImmutableList.of("system")); } @Test @@ -199,7 +210,7 @@ public void testQuery() .peek(result -> assertThat(result.getError()).isNull()) .peek(results -> { if (results.getData() != null) { - data.addAll(results.getData()); + data.addAll(results.getData().getData()); } }) .collect(last()); @@ -439,4 +450,19 @@ protected JsonResponse computeNext(JsonResponse prev return client.execute(prepareGet().setUri(previous.getValue().getNextUri()).build(), createFullJsonResponseHandler(QUERY_RESULTS_CODEC)); } } + + public static JsonCodec queryResultsCodec() + { + Injector injector = new Bootstrap( + new JsonModule(), + binder -> { + jsonCodecBinder(binder).bindJsonCodec(QueryResults.class); + newSetBinder(binder, Module.class).addBinding().to(QueryDataClientJacksonModule.class); + binder.bind(OpenTelemetry.class).toInstance(OpenTelemetry.noop()); + binder.bind(Tracer.class).toInstance(noopTracer()); + } + ).initialize(); + + return injector.getInstance(Key.get(new TypeLiteral<>(){})); + } }