Skip to content

Commit

Permalink
fix OrcRead on Decimal128 inside MapType (#4192)
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx authored Nov 23, 2021
1 parent 13af8fb commit c455b8c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
10 changes: 7 additions & 3 deletions integration_tests/src/main/python/orc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect
from datetime import date, datetime, timezone
from data_gen import *
from marks import *
from pyspark.sql.types import *
from spark_session import with_cpu_session, with_spark_session
from spark_session import with_cpu_session
from parquet_test import _nested_pruning_schemas

def read_orc_df(data_path):
Expand Down Expand Up @@ -83,11 +82,16 @@ def test_basic_read(std_input_path, name, read_func, v1_enabled_list, orc_impl,
orc_basic_map_gens = [simple_string_to_string_map_gen] + [MapGen(f(nullable=False), f()) for f in [
BooleanGen, ByteGen, ShortGen, IntegerGen, LongGen, FloatGen, DoubleGen,
lambda nullable=True: TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc), nullable=nullable),
lambda nullable=True: DateGen(start=date(1590, 1, 1), nullable=nullable)]]
lambda nullable=True: DateGen(start=date(1590, 1, 1), nullable=nullable),
lambda nullable=True: DecimalGen(precision=15, scale=1, nullable=nullable),
lambda nullable=True: DecimalGen(precision=36, scale=5, nullable=nullable)]]

# Some map gens, but not all because of nesting
orc_map_gens_sample = orc_basic_map_gens + [
MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(string_gen), max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), ArrayGen(decimal_gen_36_5), max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False),
ArrayGen(StructGen([["c0", decimal_gen_18_3], ["c1", decimal_gen_20_2]])), max_length=10),
MapGen(RepeatSeqGen(IntegerGen(nullable=False), 10), long_gen, max_length=10),
MapGen(StringGen(pattern='key_[0-9]', nullable=False), simple_string_to_string_map_gen),
MapGen(StructGen([['child0', byte_gen], ['child1', long_gen]], nullable=False),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.OrcFilters
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, MapType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SerializableConfiguration

Expand Down Expand Up @@ -416,6 +416,9 @@ trait OrcCommonFunctions extends OrcCodecWritingHelper {
dt.fields.foreach(f => findImpl(prefix + fieldName + ".", f.name, f.dataType))
case dt: ArrayType =>
findImpl(prefix + fieldName + ".", "1", dt.elementType)
case MapType(kt: DataType, vt: DataType, _) =>
findImpl(prefix + fieldName + ".", "0", kt)
findImpl(prefix + fieldName + ".", "1", vt)
case _ =>
}

Expand Down

0 comments on commit c455b8c

Please sign in to comment.