Skip to content

Commit

Permalink
[minor]: use arrow take_batch instead of get_record_batch_indices (ap…
Browse files Browse the repository at this point in the history
…ache#13084)

* Initial commit

* Fix linter errors

* Minor changes

* Fix error
  • Loading branch information
akurmustafa authored Oct 24, 2024
1 parent 3f3a0cf commit 8adbc23
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 35 deletions.
24 changes: 4 additions & 20 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ pub mod proxy;
pub mod string_utils;

use crate::error::{_internal_datafusion_err, _internal_err};
use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
use arrow::array::{ArrayRef, PrimitiveArray};
use crate::{DataFusionError, Result, ScalarValue};
use arrow::array::ArrayRef;
use arrow::buffer::OffsetBuffer;
use arrow::compute::{partition, take_arrays, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef};
use arrow_array::cast::AsArray;
use arrow_array::{
Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait,
RecordBatchOptions,
};
use arrow_schema::DataType;
use sqlparser::ast::Ident;
Expand Down Expand Up @@ -92,20 +90,6 @@ pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result<Vec<ScalarValu
.collect()
}

/// Construct a new RecordBatch from the rows of the `record_batch` at the `indices`.
pub fn get_record_batch_at_indices(
record_batch: &RecordBatch,
indices: &PrimitiveArray<UInt32Type>,
) -> Result<RecordBatch> {
let new_columns = take_arrays(record_batch.columns(), indices, None)?;
RecordBatch::try_new_with_options(
record_batch.schema(),
new_columns,
&RecordBatchOptions::new().with_row_count(Some(indices.len())),
)
.map_err(|e| arrow_datafusion_err!(e))
}

/// This function compares two tuples depending on the given sort options.
pub fn compare_rows(
x: &[ScalarValue],
Expand Down
17 changes: 5 additions & 12 deletions datafusion/core/tests/fuzz_cases/equivalence/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,11 @@ use std::any::Any;
use std::cmp::Ordering;
use std::sync::Arc;

use arrow::compute::{lexsort_to_indices, SortColumn};
use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn};
use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::{
ArrayRef, Float32Array, Float64Array, PrimitiveArray, RecordBatch, UInt32Array,
};
use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array};
use arrow_schema::{SchemaRef, SortOptions};
use datafusion_common::utils::{
compare_rows, get_record_batch_at_indices, get_row_at_idx,
};
use datafusion_common::utils::{compare_rows, get_row_at_idx};
use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand Down Expand Up @@ -465,7 +461,7 @@ pub fn generate_table_for_orderings(
// Sort batch according to first ordering expression
let sort_columns = get_sort_columns(&batch, &orderings[0])?;
let sort_indices = lexsort_to_indices(&sort_columns, None)?;
let mut batch = get_record_batch_at_indices(&batch, &sort_indices)?;
let mut batch = take_record_batch(&batch, &sort_indices)?;

// prune out rows that is invalid according to remaining orderings.
for ordering in orderings.iter().skip(1) {
Expand All @@ -490,10 +486,7 @@ pub fn generate_table_for_orderings(
}
}
// Only keep valid rows, that satisfies given ordering relation.
batch = get_record_batch_at_indices(
&batch,
&PrimitiveArray::from_iter_values(keep_indices),
)?;
batch = take_record_batch(&batch, &UInt32Array::from_iter_values(keep_indices))?;
}

Ok(batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use crate::{
SendableRecordBatchStream, Statistics, WindowExpr,
};
use ahash::RandomState;
use arrow::compute::take_record_batch;
use arrow::{
array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder},
compute::{concat, concat_batches, sort_to_indices, take_arrays},
Expand All @@ -49,8 +50,7 @@ use arrow::{
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::stats::Precision;
use datafusion_common::utils::{
evaluate_partition_ranges, get_at_indices, get_record_batch_at_indices,
get_row_at_idx,
evaluate_partition_ranges, get_at_indices, get_row_at_idx,
};
use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result};
use datafusion_execution::TaskContext;
Expand Down Expand Up @@ -558,7 +558,7 @@ impl PartitionSearcher for LinearSearch {
let mut new_indices = UInt32Builder::with_capacity(indices.len());
new_indices.append_slice(&indices);
let indices = new_indices.finish();
Ok((row, get_record_batch_at_indices(record_batch, &indices)?))
Ok((row, take_record_batch(record_batch, &indices)?))
})
.collect()
}
Expand Down

0 comments on commit 8adbc23

Please sign in to comment.