Skip to content

Commit

Permalink
Added Threshold Query Builder (databrickslabs#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravit-db authored Mar 24, 2024
1 parent 218fa91 commit 3b5ffba
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 226 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ branch = true
parallel = true

[tool.coverage.report]
omit = ["src/databricks/labs/remorph/reconcile/*",
omit = [
"src/databricks/labs/remorph/coverage/*",
"src/databricks/labs/remorph/helpers/execution_time.py",
"__about__.py"]
Expand Down
31 changes: 19 additions & 12 deletions src/databricks/labs/remorph/reconcile/connectors/data_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from abc import ABC, abstractmethod

from databricks.sdk import WorkspaceClient # pylint: disable-next=wrong-import-order
Expand All @@ -6,7 +7,6 @@
from databricks.labs.remorph.reconcile.recon_config import ( # pylint: disable=ungrouped-imports
JdbcReaderOptions,
Schema,
Tables,
)


Expand All @@ -20,11 +20,11 @@ def __init__(self, source: str, spark: SparkSession, ws: WorkspaceClient, scope:
self.scope = scope

@abstractmethod
def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame:
def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame:
return NotImplemented

@abstractmethod
def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]:
def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]:
return NotImplemented

def _get_jdbc_reader(self, query, jdbc_url, driver):
Expand All @@ -36,15 +36,22 @@ def _get_jdbc_reader(self, query, jdbc_url, driver):
)

@staticmethod
def _get_jdbc_reader_options(jdbc_reader_options: JdbcReaderOptions):
def _get_jdbc_reader_options(options: JdbcReaderOptions):
return {
"numPartitions": jdbc_reader_options.number_partitions,
"partitionColumn": jdbc_reader_options.partition_column,
"lowerBound": jdbc_reader_options.lower_bound,
"upperBound": jdbc_reader_options.upper_bound,
"fetchsize": jdbc_reader_options.fetch_size,
"numPartitions": options.number_partitions,
"partitionColumn": options.partition_column,
"lowerBound": options.lower_bound,
"upperBound": options.upper_bound,
"fetchsize": options.fetch_size,
}

def _get_secrets(self, key_name):
key = self.source + '_' + key_name
return self.ws.secrets.get_secret(self.scope, key)
def _get_secrets(self, key):
return self.ws.secrets.get_secret(self.scope, self.source + '_' + key)

@staticmethod
def _get_table_or_query(catalog: str, schema: str, query: str) -> str:
if re.search('select', query, re.IGNORECASE):
return query.format(catalog_name=catalog, schema_name=schema)
if catalog:
return catalog + "." + schema + "." + query
return schema + "." + query
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from pyspark.sql import DataFrame

from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.recon_config import Schema, Tables
from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema


class DatabricksDataSource(DataSource):
def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame:
def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame:
# Implement Databricks-specific logic here
return NotImplemented

def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]:
def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]:
# Implement Databricks-specific logic here
return NotImplemented

Expand Down
29 changes: 12 additions & 17 deletions src/databricks/labs/remorph/reconcile/connectors/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.constants import SourceDriver
from databricks.labs.remorph.reconcile.recon_config import Schema, Tables
from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema


class OracleDataSource(DataSource):
Expand All @@ -16,32 +16,27 @@ def get_jdbc_url(self) -> str:
f":{self._get_secrets('port')}/{self._get_secrets('database')}"
)

# TODO need to check schema_name,catalog_name is needed
def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame:
def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame:
try:
if table_conf.jdbc_reader_options is None:
return self.reader(query).options(**self._get_timestamp_options()).load()
return (
self.reader(query)
.options(
**self._get_jdbc_reader_options(table_conf.jdbc_reader_options) | self._get_timestamp_options()
)
.load()
)
table_query = self._get_table_or_query(catalog, schema, query)
if options is None:
return self.reader(table_query).options(**self._get_timestamp_options()).load()
options = self._get_jdbc_reader_options(options) | self._get_timestamp_options()
return self.reader(table_query).options(**options).load()
except PySparkException as e:
error_msg = (
f"An error occurred while fetching Oracle Data using the following {query} in OracleDataSource : {e!s}"
)
raise PySparkException(error_msg) from e

def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]:
def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]:
try:
schema_query = self._get_schema_query(table_name, schema_name)
schema_query = self._get_schema_query(table, schema)
schema_df = self.reader(schema_query).load()
return [Schema(field.column_name.lower(), field.data_type.lower()) for field in schema_df.collect()]
except PySparkException as e:
error_msg = (
f"An error occurred while fetching Oracle Schema using the following {table_name} in "
f"An error occurred while fetching Oracle Schema using the following {table} in "
f"OracleDataSource: {e!s}"
)
raise PySparkException(error_msg) from e
Expand All @@ -63,7 +58,7 @@ def reader(self, query: str) -> DataFrameReader:
return self._get_jdbc_reader(query, self.get_jdbc_url, SourceDriver.ORACLE.value)

@staticmethod
def _get_schema_query(table_name: str, owner: str) -> str:
def _get_schema_query(table: str, owner: str) -> str:
return f"""select column_name, case when (data_precision is not null
and data_scale <> 0)
then data_type || '(' || data_precision || ',' || data_scale || ')'
Expand All @@ -75,4 +70,4 @@ def _get_schema_query(table_name: str, owner: str) -> str:
else data_type || '(' || CHAR_LENGTH || ')'
end data_type
FROM ALL_TAB_COLUMNS
WHERE lower(TABLE_NAME) = '{table_name}' and lower(owner) = '{owner}' """
WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}' """
6 changes: 3 additions & 3 deletions src/databricks/labs/remorph/reconcile/connectors/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from pyspark.sql import DataFrame

from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.recon_config import Schema, Tables
from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema


class SnowflakeDataSource(DataSource):
def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame:
def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame:
# Implement Snowflake-specific logic here
return NotImplemented

def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]:
def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]:
# Implement Snowflake-specific logic here
return NotImplemented

Expand Down
6 changes: 3 additions & 3 deletions src/databricks/labs/remorph/reconcile/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from databricks.labs.blueprint.installation import Installation

from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.recon_config import TableRecon, Tables
from databricks.labs.remorph.reconcile.recon_config import Table, TableRecon

logger = logging.getLogger(__name__)

Expand All @@ -27,10 +27,10 @@ def __init__(self, source: DataSource, target: DataSource):
self.source = source
self.target = target

def compare_schemas(self, table_conf: Tables, schema_name: str, catalog_name: str) -> bool:
def compare_schemas(self, table_conf: Table, schema_name: str, catalog_name: str) -> bool:
raise NotImplementedError

def compare_data(self, table_conf: Tables, schema_name: str, catalog_name: str) -> bool:
def compare_data(self, table_conf: Table, schema_name: str, catalog_name: str) -> bool:
raise NotImplementedError


Expand Down
Loading

0 comments on commit 3b5ffba

Please sign in to comment.