From c6f1ba29c52f06423ec0888680b4aaa2387bf8d5 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 00:26:03 -0800 Subject: [PATCH] [FIX] (WIP) casting of arrays from daft to arrow with unsigned --- src/daft-connect/src/lib.rs | 1 - src/daft-connect/src/op/execute/root.rs | 42 ++++++++++++++++--- src/daft-connect/src/translation.rs | 4 +- src/daft-connect/src/translation/datatype.rs | 20 +++++++++ .../translation/expr/unresolved_function.rs | 11 +++++ tests/connect/test_count.py | 13 ++++++ 6 files changed, 84 insertions(+), 7 deletions(-) create mode 100644 tests/connect/test_count.py diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index efc861b986..539587c2bd 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -158,7 +158,6 @@ impl SparkConnectService for DaftSparkConnectService { request: Request, ) -> Result, Status> { let request = request.into_inner(); - let session = self.get_session(&request.session_id)?; let Some(operation) = request.operation_id else { diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index fcd1f41bb9..6be3c17382 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,7 +1,10 @@ use std::{future::ready, sync::Arc}; use common_daft_config::DaftExecutionConfig; +use daft_core::series::Series; use daft_local_execution::NativeExecutor; +use daft_schema::{field::Field, schema::Schema}; +use daft_table::Table; use futures::stream; use spark_connect::{ExecutePlanResponse, Relation}; use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; @@ -10,7 +13,7 @@ use crate::{ op::execute::{ExecuteStream, PlanIds}, session::Session, translation, - translation::Plan, + translation::{to_spark_compatible_datatype, Plan}, }; impl Session { @@ -37,17 +40,46 @@ impl Session { let cfg = Arc::new(DaftExecutionConfig::default()); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; let mut result_stream = native_executor.run(psets, cfg, None)?.into_stream(); - while let Some(result) = result_stream.next().await { let result = result?; let tables = result.get_tables()?; + for table in tables.as_slice() { - let response = context.gen_response(table)?; - if tx.send(Ok(response)).await.is_err() { - return Ok(()); + // Inside the for loop over tables + let mut arrow_arrays = Vec::with_capacity(table.num_columns()); + let mut column_names = Vec::with_capacity(table.num_columns()); + let mut field_types = Vec::with_capacity(table.num_columns()); + + for i in 0..table.num_columns() { + let s = table.get_column_by_index(i)?; + + let daft_data_type = to_spark_compatible_datatype(s.data_type()); + let s = s.cast(&daft_data_type)?; + + // Store the actual type after potential casting + field_types.push(Field::new(s.name(), daft_data_type)); + column_names.push(s.name().to_string()); + arrow_arrays.push(s.to_arrow()); } + + // Create new schema with actual types after casting + let new_schema = Schema::new(field_types)?; + + // Convert arrays back to series + let series = arrow_arrays + .into_iter() + .zip(column_names) + .map(|(array, name)| Series::try_from((name.as_str(), array))) + .try_collect()?; + + // Create table from series + let new_table = Table::new_with_size(new_schema, series, table.len())?; + + let response = context.gen_response(&new_table)?; + tx.send(Ok(response)).await.unwrap(); } } + Ok(()) }; diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index 8b61b93f98..07b2d36a18 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -6,7 +6,9 @@ mod literal; mod logical_plan; mod schema; -pub use datatype::{deser_spark_datatype, to_daft_datatype, to_spark_datatype}; +pub use datatype::{ + deser_spark_datatype, to_daft_datatype, to_spark_compatible_datatype, to_spark_datatype, +}; pub use expr::to_daft_expr; pub use literal::to_daft_literal; pub use logical_plan::{to_logical_plan, Plan}; diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs index 722e66c3b3..9891411b95 100644 --- a/src/daft-connect/src/translation/datatype.rs +++ b/src/daft-connect/src/translation/datatype.rs @@ -6,6 +6,26 @@ use tracing::warn; mod codec; pub use codec::deser as deser_spark_datatype; +// todo: still a WIP; by no means complete +pub fn to_spark_compatible_datatype(datatype: &DataType) -> DataType { + // TL;DR unsigned integers are not supported by Spark + match datatype { + DataType::UInt8 => DataType::Int8, + DataType::UInt16 => DataType::Int16, + DataType::UInt32 => DataType::Int32, + DataType::UInt64 => DataType::Int64, + DataType::Struct(fields) => { + let fields = fields + .iter() + .map(|f| Field::new(f.name.clone(), to_spark_compatible_datatype(&f.dtype))) + .collect(); + + DataType::Struct(fields) + } + _ => datatype.clone(), + } +} + pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { match datatype { DataType::Null => spark_connect::DataType { diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs index af70a1e8f2..2efeb5dc24 100644 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -1,6 +1,7 @@ use daft_core::count_mode::CountMode; use eyre::{bail, Context}; use spark_connect::expression::UnresolvedFunction; +use tracing::debug; use crate::translation::to_daft_expr; @@ -79,6 +80,16 @@ pub fn handle_count(arguments: Vec) -> eyre::Result