Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced install script to enforce usage of a warehouse or cluster when skip-validation is set to False #213

Merged
merged 14 commits into from
Apr 4, 2024
1 change: 0 additions & 1 deletion labs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ install:
require_running_cluster: false
require_databricks_connect: true
script: src/databricks/labs/remorph/install.py
warehouse_types: ["SERVERLESS", "PRO"]
uninstall:
script: src/databricks/labs/remorph/uninstall.py
entrypoint: src/databricks/labs/remorph/cli.py
Expand Down
10 changes: 8 additions & 2 deletions src/databricks/labs/remorph/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from databricks.labs.blueprint.cli import App
from databricks.labs.blueprint.entrypoint import get_logger
from databricks.labs.blueprint.installation import Installation
from databricks.sdk import WorkspaceClient

from databricks.labs.remorph.config import MorphConfig
Expand All @@ -29,6 +30,10 @@ def transpile(
):
"""transpiles source dialect to databricks dialect"""
logger.info(f"user: {w.current_user.me()}")
installation = Installation.current(w, 'remorph')
default_config = installation.load(MorphConfig)

# TODO refactor cli based on the default config

if source.lower() not in {"snowflake", "tsql"}:
raise_validation_exception(
Expand All @@ -37,20 +42,21 @@ def transpile(
if not os.path.exists(input_sql) or input_sql in {None, ""}:
raise_validation_exception(f"Error: Invalid value for '--input_sql': Path '{input_sql}' does not exist.")
if output_folder == "":
output_folder = None
output_folder = default_config.output_folder if default_config.output_folder else None
if skip_validation.lower() not in {"true", "false"}:
raise_validation_exception(
f"Error: Invalid value for '--skip_validation': '{skip_validation}' is not one of 'true', 'false'. "
)

sdk_config = default_config.sdk_config if default_config.sdk_config else None
config = MorphConfig(
source=source.lower(),
input_sql=input_sql,
output_folder=output_folder,
skip_validation=skip_validation.lower() == "true", # convert to bool
catalog_name=catalog_name,
schema_name=schema_name,
sdk_config=w.config,
sdk_config=sdk_config,
)

status = morph(w, config)
Expand Down
4 changes: 1 addition & 3 deletions src/databricks/labs/remorph/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from dataclasses import dataclass

from databricks.sdk.core import Config

logger = logging.getLogger(__name__)


Expand All @@ -12,7 +10,7 @@ class MorphConfig:
__version__ = 1

source: str
sdk_config: Config | None
sdk_config: dict[str, str] | None
input_sql: str | None = None
output_folder: str | None = None
skip_validation: bool = False
Expand Down
7 changes: 5 additions & 2 deletions src/databricks/labs/remorph/helpers/db_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@


def get_sql_backend(ws: WorkspaceClient, config: MorphConfig) -> SqlBackend:
sdk_config = ws.config
warehouse_id = isinstance(sdk_config.warehouse_id, str) and sdk_config.warehouse_id
sdk_config = config.sdk_config
warehouse_id = sdk_config.get("warehouse_id", None) if sdk_config else None
cluster_id = sdk_config.get("cluster_id", None) if sdk_config else None
catalog_name = config.catalog_name
schema_name = config.schema_name
if warehouse_id:
sql_backend = StatementExecutionBackend(ws, warehouse_id, catalog=catalog_name, schema=schema_name)
else:
# assigning cluster id explicitly to the config as user can provide them during installation
ws.config.cluster_id = cluster_id if cluster_id else ws.config.cluster_id
sql_backend = RuntimeBackend() if "DATABRICKS_RUNTIME_VERSION" in os.environ else DatabricksConnectBackend(ws)
try:
sql_backend.execute(f"use catalog {catalog_name}")
Expand Down
72 changes: 60 additions & 12 deletions src/databricks/labs/remorph/install.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import time
import webbrowser
from datetime import timedelta
from pathlib import Path
Expand All @@ -12,13 +13,19 @@
from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound
from databricks.sdk.retries import retried
from databricks.sdk.service.sql import (
CreateWarehouseRequestWarehouseType,
EndpointInfoWarehouseType,
SpotInstancePolicy,
)

from databricks.labs.remorph.__about__ import __version__
from databricks.labs.remorph.config import MorphConfig

logger = logging.getLogger(__name__)

PRODUCT_INFO = ProductInfo(__file__)
WAREHOUSE_PREFIX = "Remorph Transpiler Validation"


class WorkspaceInstaller:
Expand Down Expand Up @@ -50,38 +57,79 @@ def configure(self) -> MorphConfig:
logger.debug(f"Cannot find previous installation: {err}")
logger.info("Please answer a couple of questions to configure Remorph")

# default params
catalog_name = "transpiler_test"
schema_name = "convertor_test"
ws_config = None

source_prompt = self._prompts.choice("Select the source", ["snowflake", "tsql"])
source = source_prompt.lower()

skip_validation = self._prompts.confirm("Do you want to Skip Validation")

catalog_name = self._prompts.question("Enter catalog_name")

try:
self._catalog_setup.get(catalog_name)
except NotFound:
self.setup_catalog(catalog_name)
if not skip_validation:
ws_config = self._configure_runtime()
catalog_name = self._prompts.question("Enter catalog_name")
try:
self._catalog_setup.get(catalog_name)
except NotFound:
self.setup_catalog(catalog_name)

schema_name = self._prompts.question("Enter schema_name")
schema_name = self._prompts.question("Enter schema_name")

try:
self._catalog_setup.get_schema(f"{catalog_name}.{schema_name}")
except NotFound:
self.setup_schema(catalog_name, schema_name)
try:
self._catalog_setup.get_schema(f"{catalog_name}.{schema_name}")
except NotFound:
self.setup_schema(catalog_name, schema_name)

config = MorphConfig(
source=source,
skip_validation=skip_validation,
catalog_name=catalog_name,
schema_name=schema_name,
sdk_config=None,
sdk_config=ws_config,
)

ws_file_url = self._installation.save(config)
if self._prompts.confirm("Open config file in the browser and continue installing?"):
webbrowser.open(ws_file_url)
return config

def _configure_runtime(self) -> dict[str, str]:
if self._prompts.confirm("Do you want to use SQL Warehouse for validation?"):
warehouse_id = self._configure_warehouse()
return {"warehouse_id": warehouse_id}

if self._ws.config.cluster_id:
logger.info(f"Using cluster {self._ws.config.cluster_id} for validation")
return {"cluster": self._ws.config.cluster_id}

cluster_id = self._prompts.question("Enter a valid cluster_id to proceed")
return {"cluster": cluster_id}

def _configure_warehouse(self):
def warehouse_type(_):
return _.warehouse_type.value if not _.enable_serverless_compute else "SERVERLESS"

pro_warehouses = {"[Create new PRO SQL warehouse]": "create_new"} | {
f"{_.name} ({_.id}, {warehouse_type(_)}, {_.state.value})": _.id
for _ in self._ws.warehouses.list()
if _.warehouse_type == EndpointInfoWarehouseType.PRO
}
warehouse_id = self._prompts.choice_from_dict(
"Select PRO or SERVERLESS SQL warehouse to run validation on", pro_warehouses
)
if warehouse_id == "create_new":
new_warehouse = self._ws.warehouses.create(
name=f"{WAREHOUSE_PREFIX} {time.time_ns()}",
spot_instance_policy=SpotInstancePolicy.COST_OPTIMIZED,
warehouse_type=CreateWarehouseRequestWarehouseType.PRO,
cluster_size="Small",
max_num_clusters=1,
)
warehouse_id = new_warehouse.id
return warehouse_id

@retried(on=[NotFound], timeout=timedelta(minutes=5))
def setup_catalog(self, catalog_name: str):
allow_catalog_creation = self._prompts.confirm(
Expand Down
9 changes: 5 additions & 4 deletions src/databricks/labs/remorph/transpiler/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def morph(workspace_client: WorkspaceClient, config: MorphConfig):
skip_validation = config.skip_validation
status = []
result = MorphStatus([], 0, 0, 0, [])
validator = Validator(db_sql.get_sql_backend(workspace_client, config))
validator = None
if not config.skip_validation:
validator = Validator(db_sql.get_sql_backend(workspace_client, config))

if input_sql.is_file():
if is_sql_file(input_sql):
msg = f"Processing for sqls under this file: {input_sql}"
Expand Down Expand Up @@ -157,16 +160,14 @@ def morph(workspace_client: WorkspaceClient, config: MorphConfig):
validate_error_count = result.validate_error_count

error_list_count = parse_error_count + validate_error_count

if not skip_validation:
logger.info(f"No of Sql Failed while Validating: {validate_error_count}")

error_log_file = "None"
if error_list_count > 0:
error_log_file = Path.cwd() / f"err_{os.getpid()}.lst"
with error_log_file.open("a") as e:
e.writelines(f"{err}\n" for err in result.error_log_list)
else:
error_log_file = "None"

status.append(
{
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def mock_workspace_client():
yield client


@pytest.fixture(scope="session")
def morph_config(mock_databricks_config):
@pytest.fixture()
def morph_config():
yield MorphConfig(
sdk_config=mock_databricks_config,
sdk_config={"cluster_id": "test_cluster"},
source="snowflake",
input_sql="input_sql",
output_folder="output_folder",
Expand Down
42 changes: 23 additions & 19 deletions tests/unit/helpers/test_db_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,64 +6,68 @@
from databricks.labs.remorph.helpers.db_sql import get_sql_backend


@pytest.mark.usefixtures("mock_workspace_client", "morph_config")
@pytest.fixture()
def morph_config_sqlbackend(morph_config):
return morph_config


@patch('databricks.labs.remorph.helpers.db_sql.StatementExecutionBackend')
def test_get_sql_backend_with_warehouse_id(
stmt_execution_backend,
mock_workspace_client,
morph_config,
morph_config_sqlbackend,
):
mock_workspace_client.config.warehouse_id = "test_warehouse_id"
sql_backend = get_sql_backend(mock_workspace_client, morph_config)
morph_config_sqlbackend.sdk_config = {"warehouse_id": "test_warehouse_id"}
sql_backend = get_sql_backend(mock_workspace_client, morph_config_sqlbackend)
stmt_execution_backend.assert_called_once_with(
mock_workspace_client,
"test_warehouse_id",
catalog=morph_config.catalog_name,
schema=morph_config.schema_name,
catalog=morph_config_sqlbackend.catalog_name,
schema=morph_config_sqlbackend.schema_name,
)
assert isinstance(sql_backend, stmt_execution_backend.return_value.__class__)


@pytest.mark.usefixtures("mock_workspace_client", "morph_config")
@patch('databricks.labs.remorph.helpers.db_sql.DatabricksConnectBackend')
def test_get_sql_backend_without_warehouse_id(
databricks_connect_backend,
mock_workspace_client,
morph_config,
morph_config_sqlbackend,
):
mock_dbc_backend_instance = databricks_connect_backend.return_value
sql_backend = get_sql_backend(mock_workspace_client, morph_config)
# morph config mock object has cluster id
sql_backend = get_sql_backend(mock_workspace_client, morph_config_sqlbackend)
databricks_connect_backend.assert_called_once_with(mock_workspace_client)
mock_dbc_backend_instance.execute.assert_any_call(f"use catalog {morph_config.catalog_name}")
mock_dbc_backend_instance.execute.assert_any_call(f"use {morph_config.schema_name}")
mock_dbc_backend_instance.execute.assert_any_call(f"use catalog {morph_config_sqlbackend.catalog_name}")
mock_dbc_backend_instance.execute.assert_any_call(f"use {morph_config_sqlbackend.schema_name}")
assert isinstance(sql_backend, databricks_connect_backend.return_value.__class__)


@pytest.mark.usefixtures("mock_workspace_client", "morph_config", "monkeypatch")
@pytest.mark.usefixtures("monkeypatch")
@patch('databricks.labs.remorph.helpers.db_sql.RuntimeBackend')
def test_get_sql_backend_without_warehouse_id_in_notebook(
runtime_backend,
mock_workspace_client,
morph_config,
morph_config_sqlbackend,
monkeypatch,
):
monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3")
mock_runtime_backend_instance = runtime_backend.return_value
sql_backend = get_sql_backend(mock_workspace_client, morph_config)
morph_config_sqlbackend.sdk_config = None
sql_backend = get_sql_backend(mock_workspace_client, morph_config_sqlbackend)
runtime_backend.assert_called_once()
mock_runtime_backend_instance.execute.assert_any_call(f"use catalog {morph_config.catalog_name}")
mock_runtime_backend_instance.execute.assert_any_call(f"use {morph_config.schema_name}")
mock_runtime_backend_instance.execute.assert_any_call(f"use catalog {morph_config_sqlbackend.catalog_name}")
mock_runtime_backend_instance.execute.assert_any_call(f"use {morph_config_sqlbackend.schema_name}")
assert isinstance(sql_backend, runtime_backend.return_value.__class__)


@pytest.mark.usefixtures("mock_workspace_client", "morph_config")
@patch('databricks.labs.remorph.helpers.db_sql.DatabricksConnectBackend')
def test_get_sql_backend_with_error(
databricks_connect_backend,
mock_workspace_client,
morph_config,
morph_config_sqlbackend,
):
mock_dbc_backend_instance = databricks_connect_backend.return_value
mock_dbc_backend_instance.execute.side_effect = DatabricksError("Test error")
with pytest.raises(DatabricksError):
get_sql_backend(mock_workspace_client, morph_config)
get_sql_backend(mock_workspace_client, morph_config_sqlbackend)
Loading
Loading