Skip to content

Commit

Permalink
Add in more generalized support for casting nested types (#3162)
Browse files Browse the repository at this point in the history
Signed-off-by: Robert (Bobby) Evans <[email protected]>
  • Loading branch information
revans2 authored Aug 6, 2021
1 parent 96f8c0c commit aae8875
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 322 deletions.
12 changes: 6 additions & 6 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -20587,7 +20587,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>missing nested BOOLEAN, BYTE, SHORT, LONG, DATE, TIMESTAMP, STRING, DECIMAL, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>The array's child type must also support being cast to the desired child type;<br/>max nested DECIMAL precision of 18;<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -20609,7 +20609,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>the map's key and value must also support being cast to the desired child types;<br/>max nested DECIMAL precision of 18;<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand All @@ -20631,7 +20631,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>the struct's children must also support being cast to the desired child type(s);<br/>max nested DECIMAL precision of 18;<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested CALENDAR, UDT</em></td>
<td> </td>
</tr>
<tr>
Expand Down Expand Up @@ -20991,7 +20991,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><em>PS<br/>missing nested BOOLEAN, BYTE, SHORT, LONG, DATE, TIMESTAMP, STRING, DECIMAL, NULL, BINARY, CALENDAR, MAP, STRUCT, UDT</em></td>
<td><em>PS<br/>The array's child type must also support being cast to the desired child type;<br/>max nested DECIMAL precision of 18;<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
<td> </td>
Expand All @@ -21013,7 +21013,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>the map's key and value must also support being cast to the desired child types;<br/>max nested DECIMAL precision of 18;<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested CALENDAR, UDT</em></td>
<td> </td>
<td> </td>
</tr>
Expand All @@ -21035,7 +21035,7 @@ and the accelerator produces the same result.
<td> </td>
<td> </td>
<td> </td>
<td><b>NS</b></td>
<td><em>PS<br/>the struct's children must also support being cast to the desired child type(s);<br/>max nested DECIMAL precision of 18;<br/>UTC is only supported TZ for nested TIMESTAMP;<br/>missing nested CALENDAR, UDT</em></td>
<td> </td>
</tr>
<tr>
Expand Down
50 changes: 0 additions & 50 deletions integration_tests/src/main/python/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,53 +165,3 @@ def test_array_element_at_all_null_ansi_not_fail(data_gen):
conf={'spark.sql.ansi.enabled':True,
'spark.sql.legacy.allowNegativeScaleOfDecimal': True})


@pytest.mark.parametrize('child_gen', [
float_gen,
double_gen,
int_gen
], ids=idfn)
@pytest.mark.parametrize('child_to_type', [
FloatType(),
DoubleType(),
IntegerType(),
], ids=idfn)
@pytest.mark.parametrize('depth', [1, 2, 3], ids=idfn)
def test_array_cast_recursive(child_gen, child_to_type, depth):
def cast_func(spark):
depth_rng = range(0, depth)
nested_gen = reduce(lambda dg, i: ArrayGen(dg, max_length=int(max(1, 16 / (2 ** i)))),
depth_rng, child_gen)
nested_type = reduce(lambda t, _: ArrayType(t), depth_rng, child_to_type)
df = two_col_df(spark, int_gen, nested_gen)
res = df.select(df.b.cast(nested_type))
return res
assert_gpu_and_cpu_are_equal_collect(cast_func)


@allow_non_gpu('ProjectExec', 'Alias', 'Cast')
def test_array_cast_fallback():
def cast_float_to_double(spark):
df = two_col_df(spark, int_gen, ArrayGen(int_gen))
res = df.select(df.b.cast(ArrayType(StringType())))
return res
assert_gpu_and_cpu_are_equal_collect(cast_float_to_double)


@pytest.mark.parametrize('child_gen', [
byte_gen,
string_gen,
decimal_gen_default,
], ids=idfn)
@pytest.mark.parametrize('child_to_type', [
FloatType(),
DoubleType(),
IntegerType(),
], ids=idfn)
@allow_non_gpu('ProjectExec', 'Alias', 'Cast')
def test_array_cast_bad_from_good_to_fallback(child_gen, child_to_type):
def cast_array(spark):
df = two_col_df(spark, int_gen, ArrayGen(child_gen))
res = df.select(df.b.cast(ArrayType(child_to_type)))
return res
assert_gpu_and_cpu_are_equal_collect(cast_array)
23 changes: 23 additions & 0 deletions integration_tests/src/main/python/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,26 @@ def test_cast_empty_string_to_int():
'CAST(a as INTEGER)',
'CAST(a as LONG)'))

# These tests are not intended to be exhaustive. The scala test CastOpSuite should cover
# just about everything for non-nested values. This is intended to check that the
# recursive code in nested type checks, like arrays, is working properly. So we are going
# pick child types that are simple to cast. Upcasting integer values and casting them to strings
@pytest.mark.parametrize('data_gen,to_type', [
(ArrayGen(byte_gen), ArrayType(IntegerType())),
(ArrayGen(StringGen('[0-9]{1,5}')), ArrayType(IntegerType())),
(ArrayGen(byte_gen), ArrayType(StringType())),
(ArrayGen(byte_gen), ArrayType(DecimalType(6, 2))),
(ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(IntegerType()))),
(ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(StringType()))),
(ArrayGen(ArrayGen(byte_gen)), ArrayType(ArrayType(DecimalType(6, 2)))),
(StructGen([('a', byte_gen)]), StructType([StructField('a', IntegerType())])),
(StructGen([('a', byte_gen), ('c', short_gen)]), StructType([StructField('b', IntegerType()), StructField('c', ShortType())])),
(StructGen([('a', ArrayGen(byte_gen)), ('c', short_gen)]), StructType([StructField('a', ArrayType(IntegerType())), StructField('c', LongType())])),
(ArrayGen(StructGen([('a', byte_gen), ('b', byte_gen)])), ArrayType(StringType())),
(MapGen(ByteGen(nullable=False), byte_gen), MapType(StringType(), StringType())),
(MapGen(ShortGen(nullable=False), ArrayGen(byte_gen)), MapType(IntegerType(), ArrayType(ShortType()))),
(MapGen(ShortGen(nullable=False), ArrayGen(StructGen([('a', byte_gen)]))), MapType(IntegerType(), ArrayType(StructType([StructField('b', ShortType())]))))
], ids=idfn)
def test_cast_nested(data_gen, to_type):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).select(f.col('a').cast(to_type)))
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,22 @@ class Spark311Shims extends Spark301Shims {

// calendarChecks are the same

override val arrayChecks: TypeSig = none
override val arrayChecks: TypeSig =
ARRAY.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) +
psNote(TypeEnum.ARRAY, "The array's child type must also support being cast to " +
"the desired child type")
override val sparkArraySig: TypeSig = ARRAY.nested(all)

override val mapChecks: TypeSig = none
override val mapChecks: TypeSig =
MAP.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT + MAP) +
psNote(TypeEnum.MAP, "the map's key and value must also support being cast to the " +
"desired child types")
override val sparkMapSig: TypeSig = MAP.nested(all)

override val structChecks: TypeSig = none
override val structChecks: TypeSig =
STRUCT.nested(commonCudfTypes + DECIMAL_64 + NULL + ARRAY + BINARY + STRUCT) +
psNote(TypeEnum.STRUCT, "the struct's children must also support being cast to the " +
"desired child type(s)")
override val sparkStructSig: TypeSig = STRUCT.nested(all)

override val udtChecks: TypeSig = none
Expand Down
Loading

0 comments on commit aae8875

Please sign in to comment.