From 7e1e9d385783069791021ead1035c671a2ef1921 Mon Sep 17 00:00:00 2001 From: "Robert (Bobby) Evans" Date: Tue, 12 Sep 2023 12:51:05 -0500 Subject: [PATCH] Add in support for FIXED_LEN_BYTE_ARRAY as binary (#8404) * Add in support for FIXED_LEN_BYTE_ARRAY as binary Signed-off-by: Robert (Bobby) Evans * Update tests to run on GPU --------- Signed-off-by: Robert (Bobby) Evans --- .../nvidia/spark/rapids/GpuParquetScan.scala | 8 +- .../rapids/shims/ParquetSchemaClipShims.scala | 5 +- .../sql/rapids/ParquetFormatScanSuite.scala | 113 +++++++++++++++++- 3 files changed, 121 insertions(+), 5 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index f524c9f55db..d13b9617ae9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -1035,8 +1035,12 @@ private case class GpuParquetFileFilterHandler( case PrimitiveTypeName.INT96 if dt == DataTypes.TimestampType => return - case PrimitiveTypeName.BINARY if dt == DataTypes.StringType || - dt == DataTypes.BinaryType || canReadAsBinaryDecimal(pt, dt) => + case PrimitiveTypeName.BINARY if dt == DataTypes.StringType => + // PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY for StringType is not supported by parquet + return + + case PrimitiveTypeName.BINARY | PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY + if dt == DataTypes.BinaryType => return case PrimitiveTypeName.BINARY | PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/ParquetSchemaClipShims.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/ParquetSchemaClipShims.scala index 9518e19de8f..8fc974f1cc1 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/ParquetSchemaClipShims.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/ParquetSchemaClipShims.scala @@ -102,7 +102,7 @@ object ParquetSchemaClipShims { if (typeAnnotation == null) s"$typeName" else s"$typeName ($typeAnnotation)" def typeNotImplemented() = - TrampolineUtil.throwAnalysisException(s"Parquet type not yet supported: ${typeString}") + TrampolineUtil.throwAnalysisException(s"Parquet type not yet supported: $typeString") def illegalType() = TrampolineUtil.throwAnalysisException(s"Illegal Parquet type: $parquetType") @@ -185,7 +185,7 @@ object ParquetSchemaClipShims { if (!SQLConf.get.isParquetINT96AsTimestamp) { TrampolineUtil.throwAnalysisException( "INT96 is not supported unless it's interpreted as timestamp. " + - s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") } TimestampType @@ -204,6 +204,7 @@ object ParquetSchemaClipShims { typeAnnotation match { case _: DecimalLogicalTypeAnnotation => makeDecimalType(Decimal.maxPrecisionForBytes(parquetType.getTypeLength)) + case null => BinaryType case _: IntervalLogicalTypeAnnotation => typeNotImplemented() case _ => illegalType() } diff --git a/tests/src/test/scala/org/apache/spark/sql/rapids/ParquetFormatScanSuite.scala b/tests/src/test/scala/org/apache/spark/sql/rapids/ParquetFormatScanSuite.scala index e974c145347..7df894d003b 100644 --- a/tests/src/test/scala/org/apache/spark/sql/rapids/ParquetFormatScanSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/rapids/ParquetFormatScanSuite.scala @@ -33,8 +33,9 @@ import org.apache.parquet.io.api.{Binary, RecordConsumer} import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** @@ -388,6 +389,116 @@ class ParquetFormatScanSuite extends SparkQueryCompareTestSuite with Eventually }, conf = conf) } + test(s"FIXED_LEN_BYTE_ARRAY(16) BINARY $parserType") { + assume(isSpark340OrLater) + withGpuSparkSession(spark => { + val schema = + """message spark { + | required fixed_len_byte_array(16) test; + |} + """.stripMargin + + withTempDir(spark) { dir => + val testPath = dir + "/FIXED_BIN16_TEST.parquet" + writeDirect(testPath, schema, { rc => + rc.message { + rc.field("test", 0) { + rc.addBinary(Binary.fromString("1234567890123456")) + } + } + }, { rc => + rc.message { + rc.field("test", 0) { + rc.addBinary(Binary.fromString("ABCDEFGHIJKLMNOP")) + } + } + }) + + val data = spark.read.parquet(testPath).collect() + sameRows(Seq(Row("1234567890123456".getBytes), + Row("ABCDEFGHIJKLMNOP".getBytes)), data) + } + }, conf = conf) + } + + test(s"FIXED_LEN_BYTE_ARRAY(16) binaryAsString $parserType") { + assume(isSpark340OrLater) + // Parquet does not let us tag a FIXED_LEN_BYTE_ARRAY with utf8 to make it a string + // Spark ignores binaryAsString for it, so we need to test that we do the same + val conf = new SparkConf() + .set("spark.rapids.sql.format.parquet.reader.footer.type", parserType) + .set("spark.sql.parquet.binaryAsString", "true") + withGpuSparkSession(spark => { + val schema = + """message spark { + | required fixed_len_byte_array(16) test; + |} + """.stripMargin + + withTempDir(spark) { dir => + val testPath = dir + "/FIXED_BIN16_TEST.parquet" + writeDirect(testPath, schema, { rc => + rc.message { + rc.field("test", 0) { + rc.addBinary(Binary.fromString("1234567890123456")) + } + } + }, { rc => + rc.message { + rc.field("test", 0) { + rc.addBinary(Binary.fromString("ABCDEFGHIJKLMNOP")) + } + } + }) + + val data = spark.read.parquet(testPath).collect() + sameRows(Seq(Row("1234567890123456".getBytes), + Row("ABCDEFGHIJKLMNOP".getBytes)), data) + } + }, conf = conf) + } + + test(s"FIXED_LEN_BYTE_ARRAY(16) String in Schema $parserType") { + assume(isSpark340OrLater) + // Parquet does not let us tag a FIXED_LEN_BYTE_ARRAY with utf8 to make it a string + // Spark also fails the task if we try to read it as a String so we should verify that + // We also throw an exception. + withGpuSparkSession(spark => { + val schema = + """message spark { + | required fixed_len_byte_array(16) test; + |} + """.stripMargin + + withTempDir(spark) { dir => + val testPath = dir + "/FIXED_BIN16_TEST.parquet" + writeDirect(testPath, schema, { rc => + rc.message { + rc.field("test", 0) { + rc.addBinary(Binary.fromString("1234567890123456")) + } + } + }, { rc => + rc.message { + rc.field("test", 0) { + rc.addBinary(Binary.fromString("ABCDEFGHIJKLMNOP")) + } + } + }) + + try { + spark.read.schema(StructType(Seq(StructField("test", StringType)))) + .parquet(testPath).collect() + fail("We read back in some data, but we expected an exception...") + } catch { + case _: SparkException => + // It would be nice to verify that the exception is what we expect, but we are not + // doing CPU vs GPU so we will just doing a GPU pass here. + } + } + }, conf = conf) + } + test(s"BSON $parserType") { withGpuSparkSession(spark => { val schema =