diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 34a321910c30..0bc2d39481e3 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -314,119 +314,32 @@ pub fn sort_to_indices( } }, DataType::Dictionary(_, _) => { + let value_null_first = if options.descending { + // When sorting dictionary in descending order, we take inverse of of null ordering + // when sorting the values. Because if `nulls_first` is true, null must be in front + // of non-null value. As we take the sorted order of value array to sort dictionary + // keys, these null values will be treated as smallest ones and be sorted to the end + // of sorted result. So we set `nulls_first` to false when sorting dictionary value + // array to make them as largest ones, then null values will be put at the beginning + // of sorted dictionary result. + !options.nulls_first + } else { + options.nulls_first + }; + let value_options = Some(SortOptions { + descending: false, + nulls_first: value_null_first, + }); downcast_dictionary_array!( values => match values.values().data_type() { - DataType::Int8 => { - let dict_values = values.values(); - let value_null_first = if options.descending { - // When sorting dictionary in descending order, we take inverse of of null ordering - // when sorting the values. Because if `nulls_first` is true, null must be in front - // of non-null value. As we take the sorted order of value array to sort dictionary - // keys, these null values will be treated as smallest ones and be sorted to the end - // of sorted result. So we set `nulls_first` to false when sorting dictionary value - // array to make them as largest ones, then null values will be put at the beginning - // of sorted dictionary result. - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = prepare_indices_map(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp) - }, - DataType::Int16 => { + dt if DataType::is_primitive(dt) => { let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = prepare_indices_map(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp) - }, - DataType::Int32 => { - let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = prepare_indices_map(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp) - }, - DataType::Int64 => { - let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = prepare_indices_map(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp) - }, - DataType::UInt8 => { - let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = prepare_indices_map(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp) - }, - DataType::UInt16 => { - let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = prepare_indices_map(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp) - }, - DataType::UInt32 => { - let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); - let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; - let value_indices_map = prepare_indices_map(&sorted_value_indices); - sort_primitive_dictionary::<_, _>(values, &value_indices_map,v, n, options, limit, cmp) - }, - DataType::UInt64 => { - let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; let value_indices_map = prepare_indices_map(&sorted_value_indices); sort_primitive_dictionary::<_, _>(values, &value_indices_map, v, n, options, limit, cmp) }, DataType::Utf8 => { let dict_values = values.values(); - let value_null_first = if options.descending { - !options.nulls_first - } else { - options.nulls_first - }; - let value_options = Some(SortOptions { descending: false, nulls_first: value_null_first }); let sorted_value_indices = sort_to_indices(dict_values, value_options, None)?; let value_indices_map = prepare_indices_map(&sorted_value_indices); sort_string_dictionary::<_>(values, &value_indices_map, v, n, &options, limit) @@ -3552,4 +3465,142 @@ mod tests { vec![None, None, None, Some(5), Some(5), Some(3), Some(1)], ); } + + #[test] + fn test_sort_f32_dicts() { + let keys = + Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); + let values = Float32Array::from(vec![1.2, 3.0, 5.1]); + test_sort_primitive_dict_arrays::( + keys, + values, + None, + None, + vec![None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)], + ); + + let keys = + Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); + let values = Float32Array::from(vec![1.2, 3.0, 5.1]); + test_sort_primitive_dict_arrays::( + keys, + values, + Some(SortOptions { + descending: true, + nulls_first: false, + }), + None, + vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None], + ); + + let keys = + Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); + let values = Float32Array::from(vec![1.2, 3.0, 5.1]); + test_sort_primitive_dict_arrays::( + keys, + values, + Some(SortOptions { + descending: false, + nulls_first: false, + }), + None, + vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None], + ); + + let keys = + Int8Array::from(vec![Some(1_i8), None, Some(2), None, Some(2), Some(0)]); + let values = Float32Array::from(vec![1.2, 3.0, 5.1]); + test_sort_primitive_dict_arrays::( + keys, + values, + Some(SortOptions { + descending: true, + nulls_first: true, + }), + Some(3), + vec![None, None, Some(5.1)], + ); + + // Values have `None`. + let keys = Int8Array::from(vec![ + Some(1_i8), + None, + Some(3), + None, + Some(2), + Some(3), + Some(0), + ]); + let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]); + test_sort_primitive_dict_arrays::( + keys, + values, + None, + None, + vec![None, None, None, Some(1.2), Some(3.0), Some(5.1), Some(5.1)], + ); + + let keys = Int8Array::from(vec![ + Some(1_i8), + None, + Some(3), + None, + Some(2), + Some(3), + Some(0), + ]); + let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]); + test_sort_primitive_dict_arrays::( + keys, + values, + Some(SortOptions { + descending: false, + nulls_first: false, + }), + None, + vec![Some(1.2), Some(3.0), Some(5.1), Some(5.1), None, None, None], + ); + + let keys = Int8Array::from(vec![ + Some(1_i8), + None, + Some(3), + None, + Some(2), + Some(3), + Some(0), + ]); + let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]); + test_sort_primitive_dict_arrays::( + keys, + values, + Some(SortOptions { + descending: true, + nulls_first: false, + }), + None, + vec![Some(5.1), Some(5.1), Some(3.0), Some(1.2), None, None, None], + ); + + let keys = Int8Array::from(vec![ + Some(1_i8), + None, + Some(3), + None, + Some(2), + Some(3), + Some(0), + ]); + let values = Float32Array::from(vec![Some(1.2), Some(3.0), None, Some(5.1)]); + test_sort_primitive_dict_arrays::( + keys, + values, + Some(SortOptions { + descending: true, + nulls_first: true, + }), + None, + vec![None, None, None, Some(5.1), Some(5.1), Some(3.0), Some(1.2)], + ); + } } diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index 2ca71ef77725..d3189b8b18cc 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -1070,6 +1070,30 @@ impl DataType { ) } + /// Returns true if the type is primitive: (numeric, temporal). + pub fn is_primitive(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + | Interval(_) + | Duration(_) + ) + } + /// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval). pub fn is_temporal(t: &DataType) -> bool { use DataType::*;