Skip to content

Commit

Permalink
Fix parquet binary reads to do the transformation in the plugin [data…
Browse files Browse the repository at this point in the history
…bricks] (#6292)

Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Aug 11, 2022
1 parent f3f6bab commit 8c0e81e
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 22 deletions.
14 changes: 14 additions & 0 deletions integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,19 @@ def _gen_random(self, rand):
def start(self, rand):
self._start(rand, lambda: self._gen_random(rand))

class BinaryGen(DataGen):
"""Generate BinaryType values"""
def __init__(self, min_length=0, max_length=20, nullable=True):
super().__init__(BinaryType(), nullable=nullable)
self._min_length = min_length
self._max_length = max_length

def start(self, rand):
def gen_bytes():
length = rand.randint(self._min_length, self._max_length)
return bytes([ rand.randint(0, 255) for _ in range(length) ])
self._start(rand, gen_bytes)

def skip_if_not_utc():
if (not is_tz_utc()):
skip_unless_precommit_tests('The java system time zone is not set to UTC')
Expand Down Expand Up @@ -883,6 +896,7 @@ def gen_scalars_for_sql(data_gen, count, seed=0, force_no_nulls=False):
boolean_gen = BooleanGen()
date_gen = DateGen()
timestamp_gen = TimestampGen()
binary_gen = BinaryGen()
null_gen = NullGen()

numeric_gens = [byte_gen, short_gen, int_gen, long_gen, float_gen, double_gen]
Expand Down
9 changes: 7 additions & 2 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,14 @@ def test_parquet_read_round_trip_binary(std_input_path, read_func, binary_as_str

@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql])
@pytest.mark.parametrize('binary_as_string', [True, False])
def test_binary_df_read(spark_tmp_path, binary_as_string, read_func):
@pytest.mark.parametrize('data_gen', [binary_gen,
ArrayGen(binary_gen),
StructGen([('a_1', binary_gen), ('a_2', string_gen)]),
StructGen([('a_1', ArrayGen(binary_gen))]),
MapGen(ByteGen(nullable=False), binary_gen)], ids=idfn)
def test_binary_df_read(spark_tmp_path, binary_as_string, read_func, data_gen):
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(lambda spark: unary_op_df(spark, StringGen()).selectExpr("cast(a as binary)").write.parquet(data_path))
with_cpu_session(lambda spark: unary_op_df(spark, data_gen).write.parquet(data_path))
all_confs = {
'spark.sql.parquet.binaryAsString': binary_as_string,
# set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, PartitioningAwareFileIndex, SchemaColumnConvertNotSupportedException}
import org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1330,14 +1329,13 @@ trait ParquetPartitionReaderBase extends Logging with Arm with ScanWithMetrics
*
* @param isCaseSensitive if it is case sensitive
* @param useFieldId if enabled `spark.sql.parquet.fieldId.read.enabled`
* @return a sequence of tuple of (column names, column dataType) following the order of
* readDataSchema
* @return a sequence of tuple of column names following the order of readDataSchema
*/
protected def toCudfColumnNamesAndDataTypes(
protected def toCudfColumnNames(
readDataSchema: StructType,
fileSchema: MessageType,
isCaseSensitive: Boolean,
useFieldId: Boolean): Seq[(String, StructField)] = {
useFieldId: Boolean): Seq[String] = {

// map from field ID to the parquet column name
val fieldIdToNameMap = ParquetSchemaClipShims.fieldIdToNameMap(useFieldId, fileSchema)
Expand Down Expand Up @@ -1373,34 +1371,29 @@ trait ParquetPartitionReaderBase extends Logging with Arm with ScanWithMetrics
clippedReadFields.map { f =>
if (useFieldId && ParquetSchemaClipShims.hasFieldId(f)) {
// find the parquet column name
(fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f)), f)
fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f))
} else {
(m.get(f.name).getOrElse(f.name), f)
m.get(f.name).getOrElse(f.name)
}
}
} else {
clippedReadFields.map { f =>
if (useFieldId && ParquetSchemaClipShims.hasFieldId(f)) {
(fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f)), f)
fieldIdToNameMap(ParquetSchemaClipShims.getFieldId(f))
} else {
(f.name, f)
f.name
}
}
}
}

val sparkToParquetSchema = new SparkToParquetSchemaConverter()
def getParquetOptions(clippedSchema: MessageType, useFieldId: Boolean): ParquetOptions = {
val includeColumns = toCudfColumnNamesAndDataTypes(readDataSchema, clippedSchema,
val includeColumns = toCudfColumnNames(readDataSchema, clippedSchema,
isSchemaCaseSensitive, useFieldId)
val builder = ParquetOptions.builder().withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
includeColumns.foreach { t =>
builder.includeColumn(t._1,
t._2.dataType == BinaryType &&
sparkToParquetSchema.convertField(t._2).asPrimitiveType().getPrimitiveTypeName
== PrimitiveTypeName.BINARY)
}
builder.build()
ParquetOptions.builder()
.withTimeUnit(DType.TIMESTAMP_MICROSECONDS)
.includeColumn(includeColumns : _*)
.build()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.nvidia.spark.rapids.shims.ParquetSchemaClipShims
import org.apache.parquet.schema._
import org.apache.parquet.schema.Type.Repetition

import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.rapids.shims.RapidsErrorUtils
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -510,7 +511,8 @@ object ParquetSchemaUtils extends Arm {
}
SchemaUtils.evolveSchemaIfNeededAndClose(table, fileSparkSchema, sparkSchema,
caseSensitive, Some(evolveSchemaCasts),
existsUnsignedType(fileSchema.asGroupType()))
existsUnsignedType(fileSchema.asGroupType()) ||
TrampolineUtil.dataTypeExistsRecursively(sparkSchema, _.isInstanceOf[BinaryType]))
}

/**
Expand Down

0 comments on commit 8c0e81e

Please sign in to comment.