diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index 13588f13ca5..52ecee7d1d9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -359,25 +359,35 @@ case class GpuCast( withResource(input.getBase.nansToNulls()) { inputWithNansToNull => withResource(FloatUtils.infinityToNulls(inputWithNansToNull)) { inputWithoutNanAndInfinity => - withResource(inputWithoutNanAndInfinity.mul(microsPerSec)) { inputTimesMicrosCv => - GpuColumnVector.from(inputTimesMicrosCv.castTo(DType.TIMESTAMP_MICROSECONDS)) + withResource(inputWithoutNanAndInfinity.mul(microsPerSec, DType.INT64)) { + inputTimesMicrosCv => + GpuColumnVector.from(inputTimesMicrosCv.castTo(DType.TIMESTAMP_MICROSECONDS)) } } } } + case (BooleanType, TimestampType) => + // cudf requires casting to a long first. + withResource(input.getBase.castTo(DType.INT64)) { longs => + GpuColumnVector.from(longs.castTo(cudfType)) + } + case (BooleanType | ByteType | ShortType | IntegerType, TimestampType) => + // cudf requires casting to a long first + withResource(input.getBase.castTo(DType.INT64)) { longs => + withResource(longs.castTo(DType.TIMESTAMP_SECONDS)) { timestampSecs => + GpuColumnVector.from(timestampSecs.castTo(cudfType)) + } + } case (_: NumericType, TimestampType) => // Spark casting to timestamp assumes value is in seconds, but timestamps // are tracked in microseconds. - val timestampSecs = input.getBase.castTo(DType.TIMESTAMP_SECONDS) - try { + withResource(input.getBase.castTo(DType.TIMESTAMP_SECONDS)) { timestampSecs => GpuColumnVector.from(timestampSecs.castTo(cudfType)) - } finally { - timestampSecs.close(); } + case (FloatType, LongType) | (DoubleType, IntegerType | LongType) => // Float.NaN => Int is casted to a zero but float.NaN => Long returns a small negative // number Double.NaN => Int | Long, returns a small negative number so Nans have to be // converted to zero first - case (FloatType, LongType) | (DoubleType, IntegerType | LongType) => withResource(FloatUtils.nanToZero(input.getBase)) { inputWithNansToZero => GpuColumnVector.from(inputWithNansToZero.castTo(cudfType)) }