diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/FloatLiteral.java b/fe/fe-core/src/main/java/com/starrocks/analysis/FloatLiteral.java index 1d53161f467c0..45faf371bdc02 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/FloatLiteral.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/FloatLiteral.java @@ -76,6 +76,11 @@ public FloatLiteral(String value) throws AnalysisException { this(value, NodePosition.ZERO); } + public FloatLiteral(String value, Type type) throws AnalysisException { + this(value, NodePosition.ZERO); + this.type = type; + } + public FloatLiteral(String value, NodePosition pos) throws AnalysisException { super(pos); Double floatValue = null; diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/LiteralExpr.java b/fe/fe-core/src/main/java/com/starrocks/analysis/LiteralExpr.java index b61ebb9b7f0db..69d5c04c80e87 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/LiteralExpr.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/LiteralExpr.java @@ -87,7 +87,7 @@ public static LiteralExpr create(String value, Type type) throws AnalysisExcepti break; case FLOAT: case DOUBLE: - literalExpr = new FloatLiteral(value); + literalExpr = new FloatLiteral(value, type); break; case DECIMALV2: case DECIMAL32: diff --git a/fe/fe-core/src/test/java/com/starrocks/analysis/LiteralExprCompareTest.java b/fe/fe-core/src/test/java/com/starrocks/analysis/LiteralExprCompareTest.java index 373f8e283c2bc..c3cb02492fd6b 100644 --- a/fe/fe-core/src/test/java/com/starrocks/analysis/LiteralExprCompareTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/analysis/LiteralExprCompareTest.java @@ -19,6 +19,7 @@ import com.starrocks.catalog.PrimitiveType; import com.starrocks.catalog.ScalarType; +import com.starrocks.catalog.Type; import com.starrocks.common.AnalysisException; import org.junit.Assert; import org.junit.BeforeClass; @@ -181,6 +182,14 @@ public void floatAndDoubleExpr() throws AnalysisException { Assert.assertTrue(0 == double1.compareLiteral(double2)); // self equal Assert.assertTrue(0 == double1.compareLiteral(double1)); + + LiteralExpr floatType = LiteralExpr.create("3.14", Type.FLOAT); + Assert.assertEquals(PrimitiveType.FLOAT, floatType.getType().getPrimitiveType()); + Assert.assertEquals(true, floatType.equals(new FloatLiteral(3.14, Type.FLOAT))); + + LiteralExpr doubleType = LiteralExpr.create("3.14", Type.DOUBLE); + Assert.assertEquals(PrimitiveType.DOUBLE, doubleType.getType().getPrimitiveType()); + Assert.assertEquals(true, doubleType.equals(new FloatLiteral(3.14, Type.DOUBLE))); } private void intTestInternal(ScalarType type) throws AnalysisException {