diff --git a/src/databricks/labs/remorph/cli.py b/src/databricks/labs/remorph/cli.py index 2e87d94717..98fd863dcf 100644 --- a/src/databricks/labs/remorph/cli.py +++ b/src/databricks/labs/remorph/cli.py @@ -1,3 +1,4 @@ +import asyncio import json import os from pathlib import Path @@ -85,8 +86,10 @@ def transpile( mode=mode, sdk_config=sdk_config, ) + status, errors = asyncio.run(do_transpile(ctx.workspace_client, engine, config)) - status = do_transpile(ctx.workspace_client, engine, config) + for error in errors: + print(str(error)) print(json.dumps(status)) diff --git a/src/databricks/labs/remorph/config.py b/src/databricks/labs/remorph/config.py index 1edde865e8..1a2f204c99 100644 --- a/src/databricks/labs/remorph/config.py +++ b/src/databricks/labs/remorph/config.py @@ -41,7 +41,7 @@ def output_path(self): @property def error_path(self): - return None if self.error_file_path is None else Path(self.error_file_path) + return Path(self.error_file_path) if self.error_file_path else None @property def target_dialect(self): diff --git a/src/databricks/labs/remorph/transpiler/execute.py b/src/databricks/labs/remorph/transpiler/execute.py index f2d205c67b..b15f74bea4 100644 --- a/src/databricks/labs/remorph/transpiler/execute.py +++ b/src/databricks/labs/remorph/transpiler/execute.py @@ -2,7 +2,7 @@ import datetime import logging from pathlib import Path -from typing import cast +from typing import cast, Any from databricks.labs.remorph.__about__ import __version__ from databricks.labs.remorph.config import ( @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -def _process_file( +async def _process_file( config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine, @@ -46,8 +46,8 @@ def _process_file( with input_path.open("r") as f: source_sql = remove_bom(f.read()) - transpile_result = asyncio.run( - _transpile(transpiler, config.source_dialect, config.target_dialect, source_sql, input_path) + transpile_result = await _transpile( + transpiler, config.source_dialect, config.target_dialect, source_sql, input_path ) error_list.extend(transpile_result.error_list) @@ -71,7 +71,7 @@ def _process_file( return transpile_result.success_count, error_list -def _process_directory( +async def _process_directory( config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine, @@ -93,14 +93,14 @@ def _process_directory( continue output_file_name = output_folder_base / file.name - success_count, error_list = _process_file(config, validator, transpiler, file, output_file_name) + success_count, error_list = await _process_file(config, validator, transpiler, file, output_file_name) counter = counter + success_count all_errors.extend(error_list) return counter, all_errors -def _process_input_dir(config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine): +async def _process_input_dir(config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine): error_list = [] file_list = [] counter = 0 @@ -112,13 +112,13 @@ def _process_input_dir(config: TranspileConfig, validator: Validator | None, tra msg = f"Processing for sqls under this folder: {folder}" logger.info(msg) file_list.extend(files) - no_of_sqls, errors = _process_directory(config, validator, transpiler, root, base_root, files) + no_of_sqls, errors = await _process_directory(config, validator, transpiler, root, base_root, files) counter = counter + no_of_sqls error_list.extend(errors) return TranspileStatus(file_list, counter, error_list) -def _process_input_file( +async def _process_input_file( config: TranspileConfig, validator: Validator | None, transpiler: TranspileEngine ) -> TranspileStatus: if not is_sql_file(config.input_path): @@ -135,12 +135,23 @@ def _process_input_file( make_dir(output_path) output_file = output_path / config.input_path.name - no_of_sqls, error_list = _process_file(config, validator, transpiler, config.input_path, output_file) + no_of_sqls, error_list = await _process_file(config, validator, transpiler, config.input_path, output_file) return TranspileStatus([config.input_path], no_of_sqls, error_list) @timeit -def transpile(workspace_client: WorkspaceClient, engine: TranspileEngine, config: TranspileConfig): +async def transpile( + workspace_client: WorkspaceClient, engine: TranspileEngine, config: TranspileConfig +) -> tuple[list[dict[str, Any]], list[TranspileError]]: + await engine.initialize(config) + status, errors = await _do_transpile(workspace_client, engine, config) + await engine.shutdown() + return status, errors + + +async def _do_transpile( + workspace_client: WorkspaceClient, engine: TranspileEngine, config: TranspileConfig +) -> tuple[list[dict[str, Any]], list[TranspileError]]: """ [Experimental] Transpiles the SQL queries from one dialect to another. @@ -162,9 +173,9 @@ def transpile(workspace_client: WorkspaceClient, engine: TranspileEngine, config if config.input_source is None: raise InvalidInputException("Missing input source!") if config.input_path.is_dir(): - result = _process_input_dir(config, validator, engine) + result = await _process_input_dir(config, validator, engine) elif config.input_path.is_file(): - result = _process_input_file(config, validator, engine) + result = await _process_input_file(config, validator, engine) else: msg = f"{config.input_source} does not exist." logger.error(msg) @@ -194,7 +205,7 @@ def transpile(workspace_client: WorkspaceClient, engine: TranspileEngine, config "error_log_file": str(error_log_path), } ) - return status + return status, result.error_list def verify_workspace_client(workspace_client: WorkspaceClient) -> WorkspaceClient: @@ -213,9 +224,9 @@ def verify_workspace_client(workspace_client: WorkspaceClient) -> WorkspaceClien async def _transpile( - transpiler: TranspileEngine, from_dialect: str, to_dialect: str, source_code: str, input_path: Path + engine: TranspileEngine, from_dialect: str, to_dialect: str, source_code: str, input_path: Path ) -> TranspileResult: - return await transpiler.transpile(from_dialect, to_dialect, source_code, input_path) + return await engine.transpile(from_dialect, to_dialect, source_code, input_path) def _validation( diff --git a/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py b/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py index 6e966f280e..7dc82e7ad7 100644 --- a/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py +++ b/src/databricks/labs/remorph/transpiler/lsp/lsp_engine.py @@ -30,7 +30,10 @@ TextDocumentIdentifier, METHOD_TO_TYPES, LanguageKind, + Range as LSPRange, + Position as LSPPosition, _SPECIAL_PROPERTIES, + DiagnosticSeverity, ) from pygls.lsp.client import BaseLanguageClient from pygls.exceptions import FeatureRequestError @@ -40,7 +43,13 @@ from databricks.labs.remorph.config import TranspileConfig, TranspileResult from databricks.labs.remorph.errors.exceptions import IllegalStateException from databricks.labs.remorph.transpiler.transpile_engine import TranspileEngine -from databricks.labs.remorph.transpiler.transpile_status import TranspileError, ErrorSeverity, ErrorKind +from databricks.labs.remorph.transpiler.transpile_status import ( + TranspileError, + ErrorKind, + ErrorSeverity, + CodeRange, + CodePosition, +) logger = logging.getLogger(__name__) @@ -187,15 +196,18 @@ def wrapper(params): class ChangeManager(abc.ABC): @classmethod - def apply(cls, source_code: str, file_path: Path, changes: Sequence[TextEdit]) -> TranspileResult: - if not changes: + def apply( + cls, source_code: str, changes: Sequence[TextEdit], diagnostics: Sequence[Diagnostic], file_path: Path + ) -> TranspileResult: + if not changes and not diagnostics: return TranspileResult(source_code, 1, []) + transpile_errors = [DiagnosticConverter.apply(file_path, diagnostic) for diagnostic in diagnostics] try: lines = source_code.split("\n") for change in changes: lines = cls._apply(lines, change) transpiled_code = "\n".join(lines) - return TranspileResult(transpiled_code, 1, []) + return TranspileResult(transpiled_code, 1, transpile_errors) except IndexError as e: logger.error("Failed to apply changes", exc_info=e) error = TranspileError( @@ -205,7 +217,8 @@ def apply(cls, source_code: str, file_path: Path, changes: Sequence[TextEdit]) - path=file_path, message="Internal error, failed to apply changes", ) - return TranspileResult(source_code, 1, [error]) + transpile_errors.append(error) + return TranspileResult(source_code, 1, transpile_errors) @classmethod def _apply(cls, lines: list[str], change: TextEdit) -> list[str]: @@ -247,6 +260,46 @@ def _is_full_document_change(cls, lines: list[str], change: TextEdit) -> bool: ) +class DiagnosticConverter(abc.ABC): + + _KIND_NAMES = {e.name for e in ErrorKind} + + @classmethod + def apply(cls, file_path: Path, diagnostic: Diagnostic) -> TranspileError: + code = str(diagnostic.code) + kind = ErrorKind.INTERNAL + parts = code.split("-") + if len(parts) >= 2 and parts[0] in cls._KIND_NAMES: + kind = ErrorKind[parts[0]] + parts.pop(0) + code = "-".join(parts) + severity = cls._convert_severity(diagnostic.severity) + lsp_range = cls._convert_range(diagnostic.range) + return TranspileError( + code=code, kind=kind, severity=severity, path=file_path, message=diagnostic.message, range=lsp_range + ) + + @classmethod + def _convert_range(cls, lsp_range: LSPRange | None) -> CodeRange | None: + if not lsp_range: + return None + return CodeRange(cls._convert_position(lsp_range.start), cls._convert_position(lsp_range.end)) + + @classmethod + def _convert_position(cls, lsp_position: LSPPosition) -> CodePosition: + return CodePosition(lsp_position.line, lsp_position.character) + + @classmethod + def _convert_severity(cls, severity: DiagnosticSeverity | None) -> ErrorSeverity: + if severity == DiagnosticSeverity.Information: + return ErrorSeverity.INFO + if severity == DiagnosticSeverity.Warning: + return ErrorSeverity.WARNING + if severity == DiagnosticSeverity.Error: + return ErrorSeverity.ERROR + return ErrorSeverity.INFO + + class LSPEngine(TranspileEngine): @classmethod @@ -335,7 +388,7 @@ async def transpile( self.open_document(file_path, source_code=source_code) response = await self.transpile_document(file_path) self.close_document(file_path) - return ChangeManager.apply(source_code, file_path, response.changes) + return ChangeManager.apply(source_code, response.changes, response.diagnostics, file_path) def open_document(self, file_path: Path, encoding="utf-8", source_code: str | None = None) -> None: if source_code is None: diff --git a/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py b/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py index 7286d8f3b1..8e9e5aeff0 100644 --- a/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py +++ b/src/databricks/labs/remorph/transpiler/sqlglot/sqlglot_engine.py @@ -9,7 +9,7 @@ from sqlglot.expressions import Expression from sqlglot.tokens import Token, TokenType -from databricks.labs.remorph.config import TranspileResult +from databricks.labs.remorph.config import TranspileResult, TranspileConfig from databricks.labs.remorph.helpers.string_utils import format_error_message from databricks.labs.remorph.transpiler.sqlglot import lca_utils from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect @@ -68,6 +68,12 @@ def _partial_transpile( problem_list.append(ParserProblem(parsed_expression.original_sql, error)) return transpiled_sqls, problem_list + async def initialize(self, config: TranspileConfig) -> None: + pass + + async def shutdown(self) -> None: + pass + async def transpile( self, source_dialect: str, target_dialect: str, source_code: str, file_path: Path ) -> TranspileResult: diff --git a/src/databricks/labs/remorph/transpiler/transpile_engine.py b/src/databricks/labs/remorph/transpiler/transpile_engine.py index 1d26500834..9c801fa660 100644 --- a/src/databricks/labs/remorph/transpiler/transpile_engine.py +++ b/src/databricks/labs/remorph/transpiler/transpile_engine.py @@ -2,7 +2,7 @@ import abc from pathlib import Path -from databricks.labs.remorph.config import TranspileResult +from databricks.labs.remorph.config import TranspileResult, TranspileConfig class TranspileEngine(abc.ABC): @@ -24,6 +24,12 @@ def load_engine(cls, transpiler_config_path: Path) -> TranspileEngine: return LSPEngine.from_config_path(transpiler_config_path) + @abc.abstractmethod + async def initialize(self, config: TranspileConfig) -> None: ... + + @abc.abstractmethod + async def shutdown(self) -> None: ... + @abc.abstractmethod async def transpile( self, source_dialect: str, target_dialect: str, source_code: str, file_path: Path diff --git a/tests/resources/lsp_transpiler/internal.sql b/tests/resources/lsp_transpiler/internal.sql new file mode 100644 index 0000000000..fc563464ee --- /dev/null +++ b/tests/resources/lsp_transpiler/internal.sql @@ -0,0 +1 @@ +create table stuff(name varchar(12)) diff --git a/tests/resources/lsp_transpiler/lsp_server.py b/tests/resources/lsp_transpiler/lsp_server.py index 488cbf64b6..e5f02adc11 100644 --- a/tests/resources/lsp_transpiler/lsp_server.py +++ b/tests/resources/lsp_transpiler/lsp_server.py @@ -1,6 +1,7 @@ import os import sys from collections.abc import Sequence +from pathlib import Path from typing import Any, Literal from uuid import uuid4 @@ -23,6 +24,7 @@ Position, METHOD_TO_TYPES, _SPECIAL_PROPERTIES, + DiagnosticSeverity, ) from pygls.lsp.server import LanguageServer @@ -112,14 +114,39 @@ async def did_initialize(self, init_params: InitializeParams) -> None: def transpile_to_databricks(self, params: TranspileDocumentParams) -> TranspileDocumentResult: source_sql = self.workspace.get_text_document(params.uri).source source_lines = source_sql.split("\n") - transpiled_sql = source_sql.upper() - changes = [ - TextEdit( - range=Range(start=Position(0, 0), end=Position(len(source_lines), len(source_lines[-1]))), - new_text=transpiled_sql, + range = Range(start=Position(0, 0), end=Position(len(source_lines), len(source_lines[-1]))) + transpiled_sql, diagnostics = self._transpile(Path(params.uri).name, range, source_sql) + changes = [TextEdit(range=range, new_text=transpiled_sql)] + return TranspileDocumentResult(uri=params.uri, changes=changes, diagnostics=diagnostics) + + def _transpile(self, file_name: str, lsp_range: Range, source_sql: str) -> tuple[str, list[Diagnostic]]: + if file_name == "no_transpile.sql": + diagnostic = Diagnostic( + range=lsp_range, + message="No transpilation required", + severity=DiagnosticSeverity.Information, + code="GENERATION-NOT_REQUIRED", ) - ] - return TranspileDocumentResult(uri=params.uri, changes=changes, diagnostics=[]) + return source_sql, [diagnostic] + elif file_name == "unsupported_lca.sql": + diagnostic = Diagnostic( + range=lsp_range, + message="LCA conversion not supported", + severity=DiagnosticSeverity.Error, + code="ANALYSIS-UNSUPPORTED_LCA", + ) + return source_sql, [diagnostic] + elif file_name == "internal.sql": + diagnostic = Diagnostic( + range=lsp_range, + message="Something went wrong", + severity=DiagnosticSeverity.Warning, + code="SOME_ERROR_CODE", + ) + return source_sql, [diagnostic] + else: + # general test case + return source_sql.upper(), [] server = TestLspServer("test-lsp-server", "v0.1") diff --git a/tests/resources/lsp_transpiler/no_transpile.sql b/tests/resources/lsp_transpiler/no_transpile.sql new file mode 100644 index 0000000000..fc563464ee --- /dev/null +++ b/tests/resources/lsp_transpiler/no_transpile.sql @@ -0,0 +1 @@ +create table stuff(name varchar(12)) diff --git a/tests/resources/lsp_transpiler/unsupported_lca.sql b/tests/resources/lsp_transpiler/unsupported_lca.sql new file mode 100644 index 0000000000..fc563464ee --- /dev/null +++ b/tests/resources/lsp_transpiler/unsupported_lca.sql @@ -0,0 +1 @@ +create table stuff(name varchar(12)) diff --git a/tests/unit/test_cli_transpile.py b/tests/unit/test_cli_transpile.py index 80133e6c44..d6873e5ad7 100644 --- a/tests/unit/test_cli_transpile.py +++ b/tests/unit/test_cli_transpile.py @@ -1,4 +1,5 @@ -from unittest.mock import create_autospec, patch, PropertyMock, ANY +import asyncio +from unittest.mock import create_autospec, patch, PropertyMock, ANY, MagicMock import pytest @@ -8,6 +9,7 @@ from databricks.sdk import WorkspaceClient from databricks.labs.remorph.transpiler.transpile_engine import TranspileEngine +from tests.unit.conftest import path_to_resource def test_transpile_with_missing_installation(): @@ -32,11 +34,22 @@ def test_transpile_with_missing_installation(): ) +def patch_do_transpile(): + mock_transpile = MagicMock(return_value=({}, [])) + + @asyncio.coroutine + def patched_do_transpile(*args, **kwargs): + return mock_transpile(*args, **kwargs) + + return mock_transpile, patched_do_transpile + + def test_transpile_with_no_sdk_config(): workspace_client = create_autospec(WorkspaceClient) + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), patch("os.path.exists", return_value=True), ): default_config = TranspileConfig( @@ -85,10 +98,11 @@ def test_transpile_with_no_sdk_config(): def test_transpile_with_warehouse_id_in_sdk_config(): workspace_client = create_autospec(WorkspaceClient) + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): sdk_config = {"warehouse_id": "w_id"} default_config = TranspileConfig( @@ -137,10 +151,11 @@ def test_transpile_with_warehouse_id_in_sdk_config(): def test_transpile_with_cluster_id_in_sdk_config(): workspace_client = create_autospec(WorkspaceClient) + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("databricks.labs.remorph.cli.ApplicationContext", autospec=True) as mock_app_context, patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): sdk_config = {"cluster_id": "c_id"} default_config = TranspileConfig( @@ -292,9 +307,10 @@ def test_transpile_with_valid_input(mock_workspace_client_cli): mode = "current" sdk_config = {'cluster_id': 'test_cluster'} + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): cli.transpile( mock_workspace_client_cli, @@ -326,6 +342,50 @@ def test_transpile_with_valid_input(mock_workspace_client_cli): ) +def test_transpile_with_valid_transpiler(mock_workspace_client_cli): + transpiler_config_path = path_to_resource("lsp_transpiler", "lsp_config.yml") + source_dialect = "snowflake" + input_source = path_to_resource("functional", "snowflake", "aggregates", "least_1.sql") + output_folder = path_to_resource("lsp_transpiler") + error_file = "" + skip_validation = "true" + catalog_name = "my_catalog" + schema_name = "my_schema" + mode = "current" + sdk_config = {'cluster_id': 'test_cluster'} + + mock_transpile, patched_do_transpile = patch_do_transpile() + with (patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile),): + cli.transpile( + mock_workspace_client_cli, + transpiler_config_path, + source_dialect, + input_source, + output_folder, + error_file, + skip_validation, + catalog_name, + schema_name, + mode, + ) + mock_transpile.assert_called_once_with( + mock_workspace_client_cli, + ANY, + TranspileConfig( + transpiler_config_path=transpiler_config_path, + source_dialect=source_dialect, + input_source=input_source, + output_folder=output_folder, + error_file_path=error_file, + sdk_config=sdk_config, + skip_validation=True, + catalog_name=catalog_name, + schema_name=schema_name, + mode=mode, + ), + ) + + def test_transpile_empty_output_folder(mock_workspace_client_cli): transpiler = "sqlglot" source_dialect = "snowflake" @@ -339,9 +399,10 @@ def test_transpile_empty_output_folder(mock_workspace_client_cli): mode = "current" sdk_config = {'cluster_id': 'test_cluster'} + mock_transpile, patched_do_transpile = patch_do_transpile() with ( patch("os.path.exists", return_value=True), - patch("databricks.labs.remorph.cli.do_transpile", return_value={}) as mock_transpile, + patch("databricks.labs.remorph.cli.do_transpile", new=patched_do_transpile), ): cli.transpile( mock_workspace_client_cli, @@ -400,3 +461,30 @@ def test_transpile_with_invalid_mode(mock_workspace_client_cli): schema_name, mode, ) + + +def test_transpile_prints_errors(capsys, tmp_path, mock_workspace_client_cli): + transpiler_config_path = path_to_resource("lsp_transpiler", "lsp_config.yml") + source_dialect = "snowflake" + input_source = path_to_resource("lsp_transpiler", "unsupported_lca.sql") + output_folder = str(tmp_path) + error_file = None + skip_validation = "true" + catalog_name = "my_catalog" + schema_name = "my_schema" + mode = "current" + cli.transpile( + mock_workspace_client_cli, + transpiler_config_path, + source_dialect, + input_source, + output_folder, + error_file, + skip_validation, + catalog_name, + schema_name, + mode, + ) + captured = capsys.readouterr() + assert "TranspileError" in captured.out + assert "UNSUPPORTED_LCA" in captured.out diff --git a/tests/unit/transpiler/test_execute.py b/tests/unit/transpiler/test_execute.py index ba6832de69..57442a8412 100644 --- a/tests/unit/transpiler/test_execute.py +++ b/tests/unit/transpiler/test_execute.py @@ -1,3 +1,4 @@ +import asyncio import re import shutil from pathlib import Path @@ -8,11 +9,13 @@ from databricks.connect import DatabricksSession from databricks.labs.lsql.backends import MockBackend from databricks.labs.lsql.core import Row +from databricks.sdk import WorkspaceClient + from databricks.labs.remorph.config import TranspileConfig, ValidationResult from databricks.labs.remorph.helpers.file_utils import make_dir from databricks.labs.remorph.helpers.validation import Validator from databricks.labs.remorph.transpiler.execute import ( - transpile, + transpile as do_transpile, transpile_column_exp, transpile_sql, ) @@ -24,6 +27,10 @@ # pylint: disable=unspecified-encoding +def transpile(workspace_client: WorkspaceClient, engine: SqlglotEngine, config: TranspileConfig): + return asyncio.run(do_transpile(workspace_client, engine, config)) + + def safe_remove_dir(dir_path: Path): if dir_path.exists(): shutil.rmtree(dir_path) @@ -164,7 +171,7 @@ def test_with_dir_skip_validation(initial_setup, tmp_path, mock_workspace_client # call transpile with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" assert isinstance(status, list), "Status returned by transpile function is not a list" @@ -224,7 +231,7 @@ def test_with_dir_with_output_folder_skip_validation(initial_setup, tmp_path, mo skip_validation=True, ) with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" assert isinstance(status, list), "Status returned by transpile function is not a list" @@ -298,7 +305,7 @@ def test_with_file(initial_setup, tmp_path, mock_workspace_client): ), patch("databricks.labs.remorph.transpiler.execute.Validator", return_value=mock_validate), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" @@ -347,7 +354,7 @@ def test_with_file_with_output_folder_skip_validation(initial_setup, mock_worksp 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend(), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" @@ -382,7 +389,7 @@ def test_with_not_a_sql_file_skip_validation(initial_setup, mock_workspace_clien 'databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend(), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" @@ -500,7 +507,7 @@ def test_with_file_with_success(initial_setup, mock_workspace_client): ), patch("databricks.labs.remorph.transpiler.execute.Validator", return_value=mock_validate), ): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" assert isinstance(status, list), "Status returned by transpile function is not a list" @@ -544,7 +551,7 @@ def test_parse_error_handling(initial_setup, tmp_path, mock_workspace_client): ) with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" @@ -602,7 +609,7 @@ def test_token_error_handling(initial_setup, tmp_path, mock_workspace_client): ) with patch('databricks.labs.remorph.helpers.db_sql.get_sql_backend', return_value=MockBackend()): - status = transpile(mock_workspace_client, SqlglotEngine(), config) + status, _errors = transpile(mock_workspace_client, SqlglotEngine(), config) # assert the status assert status is not None, "Status returned by transpile function is None" assert isinstance(status, list), "Status returned by transpile function is not a list" diff --git a/tests/unit/transpiler/test_lsp_engine.py b/tests/unit/transpiler/test_lsp_engine.py index d34bff9825..c537c2c72a 100644 --- a/tests/unit/transpiler/test_lsp_engine.py +++ b/tests/unit/transpiler/test_lsp_engine.py @@ -1,4 +1,5 @@ import asyncio +import dataclasses from pathlib import Path from time import sleep @@ -10,6 +11,7 @@ LSPEngine, ChangeManager, ) +from databricks.labs.remorph.transpiler.transpile_status import TranspileError, ErrorSeverity, ErrorKind from tests.unit.conftest import path_to_resource @@ -132,13 +134,58 @@ async def test_server_transpiles_document(lsp_engine, transpile_config): ], ) def test_change_mgr_replaces_text(source, changes, expected): - result = ChangeManager.apply(source, Path("dummy.sql"), changes) + result = ChangeManager.apply(source, changes, [], Path()) assert result.transpiled_code == expected -def test_change_mgr_returns_error(): - source = "abc" - changes = [TextEdit(Range(Position(9, 0), Position(10, 10)), "def")] - result = ChangeManager.apply(source, Path("dummy.sql"), changes) - assert result.transpiled_code == source - assert "Internal error" in result.error_list[0].message +@pytest.mark.parametrize( + "resource, errors", + [ + ("source_stuff.sql", []), + ( + "no_transpile.sql", + [ + TranspileError( + "NOT_REQUIRED", + ErrorKind.GENERATION, + ErrorSeverity.INFO, + Path("no_transpile.sql"), + "No transpilation required", + ) + ], + ), + ( + "unsupported_lca.sql", + [ + TranspileError( + "UNSUPPORTED_LCA", + ErrorKind.ANALYSIS, + ErrorSeverity.ERROR, + Path("unsupported_lca.sql"), + "LCA conversion not supported", + ) + ], + ), + ( + "internal.sql", + [ + TranspileError( + "SOME_ERROR_CODE", + ErrorKind.INTERNAL, + ErrorSeverity.WARNING, + Path("internal.sql"), + "Something went wrong", + ) + ], + ), + ], +) +async def test_client_translates_diagnostics(lsp_engine, transpile_config, resource, errors): + sample_path = Path(path_to_resource("lsp_transpiler", resource)) + await lsp_engine.initialize(transpile_config) + result = await lsp_engine.transpile( + transpile_config.source_dialect, "databricks", sample_path.read_text(encoding="utf-8"), sample_path + ) + await lsp_engine.shutdown() + actual = [dataclasses.replace(error, path=Path(error.path.name), range=None) for error in result.error_list] + assert actual == errors