diff --git a/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/QueryReader.java b/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/QueryReader.java index c6f3253fc424..7bd5b0bca62c 100644 --- a/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/QueryReader.java +++ b/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/QueryReader.java @@ -17,6 +17,10 @@ */ package org.apache.beam.sdk.tpcds; +import java.util.Set; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParseException; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParser; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; @@ -37,4 +41,20 @@ public static String readQuery(String queryFileName) throws Exception { String path = "queries/" + queryFileName + ".sql"; return Resources.toString(Resources.getResource(path), Charsets.UTF_8); } + + /** + * Parse query and get all its identifiers. + * + * @param queryString + * @return Set of SQL query identifiers as strings. + * @throws SqlParseException + */ + public static Set getQueryIdentifiers(String queryString) throws SqlParseException { + SqlParser parser = SqlParser.create(queryString); + SqlNode parsedQuery = parser.parseQuery(); + SqlTransformRunner.SqlIdentifierVisitor sqlVisitor = + new SqlTransformRunner.SqlIdentifierVisitor(); + parsedQuery.accept(sqlVisitor); + return sqlVisitor.getIdentifiers(); + } } diff --git a/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java b/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java index ad1714fe0f0d..c3b63e52af3f 100644 --- a/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java +++ b/sdks/java/testing/tpcds/src/main/java/org/apache/beam/sdk/tpcds/SqlTransformRunner.java @@ -20,8 +20,10 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; @@ -41,6 +43,8 @@ import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlIdentifier; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.util.SqlBasicVisitor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; import org.apache.commons.csv.CSVFormat; @@ -50,6 +54,8 @@ /** * This class executes jobs using PCollection and SqlTransform, it uses SqlTransform.query to run * queries. + * + *

TODO: Add tests. */ public class SqlTransformRunner { private static final String SUMMARY_START = "\n" + "TPC-DS Query Execution Summary:"; @@ -66,6 +72,21 @@ public class SqlTransformRunner { private static final Logger LOG = LoggerFactory.getLogger(SqlTransformRunner.class); + /** This class is used to extract all SQL query identifiers. */ + static class SqlIdentifierVisitor extends SqlBasicVisitor { + private final Set identifiers = new HashSet<>(); + + public Set getIdentifiers() { + return identifiers; + } + + @Override + public Void visit(SqlIdentifier id) { + identifiers.addAll(id.names); + return null; + } + } + /** * Get all tables (in the form of TextTable) needed for a specific query execution. * @@ -82,17 +103,17 @@ private static PCollectionTuple getTables( Map schemaMap = TpcdsSchemas.getTpcdsSchemas(); TpcdsOptions tpcdsOptions = pipeline.getOptions().as(TpcdsOptions.class); String dataSize = TpcdsParametersReader.getAndCheckDataSize(tpcdsOptions); - String queryString = QueryReader.readQuery(queryName); + Set identifiers = QueryReader.getQueryIdentifiers(QueryReader.readQuery(queryName)); PCollectionTuple tables = PCollectionTuple.empty(pipeline); for (Map.Entry tableSchema : schemaMap.entrySet()) { String tableName = tableSchema.getKey(); - // Only when queryString contains tableName, the table is relevant to this query and will be - // added. This can avoid reading unnecessary data files. - // TODO: Simple but not reliable way since table name can be any substring in a query and can - // give false positives - if (queryString.contains(tableName)) { + // Only when query identifiers contain tableName, the table is relevant to this query and will + // be added. This can avoid reading unnecessary data files. + if (identifiers.contains(tableName.toUpperCase())) { + Set tableColumns = getTableColumns(identifiers, tableSchema); + switch (tpcdsOptions.getSourceType()) { case CSV: { @@ -104,7 +125,7 @@ private static PCollectionTuple getTables( case PARQUET: { PCollection table = - getTableParquet(pipeline, tpcdsOptions, dataSize, tableName); + getTableParquet(pipeline, tpcdsOptions, dataSize, tableName, tableColumns); tables = tables.and(new TupleTag<>(tableName), table); break; } @@ -117,10 +138,28 @@ private static PCollectionTuple getTables( return tables; } + private static Set getTableColumns( + Set identifiers, Map.Entry tableSchema) { + Set tableColumns = new HashSet<>(); + List fields = tableSchema.getValue().getFields(); + for (Schema.Field field : fields) { + String fieldName = field.getName(); + if (identifiers.contains(fieldName.toUpperCase())) { + tableColumns.add(fieldName); + } + } + return tableColumns; + } + private static PCollection getTableParquet( - Pipeline pipeline, TpcdsOptions tpcdsOptions, String dataSize, String tableName) + Pipeline pipeline, + TpcdsOptions tpcdsOptions, + String dataSize, + String tableName, + Set tableColumns) throws IOException { org.apache.avro.Schema schema = getAvroSchema(tableName); + org.apache.avro.Schema schemaProjected = getProjectedSchema(tableColumns, schema); String filepattern = tpcdsOptions.getDataDirectory() + "/" + dataSize + "/" + tableName + "/*.parquet"; @@ -130,7 +169,7 @@ private static PCollection getTableParquet( ParquetIO.read(schema) .from(filepattern) .withSplit() - // TODO: add .withProjection() + .withProjection(schemaProjected, schemaProjected) .withBeamSchemas(true)); } @@ -161,6 +200,21 @@ private static org.apache.avro.Schema getAvroSchema(String tableName) throws IOE .parse(Resources.toString(Resources.getResource(path), Charsets.UTF_8)); } + static org.apache.avro.Schema getProjectedSchema( + Set projectedFieldNames, org.apache.avro.Schema schema) { + List projectedFields = new ArrayList<>(); + for (org.apache.avro.Schema.Field f : schema.getFields()) { + if (projectedFieldNames.contains(f.name())) { + projectedFields.add( + new org.apache.avro.Schema.Field(f.name(), f.schema(), f.doc(), f.defaultVal())); + } + } + org.apache.avro.Schema schemaProjected = + org.apache.avro.Schema.createRecord(schema.getName() + "_projected", null, null, false); + schemaProjected.setFields(projectedFields); + return schemaProjected; + } + /** * Print the summary table after all jobs are finished. * diff --git a/sdks/java/testing/tpcds/src/test/java/org/apache/beam/sdk/tpcds/QueryReaderTest.java b/sdks/java/testing/tpcds/src/test/java/org/apache/beam/sdk/tpcds/QueryReaderTest.java index 42f7d5b5abb3..b21cdfaefb32 100644 --- a/sdks/java/testing/tpcds/src/test/java/org/apache/beam/sdk/tpcds/QueryReaderTest.java +++ b/sdks/java/testing/tpcds/src/test/java/org/apache/beam/sdk/tpcds/QueryReaderTest.java @@ -19,6 +19,8 @@ import static org.junit.Assert.assertEquals; +import java.util.HashSet; +import java.util.Set; import org.junit.Test; public class QueryReaderTest { @@ -66,6 +68,32 @@ public void testQuery3String() throws Exception { assertEquals(expectedNoSpaces, query3StringNoSpaces); } + @Test + public void testQuery3Identifiers() throws Exception { + Set expected = new HashSet<>(); + expected.add("BRAND"); + expected.add("BRAND_ID"); + expected.add("D_DATE_SK"); + expected.add("D_MOY"); + expected.add("D_YEAR"); + expected.add("DATE_DIM"); + expected.add("DT"); + expected.add("I_BRAND"); + expected.add("I_BRAND_ID"); + expected.add("I_ITEM_SK"); + expected.add("I_MANUFACT_ID"); + expected.add("ITEM"); + expected.add("SS_EXT_SALES_PRICE"); + expected.add("SS_ITEM_SK"); + expected.add("SS_SOLD_DATE_SK"); + expected.add("STORE_SALES"); + expected.add("SUM_AGG"); + + String query3String = QueryReader.readQuery("query3"); + Set identifiers = QueryReader.getQueryIdentifiers(query3String); + assertEquals(expected, identifiers); + } + @Test public void testQuery4String() throws Exception { String query4String = QueryReader.readQuery("query4");