diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index 0fc4bde51de..c939138e856 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -19,6 +19,8 @@ assert_gpu_fallback_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_and_cpu_error, assert_py4j_exception from data_gen import * from marks import * +import pyarrow as pa +import pyarrow.parquet as pa_pq from pyspark.sql.types import * from pyspark.sql.functions import * from spark_session import with_cpu_session, with_gpu_session, is_before_spark_320, is_before_spark_330, is_spark_321cdh @@ -125,6 +127,34 @@ def test_parquet_fallback(spark_tmp_path, read_func, disable_conf): conf={disable_conf: 'false', "spark.sql.sources.useV1SourceList": "parquet"}) +@pytest.mark.parametrize('read_func', [read_parquet_df, read_parquet_sql]) +@pytest.mark.parametrize('reader_confs', reader_opt_confs) +@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) +def test_parquet_read_round_trip_binary_as_string(spark_tmp_path, read_func, reader_confs, v1_enabled_list): + gen_list = [("a", string_gen), ("b", int_gen), ("c", string_gen)] + data_path = spark_tmp_path + '/binary_as_string.parquet' + # cast to binary to read back as a string + # NOTE: using pyarrow to write the parquet file because spark doesn't + # produce a parquet file where the binary values are read back as strings, + # ultimately this simulates reading a parquet file produced outside of spark + def create_parquet_file(spark): + df = gen_df(spark, gen_list).select( + f.col('a').cast("BINARY").alias('a'),\ + f.col('b'), f.col('c')) + pa_pq.write_table(pa.Table.from_pandas(df.toPandas()), data_path) + + with_cpu_session(create_parquet_file, conf=rebase_write_corrected_conf) + all_confs = copy_and_update(reader_confs, { + 'spark.sql.sources.useV1SourceList': v1_enabled_list, + 'spark.sql.parquet.binaryAsString': 'true', + # set the int96 rebase mode values because its LEGACY in databricks which will preclude this op from running on GPU + 'spark.sql.legacy.parquet.int96RebaseModeInRead' : 'CORRECTED', + 'spark.sql.legacy.parquet.datetimeRebaseModeInRead': 'CORRECTED'}) + # once https://github.com/NVIDIA/spark-rapids/issues/1126 is in we can remove spark.sql.legacy.parquet.datetimeRebaseModeInRead config which is a workaround + # for nested timestamp/date support + assert_gpu_and_cpu_are_equal_collect(read_func(data_path), + conf=all_confs) + parquet_compress_options = ['none', 'uncompressed', 'snappy', 'gzip'] # The following need extra jars 'lzo', 'lz4', 'brotli', 'zstd' # https://github.com/NVIDIA/spark-rapids/issues/143 diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 4efefb915c6..a827e474370 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -193,16 +193,6 @@ object GpuParquetScan { FileFormatChecks.tag(meta, readSchema, ParquetFormatType, ReadFileOp) - val schemaHasStrings = readSchema.exists { field => - TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[StringType]) - } - - if (sqlConf.get(SQLConf.PARQUET_BINARY_AS_STRING.key, - SQLConf.PARQUET_BINARY_AS_STRING.defaultValueString).toBoolean && schemaHasStrings) { - meta.willNotWorkOnGpu(s"GpuParquetScan does not support" + - s" ${SQLConf.PARQUET_BINARY_AS_STRING.key}") - } - val schemaHasTimestamps = readSchema.exists { field => TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) }