Skip to content

Commit

Permalink
[FEAT] connect: explain (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 21, 2024
1 parent 9711e2c commit c5a042d
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ impl SparkConnectService for DaftSparkConnectService {

Ok(Response::new(response))
}
_ => unimplemented_err!("Analyze plan operation is not yet implemented"),
op => {
println!("{op:#?}");
unimplemented_err!("Analyze plan operation is not yet implemented")
},
}
}

Expand Down
1 change: 1 addition & 0 deletions src/daft-connect/src/op.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod execute;
pub mod analyze;
101 changes: 101 additions & 0 deletions src/daft-connect/src/op/analyze.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use std::pin::Pin;

use spark_connect::{analyze_plan_response, AnalyzePlanResponse, Relation};
use tonic::Status;

use crate::session::Session;

pub type AnalyzeStream =
Pin<Box<dyn futures::Stream<Item = Result<AnalyzePlanResponse, Status>> + Send + Sync>>;

impl Session {
pub async fn handle_explain_command(
&self,
command: Relation,
operation_id: String,
) -> Result<AnalyzeStream, Status> {
}
}

use std::{collections::HashMap, future::ready};

use common_daft_config::DaftExecutionConfig;
use futures::stream;
use spark_connect::{ExecutePlanResponse, Relation};
use tokio_util::sync::CancellationToken;
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};

use crate::{op::execute::ExecuteStream, session::Session, translation};

pub struct PlanIds {
session: String,
server_side_session: String,
operation: String,
}

impl PlanIds {
pub fn response(&self, result: analyze_plan_response::Result) -> AnalyzePlanResponse {
AnalyzePlanResponse {
session_id: self.session.to_string(),
server_side_session_id: self.server_side_session.to_string(),
result: Some(result),
}
}
}

impl Session {
pub async fn handle_explain_command(
&self,
command: Relation,
operation_id: String,
) -> Result<ExecuteStream, Status> {
use futures::{StreamExt, TryStreamExt};

let context = PlanIds {
session: self.client_side_session_id().to_string(),
server_side_session: self.server_side_session_id().to_string(),
operation: operation_id,
};

let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(16);
std::thread::spawn(move || {
let result = (|| -> eyre::Result<()> {
let plan = translation::to_logical_plan(command)?;
let logical_plan = plan.build();
let physical_plan = daft_local_plan::translate(&logical_plan)?;

let cfg = DaftExecutionConfig::default();
let results = daft_local_execution::run_local(
&physical_plan,
HashMap::new(),
cfg.into(),
None,
CancellationToken::new(), // todo: maybe implement cancelling
)?;

for result in results {
let result = result?;
let tables = result.get_tables()?;

for table in tables.as_slice() {
let response = context.gen_response(table)?;
tx.blocking_send(Ok(response)).unwrap();
}
}
Ok(())
})();

if let Err(e) = result {
tx.blocking_send(Err(e)).unwrap();
}
});

let stream = ReceiverStream::new(rx);

let stream = stream
.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}")))
.chain(stream::once(ready(Ok(finished))));

Ok(Box::pin(stream))
}
}
2 changes: 1 addition & 1 deletion src/daft-connect/src/op/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod root;

pub type ExecuteStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream;

pub struct PlanIds {
struct PlanIds {
session: String,
server_side_session: String,
operation: String,
Expand Down
21 changes: 21 additions & 0 deletions tests/connect/test_explain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations


def test_explain(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Union the two ranges
unioned = range1.union(range2)

# Get the explain plan
explain_str = unioned.explain(extended=True)

# Verify explain output contains expected elements
assert "Union" in explain_str, "Explain plan should contain Union operation"
assert "Range" in explain_str, "Explain plan should contain Range operations"

# Check that both range operations are present
assert "(0, 7" in explain_str, "First range parameters should be in explain plan"
assert "(3, 10" in explain_str, "Second range parameters should be in explain plan"

0 comments on commit c5a042d

Please sign in to comment.