Skip to content

Commit

Permalink
Add in support for FIXED_LEN_BYTE_ARRAY as binary (#8404)
Browse files Browse the repository at this point in the history
* Add in support for FIXED_LEN_BYTE_ARRAY as binary

Signed-off-by: Robert (Bobby) Evans <[email protected]>

* Update tests to run on GPU

---------

Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Sep 12, 2023
1 parent 8399836 commit 7e1e9d3
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1035,8 +1035,12 @@ private case class GpuParquetFileFilterHandler(
case PrimitiveTypeName.INT96 if dt == DataTypes.TimestampType =>
return

case PrimitiveTypeName.BINARY if dt == DataTypes.StringType ||
dt == DataTypes.BinaryType || canReadAsBinaryDecimal(pt, dt) =>
case PrimitiveTypeName.BINARY if dt == DataTypes.StringType =>
// PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY for StringType is not supported by parquet
return

case PrimitiveTypeName.BINARY | PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
if dt == DataTypes.BinaryType =>
return

case PrimitiveTypeName.BINARY | PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ object ParquetSchemaClipShims {
if (typeAnnotation == null) s"$typeName" else s"$typeName ($typeAnnotation)"

def typeNotImplemented() =
TrampolineUtil.throwAnalysisException(s"Parquet type not yet supported: ${typeString}")
TrampolineUtil.throwAnalysisException(s"Parquet type not yet supported: $typeString")

def illegalType() =
TrampolineUtil.throwAnalysisException(s"Illegal Parquet type: $parquetType")
Expand Down Expand Up @@ -185,7 +185,7 @@ object ParquetSchemaClipShims {
if (!SQLConf.get.isParquetINT96AsTimestamp) {
TrampolineUtil.throwAnalysisException(
"INT96 is not supported unless it's interpreted as timestamp. " +
s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
}
TimestampType

Expand All @@ -204,6 +204,7 @@ object ParquetSchemaClipShims {
typeAnnotation match {
case _: DecimalLogicalTypeAnnotation =>
makeDecimalType(Decimal.maxPrecisionForBytes(parquetType.getTypeLength))
case null => BinaryType
case _: IntervalLogicalTypeAnnotation => typeNotImplemented()
case _ => illegalType()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import org.apache.parquet.io.api.{Binary, RecordConsumer}
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.scalatest.concurrent.Eventually

import org.apache.spark.SparkConf
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -388,6 +389,116 @@ class ParquetFormatScanSuite extends SparkQueryCompareTestSuite with Eventually
}, conf = conf)
}

test(s"FIXED_LEN_BYTE_ARRAY(16) BINARY $parserType") {
assume(isSpark340OrLater)
withGpuSparkSession(spark => {
val schema =
"""message spark {
| required fixed_len_byte_array(16) test;
|}
""".stripMargin

withTempDir(spark) { dir =>
val testPath = dir + "/FIXED_BIN16_TEST.parquet"
writeDirect(testPath, schema, { rc =>
rc.message {
rc.field("test", 0) {
rc.addBinary(Binary.fromString("1234567890123456"))
}
}
}, { rc =>
rc.message {
rc.field("test", 0) {
rc.addBinary(Binary.fromString("ABCDEFGHIJKLMNOP"))
}
}
})

val data = spark.read.parquet(testPath).collect()
sameRows(Seq(Row("1234567890123456".getBytes),
Row("ABCDEFGHIJKLMNOP".getBytes)), data)
}
}, conf = conf)
}

test(s"FIXED_LEN_BYTE_ARRAY(16) binaryAsString $parserType") {
assume(isSpark340OrLater)
// Parquet does not let us tag a FIXED_LEN_BYTE_ARRAY with utf8 to make it a string
// Spark ignores binaryAsString for it, so we need to test that we do the same
val conf = new SparkConf()
.set("spark.rapids.sql.format.parquet.reader.footer.type", parserType)
.set("spark.sql.parquet.binaryAsString", "true")
withGpuSparkSession(spark => {
val schema =
"""message spark {
| required fixed_len_byte_array(16) test;
|}
""".stripMargin

withTempDir(spark) { dir =>
val testPath = dir + "/FIXED_BIN16_TEST.parquet"
writeDirect(testPath, schema, { rc =>
rc.message {
rc.field("test", 0) {
rc.addBinary(Binary.fromString("1234567890123456"))
}
}
}, { rc =>
rc.message {
rc.field("test", 0) {
rc.addBinary(Binary.fromString("ABCDEFGHIJKLMNOP"))
}
}
})

val data = spark.read.parquet(testPath).collect()
sameRows(Seq(Row("1234567890123456".getBytes),
Row("ABCDEFGHIJKLMNOP".getBytes)), data)
}
}, conf = conf)
}

test(s"FIXED_LEN_BYTE_ARRAY(16) String in Schema $parserType") {
assume(isSpark340OrLater)
// Parquet does not let us tag a FIXED_LEN_BYTE_ARRAY with utf8 to make it a string
// Spark also fails the task if we try to read it as a String so we should verify that
// We also throw an exception.
withGpuSparkSession(spark => {
val schema =
"""message spark {
| required fixed_len_byte_array(16) test;
|}
""".stripMargin

withTempDir(spark) { dir =>
val testPath = dir + "/FIXED_BIN16_TEST.parquet"
writeDirect(testPath, schema, { rc =>
rc.message {
rc.field("test", 0) {
rc.addBinary(Binary.fromString("1234567890123456"))
}
}
}, { rc =>
rc.message {
rc.field("test", 0) {
rc.addBinary(Binary.fromString("ABCDEFGHIJKLMNOP"))
}
}
})

try {
spark.read.schema(StructType(Seq(StructField("test", StringType))))
.parquet(testPath).collect()
fail("We read back in some data, but we expected an exception...")
} catch {
case _: SparkException =>
// It would be nice to verify that the exception is what we expect, but we are not
// doing CPU vs GPU so we will just doing a GPU pass here.
}
}
}, conf = conf)
}

test(s"BSON $parserType") {
withGpuSparkSession(spark => {
val schema =
Expand Down

0 comments on commit 7e1e9d3

Please sign in to comment.