Skip to content

Commit

Permalink
Improve Exception handling (#392)
Browse files Browse the repository at this point in the history
closes #388
  • Loading branch information
sundarshankar89 authored May 27, 2024
1 parent eb0fd83 commit b0a607f
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re

from pyspark.errors import PySparkException
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col
from sqlglot import Dialect
Expand Down Expand Up @@ -53,7 +54,7 @@ def read_data(
try:
df = self._spark.sql(table_query)
return df.select([col(column).alias(column.lower()) for column in df.columns])
except RuntimeError as e:
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "data", table_query)

def get_schema(
Expand All @@ -66,5 +67,5 @@ def get_schema(
try:
schema_df = self._spark.sql(schema_query).where("col_name not like '#%'").distinct()
return [Schema(field.col_name.lower(), field.data_type.lower()) for field in schema_df.collect()]
except RuntimeError as e:
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)
5 changes: 3 additions & 2 deletions src/databricks/labs/remorph/reconcile/connectors/oracle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re

from pyspark.errors import PySparkException
from pyspark.sql import DataFrame, DataFrameReader, SparkSession
from sqlglot import Dialect

Expand Down Expand Up @@ -59,7 +60,7 @@ def read_data(
return self.reader(table_query).options(**self._get_timestamp_options()).load()
options = self._get_jdbc_reader_options(options) | self._get_timestamp_options()
return self.reader(table_query).options(**options).load()
except RuntimeError as e:
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "data", table_query)

def get_schema(
Expand All @@ -76,7 +77,7 @@ def get_schema(
try:
schema_df = self.reader(schema_query).load()
return [Schema(field.column_name.lower(), field.data_type.lower()) for field in schema_df.collect()]
except RuntimeError as e:
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions src/databricks/labs/remorph/reconcile/connectors/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re

from pyspark.errors import PySparkException
from pyspark.sql import DataFrame, DataFrameReader, SparkSession
from pyspark.sql.functions import col
from sqlglot import Dialect
Expand Down Expand Up @@ -63,7 +64,7 @@ def read_data(
.load()
)
return df.select([col(column).alias(column.lower()) for column in df.columns])
except RuntimeError as e:
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "data", table_query)

def get_schema(
Expand All @@ -80,7 +81,7 @@ def get_schema(
try:
schema_df = self.reader(schema_query).load()
return [Schema(field.column_name.lower(), field.data_type.lower()) for field in schema_df.collect()]
except RuntimeError as e:
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)

def reader(self, query: str) -> DataFrameReader:
Expand Down
23 changes: 13 additions & 10 deletions src/databricks/labs/remorph/reconcile/connectors/source_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
from databricks.sdk import WorkspaceClient


class DataSourceAdapter:
@staticmethod
def create_adapter(engine: Dialect, spark: SparkSession, ws: WorkspaceClient, secret_scope: str) -> DataSource:
if isinstance(engine, Snow):
return SnowflakeDataSource(engine, spark, ws, secret_scope)
if isinstance(engine, Oracle):
return OracleDataSource(engine, spark, ws, secret_scope)
if isinstance(engine, Databricks):
return DatabricksDataSource(engine, spark, ws, secret_scope)
raise ValueError(f"Unsupported source type --> {engine}")
def create_adapter(
engine: Dialect,
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
) -> DataSource:
if isinstance(engine, Snow):
return SnowflakeDataSource(engine, spark, ws, secret_scope)
if isinstance(engine, Oracle):
return OracleDataSource(engine, spark, ws, secret_scope)
if isinstance(engine, Databricks):
return DatabricksDataSource(engine, spark, ws, secret_scope)
raise ValueError(f"Unsupported source type --> {engine}")
12 changes: 3 additions & 9 deletions src/databricks/labs/remorph/reconcile/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
reconcile_data,
)
from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.connectors.source_adapter import (
DataSourceAdapter,
)
from databricks.labs.remorph.reconcile.connectors.source_adapter import create_adapter
from databricks.labs.remorph.reconcile.exception import DataSourceRuntimeException
from databricks.labs.remorph.reconcile.query_builder.hash_query import HashQueryBuilder
from databricks.labs.remorph.reconcile.query_builder.sampling_query import (
Expand Down Expand Up @@ -111,10 +109,8 @@ def initialise_data_source(
engine: Dialect,
secret_scope: str,
):
source = DataSourceAdapter().create_adapter(engine=engine, spark=spark, ws=ws, secret_scope=secret_scope)
target = DataSourceAdapter().create_adapter(
engine=get_dialect("databricks"), spark=spark, ws=ws, secret_scope=secret_scope
)
source = create_adapter(engine=engine, spark=spark, ws=ws, secret_scope=secret_scope)
target = create_adapter(engine=get_dialect("databricks"), spark=spark, ws=ws, secret_scope=secret_scope)

return source, target

Expand Down Expand Up @@ -398,8 +394,6 @@ def _run_reconcile_data(
return reconciler.reconcile_data(table_conf=table_conf, src_schema=src_schema, tgt_schema=tgt_schema)
except DataSourceRuntimeException as e:
return DataReconcileOutput(exception=str(e))
except PySparkException as e:
return DataReconcileOutput(exception=str(e))


def _run_reconcile_schema(
Expand Down
32 changes: 25 additions & 7 deletions tests/unit/reconcile/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,10 @@ def mock_for_report_type_data(


def test_recon_for_report_type_is_data(
mock_workspace_client, mock_spark, report_tables_schema, mock_for_report_type_data
mock_workspace_client,
mock_spark,
report_tables_schema,
mock_for_report_type_data,
):
recon_schema, metrics_schema, details_schema = report_tables_schema
table_recon, source, target = mock_for_report_type_data
Expand Down Expand Up @@ -1173,7 +1176,10 @@ def mock_for_report_type_row(table_conf_with_opts, table_schema, mock_spark, que


def test_recon_for_report_type_is_row(
mock_workspace_client, mock_spark, mock_for_report_type_row, report_tables_schema
mock_workspace_client,
mock_spark,
mock_for_report_type_row,
report_tables_schema,
):
recon_schema, metrics_schema, details_schema = report_tables_schema
source, target, table_recon = mock_for_report_type_row
Expand Down Expand Up @@ -1297,7 +1303,10 @@ def mock_for_recon_exception(table_conf_with_opts, setup_metadata_table):


def test_schema_recon_with_data_source_exception(
mock_workspace_client, mock_spark, report_tables_schema, mock_for_recon_exception
mock_workspace_client,
mock_spark,
report_tables_schema,
mock_for_recon_exception,
):
recon_schema, metrics_schema, details_schema = report_tables_schema
table_recon, source, target = mock_for_recon_exception
Expand Down Expand Up @@ -1359,7 +1368,10 @@ def test_schema_recon_with_data_source_exception(


def test_schema_recon_with_general_exception(
mock_workspace_client, mock_spark, report_tables_schema, mock_for_report_type_schema
mock_workspace_client,
mock_spark,
report_tables_schema,
mock_for_report_type_schema,
):
recon_schema, metrics_schema, details_schema = report_tables_schema
table_recon, source, target = mock_for_report_type_schema
Expand Down Expand Up @@ -1423,7 +1435,10 @@ def test_schema_recon_with_general_exception(


def test_data_recon_with_general_exception(
mock_workspace_client, mock_spark, report_tables_schema, mock_for_report_type_schema
mock_workspace_client,
mock_spark,
report_tables_schema,
mock_for_report_type_schema,
):
recon_schema, metrics_schema, details_schema = report_tables_schema
table_recon, source, target = mock_for_report_type_schema
Expand All @@ -1438,7 +1453,7 @@ def test_data_recon_with_general_exception(
),
patch("databricks.labs.remorph.reconcile.execute.Reconciliation.reconcile_data") as data_source_mock,
):
data_source_mock.side_effect = PySparkException("Unknown Error")
data_source_mock.side_effect = DataSourceRuntimeException("Unknown Error")
mock_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185)
recon_datetime.now.return_value = datetime(2024, 5, 23, 9, 21, 25, 122185)
recon_id = recon(mock_workspace_client, mock_spark, table_recon, get_dialect("snowflake"), "data")
Expand Down Expand Up @@ -1487,7 +1502,10 @@ def test_data_recon_with_general_exception(


def test_data_recon_with_source_exception(
mock_workspace_client, mock_spark, report_tables_schema, mock_for_report_type_schema
mock_workspace_client,
mock_spark,
report_tables_schema,
mock_for_report_type_schema,
):
recon_schema, metrics_schema, details_schema = report_tables_schema
table_recon, source, target = mock_for_report_type_schema
Expand Down
12 changes: 5 additions & 7 deletions tests/unit/reconcile/test_source_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource
from databricks.labs.remorph.reconcile.connectors.oracle import OracleDataSource
from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource
from databricks.labs.remorph.reconcile.connectors.source_adapter import (
DataSourceAdapter,
)
from databricks.labs.remorph.reconcile.connectors.source_adapter import create_adapter
from databricks.sdk import WorkspaceClient


Expand All @@ -19,7 +17,7 @@ def test_create_adapter_for_snowflake_dialect():
ws = create_autospec(WorkspaceClient)
scope = "scope"

data_source = DataSourceAdapter().create_adapter(engine, spark, ws, scope)
data_source = create_adapter(engine, spark, ws, scope)
snowflake_data_source = SnowflakeDataSource(engine, spark, ws, scope).__class__

assert isinstance(data_source, snowflake_data_source)
Expand All @@ -31,7 +29,7 @@ def test_create_adapter_for_oracle_dialect():
ws = create_autospec(WorkspaceClient)
scope = "scope"

data_source = DataSourceAdapter().create_adapter(engine, spark, ws, scope)
data_source = create_adapter(engine, spark, ws, scope)
oracle_data_source = OracleDataSource(engine, spark, ws, scope).__class__

assert isinstance(data_source, oracle_data_source)
Expand All @@ -43,7 +41,7 @@ def test_create_adapter_for_databricks_dialect():
ws = create_autospec(WorkspaceClient)
scope = "scope"

data_source = DataSourceAdapter().create_adapter(engine, spark, ws, scope)
data_source = create_adapter(engine, spark, ws, scope)
databricks_data_source = DatabricksDataSource(engine, spark, ws, scope).__class__

assert isinstance(data_source, databricks_data_source)
Expand All @@ -56,4 +54,4 @@ def test_raise_exception_for_unknown_dialect():
scope = "scope"

with pytest.raises(ValueError, match=f"Unsupported source type --> {engine}"):
DataSourceAdapter().create_adapter(engine, spark, ws, scope)
create_adapter(engine, spark, ws, scope)

0 comments on commit b0a607f

Please sign in to comment.