Skip to content

Commit

Permalink
fix: allow duplicate field names in table join, fix output with dupli…
Browse files Browse the repository at this point in the history
…cated names (#1023)

* fix: allow duplicate field names in table join

* move join related code into join_utils.rs
  • Loading branch information
QP Hou authored Sep 20, 2021
1 parent 843cd93 commit 65483d3
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 212 deletions.
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::record_batch::RecordBatch;
use futures::{Stream, TryStreamExt};

use super::{
coalesce_partitions::CoalescePartitionsExec, hash_utils::check_join_is_valid,
coalesce_partitions::CoalescePartitionsExec, join_utils::check_join_is_valid,
ColumnStatistics, Statistics,
};
use crate::{
Expand Down
142 changes: 75 additions & 67 deletions datafusion/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use hashbrown::raw::RawTable;

use super::{
coalesce_partitions::CoalescePartitionsExec,
hash_utils::{build_join_schema, check_join_is_valid, JoinOn},
join_utils::{build_join_schema, check_join_is_valid, ColumnIndex, JoinOn, JoinSide},
};
use super::{
expressions::Column,
Expand Down Expand Up @@ -115,6 +115,8 @@ pub struct HashJoinExec {
mode: PartitionMode,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
}

/// Metrics for HashJoinExec
Expand Down Expand Up @@ -165,14 +167,6 @@ pub enum PartitionMode {
CollectLeft,
}

/// Information about the index and placement (left or right) of the columns
struct ColumnIndex {
/// Index of the column
index: usize,
/// Whether the column is at the left or right side
is_left: bool,
}

impl HashJoinExec {
/// Tries to create a new [HashJoinExec].
/// # Error
Expand All @@ -188,7 +182,8 @@ impl HashJoinExec {
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &on)?;

let schema = Arc::new(build_join_schema(&left_schema, &right_schema, join_type));
let (schema, column_indices) =
build_join_schema(&left_schema, &right_schema, join_type);

let random_state = RandomState::with_seeds(0, 0, 0, 0);

Expand All @@ -197,11 +192,12 @@ impl HashJoinExec {
right,
on,
join_type: *join_type,
schema,
schema: Arc::new(schema),
build_side: Arc::new(Mutex::new(None)),
random_state,
mode: partition_mode,
metrics: ExecutionPlanMetricsSet::new(),
column_indices,
})
}

Expand Down Expand Up @@ -229,38 +225,6 @@ impl HashJoinExec {
pub fn partition_mode(&self) -> &PartitionMode {
&self.mode
}

/// Calculates column indices and left/right placement on input / output schemas and jointype
fn column_indices_from_schema(&self) -> ArrowResult<Vec<ColumnIndex>> {
let (primary_is_left, primary_schema, secondary_schema) = match self.join_type {
JoinType::Inner
| JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::Anti => (true, self.left.schema(), self.right.schema()),
JoinType::Right => (false, self.right.schema(), self.left.schema()),
};
let mut column_indices = Vec::with_capacity(self.schema.fields().len());
for field in self.schema.fields() {
let (is_primary, index) = match primary_schema.index_of(field.name()) {
Ok(i) => Ok((true, i)),
Err(_) => {
match secondary_schema.index_of(field.name()) {
Ok(i) => Ok((false, i)),
_ => Err(DataFusionError::Internal(
format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string()
))
}
}
}.map_err(DataFusionError::into_arrow_external_error)?;

let is_left =
is_primary && primary_is_left || !is_primary && !primary_is_left;
column_indices.push(ColumnIndex { index, is_left });
}

Ok(column_indices)
}
}

#[async_trait]
Expand Down Expand Up @@ -421,7 +385,6 @@ impl ExecutionPlan for HashJoinExec {
let right_stream = self.right.execute(partition).await?;
let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();

let column_indices = self.column_indices_from_schema()?;
let num_rows = left_data.1.num_rows();
let visited_left_side = match self.join_type {
JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
Expand All @@ -436,7 +399,7 @@ impl ExecutionPlan for HashJoinExec {
self.join_type,
left_data,
right_stream,
column_indices,
self.column_indices.clone(),
self.random_state.clone(),
visited_left_side,
HashJoinMetrics::new(partition, &self.metrics),
Expand Down Expand Up @@ -522,8 +485,6 @@ struct HashJoinStream {
left_data: JoinLeftData,
/// right
right: SendableRecordBatchStream,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
/// Random state used for hashing initialization
random_state: RandomState,
/// Keeps track of the left side rows whether they are visited
Expand All @@ -532,6 +493,8 @@ struct HashJoinStream {
is_exhausted: bool,
/// Metrics
join_metrics: HashJoinMetrics,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -589,12 +552,15 @@ fn build_batch_from_indices(
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());

for column_index in column_indices {
let array = if column_index.is_left {
let array = left.column(column_index.index);
compute::take(array.as_ref(), &left_indices, None)?
} else {
let array = right.column(column_index.index);
compute::take(array.as_ref(), &right_indices, None)?
let array = match column_index.side {
JoinSide::Left => {
let array = left.column(column_index.index);
compute::take(array.as_ref(), &left_indices, None)?
}
JoinSide::Right => {
let array = right.column(column_index.index);
compute::take(array.as_ref(), &right_indices, None)?
}
};
columns.push(array);
}
Expand Down Expand Up @@ -861,12 +827,15 @@ fn produce_from_matched(
let num_rows = indices.len();
let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
for (idx, column_index) in column_indices.iter().enumerate() {
let array = if column_index.is_left {
let array = left_data.1.column(column_index.index);
compute::take(array.as_ref(), &indices, None).unwrap()
} else {
let datatype = schema.field(idx).data_type();
arrow::array::new_null_array(datatype, num_rows)
let array = match column_index.side {
JoinSide::Left => {
let array = left_data.1.column(column_index.index);
compute::take(array.as_ref(), &indices, None).unwrap()
}
JoinSide::Right => {
let datatype = schema.field(idx).data_type();
arrow::array::new_null_array(datatype, num_rows)
}
};

columns.push(array);
Expand Down Expand Up @@ -1375,7 +1344,7 @@ mod tests {
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"| 3 | 7 | 9 | | 7 | |",
"| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];

Expand Down Expand Up @@ -1451,9 +1420,9 @@ mod tests {
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b1 | c2 |",
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | | 4 | |",
"| 2 | 5 | 8 | | 5 | |",
"| 3 | 7 | 9 | | 7 | |",
"| 1 | 4 | 7 | | | |",
"| 2 | 5 | 8 | | | |",
"| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];

Expand Down Expand Up @@ -1523,7 +1492,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"| 3 | 7 | 9 | | 7 | |",
"| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
Expand Down Expand Up @@ -1563,7 +1532,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"| 3 | 7 | 9 | | 7 | |",
"| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
Expand Down Expand Up @@ -1672,7 +1641,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b1 | c2 |",
"+----+----+----+----+----+----+",
"| | 6 | | 30 | 6 | 90 |",
"| | | | 30 | 6 | 90 |",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"+----+----+----+----+----+----+",
Expand Down Expand Up @@ -1709,7 +1678,7 @@ mod tests {
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b1 | c2 |",
"+----+----+----+----+----+----+",
"| | 6 | | 30 | 6 | 90 |",
"| | | | 30 | 6 | 90 |",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"+----+----+----+----+----+----+",
Expand Down Expand Up @@ -1808,4 +1777,43 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn join_with_duplicated_column_names() -> Result<()> {
let left = build_table(
("a", &vec![1, 2, 3]),
("b", &vec![4, 5, 7]),
("c", &vec![7, 8, 9]),
);
let right = build_table(
("a", &vec![10, 20, 30]),
("b", &vec![1, 2, 7]),
("c", &vec![70, 80, 90]),
);
let on = vec![(
// join on a=b so there are duplicate column names on unjoined columns
Column::new_with_schema("a", &left.schema()).unwrap(),
Column::new_with_schema("b", &right.schema()).unwrap(),
)];

let join = join(left, right, on, &JoinType::Inner)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);

let stream = join.execute(0).await?;
let batches = common::collect(stream).await?;

let expected = vec![
"+---+---+---+----+---+----+",
"| a | b | c | a | b | c |",
"+---+---+---+----+---+----+",
"| 1 | 4 | 7 | 10 | 1 | 70 |",
"| 2 | 5 | 8 | 20 | 2 | 80 |",
"+---+---+---+----+---+----+",
];
assert_batches_sorted_eq!(expected, &batches);

Ok(())
}
}
Loading

0 comments on commit 65483d3

Please sign in to comment.