Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] connect: explain (WIP) #3379

Draft
wants to merge 6 commits into
base: andrew/connect-intersect-union
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use spark_connect::{
use tonic::{transport::Server, Request, Response, Status};
use tracing::{debug, info};
use uuid::Uuid;

use spark_connect::analyze_plan_request::explain::ExplainMode;
use crate::session::Session;

mod config;
Expand Down Expand Up @@ -285,6 +285,8 @@ impl SparkConnectService for DaftSparkConnectService {
use spark_connect::analyze_plan_request::*;
let request = request.into_inner();

let mut session = self.get_session(&request.session_id)?;

let AnalyzePlanRequest {
session_id,
analyze,
Expand Down Expand Up @@ -328,7 +330,35 @@ impl SparkConnectService for DaftSparkConnectService {

Ok(Response::new(response))
}
_ => unimplemented_err!("Analyze plan operation is not yet implemented"),
Analyze::Explain(explain) => {
let Explain { plan, explain_mode } = explain;

let explain_mode = ExplainMode::try_from(explain_mode)
.map_err(|_| invalid_argument_err!("Invalid Explain Mode"))?;

let Some(plan) = plan else {
return invalid_argument_err!("Plan is required");
};

let Some(plan) = plan.op_type else {
return invalid_argument_err!("Op Type is required");
};

let OpType::Root(relation) = plan else {
return invalid_argument_err!("Plan operation is required");
};

let result = match session.handle_explain_command(relation, explain_mode).await {
Ok(result) => result,
Err(e) => return Err(Status::internal(format!("Error in Daft server: {e:?}"))),
};

Ok(Response::new(result))
}
op => {
println!("{op:#?}");
unimplemented_err!("Analyze plan operation is not yet implemented")
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/src/op.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod execute;
pub mod analyze;
52 changes: 52 additions & 0 deletions src/daft-connect/src/op/analyze.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::pin::Pin;

use spark_connect::{analyze_plan_response, AnalyzePlanResponse};

pub type AnalyzeStream =
Pin<Box<dyn futures::Stream<Item = Result<AnalyzePlanResponse, Status>> + Send + Sync>>;

use spark_connect::{analyze_plan_request::explain::ExplainMode, Relation};
use tonic::Status;

use crate::{session::Session, translation};

pub struct PlanIds {
session: String,
server_side_session: String,
}

impl PlanIds {
pub fn response(&self, result: analyze_plan_response::Result) -> AnalyzePlanResponse {
AnalyzePlanResponse {
session_id: self.session.to_string(),
server_side_session_id: self.server_side_session.to_string(),
result: Some(result),
}
}
}

impl Session {
pub async fn handle_explain_command(
&self,
command: Relation,
_mode: ExplainMode,
) -> eyre::Result<AnalyzePlanResponse> {
let context = PlanIds {
session: self.client_side_session_id().to_string(),
server_side_session: self.server_side_session_id().to_string(),
};

let plan = translation::to_logical_plan(command)?;
let optimized_plan = plan.optimize()?;

let optimized_plan = optimized_plan.build();

// todo: what do we want this to display
let explain_string = format!("{optimized_plan}");

let schema = analyze_plan_response::Explain { explain_string };

let response = context.response(analyze_plan_response::Result::Explain(schema));
Ok(response)
}
}
2 changes: 1 addition & 1 deletion src/daft-connect/src/op/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod root;

pub type ExecuteStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream;

pub struct PlanIds {
struct PlanIds {
session: String,
server_side_session: String,
operation: String,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use datatype::{to_daft_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
Expand Down
154 changes: 153 additions & 1 deletion src/daft-connect/src/translation/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use daft_schema::dtype::DataType;
use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit};
use eyre::{bail, ensure, WrapErr};
use spark_connect::data_type::Kind;
use tracing::warn;

Expand Down Expand Up @@ -112,3 +113,154 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType {
_ => unimplemented!("Unsupported datatype: {datatype:?}"),
}
}

// todo(test): add tests for this esp in Python
pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result<DataType> {
let Some(kind) = &datatype.kind else {
bail!("Datatype is required");
};

let type_variation_err = "Custom type variation reference not supported";

match kind {
Kind::Null(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Null)
}
Kind::Binary(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Binary)
}
Kind::Boolean(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Boolean)
}
Kind::Byte(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int8)
}
Kind::Short(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int16)
}
Kind::Integer(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int32)
}
Kind::Long(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Int64)
}
Kind::Float(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float32)
}
Kind::Double(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Float64)
}
Kind::Decimal(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

let Some(precision) = value.precision else {
bail!("Decimal precision is required");
};

let Some(scale) = value.scale else {
bail!("Decimal scale is required");
};

let precision = usize::try_from(precision)
.wrap_err("Decimal precision must be a non-negative integer")?;

let scale =
usize::try_from(scale).wrap_err("Decimal scale must be a non-negative integer")?;

Ok(DataType::Decimal128(precision, scale))
}
Kind::String(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::Char(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::VarChar(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Utf8)
}
Kind::Date(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
Ok(DataType::Date)
}
Kind::Timestamp(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

// todo(?): is this correct?

Ok(DataType::Timestamp(TimeUnit::Microseconds, None))
}
Kind::TimestampNtz(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);

// todo(?): is this correct?

Ok(DataType::Timestamp(TimeUnit::Microseconds, None))
}
Kind::CalendarInterval(_) => bail!("Calendar interval type not supported"),
Kind::YearMonthInterval(_) => bail!("Year-month interval type not supported"),
Kind::DayTimeInterval(_) => bail!("Day-time interval type not supported"),
Kind::Array(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let element_type = to_daft_datatype(
value
.element_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Array element type is required"))?,
)?;
Ok(DataType::List(Box::new(element_type)))
}
Kind::Struct(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let fields = value
.fields
.iter()
.map(|f| {
let field_type = to_daft_datatype(
f.data_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Struct field type is required"))?,
)?;
Ok(Field::new(&f.name, field_type))
})
.collect::<eyre::Result<Vec<_>>>()?;
Ok(DataType::Struct(fields))
}
Kind::Map(value) => {
ensure!(value.type_variation_reference == 0, type_variation_err);
let key_type = to_daft_datatype(
value
.key_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map key type is required"))?,
)?;
let value_type = to_daft_datatype(
value
.value_type
.as_ref()
.ok_or_else(|| eyre::eyre!("Map value type is required"))?,
)?;

let map = DataType::Map {
key: Box::new(key_type),
value: Box::new(value_type),
};

Ok(map)
}
Kind::Variant(_) => bail!("Variant type not supported"),
Kind::Udt(_) => bail!("User-defined type not supported"),
Kind::Unparsed(_) => bail!("Unparsed type not supported"),
}
}
68 changes: 64 additions & 4 deletions src/daft-connect/src/translation/expr.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use std::sync::Arc;

use eyre::{bail, Context};
use spark_connect::{expression as spark_expr, Expression};
use spark_connect::{
expression as spark_expr,
expression::{
cast::{CastToType, EvalMode},
sort_order::{NullOrdering, SortDirection},
},
Expression,
};
use tracing::warn;
use unresolved_function::unresolved_to_daft_expr;

use crate::translation::to_daft_literal;
use crate::translation::{to_daft_datatype, to_daft_literal};

mod unresolved_function;

Expand Down Expand Up @@ -69,11 +76,64 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result<daft_dsl::ExprRef>

Ok(child.alias(name))
}
spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"),
spark_expr::ExprType::Cast(c) => {
// Cast { expr: Some(Expression { common: None, expr_type: Some(UnresolvedAttribute(UnresolvedAttribute { unparsed_identifier: "id", plan_id: None, is_metadata_column: None })) }), eval_mode: Unspecified, cast_to_type: Some(Type(DataType { kind: Some(String(String { type_variation_reference: 0, collation: "" })) })) }
// thread 'tokio-runtime-worker' panicked at src/daft-connect/src/trans
println!("got cast {c:?}");
let spark_expr::Cast {
expr,
eval_mode,
cast_to_type,
} = &**c;

let Some(expr) = expr else {
bail!("Cast expression is required");
};

let expr = to_daft_expr(expr)?;

let Some(cast_to_type) = cast_to_type else {
bail!("Cast to type is required");
};

let data_type = match cast_to_type {
CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| {
format!("Failed to convert spark datatype to daft datatype: {kind:?}")
})?,
CastToType::TypeStr(s) => {
bail!("Cast to type string not yet supported; tried to cast to {s}");
}
};

let eval_mode = EvalMode::try_from(*eval_mode)
.wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?;

warn!("Ignoring cast eval mode: {eval_mode:?}");

Ok(expr.cast(&data_type))
}
spark_expr::ExprType::UnresolvedRegex(_) => {
bail!("Unresolved regex expressions not yet supported")
}
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"),
spark_expr::ExprType::SortOrder(s) => {
let spark_expr::SortOrder {
child,
direction,
null_ordering,
} = &**s;

let Some(_child) = child else {
bail!("Sort order child is required");
};

let _sort_direction = SortDirection::try_from(*direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction}"))?;

let _sort_nulls = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?;

bail!("Sort order expressions not yet supported");
}
spark_expr::ExprType::LambdaFunction(_) => {
bail!("Lambda function expressions not yet supported")
}
Expand Down
Loading