Skip to content

Commit

Permalink
Merge two RecordBatch (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Jan 22, 2023
1 parent d5302fd commit 1b057a6
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 45 deletions.
129 changes: 92 additions & 37 deletions rust/src/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use arrow_schema::{DataType, Field, Schema};

mod kernels;
mod record_batch;
use crate::error::Result;
use crate::error::{Error, Result};
pub use kernels::*;
pub use record_batch::*;

Expand Down Expand Up @@ -69,33 +69,34 @@ impl DataTypeExt for DataType {
}

fn is_struct(&self) -> bool {
matches!(self, DataType::Struct(_))
matches!(self, Self::Struct(_))
}

fn is_fixed_stride(&self) -> bool {
match self {
DataType::Boolean
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::FixedSizeList(_, _)
| DataType::FixedSizeBinary(_) => true,
_ => false,
}
use DataType::*;
matches!(
self,
Boolean
| UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
| Decimal128(_, _)
| Decimal256(_, _)
| FixedSizeList(_, _)
| FixedSizeBinary(_)
)
}

fn is_dictionary(&self) -> bool {
matches!(self, DataType::Dictionary(_, _))
matches!(self, Self::Dictionary(_, _))
}
}

Expand Down Expand Up @@ -129,7 +130,7 @@ impl ListArrayExt for ListArray {
))))
.len(offsets.len() - 1)
.add_buffer(offsets.into_data().buffers()[0].clone())
.add_child_data(values.into_data().clone())
.add_child_data(values.into_data())
.build()?;

Ok(Self::from(data))
Expand All @@ -150,7 +151,7 @@ impl LargeListArrayExt for LargeListArray {
))))
.len(offsets.len() - 1)
.add_buffer(offsets.into_data().buffers()[0].clone())
.add_child_data(values.into_data().clone())
.add_child_data(values.into_data())
.build()?;

Ok(Self::from(data))
Expand Down Expand Up @@ -235,11 +236,6 @@ impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray {

/// Extends Arrow's [RecordBatch].
pub trait RecordBatchExt {
/// Get a column by its name.
///
/// Returns None if the column does not exist.
fn column_with_name(&self, name: &str) -> Option<&ArrayRef>;

/// Append a new column to this [`RecordBatch`] and returns a new RecordBatch.
///
/// ```
Expand Down Expand Up @@ -270,16 +266,44 @@ pub trait RecordBatchExt {
/// )
/// ```
fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<RecordBatch>;

/// Merge with another [`RecordBatch`] and returns a new one.
///
/// ```
/// use std::sync::Arc;
/// use arrow_array::*;
/// use arrow_schema::{Schema, Field, DataType};
/// use lance::arrow::*;
///
/// let left_schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
/// let int_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
/// let left = RecordBatch::try_new(left_schema, vec![int_arr.clone()]).unwrap();
///
/// let right_schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)]));
/// let str_arr = Arc::new(StringArray::from(vec!["a", "b", "c", "d"]));
/// let right = RecordBatch::try_new(right_schema, vec![str_arr.clone()]).unwrap();
///
/// let new_record_batch = left.merge(&right).unwrap();
///
/// assert_eq!(
/// new_record_batch,
/// RecordBatch::try_new(
/// Arc::new(Schema::new(
/// vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("s", DataType::Utf8, true)
/// ])
/// ),
/// vec![int_arr, str_arr],
/// ).unwrap()
/// )
/// ```
///
/// TODO: add merge nested fields support.
fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;
}

impl RecordBatchExt for RecordBatch {
fn column_with_name(&self, name: &str) -> Option<&ArrayRef> {
self.schema()
.index_of(name)
.ok()
.map(|idx| self.column(idx))
}

fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<Self> {
let mut new_fields = self.schema().fields.clone();
new_fields.push(field);
Expand All @@ -289,6 +313,37 @@ impl RecordBatchExt for RecordBatch {
));
let mut new_columns = self.columns().to_vec();
new_columns.push(arr);
Ok(RecordBatch::try_new(new_schema, new_columns)?)
Ok(Self::try_new(new_schema, new_columns)?)
}

fn merge(&self, other: &RecordBatch) -> Result<RecordBatch> {
if self.num_rows() != other.num_rows() {
return Err(Error::Arrow(format!(
"Attempt to merge two RecordBatch with different sizes: {} != {}",
self.num_rows(),
other.num_rows()
)));
}

let mut fields = self.schema().fields.clone();
let mut columns = Vec::from(self.columns());
for field in other.schema().fields.as_slice() {
if !fields.iter().any(|f| f.name() == field.name()) {
fields.push(field.clone());
columns.push(
other
.column_by_name(field.name())
.ok_or_else(|| {
Error::Arrow(format!(
"Column {} does not exist: schema={}",
field.name(),
other.schema()
))
})?
.clone(),
);
}
}
Ok(Self::try_new(Arc::new(Schema::new(fields)), columns)?)
}
}
11 changes: 5 additions & 6 deletions rust/src/index/vector/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,17 @@ impl VectorIndex for FlatIndex<'_> {
let k = params.key.clone();
let batch = batch.clone();
let vectors = batch
.column_with_name(&self.column)
.ok_or(Error::Schema(format!(
"column {} does not exist in dataset",
self.column,
)))?
.column_by_name(&params.column)
.ok_or_else(|| {
Error::Schema(format!("column {} does not exist in dataset", self.column,))
})?
.clone();
let scores = tokio::task::spawn_blocking(move || {
l2_distance(&k, as_fixed_size_list_array(&vectors)).unwrap()
})
.await?;
// TODO: only pick top-k in each batch first.
let row_id_array = batch.column_with_name("_rowid").unwrap().clone();
let row_id_array = batch["_rowid"].clone();
Ok((scores as ArrayRef, row_id_array))
})
.try_collect::<Vec<_>>()
Expand Down
3 changes: 1 addition & 2 deletions rust/src/io/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,7 @@ mod tests {

for b in 0..10 {
let batch = reader.read_batch(b, ..).await.unwrap();
assert!(batch.column_with_name("_rowid").is_some());
let row_ids_col = batch.column_with_name("_rowid").unwrap();
let row_ids_col = &batch["_rowid"];
// Do the same computation as `compute_row_id`.
let start_pos = (fragment << 32) as u64 + 10 * b as u64;

Expand Down

0 comments on commit 1b057a6

Please sign in to comment.