Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Field::Index to Field::Name expr transform #1894

Merged
merged 26 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions vortex-dtype/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use std::fmt::{Display, Formatter};
use std::sync::Arc;

use itertools::Itertools;
use vortex_error::{vortex_err, VortexResult};

use crate::FieldNames;

/// A selector for a field in a struct
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -46,6 +49,54 @@ impl Display for Field {
}
}

impl Field {
/// Returns true if the field is defined by Name
pub fn is_named(&self) -> bool {
matches!(self, Field::Name(_))
}

/// Returns true if the field is defined by Index
pub fn is_indexed(&self) -> bool {
matches!(self, Field::Index(_))
}

/// Convert a field to a named field
pub fn into_named_field(self, field_names: &FieldNames) -> VortexResult<Self> {
match self {
Field::Index(idx) => field_names
.get(idx)
.ok_or_else(|| {
vortex_err!(
"Field index {} out of bounds, it has names {:?}",
idx,
field_names
)
})
.cloned()
.map(Field::Name),
Field::Name(_) => Ok(self),
}
}

/// Convert a field to an indexed field
pub fn into_index_field(self, field_names: &FieldNames) -> VortexResult<Self> {
match self {
Field::Name(name) => field_names
.iter()
.position(|n| *n == name)
.ok_or_else(|| {
vortex_err!(
"Field name {} not found, it has names {:?}",
name,
field_names
)
})
.map(Field::Index),
Field::Index(_) => Ok(self),
}
}
}

/// A path through a (possibly nested) struct, composed of a sequence of field selectors
// TODO(ngates): wrap `Vec<Field>` in Option for cheaper "root" path.
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down
4 changes: 2 additions & 2 deletions vortex-expr/src/forms/nnf.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use vortex_error::VortexResult;

use crate::traversal::{FoldChildren, FoldDown, FoldUp, Folder, Node as _};
use crate::traversal::{FoldChildren, FoldDown, FoldUp, FolderMut, Node as _};
use crate::{not, BinaryExpr, ExprRef, Not, Operator};

/// Return an equivalent expression in Negative Normal Form (NNF).
Expand Down Expand Up @@ -63,7 +63,7 @@ pub fn nnf(expr: ExprRef) -> VortexResult<ExprRef> {
#[derive(Default)]
struct NNFVisitor {}

impl Folder for NNFVisitor {
impl FolderMut for NNFVisitor {
type NodeTy = ExprRef;
type Out = ExprRef;
type Context = bool;
Expand Down
1 change: 1 addition & 0 deletions vortex-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod project;
pub mod pruning;
mod row_filter;
mod select;
mod transform;
#[allow(dead_code)]
mod traversal;

Expand Down
31 changes: 29 additions & 2 deletions vortex-expr/src/pack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use itertools::Itertools as _;
use vortex_array::array::StructArray;
use vortex_array::validity::Validity;
use vortex_array::{ArrayData, IntoArrayData};
use vortex_dtype::FieldNames;
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult};
use vortex_dtype::{Field, FieldNames};
use vortex_error::{vortex_bail, vortex_err, VortexExpect as _, VortexResult};

use crate::{ExprRef, VortexExpr};

Expand Down Expand Up @@ -51,6 +51,33 @@ impl Pack {
}
Ok(Arc::new(Pack { names, values }))
}

pub fn names(&self) -> &FieldNames {
&self.names
}

pub fn field(&self, f: &Field) -> VortexResult<ExprRef> {
let idx = match f {
Field::Name(n) => self
.names
.iter()
.position(|name| name == n)
.ok_or_else(|| {
vortex_err!("Cannot find field {} in pack fields {:?}", n, self.names)
})?,
Field::Index(idx) => *idx,
};

self.values
.get(idx)
.cloned()
.ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
}
}

pub fn pack(names: impl Into<FieldNames>, values: Vec<ExprRef>) -> ExprRef {
Pack::try_new_expr(names.into(), values)
.vortex_expect("pack names and values have the same length")
}

impl PartialEq<dyn Any> for Pack {
Expand Down
8 changes: 8 additions & 0 deletions vortex-expr/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ pub struct Select {
child: ExprRef,
}

pub fn select(columns: Vec<Field>, child: ExprRef) -> ExprRef {
Select::include_expr(columns, child)
}

pub fn select_exclude(columns: Vec<Field>, child: ExprRef) -> ExprRef {
Select::exclude_expr(columns, child)
}

impl Select {
pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
Arc::new(Self { fields, child })
Expand Down
37 changes: 37 additions & 0 deletions vortex-expr/src/transform/expr_simplify.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use vortex_error::VortexResult;

use crate::traversal::{FoldChildren, FoldUp, FolderMut, Node};
use crate::{ExprRef, GetItem, Pack};

pub struct ExprSimplify();

impl ExprSimplify {
pub fn simplify(e: ExprRef) -> VortexResult<ExprRef> {
let mut folder = ExprSimplify();
e.transform_with_context(&mut folder, ())
.map(|e| e.result())
}
}

impl FolderMut for ExprSimplify {
type NodeTy = ExprRef;
type Out = ExprRef;
type Context = ();

fn visit_up(
&mut self,
node: Self::NodeTy,
_context: Self::Context,
children: FoldChildren<Self::Out>,
) -> VortexResult<FoldUp<Self::Out>> {
if let Some(get_item) = node.as_any().downcast_ref::<GetItem>() {
if let Some(pack) = get_item.child().as_any().downcast_ref::<Pack>() {
let expr = pack.field(get_item.field())?;
return Ok(FoldUp::Continue(expr));
}
}
Ok(FoldUp::Continue(
node.replacing_children(children.contained_children()),
))
}
}
97 changes: 97 additions & 0 deletions vortex-expr/src/transform/field_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use vortex_array::{ArrayDType, Canonical, IntoArrayData};
use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexResult};

use crate::traversal::{MutNodeVisitor, Node, TransformResult};
use crate::{ExprRef, GetItem};

pub struct FieldToNameTransform {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write a doc string please. And also one for new since it's not obvious to me what ident_dt means? It's too abbreviated

ident_dt: DType,
}

impl FieldToNameTransform {
fn new(ident_dt: DType) -> Self {
Self { ident_dt }
}

pub fn transform(expr: ExprRef, ident_dt: DType) -> VortexResult<ExprRef> {
let mut visitor = FieldToNameTransform::new(ident_dt);
expr.transform(&mut visitor).map(|node| node.result)
}
}

impl MutNodeVisitor for FieldToNameTransform {
type NodeTy = ExprRef;

fn visit_up(&mut self, node: Self::NodeTy) -> VortexResult<TransformResult<Self::NodeTy>> {
if let Some(get_item) = node.as_any().downcast_ref::<GetItem>() {
if get_item.field().is_named() {
return Ok(TransformResult::no(node));
}

// TODO(joe) expr::dtype
let child_dtype = get_item
.child()
.evaluate(&Canonical::empty(&self.ident_dt)?.into_array())?
.dtype()
.clone();

let DType::Struct(s_dtype, _) = child_dtype else {
vortex_bail!(
"get_item requires child to have struct dtype, however it was {}",
child_dtype
);
};

return Ok(TransformResult::yes(GetItem::new_expr(
get_item.field().clone().into_named_field(s_dtype.names())?,
get_item.child().clone(),
)));
}

Ok(TransformResult::no(node))
}
}

#[cfg(test)]
mod tests {
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::PType::I32;
use vortex_dtype::{DType, StructDType};

use crate::transform::field_type::FieldToNameTransform;
use crate::{get_item, ident};

#[test]
fn test_idx_to_name_expr() {
let dtype = DType::Struct(
StructDType::new(
vec!["a".into(), "b".into()].into(),
vec![
DType::Struct(
StructDType::new(
vec!["c".into(), "d".into()].into(),
vec![I32.into(), I32.into()],
),
NonNullable,
),
DType::Struct(
StructDType::new(
vec!["e".into(), "f".into()].into(),
vec![I32.into(), I32.into()],
),
NonNullable,
),
],
),
NonNullable,
);
let expr = get_item(1, get_item("a", ident()));
let new_expr = FieldToNameTransform::transform(expr, dtype.clone()).unwrap();
assert_eq!(&new_expr, &get_item("d", get_item("a", ident())));

let expr = get_item(0, get_item(1, ident()));
let new_expr = FieldToNameTransform::transform(expr, dtype).unwrap();
assert_eq!(&new_expr, &get_item("e", get_item("b", ident())));
}
}
5 changes: 5 additions & 0 deletions vortex-expr/src/transform/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#[allow(dead_code)]
mod expr_simplify;
mod field_type;
#[allow(dead_code)]
mod project_expr;
Loading
Loading