diff --git a/docs/supported_ops.md b/docs/supported_ops.md index 035856bee79..3874b53a420 100644 --- a/docs/supported_ops.md +++ b/docs/supported_ops.md @@ -10442,7 +10442,7 @@ Accelerator support is described below. NS NS NS -NS +PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, UDT) NS @@ -19925,7 +19925,7 @@ as `a` don't show up in the table. They are controlled by the rules for NS NS NS -NS +PS* (missing nested BINARY, CALENDAR, ARRAY, MAP, UDT) NS diff --git a/integration_tests/src/main/python/repart_test.py b/integration_tests/src/main/python/repart_test.py index 1141b9e19a8..2a4723c3a39 100644 --- a/integration_tests/src/main/python/repart_test.py +++ b/integration_tests/src/main/python/repart_test.py @@ -92,6 +92,7 @@ def test_repartion_df(num_parts, length): ([('a', decimal_gen_64bit)], ['a']), ([('a', string_gen)], ['a']), ([('a', null_gen)], ['a']), + ([('a', StructGen([('c0', boolean_gen), ('c1', StructGen([('cc0', boolean_gen), ('cc1', string_gen)]))]))], ['a']), ([('a', long_gen), ('b', StructGen([('b1', long_gen)]))], ['a']), ([('a', long_gen), ('b', ArrayGen(long_gen, max_length=2))], ['a']), ([('a', byte_gen)], [f.col('a') - 5]), diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 2b5e1af87fd..f797fe82aad 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -2343,7 +2343,7 @@ object GpuOverrides { "Murmur3 hash operator", ExprChecks.projectNotLambda(TypeSig.INT, TypeSig.INT, repeatingParamCheck = Some(RepeatingParamCheck("input", - TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.STRUCT).nested(), TypeSig.all))), (a, conf, p, r) => new ExprMeta[Murmur3Hash](a, conf, p, r) { override val childExprs: Seq[BaseExprMeta[_]] = a.children @@ -2518,7 +2518,7 @@ object GpuOverrides { "Hash based partitioning", // This needs to match what murmur3 supports. PartChecks(RepeatingParamCheck("hash_key", - TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL + TypeSig.STRUCT).nested(), TypeSig.all)), (hp, conf, p, r) => new PartMeta[HashPartitioning](hp, conf, p, r) { override val childExprs: Seq[BaseExprMeta[_]] =