Skip to content

Commit

Permalink
Allow standard tpch types and names in HiveQueryRunner#main
Browse files Browse the repository at this point in the history
To allow running unmodified tpch queries when debugging
  • Loading branch information
arhimondr authored and findepi committed Feb 26, 2022
1 parent a69b396 commit e9f8fa7
Showing 1 changed file with 46 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import io.trino.plugin.hive.metastore.file.FileHiveMetastore;
import io.trino.plugin.hive.metastore.file.FileHiveMetastoreConfig;
import io.trino.plugin.tpcds.TpcdsPlugin;
import io.trino.plugin.tpch.ColumnNaming;
import io.trino.plugin.tpch.DecimalTypeMapping;
import io.trino.plugin.tpch.TpchPlugin;
import io.trino.spi.security.Identity;
import io.trino.spi.security.PrincipalType;
Expand All @@ -52,6 +54,8 @@
import static io.trino.plugin.hive.HiveTestUtils.HDFS_ENVIRONMENT;
import static io.trino.plugin.hive.security.HiveSecurityModule.ALLOW_ALL;
import static io.trino.plugin.hive.security.HiveSecurityModule.SQL_STANDARD;
import static io.trino.plugin.tpch.ColumnNaming.SIMPLIFIED;
import static io.trino.plugin.tpch.DecimalTypeMapping.DOUBLE;
import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME;
import static io.trino.spi.security.SelectedRole.Type.ROLE;
import static io.trino.testing.QueryAssertions.copyTpchTables;
Expand Down Expand Up @@ -114,6 +118,8 @@ public static class Builder<SELF extends Builder<?>>
private Optional<CachingDirectoryLister> cachingDirectoryLister = Optional.empty();
private boolean tpcdsCatalogEnabled;
private String security = SQL_STANDARD;
private ColumnNaming tpchColumnNaming = SIMPLIFIED;
private DecimalTypeMapping tpchDecimalTypeMapping = DOUBLE;

protected Builder()
{
Expand Down Expand Up @@ -198,6 +204,18 @@ public SELF setSecurity(String security)
return self();
}

public SELF setTpchColumnNaming(ColumnNaming tpchColumnNaming)
{
this.tpchColumnNaming = requireNonNull(tpchColumnNaming, "tpchColumnNaming is null");
return self();
}

public SELF setTpchDecimalTypeMapping(DecimalTypeMapping tpchDecimalTypeMapping)
{
this.tpchDecimalTypeMapping = requireNonNull(tpchDecimalTypeMapping, "tpchDecimalTypeMapping is null");
return self();
}

@Override
public DistributedQueryRunner build()
throws Exception
Expand All @@ -208,7 +226,11 @@ public DistributedQueryRunner build()

try {
queryRunner.installPlugin(new TpchPlugin());
queryRunner.createCatalog("tpch", "tpch");
Map<String, String> tpchCatalogProperties = ImmutableMap.<String, String>builder()
.put("tpch.column-naming", tpchColumnNaming.name())
.put("tpch.double-type-mapping", tpchDecimalTypeMapping.name())
.buildOrThrow();
queryRunner.createCatalog("tpch", "tpch", tpchCatalogProperties);

if (tpcdsCatalogEnabled) {
queryRunner.installPlugin(new TpcdsPlugin());
Expand Down Expand Up @@ -264,7 +286,7 @@ private void populateData(DistributedQueryRunner queryRunner, HiveMetastore meta
if (metastore.getDatabase(TPCH_BUCKETED_SCHEMA).isEmpty()) {
metastore.createDatabase(createDatabaseMetastoreObject(TPCH_BUCKETED_SCHEMA, initialSchemasLocationBase));
Session session = initialTablesSessionMutator.apply(createBucketedSession(Optional.empty()));
copyTpchTablesBucketed(queryRunner, "tpch", TINY_SCHEMA_NAME, session, initialTables);
copyTpchTablesBucketed(queryRunner, "tpch", TINY_SCHEMA_NAME, session, initialTables, tpchColumnNaming);
}
}
}
Expand Down Expand Up @@ -318,43 +340,50 @@ private static void copyTpchTablesBucketed(
String sourceCatalog,
String sourceSchema,
Session session,
Iterable<TpchTable<?>> tables)
Iterable<TpchTable<?>> tables,
ColumnNaming columnNaming)
{
log.info("Loading data from %s.%s...", sourceCatalog, sourceSchema);
long startTime = System.nanoTime();
for (TpchTable<?> table : tables) {
copyTableBucketed(queryRunner, new QualifiedObjectName(sourceCatalog, sourceSchema, table.getTableName().toLowerCase(ENGLISH)), session);
copyTableBucketed(queryRunner, new QualifiedObjectName(sourceCatalog, sourceSchema, table.getTableName().toLowerCase(ENGLISH)), table, session, columnNaming);
}
log.info("Loading from %s.%s complete in %s", sourceCatalog, sourceSchema, nanosSince(startTime).toString(SECONDS));
}

private static void copyTableBucketed(QueryRunner queryRunner, QualifiedObjectName table, Session session)
private static void copyTableBucketed(QueryRunner queryRunner, QualifiedObjectName tableName, TpchTable<?> table, Session session, ColumnNaming columnNaming)
{
long start = System.nanoTime();
log.info("Running import for %s", table.getObjectName());
log.info("Running import for %s", tableName.getObjectName());
@Language("SQL") String sql;
switch (table.getObjectName()) {
switch (tableName.getObjectName()) {
case "part":
case "partsupp":
case "supplier":
case "nation":
case "region":
sql = format("CREATE TABLE %s AS SELECT * FROM %s", table.getObjectName(), table);
sql = format("CREATE TABLE %s AS SELECT * FROM %s", tableName.getObjectName(), tableName);
break;
case "lineitem":
sql = format("CREATE TABLE %s WITH (bucketed_by=array['orderkey'], bucket_count=11) AS SELECT * FROM %s", table.getObjectName(), table);
sql = format(
"CREATE TABLE %s WITH (bucketed_by=array['%s'], bucket_count=11) AS SELECT * FROM %s",
tableName.getObjectName(),
columnNaming.getName(table.getColumn("orderkey")),
tableName);
break;
case "customer":
sql = format("CREATE TABLE %s WITH (bucketed_by=array['custkey'], bucket_count=11) AS SELECT * FROM %s", table.getObjectName(), table);
break;
case "orders":
sql = format("CREATE TABLE %s WITH (bucketed_by=array['custkey'], bucket_count=11) AS SELECT * FROM %s", table.getObjectName(), table);
sql = format(
"CREATE TABLE %s WITH (bucketed_by=array['%s'], bucket_count=11) AS SELECT * FROM %s",
tableName.getObjectName(),
columnNaming.getName(table.getColumn("custkey")),
tableName);
break;
default:
throw new UnsupportedOperationException();
}
long rows = (Long) queryRunner.execute(session, sql).getMaterializedRows().get(0).getField(0);
log.info("Imported %s rows for %s in %s", rows, table.getObjectName(), nanosSince(start).convertToMostSuccinctTimeUnit());
log.info("Imported %s rows for %s in %s", rows, tableName.getObjectName(), nanosSince(start).convertToMostSuccinctTimeUnit());
}

public static void main(String[] args)
Expand All @@ -380,6 +409,10 @@ public static void main(String[] args)
.setBaseDataDir(baseDataDir)
.setTpcdsCatalogEnabled(true)
.setSecurity(ALLOW_ALL)
// Uncomment to enable standard column naming (column names to be prefixed with the first letter of the table name, e.g.: o_orderkey vs orderkey)
// and standard column types (decimals vs double for some columns). This will allow running unmodified tpch queries on the cluster.
// .setTpchColumnNaming(STANDARD)
// .setTpchDecimalTypeMapping(DECIMAL)
.build();
Thread.sleep(10);
log.info("======== SERVER STARTED ========");
Expand Down

0 comments on commit e9f8fa7

Please sign in to comment.