From e9f8fa75b2038d7b94f68cb27cf98364eb85861b Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Tue, 22 Feb 2022 17:59:56 -0500 Subject: [PATCH] Allow standard tpch types and names in HiveQueryRunner#main To allow running unmodified tpch queries when debugging --- .../io/trino/plugin/hive/HiveQueryRunner.java | 59 +++++++++++++++---- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java index 1784c909b507..65edc9264e2c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/HiveQueryRunner.java @@ -27,6 +27,8 @@ import io.trino.plugin.hive.metastore.file.FileHiveMetastore; import io.trino.plugin.hive.metastore.file.FileHiveMetastoreConfig; import io.trino.plugin.tpcds.TpcdsPlugin; +import io.trino.plugin.tpch.ColumnNaming; +import io.trino.plugin.tpch.DecimalTypeMapping; import io.trino.plugin.tpch.TpchPlugin; import io.trino.spi.security.Identity; import io.trino.spi.security.PrincipalType; @@ -52,6 +54,8 @@ import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT; import static io.trino.plugin.hive.security.HiveSecurityModule.ALLOW_ALL; import static io.trino.plugin.hive.security.HiveSecurityModule.SQL_STANDARD; +import static io.trino.plugin.tpch.ColumnNaming.SIMPLIFIED; +import static io.trino.plugin.tpch.DecimalTypeMapping.DOUBLE; import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; import static io.trino.spi.security.SelectedRole.Type.ROLE; import static io.trino.testing.QueryAssertions.copyTpchTables; @@ -114,6 +118,8 @@ public static class Builder> private Optional cachingDirectoryLister = Optional.empty(); private boolean tpcdsCatalogEnabled; private String security = SQL_STANDARD; + private ColumnNaming tpchColumnNaming = SIMPLIFIED; + private DecimalTypeMapping tpchDecimalTypeMapping = DOUBLE; protected Builder() { @@ -198,6 +204,18 @@ public SELF setSecurity(String security) return self(); } + public SELF setTpchColumnNaming(ColumnNaming tpchColumnNaming) + { + this.tpchColumnNaming = requireNonNull(tpchColumnNaming, "tpchColumnNaming is null"); + return self(); + } + + public SELF setTpchDecimalTypeMapping(DecimalTypeMapping tpchDecimalTypeMapping) + { + this.tpchDecimalTypeMapping = requireNonNull(tpchDecimalTypeMapping, "tpchDecimalTypeMapping is null"); + return self(); + } + @Override public DistributedQueryRunner build() throws Exception @@ -208,7 +226,11 @@ public DistributedQueryRunner build() try { queryRunner.installPlugin(new TpchPlugin()); - queryRunner.createCatalog("tpch", "tpch"); + Map tpchCatalogProperties = ImmutableMap.builder() + .put("tpch.column-naming", tpchColumnNaming.name()) + .put("tpch.double-type-mapping", tpchDecimalTypeMapping.name()) + .buildOrThrow(); + queryRunner.createCatalog("tpch", "tpch", tpchCatalogProperties); if (tpcdsCatalogEnabled) { queryRunner.installPlugin(new TpcdsPlugin()); @@ -264,7 +286,7 @@ private void populateData(DistributedQueryRunner queryRunner, HiveMetastore meta if (metastore.getDatabase(TPCH_BUCKETED_SCHEMA).isEmpty()) { metastore.createDatabase(createDatabaseMetastoreObject(TPCH_BUCKETED_SCHEMA, initialSchemasLocationBase)); Session session = initialTablesSessionMutator.apply(createBucketedSession(Optional.empty())); - copyTpchTablesBucketed(queryRunner, "tpch", TINY_SCHEMA_NAME, session, initialTables); + copyTpchTablesBucketed(queryRunner, "tpch", TINY_SCHEMA_NAME, session, initialTables, tpchColumnNaming); } } } @@ -318,43 +340,50 @@ private static void copyTpchTablesBucketed( String sourceCatalog, String sourceSchema, Session session, - Iterable> tables) + Iterable> tables, + ColumnNaming columnNaming) { log.info("Loading data from %s.%s...", sourceCatalog, sourceSchema); long startTime = System.nanoTime(); for (TpchTable table : tables) { - copyTableBucketed(queryRunner, new QualifiedObjectName(sourceCatalog, sourceSchema, table.getTableName().toLowerCase(ENGLISH)), session); + copyTableBucketed(queryRunner, new QualifiedObjectName(sourceCatalog, sourceSchema, table.getTableName().toLowerCase(ENGLISH)), table, session, columnNaming); } log.info("Loading from %s.%s complete in %s", sourceCatalog, sourceSchema, nanosSince(startTime).toString(SECONDS)); } - private static void copyTableBucketed(QueryRunner queryRunner, QualifiedObjectName table, Session session) + private static void copyTableBucketed(QueryRunner queryRunner, QualifiedObjectName tableName, TpchTable table, Session session, ColumnNaming columnNaming) { long start = System.nanoTime(); - log.info("Running import for %s", table.getObjectName()); + log.info("Running import for %s", tableName.getObjectName()); @Language("SQL") String sql; - switch (table.getObjectName()) { + switch (tableName.getObjectName()) { case "part": case "partsupp": case "supplier": case "nation": case "region": - sql = format("CREATE TABLE %s AS SELECT * FROM %s", table.getObjectName(), table); + sql = format("CREATE TABLE %s AS SELECT * FROM %s", tableName.getObjectName(), tableName); break; case "lineitem": - sql = format("CREATE TABLE %s WITH (bucketed_by=array['orderkey'], bucket_count=11) AS SELECT * FROM %s", table.getObjectName(), table); + sql = format( + "CREATE TABLE %s WITH (bucketed_by=array['%s'], bucket_count=11) AS SELECT * FROM %s", + tableName.getObjectName(), + columnNaming.getName(table.getColumn("orderkey")), + tableName); break; case "customer": - sql = format("CREATE TABLE %s WITH (bucketed_by=array['custkey'], bucket_count=11) AS SELECT * FROM %s", table.getObjectName(), table); - break; case "orders": - sql = format("CREATE TABLE %s WITH (bucketed_by=array['custkey'], bucket_count=11) AS SELECT * FROM %s", table.getObjectName(), table); + sql = format( + "CREATE TABLE %s WITH (bucketed_by=array['%s'], bucket_count=11) AS SELECT * FROM %s", + tableName.getObjectName(), + columnNaming.getName(table.getColumn("custkey")), + tableName); break; default: throw new UnsupportedOperationException(); } long rows = (Long) queryRunner.execute(session, sql).getMaterializedRows().get(0).getField(0); - log.info("Imported %s rows for %s in %s", rows, table.getObjectName(), nanosSince(start).convertToMostSuccinctTimeUnit()); + log.info("Imported %s rows for %s in %s", rows, tableName.getObjectName(), nanosSince(start).convertToMostSuccinctTimeUnit()); } public static void main(String[] args) @@ -380,6 +409,10 @@ public static void main(String[] args) .setBaseDataDir(baseDataDir) .setTpcdsCatalogEnabled(true) .setSecurity(ALLOW_ALL) + // Uncomment to enable standard column naming (column names to be prefixed with the first letter of the table name, e.g.: o_orderkey vs orderkey) + // and standard column types (decimals vs double for some columns). This will allow running unmodified tpch queries on the cluster. + // .setTpchColumnNaming(STANDARD) + // .setTpchDecimalTypeMapping(DECIMAL) .build(); Thread.sleep(10); log.info("======== SERVER STARTED ========");