Skip to content

Commit

Permalink
refactor: Move Bitwise aggregations to FunctionExpr (#20193)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Dec 6, 2024
1 parent 579d8fb commit 19939ae
Show file tree
Hide file tree
Showing 18 changed files with 51 additions and 152 deletions.
32 changes: 1 addition & 31 deletions crates/polars-core/src/frame/group_by/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -875,22 +875,10 @@ pub enum GroupByMethod {
Groups,
NUnique,
Quantile(f64, QuantileMethod),
Count {
include_nulls: bool,
},
Count { include_nulls: bool },
Implode,
Std(u8),
Var(u8),
#[cfg(feature = "bitwise")]
Bitwise(GroupByBitwiseMethod),
}

#[cfg(feature = "bitwise")]
#[derive(Copy, Clone, Debug)]
pub enum GroupByBitwiseMethod {
And,
Or,
Xor,
}

impl Display for GroupByMethod {
Expand All @@ -913,27 +901,11 @@ impl Display for GroupByMethod {
Implode => "list",
Std(_) => "std",
Var(_) => "var",
#[cfg(feature = "bitwise")]
Bitwise(t) => {
f.write_str("bitwise_")?;
return Display::fmt(t, f);
},
};
write!(f, "{s}")
}
}

#[cfg(feature = "bitwise")]
impl Display for GroupByBitwiseMethod {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::And => f.write_str("and"),
Self::Or => f.write_str("or"),
Self::Xor => f.write_str("xor"),
}
}
}

// Formatting functions used in eager and lazy code for renaming grouped columns
pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
use GroupByMethod::*;
Expand All @@ -954,8 +926,6 @@ pub fn fmt_group_by_column(name: &str, method: GroupByMethod) -> PlSmallStr {
Quantile(quantile, _interpol) => format_pl_smallstr!("{name}_quantile_{quantile:.2}"),
Std(_) => format_pl_smallstr!("{name}_agg_std"),
Var(_) => format_pl_smallstr!("{name}_agg_var"),
#[cfg(feature = "bitwise")]
Bitwise(f) => format_pl_smallstr!("{name}_agg_bitwise_{f}"),
}
}

Expand Down
28 changes: 0 additions & 28 deletions crates/polars-expr/src/expressions/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,6 @@ impl PhysicalExpr for AggregationExpr {
.var_reduce(ddof)
.map(|sc| sc.into_column(s.name().clone())),
GroupByMethod::Quantile(_, _) => unimplemented!(),
#[cfg(feature = "bitwise")]
GroupByMethod::Bitwise(f) => match f {
GroupByBitwiseMethod::And => parallel_op_columns(
|s| s.and_reduce().map(|sc| sc.into_column(s.name().clone())),
s,
allow_threading,
),
GroupByBitwiseMethod::Or => parallel_op_columns(
|s| s.or_reduce().map(|sc| sc.into_column(s.name().clone())),
s,
allow_threading,
),
GroupByBitwiseMethod::Xor => parallel_op_columns(
|s| s.xor_reduce().map(|sc| sc.into_column(s.name().clone())),
s,
allow_threading,
),
},
}
}
#[allow(clippy::ptr_arg)]
Expand Down Expand Up @@ -429,16 +411,6 @@ impl PhysicalExpr for AggregationExpr {
// implemented explicitly in AggQuantile struct
unimplemented!()
},
#[cfg(feature = "bitwise")]
GroupByMethod::Bitwise(f) => {
let (c, groups) = ac.get_final_aggregation();
let agg_c = match f {
GroupByBitwiseMethod::And => c.agg_and(&groups),
GroupByBitwiseMethod::Or => c.agg_or(&groups),
GroupByBitwiseMethod::Xor => c.agg_xor(&groups),
};
AggregatedScalar(agg_c.with_name(keep_name))
},
GroupByMethod::NanMin => {
#[cfg(feature = "propagate_nans")]
{
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,6 @@ fn create_physical_expr_inner(
},
I::Std(_, ddof) => GBM::Std(*ddof),
I::Var(_, ddof) => GBM::Var(*ddof),
#[cfg(feature = "bitwise")]
I::Bitwise(_, f) => GBM::Bitwise((*f).into()),
I::AggGroups(_) => {
polars_bail!(InvalidOperation: "agg groups expression only supported in aggregation context")
},
Expand Down
22 changes: 16 additions & 6 deletions crates/polars-plan/src/dsl/bitwise.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::sync::Arc;

use super::{AggExpr, BitwiseAggFunction, BitwiseFunction, Expr, FunctionExpr};
use super::{BitwiseFunction, Expr, FunctionExpr, FunctionFlags};

impl Expr {
/// Evaluate the number of set bits.
Expand Down Expand Up @@ -35,16 +33,28 @@ impl Expr {

/// Perform an aggregation of bitwise ANDs
pub fn bitwise_and(self) -> Self {
Expr::Agg(AggExpr::Bitwise(Arc::new(self), BitwiseAggFunction::And))
self.apply_private(FunctionExpr::Bitwise(BitwiseFunction::And))
.with_function_options(|mut options| {
options.flags |= FunctionFlags::RETURNS_SCALAR;
options
})
}

/// Perform an aggregation of bitwise ORs
pub fn bitwise_or(self) -> Self {
Expr::Agg(AggExpr::Bitwise(Arc::new(self), BitwiseAggFunction::Or))
self.apply_private(FunctionExpr::Bitwise(BitwiseFunction::Or))
.with_function_options(|mut options| {
options.flags |= FunctionFlags::RETURNS_SCALAR;
options
})
}

/// Perform an aggregation of bitwise XORs
pub fn bitwise_xor(self) -> Self {
Expr::Agg(AggExpr::Bitwise(Arc::new(self), BitwiseAggFunction::Xor))
self.apply_private(FunctionExpr::Bitwise(BitwiseFunction::Xor))
.with_function_options(|mut options| {
options.flags |= FunctionFlags::RETURNS_SCALAR;
options
})
}
}
4 changes: 0 additions & 4 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ pub enum AggExpr {
AggGroups(Arc<Expr>),
Std(Arc<Expr>, u8),
Var(Arc<Expr>, u8),
#[cfg(feature = "bitwise")]
Bitwise(Arc<Expr>, super::function_expr::BitwiseAggFunction),
}

impl AsRef<Expr> for AggExpr {
Expand All @@ -61,8 +59,6 @@ impl AsRef<Expr> for AggExpr {
AggGroups(e) => e,
Std(e, _) => e,
Var(e, _) => e,
#[cfg(feature = "bitwise")]
Bitwise(e, _) => e,
}
}
}
Expand Down
49 changes: 33 additions & 16 deletions crates/polars-plan/src/dsl/function_expr/bitwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use crate::dsl::FieldsMapper;
use crate::map;

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash, IntoStaticStr)]
#[strum(serialize_all = "snake_case")]
pub enum BitwiseFunction {
CountOnes,
CountZeros,
Expand All @@ -19,12 +20,8 @@ pub enum BitwiseFunction {

TrailingOnes,
TrailingZeros,
}

#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash, IntoStaticStr)]
#[strum(serialize_all = "snake_case")]
pub enum BitwiseAggFunction {
// Bitwise Aggregations
And,
Or,
Xor,
Expand All @@ -41,6 +38,10 @@ impl fmt::Display for BitwiseFunction {
B::LeadingZeros => "leading_zeros",
B::TrailingOnes => "trailing_ones",
B::TrailingZeros => "trailing_zeros",

B::And => "and",
B::Or => "or",
B::Xor => "xor",
};

f.write_str(s)
Expand All @@ -58,16 +59,10 @@ impl From<BitwiseFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
B::LeadingZeros => map!(leading_zeros),
B::TrailingOnes => map!(trailing_ones),
B::TrailingZeros => map!(trailing_zeros),
}
}
}

impl From<BitwiseAggFunction> for GroupByBitwiseMethod {
fn from(value: BitwiseAggFunction) -> Self {
match value {
BitwiseAggFunction::And => Self::And,
BitwiseAggFunction::Or => Self::Or,
BitwiseAggFunction::Xor => Self::Xor,
B::And => map!(reduce_and),
B::Or => map!(reduce_or),
B::Xor => map!(reduce_xor),
}
}
}
Expand All @@ -86,7 +81,17 @@ impl BitwiseFunction {
polars_bail!(InvalidOperation: "dtype {} not supported in '{}' operation", dtype, self);
}

Ok(DataType::UInt32)
match self {
Self::CountOnes |
Self::CountZeros |
Self::LeadingOnes |
Self::LeadingZeros |
Self::TrailingOnes |
Self::TrailingZeros => Ok(DataType::UInt32),
Self::And |
Self::Or |
Self::Xor => Ok(dtype.clone()),
}
})
}
}
Expand Down Expand Up @@ -114,3 +119,15 @@ fn trailing_ones(c: &Column) -> PolarsResult<Column> {
fn trailing_zeros(c: &Column) -> PolarsResult<Column> {
c.try_apply_unary_elementwise(polars_ops::series::trailing_zeros)
}

fn reduce_and(c: &Column) -> PolarsResult<Column> {
c.and_reduce().map(|v| v.into_column(c.name().clone()))
}

fn reduce_or(c: &Column) -> PolarsResult<Column> {
c.or_reduce().map(|v| v.into_column(c.name().clone()))
}

fn reduce_xor(c: &Column) -> PolarsResult<Column> {
c.xor_reduce().map(|v| v.into_column(c.name().clone()))
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ use serde::{Deserialize, Serialize};

pub(crate) use self::binary::BinaryFunction;
#[cfg(feature = "bitwise")]
pub use self::bitwise::{BitwiseAggFunction, BitwiseFunction};
pub use self::bitwise::BitwiseFunction;
pub use self::boolean::BooleanFunction;
#[cfg(feature = "business")]
pub(super) use self::business::BusinessFunction;
Expand Down
8 changes: 0 additions & 8 deletions crates/polars-plan/src/plans/aexpr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ pub enum IRAggExpr {
Count(Node, bool),
Std(Node, u8),
Var(Node, u8),
#[cfg(feature = "bitwise")]
Bitwise(Node, BitwiseAggFunction),
AggGroups(Node),
}

Expand All @@ -67,8 +65,6 @@ impl Hash for IRAggExpr {
method: interpol, ..
} => interpol.hash(state),
Self::Std(_, v) | Self::Var(_, v) => v.hash(state),
#[cfg(feature = "bitwise")]
Self::Bitwise(_, f) => f.hash(state),
_ => {},
}
}
Expand Down Expand Up @@ -98,8 +94,6 @@ impl IRAggExpr {
(Quantile { method: l, .. }, Quantile { method: r, .. }) => l == r,
(Std(_, l), Std(_, r)) => l == r,
(Var(_, l), Var(_, r)) => l == r,
#[cfg(feature = "bitwise")]
(Bitwise(_, l), Bitwise(_, r)) => l == r,
_ => std::mem::discriminant(self) == std::mem::discriminant(other),
}
}
Expand Down Expand Up @@ -133,8 +127,6 @@ impl From<IRAggExpr> for GroupByMethod {
Count(_, include_nulls) => GroupByMethod::Count { include_nulls },
Std(_, ddof) => GroupByMethod::Std(ddof),
Var(_, ddof) => GroupByMethod::Var(ddof),
#[cfg(feature = "bitwise")]
Bitwise(_, f) => GroupByMethod::Bitwise(f.into()),
AggGroups(_) => GroupByMethod::Groups,
Quantile { .. } => unreachable!(),
}
Expand Down
7 changes: 0 additions & 7 deletions crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,6 @@ impl AExpr {
float_type(&mut field);
Ok(field)
},
#[cfg(feature = "bitwise")]
Bitwise(expr, _) => {
*agg_list = false;
let field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
// @Q? Do we need to coerce here?
Ok(field)
},
}
},
Cast { expr, dtype, .. } => {
Expand Down
4 changes: 0 additions & 4 deletions crates/polars-plan/src/plans/aexpr/traverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,6 @@ impl IRAggExpr {
Std(input, _) => Single(*input),
Var(input, _) => Single(*input),
AggGroups(input) => Single(*input),
#[cfg(feature = "bitwise")]
Bitwise(input, _) => Single(*input),
}
}
pub fn set_input(&mut self, input: Node) {
Expand All @@ -205,8 +203,6 @@ impl IRAggExpr {
Std(input, _) => input,
Var(input, _) => input,
AggGroups(input) => input,
#[cfg(feature = "bitwise")]
Bitwise(input, _) => input,
};
*node = input;
}
Expand Down
5 changes: 0 additions & 5 deletions crates/polars-plan/src/plans/conversion/expr_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,6 @@ pub(super) fn to_aexpr_impl(
AggExpr::AggGroups(expr) => {
IRAggExpr::AggGroups(to_aexpr_impl_materialized_lit(owned(expr), arena, state)?)
},
#[cfg(feature = "bitwise")]
AggExpr::Bitwise(expr, f) => IRAggExpr::Bitwise(
to_aexpr_impl_materialized_lit(owned(expr), arena, state)?,
f,
),
};
AExpr::Agg(a_agg)
},
Expand Down
5 changes: 0 additions & 5 deletions crates/polars-plan/src/plans/conversion/ir_to_dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,6 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
let expr = node_to_expr(expr, expr_arena);
AggExpr::Count(Arc::new(expr), include_nulls).into()
},
#[cfg(feature = "bitwise")]
IRAggExpr::Bitwise(expr, f) => {
let expr = node_to_expr(expr, expr_arena);
AggExpr::Bitwise(Arc::new(expr), f).into()
},
},
AExpr::Ternary {
predicate,
Expand Down
10 changes: 0 additions & 10 deletions crates/polars-plan/src/plans/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,6 @@ impl fmt::Debug for Expr {
Var(expr, _) => write!(f, "{expr:?}.var()"),
Std(expr, _) => write!(f, "{expr:?}.std()"),
Quantile { expr, .. } => write!(f, "{expr:?}.quantile()"),
#[cfg(feature = "bitwise")]
Bitwise(expr, t) => {
let t = match t {
BitwiseAggFunction::And => "and",
BitwiseAggFunction::Or => "or",
BitwiseAggFunction::Xor => "xor",
};

write!(f, "{expr:?}.bitwise.{t}()")
},
}
},
Cast {
Expand Down
10 changes: 0 additions & 10 deletions crates/polars-plan/src/plans/ir/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,16 +591,6 @@ impl Display for ExprIRDisplay<'_> {
Var(expr, _) => write!(f, "{}.var()", self.with_root(expr)),
Std(expr, _) => write!(f, "{}.std()", self.with_root(expr)),
Quantile { expr, .. } => write!(f, "{}.quantile()", self.with_root(expr)),
#[cfg(feature = "bitwise")]
Bitwise(expr, t) => {
let t = match t {
BitwiseAggFunction::And => "and",
BitwiseAggFunction::Or => "or",
BitwiseAggFunction::Xor => "xor",
};

write!(f, "{}.bitwise.{t}()", self.with_root(expr))
},
}
},
Cast {
Expand Down
Loading

0 comments on commit 19939ae

Please sign in to comment.