Skip to content

Commit

Permalink
[REFACTOR] connect: to_daft_* use ref instead of value
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent 1c7d6f5 commit 863ac08
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 27 deletions.
12 changes: 6 additions & 6 deletions src/daft-connect/src/translation/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use crate::translation::to_daft_literal;

mod unresolved_function;

pub fn to_daft_expr(expression: Expression) -> eyre::Result<daft_dsl::ExprRef> {
if let Some(common) = expression.common {
pub fn to_daft_expr(expression: &Expression) -> eyre::Result<daft_dsl::ExprRef> {
if let Some(common) = &expression.common {
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented");
};

let Some(expr) = expression.expr_type else {
let Some(expr) = &expression.expr_type else {
bail!("Expression is required");
};

Expand All @@ -35,7 +35,7 @@ pub fn to_daft_expr(expression: Expression) -> eyre::Result<daft_dsl::ExprRef> {
warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented");
}

Ok(daft_dsl::col(unparsed_identifier))
Ok(daft_dsl::col(unparsed_identifier.as_str()))
}
spark_expr::ExprType::UnresolvedFunction(f) => {
unresolved_to_daft_expr(f).wrap_err("Failed to handle unresolved function")
Expand All @@ -49,7 +49,7 @@ pub fn to_daft_expr(expression: Expression) -> eyre::Result<daft_dsl::ExprRef> {
expr,
name,
metadata,
} = *alias;
} = &**alias;

let Some(expr) = expr else {
bail!("Alias expr is required");
Expand All @@ -63,7 +63,7 @@ pub fn to_daft_expr(expression: Expression) -> eyre::Result<daft_dsl::ExprRef> {
bail!("Alias metadata is not yet supported; got {metadata:?}");
}

let child = to_daft_expr(*expr)?;
let child = to_daft_expr(expr)?;

let name = Arc::from(name.as_str());

Expand Down
9 changes: 4 additions & 5 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
use daft_core::count_mode::CountMode;
use daft_schema::dtype::DataType;
use eyre::{bail, Context};
use spark_connect::expression::UnresolvedFunction;

use crate::translation::to_daft_expr;

pub fn unresolved_to_daft_expr(f: UnresolvedFunction) -> eyre::Result<daft_dsl::ExprRef> {
pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl::ExprRef> {
let UnresolvedFunction {
function_name,
arguments,
is_distinct,
is_user_defined_function,
} = f;

let arguments: Vec<_> = arguments.into_iter().map(to_daft_expr).try_collect()?;
let arguments: Vec<_> = arguments.iter().map(to_daft_expr).try_collect()?;

if is_distinct {
if *is_distinct {
bail!("Distinct not yet supported");
}

if is_user_defined_function {
if *is_user_defined_function {
bail!("User-defined functions not yet supported");
}

Expand Down
26 changes: 13 additions & 13 deletions src/daft-connect/src/translation/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,57 @@ use eyre::{bail, WrapErr};
use spark_connect::expression::{literal::LiteralType, Literal};

// todo(test): add tests for this esp in Python
pub fn to_daft_literal(literal: Literal) -> eyre::Result<daft_dsl::ExprRef> {
let Some(literal) = literal.literal_type else {
pub fn to_daft_literal(literal: &Literal) -> eyre::Result<daft_dsl::ExprRef> {
let Some(literal) = &literal.literal_type else {
bail!("Literal is required");
};

match literal {
LiteralType::Array(_) => bail!("Array literals not yet supported"),
LiteralType::Binary(bytes) => Ok(daft_dsl::lit(bytes.as_slice())),
LiteralType::Boolean(b) => Ok(daft_dsl::lit(b)),
LiteralType::Boolean(b) => Ok(daft_dsl::lit(*b)),
LiteralType::Byte(b) => {
// todo(correctness): is this signed or unsigned?
let b = i8::try_from(b).wrap_err_with(|| format!("Byte value {b} is out of range"))?;
let b = i8::try_from(*b).wrap_err_with(|| format!("Byte value {b} is out of range"))?;
let b = i32::from(b);

Ok(daft_dsl::lit::<i32>(b))
}
LiteralType::CalendarInterval(_) => {
bail!("Calendar interval literals not yet supported")
}
LiteralType::Date(d) => Ok(daft_dsl::lit(d)),
LiteralType::Date(d) => Ok(daft_dsl::lit(*d)),
LiteralType::DayTimeInterval(_) => {
bail!("Day-time interval literals not yet supported")
}
LiteralType::Decimal(_) => bail!("Decimal literals not yet supported"),
LiteralType::Double(d) => Ok(daft_dsl::lit(d)),
LiteralType::Double(d) => Ok(daft_dsl::lit(*d)),
LiteralType::Float(f) => {
let f = f64::from(f);
let f = f64::from(*f);
Ok(daft_dsl::lit(f))
}
LiteralType::Integer(i) => Ok(daft_dsl::lit(i)),
LiteralType::Long(l) => Ok(daft_dsl::lit(l)),
LiteralType::Integer(i) => Ok(daft_dsl::lit(*i)),
LiteralType::Long(l) => Ok(daft_dsl::lit(*l)),
LiteralType::Map(_) => bail!("Map literals not yet supported"),
LiteralType::Null(_) => {
// todo(correctness): is it ok to assume type is i32 here?
Ok(daft_dsl::lit(None::<i32>))
}
LiteralType::Short(s) => {
let short =
i16::try_from(s).wrap_err_with(|| format!("Short value {s} is out of range"))?;
i16::try_from(*s).wrap_err_with(|| format!("Short value {s} is out of range"))?;

Ok(daft_dsl::lit(i32::from(short)))
}
LiteralType::String(s) => Ok(daft_dsl::lit(s)),
LiteralType::String(s) => Ok(daft_dsl::lit(s.as_str())),
LiteralType::Struct(_) => bail!("Struct literals not yet supported"),
LiteralType::Timestamp(ts) => {
// todo(correctness): is it ok that the type is different logically?
Ok(daft_dsl::lit(ts))
Ok(daft_dsl::lit(*ts))
}
LiteralType::TimestampNtz(ts) => {
// todo(correctness): is it ok that the type is different logically?
Ok(daft_dsl::lit(ts))
Ok(daft_dsl::lit(*ts))
}
LiteralType::YearMonthInterval(_) => {
bail!("Year-month interval literals not yet supported")
Expand Down
4 changes: 2 additions & 2 deletions src/daft-connect/src/translation/logical_plan/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result<LogicalPla
.wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?;

let grouping_expressions: Vec<_> = grouping_expressions
.into_iter()
.iter()
.map(to_daft_expr)
.try_collect()?;

let aggregate_expressions: Vec<_> = aggregate_expressions
.into_iter()
.iter()
.map(to_daft_expr)
.try_collect()?;

Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation/logical_plan/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub fn project(project: Project) -> eyre::Result<LogicalPlanBuilder> {

let plan = to_logical_plan(*input)?;

let daft_exprs: Vec<_> = expressions.into_iter().map(to_daft_expr).try_collect()?;
let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?;

let plan = plan.select(daft_exprs)?;

Expand Down

0 comments on commit 863ac08

Please sign in to comment.