Skip to content

Commit

Permalink
make progress
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Sep 19, 2023
1 parent 37c4cb1 commit 5b6a044
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
49 changes: 47 additions & 2 deletions src/databricks/labs/ucx/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import webbrowser
from dataclasses import replace
from pathlib import Path
from typing import Any

import yaml
from databricks.sdk import WorkspaceClient
from databricks.sdk.core import DatabricksError
from databricks.sdk.service import compute, jobs
from databricks.sdk.service.sql import EndpointInfoWarehouseType, SpotInstancePolicy
from databricks.sdk.service.workspace import ImportFormat

from databricks.labs.ucx.__about__ import __version__
Expand Down Expand Up @@ -90,10 +92,10 @@ def _create_dashboards(self):
def _warehouse_id(self) -> str:
if self._current_config.warehouse_id is not None:
return self._current_config.warehouse_id
warehouses = self._ws.warehouses.list()
warehouses = [_ for _ in self._ws.warehouses.list() if _.warehouse_type == EndpointInfoWarehouseType.PRO]
warehouse_id = self._current_config.warehouse_id
if not warehouse_id and not warehouses:
msg = "need either configured warehouse_id or an existing SQL warehouse"
msg = "need either configured warehouse_id or an existing PRO SQL warehouse"
raise ValueError(msg)
if not warehouse_id:
warehouse_id = warehouses[0].id
Expand Down Expand Up @@ -140,6 +142,25 @@ def _configure(self):

logger.info("Please answer a couple of questions to configure Unity Catalog migration")
inventory_database = self._question("Inventory Database", default="ucx")

pro_warehouses = {"[Create new PRO SQL warehouse]": "create_new"} | {
f"{_.name} ({_.id}, {_.warehouse_type.value}, {_.state.value})": _.id
for _ in self._ws.warehouses.list()
if _.warehouse_type == EndpointInfoWarehouseType.PRO
}
warehouse_id = self._choice_from_dict(
"Select PRO or SERVERLESS SQL warehouse to run assessment dashboards on", pro_warehouses
)
if warehouse_id == "create_new":
new_warehouse = self._ws.warehouses.create(
name="Unity Catalog Migration",
spot_instance_policy=SpotInstancePolicy.COST_OPTIMIZED,
warehouse_type=EndpointInfoWarehouseType.PRO,
cluster_size="Small",
max_num_clusters=1,
)
warehouse_id = new_warehouse.id

selected_groups = self._question(
"Comma-separated list of workspace group names to migrate (empty means all)", default="<ALL>"
)
Expand All @@ -157,6 +178,7 @@ def _configure(self):
inventory_database=inventory_database,
groups=GroupsConfig(**groups_config_args),
tacl=TaclConfig(auto=True),
warehouse_id=warehouse_id,
log_level=log_level,
num_threads=num_threads,
)
Expand Down Expand Up @@ -254,6 +276,28 @@ def _create_debug(self, remote_wheel: str):
def _notebook_link(self, path: str) -> str:
return f"{self._ws.config.host}/#workspace{path}"

def _choice_from_dict(self, text: str, choices: dict[str, Any]) -> Any:
key = self._choice(text, list(choices.keys()))
return choices[key]

def _choice(self, text: str, choices: list[Any]) -> str:
if not self._prompts:
return "any"
choices = sorted(choices)
numbered = "\n".join(f"\033[1m[{i}]\033[0m \033[36m{v}\033[0m" for i, v in enumerate(choices))
prompt = f"\033[1m{text}\033[0m\n{numbered}\nEnter a number between 0 and {len(choices)-1}: "
while True:
res = input(prompt)
try:
res = int(res)
except ValueError:
print(f"\033[31m[ERROR] Invalid number: {res}\033[0m\n")
continue
if res >= len(choices) or res < 0:
print(f"\033[31m[ERROR] Out of range: {res}\033[0m\n")
continue
return choices[res]

@staticmethod
def _question(text: str, *, default: str | None = None) -> str:
default_help = "" if default is None else f"\033[36m (default: {default})\033[0m"
Expand Down Expand Up @@ -310,6 +354,7 @@ def _job_task(self, task: Task, dbfs_path: str) -> jobs.Task:
def _job_dashboard_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
return replace(
jobs_task,
job_cluster_key=None,
sql_task=jobs.SqlTask(
warehouse_id=self._warehouse_id,
dashboard=jobs.SqlTaskDashboard(dashboard_id=self._dashboards[task.dashboard]),
Expand Down
19 changes: 18 additions & 1 deletion tests/unit/test_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
Dashboard,
DataSource,
EndpointInfo,
EndpointInfoWarehouseType,
Query,
State,
Visualization,
Widget,
)
Expand All @@ -36,8 +38,12 @@ def not_found(_):
ws.current_user.me = lambda: iam.User(user_name="[email protected]", groups=[iam.ComplexValue(display="admins")])
ws.config.host = "https://foo"
ws.workspace.get_status = not_found
ws.warehouses.list = lambda **_: [
EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO, state=State.RUNNING)
]

install = Installer(ws)
install._choice = lambda _1, _2: "None (abc, PRO, RUNNING)"
install._configure()

ws.workspace.upload.assert_called_with(
Expand All @@ -52,6 +58,7 @@ def not_found(_):
tacl:
auto: true
version: 1
warehouse_id: abc
workspace_start_path: /
""",
format=ImportFormat.AUTO,
Expand Down Expand Up @@ -89,9 +96,13 @@ def mock_question(text: str, *, default: str | None = None) -> str:
ws.current_user.me = lambda: iam.User(user_name="[email protected]", groups=[iam.ComplexValue(display="admins")])
ws.config.host = "https://foo"
ws.workspace.get_status = not_found
ws.warehouses.list = lambda **_: [
EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO, state=State.RUNNING)
]

install = Installer(ws)
install._question = mock_question
install._choice = lambda _1, _2: "None (abc, PRO, RUNNING)"
install._configure()

ws.workspace.upload.assert_called_with(
Expand All @@ -105,6 +116,7 @@ def mock_question(text: str, *, default: str | None = None) -> str:
tacl:
auto: true
version: 1
warehouse_id: abc
workspace_start_path: /
""",
format=ImportFormat.AUTO,
Expand All @@ -126,9 +138,13 @@ def mock_question(text: str, *, default: str | None = None) -> str:
ws.current_user.me = lambda: iam.User(user_name="[email protected]", groups=[iam.ComplexValue(display="admins")])
ws.config.host = "https://foo"
ws.workspace.get_status = not_found
ws.warehouses.list = lambda **_: [
EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO, state=State.RUNNING)
]

install = Installer(ws)
install._question = mock_question
install._choice = lambda _1, _2: "None (abc, PRO, RUNNING)"
install._configure()

ws.workspace.upload.assert_called_with(
Expand All @@ -145,6 +161,7 @@ def mock_question(text: str, *, default: str | None = None) -> str:
tacl:
auto: true
version: 1
warehouse_id: abc
workspace_start_path: /
""",
format=ImportFormat.AUTO,
Expand All @@ -165,7 +182,7 @@ def test_main_with_existing_conf_does_not_recreate_config(mocker):
ws.workspace.download = lambda _: io.BytesIO(config_bytes)
ws.workspace.get_status = lambda _: ObjectInfo(object_id=123)
ws.data_sources.list = lambda: [DataSource(id="bcd", warehouse_id="abc")]
ws.warehouses.list = lambda **_: [EndpointInfo(id="abc")]
ws.warehouses.list = lambda **_: [EndpointInfo(id="abc", warehouse_type=EndpointInfoWarehouseType.PRO)]
ws.dashboards.create.return_value = Dashboard(id="abc")
ws.queries.create.return_value = Query(id="abc")
ws.query_visualizations.create.return_value = Visualization(id="abc")
Expand Down

0 comments on commit 5b6a044

Please sign in to comment.