Skip to content

Commit

Permalink
like, cast, coalesce
Browse files Browse the repository at this point in the history
  • Loading branch information
yliuuuu committed Sep 22, 2023
1 parent da56996 commit 95bdf80
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import org.partiql.ast.Expr
import org.partiql.ast.Identifier
import org.partiql.ast.builder.AstFactory
import org.partiql.transpiler.ProblemCallback
import org.partiql.transpiler.error
import org.partiql.transpiler.info
import org.partiql.transpiler.sql.SqlArgs
import org.partiql.transpiler.sql.SqlCallFn
import org.partiql.transpiler.sql.SqlCalls
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import org.partiql.value.symbolValue

@OptIn(PartiQLValueExperimental::class)
Expand Down Expand Up @@ -53,5 +55,65 @@ public class RedshiftCalls(private val onProblem: ProblemCallback) : SqlCalls()
exprVar(id, Expr.Var.Scope.DEFAULT)
}

override fun rewriteCast(type: PartiQLValueType, args: SqlArgs): Expr = Ast.create {
when (type) {
PartiQLValueType.ANY -> {
onProblem.error("PartiQL `ANY` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.INT8 -> {
onProblem.error("PartiQL `INT8` type (1-byte integer) not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.INT -> {
onProblem.error("PartiQL `INT` type (arbitrary precision integer) not supported in Redshift")
// this needs a extra safety renaming because int refers to int4 in redshift.
exprCast(args[0].expr, typeCustom("Arbitrary Precision Integer"))
}
PartiQLValueType.MISSING -> {
onProblem.error("PartiQL `MISSING` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.SYMBOL -> {
onProblem.error("PartiQL `SYMBOL` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.INTERVAL -> {
onProblem.error("PartiQL `INTERVAL` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.BLOB -> {
onProblem.error("PartiQL `BLOB` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.CLOB -> {
onProblem.error("PartiQL `CLOB` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.BAG -> {
onProblem.error("PartiQL `BAG` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.LIST -> {
onProblem.error("PartiQL `LIST` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.SEXP -> {
onProblem.error("PartiQL `SEXP` type not supported in Redshift")
super.rewriteCast(type, args)
}
PartiQLValueType.STRUCT -> {
onProblem.error("PartiQL `STRUCT` type not supported in Redshift")
super.rewriteCast(type, args)
}
// using the customer type to rename type
PartiQLValueType.FLOAT32 -> exprCast(args[0].expr, typeCustom("FLOAT4"))
PartiQLValueType.FLOAT64 -> exprCast(args[0].expr, typeCustom("FLOAT8"))
PartiQLValueType.BINARY -> exprCast(args[0].expr, typeCustom("VARBYTE"))
PartiQLValueType.BYTE -> TODO("Mapping to VARBYTE(1), do this after supporting parameterized type")
else -> super.rewriteCast(type, args)
}
}

private fun AstFactory.id(symbol: String) = identifierSymbol(symbol, Identifier.CaseSensitivity.INSENSITIVE)
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public abstract class SqlDialect : AstBaseVisitor<SqlBlock, SqlBlock>() {
override fun visitTypeInterval(node: Type.Interval, head: SqlBlock) = head concat type("INTERVAL", node.precision)

// unsupported
override fun visitTypeCustom(node: Type.Custom, head: SqlBlock) = defaultReturn(node, head)
override fun visitTypeCustom(node: Type.Custom, head: SqlBlock) = head concat r(node.name)

// Expressions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import org.partiql.ast.AstNode
import org.partiql.ast.DatetimeField
import org.partiql.ast.Expr
import org.partiql.ast.Select
import org.partiql.ast.Type
import org.partiql.ast.visitor.AstBaseVisitor
import org.partiql.plan.Identifier
import org.partiql.plan.Plan
Expand All @@ -15,6 +16,8 @@ import org.partiql.planner.Env
import org.partiql.planner.typer.toNonNullStaticType
import org.partiql.planner.typer.toStaticType
import org.partiql.types.StaticType
import org.partiql.types.TimeType
import org.partiql.types.TimestampType
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.boolValue
Expand Down Expand Up @@ -329,8 +332,47 @@ internal object RexConverter {
TODO("SQL Special Form EXTRACT")
}

override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex {
TODO("SQL Special Form CAST")
// TODO: Ignoring type parameter now
override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex = transform {
val type = node.asType
val arg0 = visitExpr(node.value, ctx)
when(type) {
is Type.NullType -> rex(StaticType.NULL, call("cast_null", arg0))
is Type.Missing -> rex(StaticType.MISSING, call("cast_missing", arg0))
is Type.Bool -> rex(StaticType.BOOL, call("cast_bool", arg0))
is Type.Tinyint -> TODO("Static Type does not have TINYINT type")
is Type.Smallint, is Type.Int2 -> rex(StaticType.INT2, call("cast_int16", arg0))
is Type.Int4 -> rex(StaticType.INT4, call("cast_int32", arg0))
is Type.Bigint, is Type.Int8 -> rex(StaticType.INT8, call("cast_int64", arg0))
is Type.Int -> rex(StaticType.INT, call("cast_int", arg0))
is Type.Real -> TODO("Static Type does not have REAL type")
is Type.Float32 -> TODO("Static Type does not have FLOAT32 type")
is Type.Float64 -> rex(StaticType.FLOAT, call("cast_float64", arg0))
is Type.Decimal -> rex(StaticType.DECIMAL, call("cast_decimal", arg0))
is Type.Numeric -> rex(StaticType.DECIMAL, call("cast_numeric", arg0))
is Type.Char -> rex(StaticType.CHAR, call("cast_char", arg0))
is Type.Varchar -> rex(StaticType.STRING, call("cast_varchar", arg0))
is Type.String -> rex(StaticType.STRING, call("cast_string", arg0))
is Type.Symbol -> rex(StaticType.SYMBOL, call("cast_symbol", arg0))
is Type.Bit -> TODO("Static Type does not have Bit type")
is Type.BitVarying -> TODO("Static Type does not have BitVarying type")
is Type.ByteString -> TODO("Static Type does not have ByteString type")
is Type.Blob -> rex(StaticType.BLOB, call("cast_blob", arg0))
is Type.Clob -> rex(StaticType.CLOB, call("cast_clob", arg0))
is Type.Date -> rex(StaticType.DATE, call("cast_date", arg0))
is Type.Time -> rex(StaticType.TIME, call("cast_time", arg0))
is Type.TimeWithTz -> rex(TimeType(null, true), call("cast_timeWithTz", arg0))
is Type.Timestamp -> TODO("Need to rebase main")
is Type.TimestampWithTz -> rex(StaticType.TIMESTAMP, call("cast_timeWithTz", arg0))
is Type.Interval -> TODO("Static Type does not have Interval type")
is Type.Bag -> rex(StaticType.BAG, call("cast_bag", arg0))
is Type.List -> rex(StaticType.LIST, call("cast_list", arg0))
is Type.Sexp -> rex(StaticType.SEXP, call("cast_sexp", arg0))
is Type.Tuple -> rex(StaticType.STRUCT, call("cast_tuple", arg0))
is Type.Struct -> rex(StaticType.STRUCT, call("cast_struct", arg0))
is Type.Any -> rex(StaticType.ANY, call("cast_any", arg0))
is Type.Custom -> TODO("Custom type not supported ")
}
}

override fun visitExprCanCast(node: Expr.CanCast, ctx: Env): Rex {
Expand Down

0 comments on commit 95bdf80

Please sign in to comment.