Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] improve Spark Connect compatibility for types and count behavior #3352

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ impl SparkConnectService for DaftSparkConnectService {
request: Request<ExecutePlanRequest>,
) -> Result<Response<Self::ExecutePlanStream>, Status> {
let request = request.into_inner();

let session = self.get_session(&request.session_id)?;

let Some(operation) = request.operation_id else {
Expand Down
41 changes: 37 additions & 4 deletions src/daft-connect/src/op/execute/root.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{collections::HashMap, future::ready};

use common_daft_config::DaftExecutionConfig;
use daft_core::series::Series;
use daft_local_execution::NativeExecutor;
use daft_schema::{field::Field, schema::Schema};
use daft_table::Table;
use futures::stream;
use spark_connect::{ExecutePlanResponse, Relation};
use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
Expand All @@ -10,6 +13,7 @@ use crate::{
op::execute::{ExecuteStream, PlanIds},
session::Session,
translation,
translation::to_spark_compatible_datatype,
};

impl Session {
Expand Down Expand Up @@ -38,17 +42,46 @@ impl Session {
let mut result_stream = native_executor
.run(HashMap::new(), cfg.into(), None)?
.into_stream();

while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;

for table in tables.as_slice() {
let response = context.gen_response(table)?;
if tx.send(Ok(response)).await.is_err() {
return Ok(());
// Inside the for loop over tables
let mut arrow_arrays = Vec::with_capacity(table.num_columns());
let mut column_names = Vec::with_capacity(table.num_columns());
let mut field_types = Vec::with_capacity(table.num_columns());

for i in 0..table.num_columns() {
let s = table.get_column_by_index(i)?;

let daft_data_type = to_spark_compatible_datatype(s.data_type());
let s = s.cast(&daft_data_type)?;

// Store the actual type after potential casting
field_types.push(Field::new(s.name(), daft_data_type));
column_names.push(s.name().to_string());
arrow_arrays.push(s.to_arrow());
}

// Create new schema with actual types after casting
let new_schema = Schema::new(field_types)?;

// Convert arrays back to series
let series = arrow_arrays
.into_iter()
.zip(column_names)
.map(|(array, name)| Series::try_from((name.as_str(), array)))
.try_collect()?;

// Create table from series
let new_table = Table::new_with_size(new_schema, series, table.len())?;

let response = context.gen_response(&new_table)?;
graphite-app[bot] marked this conversation as resolved.
Show resolved Hide resolved
graphite-app[bot] marked this conversation as resolved.
Show resolved Hide resolved
graphite-app[bot] marked this conversation as resolved.
Show resolved Hide resolved
graphite-app[bot] marked this conversation as resolved.
Show resolved Hide resolved
tx.send(Ok(response)).await.unwrap();
}
}

Ok(())
};

Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/src/translation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod literal;
mod logical_plan;
mod schema;

pub use datatype::to_spark_datatype;
pub use datatype::{to_spark_compatible_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
pub use logical_plan::to_logical_plan;
Expand Down
22 changes: 21 additions & 1 deletion src/daft-connect/src/translation/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
use daft_schema::dtype::DataType;
use daft_schema::{dtype::DataType, field::Field};
use spark_connect::data_type::Kind;
use tracing::warn;

// todo: still a WIP; by no means complete
pub fn to_spark_compatible_datatype(datatype: &DataType) -> DataType {
// TL;DR unsigned integers are not supported by Spark
match datatype {
DataType::UInt8 => DataType::Int8,
DataType::UInt16 => DataType::Int16,
DataType::UInt32 => DataType::Int32,

Check warning on line 11 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L9-L11

Added lines #L9 - L11 were not covered by tests
DataType::UInt64 => DataType::Int64,
DataType::Struct(fields) => {
let fields = fields
.iter()
.map(|f| Field::new(f.name.clone(), to_spark_compatible_datatype(&f.dtype)))
.collect();

DataType::Struct(fields)

Check warning on line 19 in src/daft-connect/src/translation/datatype.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/datatype.rs#L13-L19

Added lines #L13 - L19 were not covered by tests
}
_ => datatype.clone(),
}
}

pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType {
match datatype {
DataType::Null => spark_connect::DataType {
Expand Down
11 changes: 11 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use daft_core::count_mode::CountMode;
use eyre::{bail, Context};
use spark_connect::expression::UnresolvedFunction;
use tracing::debug;

use crate::translation::to_daft_expr;

Expand Down Expand Up @@ -38,6 +39,16 @@

let [arg] = arguments;

// special case to be consistent with how spark handles counting literals
// see https://github.com/Eventual-Inc/Daft/issues/3421
let count_special_case = *arg == daft_dsl::Expr::Literal(daft_dsl::LiteralValue::Int32(1));

if count_special_case {
debug!("special case for count");
let result = daft_dsl::col("*").count(CountMode::All);
return Ok(result);
}

Check warning on line 51 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L50-L51

Added lines #L50 - L51 were not covered by tests
let count = arg.count(CountMode::All);

Ok(count)
Expand Down
13 changes: 13 additions & 0 deletions tests/connect/test_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations


def test_count(spark_session):
# Create a range using Spark
# For example, creating a range from 0 to 9
spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9

# Convert to Pandas DataFrame
count = spark_range.count()

# Verify the DataFrame has expected values
assert count == 10, "DataFrame should have 10 rows"