Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for parameterized decimal cast #1483

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ import org.partiql.value.PartiQLValueType.BOOL
import org.partiql.value.PartiQLValueType.CHAR
import org.partiql.value.PartiQLValueType.DATE
import org.partiql.value.PartiQLValueType.DECIMAL
import org.partiql.value.PartiQLValueType.DECIMAL_ARBITRARY
import org.partiql.value.PartiQLValueType.FLOAT32
import org.partiql.value.PartiQLValueType.FLOAT64
import org.partiql.value.PartiQLValueType.INT
import org.partiql.value.PartiQLValueType.INT16
import org.partiql.value.PartiQLValueType.INT32
import org.partiql.value.PartiQLValueType.INT64
import org.partiql.value.PartiQLValueType.INT8
import org.partiql.value.PartiQLValueType.MISSING
import org.partiql.value.PartiQLValueType.NULL
import org.partiql.value.PartiQLValueType.STRING
Expand Down Expand Up @@ -63,6 +68,7 @@ internal object PartiQLHeader : Header() {
mod(),
concat(),
bitwiseAnd(),
castAsParameterizedDecimal(), // explicit casts (aka NOT coercions from TypeLattice).
).flatten()

/**
Expand Down Expand Up @@ -460,6 +466,32 @@ internal object PartiQLHeader : Header() {
)
}

private fun castAsParameterizedDecimal(): List<FunctionSignature.Scalar> = listOf(
BOOL,
INT8,
INT16,
INT32,
INT64,
INT,
DECIMAL,
DECIMAL_ARBITRARY,
FLOAT32,
FLOAT64,
STRING,
).map { value ->
FunctionSignature.Scalar(
name = "cast_decimal",
returns = DECIMAL,
parameters = listOf(
FunctionParameter("value", value),
FunctionParameter("precision", INT32),
FunctionParameter("scale", INT32),
),
isNullable = false,
isNullCall = true,
)
}

// SUBSTRING (expression, start[, length]?)
// SUBSTRINGG(expression from start [FOR length]? )
private fun substring(): List<FunctionSignature.Scalar> = types.text.map { t ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ internal object RexConverter {
* @param ctx
* @return
*/
private fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex {
private fun visitExprCoerce(
node: Expr,
ctx: Env,
coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR,
): Rex {
val rex = super.visitExpr(node, ctx)
return when (rex.op is Rex.Op.Select) {
true -> rex(StaticType.ANY, rexOpSubquery(rex.op, coercion))
Expand Down Expand Up @@ -188,7 +192,10 @@ internal object RexConverter {
when (identifierSteps.size) {
0 -> root to node.steps
else -> {
val newRoot = rex(StaticType.ANY, rexOpVarUnresolved(mergeIdentifiers(op.identifier, identifierSteps), op.scope))
val newRoot = rex(
StaticType.ANY,
rexOpVarUnresolved(mergeIdentifiers(op.identifier, identifierSteps), op.scope)
)
val newSteps = node.steps.subList(identifierSteps.size, node.steps.size)
newRoot to newSteps
}
Expand Down Expand Up @@ -219,7 +226,10 @@ internal object RexConverter {
is Expr.Path.Step.Symbol -> {
val identifier = AstToPlan.convert(step.symbol)
when (identifier.caseSensitivity) {
Identifier.CaseSensitivity.SENSITIVE -> rexOpPathKey(current, rexString(identifier.symbol))
Identifier.CaseSensitivity.SENSITIVE -> rexOpPathKey(
current,
rexString(identifier.symbol)
)
Identifier.CaseSensitivity.INSENSITIVE -> rexOpPathSymbol(current, identifier.symbol)
}
}
Expand Down Expand Up @@ -516,7 +526,7 @@ internal object RexConverter {
TODO("SQL Special Form EXTRACT")
}

// TODO: Ignoring type parameter now
// TODO: Ignoring type parameters (EXCEPT DECIMAL) now
override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex {
val type = node.asType
val arg0 = visitExprCoerce(node.value, ctx)
Expand All @@ -532,7 +542,17 @@ internal object RexConverter {
is Type.Real -> TODO("Static Type does not have REAL type")
is Type.Float32 -> TODO("Static Type does not have FLOAT32 type")
is Type.Float64 -> rex(StaticType.FLOAT, call("cast_float64", arg0))
is Type.Decimal -> rex(StaticType.DECIMAL, call("cast_decimal", arg0))
is Type.Decimal -> {
if (type.precision != null) {
// CONSTRAINED — cast_decimal(arg, precision, scale)
val p = rex(StaticType.INT4, rexOpLit(int32Value(type.precision)))
val s = rex(StaticType.INT4, rexOpLit(int32Value(type.scale ?: 0)))
rex(StaticType.DECIMAL, call("cast_decimal", arg0, p, s))
} else {
// UNCONSTRAINED — cast_decimal(arg)
rex(StaticType.DECIMAL, call("cast_decimal", arg0))
}
}
is Type.Numeric -> rex(StaticType.DECIMAL, call("cast_numeric", arg0))
is Type.Char -> rex(StaticType.CHAR, call("cast_char", arg0))
is Type.Varchar -> rex(StaticType.STRING, call("cast_varchar", arg0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ internal sealed class FnMatch<T : FunctionSignature> {
* @property candidates an ordered list of potentially applicable functions to dispatch dynamically.
*/
public data class Dynamic<T : FunctionSignature>(
public val candidates: List<Ok<T>>
public val candidates: List<Ok<T>>,
) : FnMatch<T>()

public data class Error<T : FunctionSignature>(
Expand Down Expand Up @@ -354,7 +354,36 @@ internal class FnResolver(private val header: Header) {
*
* But what about parameterized types? Are the parameters dropped in casts, or do parameters become arguments?
*/
private fun castName(type: PartiQLValueType) = "cast_${type.name.lowercase()}"
private fun castName(type: PartiQLValueType): String = when (type) {
ANY -> "cast_any" // TODO remove, only added for backwards compatibility in next release.
BOOL -> "cast_bool"
INT8 -> "cast_int8"
INT16 -> "cast_int16"
INT32 -> "cast_int32"
INT64 -> "cast_int64"
INT -> "cast_int"
DECIMAL -> "cast_decimal"
DECIMAL_ARBITRARY -> "cast_decimal"
FLOAT32 -> "cast_float32"
FLOAT64 -> "cast_float64"
CHAR -> "cast_char"
STRING -> "cast_string"
SYMBOL -> "cast_symbol"
BINARY -> "cast_binary"
BYTE -> "cast_byte"
BLOB -> "cast_blob"
CLOB -> "cast_clob"
DATE -> "cast_date"
TIME -> "cast_time"
TIMESTAMP -> "cast_timestamp"
INTERVAL -> "cast_interval"
BAG -> "cast_bag"
LIST -> "cast_list"
SEXP -> "cast_sexp"
STRUCT -> "cast_struct"
PartiQLValueType.NULL -> "cast_null" // TODO remove, only added for backwards compatibility in next release.
PartiQLValueType.MISSING -> "cast_missing" // TODO remove, only added for backwards compatibility in next release.
}

internal fun cast(operand: PartiQLValueType, target: PartiQLValueType) =
FunctionSignature.Scalar(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ import org.partiql.types.AnyType
import org.partiql.types.BagType
import org.partiql.types.BoolType
import org.partiql.types.CollectionType
import org.partiql.types.DecimalType
import org.partiql.types.IntType
import org.partiql.types.ListType
import org.partiql.types.SexpType
Expand All @@ -91,6 +92,7 @@ import org.partiql.types.StructType
import org.partiql.types.TupleConstraint
import org.partiql.types.function.FunctionSignature
import org.partiql.value.BoolValue
import org.partiql.value.Int32Value
import org.partiql.value.MissingValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.TextValue
Expand Down Expand Up @@ -634,12 +636,37 @@ internal class PlanTyper(
}
}

// TODO we have to pull out decimal type parameters here because V0 drops the type in CAST.
if (newFn.signature.name == "cast_decimal" && newFn.signature.parameters.size == 3) {
val p = getIntOrErr(newArgs[1].op)
val s = getIntOrErr(newArgs[2].op)
val returns = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(p, s))
val op = rexOpCallStatic(newFn, newArgs)
return rex(returns, op)
}

// Type return
val returns = newFn.signature.returns
val op = rexOpCallStatic(newFn, newArgs)
return rex(returns.toStaticType().flatten(), op)
}

/**
* For `cast_decimal(v, precision, scale)` we make the precision and scale literal 32-bit integers.
*/
private fun getIntOrErr(op: Rex.Op): Int {
if (op !is Rex.Op.Lit) {
error("Unrecoverable, expected Rex.Op.Lit found ${op::class}. This should be unreachable.")
}
if (op.value !is Int32Value) {
error("Unrecoverable, expected Int32Value found ${op.value::class}. This should be unreachable.")
}
if (op.value.value == null) {
error("Int32Value cannot be null. This should be unreachable.")
}
return op.value.value!!
}

override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex {
// Rewrite CASE-WHEN branches
val oldBranches = node.branches.toTypedArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ internal class TypeLattice private constructor(
INT32 to unsafe(),
INT64 to unsafe(),
INT to unsafe(),
DECIMAL to unsafe(),
DECIMAL_ARBITRARY to unsafe(),
STRING to coercion(),
SYMBOL to explicit(),
CLOB to coercion(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ import org.partiql.plugins.memory.MemoryConnector
import org.partiql.spi.connector.ConnectorMetadata
import org.partiql.types.AnyType
import org.partiql.types.BagType
import org.partiql.types.DecimalType
import org.partiql.types.ListType
import org.partiql.types.SexpType
import org.partiql.types.StaticType
import org.partiql.types.StaticType.Companion.ANY
import org.partiql.types.StaticType.Companion.DECIMAL
import org.partiql.types.StaticType.Companion.INT
import org.partiql.types.StaticType.Companion.INT4
import org.partiql.types.StaticType.Companion.INT8
Expand Down Expand Up @@ -287,6 +289,80 @@ class PlanTyperTestsPorted {
@JvmStatic
fun structs() = listOf<TestCase>()

@JvmStatic
fun decimalCastCases() = listOf<TestCase>(
SuccessTestCase(
name = "cast decimal",
query = "CAST(1 AS DECIMAL)",
expected = StaticType.DECIMAL,
),
SuccessTestCase(
name = "cast decimal(1)",
query = "CAST(1 AS DECIMAL(1))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)),
),
SuccessTestCase(
name = "cast decimal(1,0)",
query = "CAST(1 AS DECIMAL(1,0))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)),
),
SuccessTestCase(
name = "cast decimal(1,1)",
query = "CAST(1 AS DECIMAL(1,1))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 1)),
),
SuccessTestCase(
name = "cast decimal(38)",
query = "CAST(1 AS DECIMAL(38))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)),
),
SuccessTestCase(
name = "cast decimal(38,0)",
query = "CAST(1 AS DECIMAL(38,0))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)),
),
SuccessTestCase(
name = "cast decimal(38,38)",
query = "CAST(1 AS DECIMAL(38,38))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 38)),
),
SuccessTestCase(
name = "cast decimal string",
query = "CAST('1' AS DECIMAL)",
expected = StaticType.DECIMAL,
),
SuccessTestCase(
name = "cast decimal(1) string",
query = "CAST('1' AS DECIMAL(1))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)),
),
SuccessTestCase(
name = "cast decimal(1,0) string",
query = "CAST('1' AS DECIMAL(1,0))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 0)),
),
SuccessTestCase(
name = "cast decimal(1,1) string",
query = "CAST('1' AS DECIMAL(1,1))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(1, 1)),
),
SuccessTestCase(
name = "cast decimal(38) string",
query = "CAST('1' AS DECIMAL(38))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)),
),
SuccessTestCase(
name = "cast decimal(38,0) string",
query = "CAST('1' AS DECIMAL(38,0))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 0)),
),
SuccessTestCase(
name = "cast decimal(38,38) string",
query = "CAST('1' AS DECIMAL(38,38))",
expected = DecimalType(DecimalType.PrecisionScaleConstraint.Constrained(38, 38)),
),
)

@JvmStatic
fun selectStar() = listOf<TestCase>(
SuccessTestCase(
Expand Down Expand Up @@ -3452,6 +3528,11 @@ class PlanTyperTestsPorted {
@Execution(ExecutionMode.CONCURRENT)
fun testCollections(tc: TestCase) = runTest(tc)

@ParameterizedTest
@MethodSource("decimalCastCases")
@Execution(ExecutionMode.CONCURRENT)
fun testDecimalCast(tc: TestCase) = runTest(tc)

@ParameterizedTest
@MethodSource("selectStar")
@Execution(ExecutionMode.CONCURRENT)
Expand Down
Loading