diff --git a/pyproject.toml b/pyproject.toml index 56d48c44b1..2a1dfe56cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,7 @@ branch = true parallel = true [tool.coverage.report] -omit = ["src/databricks/labs/remorph/reconcile/*", +omit = [ "src/databricks/labs/remorph/coverage/*", "src/databricks/labs/remorph/helpers/execution_time.py", "__about__.py"] diff --git a/src/databricks/labs/remorph/reconcile/connectors/data_source.py b/src/databricks/labs/remorph/reconcile/connectors/data_source.py index df0a796de0..8487b76b2c 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/data_source.py +++ b/src/databricks/labs/remorph/reconcile/connectors/data_source.py @@ -1,3 +1,4 @@ +import re from abc import ABC, abstractmethod from databricks.sdk import WorkspaceClient # pylint: disable-next=wrong-import-order @@ -6,7 +7,6 @@ from databricks.labs.remorph.reconcile.recon_config import ( # pylint: disable=ungrouped-imports JdbcReaderOptions, Schema, - Tables, ) @@ -20,11 +20,11 @@ def __init__(self, source: str, spark: SparkSession, ws: WorkspaceClient, scope: self.scope = scope @abstractmethod - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: return NotImplemented @abstractmethod - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: return NotImplemented def _get_jdbc_reader(self, query, jdbc_url, driver): @@ -36,15 +36,22 @@ def _get_jdbc_reader(self, query, jdbc_url, driver): ) @staticmethod - def _get_jdbc_reader_options(jdbc_reader_options: JdbcReaderOptions): + def _get_jdbc_reader_options(options: JdbcReaderOptions): return { - "numPartitions": jdbc_reader_options.number_partitions, - "partitionColumn": jdbc_reader_options.partition_column, - "lowerBound": jdbc_reader_options.lower_bound, - "upperBound": jdbc_reader_options.upper_bound, - "fetchsize": jdbc_reader_options.fetch_size, + "numPartitions": options.number_partitions, + "partitionColumn": options.partition_column, + "lowerBound": options.lower_bound, + "upperBound": options.upper_bound, + "fetchsize": options.fetch_size, } - def _get_secrets(self, key_name): - key = self.source + '_' + key_name - return self.ws.secrets.get_secret(self.scope, key) + def _get_secrets(self, key): + return self.ws.secrets.get_secret(self.scope, self.source + '_' + key) + + @staticmethod + def _get_table_or_query(catalog: str, schema: str, query: str) -> str: + if re.search('select', query, re.IGNORECASE): + return query.format(catalog_name=catalog, schema_name=schema) + if catalog: + return catalog + "." + schema + "." + query + return schema + "." + query diff --git a/src/databricks/labs/remorph/reconcile/connectors/databricks.py b/src/databricks/labs/remorph/reconcile/connectors/databricks.py index d4be097304..991b63a01d 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/databricks.py +++ b/src/databricks/labs/remorph/reconcile/connectors/databricks.py @@ -1,15 +1,15 @@ from pyspark.sql import DataFrame from databricks.labs.remorph.reconcile.connectors.data_source import DataSource -from databricks.labs.remorph.reconcile.recon_config import Schema, Tables +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema class DatabricksDataSource(DataSource): - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: # Implement Databricks-specific logic here return NotImplemented - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: # Implement Databricks-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/connectors/oracle.py b/src/databricks/labs/remorph/reconcile/connectors/oracle.py index c3da2bd1f8..49f4a23766 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/oracle.py +++ b/src/databricks/labs/remorph/reconcile/connectors/oracle.py @@ -3,7 +3,7 @@ from databricks.labs.remorph.reconcile.connectors.data_source import DataSource from databricks.labs.remorph.reconcile.constants import SourceDriver -from databricks.labs.remorph.reconcile.recon_config import Schema, Tables +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema class OracleDataSource(DataSource): @@ -16,32 +16,27 @@ def get_jdbc_url(self) -> str: f":{self._get_secrets('port')}/{self._get_secrets('database')}" ) - # TODO need to check schema_name,catalog_name is needed - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: try: - if table_conf.jdbc_reader_options is None: - return self.reader(query).options(**self._get_timestamp_options()).load() - return ( - self.reader(query) - .options( - **self._get_jdbc_reader_options(table_conf.jdbc_reader_options) | self._get_timestamp_options() - ) - .load() - ) + table_query = self._get_table_or_query(catalog, schema, query) + if options is None: + return self.reader(table_query).options(**self._get_timestamp_options()).load() + options = self._get_jdbc_reader_options(options) | self._get_timestamp_options() + return self.reader(table_query).options(**options).load() except PySparkException as e: error_msg = ( f"An error occurred while fetching Oracle Data using the following {query} in OracleDataSource : {e!s}" ) raise PySparkException(error_msg) from e - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: try: - schema_query = self._get_schema_query(table_name, schema_name) + schema_query = self._get_schema_query(table, schema) schema_df = self.reader(schema_query).load() return [Schema(field.column_name.lower(), field.data_type.lower()) for field in schema_df.collect()] except PySparkException as e: error_msg = ( - f"An error occurred while fetching Oracle Schema using the following {table_name} in " + f"An error occurred while fetching Oracle Schema using the following {table} in " f"OracleDataSource: {e!s}" ) raise PySparkException(error_msg) from e @@ -63,7 +58,7 @@ def reader(self, query: str) -> DataFrameReader: return self._get_jdbc_reader(query, self.get_jdbc_url, SourceDriver.ORACLE.value) @staticmethod - def _get_schema_query(table_name: str, owner: str) -> str: + def _get_schema_query(table: str, owner: str) -> str: return f"""select column_name, case when (data_precision is not null and data_scale <> 0) then data_type || '(' || data_precision || ',' || data_scale || ')' @@ -75,4 +70,4 @@ def _get_schema_query(table_name: str, owner: str) -> str: else data_type || '(' || CHAR_LENGTH || ')' end data_type FROM ALL_TAB_COLUMNS - WHERE lower(TABLE_NAME) = '{table_name}' and lower(owner) = '{owner}' """ + WHERE lower(TABLE_NAME) = '{table}' and lower(owner) = '{owner}' """ diff --git a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py index f36c52381f..05bc88dab2 100644 --- a/src/databricks/labs/remorph/reconcile/connectors/snowflake.py +++ b/src/databricks/labs/remorph/reconcile/connectors/snowflake.py @@ -1,15 +1,15 @@ from pyspark.sql import DataFrame from databricks.labs.remorph.reconcile.connectors.data_source import DataSource -from databricks.labs.remorph.reconcile.recon_config import Schema, Tables +from databricks.labs.remorph.reconcile.recon_config import JdbcReaderOptions, Schema class SnowflakeDataSource(DataSource): - def read_data(self, schema_name: str, catalog_name: str, query: str, table_conf: Tables) -> DataFrame: + def read_data(self, catalog: str, schema: str, query: str, options: JdbcReaderOptions) -> DataFrame: # Implement Snowflake-specific logic here return NotImplemented - def get_schema(self, table_name: str, schema_name: str, catalog_name: str) -> list[Schema]: + def get_schema(self, catalog: str, schema: str, table: str) -> list[Schema]: # Implement Snowflake-specific logic here return NotImplemented diff --git a/src/databricks/labs/remorph/reconcile/execute.py b/src/databricks/labs/remorph/reconcile/execute.py index 7ef53e347d..c23a1ed395 100644 --- a/src/databricks/labs/remorph/reconcile/execute.py +++ b/src/databricks/labs/remorph/reconcile/execute.py @@ -4,7 +4,7 @@ from databricks.labs.blueprint.installation import Installation from databricks.labs.remorph.reconcile.connectors.data_source import DataSource -from databricks.labs.remorph.reconcile.recon_config import TableRecon, Tables +from databricks.labs.remorph.reconcile.recon_config import Table, TableRecon logger = logging.getLogger(__name__) @@ -27,10 +27,10 @@ def __init__(self, source: DataSource, target: DataSource): self.source = source self.target = target - def compare_schemas(self, table_conf: Tables, schema_name: str, catalog_name: str) -> bool: + def compare_schemas(self, table_conf: Table, schema_name: str, catalog_name: str) -> bool: raise NotImplementedError - def compare_data(self, table_conf: Tables, schema_name: str, catalog_name: str) -> bool: + def compare_data(self, table_conf: Table, schema_name: str, catalog_name: str) -> bool: raise NotImplementedError diff --git a/src/databricks/labs/remorph/reconcile/query_builder.py b/src/databricks/labs/remorph/reconcile/query_builder.py index e09177cca3..01ef2ec89a 100644 --- a/src/databricks/labs/remorph/reconcile/query_builder.py +++ b/src/databricks/labs/remorph/reconcile/query_builder.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from io import StringIO from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource @@ -8,165 +9,185 @@ Constants, SourceType, ) +from databricks.labs.remorph.reconcile.query_config import QueryConfig from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, Schema, - Tables, Transformation, TransformRuleMapping, ) -class QueryBuilder: +class QueryBuilder(ABC): - def __init__(self, table_conf: Tables, schema: list[Schema], layer: str, source: str): - self.table_conf = table_conf - self.schema = schema - self.layer = layer - self.source = source + def __init__(self, qrc: QueryConfig): + self.qrc = qrc - def build_hash_query(self) -> str: - schema_info = {v.column_name: v for v in self.schema} + @abstractmethod + def build_query(self): + raise NotImplementedError - columns, key_columns = self._get_column_list() - col_transformations = self._generate_transformation_rule_mapping(columns, schema_info) + def _get_custom_transformation( + self, cols: list[str], transform_dict: dict[str, Transformation], col_mapping: dict[str, ColumnMapping] + ) -> list[TransformRuleMapping]: + transform_rule_mapping = [] + for col in cols: + if col in transform_dict.keys(): + transform = self._get_layer_transform(transform_dict, col, self.qrc.layer) + else: + transform = None - hash_columns_expr = self._get_column_expr( - TransformRuleMapping.get_column_expression_without_alias, col_transformations - ) - hash_expr = self._generate_hash_algorithm(self.source, hash_columns_expr) + col_origin, col_alias = self._get_column_alias(self.qrc.layer, col, col_mapping) - key_column_transformation = self._generate_transformation_rule_mapping(key_columns, schema_info) - key_column_expr = self._get_column_expr( - TransformRuleMapping.get_column_expression_with_alias, key_column_transformation - ) + transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) - if self.layer == "source": - table_name = self.table_conf.source_name - query_filter = self.table_conf.filters.source if self.table_conf.filters else " 1 = 1 " - else: - table_name = self.table_conf.target_name - query_filter = self.table_conf.filters.target if self.table_conf.filters else " 1 = 1 " + return transform_rule_mapping - # construct select query - select_query = self._construct_hash_query(table_name, query_filter, hash_expr, key_column_expr) + def _get_default_transformation( + self, cols: list[str], col_mapping: dict[str, ColumnMapping], schema: dict[str, Schema] + ) -> list[TransformRuleMapping]: + transform_rule_mapping = [] + for col in cols: + col_origin = col if self.qrc.layer == "source" else self._get_column_map(col, col_mapping) + col_data_type = schema.get(col_origin).data_type + transform = self._get_default_transformation_expr(self.qrc.source, col_data_type).format(col_origin) - return select_query + col_origin, col_alias = self._get_column_alias(self.qrc.layer, col, col_mapping) - def _get_column_list(self) -> tuple[list[str], list[str]]: - column_mapping = self.table_conf.list_to_dict(ColumnMapping, "source_name") + transform_rule_mapping.append(TransformRuleMapping(col_origin, transform, col_alias)) - if self.table_conf.join_columns is None: - join_columns = set() - elif self.layer == "source": - join_columns = {col.source_name for col in self.table_conf.join_columns} - else: - join_columns = {col.target_name for col in self.table_conf.join_columns} + return transform_rule_mapping - if self.table_conf.select_columns is None: - select_columns = {sch.column_name for sch in self.schema} - else: - select_columns = self._get_mapped_columns(self.layer, column_mapping, self.table_conf.select_columns) + @staticmethod + def _get_default_transformation_expr(data_source: str, data_type: str) -> str: + if data_source == SourceType.ORACLE.value: + return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) + if data_source == SourceType.SNOWFLAKE.value: + return SnowflakeDataSource.snowflake_datatype_mapper.get( + data_type, ColumnTransformationType.SNOWFLAKE_DEFAULT.value + ) + if data_source == SourceType.DATABRICKS.value: + return DatabricksDataSource.databricks_datatype_mapper.get( + data_type, ColumnTransformationType.DATABRICKS_DEFAULT.value + ) + msg = f"Unsupported source type --> {data_source}" + raise ValueError(msg) + + def _generate_transform_rule_mapping(self, cols: list[str]) -> list[TransformRuleMapping]: - if self.table_conf.jdbc_reader_options and self.layer == "source": - partition_column = {self.table_conf.jdbc_reader_options.partition_column} + # compute custom transformation + if self.qrc.transform_dict: + cols_with_transform = [col for col in cols if col in self.qrc.transform_dict.keys()] + custom_transform = self._get_custom_transformation( + cols_with_transform, self.qrc.transform_dict, self.qrc.src_col_mapping + ) else: - partition_column = set() + custom_transform = [] - # Combine all column names - all_columns = join_columns | select_columns + # compute default transformation + cols_without_transform = [col for col in cols if col not in self.qrc.transform_dict.keys()] + default_transform = self._get_default_transformation( + cols_without_transform, self.qrc.src_col_mapping, self.qrc.schema_dict + ) - # Remove threshold and drop columns - threshold_columns = {thresh.column_name for thresh in self.table_conf.thresholds or []} - if self.table_conf.drop_columns is None: - drop_columns = set() - else: - drop_columns = self._get_mapped_columns(self.layer, column_mapping, self.table_conf.drop_columns) + transform_rule_mapping = custom_transform + default_transform - columns = sorted(all_columns - threshold_columns - drop_columns) - key_columns = sorted(join_columns | partition_column) + return transform_rule_mapping - return columns, key_columns + @staticmethod + def _get_layer_transform(transform_dict: dict[str, Transformation], col: str, layer: str) -> str: + return transform_dict.get(col).source if layer == "source" else transform_dict.get(col).target - def _generate_transformation_rule_mapping(self, columns: list[str], schema: dict) -> list[TransformRuleMapping]: - transformations_dict = self.table_conf.list_to_dict(Transformation, "column_name") - column_mapping_dict = self.table_conf.list_to_dict(ColumnMapping, "target_name") + @staticmethod + def _get_column_expr(func, col_transform: list[TransformRuleMapping]): + return [func(transform) for transform in col_transform] - transformation_rule_mapping = [] - for column in columns: - if column_mapping_dict and self.layer == "target": - transform_column = ( - column_mapping_dict.get(column).source_name if column_mapping_dict.get(column) else column - ) - else: - transform_column = column + @staticmethod + def _get_column_map(col, col_mapping: dict[str, ColumnMapping]) -> str: + return col_mapping.get(col, ColumnMapping(source_name='', target_name=col)).target_name - if transformations_dict and transform_column in transformations_dict.keys(): - transformation = self._get_layer_transform(transformations_dict, transform_column, self.layer) - else: - column_data_type = schema.get(column).data_type - transformation = self._get_default_transformation(self.source, column_data_type).format(column) + @staticmethod + def _get_column_alias(layer: str, col: str, col_mapping: dict[str, ColumnMapping]) -> tuple[str, str]: + if col_mapping and col in col_mapping.keys() and layer == "target": + col_alias = col_mapping.get(col).source_name + col_origin = col_mapping.get(col).target_name + else: + col_alias = col + col_origin = col - if column_mapping_dict and column in column_mapping_dict.keys(): - column_alias = column_mapping_dict.get(column).source_name - else: - column_alias = column + return col_origin, col_alias - transformation_rule_mapping.append(TransformRuleMapping(column, transformation, column_alias)) - return transformation_rule_mapping +class HashQueryBuilder(QueryBuilder): - @staticmethod - def _get_default_transformation(data_source: str, data_type: str) -> str: - if data_source == "oracle": - return OracleDataSource.oracle_datatype_mapper.get(data_type, ColumnTransformationType.ORACLE_DEFAULT.value) - if data_source == "snowflake": - return SnowflakeDataSource.snowflake_datatype_mapper.get( - data_type, ColumnTransformationType.SNOWFLAKE_DEFAULT.value - ) - if data_source == "databricks": - return DatabricksDataSource.databricks_datatype_mapper.get( - data_type, ColumnTransformationType.DATABRICKS_DEFAULT.value - ) - msg = f"Unsupported source type --> {data_source}" - raise ValueError(msg) + def build_query(self) -> str: + hash_cols = sorted( + (self.qrc.join_columns | self.qrc.select_columns) - self.qrc.threshold_columns - self.qrc.drop_columns + ) + key_cols = sorted(self.qrc.join_columns | self.qrc.partition_column) - @staticmethod - def _get_layer_transform(transform_dict: dict[str, Transformation], column: str, layer: str) -> str: - return transform_dict.get(column).source if layer == "source" else transform_dict.get(column).target + # get transformation for columns considered for hashing + col_transform = self._generate_transform_rule_mapping(hash_cols) + hash_cols_expr = sorted( + self._get_column_expr(TransformRuleMapping.get_column_expr_without_alias, col_transform) + ) + hash_expr = self._generate_hash_algorithm(self.qrc.source, hash_cols_expr) - @staticmethod - def _get_column_expr(func, column_transformations: list[TransformRuleMapping]): - return [func(transformation) for transformation in column_transformations] + # get transformation for columns considered for joining and partition key + key_col_transform = self._generate_transform_rule_mapping(key_cols) + key_col_expr = sorted(self._get_column_expr(TransformRuleMapping.get_column_expr_with_alias, key_col_transform)) + + # construct select hash query + select_query = self._construct_hash_query(self.qrc.table_name, self.qrc.filter, hash_expr, key_col_expr) + + return select_query @staticmethod - def _generate_hash_algorithm(source: str, column_expr: list[str]) -> str: + def _generate_hash_algorithm(source: str, col_expr: list[str]) -> str: if source in {SourceType.DATABRICKS.value, SourceType.SNOWFLAKE.value}: - hash_expr = "concat(" + ", ".join(column_expr) + ")" + hash_expr = "concat(" + ", ".join(col_expr) + ")" else: - hash_expr = " || ".join(column_expr) + hash_expr = " || ".join(col_expr) - return (Constants.hash_algorithm_mapping.get(source.lower()).get("source")).format(hash_expr) + return (Constants.hash_algorithm_mapping.get(source).get("source")).format(hash_expr) @staticmethod - def _construct_hash_query(table_name: str, query_filter: str, hash_expr: str, key_column_expr: list[str]) -> str: + def _construct_hash_query(table: str, query_filter: str, hash_expr: str, key_col_expr: list[str]) -> str: sql_query = StringIO() + # construct hash expr sql_query.write(f"select {hash_expr} as {Constants.hash_column_name}") # add join column - if key_column_expr: - sql_query.write(", " + ",".join(key_column_expr)) - sql_query.write(f" from {table_name} where {query_filter}") + if key_col_expr: + sql_query.write(", " + ",".join(key_col_expr)) + sql_query.write(f" from {table} where {query_filter}") select_query = sql_query.getvalue() sql_query.close() return select_query + +class ThresholdQueryBuilder(QueryBuilder): + + def build_query(self) -> str: + all_columns = set(self.qrc.threshold_columns | self.qrc.join_columns | self.qrc.partition_column) + + query_columns = sorted( + all_columns + if self.qrc.layer == "source" + else self.qrc.get_mapped_columns(self.qrc.src_col_mapping, all_columns) + ) + + transform_rule_mapping = self._get_custom_transformation( + query_columns, self.qrc.transform_dict, self.qrc.src_col_mapping + ) + col_expr = self._get_column_expr(TransformRuleMapping.get_column_expr_with_alias, transform_rule_mapping) + + select_query = self._construct_threshold_query(self.qrc.table_name, self.qrc.filter, col_expr) + + return select_query + @staticmethod - def _get_mapped_columns(layer: str, column_mapping: dict, columns: list[str]) -> set[str]: - if layer == "source": - return set(columns) - select_columns = set() - for column in columns: - select_columns.add(column_mapping.get(column).target_name if column_mapping.get(column) else column) - return select_columns + def _construct_threshold_query(table, query_filter, col_expr) -> str: + expr = ",".join(col_expr) + return f"select {expr} from {table} where {query_filter}" diff --git a/src/databricks/labs/remorph/reconcile/query_config.py b/src/databricks/labs/remorph/reconcile/query_config.py new file mode 100644 index 0000000000..e0e6d2e628 --- /dev/null +++ b/src/databricks/labs/remorph/reconcile/query_config.py @@ -0,0 +1,90 @@ +from databricks.labs.remorph.reconcile.constants import SourceType +from databricks.labs.remorph.reconcile.recon_config import ( + ColumnMapping, + Schema, + Table, + Transformation, +) + + +class QueryConfig: + def __init__(self, table_conf: Table, schema: list[Schema], layer: str, source: str): + self._table_conf = table_conf + self._schema = schema + self._layer = layer + self._source = source + + @property + def source(self): + return self._source + + @property + def layer(self): + return self._layer + + @property + def schema_dict(self): + return {v.column_name: v for v in self._schema} + + @property + def tgt_col_mapping(self): + return self._table_conf.list_to_dict(ColumnMapping, "target_name") + + @property + def src_col_mapping(self): + return self._table_conf.list_to_dict(ColumnMapping, "source_name") + + @property + def transform_dict(self): + return self._table_conf.list_to_dict(Transformation, "column_name") + + @property + def threshold_columns(self) -> set[str]: + return {thresh.column_name for thresh in self._table_conf.thresholds or []} + + @property + def join_columns(self) -> set[str]: + if self._table_conf.join_columns is None: + return set() + return set(self._table_conf.join_columns) + + @property + def select_columns(self) -> set[str]: + if self._table_conf.select_columns is None: + cols = {sch.column_name for sch in self._schema} + return cols if self._layer == "source" else self.get_mapped_columns(self.tgt_col_mapping, cols) + return set(self._table_conf.select_columns) + + @property + def partition_column(self) -> set[str]: + if self._table_conf.jdbc_reader_options and self._layer == "source": + return {self._table_conf.jdbc_reader_options.partition_column} + return set() + + @property + def drop_columns(self) -> set[str]: + if self._table_conf.drop_columns is None: + return set() + return set(self._table_conf.drop_columns) + + @property + def table_name(self) -> str: + table_name = self._table_conf.source_name if self._layer == "source" else self._table_conf.target_name + if self._source == SourceType.ORACLE.value: + return f"{{schema_name}}.{table_name}" + return f"{{catalog_name}}.{{schema_name}}.{table_name}" + + @property + def filter(self) -> str: + if self._table_conf.filters is None: + return " 1 = 1 " + if self._layer == "source": + return self._table_conf.filters.source + return self._table_conf.filters.target + + @staticmethod + def get_mapped_columns(col_mapping: dict[str, ColumnMapping], cols: set[str]) -> set[str]: + select_columns = set() + for col in cols: + select_columns.add(col_mapping.get(col, ColumnMapping(source_name=col, target_name='')).source_name) + return select_columns diff --git a/src/databricks/labs/remorph/reconcile/recon_config.py b/src/databricks/labs/remorph/reconcile/recon_config.py index f1e08c28ed..059514d094 100644 --- a/src/databricks/labs/remorph/reconcile/recon_config.py +++ b/src/databricks/labs/remorph/reconcile/recon_config.py @@ -10,15 +10,13 @@ class TransformRuleMapping: transformation: str alias_name: str - def get_column_expression_without_alias(self) -> str: + def get_column_expr_without_alias(self) -> str: if self.transformation: return f"{self.transformation}" return f"{self.column_name}" - def get_column_expression_with_alias(self) -> str: - if self.alias_name: - return f"{self.get_column_expression_without_alias()} as {self.alias_name}" - return f"{self.get_column_expression_without_alias()} as {self.column_name}" + def get_column_expr_with_alias(self) -> str: + return f"{self.get_column_expr_without_alias()} as {self.alias_name}" @dataclass @@ -30,12 +28,6 @@ class JdbcReaderOptions: fetch_size: int = 100 -@dataclass -class JoinColumns: - source_name: str - target_name: str | None = None - - @dataclass class ColumnMapping: source_name: str @@ -64,10 +56,10 @@ class Filters: @dataclass -class Tables: +class Table: source_name: str target_name: str - join_columns: list[JoinColumns] | None = None + join_columns: list[str] | None = None jdbc_reader_options: JdbcReaderOptions | None = None select_columns: list[str] | None = None drop_columns: list[str] | None = None @@ -76,9 +68,9 @@ class Tables: thresholds: list[Thresholds] | None = None filters: Filters | None = None - T = TypeVar("T") # pylint: disable=invalid-name + Typ = TypeVar("Typ") - def list_to_dict(self, cls: type[T], key: str) -> T: + def list_to_dict(self, cls: type[Typ], key: str) -> Typ: for _, value in self.__dict__.items(): if isinstance(value, list): if all(isinstance(x, cls) for x in value): @@ -91,7 +83,7 @@ class TableRecon: source_schema: str target_catalog: str target_schema: str - tables: list[Tables] + tables: list[Table] source_catalog: str | None = None diff --git a/tests/unit/reconcile/test_query_builder.py b/tests/unit/reconcile/test_query_builder.py index 546ac3b3e5..f3746acbc6 100644 --- a/tests/unit/reconcile/test_query_builder.py +++ b/tests/unit/reconcile/test_query_builder.py @@ -1,17 +1,23 @@ -from databricks.labs.remorph.reconcile.query_builder import QueryBuilder +import pytest + +from databricks.labs.remorph.reconcile.query_builder import ( + HashQueryBuilder, + ThresholdQueryBuilder, +) +from databricks.labs.remorph.reconcile.query_config import QueryConfig from databricks.labs.remorph.reconcile.recon_config import ( ColumnMapping, + Filters, JdbcReaderOptions, - JoinColumns, Schema, - Tables, + Table, Thresholds, Transformation, ) -def test_query_builder_without_join_column(): - table_conf = Tables( +def test_hash_query_builder_without_join_column(): + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, @@ -33,13 +39,14 @@ def test_query_builder_without_join_column(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_acctbal),'') || " "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_nationkey),'') || " "coalesce(trim(s_phone),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " - "hash_value__recon from supplier " + "hash_value__recon from {schema_name}.supplier " "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -54,24 +61,25 @@ def test_query_builder_without_join_column(): Schema("s_comment", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_acctbal),''), " "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " "coalesce(trim(s_name),''), coalesce(trim(s_nationkey),''), " "coalesce(trim(s_phone),''), coalesce(trim(s_suppkey),'')),256) as " - "hash_value__recon from supplier " + "hash_value__recon from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_defaults(): - table_conf = Tables( +def test_hash_query_builder_with_defaults(): + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey")], + join_columns=["s_suppkey"], select_columns=None, drop_columns=None, column_mapping=None, @@ -89,13 +97,14 @@ def test_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_acctbal),'') || " "coalesce(trim(s_address),'') || coalesce(trim(s_comment),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_nationkey),'') || " "coalesce(trim(s_phone),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " - "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier " "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -110,24 +119,25 @@ def test_query_builder_with_defaults(): Schema("s_comment", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_acctbal),''), " "coalesce(trim(s_address),''), coalesce(trim(s_comment),''), " "coalesce(trim(s_name),''), coalesce(trim(s_nationkey),''), " "coalesce(trim(s_phone),''), coalesce(trim(s_suppkey),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_select(): - table_conf = Tables( +def test_hash_query_builder_with_select(): + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=["s_suppkey", "s_name", "s_address"], drop_columns=None, column_mapping=[ @@ -148,11 +158,12 @@ def test_query_builder_with_select(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " - "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier " "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -167,23 +178,24 @@ def test_query_builder_with_select(): Schema("s_comment_t", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_transformations_with_drop_and_default_select(): - table_conf = Tables( +def test_hash_query_builder_with_transformations_with_drop_and_default_select(): + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=None, - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=None, drop_columns=["s_comment"], column_mapping=[ @@ -217,13 +229,14 @@ def test_query_builder_with_transformations_with_drop_and_default_select(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( - 'select lower(RAWTOHEX(STANDARD_HASH(trim(to_char(s_acctbal_t, ' - "'9999999999.99')) || trim(s_address) || trim(s_name) || " - "coalesce(trim(s_nationkey),'') || trim(s_phone) || " - "coalesce(trim(s_suppkey),''), 'SHA256'))) as hash_value__recon, " - "coalesce(trim(s_suppkey),'') as s_suppkey from supplier where 1 = 1 " + "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_nationkey),'') || " + "coalesce(trim(s_suppkey),'') || trim(s_address) || trim(s_name) || " + "trim(s_phone) || trim(to_char(s_acctbal_t, '9999999999.99')), 'SHA256'))) as " + "hash_value__recon, coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier " + "where 1 = 1 " ) assert actual_src_query == expected_src_query @@ -237,26 +250,26 @@ def test_query_builder_with_transformations_with_drop_and_default_select(): Schema("s_comment_t", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( - 'select sha2(concat(cast(s_acctbal_t as decimal(38,2)), trim(s_address_t), ' - "trim(s_name), coalesce(trim(s_nationkey_t),''), " - "trim(s_phone_t), coalesce(trim(s_suppkey_t),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " - 'where 1 = 1 ' + "select sha2(concat(cast(s_acctbal_t as decimal(38,2)), " + "coalesce(trim(s_nationkey_t),''), coalesce(trim(s_suppkey_t),''), " + 'trim(s_address_t), trim(s_name), trim(s_phone_t)),256) as hash_value__recon, ' + "coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_jdbc_reader_options(): - table_conf = Tables( +def test_hash_query_builder_with_jdbc_reader_options(): + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=JdbcReaderOptions( number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" ), - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=["s_suppkey", "s_name", "s_address"], drop_columns=None, column_mapping=[ @@ -277,12 +290,13 @@ def test_query_builder_with_jdbc_reader_options(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_address),'') || " "coalesce(trim(s_name),'') || coalesce(trim(s_suppkey),''), 'SHA256'))) as " "hash_value__recon, coalesce(trim(s_nationkey),'') as s_nationkey,coalesce(trim(s_suppkey),'') as s_suppkey " - "from supplier " + "from {schema_name}.supplier " "where 1 = 1 " ) @@ -298,25 +312,26 @@ def test_query_builder_with_jdbc_reader_options(): Schema("s_comment_t", "varchar"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( "select sha2(concat(coalesce(trim(s_address_t),''), " "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " - "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier " + "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " "where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query -def test_query_builder_with_threshold(): - table_conf = Tables( +def test_hash_query_builder_with_threshold(): + table_conf = Table( source_name="supplier", target_name="supplier", jdbc_reader_options=JdbcReaderOptions( number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" ), - join_columns=[JoinColumns(source_name="s_suppkey", target_name="s_suppkey_t")], + join_columns=["s_suppkey"], select_columns=None, drop_columns=None, column_mapping=[ @@ -341,14 +356,73 @@ def test_query_builder_with_threshold(): Schema("s_comment", "varchar"), ] - actual_src_query = QueryBuilder(table_conf, src_schema, "source", "oracle").build_hash_query() + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = HashQueryBuilder(src_qrc).build_query() + expected_src_query = ( + "select lower(RAWTOHEX(STANDARD_HASH(coalesce(trim(s_comment),'') || " + "coalesce(trim(s_nationkey),'') || coalesce(trim(s_suppkey),'') || " + "trim(s_address) || trim(s_name) || trim(s_phone), 'SHA256'))) as " + "hash_value__recon, coalesce(trim(s_nationkey),'') as " + "s_nationkey,coalesce(trim(s_suppkey),'') as s_suppkey from {schema_name}.supplier where 1 " + '= 1 ' + ) + assert actual_src_query == expected_src_query + + tgt_schema = [ + Schema("s_suppkey_t", "number"), + Schema("s_name", "varchar"), + Schema("s_address_t", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() + expected_tgt_query = ( + "select sha2(concat(coalesce(trim(s_comment),''), " + "coalesce(trim(s_nationkey),''), coalesce(trim(s_suppkey_t),''), " + 'trim(s_address_t), trim(s_name), trim(s_phone)),256) as hash_value__recon, ' + "coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " + ) + + assert actual_tgt_query == expected_tgt_query + + +def test_hash_query_builder_with_filters(): + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=["s_suppkey"], + select_columns=["s_suppkey", "s_name", "s_address"], + drop_columns=None, + column_mapping=[ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_address", target_name="s_address_t"), + ], + transformations=None, + thresholds=None, + filters=Filters(source="s_name='t' and s_address='a'", target="s_name='t' and s_address_t='a'"), + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + src_qrc = QueryConfig(table_conf, src_schema, "source", "snowflake") + actual_src_query = HashQueryBuilder(src_qrc).build_query() expected_src_query = ( - 'select lower(RAWTOHEX(STANDARD_HASH(trim(s_address) || ' - "coalesce(trim(s_comment),'') || trim(s_name) || " - "coalesce(trim(s_nationkey),'') || trim(s_phone) || " - "coalesce(trim(s_suppkey),''), 'SHA256'))) as hash_value__recon, " - "coalesce(trim(s_nationkey),'') as s_nationkey,coalesce(trim(s_suppkey),'') " - 'as s_suppkey from supplier where 1 = 1 ' + "select sha2(concat(coalesce(trim(s_address),''), coalesce(trim(s_name),''), " + "coalesce(trim(s_suppkey),'')),256) as hash_value__recon, " + "coalesce(trim(s_suppkey),'') as s_suppkey from {catalog_name}.{schema_name}.supplier where s_name='t' and " + "s_address='a'" ) assert actual_src_query == expected_src_query @@ -356,18 +430,176 @@ def test_query_builder_with_threshold(): Schema("s_suppkey_t", "number"), Schema("s_name", "varchar"), Schema("s_address_t", "varchar"), + Schema("s_nationkey_t", "number"), + Schema("s_phone_t", "varchar"), + Schema("s_acctbal_t", "number"), + Schema("s_comment_t", "varchar"), + ] + + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = HashQueryBuilder(tgt_qrc).build_query() + expected_tgt_query = ( + "select sha2(concat(coalesce(trim(s_address_t),''), " + "coalesce(trim(s_name),''), coalesce(trim(s_suppkey_t),'')),256) as " + "hash_value__recon, coalesce(trim(s_suppkey_t),'') as s_suppkey from {catalog_name}.{schema_name}.supplier " + "where s_name='t' and s_address_t='a'" + ) + + assert actual_tgt_query == expected_tgt_query + + +def test_hash_query_builder_with_unsupported_source(): + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=None, + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + thresholds=None, + filters=None, + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + src_qrc = QueryConfig(table_conf, src_schema, "source", "abc") + query_builder = HashQueryBuilder(src_qrc) + + with pytest.raises(Exception) as exc_info: + query_builder.build_query() + + assert str(exc_info.value) == "Unsupported source type --> abc" + + +def test_threshold_query_builder_with_defaults(): + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=None, + join_columns=["s_suppkey"], + select_columns=None, + drop_columns=None, + column_mapping=None, + transformations=None, + thresholds=[Thresholds(column_name="s_acctbal", lower_bound="0", upper_bound="100", type="int")], + filters=None, + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = ThresholdQueryBuilder(src_qrc).build_query() + expected_src_query = ( + 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {schema_name}.supplier where 1 = 1 ' + ) + assert actual_src_query == expected_src_query + + tgt_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), + Schema("s_nationkey", "number"), + Schema("s_phone", "varchar"), + Schema("s_acctbal", "number"), + Schema("s_comment", "varchar"), + ] + + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = ThresholdQueryBuilder(tgt_qrc).build_query() + expected_tgt_query = ( + 'select s_acctbal as s_acctbal,s_suppkey as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 ' + ) + assert actual_tgt_query == expected_tgt_query + + +def test_threshold_query_builder_with_transformations_and_jdbc(): + table_conf = Table( + source_name="supplier", + target_name="supplier", + jdbc_reader_options=JdbcReaderOptions( + number_partitions=100, partition_column="s_nationkey", lower_bound="0", upper_bound="100" + ), + join_columns=["s_suppkey"], + select_columns=None, + drop_columns=["s_comment"], + column_mapping=[ + ColumnMapping(source_name="s_suppkey", target_name="s_suppkey_t"), + ColumnMapping(source_name="s_address", target_name="s_address_t"), + ColumnMapping(source_name="s_nationkey", target_name="s_nationkey_t"), + ColumnMapping(source_name="s_phone", target_name="s_phone_t"), + ColumnMapping(source_name="s_acctbal", target_name="s_acctbal_t"), + ColumnMapping(source_name="s_comment", target_name="s_comment_t"), + ColumnMapping(source_name="s_suppdate", target_name="s_suppdate_t"), + ], + transformations=[ + Transformation(column_name="s_suppkey", source="trim(s_suppkey)", target="trim(s_suppkey_t)"), + Transformation(column_name="s_address", source="trim(s_address)", target="trim(s_address_t)"), + Transformation(column_name="s_phone", source="trim(s_phone)", target="trim(s_phone_t)"), + Transformation(column_name="s_name", source="trim(s_name)", target="trim(s_name)"), + Transformation( + column_name="s_acctbal", + source="trim(to_char(s_acctbal, '9999999999.99'))", + target="cast(s_acctbal_t as decimal(38,2))", + ), + ], + thresholds=[ + Thresholds(column_name="s_acctbal", lower_bound="0", upper_bound="100", type="int"), + Thresholds(column_name="s_suppdate", lower_bound="-86400", upper_bound="86400", type="timestamp"), + ], + filters=None, + ) + src_schema = [ + Schema("s_suppkey", "number"), + Schema("s_name", "varchar"), + Schema("s_address", "varchar"), Schema("s_nationkey", "number"), Schema("s_phone", "varchar"), Schema("s_acctbal", "number"), Schema("s_comment", "varchar"), + Schema("s_suppdate", "timestamp"), + ] + + src_qrc = QueryConfig(table_conf, src_schema, "source", "oracle") + actual_src_query = ThresholdQueryBuilder(src_qrc).build_query() + expected_src_query = ( + "select trim(to_char(s_acctbal, '9999999999.99')) as s_acctbal,s_nationkey " + "as s_nationkey,s_suppdate as s_suppdate,trim(s_suppkey) as s_suppkey from " + "{schema_name}.supplier where 1 = 1 " + ) + assert actual_src_query == expected_src_query + + tgt_schema = [ + Schema("s_suppkey_t", "number"), + Schema("s_name", "varchar"), + Schema("s_address_t", "varchar"), + Schema("s_nationkey_t", "number"), + Schema("s_phone_t", "varchar"), + Schema("s_acctbal_t", "number"), + Schema("s_comment_t", "varchar"), + Schema("s_suppdate_t", "timestamp"), ] - actual_tgt_query = QueryBuilder(table_conf, tgt_schema, "target", "databricks").build_hash_query() + tgt_qrc = QueryConfig(table_conf, tgt_schema, "target", "databricks") + actual_tgt_query = ThresholdQueryBuilder(tgt_qrc).build_query() expected_tgt_query = ( - "select sha2(concat(trim(s_address_t), coalesce(trim(s_comment),''), " - "trim(s_name), coalesce(trim(s_nationkey),''), trim(s_phone), " - "coalesce(trim(s_suppkey_t),'')),256) as hash_value__recon, " - "coalesce(trim(s_suppkey_t),'') as s_suppkey from supplier where 1 = 1 " + "select cast(s_acctbal_t as decimal(38,2)) as s_acctbal,s_suppdate_t as " + "s_suppdate,trim(s_suppkey_t) as s_suppkey from {catalog_name}.{schema_name}.supplier where 1 = 1 " ) assert actual_tgt_query == expected_tgt_query