Skip to content

Commit

Permalink
[FEAT] connect: df.join
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 21, 2024
1 parent bb9a07f commit 2c4cb41
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, project::project, range::range, set_op::set_op,
aggregate::aggregate, join::join, project::project, range::range, set_op::set_op,
with_columns::with_columns, with_columns_renamed::with_columns_renamed,
};

mod aggregate;
mod join;
mod project;
mod range;
mod set_op;
Expand Down Expand Up @@ -37,6 +38,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::WithColumnsRenamed(w) => with_columns_renamed(*w)
.wrap_err("Failed to apply with_columns_renamed to logical plan"),
RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"),
RelType::Join(j) => join(*j).wrap_err("Failed to apply join to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
96 changes: 96 additions & 0 deletions src/daft-connect/src/translation/logical_plan/join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use daft_logical_plan::LogicalPlanBuilder;
use eyre::{bail, WrapErr};
use spark_connect::join::JoinType;
use tracing::warn;

use crate::translation::to_logical_plan;

pub fn join(join: spark_connect::Join) -> eyre::Result<LogicalPlanBuilder> {
let spark_connect::Join {
left,
right,
join_condition,
join_type,
using_columns,
join_data_type,
} = join;

let Some(left) = left else {
bail!("Left side of join is required");
};

let Some(right) = right else {
bail!("Right side of join is required");
};

if let Some(join_condition) = join_condition {
bail!("Join conditions are not yet supported; use using_columns (join keys) instead; got {join_condition:?}");
}

let join_type = JoinType::try_from(join_type)
.wrap_err_with(|| format!("Invalid join type: {join_type:?}"))?;

let join_type = to_daft_join_type(join_type)?;

let using_columns_exprs: Vec<_> = using_columns
.iter()
.map(|s| daft_dsl::col(s.as_str()))
.collect();

if let Some(join_data_type) = join_data_type {
warn!("Ignoring join data type {join_data_type:?} for join; not yet implemented");
}

let left = to_logical_plan(*left)?;
let right = to_logical_plan(*right)?;

let result = match join_type {
JoinTypeInfo::Cross => {
left.cross_join(&right, None, None)? // todo(correctness): is this correct?
}
JoinTypeInfo::Regular(join_type) => {
left.join(
&right,
// join_conditions.clone(), // todo(correctness): is this correct?
// join_conditions, // todo(correctness): is this correct?
using_columns_exprs.clone(),
using_columns_exprs,
join_type,
None,
None,
None,
false, // todo(correctness): we want join keys or not
)?
}
};

Ok(result)
}

enum JoinTypeInfo {
Regular(daft_core::join::JoinType),
Cross,
}

impl From<daft_logical_plan::JoinType> for JoinTypeInfo {
fn from(join_type: daft_logical_plan::JoinType) -> Self {
JoinTypeInfo::Regular(join_type)
}
}

fn to_daft_join_type(join_type: JoinType) -> eyre::Result<JoinTypeInfo> {
match join_type {
JoinType::Unspecified => {
bail!("Join type must be specified; got Unspecified")
}
JoinType::Inner => Ok(daft_core::join::JoinType::Inner.into()),
JoinType::FullOuter => {
bail!("Full outer joins not yet supported") // todo(completeness): add support for full outer joins if it is not already implemented
}
JoinType::LeftOuter => Ok(daft_core::join::JoinType::Left.into()), // todo(correctness): is this correct?
JoinType::RightOuter => Ok(daft_core::join::JoinType::Right.into()),
JoinType::LeftAnti => Ok(daft_core::join::JoinType::Anti.into()), // todo(correctness): is this correct?
JoinType::LeftSemi => bail!("Left semi joins not yet supported"), // todo(completeness): add support for left semi joins if it is not already implemented
JoinType::Cross => Ok(JoinTypeInfo::Cross),
}
}
57 changes: 57 additions & 0 deletions tests/connect/test_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_join(spark_session):
# Create two DataFrames with overlapping IDs
df1 = spark_session.range(5)
df2 = spark_session.range(3, 7)

# Perform inner join on 'id' column
joined_df = df1.join(df2, "id", "inner")

# Verify join results using collect()
joined_ids = {row.id for row in joined_df.select("id").collect()}
assert joined_ids == {3, 4}, "Inner join should only contain IDs 3 and 4"

# Test left outer join
left_joined_df = df1.join(df2, "id", "left")
left_joined_ids = {row.id for row in left_joined_df.select("id").collect()}
assert left_joined_ids == {0, 1, 2, 3, 4}, "Left join should keep all rows from left DataFrame"

# Test right outer join
right_joined_df = df1.join(df2, "id", "right")
right_joined_ids = {row.id for row in right_joined_df.select("id").collect()}
assert right_joined_ids == {3, 4, 5, 6}, "Right join should keep all rows from right DataFrame"



def test_cross_join(spark_session):
# Create two small DataFrames to demonstrate cross join
# df_left: [0, 1]
# df_right: [10, 11]
# Expected result will be all combinations:
# id1 id2
# 0 10
# 0 11
# 1 10
# 1 11
df_left = spark_session.range(2)
df_right = spark_session.range(10, 12).withColumnRenamed("id", "id2")

# Perform cross join - this creates cartesian product of both DataFrames
cross_joined_df = df_left.crossJoin(df_right)

# Convert to pandas for easier verification
result_df = cross_joined_df.toPandas()

# Verify we get all 4 combinations (2 x 2 = 4 rows)
assert len(result_df) == 4, "Cross join should produce 4 rows (2x2 cartesian product)"

# Verify all expected combinations exist
expected_combinations = {(0, 10), (0, 11), (1, 10), (1, 11)}
actual_combinations = {(row["id"], row["id2"]) for _, row in result_df.iterrows()}
assert actual_combinations == expected_combinations, "Cross join should contain all possible combinations"


0 comments on commit 2c4cb41

Please sign in to comment.