Skip to content

Commit

Permalink
3307 add tinyint and unsigned int variants (#3359)
Browse files Browse the repository at this point in the history
  • Loading branch information
kmitchener authored Sep 6, 2022
1 parent 827cab9 commit 751cbc8
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 122 deletions.
69 changes: 69 additions & 0 deletions datafusion/core/tests/sql/cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::sql::execute_to_batches;
use arrow::datatypes::DataType;
use arrow::record_batch::RecordBatch;
use datafusion::error::Result;
use datafusion::prelude::SessionContext;

async fn execute_sql(sql: &str) -> Vec<RecordBatch> {
let ctx = SessionContext::new();
execute_to_batches(&ctx, sql).await
}

#[tokio::test]
async fn cast_tinyint() -> Result<()> {
let actual = execute_sql("SELECT cast(10 as tinyint)").await;
assert_eq!(&DataType::Int8, actual[0].schema().field(0).data_type());
Ok(())
}

#[tokio::test]
async fn cast_tinyint_operator() -> Result<()> {
let actual = execute_sql("SELECT 10::tinyint").await;
assert_eq!(&DataType::Int8, actual[0].schema().field(0).data_type());
Ok(())
}

#[tokio::test]
async fn cast_unsigned_tinyint() -> Result<()> {
let actual = execute_sql("SELECT 10::tinyint unsigned").await;
assert_eq!(&DataType::UInt8, actual[0].schema().field(0).data_type());
Ok(())
}

#[tokio::test]
async fn cast_unsigned_smallint() -> Result<()> {
let actual = execute_sql("SELECT 10::smallint unsigned").await;
assert_eq!(&DataType::UInt16, actual[0].schema().field(0).data_type());
Ok(())
}

#[tokio::test]
async fn cast_unsigned_int() -> Result<()> {
let actual = execute_sql("SELECT 10::integer unsigned").await;
assert_eq!(&DataType::UInt32, actual[0].schema().field(0).data_type());
Ok(())
}

#[tokio::test]
async fn cast_unsigned_bigint() -> Result<()> {
let actual = execute_sql("SELECT 10::bigint unsigned").await;
assert_eq!(&DataType::UInt64, actual[0].schema().field(0).data_type());
Ok(())
}
62 changes: 31 additions & 31 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]",
" TableScan: aggregate_test_100 [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]",
" TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down Expand Up @@ -243,9 +243,9 @@ async fn csv_explain_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand All @@ -271,8 +271,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
" Filter: #aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -286,8 +286,8 @@ async fn csv_explain_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
" Filter: #aggregate_test_100.c2 > Int32(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
" Filter: #aggregate_test_100.c2 > Int8(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -307,9 +307,9 @@ async fn csv_explain_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
Expand All @@ -318,9 +318,9 @@ async fn csv_explain_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand Down Expand Up @@ -349,7 +349,7 @@ async fn csv_explain_plans() {
// Since the plan contains path that are environmentally dependant (e.g. full path of the test file), only verify important content
assert_contains!(&actual, "logical_plan");
assert_contains!(&actual, "Projection: #aggregate_test_100.c1");
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int32(10)");
assert_contains!(actual, "Filter: #aggregate_test_100.c2 > Int8(10)");
}

#[tokio::test]
Expand Down Expand Up @@ -381,8 +381,8 @@ async fn csv_explain_inlist_verbose() {
let actual = execute(&ctx, sql).await;

// Optimized by PreCastLitInComparisonExpressions rule
// the data type of c2 is INT32, the type of `1,2,3,4` is INT64.
// the value of `1,2,4` will be casted to INT32 and pre-calculated
// the data type of c2 is INT8, the type of `1,2,4` is INT64.
// the value of `1,2,4` will be casted to INT8 and pre-calculated

// flatten to a single string
let actual = actual.into_iter().map(|r| r.join("\t")).collect::<String>();
Expand All @@ -392,10 +392,10 @@ async fn csv_explain_inlist_verbose() {
&actual,
"#aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])"
);
// after optimization (casted to Int32)
// after optimization (casted to Int8)
assert_contains!(
&actual,
"#aggregate_test_100.c2 IN ([Int32(1), Int32(2), Int32(4)])"
"#aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])"
);
}

Expand All @@ -420,8 +420,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]",
" TableScan: aggregate_test_100 [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]",
" Filter: #aggregate_test_100.c2 > Int64(10) [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]",
" TableScan: aggregate_test_100 [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down Expand Up @@ -467,9 +467,9 @@ async fn csv_explain_verbose_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int64(10)\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int32, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:Int64, c10:Utf8, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100\\nSchema: [c1:Utf8, c2:Int8, c3:Int16, c4:Int16, c5:Int32, c6:Int64, c7:Int16, c8:Int32, c9:UInt32, c10:UInt64, c11:Float32, c12:Float64, c13:Utf8]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand All @@ -495,8 +495,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #aggregate_test_100.c1 [c1:Utf8]",
" Filter: #aggregate_test_100.c2 > Int32(10) [c1:Utf8, c2:Int32]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)] [c1:Utf8, c2:Int32]",
" Filter: #aggregate_test_100.c2 > Int8(10) [c1:Utf8, c2:Int8]",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)] [c1:Utf8, c2:Int8]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -510,8 +510,8 @@ async fn csv_explain_verbose_plans() {
let expected = vec![
"Explain",
" Projection: #aggregate_test_100.c1",
" Filter: #aggregate_test_100.c2 > Int32(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]",
" Filter: #aggregate_test_100.c2 > Int8(10)",
" TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]",
];
let formatted = plan.display_indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand All @@ -531,9 +531,9 @@ async fn csv_explain_verbose_plans() {
" 2[shape=box label=\"Explain\"]",
" 3[shape=box label=\"Projection: #aggregate_test_100.c1\"]",
" 2 -> 3 [arrowhead=none, arrowtail=normal, dir=back]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\"]",
" 4[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\"]",
" 3 -> 4 [arrowhead=none, arrowtail=normal, dir=back]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\"]",
" 5[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\"]",
" 4 -> 5 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
" subgraph cluster_6",
Expand All @@ -542,9 +542,9 @@ async fn csv_explain_verbose_plans() {
" 7[shape=box label=\"Explain\\nSchema: [plan_type:Utf8, plan:Utf8]\"]",
" 8[shape=box label=\"Projection: #aggregate_test_100.c1\\nSchema: [c1:Utf8]\"]",
" 7 -> 8 [arrowhead=none, arrowtail=normal, dir=back]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int32(10)\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 9[shape=box label=\"Filter: #aggregate_test_100.c2 > Int8(10)\\nSchema: [c1:Utf8, c2:Int8]\"]",
" 8 -> 9 [arrowhead=none, arrowtail=normal, dir=back]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]\\nSchema: [c1:Utf8, c2:Int32]\"]",
" 10[shape=box label=\"TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]\\nSchema: [c1:Utf8, c2:Int8]\"]",
" 9 -> 10 [arrowhead=none, arrowtail=normal, dir=back]",
" }",
"}",
Expand Down Expand Up @@ -781,8 +781,8 @@ async fn csv_explain() {
vec![
"logical_plan",
"Projection: #aggregate_test_100.c1\
\n Filter: #aggregate_test_100.c2 > Int32(10)\
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int32(10)]"
\n Filter: #aggregate_test_100.c2 > Int8(10)\
\n TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[#aggregate_test_100.c2 > Int8(10)]"
],
vec!["physical_plan",
"ProjectionExec: expr=[c1@0 as c1]\
Expand Down
9 changes: 4 additions & 5 deletions datafusion/core/tests/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ macro_rules! test_expression {
pub mod aggregates;
#[cfg(feature = "avro")]
pub mod avro;
pub mod cast;
pub mod create_drop;
pub mod errors;
pub mod explain_analyze;
Expand Down Expand Up @@ -621,22 +622,20 @@ async fn register_tpch_csv_data(
async fn register_aggregate_csv_by_sql(ctx: &SessionContext) {
let testdata = datafusion::test_util::arrow_test_data();

// TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once
// unsigned is supported.
let df = ctx
.sql(&format!(
"
CREATE EXTERNAL TABLE aggregate_test_100 (
c1 VARCHAR NOT NULL,
c2 INT NOT NULL,
c2 TINYINT NOT NULL,
c3 SMALLINT NOT NULL,
c4 SMALLINT NOT NULL,
c5 INTEGER NOT NULL,
c6 BIGINT NOT NULL,
c7 SMALLINT NOT NULL,
c8 INT NOT NULL,
c9 BIGINT NOT NULL,
c10 VARCHAR NOT NULL,
c9 INT UNSIGNED NOT NULL,
c10 BIGINT UNSIGNED NOT NULL,
c11 FLOAT NOT NULL,
c12 DOUBLE NOT NULL,
c13 VARCHAR NOT NULL
Expand Down
Loading

0 comments on commit 751cbc8

Please sign in to comment.