Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass down precision/scale when casting to DECIMAL #191

Merged
merged 3 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,16 @@ protected SqlNode visitBinary(ASTNode node, ParseContext ctx) {

@Override
protected SqlNode visitDecimal(ASTNode node, ParseContext ctx) {
if (node.getChildCount() == 2) {
try {
final SqlTypeNameSpec typeNameSpec = new SqlBasicTypeNameSpec(SqlTypeName.DECIMAL,
Integer.parseInt(((ASTNode) node.getChildren().get(0)).getText()),
Integer.parseInt(((ASTNode) node.getChildren().get(1)).getText()), ZERO);
return new SqlDataTypeSpec(typeNameSpec, ZERO);
} catch (NumberFormatException e) {
return createBasicTypeSpec(SqlTypeName.DECIMAL);
}
}
return createBasicTypeSpec(SqlTypeName.DECIMAL);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,24 @@ public void testConcat() {
assertEquals(generated, expected);
}

@Test
public void testCastToDecimal() {
final String expected =
"LogicalProject(EXPR$0=[CAST($0):DECIMAL(6, 2)])\n" + " LogicalTableScan(table=[[hive, default, foo]])\n";
final String sql = "SELECT CAST(a AS DECIMAL(6, 2)) FROM foo";
String generated = relToString(sql);
assertEquals(generated, expected);
}

@Test
public void testCastToDecimalDefault() {
final String expected =
"LogicalProject(EXPR$0=[CAST($0):DECIMAL(10, 0)])\n" + " LogicalTableScan(table=[[hive, default, foo]])\n";
final String sql = "SELECT CAST(a AS DECIMAL) FROM foo";
String generated = relToString(sql);
assertEquals(generated, expected);
}

private String relToString(String sql) {
return RelOptUtil.toString(converter.convertSql(sql));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,4 +597,18 @@ public void testAliasHaving() {
+ "HAVING SUBSTRING(b, 1, 1)\n" + "IN ('dummy_value')";
assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql);
}

@Test
public void testCastDecimal() {
RelNode relNode = TestUtils.toRelNode("SELECT CAST(a as DECIMAL(6, 2)) as casted_decimal FROM default.foo");
String targetSql = "SELECT CAST(a AS DECIMAL(6, 2)) casted_decimal\n" + "FROM default.foo";
assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql);
}

@Test
public void testCastDecimalDefault() {
RelNode relNode = TestUtils.toRelNode("SELECT CAST(a as DECIMAL) as casted_decimal FROM default.foo");
String targetSql = "SELECT CAST(a AS DECIMAL(10, 0)) casted_decimal\n" + "FROM default.foo";
assertEquals(CoralSpark.create(relNode).getSparkSql(), targetSql);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ public Object[][] viewTestCasesProvider() {
+ "FROM \"test\".\"table_ints_strings\"" },

{ "test", "least_view", "SELECT \"least\"(\"a\", \"b\") AS \"g_int\", \"least\"(\"c\", \"d\") AS \"g_string\"\n"
+ "FROM \"test\".\"table_ints_strings\"" },

{ "test", "cast_decimal_view", "SELECT CAST(\"a\" AS DECIMAL(6, 2)) AS \"casted_decimal\"\n"
+ "FROM \"test\".\"table_ints_strings\"" } };
}

Expand Down Expand Up @@ -465,4 +468,28 @@ public void testAliasHaving() {
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}

@Test
public void testCastDecimal() {
RelToTrinoConverter relToTrinoConverter = new RelToTrinoConverter();

RelNode relNode = hiveToRelConverter
.convertSql("SELECT CAST(t.a as DECIMAL(6, 2)) as casted_decimal FROM test.table_ints_strings t");
String targetSql =
"SELECT CAST(\"a\" AS DECIMAL(6, 2)) AS \"casted_decimal\"\n" + "FROM \"test\".\"table_ints_strings\"";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}

@Test
public void testCastDecimalDefault() {
RelToTrinoConverter relToTrinoConverter = new RelToTrinoConverter();

RelNode relNode =
hiveToRelConverter.convertSql("SELECT CAST(t.a as DECIMAL) as casted_decimal FROM test.table_ints_strings t");
String targetSql =
"SELECT CAST(\"a\" AS DECIMAL(10, 0)) AS \"casted_decimal\"\n" + "FROM \"test\".\"table_ints_strings\"";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ public static void initializeViews(Path metastoreDbDirectory) throws HiveExcepti
run(driver, "CREATE VIEW IF NOT EXISTS test.least_view AS \n"
+ "SELECT least(t.a, t.b) as g_int, least(t.c, t.d) as g_string FROM test.table_ints_strings t");

run(driver, "CREATE VIEW IF NOT EXISTS test.cast_decimal_view AS \n"
+ "SELECT CAST(t.a as DECIMAL(6,2)) as casted_decimal FROM test.table_ints_strings t");

run(driver, "CREATE TABLE IF NOT EXISTS test.tableS (structCol struct<a:int>)");
run(driver, "CREATE TABLE IF NOT EXISTS test.tableT (structCol struct<a:int>)");
run(driver, "CREATE VIEW IF NOT EXISTS test.viewA AS SELECT structCol as struct_col FROM test.tableS");
Expand Down