Skip to content

Commit

Permalink
Add casting of count to Int64 in array_repeat function to ensur…
Browse files Browse the repository at this point in the history
…e consistent integer type handling (#14236)

* Add casting of `count` to `Int64` in `array_repeat` function to ensure consistent integer type handling

* updated the "array_repeat" function signature to automatically cast uint or int types

* fixed clippy issue
  • Loading branch information
jatin510 authored Jan 24, 2025
1 parent 54e95b0 commit 633eef6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
51 changes: 40 additions & 11 deletions datafusion/functions-nested/src/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
use crate::utils::make_scalar_function;
use arrow::array::{Capacities, MutableArrayData};
use arrow::compute;
use arrow::compute::cast;
use arrow_array::{
new_null_array, Array, ArrayRef, GenericListArray, Int64Array, ListArray,
OffsetSizeTrait,
new_null_array, Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait,
UInt64Array,
};
use arrow_buffer::OffsetBuffer;
use arrow_schema::DataType::{LargeList, List};
use arrow_schema::{DataType, Field};
use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
use datafusion_common::cast::{as_large_list_array, as_list_array, as_uint64_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -86,7 +87,7 @@ impl Default for ArrayRepeat {
impl ArrayRepeat {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![String::from("list_repeat")],
}
}
Expand Down Expand Up @@ -124,19 +125,47 @@ impl ScalarUDFImpl for ArrayRepeat {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!("array_repeat expects two arguments");
}

let element_type = &arg_types[0];
let first = element_type.clone();

let count_type = &arg_types[1];

// Coerce the second argument to Int64/UInt64 if it's a numeric type
let second = match count_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Int64
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
DataType::UInt64
}
_ => return exec_err!("count must be an integer type"),
};

Ok(vec![first, second])
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

/// Array_repeat SQL function
pub fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 2 {
return exec_err!("array_repeat expects two arguments");
}

let element = &args[0];
let count_array = as_int64_array(&args[1])?;
let count_array = &args[1];

let count_array = match count_array.data_type() {
DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
DataType::UInt64 => count_array,
_ => return exec_err!("count must be an integer type"),
};

let count_array = as_uint64_array(count_array)?;

match element.data_type() {
List(_) => {
Expand Down Expand Up @@ -165,7 +194,7 @@ pub fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
/// ```
fn general_repeat<O: OffsetSizeTrait>(
array: &ArrayRef,
count_array: &Int64Array,
count_array: &UInt64Array,
) -> Result<ArrayRef> {
let data_type = array.data_type();
let mut new_values = vec![];
Expand Down Expand Up @@ -219,7 +248,7 @@ fn general_repeat<O: OffsetSizeTrait>(
/// ```
fn general_list_repeat<O: OffsetSizeTrait>(
list_array: &GenericListArray<O>,
count_array: &Int64Array,
count_array: &UInt64Array,
) -> Result<ArrayRef> {
let data_type = list_array.data_type();
let value_type = list_array.value_type();
Expand Down
24 changes: 24 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2760,6 +2760,30 @@ select
----
[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[NULL, NULL], [NULL, NULL], [NULL, NULL]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]

# array_repeat scalar function with count of different integer types
query ????????
Select
array_repeat(1, arrow_cast(2,'Int8')),
array_repeat(2, arrow_cast(2,'Int16')),
array_repeat(3, arrow_cast(2,'Int32')),
array_repeat(4, arrow_cast(2,'Int64')),
array_repeat(1, arrow_cast(2,'UInt8')),
array_repeat(2, arrow_cast(2,'UInt16')),
array_repeat(3, arrow_cast(2,'UInt32')),
array_repeat(4, arrow_cast(2,'UInt64'));
----
[1, 1] [2, 2] [3, 3] [4, 4] [1, 1] [2, 2] [3, 3] [4, 4]

# array_repeat scalar function with count of negative integer types
query ????
Select
array_repeat(1, arrow_cast(-2,'Int8')),
array_repeat(2, arrow_cast(-2,'Int16')),
array_repeat(3, arrow_cast(-2,'Int32')),
array_repeat(4, arrow_cast(-2,'Int64'));
----
[] [] [] []

# array_repeat with columns #1

statement ok
Expand Down

0 comments on commit 633eef6

Please sign in to comment.