-
Notifications
You must be signed in to change notification settings - Fork 171
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3394a66
commit 17f65ec
Showing
3 changed files
with
108 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
use daft_logical_plan::LogicalPlanBuilder; | ||
use eyre::{bail, WrapErr}; | ||
use spark_connect::join::JoinType; | ||
use tracing::warn; | ||
|
||
use crate::translation::{to_daft_expr, to_logical_plan}; | ||
|
||
pub fn join(join: spark_connect::Join) -> eyre::Result<LogicalPlanBuilder> { | ||
let spark_connect::Join { | ||
left, | ||
right, | ||
join_condition, | ||
join_type, | ||
using_columns, | ||
join_data_type, | ||
} = join; | ||
|
||
let Some(left) = left else { | ||
bail!("Left side of join is required"); | ||
}; | ||
|
||
let Some(right) = right else { | ||
bail!("Right side of join is required"); | ||
}; | ||
|
||
if let Some(join_condition) = join_condition { | ||
bail!("Join conditions are not yet supported; use using_columns (join keys) instead; got {join_condition:?}"); | ||
} | ||
|
||
let join_type = JoinType::try_from(join_type) | ||
.wrap_err_with(|| format!("Invalid join type: {join_type:?}"))?; | ||
|
||
let join_type = to_daft_join_type(join_type)?; | ||
|
||
let using_columns_exprs: Vec<_> = using_columns | ||
.iter() | ||
.map(|s| daft_dsl::col(s.as_str())) | ||
.collect(); | ||
|
||
if let Some(join_data_type) = join_data_type { | ||
warn!("Ignoring join data type {join_data_type:?} for join; not yet implemented"); | ||
} | ||
|
||
let left = to_logical_plan(*left)?; | ||
let right = to_logical_plan(*right)?; | ||
|
||
Ok(left.join( | ||
&right, | ||
// join_conditions.clone(), // todo(correctness): is this correct? | ||
// join_conditions, // todo(correctness): is this correct? | ||
using_columns_exprs.clone(), | ||
using_columns_exprs, | ||
join_type, | ||
None, | ||
None, | ||
None, | ||
false, // todo(correctness): we want join keys or not | ||
)?) | ||
} | ||
|
||
fn to_daft_join_type(join_type: JoinType) -> eyre::Result<daft_core::join::JoinType> { | ||
match join_type { | ||
JoinType::Unspecified => { | ||
bail!("Join type must be specified; got Unspecified") | ||
} | ||
JoinType::Inner => Ok(daft_core::join::JoinType::Inner), | ||
JoinType::FullOuter => { | ||
bail!("Full outer joins not yet supported") // todo(completeness): add support for full outer joins if it is not already implemented | ||
} | ||
JoinType::LeftOuter => Ok(daft_core::join::JoinType::Left), // todo(correctness): is this correct? | ||
JoinType::RightOuter => Ok(daft_core::join::JoinType::Right), | ||
JoinType::LeftAnti => Ok(daft_core::join::JoinType::Anti), // todo(correctness): is this correct? | ||
JoinType::LeftSemi => bail!("Left semi joins not yet supported"), // todo(completeness): add support for left semi joins if it is not already implemented | ||
JoinType::Cross => bail!("Cross joins not yet supported"), // todo(completeness): add support for cross joins if it is not already implemented | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from __future__ import annotations | ||
|
||
from pyspark.sql.functions import col | ||
|
||
|
||
def test_join(spark_session): | ||
# Create two DataFrames with overlapping IDs | ||
df1 = spark_session.range(5) | ||
df2 = spark_session.range(3, 7) | ||
|
||
# Perform inner join on 'id' column | ||
joined_df = df1.join(df2, "id", "inner") | ||
|
||
# Verify join results | ||
# assert joined_df.count() == 2, "Inner join should only return matching rows" | ||
|
||
# Convert to pandas to verify exact results | ||
joined_pandas = joined_df.toPandas() | ||
assert set(joined_pandas["id"].tolist()) == {3, 4}, "Inner join should only contain IDs 3 and 4" | ||
|
||
# Test left outer join | ||
left_joined_pandas = df1.join(df2, "id", "left").toPandas() | ||
assert set(left_joined_pandas["id"].tolist()) == {0, 1, 2, 3, 4}, "Left join should keep all rows from left DataFrame" | ||
|
||
# Test right outer join | ||
right_joined_pandas = df1.join(df2, "id", "right").toPandas() | ||
assert set(right_joined_pandas["id"].tolist()) == {3, 4, 5, 6}, "Right join should keep all rows from right DataFrame" |