diff --git a/docs/supported_ops.md b/docs/supported_ops.md index ed4ec7a95d1..c8cb78caf1c 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -21517,9 +21517,9 @@ dates or timestamps, or for a lack of type coercion support. NS -PS
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested BINARY, MAP, STRUCT, UDT
-NS +PS
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested BINARY, MAP, UDT
NS +PS
max nested DECIMAL precision of 18;
UTC is only supported TZ for nested TIMESTAMP;
missing nested BINARY, MAP, UDT
NS diff --git a/integration_tests/src/main/python/orc_test.py b/integration_tests/src/main/python/orc_test.py index 2e6be158be9..0f5aa5575ff 100644 --- a/integration_tests/src/main/python/orc_test.py +++ b/integration_tests/src/main/python/orc_test.py @@ -20,6 +20,7 @@ from marks import * from pyspark.sql.types import * from spark_session import with_cpu_session, with_spark_session +from parquet_test import _nested_pruning_schemas def read_orc_df(data_path): return lambda spark : spark.read.orc(data_path) @@ -58,14 +59,29 @@ def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl, string_gen, boolean_gen, DateGen(start=date(1590, 1, 1)), TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))] + decimal_gens_no_neg +orc_basic_struct_gen = StructGen([['child'+str(ind), sub_gen] for ind, sub_gen in enumerate(orc_basic_gens)]) + # Some array gens, but not all because of nesting orc_array_gens_sample = [ArrayGen(sub_gen) for sub_gen in orc_basic_gens] + [ ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10), ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10), - ArrayGen(ArrayGen(decimal_gen_default, max_length=10), max_length=10)] + ArrayGen(ArrayGen(decimal_gen_default, max_length=10), max_length=10), + ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))] + +# Some struct gens, but not all because of nesting. +# No empty struct gen because it leads to an error as below. +# ''' +# E pyspark.sql.utils.AnalysisException: +# E Datasource does not support writing empty or nested empty schemas. +# E Please make sure the data schema has at least one or more column(s). +# ''' +orc_struct_gens_sample = [orc_basic_struct_gen, + StructGen([['child0', byte_gen], ['child1', orc_basic_struct_gen]]), + StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])] orc_gens_list = [orc_basic_gens, orc_array_gens_sample, + orc_struct_gens_sample, pytest.param([date_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/131')), pytest.param([timestamp_gen], marks=pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/131'))] @@ -119,7 +135,10 @@ def test_read_round_trip(spark_tmp_path, orc_gens, read_func, reader_confs, v1_e @pytest.mark.parametrize('reader_confs', reader_opt_confs, ids=idfn) def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func, v1_enabled_list, reader_confs): data_path = spark_tmp_path + '/ORC_DATA' - gen_list = [('a', RepeatSeqGen(orc_gen, 100)), ('b', orc_gen)] + # Append two struct columns to verify nested predicate pushdown. + gen_list = [('a', RepeatSeqGen(orc_gen, 100)), ('b', orc_gen), + ('s1', StructGen([['sa', orc_gen]])), + ('s2', StructGen([['sa', StructGen([['ssa', orc_gen]])]]))] s0 = gen_scalar(orc_gen, force_no_nulls=True) with_cpu_session( lambda spark : gen_df(spark, gen_list).orderBy('a').write.orc(data_path)) @@ -127,7 +146,7 @@ def test_pred_push_round_trip(spark_tmp_path, orc_gen, read_func, v1_enabled_lis all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list}) rf = read_func(data_path) assert_gpu_and_cpu_are_equal_collect( - lambda spark: rf(spark).select(f.col('a') >= s0), + lambda spark: rf(spark).select(f.col('a') >= s0, f.col('s1.sa') >= s0, f.col('s2.sa.ssa') >= s0), conf=all_confs) orc_compress_options = ['none', 'uncompressed', 'snappy', 'zlib'] @@ -323,3 +342,34 @@ def test_missing_column_names_filter(spark_tmp_table_factory, reader_confs): assert_gpu_and_cpu_are_equal_collect( lambda spark : spark.sql("SELECT _col3,_col2 FROM {} WHERE _col2 = '155'".format(table_name)), all_confs) + + +@pytest.mark.parametrize('data_gen,read_schema', _nested_pruning_schemas, ids=idfn) +@pytest.mark.parametrize('reader_confs', reader_opt_confs, ids=idfn) +@pytest.mark.parametrize('v1_enabled_list', ["", "orc"]) +@pytest.mark.parametrize('nested_enabled', ["true", "false"]) +def test_read_nested_pruning(spark_tmp_path, data_gen, read_schema, reader_confs, v1_enabled_list, nested_enabled): + data_path = spark_tmp_path + '/ORC_DATA' + with_cpu_session( + lambda spark : gen_df(spark, data_gen).write.orc(data_path)) + all_confs = reader_confs.copy() + all_confs.update({'spark.sql.sources.useV1SourceList': v1_enabled_list, + 'spark.sql.optimizer.nestedSchemaPruning.enabled': nested_enabled}) + # This is a hack to get the type in a slightly less verbose way + rs = StructGen(read_schema, nullable=False).data_type + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.read.schema(rs).orc(data_path), + conf=all_confs) + + +# This is for the corner case of reading only a strcut column that has no nulls. +# Then there will be no streams in a stripe connecting to this column(Its ROW_INDEX +# streams have been pruned by the Plugin.), and CUDF throws an exception for such case. +@pytest.mark.xfail(reason='https://github.com/rapidsai/cudf/issues/8878') +def test_read_struct_without_stream(spark_tmp_path, reader_confs): + data_gen = StructGen([['c_byte', ByteGen(nullable=False)]], nullable=False) + data_path = spark_tmp_path + '/ORC_DATA' + with_cpu_session( + lambda spark : unary_op_df(spark, data_gen, 10).write.orc(data_path)) + assert_gpu_and_cpu_are_equal_collect( + lambda spark : spark.read.orc(data_path)) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala index 44da53af13b..157d1483be4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScan.scala @@ -393,6 +393,9 @@ trait OrcCommonFunctions extends OrcCodecWritingHelper { /** * Cast columns with precision that can be stored in an int to DECIMAL32, to save space. + * Besides the plugin makes the assumption that if the precision is small enough to fit + * in a DECIMAL32, then CUDF has it stored as a DECIMAL32. Getting this wrong may lead + * to a number of problems later on. * * @param table the input table, will be closed after returning. * @param schema the schema of the table @@ -1170,7 +1173,7 @@ private case class GpuOrcFileFilterHandler( * @param fileSchema input file's ORC schema * @param readSchema ORC schema for what will be read * @param isCaseAware true if field names are case-sensitive - * @return read schema mapped to the file's field names + * @return read schema if check passes. */ private def checkSchemaCompatibility( fileSchema: TypeDescription, @@ -1187,19 +1190,48 @@ private case class GpuOrcFileFilterHandler( val readerFieldNames = readSchema.getFieldNames.asScala val readerChildren = readSchema.getChildren.asScala - val newReadSchema = TypeDescription.createStruct() readerFieldNames.zip(readerChildren).foreach { case (readField, readType) => val (fileType, fileFieldName) = fileTypesMap.getOrElse(readField, (null, null)) - if (readType != fileType) { + // When column pruning is enabled, the readType is not always equal to the fileType, + // may be part of the fileType. e.g. + // read type: struct + // file type: struct + if (!isSchemaCompatible(fileType, readType)) { throw new QueryExecutionException("Incompatible schemas for ORC file" + s" at ${partFile.filePath}\n" + s" file schema: $fileSchema\n" + s" read schema: $readSchema") } - newReadSchema.addField(fileFieldName, fileType) } + // To support nested column pruning, the original read schema (pruned) should be + // returned, instead of creating a new schema from the children of the file schema, + // who may contain more nested columns than read schema, causing mismatch between the + // pruned data and the pruned schema. + readSchema + } - newReadSchema + /** + * The read schema is compatible with the file schema only when + * 1) They are equal to each other + * 2) The read schema is part of the file schema for struct types. + * + * @param fileSchema input file's ORC schema + * @param readSchema ORC schema for what will be read + * @return true if they are compatible, otherwise false + */ + private def isSchemaCompatible( + fileSchema: TypeDescription, + readSchema: TypeDescription): Boolean = { + fileSchema == readSchema || + fileSchema != null && readSchema != null && + fileSchema.getCategory == readSchema.getCategory && { + if (readSchema.getChildren != null) { + readSchema.getChildren.asScala.forall(rc => + fileSchema.getChildren.asScala.exists(fc => isSchemaCompatible(fc, rc))) + } else { + false + } + } } /** diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 1da18a6f237..da37c402685 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -769,7 +769,8 @@ object GpuOverrides { sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.UDT).nested())), (OrcFormatType, FileFormatChecks( - cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.DECIMAL_64).nested(), + cudfRead = (TypeSig.commonCudfTypes + TypeSig.ARRAY + TypeSig.DECIMAL_64 + + TypeSig.STRUCT).nested(), cudfWrite = TypeSig.commonCudfTypes, sparkSig = (TypeSig.atomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.UDT).nested()))) diff --git a/tools/src/main/resources/supportedDataSource.csv b/tools/src/main/resources/supportedDataSource.csv index 78c907e3e77..f5d90fb9a52 100644 --- a/tools/src/main/resources/supportedDataSource.csv +++ b/tools/src/main/resources/supportedDataSource.csv @@ -1,4 +1,4 @@ Format,Direction,BOOLEAN,BYTE,SHORT,INT,LONG,FLOAT,DOUBLE,DATE,TIMESTAMP,STRING,DECIMAL,NULL,BINARY,CALENDAR,ARRAY,MAP,STRUCT,UDT CSV,read,CO,CO,CO,CO,CO,CO,CO,CO,CO,S,CO,NA,NS,NA,NA,NA,NA,NA -ORC,read,S,S,S,S,S,S,S,S,PS,S,CO,NA,NS,NA,PS,NS,NS,NS +ORC,read,S,S,S,S,S,S,S,S,PS,S,CO,NA,NS,NA,PS,NS,PS,NS Parquet,read,S,S,S,S,S,S,S,S,PS,S,CO,NA,NS,NA,PS,PS,PS,NS diff --git a/tools/src/test/resources/QualificationExpectations/complex_dec_expectation.csv b/tools/src/test/resources/QualificationExpectations/complex_dec_expectation.csv index 4a59efccf04..dc35a000774 100644 --- a/tools/src/test/resources/QualificationExpectations/complex_dec_expectation.csv +++ b/tools/src/test/resources/QualificationExpectations/complex_dec_expectation.csv @@ -1,2 +1,2 @@ App Name,App ID,Score,Potential Problems,SQL DF Duration,SQL Dataframe Task Duration,App Duration,Executor CPU Time Percent,App Duration Estimated,SQL Duration with Potential Problems,SQL Ids with Failures,Read Score Percent,Read File Format Score,Unsupported Read File Formats and Types -Spark shell,local-1626104300434,1322.1,DECIMAL,2429,1469,131104,88.35,false,160,"",20,50.0,Parquet[decimal];ORC[map:struct:decimal] +Spark shell,local-1626104300434,1322.1,DECIMAL,2429,1469,131104,88.35,false,160,"",20,50.0,Parquet[decimal];ORC[map:decimal]