Skip to content

Commit

Permalink
Fix SQL planner bug when resolving columns with same name as a relati…
Browse files Browse the repository at this point in the history
…on (#3003)

* repro unit test

* improve test

* fix

* make logic more robust

* add a simple test case

Co-authored-by: Wei-Ting Kuo <[email protected]>
  • Loading branch information
andygrove and waitingkuo authored Aug 2, 2022
1 parent 55a1286 commit e23925c
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 22 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/intersection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async fn intersect_with_null_not_equal() {
INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2";

let expected = vec!["++", "++"];
let ctx = create_join_context_qualified().unwrap();
let ctx = create_join_context_qualified("t1", "t2").unwrap();
let actual = execute_to_batches(&ctx, sql).await;
assert_batches_eq!(expected, &actual);
}
Expand All @@ -41,7 +41,7 @@ async fn intersect_with_null_equal() {
"+-----+-----+",
];

let ctx = create_join_context_qualified().unwrap();
let ctx = create_join_context_qualified("t1", "t2").unwrap();
let actual = execute_to_batches(&ctx, sql).await;

assert_batches_eq!(expected, &actual);
Expand Down
17 changes: 14 additions & 3 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async fn equijoin() -> Result<()> {
assert_batches_eq!(expected, &actual);
}

let ctx = create_join_context_qualified()?;
let ctx = create_join_context_qualified("t1", "t2")?;
let equivalent_sql = [
"SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a",
"SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a",
Expand Down Expand Up @@ -890,13 +890,24 @@ async fn inner_join_qualified_names() -> Result<()> {
];

for sql in equivalent_sql.iter() {
let ctx = create_join_context_qualified()?;
let ctx = create_join_context_qualified("t1", "t2")?;
let actual = execute_to_batches(&ctx, sql).await;
assert_batches_eq!(expected, &actual);
}
Ok(())
}

#[tokio::test]
async fn issue_3002() -> Result<()> {
// repro case for https://github.com/apache/arrow-datafusion/issues/3002
let sql = "select a.a, b.b from a join b on a.a = b.b";
let expected = vec!["++", "++"];
let ctx = create_join_context_qualified("a", "b")?;
let actual = execute_to_batches(&ctx, sql).await;
assert_batches_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn inner_join_nulls() {
let sql = "SELECT * FROM (SELECT null AS id1) t1
Expand All @@ -908,7 +919,7 @@ async fn inner_join_nulls() {
"++",
];

let ctx = create_join_context_qualified().unwrap();
let ctx = create_join_context_qualified("t1", "t2").unwrap();
let actual = execute_to_batches(&ctx, sql).await;

// left and right shouldn't match anything
Expand Down
9 changes: 6 additions & 3 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ fn create_join_context(column_left: &str, column_right: &str) -> Result<SessionC
Ok(ctx)
}

fn create_join_context_qualified() -> Result<SessionContext> {
fn create_join_context_qualified(
left_name: &str,
right_name: &str,
) -> Result<SessionContext> {
let ctx = SessionContext::new();

let t1_schema = Arc::new(Schema::new(vec![
Expand All @@ -245,7 +248,7 @@ fn create_join_context_qualified() -> Result<SessionContext> {
],
)?;
let t1_table = MemTable::try_new(t1_schema, vec![vec![t1_data]])?;
ctx.register_table("t1", Arc::new(t1_table))?;
ctx.register_table(left_name, Arc::new(t1_table))?;

let t2_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, true),
Expand All @@ -261,7 +264,7 @@ fn create_join_context_qualified() -> Result<SessionContext> {
],
)?;
let t2_table = MemTable::try_new(t2_schema, vec![vec![t2_data]])?;
ctx.register_table("t2", Arc::new(t2_table))?;
ctx.register_table(right_name, Arc::new(t2_table))?;

Ok(ctx)
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ async fn except_with_null_not_equal() {
"+-----+-----+",
];

let ctx = create_join_context_qualified().unwrap();
let ctx = create_join_context_qualified("t1", "t2").unwrap();
let actual = execute_to_batches(&ctx, sql).await;

assert_batches_eq!(expected, &actual);
Expand All @@ -325,7 +325,7 @@ async fn except_with_null_equal() {
EXCEPT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2";

let expected = vec!["++", "++"];
let ctx = create_join_context_qualified().unwrap();
let ctx = create_join_context_qualified("t1", "t2").unwrap();
let actual = execute_to_batches(&ctx, sql).await;

assert_batches_eq!(expected, &actual);
Expand Down
13 changes: 13 additions & 0 deletions datafusion/core/tests/sql/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,16 @@ fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) {
.collect();
assert_eq!(actual, expected);
}

#[tokio::test]
async fn paralleproject_column_with_same_name_as_relationl() -> Result<()> {
let ctx = SessionContext::new();

let sql = "select a.a from (select 1 as a) as a;";
let actual = execute_to_batches(&ctx, sql).await;

let expected = vec!["+---+", "| a |", "+---+", "| 1 |", "+---+"];
assert_batches_sorted_eq!(expected, &actual);

Ok(())
}
40 changes: 28 additions & 12 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1718,18 +1718,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
} else {
match (var_names.pop(), var_names.pop()) {
(Some(name), Some(relation)) if var_names.is_empty() => {
if let Some(field) = schema.fields().iter().find(|f| f.name().eq(&relation)) {
// Access to a field of a column which is a structure, example: SELECT my_struct.key
Ok(Expr::GetIndexedField {
expr: Box::new(Expr::Column(field.qualified_column())),
key: ScalarValue::Utf8(Some(name)),
})
} else {
// table.column identifier
Ok(Expr::Column(Column {
relation: Some(relation),
name,
}))
match schema.field_with_qualified_name(&relation, &name) {
Ok(_) => {
// found an exact match on a qualified name so this is a table.column identifier
Ok(Expr::Column(Column {
relation: Some(relation),
name,
}))
},
Err(e) => {
let search_term = format!(".{}.{}", relation, name);
if schema.field_names().iter().any(|name| name.as_str().ends_with(&search_term)) {
// this could probably be improved but here we handle the case
// where the qualifier is only a partial qualifier such as when
// referencing "t1.foo" when the available field is "public.t1.foo"
Ok(Expr::Column(Column {
relation: Some(relation),
name,
}))
} else if let Some(field) = schema.fields().iter().find(|f| f.name().eq(&relation)) {
// Access to a field of a column which is a structure, example: SELECT my_struct.key
Ok(Expr::GetIndexedField {
expr: Box::new(Expr::Column(field.qualified_column())),
key: ScalarValue::Utf8(Some(name)),
})
} else {
Err(e)
}
}
}
}
_ => Err(DataFusionError::NotImplemented(format!(
Expand Down

0 comments on commit e23925c

Please sign in to comment.