-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bb9a07f
commit 2c4cb41
Showing
3 changed files
with
156 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
use daft_logical_plan::LogicalPlanBuilder; | ||
use eyre::{bail, WrapErr}; | ||
use spark_connect::join::JoinType; | ||
use tracing::warn; | ||
|
||
use crate::translation::to_logical_plan; | ||
|
||
pub fn join(join: spark_connect::Join) -> eyre::Result<LogicalPlanBuilder> { | ||
let spark_connect::Join { | ||
left, | ||
right, | ||
join_condition, | ||
join_type, | ||
using_columns, | ||
join_data_type, | ||
} = join; | ||
|
||
let Some(left) = left else { | ||
bail!("Left side of join is required"); | ||
}; | ||
|
||
let Some(right) = right else { | ||
bail!("Right side of join is required"); | ||
}; | ||
|
||
if let Some(join_condition) = join_condition { | ||
bail!("Join conditions are not yet supported; use using_columns (join keys) instead; got {join_condition:?}"); | ||
} | ||
|
||
let join_type = JoinType::try_from(join_type) | ||
.wrap_err_with(|| format!("Invalid join type: {join_type:?}"))?; | ||
|
||
let join_type = to_daft_join_type(join_type)?; | ||
|
||
let using_columns_exprs: Vec<_> = using_columns | ||
.iter() | ||
.map(|s| daft_dsl::col(s.as_str())) | ||
.collect(); | ||
|
||
if let Some(join_data_type) = join_data_type { | ||
warn!("Ignoring join data type {join_data_type:?} for join; not yet implemented"); | ||
} | ||
|
||
let left = to_logical_plan(*left)?; | ||
let right = to_logical_plan(*right)?; | ||
|
||
let result = match join_type { | ||
JoinTypeInfo::Cross => { | ||
left.cross_join(&right, None, None)? // todo(correctness): is this correct? | ||
} | ||
JoinTypeInfo::Regular(join_type) => { | ||
left.join( | ||
&right, | ||
// join_conditions.clone(), // todo(correctness): is this correct? | ||
// join_conditions, // todo(correctness): is this correct? | ||
using_columns_exprs.clone(), | ||
using_columns_exprs, | ||
join_type, | ||
None, | ||
None, | ||
None, | ||
false, // todo(correctness): we want join keys or not | ||
)? | ||
} | ||
}; | ||
|
||
Ok(result) | ||
} | ||
|
||
enum JoinTypeInfo { | ||
Regular(daft_core::join::JoinType), | ||
Cross, | ||
} | ||
|
||
impl From<daft_logical_plan::JoinType> for JoinTypeInfo { | ||
fn from(join_type: daft_logical_plan::JoinType) -> Self { | ||
JoinTypeInfo::Regular(join_type) | ||
} | ||
} | ||
|
||
fn to_daft_join_type(join_type: JoinType) -> eyre::Result<JoinTypeInfo> { | ||
match join_type { | ||
JoinType::Unspecified => { | ||
bail!("Join type must be specified; got Unspecified") | ||
} | ||
JoinType::Inner => Ok(daft_core::join::JoinType::Inner.into()), | ||
JoinType::FullOuter => { | ||
bail!("Full outer joins not yet supported") // todo(completeness): add support for full outer joins if it is not already implemented | ||
} | ||
JoinType::LeftOuter => Ok(daft_core::join::JoinType::Left.into()), // todo(correctness): is this correct? | ||
JoinType::RightOuter => Ok(daft_core::join::JoinType::Right.into()), | ||
JoinType::LeftAnti => Ok(daft_core::join::JoinType::Anti.into()), // todo(correctness): is this correct? | ||
JoinType::LeftSemi => bail!("Left semi joins not yet supported"), // todo(completeness): add support for left semi joins if it is not already implemented | ||
JoinType::Cross => Ok(JoinTypeInfo::Cross), | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
from pyspark.sql.functions import col | ||
|
||
|
||
def test_join(spark_session): | ||
# Create two DataFrames with overlapping IDs | ||
df1 = spark_session.range(5) | ||
df2 = spark_session.range(3, 7) | ||
|
||
# Perform inner join on 'id' column | ||
joined_df = df1.join(df2, "id", "inner") | ||
|
||
# Verify join results using collect() | ||
joined_ids = {row.id for row in joined_df.select("id").collect()} | ||
assert joined_ids == {3, 4}, "Inner join should only contain IDs 3 and 4" | ||
|
||
# Test left outer join | ||
left_joined_df = df1.join(df2, "id", "left") | ||
left_joined_ids = {row.id for row in left_joined_df.select("id").collect()} | ||
assert left_joined_ids == {0, 1, 2, 3, 4}, "Left join should keep all rows from left DataFrame" | ||
|
||
# Test right outer join | ||
right_joined_df = df1.join(df2, "id", "right") | ||
right_joined_ids = {row.id for row in right_joined_df.select("id").collect()} | ||
assert right_joined_ids == {3, 4, 5, 6}, "Right join should keep all rows from right DataFrame" | ||
|
||
|
||
|
||
def test_cross_join(spark_session): | ||
# Create two small DataFrames to demonstrate cross join | ||
# df_left: [0, 1] | ||
# df_right: [10, 11] | ||
# Expected result will be all combinations: | ||
# id1 id2 | ||
# 0 10 | ||
# 0 11 | ||
# 1 10 | ||
# 1 11 | ||
df_left = spark_session.range(2) | ||
df_right = spark_session.range(10, 12).withColumnRenamed("id", "id2") | ||
|
||
# Perform cross join - this creates cartesian product of both DataFrames | ||
cross_joined_df = df_left.crossJoin(df_right) | ||
|
||
# Convert to pandas for easier verification | ||
result_df = cross_joined_df.toPandas() | ||
|
||
# Verify we get all 4 combinations (2 x 2 = 4 rows) | ||
assert len(result_df) == 4, "Cross join should produce 4 rows (2x2 cartesian product)" | ||
|
||
# Verify all expected combinations exist | ||
expected_combinations = {(0, 10), (0, 11), (1, 10), (1, 11)} | ||
actual_combinations = {(row["id"], row["id2"]) for _, row in result_df.iterrows()} | ||
assert actual_combinations == expected_combinations, "Cross join should contain all possible combinations" | ||
|
||
|