diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 70171ad0d4..ff7a19cde2 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -158,7 +158,6 @@ impl SparkConnectService for DaftSparkConnectService { request: Request, ) -> Result, Status> { let request = request.into_inner(); - let session = self.get_session(&request.session_id)?; let Some(operation) = request.operation_id else { diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 1e1fac147b..ddfd401846 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,7 +1,10 @@ use std::{collections::HashMap, future::ready}; use common_daft_config::DaftExecutionConfig; +use daft_core::series::Series; use daft_local_execution::NativeExecutor; +use daft_schema::{field::Field, schema::Schema}; +use daft_table::Table; use futures::stream; use spark_connect::{ExecutePlanResponse, Relation}; use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; @@ -10,6 +13,7 @@ use crate::{ op::execute::{ExecuteStream, PlanIds}, session::Session, translation, + translation::to_spark_compatible_datatype, }; impl Session { @@ -38,17 +42,46 @@ impl Session { let mut result_stream = native_executor .run(HashMap::new(), cfg.into(), None)? .into_stream(); - while let Some(result) = result_stream.next().await { let result = result?; let tables = result.get_tables()?; + for table in tables.as_slice() { - let response = context.gen_response(table)?; - if tx.send(Ok(response)).await.is_err() { - return Ok(()); + // Inside the for loop over tables + let mut arrow_arrays = Vec::with_capacity(table.num_columns()); + let mut column_names = Vec::with_capacity(table.num_columns()); + let mut field_types = Vec::with_capacity(table.num_columns()); + + for i in 0..table.num_columns() { + let s = table.get_column_by_index(i)?; + + let daft_data_type = to_spark_compatible_datatype(s.data_type()); + let s = s.cast(&daft_data_type)?; + + // Store the actual type after potential casting + field_types.push(Field::new(s.name(), daft_data_type)); + column_names.push(s.name().to_string()); + arrow_arrays.push(s.to_arrow()); } + + // Create new schema with actual types after casting + let new_schema = Schema::new(field_types)?; + + // Convert arrays back to series + let series = arrow_arrays + .into_iter() + .zip(column_names) + .map(|(array, name)| Series::try_from((name.as_str(), array))) + .try_collect()?; + + // Create table from series + let new_table = Table::new_with_size(new_schema, series, table.len())?; + + let response = context.gen_response(&new_table)?; + tx.send(Ok(response)).await.unwrap(); } } + Ok(()) }; diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index bb2d73b507..f7013601c5 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -6,7 +6,7 @@ mod literal; mod logical_plan; mod schema; -pub use datatype::to_spark_datatype; +pub use datatype::{to_spark_compatible_datatype, to_spark_datatype}; pub use expr::to_daft_expr; pub use literal::to_daft_literal; pub use logical_plan::to_logical_plan; diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs index 9a40844464..25d7de44eb 100644 --- a/src/daft-connect/src/translation/datatype.rs +++ b/src/daft-connect/src/translation/datatype.rs @@ -1,7 +1,27 @@ -use daft_schema::dtype::DataType; +use daft_schema::{dtype::DataType, field::Field}; use spark_connect::data_type::Kind; use tracing::warn; +// todo: still a WIP; by no means complete +pub fn to_spark_compatible_datatype(datatype: &DataType) -> DataType { + // TL;DR unsigned integers are not supported by Spark + match datatype { + DataType::UInt8 => DataType::Int8, + DataType::UInt16 => DataType::Int16, + DataType::UInt32 => DataType::Int32, + DataType::UInt64 => DataType::Int64, + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|f| Field::new(f.name.clone(), to_spark_compatible_datatype(&f.dtype))) + .collect(); + + DataType::Struct(fields) + } + _ => datatype.clone(), + } +} + pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { match datatype { DataType::Null => spark_connect::DataType { diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs index ffb8c802ce..b33eeeaf60 100644 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -1,6 +1,7 @@ use daft_core::count_mode::CountMode; use eyre::{bail, Context}; use spark_connect::expression::UnresolvedFunction; +use tracing::debug; use crate::translation::to_daft_expr; @@ -38,6 +39,16 @@ pub fn handle_count(arguments: Vec) -> eyre::Result