Skip to content

Commit

Permalink
Added binary read support for Parquet [Databricks] (#6161)
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri authored Aug 5, 2022
1 parent 03e7fdd commit 8d14f8c
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 41 deletions.
16 changes: 8 additions & 8 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,11 @@ Accelerator supports are described below.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down Expand Up @@ -18816,11 +18816,11 @@ dates or timestamps, or for a lack of type coercion support.
<td>S</td>
<td>S</td>
<td> </td>
<td><b>NS</b></td>
<td>S</td>
<td> </td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
28 changes: 28 additions & 0 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,34 @@ def test_parquet_fallback(spark_tmp_path, read_func, disable_conf):
conf={disable_conf: 'false',
"spark.sql.sources.useV1SourceList": "parquet"})

@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@pytest.mark.parametrize('binary_as_string', [True, False])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
def test_parquet_read_round_trip_binary(std_input_path, read_func, binary_as_string, reader_confs):
data_path = std_input_path + '/binary_as_string.parquet'

all_confs = copy_and_update(reader_confs, {
'spark.sql.parquet.binaryAsString': binary_as_string,
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'})
# once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround
# for nested timestamp/date support
assert_gpu_and_cpu_are_equal_collect(read_func(data_path),
conf=all_confs)

@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@pytest.mark.parametrize('binary_as_string', [True, False])
def test_binary_df_read(spark_tmp_path, binary_as_string, read_func):
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(lambda spark: unary_op_df(spark, StringGen()).selectExpr("cast(a as binary)").write.parquet(data_path))
all_confs = {
'spark.sql.parquet.binaryAsString': binary_as_string,
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
'spark.sql.legacy.parquet.int96RebaseModeInRead': 'CORRECTED',
'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}
assert_gpu_and_cpu_are_equal_collect(read_func(data_path), conf=all_confs)

@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ abstract class Spark31XShims extends SparkShims with Spark31Xuntil33XShims with
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
TypeSig.ARRAY + TypeSig.BINARY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ abstract class Spark31XdbShims extends Spark31XdbShimsBase with Logging {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
TypeSig.ARRAY + TypeSig.BINARY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ trait Spark320PlusShims extends SparkShims with RebaseShims with Logging {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
TypeSig.ARRAY + TypeSig.BINARY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new FileSourceScanExecMeta(fsse, conf, p, r)),
GpuOverrides.exec[BatchScanExec](
"The backend for most file input",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ object SparkShimImpl extends Spark321PlusShims with Spark320until340Shims {
GpuOverrides.exec[FileSourceScanExec](
"Reading data from files, often from Hive tables",
ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP +
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
TypeSig.ARRAY + TypeSig.BINARY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {

// Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,8 @@ object GpuOverrides extends Logging {
sparkSig = TypeSig.cpuAtomics)),
(ParquetFormatType, FileFormatChecks(
cudfRead = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT +
TypeSig.ARRAY + TypeSig.MAP + GpuTypeShims.additionalParquetSupportedTypes).nested(),
TypeSig.ARRAY + TypeSig.MAP + TypeSig.BINARY +
GpuTypeShims.additionalParquetSupportedTypes).nested(),
cudfWrite = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT +
TypeSig.ARRAY + TypeSig.MAP + GpuTypeShims.additionalParquetSupportedTypes).nested(),
sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, PartitioningAwareFileIndex, SchemaColumnConvertNotSupportedException}
import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1329,13 +1330,14 @@ trait ParquetPartitionReaderBase extends Logging with Arm with ScanWithMetrics
*
* @param isCaseSensitive if it is case sensitive
* @param useFieldId if enabled `spark.sql.parquet.fieldId.read.enabled`
* @return a sequence of column names following the order of readDataSchema
* @return a sequence of tuple of (column names, column dataType) following the order of
* readDataSchema
*/
protected def toCudfColumnNames(
protected def toCudfColumnNamesAndDataTypes(
readDataSchema: StructType,
fileSchema: MessageType,
isCaseSensitive: Boolean,
useFieldId: Boolean): Seq[String] = {
useFieldId: Boolean): Seq[(String, StructField)] = {

// map from field ID to the parquet column name
val fieldIdToNameMap = ParquetSchemaClipShims.fieldIdToNameMap(useFieldId, fileSchema)
Expand All @@ -1348,8 +1350,8 @@ trait ParquetPartitionReaderBase extends Logging with Arm with ScanWithMetrics
// StructField("c3", IntegerType))
// File schema is:
// message spark_schema {
// optional int32 c1 = 1 (field is is 1),
// optional int32 c2 = 2 (field is is 2),
// optional int32 c1 = 1 (field ID is 1),
// optional int32 c2 = 2 (field ID is 2),
// optional int32 c3,
// }
// ID = 55 not matched, returns ["c1", "c3"]
Expand All @@ -1371,21 +1373,35 @@ trait ParquetPartitionReaderBase extends Logging with Arm with ScanWithMetrics
clippedReadFields.map { f =>
if (useFieldId && ParquetSchemaClipShims.hasFieldId(f)) {
// find the parquet column name
fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f))
(fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f)), f)
} else {
m.get(f.name).getOrElse(f.name)
(m.get(f.name).getOrElse(f.name), f)
}
}
} else {
clippedReadFields.map { f =>
if (useFieldId && ParquetSchemaClipShims.hasFieldId(f)) {
fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f))
(fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f)), f)
} else {
f.name
(f.name, f)
}
}
}
}

val sparkToParquetSchema = new SparkToParquetSchemaConverter()
def getParquetOptions(clippedSchema: MessageType, useFieldId: Boolean): ParquetOptions = {
val includeColumns = toCudfColumnNamesAndDataTypes(readDataSchema, clippedSchema,
isSchemaCaseSensitive, useFieldId)
val builder = ParquetOptions.builder().withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
includeColumns.foreach { t =>
builder.includeColumn(t._1,
t._2.dataType == BinaryType &&
sparkToParquetSchema.convertField(t._2).asPrimitiveType().getPrimitiveTypeName
== PrimitiveTypeName.BINARY)
}
builder.build()
}
}

// Parquet schema wrapper
Expand Down Expand Up @@ -1565,11 +1581,7 @@ class MultiFileParquetPartitionReader(
// Dump parquet data into a file
dumpDataToFile(dataBuffer, dataSize, splits, Option(debugDumpPrefix), Some("parquet"))

val includeColumns = toCudfColumnNames(readDataSchema, clippedSchema,
isSchemaCaseSensitive, useFieldId)
val parseOpts = ParquetOptions.builder()
.withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
.includeColumn(includeColumns: _*).build()
val parseOpts = getParquetOptions(clippedSchema, useFieldId)

// About to start using the GPU
GpuSemaphore.acquireIfNecessary(TaskContext.get(), metrics(SEMAPHORE_WAIT_TIME))
Expand Down Expand Up @@ -1868,12 +1880,7 @@ class MultiFileCloudParquetPartitionReader(

// Dump parquet data into a file
dumpDataToFile(hostBuffer, dataSize, files, Option(debugDumpPrefix), Some("parquet"))

val includeColumns = toCudfColumnNames(readDataSchema, clippedSchema,
isSchemaCaseSensitive, useFieldId)
val parseOpts = ParquetOptions.builder()
.withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
.includeColumn(includeColumns: _*).build()
val parseOpts = getParquetOptions(clippedSchema, useFieldId)

// about to start using the GPU
GpuSemaphore.acquireIfNecessary(TaskContext.get(), metrics(SEMAPHORE_WAIT_TIME))
Expand Down Expand Up @@ -2012,12 +2019,7 @@ class ParquetPartitionReader(

// Dump parquet data into a file
dumpDataToFile(dataBuffer, dataSize, Array(split), Option(debugDumpPrefix), Some("parquet"))

val includeColumns = toCudfColumnNames(readDataSchema, clippedParquetSchema,
isSchemaCaseSensitive, useFieldId)
val parseOpts = ParquetOptions.builder()
.withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
.includeColumn(includeColumns: _*).build()
val parseOpts = getParquetOptions(clippedParquetSchema, useFieldId)

// about to start using the GPU
GpuSemaphore.acquireIfNecessary(TaskContext.get(), metrics(SEMAPHORE_WAIT_TIME))
Expand Down
2 changes: 1 addition & 1 deletion tools/src/main/resources/supportedDataSource.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ Iceberg,read,S,S,S,S,S,S,S,S,PS,S,S,NA,NS,NA,PS,PS,PS,NS
JSON,read,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO,CO
ORC,read,S,S,S,S,S,S,S,S,PS,S,S,NA,NS,NA,PS,PS,PS,NS
ORC,write,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
Parquet,read,S,S,S,S,S,S,S,S,PS,S,S,NA,NS,NA,PS,PS,PS,NS
Parquet,read,S,S,S,S,S,S,S,S,PS,S,S,NA,S,NA,PS,PS,PS,NS
Parquet,write,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
2 changes: 1 addition & 1 deletion tools/src/main/resources/supportedExecs.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Exec,Supported,Notes,Params,BOOLEAN,BYTE,SHORT,INT,LONG,FLOAT,DOUBLE,DATE,TIMEST
CoalesceExec,S,None,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS
CollectLimitExec,NS,This is disabled by default because Collect Limit replacement can be slower on the GPU; if huge number of rows in a batch it could help by limiting the number of rows transferred from GPU to CPU,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS
ExpandExec,S,None,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS
FileSourceScanExec,S,None,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS
FileSourceScanExec,S,None,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,S,NS,PS,PS,PS,NS
FilterExec,S,None,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS
GenerateExec,S,None,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS
GlobalLimitExec,S,None,Input/Output,S,S,S,S,S,S,S,S,PS,S,S,S,NS,NS,PS,PS,PS,NS
Expand Down

0 comments on commit 8d14f8c

Please sign in to comment.