Skip to content

Commit

Permalink
chore: introduce ExprRef, teach expressions new_ref (#1258)
Browse files Browse the repository at this point in the history
I find this makes the pruner a bit easier to read.
  • Loading branch information
danking authored Nov 14, 2024
1 parent ff00dec commit 437392b
Show file tree
Hide file tree
Showing 18 changed files with 316 additions and 344 deletions.
19 changes: 6 additions & 13 deletions pyvortex/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pyo3::types::*;
use vortex::dtype::field::Field;
use vortex::dtype::half::f16;
use vortex::dtype::{DType, Nullability, PType};
use vortex::expr::{BinaryExpr, Column, Literal, Operator, VortexExpr};
use vortex::expr::{BinaryExpr, Column, ExprRef, Literal, Operator};
use vortex::scalar::{PValue, Scalar, ScalarValue};

use crate::dtype::PyDType;
Expand Down Expand Up @@ -119,11 +119,11 @@ use crate::dtype::PyDType;
/// ]
#[pyclass(name = "Expr", module = "vortex")]
pub struct PyExpr {
inner: Arc<dyn VortexExpr>,
inner: ExprRef,
}

impl PyExpr {
pub fn unwrap(&self) -> &Arc<dyn VortexExpr> {
pub fn unwrap(&self) -> &ExprRef {
&self.inner
}
}
Expand All @@ -136,11 +136,7 @@ fn py_binary_opeartor<'py>(
Bound::new(
left.py(),
PyExpr {
inner: Arc::new(BinaryExpr::new(
left.inner.clone(),
operator,
right.borrow().inner.clone(),
)),
inner: BinaryExpr::new_expr(left.inner.clone(), operator, right.borrow().inner.clone()),
},
)
}
Expand Down Expand Up @@ -252,7 +248,7 @@ pub fn column<'py>(name: &Bound<'py, PyString>) -> PyResult<Bound<'py, PyExpr>>
Bound::new(
py,
PyExpr {
inner: Arc::new(Column::new(Field::Name(name))),
inner: Column::new_expr(Field::Name(name)),
},
)
}
Expand All @@ -270,10 +266,7 @@ pub fn scalar<'py>(dtype: DType, value: &Bound<'py, PyAny>) -> PyResult<Bound<'p
Bound::new(
py,
PyExpr {
inner: Arc::new(Literal::new(Scalar::new(
dtype.clone(),
scalar_value(dtype, value)?,
))),
inner: Literal::new_expr(Scalar::new(dtype.clone(), scalar_value(dtype, value)?)),
},
)
}
Expand Down
4 changes: 2 additions & 2 deletions vortex-datafusion/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use vortex_array::arrow::infer_schema;
use vortex_array::{Array, ArrayDType as _};
use vortex_error::{VortexError, VortexExpect as _};
use vortex_expr::datafusion::convert_expr_to_vortex;
use vortex_expr::VortexExpr;
use vortex_expr::ExprRef;

use crate::plans::{RowSelectorExec, TakeRowsExec};
use crate::{can_be_pushed_down, VortexScanExec};
Expand Down Expand Up @@ -190,7 +190,7 @@ impl VortexMemTableOptions {
/// columns.
fn make_filter_then_take_plan(
schema: SchemaRef,
filter_expr: Arc<dyn VortexExpr>,
filter_expr: ExprRef,
chunked_array: ChunkedArray,
output_projection: Vec<usize>,
_session_state: &dyn Session,
Expand Down
11 changes: 4 additions & 7 deletions vortex-datafusion/src/plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ use vortex_array::compute::take;
use vortex_array::{Array, IntoArrayVariant, IntoCanonical};
use vortex_dtype::field::Field;
use vortex_error::{vortex_err, vortex_panic, VortexError};
use vortex_expr::VortexExpr;
use vortex_expr::ExprRef;

/// Physical plan operator that applies a set of [filters][Expr] against the input, producing a
/// row mask that can be used downstream to force a take against the corresponding struct array
/// chunks but for different columns.
pub(crate) struct RowSelectorExec {
filter_expr: Arc<dyn VortexExpr>,
filter_expr: ExprRef,
/// cached PlanProperties object. We do not make use of this.
cached_plan_props: PlanProperties,
/// Full array. We only access partitions of this data.
Expand All @@ -46,10 +46,7 @@ static ROW_SELECTOR_SCHEMA_REF: LazyLock<SchemaRef> = LazyLock::new(|| {
});

impl RowSelectorExec {
pub(crate) fn try_new(
filter_expr: Arc<dyn VortexExpr>,
chunked_array: &ChunkedArray,
) -> DFResult<Self> {
pub(crate) fn try_new(filter_expr: ExprRef, chunked_array: &ChunkedArray) -> DFResult<Self> {
let cached_plan_props = PlanProperties::new(
EquivalenceProperties::new(ROW_SELECTOR_SCHEMA_REF.clone()),
Partitioning::UnknownPartitioning(1),
Expand Down Expand Up @@ -134,7 +131,7 @@ impl ExecutionPlan for RowSelectorExec {
pub(crate) struct RowIndicesStream {
chunked_array: ChunkedArray,
chunk_idx: usize,
conjunction_expr: Arc<dyn VortexExpr>,
conjunction_expr: ExprRef,
filter_projection: Vec<Field>,
}

Expand Down
14 changes: 7 additions & 7 deletions vortex-expr/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@ use vortex_array::Array;
use vortex_dtype::field::Field;
use vortex_error::VortexResult;

use crate::{unbox_any, Operator, VortexExpr};
use crate::{unbox_any, ExprRef, Operator, VortexExpr};

#[derive(Debug, Clone)]
pub struct BinaryExpr {
lhs: Arc<dyn VortexExpr>,
lhs: ExprRef,
operator: Operator,
rhs: Arc<dyn VortexExpr>,
rhs: ExprRef,
}

impl BinaryExpr {
pub fn new(lhs: Arc<dyn VortexExpr>, operator: Operator, rhs: Arc<dyn VortexExpr>) -> Self {
Self { lhs, operator, rhs }
pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
Arc::new(Self { lhs, operator, rhs })
}

pub fn lhs(&self) -> &Arc<dyn VortexExpr> {
pub fn lhs(&self) -> &ExprRef {
&self.lhs
}

pub fn rhs(&self) -> &Arc<dyn VortexExpr> {
pub fn rhs(&self) -> &ExprRef {
&self.rhs
}

Expand Down
15 changes: 10 additions & 5 deletions vortex-expr/src/column.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::any::Any;
use std::fmt::Display;
use std::sync::Arc;

use vortex_array::aliases::hash_set::HashSet;
use vortex_array::array::StructArray;
Expand All @@ -8,16 +9,16 @@ use vortex_array::Array;
use vortex_dtype::field::Field;
use vortex_error::{vortex_err, VortexResult};

use crate::{unbox_any, VortexExpr};
use crate::{unbox_any, ExprRef, VortexExpr};

#[derive(Debug, PartialEq, Hash, Clone, Eq)]
pub struct Column {
field: Field,
}

impl Column {
pub fn new(field: Field) -> Self {
Self { field }
pub fn new_expr(field: Field) -> ExprRef {
Arc::new(Self { field })
}

pub fn field(&self) -> &Field {
Expand All @@ -27,13 +28,17 @@ impl Column {

impl From<String> for Column {
fn from(value: String) -> Self {
Column::new(value.into())
Column {
field: value.into(),
}
}
}

impl From<usize> for Column {
fn from(value: usize) -> Self {
Column::new(value.into())
Column {
field: value.into(),
}
}
}

Expand Down
11 changes: 4 additions & 7 deletions vortex-expr/src/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ use datafusion_physical_expr::{expressions, PhysicalExpr};
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_scalar::Scalar;

use crate::{BinaryExpr, Column, Literal, Operator, VortexExpr};

pub fn convert_expr_to_vortex(
physical_expr: Arc<dyn PhysicalExpr>,
) -> VortexResult<Arc<dyn VortexExpr>> {
use crate::{BinaryExpr, Column, ExprRef, Literal, Operator};
pub fn convert_expr_to_vortex(physical_expr: Arc<dyn PhysicalExpr>) -> VortexResult<ExprRef> {
if let Some(binary_expr) = physical_expr
.as_any()
.downcast_ref::<expressions::BinaryExpr>()
Expand All @@ -20,7 +17,7 @@ pub fn convert_expr_to_vortex(
let right = convert_expr_to_vortex(binary_expr.right().clone())?;
let operator = *binary_expr.op();

return Ok(Arc::new(BinaryExpr::new(left, operator.try_into()?, right)) as _);
return Ok(BinaryExpr::new_expr(left, operator.try_into()?, right));
}

if let Some(col_expr) = physical_expr.as_any().downcast_ref::<expressions::Column>() {
Expand All @@ -34,7 +31,7 @@ pub fn convert_expr_to_vortex(
.downcast_ref::<expressions::Literal>()
{
let value = Scalar::from(lit.value().clone());
return Ok(Arc::new(Literal::new(value)) as _);
return Ok(Literal::new_expr(value));
}

vortex_bail!("Couldn't convert DataFusion physical expression to a vortex expression")
Expand Down
74 changes: 41 additions & 33 deletions vortex-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use vortex_array::Array;
use vortex_dtype::field::Field;
use vortex_error::{VortexExpect, VortexResult};

pub type ExprRef = Arc<dyn VortexExpr>;

/// Represents logical operation on [`Array`]s
pub trait VortexExpr: Debug + Send + Sync + PartialEq<dyn Any> + Display {
/// Convert expression reference to reference of [`Any`] type
Expand All @@ -44,13 +46,13 @@ pub trait VortexExpr: Debug + Send + Sync + PartialEq<dyn Any> + Display {
}

/// Splits top level and operations into separate expressions
pub fn split_conjunction(expr: &Arc<dyn VortexExpr>) -> Vec<Arc<dyn VortexExpr>> {
pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
let mut conjunctions = vec![];
split_inner(expr, &mut conjunctions);
conjunctions
}

fn split_inner(expr: &Arc<dyn VortexExpr>, exprs: &mut Vec<Arc<dyn VortexExpr>>) {
fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
match expr.as_any().downcast_ref::<BinaryExpr>() {
Some(bexp) if bexp.op() == Operator::And => {
split_inner(bexp.lhs(), exprs);
Expand All @@ -64,9 +66,9 @@ fn split_inner(expr: &Arc<dyn VortexExpr>, exprs: &mut Vec<Arc<dyn VortexExpr>>)

// Taken from apache-datafusion, necessary since you can't require VortexExpr implement PartialEq<dyn VortexExpr>
pub fn unbox_any(any: &dyn Any) -> &dyn Any {
if any.is::<Arc<dyn VortexExpr>>() {
any.downcast_ref::<Arc<dyn VortexExpr>>()
.vortex_expect("any.is::<Arc<dyn VortexExpr>> returned true but downcast_ref failed")
if any.is::<ExprRef>() {
any.downcast_ref::<ExprRef>()
.vortex_expect("any.is::<ExprRef> returned true but downcast_ref failed")
.as_any()
} else if any.is::<Box<dyn VortexExpr>>() {
any.downcast_ref::<Box<dyn VortexExpr>>()
Expand All @@ -87,75 +89,78 @@ mod tests {

#[test]
fn basic_expr_split_test() {
let lhs = Arc::new(Column::new(Field::Name("a".to_string()))) as _;
let rhs = Arc::new(Literal::new(1.into())) as _;
let expr = Arc::new(BinaryExpr::new(lhs, Operator::Eq, rhs)) as _;
let lhs = Column::new_expr(Field::Name("a".to_string()));
let rhs = Literal::new_expr(1.into());
let expr = BinaryExpr::new_expr(lhs, Operator::Eq, rhs);
let conjunction = split_conjunction(&expr);
assert_eq!(conjunction.len(), 1);
}

#[test]
fn basic_conjunction_split_test() {
let lhs = Arc::new(Column::new(Field::Name("a".to_string()))) as _;
let rhs = Arc::new(Literal::new(1.into())) as _;
let expr = Arc::new(BinaryExpr::new(lhs, Operator::And, rhs)) as _;
let lhs = Column::new_expr(Field::Name("a".to_string()));
let rhs = Literal::new_expr(1.into());
let expr = BinaryExpr::new_expr(lhs, Operator::And, rhs);
let conjunction = split_conjunction(&expr);
assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
}

#[test]
fn expr_display() {
assert_eq!(Column::new(Field::Name("a".to_string())).to_string(), "$a");
assert_eq!(Column::new(Field::Index(1)).to_string(), "[1]");
assert_eq!(
Column::new_expr(Field::Name("a".to_string())).to_string(),
"$a"
);
assert_eq!(Column::new_expr(Field::Index(1)).to_string(), "[1]");
assert_eq!(Identity.to_string(), "[]");
assert_eq!(Identity.to_string(), "[]");

let col1: Arc<dyn VortexExpr> = Arc::new(Column::new(Field::Name("col1".to_string())));
let col2: Arc<dyn VortexExpr> = Arc::new(Column::new(Field::Name("col2".to_string())));
let col1: Arc<dyn VortexExpr> = Column::new_expr(Field::Name("col1".to_string()));
let col2: Arc<dyn VortexExpr> = Column::new_expr(Field::Name("col2".to_string()));
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::And, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::And, col2.clone()).to_string(),
"($col1 and $col2)"
);
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::Or, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::Or, col2.clone()).to_string(),
"($col1 or $col2)"
);
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::Eq, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::Eq, col2.clone()).to_string(),
"($col1 = $col2)"
);
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::NotEq, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::NotEq, col2.clone()).to_string(),
"($col1 != $col2)"
);
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::Gt, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::Gt, col2.clone()).to_string(),
"($col1 > $col2)"
);
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::Gte, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::Gte, col2.clone()).to_string(),
"($col1 >= $col2)"
);
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::Lt, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::Lt, col2.clone()).to_string(),
"($col1 < $col2)"
);
assert_eq!(
BinaryExpr::new(col1.clone(), Operator::Lte, col2.clone()).to_string(),
BinaryExpr::new_expr(col1.clone(), Operator::Lte, col2.clone()).to_string(),
"($col1 <= $col2)"
);

assert_eq!(
BinaryExpr::new(
Arc::new(BinaryExpr::new(col1.clone(), Operator::Lt, col2.clone())),
BinaryExpr::new_expr(
BinaryExpr::new_expr(col1.clone(), Operator::Lt, col2.clone()),
Operator::Or,
Arc::new(BinaryExpr::new(col1.clone(), Operator::NotEq, col2.clone()))
BinaryExpr::new_expr(col1.clone(), Operator::NotEq, col2.clone())
)
.to_string(),
"(($col1 < $col2) or ($col1 != $col2))"
);

assert_eq!(Not::new(col1.clone()).to_string(), "!$col1");
assert_eq!(Not::new_expr(col1.clone()).to_string(), "!$col1");

assert_eq!(
Select::include(vec![Field::Name("col1".to_string())]).to_string(),
Expand All @@ -179,20 +184,23 @@ mod tests {
"Exclude($col1,$col2,[1])"
);

assert_eq!(Literal::new(Scalar::from(0_u8)).to_string(), "0_u8");
assert_eq!(Literal::new(Scalar::from(0.0_f32)).to_string(), "0_f32");
assert_eq!(Literal::new_expr(Scalar::from(0_u8)).to_string(), "0_u8");
assert_eq!(
Literal::new_expr(Scalar::from(0.0_f32)).to_string(),
"0_f32"
);
assert_eq!(
Literal::new(Scalar::from(i64::MAX)).to_string(),
Literal::new_expr(Scalar::from(i64::MAX)).to_string(),
"9223372036854775807_i64"
);
assert_eq!(Literal::new(Scalar::from(true)).to_string(), "true");
assert_eq!(Literal::new_expr(Scalar::from(true)).to_string(), "true");
assert_eq!(
Literal::new(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
Literal::new_expr(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
"null"
);

assert_eq!(
Literal::new(Scalar::new(
Literal::new_expr(Scalar::new(
DType::Struct(
StructDType::new(
Arc::from([Arc::from("dog"), Arc::from("cat")]),
Expand Down
Loading

0 comments on commit 437392b

Please sign in to comment.