diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 70171ad0d4..02923937a3 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -24,7 +24,7 @@ use spark_connect::{ use tonic::{transport::Server, Request, Response, Status}; use tracing::{debug, info}; use uuid::Uuid; - +use spark_connect::analyze_plan_request::explain::ExplainMode; use crate::session::Session; mod config; @@ -285,6 +285,8 @@ impl SparkConnectService for DaftSparkConnectService { use spark_connect::analyze_plan_request::*; let request = request.into_inner(); + let mut session = self.get_session(&request.session_id)?; + let AnalyzePlanRequest { session_id, analyze, @@ -328,7 +330,35 @@ impl SparkConnectService for DaftSparkConnectService { Ok(Response::new(response)) } - _ => unimplemented_err!("Analyze plan operation is not yet implemented"), + Analyze::Explain(explain) => { + let Explain { plan, explain_mode } = explain; + + let explain_mode = ExplainMode::try_from(explain_mode) + .map_err(|_| invalid_argument_err!("Invalid Explain Mode"))?; + + let Some(plan) = plan else { + return invalid_argument_err!("Plan is required"); + }; + + let Some(plan) = plan.op_type else { + return invalid_argument_err!("Op Type is required"); + }; + + let OpType::Root(relation) = plan else { + return invalid_argument_err!("Plan operation is required"); + }; + + let result = match session.handle_explain_command(relation, explain_mode).await { + Ok(result) => result, + Err(e) => return Err(Status::internal(format!("Error in Daft server: {e:?}"))), + }; + + Ok(Response::new(result)) + } + op => { + println!("{op:#?}"); + unimplemented_err!("Analyze plan operation is not yet implemented") + } } } diff --git a/src/daft-connect/src/op.rs b/src/daft-connect/src/op.rs index 2e8bdddf98..4e012a6c30 100644 --- a/src/daft-connect/src/op.rs +++ b/src/daft-connect/src/op.rs @@ -1 +1,2 @@ pub mod execute; +pub mod analyze; diff --git a/src/daft-connect/src/op/analyze.rs b/src/daft-connect/src/op/analyze.rs new file mode 100644 index 0000000000..5bb85bd194 --- /dev/null +++ b/src/daft-connect/src/op/analyze.rs @@ -0,0 +1,52 @@ +use std::pin::Pin; + +use spark_connect::{analyze_plan_response, AnalyzePlanResponse}; + +pub type AnalyzeStream = + Pin> + Send + Sync>>; + +use spark_connect::{analyze_plan_request::explain::ExplainMode, Relation}; +use tonic::Status; + +use crate::{session::Session, translation}; + +pub struct PlanIds { + session: String, + server_side_session: String, +} + +impl PlanIds { + pub fn response(&self, result: analyze_plan_response::Result) -> AnalyzePlanResponse { + AnalyzePlanResponse { + session_id: self.session.to_string(), + server_side_session_id: self.server_side_session.to_string(), + result: Some(result), + } + } +} + +impl Session { + pub async fn handle_explain_command( + &self, + command: Relation, + _mode: ExplainMode, + ) -> eyre::Result { + let context = PlanIds { + session: self.client_side_session_id().to_string(), + server_side_session: self.server_side_session_id().to_string(), + }; + + let plan = translation::to_logical_plan(command)?; + let optimized_plan = plan.optimize()?; + + let optimized_plan = optimized_plan.build(); + + // todo: what do we want this to display + let explain_string = format!("{optimized_plan}"); + + let schema = analyze_plan_response::Explain { explain_string }; + + let response = context.response(analyze_plan_response::Result::Explain(schema)); + Ok(response) + } +} diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/op/execute.rs index fba3cc850d..7ec6a44324 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/op/execute.rs @@ -14,7 +14,7 @@ mod root; pub type ExecuteStream = ::ExecutePlanStream; -pub struct PlanIds { +struct PlanIds { session: String, server_side_session: String, operation: String, diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index bb2d73b507..a03fe113a7 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -6,7 +6,7 @@ mod literal; mod logical_plan; mod schema; -pub use datatype::to_spark_datatype; +pub use datatype::{to_daft_datatype, to_spark_datatype}; pub use expr::to_daft_expr; pub use literal::to_daft_literal; pub use logical_plan::to_logical_plan; diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs index 9a40844464..d5f186c659 100644 --- a/src/daft-connect/src/translation/datatype.rs +++ b/src/daft-connect/src/translation/datatype.rs @@ -1,4 +1,5 @@ -use daft_schema::dtype::DataType; +use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit}; +use eyre::{bail, ensure, WrapErr}; use spark_connect::data_type::Kind; use tracing::warn; @@ -112,3 +113,154 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { _ => unimplemented!("Unsupported datatype: {datatype:?}"), } } + +// todo(test): add tests for this esp in Python +pub fn to_daft_datatype(datatype: &spark_connect::DataType) -> eyre::Result { + let Some(kind) = &datatype.kind else { + bail!("Datatype is required"); + }; + + let type_variation_err = "Custom type variation reference not supported"; + + match kind { + Kind::Null(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Null) + } + Kind::Binary(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Binary) + } + Kind::Boolean(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Boolean) + } + Kind::Byte(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int8) + } + Kind::Short(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int16) + } + Kind::Integer(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int32) + } + Kind::Long(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Int64) + } + Kind::Float(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Float32) + } + Kind::Double(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Float64) + } + Kind::Decimal(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + + let Some(precision) = value.precision else { + bail!("Decimal precision is required"); + }; + + let Some(scale) = value.scale else { + bail!("Decimal scale is required"); + }; + + let precision = usize::try_from(precision) + .wrap_err("Decimal precision must be a non-negative integer")?; + + let scale = + usize::try_from(scale).wrap_err("Decimal scale must be a non-negative integer")?; + + Ok(DataType::Decimal128(precision, scale)) + } + Kind::String(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Utf8) + } + Kind::Char(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Utf8) + } + Kind::VarChar(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Utf8) + } + Kind::Date(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + Ok(DataType::Date) + } + Kind::Timestamp(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + + // todo(?): is this correct? + + Ok(DataType::Timestamp(TimeUnit::Microseconds, None)) + } + Kind::TimestampNtz(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + + // todo(?): is this correct? + + Ok(DataType::Timestamp(TimeUnit::Microseconds, None)) + } + Kind::CalendarInterval(_) => bail!("Calendar interval type not supported"), + Kind::YearMonthInterval(_) => bail!("Year-month interval type not supported"), + Kind::DayTimeInterval(_) => bail!("Day-time interval type not supported"), + Kind::Array(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + let element_type = to_daft_datatype( + value + .element_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Array element type is required"))?, + )?; + Ok(DataType::List(Box::new(element_type))) + } + Kind::Struct(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + let fields = value + .fields + .iter() + .map(|f| { + let field_type = to_daft_datatype( + f.data_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Struct field type is required"))?, + )?; + Ok(Field::new(&f.name, field_type)) + }) + .collect::>>()?; + Ok(DataType::Struct(fields)) + } + Kind::Map(value) => { + ensure!(value.type_variation_reference == 0, type_variation_err); + let key_type = to_daft_datatype( + value + .key_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Map key type is required"))?, + )?; + let value_type = to_daft_datatype( + value + .value_type + .as_ref() + .ok_or_else(|| eyre::eyre!("Map value type is required"))?, + )?; + + let map = DataType::Map { + key: Box::new(key_type), + value: Box::new(value_type), + }; + + Ok(map) + } + Kind::Variant(_) => bail!("Variant type not supported"), + Kind::Udt(_) => bail!("User-defined type not supported"), + Kind::Unparsed(_) => bail!("Unparsed type not supported"), + } +} diff --git a/src/daft-connect/src/translation/expr.rs b/src/daft-connect/src/translation/expr.rs index bcbadf9737..f5307fae9d 100644 --- a/src/daft-connect/src/translation/expr.rs +++ b/src/daft-connect/src/translation/expr.rs @@ -1,11 +1,18 @@ use std::sync::Arc; use eyre::{bail, Context}; -use spark_connect::{expression as spark_expr, Expression}; +use spark_connect::{ + expression as spark_expr, + expression::{ + cast::{CastToType, EvalMode}, + sort_order::{NullOrdering, SortDirection}, + }, + Expression, +}; use tracing::warn; use unresolved_function::unresolved_to_daft_expr; -use crate::translation::to_daft_literal; +use crate::translation::{to_daft_datatype, to_daft_literal}; mod unresolved_function; @@ -69,11 +76,64 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result Ok(child.alias(name)) } - spark_expr::ExprType::Cast(_) => bail!("Cast expressions not yet supported"), + spark_expr::ExprType::Cast(c) => { + // Cast { expr: Some(Expression { common: None, expr_type: Some(UnresolvedAttribute(UnresolvedAttribute { unparsed_identifier: "id", plan_id: None, is_metadata_column: None })) }), eval_mode: Unspecified, cast_to_type: Some(Type(DataType { kind: Some(String(String { type_variation_reference: 0, collation: "" })) })) } + // thread 'tokio-runtime-worker' panicked at src/daft-connect/src/trans + println!("got cast {c:?}"); + let spark_expr::Cast { + expr, + eval_mode, + cast_to_type, + } = &**c; + + let Some(expr) = expr else { + bail!("Cast expression is required"); + }; + + let expr = to_daft_expr(expr)?; + + let Some(cast_to_type) = cast_to_type else { + bail!("Cast to type is required"); + }; + + let data_type = match cast_to_type { + CastToType::Type(kind) => to_daft_datatype(kind).wrap_err_with(|| { + format!("Failed to convert spark datatype to daft datatype: {kind:?}") + })?, + CastToType::TypeStr(s) => { + bail!("Cast to type string not yet supported; tried to cast to {s}"); + } + }; + + let eval_mode = EvalMode::try_from(*eval_mode) + .wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?; + + warn!("Ignoring cast eval mode: {eval_mode:?}"); + + Ok(expr.cast(&data_type)) + } spark_expr::ExprType::UnresolvedRegex(_) => { bail!("Unresolved regex expressions not yet supported") } - spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"), + spark_expr::ExprType::SortOrder(s) => { + let spark_expr::SortOrder { + child, + direction, + null_ordering, + } = &**s; + + let Some(_child) = child else { + bail!("Sort order child is required"); + }; + + let _sort_direction = SortDirection::try_from(*direction) + .wrap_err_with(|| format!("Invalid sort direction: {direction}"))?; + + let _sort_nulls = NullOrdering::try_from(*null_ordering) + .wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?; + + bail!("Sort order expressions not yet supported"); + } spark_expr::ExprType::LambdaFunction(_) => { bail!("Lambda function expressions not yet supported") } diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs index ffb8c802ce..af70a1e8f2 100644 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -24,10 +24,51 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result handle_count(arguments).wrap_err("Failed to handle count function"), + "<" => handle_binary_op(arguments, daft_dsl::Operator::Lt) + .wrap_err("Failed to handle < function"), + ">" => handle_binary_op(arguments, daft_dsl::Operator::Gt) + .wrap_err("Failed to handle > function"), + "<=" => handle_binary_op(arguments, daft_dsl::Operator::LtEq) + .wrap_err("Failed to handle <= function"), + ">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq) + .wrap_err("Failed to handle >= function"), + "%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus) + .wrap_err("Failed to handle % function"), + "sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"), + "isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"), + "isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"), n => bail!("Unresolved function {n} not yet supported"), } } +pub fn handle_sum(arguments: Vec) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + Ok(arg.sum()) +} + +pub fn handle_binary_op( + arguments: Vec, + op: daft_dsl::Operator, +) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 2] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly two arguments; got {arguments:?}"); + } + }; + + let [left, right] = arguments; + + Ok(daft_dsl::binary_op(op, left, right)) +} + pub fn handle_count(arguments: Vec) -> eyre::Result { let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { Ok(arguments) => arguments, @@ -42,3 +83,29 @@ pub fn handle_count(arguments: Vec) -> eyre::Result) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + + Ok(arg.is_null()) +} + +pub fn handle_isnotnull(arguments: Vec) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + + Ok(arg.not_null()) +} diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 93c9e9bd4a..53a0cfc923 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,11 +3,16 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, project::project, range::range, set_op::set_op, + with_columns::with_columns, +}; mod aggregate; mod project; mod range; +mod set_op; +mod with_columns; pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { @@ -25,6 +30,10 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::WithColumns(w) => { + with_columns(*w).wrap_err("Failed to apply with_columns to logical plan") + } + RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/set_op.rs b/src/daft-connect/src/translation/logical_plan/set_op.rs new file mode 100644 index 0000000000..7dfeff9650 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/set_op.rs @@ -0,0 +1,57 @@ +use eyre::{bail, Context}; +use spark_connect::set_operation::SetOpType; +use tracing::warn; + +use crate::translation::to_logical_plan; + +pub fn set_op( + set_op: spark_connect::SetOperation, +) -> eyre::Result { + let spark_connect::SetOperation { + left_input, + right_input, + set_op_type, + is_all, + by_name, + allow_missing_columns, + } = set_op; + + let Some(left_input) = left_input else { + bail!("Left input is required"); + }; + + let Some(right_input) = right_input else { + bail!("Right input is required"); + }; + + let set_op = SetOpType::try_from(set_op_type) + .wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?; + + if let Some(by_name) = by_name { + warn!("Ignoring by_name: {by_name}"); + } + + if let Some(allow_missing_columns) = allow_missing_columns { + warn!("Ignoring allow_missing_columns: {allow_missing_columns}"); + } + + let left = to_logical_plan(*left_input)?; + let right = to_logical_plan(*right_input)?; + + let is_all = is_all.unwrap_or(false); + + match set_op { + SetOpType::Unspecified => { + bail!("Unspecified set operation is not supported"); + } + SetOpType::Intersect => left + .intersect(&right, is_all) + .wrap_err("Failed to apply intersect to logical plan"), + SetOpType::Union => left + .union(&right, is_all) + .wrap_err("Failed to apply union to logical plan"), + SetOpType::Except => { + bail!("Except set operation is not supported"); + } + } +} diff --git a/src/daft-connect/src/translation/logical_plan/with_columns.rs b/src/daft-connect/src/translation/logical_plan/with_columns.rs new file mode 100644 index 0000000000..2ab9424a72 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/with_columns.rs @@ -0,0 +1,32 @@ +use eyre::bail; +use spark_connect::{expression::ExprType, Expression}; + +use crate::translation::{to_daft_expr, to_logical_plan}; + +pub fn with_columns( + with_columns: spark_connect::WithColumns, +) -> eyre::Result { + let spark_connect::WithColumns { input, aliases } = with_columns; + + let Some(input) = input else { + bail!("input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let daft_exprs: Vec<_> = aliases + .into_iter() + .map(|alias| { + let expression = Expression { + common: None, + expr_type: Some(ExprType::Alias(Box::new(alias))), + }; + + to_daft_expr(&expression) + }) + .try_collect()?; + + let plan = plan.with_columns(daft_exprs)?; + + Ok(plan) +} diff --git a/tests/connect/test_basic_column.py b/tests/connect/test_basic_column.py new file mode 100644 index 0000000000..fefb41eb98 --- /dev/null +++ b/tests/connect/test_basic_column.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from pyspark.sql.functions import col +from pyspark.sql.types import StringType + + +def test_column_operations(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Test __getattr__ + df_attr = df.select(col("id").desc()) # Fix: call desc() as method + assert df_attr.toPandas()["id"].iloc[0] == 9, "desc should sort in descending order" + + # Test __getitem__ + # df_item = df.select(col("id")[0]) + # assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element" + + # Test alias + df_alias = df.select(col("id").alias("my_number")) + assert "my_number" in df_alias.columns, "alias should rename column" + assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged" + + # Test cast + df_cast = df.select(col("id").cast(StringType())) + assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type" + + # Test isNotNull/isNull + df_null = df.select(col("id").isNotNull().alias("not_null"), col("id").isNull().alias("is_null")) + assert df_null.toPandas()["not_null"].iloc[0] == True, "isNotNull should be True for non-null values" + assert df_null.toPandas()["is_null"].iloc[0] == False, "isNull should be False for non-null values" + + # Test name + df_name = df.select(col("id").name("renamed_id")) + assert "renamed_id" in df_name.columns, "name should rename column" + assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged" diff --git a/tests/connect/test_explain.py b/tests/connect/test_explain.py new file mode 100644 index 0000000000..3b5574dec3 --- /dev/null +++ b/tests/connect/test_explain.py @@ -0,0 +1,16 @@ +from __future__ import annotations + + +def test_explain(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Get the explain plan + explain_str = unioned.explain(extended=True) + + # Verify explain output contains expected elements + print(explain_str) diff --git a/tests/connect/test_group_by.py b/tests/connect/test_group_by.py new file mode 100644 index 0000000000..40efbb20c6 --- /dev/null +++ b/tests/connect/test_group_by.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_group_by(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Add a column that will have repeated values for grouping + df = df.withColumn("group", col("id") % 3) + + # Group by the new column and sum the ids in each group + df_grouped = df.groupBy("group").sum("id") + + # Convert to pandas to verify the sums + df_grouped_pandas = df_grouped.toPandas() + + print(df_grouped_pandas) + + # Sort by group to ensure consistent order for comparison + df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True) + + # Verify the expected sums for each group + # group id + # 0 2 15 + # 1 1 12 + # 2 0 18 + expected = { + "group": [0, 1, 2], + "id": [18, 12, 15], # todo(correctness): should this be "id" for value here? + } + + assert df_grouped_pandas["group"].tolist() == expected["group"] + assert df_grouped_pandas["id"].tolist() == expected["id"] diff --git a/tests/connect/test_intersection.py b/tests/connect/test_intersection.py new file mode 100644 index 0000000000..7944de5cae --- /dev/null +++ b/tests/connect/test_intersection.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_intersection(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Intersect the two ranges + intersected = range1.intersect(range2) + + # Collect results + results = intersected.collect() + + # Verify the DataFrame has expected values + # Intersection should only include overlapping values once + assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)" + + # Check that all expected values are present + values = [row.id for row in results] + assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence" diff --git a/tests/connect/test_union.py b/tests/connect/test_union.py new file mode 100644 index 0000000000..9ac235d9e5 --- /dev/null +++ b/tests/connect/test_union.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_union(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Collect results + results = unioned.collect() + + # Verify the DataFrame has expected values + # Union includes duplicates, so length should be sum of both ranges + assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)" + + # Check that all expected values are present, including duplicates + values = [row.id for row in results] + assert sorted(values) == [0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 8, 9], "Values should match expected sequence with duplicates" diff --git a/tests/connect/test_with_column.py b/tests/connect/test_with_column.py new file mode 100644 index 0000000000..ad237339b2 --- /dev/null +++ b/tests/connect/test_with_column.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_with_column(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Add a new column that's a boolean indicating if id > 2 + df_with_col = df.withColumn("double_id", col("id") > 2) + + # Verify the schema has both columns + assert "id" in df_with_col.schema.names, "Original column should still exist" + assert "double_id" in df_with_col.schema.names, "New column should be added" + + # Verify the data is correct + df_pandas = df_with_col.toPandas() + assert (df_pandas["double_id"] == (df_pandas["id"] > 2)).all(), "New column should be greater than 2 comparison"