Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly support sliced unions #91

Merged
merged 2 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arrow2_convert/tests/test_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fn test_nested_unit_variant() {
}

// TODO: reenable this test once slices for enums is fixed.
//#[test]
#[test]
#[allow(unused)]
fn test_slice() {
#[derive(Debug, PartialEq, ArrowField, ArrowSerialize, ArrowDeserialize)]
Expand Down
79 changes: 17 additions & 62 deletions arrow2_convert_derive/src/derive_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
original_name,
original_name_str,
visibility,
is_dense,
is_dense: _,
variants,
variant_names,
variant_names: _,
variant_indices,
variant_types,
..
Expand All @@ -426,7 +426,7 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
// for dense unions.
// - For sparse unions, return the value of the variant that corresponds to the matched arm, and
// consume the iterators of the rest of the variants.
let iter_next_match_block = if is_dense {
let iter_next_match_block = {
let candidates = variants.iter()
.zip(&variant_indices)
.zip(&variant_types)
Expand All @@ -435,65 +435,20 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
if v.is_unit {
quote! {
#lit_idx => {
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
assert!(v.unwrap());
Some(Some(#original_name::#name))
}
}
}
else {
quote! {
#lit_idx => {
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v)))
}
}
}
})
.collect::<Vec<TokenStream>>();
quote! { #(#candidates)* }
} else {
let candidates = variants.iter()
.enumerate()
.zip(variant_indices.iter())
.zip(&variant_types)
.map(|(((i, v), lit_idx), variant_type)| {
let consume = variants.iter()
.enumerate()
.map(|(n, v)| {
let name = &v.syn.ident;
if i != n {
quote! {
let _ = self.#name.next();
}
}
else {
quote! {}
}
})
.collect::<Vec<TokenStream>>();
let consume = quote! { #(#consume)* };

let name = &v.syn.ident;
if v.is_unit {
quote! {
#lit_idx => {
#consume
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
assert!(v.unwrap());
Some(Some(#original_name::#name))
}
}
}
else {
quote! {
#lit_idx => {
#consume
let v = self.#name.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", #original_name_str));
let (_, offset) = self.arr.index(next_index);
let slice = self.arr.fields()[#lit_idx].slice(offset, 1);
let mut slice_iter = <<#variant_type as arrow2_convert::deserialize::ArrowDeserialize> ::ArrayType as arrow2_convert::deserialize::ArrowArray> ::iter_from_array_ref(slice.deref());
let v = slice_iter
.next()
.unwrap_or_else(|| panic!("Invalid offset for {}", "TensorData"));
Some(<#variant_type as arrow2_convert::deserialize::ArrowDeserialize>::arrow_deserialize(v).map(|v| #original_name::#name(v)))
}
}
Expand All @@ -516,15 +471,12 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
#[inline]
fn iter_from_array_ref<'a>(b: &'a dyn arrow2::array::Array) -> <&'a Self as IntoIterator>::IntoIter
{
use core::ops::Deref;
let arr = b.as_any().downcast_ref::<arrow2::array::UnionArray>().unwrap();
let fields = arr.fields();

#iterator_name {
#(
#variant_names: <<#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as arrow2_convert::deserialize::ArrowArray>::iter_from_array_ref(fields[#variant_indices].deref()),
)*
arr,
types_iter: arr.types().iter(),
index_iter: 0..arr.len(),
}
}
}
Expand All @@ -545,10 +497,9 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {
let array_iterator_decl = quote! {
#[allow(non_snake_case)]
#visibility struct #iterator_name<'a> {
#(
#variant_names: <&'a <#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as IntoIterator>::IntoIter,
)*
arr: &'a arrow2::array::UnionArray,
types_iter: std::slice::Iter<'a, i8>,
index_iter: std::ops::Range<usize>,
}
};

Expand All @@ -558,6 +509,10 @@ pub fn expand_deserialize(input: DeriveEnum) -> TokenStream {

#[inline]
fn next(&mut self) -> Option<Self::Item> {
use core::ops::Deref;
let Some(next_index) = self.index_iter.next() else {
return None;
};
match self.types_iter.next() {
Some(type_idx) => {
match type_idx {
Expand Down