diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 82c9c90eacd8..fa8749faf4e3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -181,6 +181,17 @@ final class TypeSig private( new TypeSig(it, maxAllowedDecimalPrecision, childTypes, lt, notes) } + /** + * Add a literal restriction to the signature + * @param dataTypes the types that have to be literal. Will be added if they do not already exist. + * @return the new signature. + */ + def withLit(dataTypes: TypeEnum.ValueSet): TypeSig = { + val it = initialTypes ++ dataTypes + val lt = litOnlyTypes ++ dataTypes + new TypeSig(it, maxAllowedDecimalPrecision, childTypes, lt, notes) + } + /** * All currently supported types can only be literal values. * @return the new signature. @@ -531,7 +542,7 @@ object TypeSig { * Create a TypeSig that only supports literals of certain given types. */ def lit(dataTypes: TypeEnum.ValueSet): TypeSig = - new TypeSig(dataTypes) + TypeSig.none.withLit(dataTypes) /** * Create a TypeSig that supports only literals of common primitive CUDF types. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 591d0d393fcf..eeb453762e48 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -254,10 +254,10 @@ case class GpuGetMapValue(child: Expression, key: Expression, failOnError: Boole } override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = - throw new IllegalStateException("This is not supported yet") + throw new IllegalStateException("Map lookup keys must be scalar values") override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = - throw new IllegalStateException("This is not supported yet") + throw new IllegalStateException("Map lookup keys must be scalar values") override def left: Expression = child