From 339f6b7d25a3712a24f3774e9d1a7ce944498092 Mon Sep 17 00:00:00 2001 From: MithunR Date: Fri, 8 Apr 2022 12:20:48 -0700 Subject: [PATCH] Fix CPU fallback for Map lookup. Fixes #5180. Map lookup is currently supported only in cases where the keys are scalar values. In case the keys are specified as a vector (e.g. expressions), the plugin should fall back to CPU. #4944 introduced a bug in how literal signatures are specified for multiple data types. This breaks CPU fallback. This commit fixes the specification of literals-only `TypeSig`. --- .../scala/com/nvidia/spark/rapids/TypeChecks.scala | 13 ++++++++++++- .../spark/sql/rapids/complexTypeExtractors.scala | 4 ++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 82c9c90eacd8..fa8749faf4e3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -181,6 +181,17 @@ final class TypeSig private( new TypeSig(it, maxAllowedDecimalPrecision, childTypes, lt, notes) } + /** + * Add a literal restriction to the signature + * @param dataTypes the types that have to be literal. Will be added if they do not already exist. + * @return the new signature. + */ + def withLit(dataTypes: TypeEnum.ValueSet): TypeSig = { + val it = initialTypes ++ dataTypes + val lt = litOnlyTypes ++ dataTypes + new TypeSig(it, maxAllowedDecimalPrecision, childTypes, lt, notes) + } + /** * All currently supported types can only be literal values. * @return the new signature. @@ -531,7 +542,7 @@ object TypeSig { * Create a TypeSig that only supports literals of certain given types. */ def lit(dataTypes: TypeEnum.ValueSet): TypeSig = - new TypeSig(dataTypes) + TypeSig.none.withLit(dataTypes) /** * Create a TypeSig that supports only literals of common primitive CUDF types. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 591d0d393fcf..eeb453762e48 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -254,10 +254,10 @@ case class GpuGetMapValue(child: Expression, key: Expression, failOnError: Boole } override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = - throw new IllegalStateException("This is not supported yet") + throw new IllegalStateException("Map lookup keys must be scalar values") override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = - throw new IllegalStateException("This is not supported yet") + throw new IllegalStateException("Map lookup keys must be scalar values") override def left: Expression = child