-
Notifications
You must be signed in to change notification settings - Fork 842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RecordBatch
normalization (flattening)
#6758
base: main
Are you sure you want to change the base?
Changes from 4 commits
bbd7c8b
8abcd25
6bba7d3
55eb953
30d6294
0ed979d
d9d08cd
d1b3260
a12082c
7adda58
9c9c699
4422add
d0dc5a7
1e40c98
3c424d1
6d6b026
71380b6
af7946b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -19,8 +19,11 @@ | |||||||||||||
//! [schema](arrow_schema::Schema). | ||||||||||||||
|
||||||||||||||
use crate::{new_empty_array, Array, ArrayRef, StructArray}; | ||||||||||||||
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef}; | ||||||||||||||
use std::ops::Index; | ||||||||||||||
use arrow_schema::{ | ||||||||||||||
ArrowError, DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef, | ||||||||||||||
}; | ||||||||||||||
use std::collections::VecDeque; | ||||||||||||||
use std::ops::{Deref, Index}; | ||||||||||||||
use std::sync::Arc; | ||||||||||||||
|
||||||||||||||
/// Trait for types that can read `RecordBatch`'s. | ||||||||||||||
|
@@ -403,6 +406,68 @@ impl RecordBatch { | |||||||||||||
) | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Normalize a semi-structured RecordBatch into a flat table | ||||||||||||||
/// If max_level is 0, normalizes all levels. | ||||||||||||||
pub fn normalize(&self, separator: &str, mut max_level: usize) -> Result<Self, ArrowError> { | ||||||||||||||
if max_level == 0 { | ||||||||||||||
max_level = usize::MAX; | ||||||||||||||
} | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
imo this seems the more Rusty way, making use of Option instead of a sentinel value (though I'm not sure if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay I've been working on this a bit, I found a few possible solutions that might fit. I think Option might not be the best choice, since personally, the case of Some(0) feels weird to me, and would mean you're doing an annoying copy for no reason (because of that, I would want to add in an if statement to catch it, but then we end up in the same place). For max_level.is_zero().then(|| max_level = usize::MAX); Another option is to use something like NonZeroUsize::new(1) This makes the normalize call potentially longer and more annoying, but it means there wouldn't be another import. Any thoughts on these/if you disagree? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Personally I find this okay. I'm less concerned with requiring an if check inside the code (its pretty simple anyway) compared to presenting a more Rust-like interface to users.
I don't follow this, the I agree with But yeah I'm curious to see what others might think for this too. |
||||||||||||||
if self.num_rows() == 0 { | ||||||||||||||
ngli-me marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
// No data, only need to normalize the schema | ||||||||||||||
return Ok(Self::new_empty(Arc::new( | ||||||||||||||
self.schema.normalize(separator, max_level)?, | ||||||||||||||
))); | ||||||||||||||
} | ||||||||||||||
let mut queue: VecDeque<(usize, &Arc<dyn Array>, &FieldRef)> = VecDeque::new(); | ||||||||||||||
|
||||||||||||||
// push fields | ||||||||||||||
for (c, f) in self.columns.iter().zip(self.schema.fields()) { | ||||||||||||||
queue.push_front((0, c, f)); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
while !queue.is_empty() { | ||||||||||||||
match queue.pop_front() { | ||||||||||||||
Some((depth, c, f)) => { | ||||||||||||||
|
||||||||||||||
if depth < max_level { | ||||||||||||||
match (c.data_type(), f.data_type()) { | ||||||||||||||
//DataType::List(f) => field, | ||||||||||||||
ngli-me marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
//DataType::ListView(_) => field, | ||||||||||||||
//DataType::FixedSizeList(_, _) => field, | ||||||||||||||
//DataType::LargeList(_) => field, | ||||||||||||||
//DataType::LargeListView(_) => field, | ||||||||||||||
(DataType::Struct(cf), DataType::Struct(ff)) => { | ||||||||||||||
let field_name = f.name().as_str(); | ||||||||||||||
let new_key = format!("{key_string}{separator}{field_name}"); | ||||||||||||||
ff.iter().rev().zip(cf.iter().rev()).map(|(field, ())| { | ||||||||||||||
let updated_field = Field::new( | ||||||||||||||
format!("{key_string}{separator}{}", field.name()), | ||||||||||||||
field.data_type().clone(), | ||||||||||||||
field.is_nullable(), | ||||||||||||||
); | ||||||||||||||
queue.push_front(( | ||||||||||||||
depth + 1, | ||||||||||||||
c, // TODO: need to modify c -- if it's a StructArray, it needs to have the fields modified. | ||||||||||||||
&Arc::new(updated_field), | ||||||||||||||
)) | ||||||||||||||
}); | ||||||||||||||
} | ||||||||||||||
//DataType::Union(_, _) => field, | ||||||||||||||
//DataType::Dictionary(_, _) => field, | ||||||||||||||
//DataType::Map(_, _) => field, | ||||||||||||||
//DataType::RunEndEncoded(_, _) => field, // not sure how to support this field | ||||||||||||||
_ => queue.push_front((depth, c, f)), | ||||||||||||||
} | ||||||||||||||
} else { | ||||||||||||||
queue.push_front((depth, c, f)); | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
None => break, | ||||||||||||||
ngli-me marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
}; | ||||||||||||||
} | ||||||||||||||
todo!() | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
/// Returns the number of columns in the record batch. | ||||||||||||||
/// | ||||||||||||||
/// # Example | ||||||||||||||
|
@@ -1206,6 +1271,44 @@ mod tests { | |||||||||||||
assert_ne!(batch1, batch2); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
#[test] | ||||||||||||||
fn normalize() { | ||||||||||||||
let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""])); | ||||||||||||||
let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)])); | ||||||||||||||
let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)])); | ||||||||||||||
|
||||||||||||||
let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true)); | ||||||||||||||
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true)); | ||||||||||||||
let year_field = Arc::new(Field::new("year", DataType::Int64, true)); | ||||||||||||||
|
||||||||||||||
let a = Arc::new(StructArray::from(vec![ | ||||||||||||||
(animals_field.clone(), Arc::new(animals) as ArrayRef), | ||||||||||||||
(n_legs_field.clone(), Arc::new(n_legs) as ArrayRef), | ||||||||||||||
(year_field.clone(), Arc::new(year) as ArrayRef), | ||||||||||||||
])); | ||||||||||||||
|
||||||||||||||
let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)])); | ||||||||||||||
|
||||||||||||||
let schema = Schema::new(vec![ | ||||||||||||||
Field::new( | ||||||||||||||
"a", | ||||||||||||||
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])), | ||||||||||||||
false, | ||||||||||||||
), | ||||||||||||||
Field::new("month", DataType::Int64, true), | ||||||||||||||
]); | ||||||||||||||
let normalized = schema.clone().normalize(".", 0).unwrap(); | ||||||||||||||
println!("{:?}", normalized); | ||||||||||||||
|
||||||||||||||
let record_batch = | ||||||||||||||
RecordBatch::try_new(Arc::new(schema), vec![a, month]).expect("valid conversion"); | ||||||||||||||
|
||||||||||||||
println!("Fields: {:?}", record_batch.schema().fields()); | ||||||||||||||
println!("Metadata{:?}", record_batch.columns()); | ||||||||||||||
|
||||||||||||||
//println!("{:?}", record_batch); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
#[test] | ||||||||||||||
fn project() { | ||||||||||||||
let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); | ||||||||||||||
|
@@ -1318,7 +1421,9 @@ mod tests { | |||||||||||||
let metadata = vec![("foo".to_string(), "bar".to_string())] | ||||||||||||||
.into_iter() | ||||||||||||||
.collect(); | ||||||||||||||
println!("Metadata: {:?}", metadata); | ||||||||||||||
let metadata_schema = nullable_schema.as_ref().clone().with_metadata(metadata); | ||||||||||||||
println!("Metadata schema: {:?}", metadata_schema); | ||||||||||||||
let batch = batch.with_schema(Arc::new(metadata_schema)).unwrap(); | ||||||||||||||
|
||||||||||||||
// Cannot remove metadata | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ use std::sync::Arc; | |
|
||
use crate::error::ArrowError; | ||
use crate::field::Field; | ||
use crate::{FieldRef, Fields}; | ||
use crate::{DataType, FieldRef, Fields}; | ||
|
||
/// A builder to facilitate building a [`Schema`] from iteratively from [`FieldRef`] | ||
#[derive(Debug, Default)] | ||
|
@@ -413,6 +413,93 @@ impl Schema { | |
&self.metadata | ||
} | ||
|
||
/// Returns a new schema, normalized based on the max_level | ||
/// This carries metadata from the parent schema over as well | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, please document the parametrs to this function and add a documentation example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, thanks! |
||
pub fn normalize(&self, separator: &str, mut max_level: usize) -> Result<Self, ArrowError> { | ||
ngli-me marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if max_level == 0 { | ||
max_level = usize::MAX; | ||
} | ||
let mut new_fields: Vec<Field> = vec![]; | ||
for field in self.fields() { | ||
match field.data_type() { | ||
//DataType::List(f) => field, | ||
//DataType::ListView(_) => field, | ||
//DataType::FixedSizeList(_, _) => field, | ||
//DataType::LargeList(_) => field, | ||
//DataType::LargeListView(_) => field, | ||
DataType::Struct(nested_fields) => { | ||
let field_name = field.name().as_str(); | ||
new_fields = [ | ||
new_fields, | ||
Self::normalizer( | ||
ngli-me marked this conversation as resolved.
Show resolved
Hide resolved
|
||
nested_fields.to_vec(), | ||
field_name, | ||
separator, | ||
max_level - 1, | ||
), | ||
] | ||
.concat(); | ||
} | ||
//DataType::Union(_, _) => field, | ||
//DataType::Dictionary(_, _) => field, | ||
//DataType::Map(_, _) => field, | ||
//DataType::RunEndEncoded(_, _) => field, // not sure how to support this field | ||
_ => new_fields.push(Field::new( | ||
field.name(), | ||
field.data_type().clone(), | ||
field.is_nullable(), | ||
)), | ||
}; | ||
} | ||
Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) | ||
} | ||
|
||
fn normalizer( | ||
fields: Vec<FieldRef>, | ||
key_string: &str, | ||
separator: &str, | ||
max_level: usize, | ||
) -> Vec<Field> { | ||
if max_level > 0 { | ||
let mut new_fields: Vec<Field> = vec![]; | ||
for field in fields { | ||
match field.data_type() { | ||
//DataType::List(f) => , | ||
//DataType::ListView(_) => , | ||
//DataType::FixedSizeList(_, _) => , | ||
//DataType::LargeList(_) => , | ||
//DataType::LargeListView(_) => , | ||
DataType::Struct(nested_fields) => { | ||
let field_name = field.name().as_str(); | ||
let new_key = format!("{key_string}{separator}{field_name}"); | ||
new_fields = [ | ||
new_fields, | ||
Self::normalizer( | ||
nested_fields.to_vec(), | ||
new_key.as_str(), | ||
separator, | ||
max_level - 1, | ||
), | ||
] | ||
.concat(); | ||
} | ||
//DataType::Union(_, _) => field, | ||
//DataType::Dictionary(_, _) => field, | ||
//DataType::Map(_, _) => field, | ||
//DataType::RunEndEncoded(_, _) => field, // not sure how to support this field | ||
_ => new_fields.push(Field::new( | ||
format!("{key_string}{separator}{}", field.name()), | ||
field.data_type().clone(), | ||
field.is_nullable(), | ||
)), | ||
}; | ||
} | ||
new_fields | ||
} else { | ||
todo!() | ||
} | ||
} | ||
|
||
/// Look up a column by name and return a immutable reference to the column along with | ||
/// its index. | ||
pub fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please improve this documentation (maybe copy from the pyarrow version)?
max_level
means (in addition to that 0)separator
doesFor example like https://docs.rs/arrow/latest/arrow/index.html#columnar-format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, missed doing this, will do!