Skip to content

Commit

Permalink
Add Redshift statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Dec 12, 2022
1 parent 1e8887e commit e65585d
Show file tree
Hide file tree
Showing 6 changed files with 642 additions and 7 deletions.
12 changes: 6 additions & 6 deletions plugin/trino-redshift/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
<artifactId>javax.inject</artifactId>
</dependency>

<dependency>
<groupId>org.jdbi</groupId>
<artifactId>jdbi3-core</artifactId>
</dependency>

<!-- used by tests but also needed transitively -->
<dependency>
<groupId>io.airlift</groupId>
Expand All @@ -68,12 +73,6 @@
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>org.jdbi</groupId>
<artifactId>jdbi3-core</artifactId>
<scope>runtime</scope>
</dependency>

<!-- Trino SPI -->
<dependency>
<groupId>io.trino</groupId>
Expand Down Expand Up @@ -177,6 +176,7 @@
<configuration>
<excludes>
<exclude>**/TestRedshiftConnectorTest.java</exclude>
<exclude>**/TestRedshiftTableStatisticsReader.java</exclude>
<exclude>**/TestRedshiftTypeMapping.java</exclude>
</excludes>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.plugin.jdbc.ColumnMapping;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.LongWriteFunction;
Expand All @@ -35,7 +36,10 @@
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.plugin.jdbc.mapping.IdentifierMapping;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.CharType;
import io.trino.spi.type.Chars;
import io.trino.spi.type.DecimalType;
Expand Down Expand Up @@ -73,6 +77,7 @@
import java.util.function.BiFunction;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.base.Verify.verify;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR;
Expand Down Expand Up @@ -194,10 +199,21 @@ public class RedshiftClient
.toFormatter();
private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC);

private final boolean statisticsEnabled;
private final RedshiftTableStatisticsReader statisticsReader;

@Inject
public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier)
public RedshiftClient(
BaseJdbcConfig config,
ConnectionFactory connectionFactory,
JdbcStatisticsConfig statisticsConfig,
QueryBuilder queryBuilder,
IdentifierMapping identifierMapping,
RemoteQueryModifier queryModifier)
{
super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier);
this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled();
this.statisticsReader = new RedshiftTableStatisticsReader(connectionFactory);
}

@Override
Expand All @@ -207,6 +223,24 @@ public Optional<String> getTableComment(ResultSet resultSet)
return Optional.empty();
}

@Override
public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain<ColumnHandle> tupleDomain)
{
if (!statisticsEnabled) {
return TableStatistics.empty();
}
if (!handle.isNamedRelation()) {
return TableStatistics.empty();
}
try {
return statisticsReader.readTableStatistics(session, handle, () -> this.getColumns(session, handle));
}
catch (SQLException | RuntimeException e) {
throwIfInstanceOf(e, TrinoException.class);
throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e);
}
}

@Override
protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName)
throws SQLException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import io.trino.plugin.jdbc.ptf.Query;
import io.trino.spi.ptf.ConnectorTableFunction;
Expand All @@ -32,6 +33,7 @@

import static com.google.inject.Scopes.SINGLETON;
import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static io.airlift.configuration.ConfigBinder.configBinder;

public class RedshiftClientModule
extends AbstractConfigurationAwareModule
Expand All @@ -41,6 +43,7 @@ public void setup(Binder binder)
{
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(SINGLETON);
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(SINGLETON);
configBinder(binder).bindConfig(JdbcStatisticsConfig.class);

install(new DecimalModule());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.redshift;

import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.Estimate;
import io.trino.spi.statistics.TableStatistics;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.Jdbi;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public class RedshiftTableStatisticsReader
{
private final ConnectionFactory connectionFactory;

public RedshiftTableStatisticsReader(ConnectionFactory connectionFactory)
{
this.connectionFactory = requireNonNull(connectionFactory, "connectionFactory is null");
}

public TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table, Supplier<List<JdbcColumnHandle>> columnSupplier)
throws SQLException
{
checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table);

try (Connection connection = connectionFactory.openConnection(session);
Handle handle = Jdbi.open(connection)) {
StatisticsDao statisticsDao = new StatisticsDao(handle);

RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
Optional<Long> optionalRowCount = readRowCountTableStat(statisticsDao, table);
if (optionalRowCount.isEmpty()) {
// Table not found
return TableStatistics.empty();
}
long rowCount = optionalRowCount.get();

TableStatistics.Builder tableStatistics = TableStatistics.builder()
.setRowCount(Estimate.of(rowCount));

if (rowCount == 0) {
return tableStatistics.build();
}

Map<String, ColumnStatisticsResult> columnStatistics = statisticsDao.getColumnStatistics(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()).stream()
.collect(toImmutableMap(ColumnStatisticsResult::columnName, identity()));

for (JdbcColumnHandle column : columnSupplier.get()) {
ColumnStatisticsResult result = columnStatistics.get(column.getColumnName());
if (result == null) {
continue;
}

ColumnStatistics statistics = ColumnStatistics.builder()
.setNullsFraction(result.nullsFraction()
.map(Estimate::of)
.orElseGet(Estimate::unknown))
.setDistinctValuesCount(result.distinctValuesIndicator()
.map(distinctValuesIndicator -> {
// If the distinct value count is an estimate Redshift uses "the negative of the number of distinct values divided by the number of rows
// For example, -1 indicates a unique column in which the number of distinct values is the same as the number of rows."
// https://www.postgresql.org/docs/9.3/view-pg-stats.html
if (distinctValuesIndicator < 0.0) {
return Math.min(-distinctValuesIndicator * rowCount, rowCount);
}
return distinctValuesIndicator;
})
.map(Estimate::of)
.orElseGet(Estimate::unknown))
.setDataSize(result.averageColumnLength()
.flatMap(averageColumnLength ->
result.nullsFraction()
.map(nullsFraction -> 1.0 * averageColumnLength * rowCount * (1 - nullsFraction))
.map(Estimate::of))
.orElseGet(Estimate::unknown))
.build();

tableStatistics.setColumnStatistics(column, statistics);
}

return tableStatistics.build();
}
}

private static Optional<Long> readRowCountTableStat(StatisticsDao statisticsDao, JdbcTableHandle table)
{
RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
Optional<Long> rowCount = statisticsDao.getRowCountFromPgClass(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName());
if (rowCount.isEmpty()) {
// Table not found
return Optional.empty();
}

if (rowCount.get() == 0) {
// `pg_class.reltuples = 0` may mean an empty table or a recently populated table (CTAS, LOAD or INSERT)
// The `pg_stat_all_tables` view can be way off, so we use it only as a fallback
rowCount = statisticsDao.getRowCountFromPgStat(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName());
}

return rowCount;
}

private static class StatisticsDao
{
private final Handle handle;

public StatisticsDao(Handle handle)
{
this.handle = requireNonNull(handle, "handle is null");
}

Optional<Long> getRowCountFromPgClass(String schema, String tableName)
{
return handle.createQuery("SELECT reltuples FROM pg_class WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) AND relname = :table_name")
.bind("schema", schema)
.bind("table_name", tableName)
.mapTo(Long.class)
.findOne();
}

Optional<Long> getRowCountFromPgStat(String schema, String tableName)
{
// Redshift does not have the Postgres `n_live_tup`, so estimate from `inserts - deletes`
return handle.createQuery("SELECT n_tup_ins - n_tup_del FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name")
.bind("schema", schema)
.bind("table_name", tableName)
.mapTo(Long.class)
.findOne();
}

List<ColumnStatisticsResult> getColumnStatistics(String schema, String tableName)
{
return handle.createQuery("SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name")
.bind("schema", schema)
.bind("table_name", tableName)
.map((rs, ctx) ->
new ColumnStatisticsResult(
requireNonNull(rs.getString("attname"), "attname is null"),
Optional.of(rs.getFloat("null_frac")),
Optional.of(rs.getFloat("n_distinct")),
Optional.of(rs.getInt("avg_width"))))
.list();
}
}

// TODO remove when error prone is updated for Java 17 records
@SuppressWarnings("unused")
private record ColumnStatisticsResult(String columnName, Optional<Float> nullsFraction, Optional<Float> distinctValuesIndicator, Optional<Integer> averageColumnLength) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@

import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA;
import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner;
import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift;
import static io.trino.testing.TestingNames.randomNameSuffix;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

Expand Down Expand Up @@ -177,6 +179,77 @@ public void testDelete()
}
}

@Test(dataProvider = "testCaseColumnNamesDataProvider")
public void testCaseColumnNames(String tableName)
{
try {
assertUpdate(
"CREATE TABLE " + TEST_SCHEMA + "." + tableName +
" AS SELECT " +
" custkey AS CASE_UNQUOTED_UPPER, " +
" name AS case_unquoted_lower, " +
" address AS cASe_uNQuoTeD_miXED, " +
" nationkey AS \"CASE_QUOTED_UPPER\", " +
" phone AS \"case_quoted_lower\"," +
" acctbal AS \"CasE_QuoTeD_miXED\" " +
"FROM customer",
1500);
gatherStats(tableName);
assertQuery(
"SHOW STATS FOR " + TEST_SCHEMA + "." + tableName,
"VALUES " +
"('case_unquoted_upper', NULL, 1485, 0, null, null, null)," +
"('case_unquoted_lower', 33000, 1470, 0, null, null, null)," +
"('case_unquoted_mixed', 42000, 1500, 0, null, null, null)," +
"('case_quoted_upper', NULL, 25, 0, null, null, null)," +
"('case_quoted_lower', 28500, 1483, 0, null, null, null)," +
"('case_quoted_mixed', NULL, 1483, 0, null, null, null)," +
"(null, null, null, null, 1500, null, null)");
}
finally {
assertUpdate("DROP TABLE IF EXISTS " + tableName);
}
}

private static void gatherStats(String tableName)
{
executeInRedshift(handle -> {
handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName);
for (int i = 0; i < 5; i++) {
long actualCount = handle.createQuery("SELECT count(*) FROM " + TEST_SCHEMA + "." + tableName)
.mapTo(Long.class)
.one();
long estimatedCount = handle.createQuery("""
SELECT reltuples FROM pg_class
WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema)
AND relname = :table_name
""")
.bind("schema", TEST_SCHEMA)
.bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", ""))
.mapTo(Long.class)
.one();
if (actualCount == estimatedCount) {
return;
}
handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName);
}
throw new IllegalStateException("Stats not gathered"); // for small test tables reltuples should be exact
});
}

@DataProvider
public Object[][] testCaseColumnNamesDataProvider()
{
return new Object[][] {
{"TEST_STATS_MIXED_UNQUOTED_UPPER_" + randomNameSuffix()},
{"test_stats_mixed_unquoted_lower_" + randomNameSuffix()},
{"test_stats_mixed_uNQuoTeD_miXED_" + randomNameSuffix()},
{"\"TEST_STATS_MIXED_QUOTED_UPPER_" + randomNameSuffix() + "\""},
{"\"test_stats_mixed_quoted_lower_" + randomNameSuffix() + "\""},
{"\"test_stats_mixed_QuoTeD_miXED_" + randomNameSuffix() + "\""}
};
}

@Override
@Test
public void testReadMetadataWithRelationsConcurrentModifications()
Expand Down
Loading

0 comments on commit e65585d

Please sign in to comment.