Skip to content

Commit

Permalink
Extend Snowflake to Databricks functions coverage 2 (#72)
Browse files Browse the repository at this point in the history
Added the below functions: 

- MONTH_NAME
- OBJECT_CONSTRUCT
- OBJECT_KEYS
- TRY_PARSE_JSON
- DATEADD
- DATEDIFF
- TIMEDIFF
- TIMESTAMPDIFF
- TO_BOOLEAN
- TO_DECIMAL
- TO_DOUBLE
- TO_NUMBER
- TO_NUMERIC
- TO_OBJECT
- TO_TIME
- TIMESTAMP_FROM_PARTS
- TO_VARIANT
- TO_BASE64
- FROM_BASE64
- TRY_BASE64_DECODE_STRING
- TRY_TO_BOOLEAN
- UUID_STRING
  • Loading branch information
vijaypavann-db authored and nfx committed Jan 24, 2024
1 parent 61abf47 commit b4c619d
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 3 deletions.
65 changes: 63 additions & 2 deletions src/databricks/labs/remorph/snow/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,58 @@ def _array_slice(self: Databricks.Generator, expression: local_expression.ArrayS
return func_expr


def _parse_json(self, expr: exp.ParseJSON):
"""
Converts `PARSE_JSON` function to `FROM_JSON` function.
Schema is a mandatory argument for Databricks `FROM_JSON` function
[FROM_JSON](https://docs.databricks.com/en/sql/language-manual/functions/from_json.html)
Need to explicitly specify the Schema {<COL_NAME>_SCHEMA} in the current execution environment
"""
expr_this = self.sql(expr, "this")
column = expr_this.replace("'", "").upper()
conv_expr = self.func("FROM_JSON", expr_this, f"{{{column}_SCHEMA}}")
warning_msg = (
f"\n***Warning***: you need to explicitly specify `SCHEMA` for `{column}` column in expression: `{conv_expr}`"
)
print(warning_msg) # noqa: T201
return conv_expr


def _to_number(self, expression: local_expression.TryToNumber):
func = "TO_NUMBER"
precision = self.sql(expression, "precision")
scale = self.sql(expression, "scale")

if not precision:
precision = 38

if not scale:
scale = 0

func_expr = self.func(func, expression.this)
if expression.expression:
func_expr = self.func(func, expression.this, expression.expression)
else:
exception_msg = f"""Error Parsing expression {expression}:
* `format`: is required in Databricks [mandatory]
* `precision` and `scale`: are considered as (38, 0) if not specified.
"""
raise UnsupportedError(exception_msg)

return f"CAST({func_expr} AS DECIMAL({precision}, {scale}))"


def _uuid(self: Databricks.Generator, expression: local_expression.UUID) -> str:
namespace = self.sql(expression, "this")
name = self.sql(expression, "name")

if namespace and name:
print("UUID version 5 is not supported currently. Needs manual intervention.") # noqa : T201
return f"UUID({namespace}, {name})"
else:
return "UUID()"


class Databricks(Databricks):
# Instantiate Databricks Dialect
databricks = Databricks()
Expand Down Expand Up @@ -288,6 +340,15 @@ class Generator(databricks.Generator):
local_expression.ArrayConstructCompact: _array_construct_compact,
local_expression.ArrayIntersection: rename_func("ARRAY_INTERSECT"),
local_expression.ArraySlice: _array_slice,
local_expression.ObjectKeys: rename_func("JSON_OBJECT_KEYS"),
exp.ParseJSON: _parse_json,
local_expression.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
local_expression.ToDouble: rename_func("DOUBLE"),
local_expression.ToVariant: rename_func("TO_JSON"),
local_expression.ToObject: rename_func("TO_JSON"),
exp.ToBase64: rename_func("BASE64"),
local_expression.ToNumber: _to_number,
local_expression.UUID: _uuid,
}

def join_sql(self, expression: exp.Join) -> str:
Expand Down Expand Up @@ -398,8 +459,8 @@ def converttimezone_sql(self, expression: local_expression.ConvertTimeZone):

def splitpart_sql(self, expression: local_expression.SplitPart) -> str:
"""
:param expression: local_expression.Split expression to be parsed
:return: Converted expression (SPLIT) compatible with Databricks
:param expression: local_expression.SplitPart expression to be parsed
:return: Converted expression (SPLIT_PART) compatible with Databricks
"""
delimiter = " "
# To handle default delimiter
Expand Down
42 changes: 42 additions & 0 deletions src/databricks/labs/remorph/snow/local_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,45 @@ class ArrayIntersection(Func):

class ArraySlice(Func):
arg_types: ClassVar[dict] = {"this": True, "from": True, "to": True}


class ObjectKeys(Func):
arg_types: ClassVar[dict] = {"this": True}


class ToBoolean(Func):
arg_types: ClassVar[dict] = {"this": True, "raise_error": False}


class ToDouble(Func):
pass


class ToObject(Func):
pass


class ToNumber(Func):
arg_types: ClassVar[dict] = {"this": True, "expression": False, "precision": False, "scale": False}
_sql_names: ClassVar[list] = ["TO_DECIMAL", "TO_NUMBER", "TO_NUMERIC"]


class TimestampFromParts(Func):
arg_types: ClassVar[dict] = {
"this": True,
"expression": True,
"day": True,
"hour": True,
"min": True,
"sec": True,
"nanosec": False,
"Zone": False,
}


class ToVariant(Func):
pass


class UUID(Func):
arg_types: ClassVar[dict] = {"this": False, "name": False}
63 changes: 62 additions & 1 deletion src/databricks/labs/remorph/snow/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import typing as t
from typing import ClassVar

from sqlglot import exp
from sqlglot import exp, parser
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.snowflake import _parse_to_timestamp as parse_to_timestamp
from sqlglot.errors import ParseError
from sqlglot.helper import seq_get
from sqlglot.tokens import Token, TokenType
Expand Down Expand Up @@ -107,6 +108,48 @@ def _parse_trytonumber(args: list) -> local_expression.TryToNumber:
return local_expression.TryToNumber(this=seq_get(args, 0), expression=seq_get(args, 1))


def _parse_monthname(args: list) -> local_expression.DateFormat:
if len(args) == 1:
return local_expression.DateFormat(this=seq_get(args, 0), expression=exp.Literal.string("MMM"))

return local_expression.DateFormat(this=seq_get(args, 0), expression=seq_get(args, 1))


def _parse_object_construct(args: list) -> exp.StarMap | exp.Struct:
expression = parser.parse_var_map(args)

if isinstance(expression, exp.StarMap):
return exp.Struct(expressions=[expression.this])

return exp.Struct(
expressions=[t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values, strict=False)]
)


def _parse_to_boolean(*args: list, error: bool) -> local_expression.ToBoolean:
this_arg = seq_get(args, 0)
return local_expression.ToBoolean(this=this_arg, raise_error=exp.Literal.number(1 if error else 0))


def _parse_tonumber(args: list) -> local_expression.ToNumber:
if len(args) > 4:
error_msg = f"""Error Parsing args args:
* Number of args cannot be more than `4`, given `{len(args)}`
"""
raise ParseError(error_msg)

if len(args) == 1:
return local_expression.ToNumber(this=seq_get(args, 0))
elif len(args) == 3:
return local_expression.ToNumber(this=seq_get(args, 0), precision=seq_get(args, 1), scale=seq_get(args, 2))
elif len(args) == 4:
return local_expression.ToNumber(
this=seq_get(args, 0), expression=seq_get(args, 1), precision=seq_get(args, 2), scale=seq_get(args, 3)
)

return local_expression.ToNumber(this=seq_get(args, 0), expression=seq_get(args, 1))


class Snow(Snowflake):
# Instantiate Snowflake Dialect
snowflake = Snowflake()
Expand Down Expand Up @@ -235,6 +278,24 @@ class Parser(snowflake.Parser):
"ARRAY_CONSTRUCT_COMPACT": local_expression.ArrayConstructCompact.from_arg_list,
"ARRAY_INTERSECTION": local_expression.ArrayIntersection.from_arg_list,
"ARRAY_SLICE": local_expression.ArraySlice.from_arg_list,
"MONTHNAME": _parse_monthname,
"MONTH_NAME": _parse_monthname,
"OBJECT_CONSTRUCT": _parse_object_construct,
"OBJECT_KEYS": local_expression.ObjectKeys.from_arg_list,
"TRY_PARSE_JSON": exp.ParseJSON.from_arg_list,
"TIMEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"TIMESTAMPDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"TO_BOOLEAN": lambda args: _parse_to_boolean(args, error=True),
"TO_DECIMAL": _parse_tonumber,
"TO_DOUBLE": local_expression.ToDouble.from_arg_list,
"TO_NUMBER": _parse_tonumber,
"TO_NUMERIC": _parse_tonumber,
"TO_OBJECT": local_expression.ToObject.from_arg_list,
"TO_TIME": parse_to_timestamp,
"TIMESTAMP_FROM_PARTS": local_expression.TimestampFromParts.from_arg_list,
"TO_VARIANT": local_expression.ToVariant.from_arg_list,
"TRY_TO_BOOLEAN": lambda args: _parse_to_boolean(args, error=False),
"UUID_STRING": local_expression.UUID.from_arg_list,
}

FUNCTION_PARSERS: ClassVar[dict] = {
Expand Down

0 comments on commit b4c619d

Please sign in to comment.