Skip to content

Commit

Permalink
[BEAM-12270] TPC-DS: Add schema projection for Parquet source
Browse files Browse the repository at this point in the history
  • Loading branch information
aromanenko-dev committed Aug 20, 2021
1 parent 4e05416 commit 41d515d
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
*/
package org.apache.beam.sdk.tpcds;

import java.util.Set;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParseException;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.parser.SqlParser;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources;

Expand All @@ -37,4 +41,20 @@ public static String readQuery(String queryFileName) throws Exception {
String path = "queries/" + queryFileName + ".sql";
return Resources.toString(Resources.getResource(path), Charsets.UTF_8);
}

/**
* Parse query and get all its identifiers.
*
* @param queryString
* @return Set of SQL query identifiers as strings.
* @throws SqlParseException
*/
public static Set<String> getQueryIdentifiers(String queryString) throws SqlParseException {
SqlParser parser = SqlParser.create(queryString);
SqlNode parsedQuery = parser.parseQuery();
SqlTransformRunner.SqlIdentifierVisitor sqlVisitor =
new SqlTransformRunner.SqlIdentifierVisitor();
parsedQuery.accept(sqlVisitor);
return sqlVisitor.getIdentifiers();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletionService;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
Expand All @@ -41,6 +43,8 @@
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlIdentifier;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.util.SqlBasicVisitor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources;
import org.apache.commons.csv.CSVFormat;
Expand All @@ -50,6 +54,8 @@
/**
* This class executes jobs using PCollection and SqlTransform, it uses SqlTransform.query to run
* queries.
*
* <p>TODO: Add tests.
*/
public class SqlTransformRunner {
private static final String SUMMARY_START = "\n" + "TPC-DS Query Execution Summary:";
Expand All @@ -66,6 +72,21 @@ public class SqlTransformRunner {

private static final Logger LOG = LoggerFactory.getLogger(SqlTransformRunner.class);

/** This class is used to extract all SQL query identifiers. */
static class SqlIdentifierVisitor extends SqlBasicVisitor<Void> {
private final Set<String> identifiers = new HashSet<>();

public Set<String> getIdentifiers() {
return identifiers;
}

@Override
public Void visit(SqlIdentifier id) {
identifiers.addAll(id.names);
return null;
}
}

/**
* Get all tables (in the form of TextTable) needed for a specific query execution.
*
Expand All @@ -82,17 +103,17 @@ private static PCollectionTuple getTables(
Map<String, Schema> schemaMap = TpcdsSchemas.getTpcdsSchemas();
TpcdsOptions tpcdsOptions = pipeline.getOptions().as(TpcdsOptions.class);
String dataSize = TpcdsParametersReader.getAndCheckDataSize(tpcdsOptions);
String queryString = QueryReader.readQuery(queryName);
Set<String> identifiers = QueryReader.getQueryIdentifiers(QueryReader.readQuery(queryName));

PCollectionTuple tables = PCollectionTuple.empty(pipeline);
for (Map.Entry<String, Schema> tableSchema : schemaMap.entrySet()) {
String tableName = tableSchema.getKey();

// Only when queryString contains tableName, the table is relevant to this query and will be
// added. This can avoid reading unnecessary data files.
// TODO: Simple but not reliable way since table name can be any substring in a query and can
// give false positives
if (queryString.contains(tableName)) {
// Only when query identifiers contain tableName, the table is relevant to this query and will
// be added. This can avoid reading unnecessary data files.
if (identifiers.contains(tableName.toUpperCase())) {
Set<String> tableColumns = getTableColumns(identifiers, tableSchema);

switch (tpcdsOptions.getSourceType()) {
case CSV:
{
Expand All @@ -104,7 +125,7 @@ private static PCollectionTuple getTables(
case PARQUET:
{
PCollection<GenericRecord> table =
getTableParquet(pipeline, tpcdsOptions, dataSize, tableName);
getTableParquet(pipeline, tpcdsOptions, dataSize, tableName, tableColumns);
tables = tables.and(new TupleTag<>(tableName), table);
break;
}
Expand All @@ -117,10 +138,28 @@ private static PCollectionTuple getTables(
return tables;
}

private static Set<String> getTableColumns(
Set<String> identifiers, Map.Entry<String, Schema> tableSchema) {
Set<String> tableColumns = new HashSet<>();
List<Schema.Field> fields = tableSchema.getValue().getFields();
for (Schema.Field field : fields) {
String fieldName = field.getName();
if (identifiers.contains(fieldName.toUpperCase())) {
tableColumns.add(fieldName);
}
}
return tableColumns;
}

private static PCollection<GenericRecord> getTableParquet(
Pipeline pipeline, TpcdsOptions tpcdsOptions, String dataSize, String tableName)
Pipeline pipeline,
TpcdsOptions tpcdsOptions,
String dataSize,
String tableName,
Set<String> tableColumns)
throws IOException {
org.apache.avro.Schema schema = getAvroSchema(tableName);
org.apache.avro.Schema schemaProjected = getProjectedSchema(tableColumns, schema);

String filepattern =
tpcdsOptions.getDataDirectory() + "/" + dataSize + "/" + tableName + "/*.parquet";
Expand All @@ -130,7 +169,7 @@ private static PCollection<GenericRecord> getTableParquet(
ParquetIO.read(schema)
.from(filepattern)
.withSplit()
// TODO: add .withProjection()
.withProjection(schemaProjected, schemaProjected)
.withBeamSchemas(true));
}

Expand Down Expand Up @@ -161,6 +200,21 @@ private static org.apache.avro.Schema getAvroSchema(String tableName) throws IOE
.parse(Resources.toString(Resources.getResource(path), Charsets.UTF_8));
}

static org.apache.avro.Schema getProjectedSchema(
Set<String> projectedFieldNames, org.apache.avro.Schema schema) {
List<org.apache.avro.Schema.Field> projectedFields = new ArrayList<>();
for (org.apache.avro.Schema.Field f : schema.getFields()) {
if (projectedFieldNames.contains(f.name())) {
projectedFields.add(
new org.apache.avro.Schema.Field(f.name(), f.schema(), f.doc(), f.defaultVal()));
}
}
org.apache.avro.Schema schemaProjected =
org.apache.avro.Schema.createRecord(schema.getName() + "_projected", null, null, false);
schemaProjected.setFields(projectedFields);
return schemaProjected;
}

/**
* Print the summary table after all jobs are finished.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import static org.junit.Assert.assertEquals;

import java.util.HashSet;
import java.util.Set;
import org.junit.Test;

public class QueryReaderTest {
Expand Down Expand Up @@ -66,6 +68,32 @@ public void testQuery3String() throws Exception {
assertEquals(expectedNoSpaces, query3StringNoSpaces);
}

@Test
public void testQuery3Identifiers() throws Exception {
Set<String> expected = new HashSet<>();
expected.add("BRAND");
expected.add("BRAND_ID");
expected.add("D_DATE_SK");
expected.add("D_MOY");
expected.add("D_YEAR");
expected.add("DATE_DIM");
expected.add("DT");
expected.add("I_BRAND");
expected.add("I_BRAND_ID");
expected.add("I_ITEM_SK");
expected.add("I_MANUFACT_ID");
expected.add("ITEM");
expected.add("SS_EXT_SALES_PRICE");
expected.add("SS_ITEM_SK");
expected.add("SS_SOLD_DATE_SK");
expected.add("STORE_SALES");
expected.add("SUM_AGG");

String query3String = QueryReader.readQuery("query3");
Set<String> identifiers = QueryReader.getQueryIdentifiers(query3String);
assertEquals(expected, identifiers);
}

@Test
public void testQuery4String() throws Exception {
String query4String = QueryReader.readQuery("query4");
Expand Down

0 comments on commit 41d515d

Please sign in to comment.