diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 0824235718..9ba1a14ad3 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -608,6 +608,14 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_SCAN_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] = + conf("spark.comet.scan.allowIncompatible") + .doc( + "Comet is not currently fully compatible with Spark for all datatypes. " + + s"Set this config to true to allow them anyway. $COMPAT_GUIDE.") + .booleanConf + .createWithDefault(true) + val COMET_EXPR_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] = conf("spark.comet.expression.allowIncompatible") .doc( diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 8245e7b76b..70097cef48 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -75,6 +75,7 @@ Comet provides the following configuration settings. | spark.comet.parquet.read.parallel.io.enabled | Whether to enable Comet's parallel reader for Parquet files. The parallel reader reads ranges of consecutive data in a file in parallel. It is faster for large files and row groups but uses more resources. | true | | spark.comet.parquet.read.parallel.io.thread-pool.size | The maximum number of parallel threads the parallel reader will use in a single executor. For executors configured with a smaller number of cores, use a smaller number. | 16 | | spark.comet.regexp.allowIncompatible | Comet is not currently fully compatible with Spark for all regular expressions. Set this config to true to allow them anyway. For more information, refer to the Comet Compatibility Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | +| spark.comet.scan.allowIncompatible | Comet is not currently fully compatible with Spark for all datatypes. Set this config to true to allow them anyway. For more information, refer to the Comet Compatibility Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html). | true | | spark.comet.scan.enabled | Whether to enable native scans. When this is turned on, Spark will use Comet to read supported data sources (currently only Parquet is supported natively). Note that to enable native vectorized execution, both this config and 'spark.comet.exec.enabled' need to be enabled. | true | | spark.comet.scan.preFetch.enabled | Whether to enable pre-fetching feature of CometScan. | false | | spark.comet.scan.preFetch.threadNum | The number of threads running pre-fetching for CometScan. Effective if spark.comet.scan.preFetch.enabled is enabled. Note that more pre-fetching threads means more memory requirement to store pre-fetched row groups. | 2 | diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 1cfd869e63..8a79511157 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -1352,6 +1352,15 @@ object CometSparkSessionExtensions extends Logging { org.apache.spark.SPARK_VERSION >= "4.0" } + def isComplexTypeReaderEnabled(conf: SQLConf): Boolean = { + CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) == CometConf.SCAN_NATIVE_ICEBERG_COMPAT || + CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) == CometConf.SCAN_NATIVE_DATAFUSION + } + + def usingDataFusionParquetReader(conf: SQLConf): Boolean = { + isComplexTypeReaderEnabled(conf) && !CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get(conf) + } + /** Calculates required memory overhead in MB per executor process for Comet. */ def getCometMemoryOverheadInMiB(sparkConf: SparkConf): Long = { // `spark.executor.memory` default value is 1g diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala index 5c765003cf..f90916cb20 100644 --- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala +++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala @@ -19,6 +19,7 @@ package org.apache.comet +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ trait DataTypeSupport { @@ -35,12 +36,15 @@ trait DataTypeSupport { def isAdditionallySupported(dt: DataType): Boolean = false private def isGloballySupported(dt: DataType): Boolean = dt match { + case ByteType | ShortType + if CometSparkSessionExtensions.isComplexTypeReaderEnabled(SQLConf.get) && + !CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get() => + false case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | StringType | _: DecimalType | DateType | TimestampType => true case t: DataType if t.typeName == "timestamp_ntz" => true - true case _ => false } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 27d8e2357f..21a2dc8807 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -59,6 +59,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private val timestampPattern = "0123456789/:T" + whitespaceChars + lazy val usingDataFusionParquetReader: Boolean = + CometSparkSessionExtensions.usingDataFusionParquetReader(conf) + test("all valid cast combinations covered") { val names = testNames @@ -145,88 +148,148 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // CAST from ByteType test("cast ByteType to BooleanType") { - castTest(generateBytes(), DataTypes.BooleanType) + castTest( + generateBytes(), + DataTypes.BooleanType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ByteType to ShortType") { - castTest(generateBytes(), DataTypes.ShortType) + castTest( + generateBytes(), + DataTypes.ShortType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ByteType to IntegerType") { - castTest(generateBytes(), DataTypes.IntegerType) + castTest( + generateBytes(), + DataTypes.IntegerType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ByteType to LongType") { - castTest(generateBytes(), DataTypes.LongType) + castTest( + generateBytes(), + DataTypes.LongType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ByteType to FloatType") { - castTest(generateBytes(), DataTypes.FloatType) + castTest( + generateBytes(), + DataTypes.FloatType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ByteType to DoubleType") { - castTest(generateBytes(), DataTypes.DoubleType) + castTest( + generateBytes(), + DataTypes.DoubleType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ByteType to DecimalType(10,2)") { - castTest(generateBytes(), DataTypes.createDecimalType(10, 2)) + castTest( + generateBytes(), + DataTypes.createDecimalType(10, 2), + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ByteType to StringType") { - castTest(generateBytes(), DataTypes.StringType) + castTest( + generateBytes(), + DataTypes.StringType, + hasIncompatibleType = usingDataFusionParquetReader) } ignore("cast ByteType to BinaryType") { - castTest(generateBytes(), DataTypes.BinaryType) + castTest( + generateBytes(), + DataTypes.BinaryType, + hasIncompatibleType = usingDataFusionParquetReader) } ignore("cast ByteType to TimestampType") { // input: -1, expected: 1969-12-31 15:59:59.0, actual: 1969-12-31 15:59:59.999999 - castTest(generateBytes(), DataTypes.TimestampType) + castTest( + generateBytes(), + DataTypes.TimestampType, + hasIncompatibleType = usingDataFusionParquetReader) } // CAST from ShortType test("cast ShortType to BooleanType") { - castTest(generateShorts(), DataTypes.BooleanType) + castTest( + generateShorts(), + DataTypes.BooleanType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ShortType to ByteType") { // https://github.com/apache/datafusion-comet/issues/311 - castTest(generateShorts(), DataTypes.ByteType) + castTest( + generateShorts(), + DataTypes.ByteType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ShortType to IntegerType") { - castTest(generateShorts(), DataTypes.IntegerType) + castTest( + generateShorts(), + DataTypes.IntegerType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ShortType to LongType") { - castTest(generateShorts(), DataTypes.LongType) + castTest( + generateShorts(), + DataTypes.LongType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ShortType to FloatType") { - castTest(generateShorts(), DataTypes.FloatType) + castTest( + generateShorts(), + DataTypes.FloatType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ShortType to DoubleType") { - castTest(generateShorts(), DataTypes.DoubleType) + castTest( + generateShorts(), + DataTypes.DoubleType, + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ShortType to DecimalType(10,2)") { - castTest(generateShorts(), DataTypes.createDecimalType(10, 2)) + castTest( + generateShorts(), + DataTypes.createDecimalType(10, 2), + hasIncompatibleType = usingDataFusionParquetReader) } test("cast ShortType to StringType") { - castTest(generateShorts(), DataTypes.StringType) + castTest( + generateShorts(), + DataTypes.StringType, + hasIncompatibleType = usingDataFusionParquetReader) } ignore("cast ShortType to BinaryType") { - castTest(generateShorts(), DataTypes.BinaryType) + castTest( + generateShorts(), + DataTypes.BinaryType, + hasIncompatibleType = usingDataFusionParquetReader) } ignore("cast ShortType to TimestampType") { // input: -1003, expected: 1969-12-31 15:43:17.0, actual: 1969-12-31 15:59:59.998997 - castTest(generateShorts(), DataTypes.TimestampType) + castTest( + generateShorts(), + DataTypes.TimestampType, + hasIncompatibleType = usingDataFusionParquetReader) } // CAST from integer @@ -1069,7 +1132,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - private def castTest(input: DataFrame, toType: DataType, testAnsi: Boolean = true): Unit = { + private def castTest( + input: DataFrame, + toType: DataType, + hasIncompatibleType: Boolean = false, + testAnsi: Boolean = true): Unit = { // we now support the TryCast expression in Spark 3.3 withTempPath { dir => @@ -1079,12 +1146,20 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) { // cast() should return null for invalid inputs when ansi mode is disabled val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order by a") - checkSparkAnswerAndOperator(df) + if (hasIncompatibleType) { + checkSparkAnswer(df) + } else { + checkSparkAnswerAndOperator(df) + } // try_cast() should always return null for invalid inputs val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - checkSparkAnswerAndOperator(df2) + if (hasIncompatibleType) { + checkSparkAnswer(df2) + } else { + checkSparkAnswerAndOperator(df2) + } } if (testAnsi) { @@ -1140,7 +1215,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // try_cast() should always return null for invalid inputs val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a") - checkSparkAnswerAndOperator(df2) + if (hasIncompatibleType) { + checkSparkAnswer(df2) + } else { + checkSparkAnswerAndOperator(df2) + } } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index dd3845a7cb..9e8b14ea6f 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -125,6 +125,45 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("uint data type support") { + Seq(true, false).foreach { dictionaryEnabled => + // TODO: Once the question of what to get back from uint_8, uint_16 types is resolved, + // we can also update this test to check for COMET_SCAN_ALLOW_INCOMPATIBLE=true + Seq(false).foreach { allowIncompatible => + { + withSQLConf(CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> allowIncompatible.toString) { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "testuint.parquet") + makeParquetFileAllTypes( + path, + dictionaryEnabled = dictionaryEnabled, + Byte.MinValue, + Byte.MaxValue) + withParquetTable(path.toString, "tbl") { + val qry = "select _9 from tbl order by _11" + if (CometSparkSessionExtensions.isComplexTypeReaderEnabled(conf)) { + if (!allowIncompatible) { + checkSparkAnswer(qry) + } else { + // need to convert the values to unsigned values + val expected = (Byte.MinValue to Byte.MaxValue) + .map(v => { + if (v < 0) Byte.MaxValue.toShort - v else v + }) + .toDF("a") + checkAnswer(sql(qry), expected) + } + } else { + checkSparkAnswerAndOperator(qry) + } + } + } + } + } + } + } + } + test("null literals") { val batchSize = 1000 Seq(true, false).foreach { dictionaryEnabled => @@ -142,6 +181,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(sqlString) } } + } } diff --git a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala index 56b3ba9d2c..9ba0726e2a 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala @@ -45,7 +45,7 @@ import org.apache.spark.unsafe.types.UTF8String import com.google.common.primitives.UnsignedLong -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometSparkSessionExtensions} import org.apache.comet.CometSparkSessionExtensions.{isSpark34Plus, isSpark40Plus} abstract class ParquetReadSuite extends CometTestBase { @@ -139,7 +139,10 @@ abstract class ParquetReadSuite extends CometTestBase { i.toDouble, DateTimeUtils.toJavaDate(i)) } - checkParquetScan(data) + if (!CometSparkSessionExtensions.isComplexTypeReaderEnabled( + conf) || CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get()) { + checkParquetScan(data) + } checkParquetFile(data) } } @@ -159,7 +162,10 @@ abstract class ParquetReadSuite extends CometTestBase { i.toDouble, DateTimeUtils.toJavaDate(i)) } - checkParquetScan(data) + if (!CometSparkSessionExtensions.isComplexTypeReaderEnabled( + conf) || CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get()) { + checkParquetScan(data) + } checkParquetFile(data) } } @@ -178,7 +184,10 @@ abstract class ParquetReadSuite extends CometTestBase { DateTimeUtils.toJavaDate(i)) } val filter = (row: Row) => row.getBoolean(0) - checkParquetScan(data, filter) + if (!CometSparkSessionExtensions.isComplexTypeReaderEnabled( + conf) || CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.get()) { + checkParquetScan(data, filter) + } checkParquetFile(data, filter) } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 215f0c876c..63763aa3b8 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -429,65 +429,130 @@ abstract class CometTestBase makeParquetFileAllTypes(path, dictionaryEnabled, 0, n) } - def makeParquetFileAllTypes( - path: Path, - dictionaryEnabled: Boolean, - begin: Int, - end: Int, - pageSize: Int = 128, - randomSize: Int = 0): Unit = { - val schemaStr = + def getPrimitiveTypesParquetSchema: String = { + if (CometSparkSessionExtensions.usingDataFusionParquetReader(conf)) { + // Comet complex type reader has different behavior for uint_8, uint_16 types. + // The issue stems from undefined behavior in the parquet spec and is tracked + // here: https://github.com/apache/parquet-java/issues/3142 + // here: https://github.com/apache/arrow-rs/issues/7040 + // and here: https://github.com/apache/datafusion-comet/issues/1348 if (isSpark34Plus) { """ - |message root { - | optional boolean _1; - | optional int32 _2(INT_8); - | optional int32 _3(INT_16); - | optional int32 _4; - | optional int64 _5; - | optional float _6; - | optional double _7; - | optional binary _8(UTF8); - | optional int32 _9(UINT_8); - | optional int32 _10(UINT_16); - | optional int32 _11(UINT_32); - | optional int64 _12(UINT_64); - | optional binary _13(ENUM); - | optional FIXED_LEN_BYTE_ARRAY(3) _14; - | optional int32 _15(DECIMAL(5, 2)); - | optional int64 _16(DECIMAL(18, 10)); - | optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37)); - | optional INT64 _18(TIMESTAMP(MILLIS,true)); - | optional INT64 _19(TIMESTAMP(MICROS,true)); - | optional INT32 _20(DATE); - |} + |message root { + | optional boolean _1; + | optional int32 _2(INT_8); + | optional int32 _3(INT_16); + | optional int32 _4; + | optional int64 _5; + | optional float _6; + | optional double _7; + | optional binary _8(UTF8); + | optional int32 _9(UINT_32); + | optional int32 _10(UINT_32); + | optional int32 _11(UINT_32); + | optional int64 _12(UINT_64); + | optional binary _13(ENUM); + | optional FIXED_LEN_BYTE_ARRAY(3) _14; + | optional int32 _15(DECIMAL(5, 2)); + | optional int64 _16(DECIMAL(18, 10)); + | optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37)); + | optional INT64 _18(TIMESTAMP(MILLIS,true)); + | optional INT64 _19(TIMESTAMP(MICROS,true)); + | optional INT32 _20(DATE); + |} """.stripMargin } else { """ - |message root { - | optional boolean _1; - | optional int32 _2(INT_8); - | optional int32 _3(INT_16); - | optional int32 _4; - | optional int64 _5; - | optional float _6; - | optional double _7; - | optional binary _8(UTF8); - | optional int32 _9(UINT_8); - | optional int32 _10(UINT_16); - | optional int32 _11(UINT_32); - | optional int64 _12(UINT_64); - | optional binary _13(ENUM); - | optional binary _14(UTF8); - | optional int32 _15(DECIMAL(5, 2)); - | optional int64 _16(DECIMAL(18, 10)); - | optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37)); - | optional INT64 _18(TIMESTAMP(MILLIS,true)); - | optional INT64 _19(TIMESTAMP(MICROS,true)); - | optional INT32 _20(DATE); - |} + |message root { + | optional boolean _1; + | optional int32 _2(INT_8); + | optional int32 _3(INT_16); + | optional int32 _4; + | optional int64 _5; + | optional float _6; + | optional double _7; + | optional binary _8(UTF8); + | optional int32 _9(UINT_32); + | optional int32 _10(UINT_32); + | optional int32 _11(UINT_32); + | optional int64 _12(UINT_64); + | optional binary _13(ENUM); + | optional binary _14(UTF8); + | optional int32 _15(DECIMAL(5, 2)); + | optional int64 _16(DECIMAL(18, 10)); + | optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37)); + | optional INT64 _18(TIMESTAMP(MILLIS,true)); + | optional INT64 _19(TIMESTAMP(MICROS,true)); + | optional INT32 _20(DATE); + |} """.stripMargin } + } else { + + if (isSpark34Plus) { + """ + |message root { + | optional boolean _1; + | optional int32 _2(INT_8); + | optional int32 _3(INT_16); + | optional int32 _4; + | optional int64 _5; + | optional float _6; + | optional double _7; + | optional binary _8(UTF8); + | optional int32 _9(UINT_8); + | optional int32 _10(UINT_16); + | optional int32 _11(UINT_32); + | optional int64 _12(UINT_64); + | optional binary _13(ENUM); + | optional FIXED_LEN_BYTE_ARRAY(3) _14; + | optional int32 _15(DECIMAL(5, 2)); + | optional int64 _16(DECIMAL(18, 10)); + | optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37)); + | optional INT64 _18(TIMESTAMP(MILLIS,true)); + | optional INT64 _19(TIMESTAMP(MICROS,true)); + | optional INT32 _20(DATE); + |} + """.stripMargin + } else { + """ + |message root { + | optional boolean _1; + | optional int32 _2(INT_8); + | optional int32 _3(INT_16); + | optional int32 _4; + | optional int64 _5; + | optional float _6; + | optional double _7; + | optional binary _8(UTF8); + | optional int32 _9(UINT_8); + | optional int32 _10(UINT_16); + | optional int32 _11(UINT_32); + | optional int64 _12(UINT_64); + | optional binary _13(ENUM); + | optional binary _14(UTF8); + | optional int32 _15(DECIMAL(5, 2)); + | optional int64 _16(DECIMAL(18, 10)); + | optional FIXED_LEN_BYTE_ARRAY(16) _17(DECIMAL(38, 37)); + | optional INT64 _18(TIMESTAMP(MILLIS,true)); + | optional INT64 _19(TIMESTAMP(MICROS,true)); + | optional INT32 _20(DATE); + |} + """.stripMargin + } + } + } + + def makeParquetFileAllTypes( + path: Path, + dictionaryEnabled: Boolean, + begin: Int, + end: Int, + pageSize: Int = 128, + randomSize: Int = 0): Unit = { + // alwaysIncludeUnsignedIntTypes means we include unsignedIntTypes in the test even if the + // reader does not support them + val schemaStr = getPrimitiveTypesParquetSchema val schema = MessageTypeParser.parseMessageType(schemaStr) val writer = createParquetWriter(