Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MariaDB statistics support #19408

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions plugin/trino-mariadb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
<artifactId>configuration</artifactId>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>log</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-base-jdbc</artifactId>
Expand All @@ -48,6 +53,11 @@
<artifactId>jakarta.validation-api</artifactId>
</dependency>

<dependency>
<groupId>org.jdbi</groupId>
<artifactId>jdbi3-core</artifactId>
</dependency>

<dependency>
<groupId>org.mariadb.jdbc</groupId>
<artifactId>mariadb-java-client</artifactId>
Expand Down Expand Up @@ -89,12 +99,6 @@
<scope>provided</scope>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>log</artifactId>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>io.airlift</groupId>
<artifactId>log-manager</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
*/
package io.trino.plugin.mariadb;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.trino.plugin.base.aggregation.AggregateFunctionRewriter;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
Expand All @@ -27,6 +29,7 @@
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.plugin.jdbc.JdbcSortItem;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.LongWriteFunction;
Expand Down Expand Up @@ -56,13 +59,18 @@
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.Estimate;
import io.trino.spi.statistics.TableStatistics;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.TimeType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.Jdbi;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
Expand All @@ -75,13 +83,18 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Stream;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalDefaultScale;
import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding;
Expand Down Expand Up @@ -137,12 +150,15 @@
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.lang.String.join;
import static java.util.Map.entry;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;

public class MariaDbClient
extends BaseJdbcClient
{
private static final Logger log = Logger.get(MariaDbClient.class);

private static final int MAX_SUPPORTED_DATE_TIME_PRECISION = 6;
// MariaDB driver returns width of time types instead of precision.
private static final int ZERO_PRECISION_TIME_COLUMN_SIZE = 10;
Expand All @@ -156,17 +172,25 @@ public class MariaDbClient
// MariaDB Error Codes https://mariadb.com/kb/en/mariadb-error-codes/
private static final int PARSE_ERROR = 1064;

private final boolean statisticsEnabled;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;

@Inject
public MariaDbClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier)
public MariaDbClient(
BaseJdbcConfig config,
JdbcStatisticsConfig statisticsConfig,
ConnectionFactory connectionFactory,
QueryBuilder queryBuilder,
IdentifierMapping identifierMapping,
RemoteQueryModifier queryModifier)
{
super("`", connectionFactory, queryBuilder, config.getJdbcTypesMappedToVarchar(), identifierMapping, queryModifier, false);

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.build();
this.statisticsEnabled = statisticsConfig.isEnabled();
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
connectorExpressionRewriter,
ImmutableSet.<AggregateFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
Expand Down Expand Up @@ -623,6 +647,102 @@ protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCon
.noneMatch(type -> type instanceof CharType || type instanceof VarcharType);
}

@Override
public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle)
{
if (!statisticsEnabled) {
return TableStatistics.empty();
}
if (!handle.isNamedRelation()) {
return TableStatistics.empty();
}
hashhar marked this conversation as resolved.
Show resolved Hide resolved
try {
return readTableStatistics(session, handle);
}
catch (SQLException | RuntimeException e) {
throwIfInstanceOf(e, TrinoException.class);
throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e);
}
}

private TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table)
throws SQLException
{
checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table);

log.debug("Reading statistics for %s", table);
try (Connection connection = connectionFactory.openConnection(session);
Handle handle = Jdbi.open(connection)) {
StatisticsDao statisticsDao = new StatisticsDao(handle);

Long rowCount = statisticsDao.getTableRowCount(table);
Long indexMaxCardinality = statisticsDao.getTableMaxColumnIndexCardinality(table);
log.debug("Estimated row count of table %s is %s, and max index cardinality is %s", table, rowCount, indexMaxCardinality);

if (rowCount != null && rowCount == 0) {
// MariaDB may report 0 row count until a table is analyzed for the first time.
rowCount = null;
}

if (rowCount == null && indexMaxCardinality == null) {
// Table not found, or is a view, or has no usable statistics
return TableStatistics.empty();
}
rowCount = max(firstNonNull(rowCount, 0L), firstNonNull(indexMaxCardinality, 0L));

TableStatistics.Builder tableStatistics = TableStatistics.builder();
tableStatistics.setRowCount(Estimate.of(rowCount));

// TODO statistics from ANALYZE TABLE (https://mariadb.com/kb/en/engine-independent-table-statistics/)
// Map<String, AnalyzeColumnStatistics> columnStatistics = statisticsDao.getColumnStatistics(table);
Map<String, AnalyzeColumnStatistics> columnStatistics = ImmutableMap.of();

// TODO add support for histograms https://mariadb.com/kb/en/histogram-based-statistics/

// statistics based on existing indexes
Map<String, ColumnIndexStatistics> columnStatisticsFromIndexes = statisticsDao.getColumnIndexStatistics(table);

if (columnStatistics.isEmpty() && columnStatisticsFromIndexes.isEmpty()) {
log.debug("No column and index statistics read");
// No more information to work on
return tableStatistics.build();
}

for (JdbcColumnHandle column : getColumns(session, table)) {
ColumnStatistics.Builder columnStatisticsBuilder = ColumnStatistics.builder();

String columnName = column.getColumnName();
AnalyzeColumnStatistics analyzeColumnStatistics = columnStatistics.get(columnName);
if (analyzeColumnStatistics != null) {
log.debug("Reading column statistics for %s, %s from analayze's column statistics: %s", table, columnName, analyzeColumnStatistics);
columnStatisticsBuilder.setNullsFraction(Estimate.of(analyzeColumnStatistics.nullsRatio()));
}

ColumnIndexStatistics columnIndexStatistics = columnStatisticsFromIndexes.get(columnName);
if (columnIndexStatistics != null) {
log.debug("Reading column statistics for %s, %s from index statistics: %s", table, columnName, columnIndexStatistics);
columnStatisticsBuilder.setDistinctValuesCount(Estimate.of(columnIndexStatistics.cardinality()));

if (!columnIndexStatistics.nullable()) {
double knownNullFraction = columnStatisticsBuilder.build().getNullsFraction().getValue();
if (knownNullFraction > 0) {
log.warn("Inconsistent statistics, null fraction for a column %s, %s, that is not nullable according to index statistics: %s", table, columnName, knownNullFraction);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean "column %s.%s"? otherwise it would log foo, bar where foo is table and bar is column.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i admit having copied this from MySQL's io.trino.plugin.mysql.MySqlClient#updateColumnStatisticsFromIndexStatistics

yes, we can improve both places

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take care of this.

}
columnStatisticsBuilder.setNullsFraction(Estimate.zero());
}

// row count from INFORMATION_SCHEMA.TABLES may be very inaccurate
rowCount = max(rowCount, columnIndexStatistics.cardinality());
Comment on lines +734 to +735
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code seems to assume inaccuracy is always towards the lower values - is this a fair assumption to make?

i.e. rowCount is always lower than column index cardinality value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think it is. I think the estimate may be well under-estimated and also over-estimated.
However, row count and NDV are related. We meed to do something (either way).
I copied the current approach from MySQL connector stats code (where it's probably where i implemented it earlier)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, looking at a few other places the logic is that overestimating stats is better than underestimation because underestimation can actually lead to objectively bad plans.

e.g. broadcast of a table that is too large to broadcast is probably worse than a hash join even if the table could be broadcast.

Marking as resolved.

}

tableStatistics.setColumnStatistics(column, columnStatisticsBuilder.build());
}

tableStatistics.setRowCount(Estimate.of(rowCount));
return tableStatistics.build();
}
}

private static LongWriteFunction dateWriteFunction()
{
return (statement, index, day) -> statement.setString(index, DATE_FORMATTER.format(LocalDate.ofEpochDay(day)));
Expand Down Expand Up @@ -650,4 +770,101 @@ private static Optional<ColumnMapping> getUnsignedMapping(JdbcTypeHandle typeHan

return Optional.empty();
}

private static class StatisticsDao
{
private final Handle handle;

public StatisticsDao(Handle handle)
{
this.handle = requireNonNull(handle, "handle is null");
}

Long getTableRowCount(JdbcTableHandle table)
{
RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
return handle.createQuery("""
SELECT TABLE_ROWS FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name
AND TABLE_TYPE = 'BASE TABLE'
""")
.bind("schema", remoteTableName.getCatalogName().orElse(null))
.bind("table_name", remoteTableName.getTableName())
.mapTo(Long.class)
.findOne()
.orElse(null);
}

Long getTableMaxColumnIndexCardinality(JdbcTableHandle table)
{
RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
return handle.createQuery("""
SELECT max(CARDINALITY) AS row_count FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name
""")
.bind("schema", remoteTableName.getCatalogName().orElse(null))
.bind("table_name", remoteTableName.getTableName())
.mapTo(Long.class)
.findOne()
.orElse(null);
}

Map<String, AnalyzeColumnStatistics> getColumnStatistics(JdbcTableHandle table)
{
RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
return handle.createQuery("""
SELECT
column_name,
-- TODO min_value, max_value,
nulls_ratio
FROM mysql.column_stats
WHERE db_name = :database AND TABLE_NAME = :table_name
AND nulls_ratio IS NOT NULL
""")
.bind("database", remoteTableName.getCatalogName().orElse(null))
.bind("table_name", remoteTableName.getTableName())
.map((rs, ctx) -> {
String columnName = rs.getString("column_name");
double nullsRatio = rs.getDouble("nulls_ratio");
return entry(columnName, new AnalyzeColumnStatistics(nullsRatio));
})
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
}

Map<String, ColumnIndexStatistics> getColumnIndexStatistics(JdbcTableHandle table)
{
RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName();
return handle.createQuery("""
SELECT
COLUMN_NAME,
MAX(NULLABLE) AS NULLABLE,
MAX(CARDINALITY) AS CARDINALITY
FROM INFORMATION_SCHEMA.STATISTICS
WHERE TABLE_SCHEMA = :schema AND TABLE_NAME = :table_name
AND SEQ_IN_INDEX = 1 -- first column in the index
AND SUB_PART IS NULL -- ignore cases where only a column prefix is indexed
AND CARDINALITY IS NOT NULL -- CARDINALITY might be null (https://stackoverflow.com/a/42242729/65458)
AND CARDINALITY != 0 -- CARDINALITY is initially 0 until analyzed
GROUP BY COLUMN_NAME -- there might be multiple indexes on a column
Comment on lines +845 to +848
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to port this to MySQL as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was copied from MySQL and slightly adapted.

i think here AND CARDINALITY != 0 line was added (based on observation with MariaDB)
i suspect MySQL benefits from AND CARDINALITY IS NOT NULL line and maybe it's redundant for mariadb. it is very hard to test this, so i retained old condition and just added a new one

someone would need to observe what mysql is doing to determine whether AND CARDINALITY != 0 would make sense there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take care of this.

""")
.bind("schema", remoteTableName.getCatalogName().orElse(null))
.bind("table_name", remoteTableName.getTableName())
.map((rs, ctx) -> {
String columnName = rs.getString("COLUMN_NAME");

boolean nullable = rs.getString("NULLABLE").equalsIgnoreCase("YES");
checkState(!rs.wasNull(), "NULLABLE is null");

long cardinality = rs.getLong("CARDINALITY");
checkState(!rs.wasNull(), "CARDINALITY is null");

return entry(columnName, new ColumnIndexStatistics(nullable, cardinality));
})
.collect(toImmutableMap(Entry::getKey, Entry::getValue));
}
}

private record AnalyzeColumnStatistics(double nullsRatio) {}

private record ColumnIndexStatistics(boolean nullable, long cardinality) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import io.trino.plugin.jdbc.ptf.Query;
import io.trino.spi.function.table.ConnectorTableFunction;
Expand All @@ -43,6 +44,7 @@ public void configure(Binder binder)
{
binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(MariaDbClient.class).in(Scopes.SINGLETON);
configBinder(binder).bindConfig(MariaDbJdbcConfig.class);
configBinder(binder).bindConfig(JdbcStatisticsConfig.class);
binder.install(new DecimalModule());
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON);
}
Expand Down
Loading
Loading