Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: column_name based index access for RecordBatch and StructArray #3458

Merged
merged 2 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use arrow_buffer::buffer::buffer_bin_or;
use arrow_buffer::Buffer;
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field};
use std::any::Any;
use std::{any::Any, ops::Index};

/// A nested array type where each child (called *field*) is represented by a separate
/// array.
Expand Down Expand Up @@ -296,6 +296,23 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
}
}

impl Index<&str> for StructArray {
type Output = ArrayRef;

/// Get a reference to a column's array by name.
///
/// Note: A schema can currently have duplicate field names, in which case
/// the first field will always be selected.
/// This issue will be addressed in [ARROW-11178](https://issues.apache.org/jira/browse/ARROW-11178)
///
/// # Panics
///
/// Panics if the name is not in the schema.
fn index(&self, name: &str) -> &Self::Output {
self.column_by_name(name).unwrap()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -352,6 +369,26 @@ mod tests {
assert_eq!(0, struct_array.offset());
}

/// validates that struct can be accessed using `column_name` as index i.e. `struct_array["column_name"]`.
#[test]
fn test_struct_array_index_access() {
let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));

let struct_array = StructArray::from(vec![
(
Field::new("b", DataType::Boolean, false),
boolean.clone() as ArrayRef,
),
(
Field::new("c", DataType::Int32, false),
int.clone() as ArrayRef,
),
]);
assert_eq!(struct_array["b"].as_ref(), boolean.as_ref());
assert_eq!(struct_array["c"].as_ref(), int.as_ref());
}

/// validates that the in-memory representation follows [the spec](https://arrow.apache.org/docs/format/Columnar.html#struct-layout)
#[test]
fn test_struct_array_from_vec() {
Expand Down
40 changes: 40 additions & 0 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

use crate::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use std::ops::Index;
use std::sync::Arc;

/// Trait for types that can read `RecordBatch`'s.
Expand Down Expand Up @@ -288,6 +289,13 @@ impl RecordBatch {
&self.columns[index]
}

/// Get a reference to a column's array by name.
pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
self.schema()
.column_with_name(name)
.map(|(index, _)| &self.columns[index])
}

/// Get a reference to all columns in the record batch.
pub fn columns(&self) -> &[ArrayRef] {
&self.columns[..]
Expand Down Expand Up @@ -473,6 +481,19 @@ impl From<RecordBatch> for StructArray {
}
}

impl Index<&str> for RecordBatch {
type Output = ArrayRef;

/// Get a reference to a column's array by name.
///
/// # Panics
///
/// Panics if the name is not in the schema.
fn index(&self, name: &str) -> &Self::Output {
self.column_by_name(name).unwrap()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -746,6 +767,25 @@ mod tests {
assert_eq!(batch1, batch2);
}

/// validates if the record batch can be accessed using `column_name` as index i.e. `record_batch["column_name"]`
#[test]
fn record_batch_index_access() {
let id_arr = Arc::new(Int32Array::from(vec![1, 2, 3, 4]));
let val_arr = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
let schema1 = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("val", DataType::Int32, false),
]);
let record_batch = RecordBatch::try_new(
Arc::new(schema1),
vec![id_arr.clone(), val_arr.clone()],
)
.unwrap();

assert_eq!(record_batch["id"].as_ref(), id_arr.as_ref());
assert_eq!(record_batch["val"].as_ref(), val_arr.as_ref());
}

#[test]
fn record_batch_vals_ne() {
let id_arr1 = Int32Array::from(vec![1, 2, 3, 4]);
Expand Down