-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] connect: add alias support #3342
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
//! Translation between Spark Connect and Daft | ||
|
||
mod datatype; | ||
mod expr; | ||
mod literal; | ||
mod logical_plan; | ||
mod schema; | ||
|
||
pub use datatype::to_spark_datatype; | ||
pub use expr::to_daft_expr; | ||
pub use literal::to_daft_literal; | ||
pub use logical_plan::to_logical_plan; | ||
pub use schema::relation_to_schema; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
use daft_schema::dtype::DataType; | ||
use spark_connect::data_type::Kind; | ||
use tracing::warn; | ||
|
||
pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { | ||
match datatype { | ||
DataType::Null => spark_connect::DataType { | ||
kind: Some(Kind::Null(spark_connect::data_type::Null { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Boolean => spark_connect::DataType { | ||
kind: Some(Kind::Boolean(spark_connect::data_type::Boolean { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int8 => spark_connect::DataType { | ||
kind: Some(Kind::Byte(spark_connect::data_type::Byte { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int16 => spark_connect::DataType { | ||
kind: Some(Kind::Short(spark_connect::data_type::Short { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int32 => spark_connect::DataType { | ||
kind: Some(Kind::Integer(spark_connect::data_type::Integer { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Int64 => spark_connect::DataType { | ||
kind: Some(Kind::Long(spark_connect::data_type::Long { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt8 => spark_connect::DataType { | ||
kind: Some(Kind::Byte(spark_connect::data_type::Byte { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt16 => spark_connect::DataType { | ||
kind: Some(Kind::Short(spark_connect::data_type::Short { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt32 => spark_connect::DataType { | ||
kind: Some(Kind::Integer(spark_connect::data_type::Integer { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::UInt64 => spark_connect::DataType { | ||
kind: Some(Kind::Long(spark_connect::data_type::Long { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Float32 => spark_connect::DataType { | ||
kind: Some(Kind::Float(spark_connect::data_type::Float { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Float64 => spark_connect::DataType { | ||
kind: Some(Kind::Double(spark_connect::data_type::Double { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Decimal128(precision, scale) => spark_connect::DataType { | ||
kind: Some(Kind::Decimal(spark_connect::data_type::Decimal { | ||
scale: Some(*scale as i32), | ||
precision: Some(*precision as i32), | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Timestamp(unit, _) => { | ||
warn!("Ignoring time unit {unit:?} for timestamp type"); | ||
spark_connect::DataType { | ||
kind: Some(Kind::Timestamp(spark_connect::data_type::Timestamp { | ||
type_variation_reference: 0, | ||
})), | ||
} | ||
} | ||
DataType::Date => spark_connect::DataType { | ||
kind: Some(Kind::Date(spark_connect::data_type::Date { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Binary => spark_connect::DataType { | ||
kind: Some(Kind::Binary(spark_connect::data_type::Binary { | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
DataType::Utf8 => spark_connect::DataType { | ||
kind: Some(Kind::String(spark_connect::data_type::String { | ||
type_variation_reference: 0, | ||
collation: String::new(), // todo(correctness): is this correct? | ||
})), | ||
}, | ||
DataType::Struct(fields) => spark_connect::DataType { | ||
kind: Some(Kind::Struct(spark_connect::data_type::Struct { | ||
fields: fields | ||
.iter() | ||
.map(|f| spark_connect::data_type::StructField { | ||
name: f.name.clone(), | ||
data_type: Some(to_spark_datatype(&f.dtype)), | ||
nullable: true, // todo(correctness): is this correct? | ||
metadata: None, | ||
}) | ||
.collect(), | ||
type_variation_reference: 0, | ||
})), | ||
}, | ||
_ => unimplemented!("Unsupported datatype: {datatype:?}"), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
use std::sync::Arc; | ||
|
||
use eyre::{bail, Context}; | ||
use spark_connect::{expression as spark_expr, Expression}; | ||
use tracing::warn; | ||
use unresolved_function::unresolved_to_daft_expr; | ||
|
||
use crate::translation::to_daft_literal; | ||
|
||
mod unresolved_function; | ||
|
||
pub fn to_daft_expr(expression: Expression) -> eyre::Result<daft_dsl::ExprRef> { | ||
if let Some(common) = expression.common { | ||
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); | ||
}; | ||
|
||
let Some(expr) = expression.expr_type else { | ||
bail!("Expression is required"); | ||
}; | ||
|
||
match expr { | ||
spark_expr::ExprType::Literal(l) => to_daft_literal(l), | ||
spark_expr::ExprType::UnresolvedAttribute(attr) => { | ||
let spark_expr::UnresolvedAttribute { | ||
unparsed_identifier, | ||
plan_id, | ||
is_metadata_column, | ||
} = attr; | ||
|
||
if let Some(plan_id) = plan_id { | ||
warn!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented"); | ||
} | ||
|
||
if let Some(is_metadata_column) = is_metadata_column { | ||
warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); | ||
} | ||
|
||
Ok(daft_dsl::col(unparsed_identifier)) | ||
} | ||
spark_expr::ExprType::UnresolvedFunction(f) => { | ||
unresolved_to_daft_expr(f).wrap_err("Failed to handle unresolved function") | ||
} | ||
spark_expr::ExprType::ExpressionString(_) => bail!("Expression string not yet supported"), | ||
spark_expr::ExprType::UnresolvedStar(_) => { | ||
bail!("Unresolved star expressions not yet supported") | ||
} | ||
spark_expr::ExprType::Alias(alias) => { | ||
let spark_expr::Alias { | ||
expr, | ||
name, | ||
metadata, | ||
} = *alias; | ||
|
||
let Some(expr) = expr else { | ||
bail!("Alias expr is required"); | ||
}; | ||
|
||
let [name] = name.as_slice() else { | ||
bail!("Alias name is required and currently only works with a single string; got {name:?}"); | ||
}; | ||
|
||
if let Some(metadata) = metadata { | ||
bail!("Alias metadata is not yet supported; got {metadata:?}"); | ||
} | ||
|
||
let child = to_daft_expr(*expr)?; | ||
|
||
let name = Arc::from(name.as_str()); | ||
|
||
Ok(child.alias(name)) | ||
} | ||
spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"), | ||
spark_expr::ExprType::UnresolvedRegex(_) => { | ||
bail!("Unresolved regex expressions not yet supported") | ||
} | ||
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"), | ||
spark_expr::ExprType::LambdaFunction(_) => { | ||
bail!("Lambda function expressions not yet supported") | ||
} | ||
spark_expr::ExprType::Window(_) => bail!("Window expressions not yet supported"), | ||
spark_expr::ExprType::UnresolvedExtractValue(_) => { | ||
bail!("Unresolved extract value expressions not yet supported") | ||
} | ||
spark_expr::ExprType::UpdateFields(_) => { | ||
bail!("Update fields expressions not yet supported") | ||
} | ||
spark_expr::ExprType::UnresolvedNamedLambdaVariable(_) => { | ||
bail!("Unresolved named lambda variable expressions not yet supported") | ||
} | ||
spark_expr::ExprType::CommonInlineUserDefinedFunction(_) => { | ||
bail!("Common inline user defined function expressions not yet supported") | ||
} | ||
spark_expr::ExprType::CallFunction(_) => { | ||
bail!("Call function expressions not yet supported") | ||
} | ||
spark_expr::ExprType::NamedArgumentExpression(_) => { | ||
bail!("Named argument expressions not yet supported") | ||
} | ||
spark_expr::ExprType::MergeAction(_) => bail!("Merge action expressions not yet supported"), | ||
spark_expr::ExprType::TypedAggregateExpression(_) => { | ||
bail!("Typed aggregate expressions not yet supported") | ||
} | ||
spark_expr::ExprType::Extension(_) => bail!("Extension expressions not yet supported"), | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
src/daft-connect/src/translation/expr/unresolved_function.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
use daft_core::count_mode::CountMode; | ||
use eyre::{bail, Context}; | ||
use spark_connect::expression::UnresolvedFunction; | ||
|
||
use crate::translation::to_daft_expr; | ||
|
||
pub fn unresolved_to_daft_expr(f: UnresolvedFunction) -> eyre::Result<daft_dsl::ExprRef> { | ||
let UnresolvedFunction { | ||
function_name, | ||
arguments, | ||
is_distinct, | ||
is_user_defined_function, | ||
} = f; | ||
|
||
let arguments: Vec<_> = arguments.into_iter().map(to_daft_expr).try_collect()?; | ||
|
||
if is_distinct { | ||
bail!("Distinct not yet supported"); | ||
} | ||
|
||
if is_user_defined_function { | ||
bail!("User-defined functions not yet supported"); | ||
} | ||
|
||
match function_name.as_str() { | ||
"count" => handle_count(arguments).wrap_err("Failed to handle count function"), | ||
n => bail!("Unresolved function {n} not yet supported"), | ||
} | ||
} | ||
|
||
pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> { | ||
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { | ||
Ok(arguments) => arguments, | ||
Err(arguments) => { | ||
bail!("requires exactly one argument; got {arguments:?}"); | ||
} | ||
}; | ||
|
||
let [arg] = arguments; | ||
|
||
let count = arg.count(CountMode::All); | ||
|
||
Ok(count) | ||
} | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
handle_count
function needs to support bothcount(*)
andcount(expr)
cases. Currently it only handles the single argument case. For zero arguments (i.e.count(*)
), it should return a count of all rows. For one argument (i.e.count(expr)
), it should count non-null values of that expression.Spotted by Graphite Reviewer
Is this helpful? React 👍 or 👎 to let us know.