From 4cf2b78fcaef7f4cd5e8c89e484e7ed03f889b24 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 2 Jul 2020 07:58:39 -0600 Subject: [PATCH] move the sql for getting different rows into dbt proper, from the test suite. Bump pytest dependency. --- core/dbt/adapters/base/impl.py | 58 +++++++++++++++++++ core/dbt/adapters/base/relation.py | 4 ++ dev_requirements.txt | 2 +- .../bigquery/dbt/adapters/bigquery/impl.py | 14 +++++ test/integration/base.py | 43 ++------------ 5 files changed, 81 insertions(+), 40 deletions(-) diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index f4defc204b3..3795766277d 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -1107,6 +1107,64 @@ def post_model_hook(self, config: Mapping[str, Any], context: Any) -> None: """ pass + def get_rows_different_sql( + self, + relation_a: BaseRelation, + relation_b: BaseRelation, + column_names: Optional[List[str]] = None, + except_operator: str = 'EXCEPT', + ) -> str: + """Generate SQL for a query that returns a single row with a two + columns: the number of rows that are different between the two + relations and the number of mismatched rows. + """ + # This method only really exists for test reasons. + names: List[str] + if column_names is None: + columns = self.get_columns_in_relation(relation_a) + names = sorted((self.quote(c.name) for c in columns)) + else: + names = sorted((self.quote(n) for n in column_names)) + columns_csv = ', '.join(names) + + sql = COLUMNS_EQUAL_SQL.format( + columns=columns_csv, + relation_a=str(relation_a), + relation_b=str(relation_b), + except_op=except_operator, + ) + + return sql + + +COLUMNS_EQUAL_SQL = ''' +with diff_count as ( + SELECT + 1 as id, + COUNT(*) as num_missing FROM ( + (SELECT {columns} FROM {relation_a} {except_op} + SELECT {columns} FROM {relation_b}) + UNION ALL + (SELECT {columns} FROM {relation_b} {except_op} + SELECT {columns} FROM {relation_a}) + ) as a +), table_a as ( + SELECT COUNT(*) as num_rows FROM {relation_a} +), table_b as ( + SELECT COUNT(*) as num_rows FROM {relation_b} +), row_count_diff as ( + select + 1 as id, + table_a.num_rows - table_b.num_rows as difference + from table_a, table_b +) +select + row_count_diff.difference as row_count_difference, + diff_count.num_missing as num_mismatched +from row_count_diff +join diff_count using (id) +'''.strip() + def catch_as_completed( futures # typing: List[Future[agate.Table]] diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 10dfc69677a..3f4026fe8e5 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -51,6 +51,10 @@ def __eq__(self, other): def get_default_quote_policy(cls) -> Policy: return cls._get_field_named('quote_policy').default + @classmethod + def get_default_include_policy(cls) -> Policy: + return cls._get_field_named('include_policy').default + def get(self, key, default=None): """Override `.get` to return a metadata object so we don't break dbt_utils. diff --git a/dev_requirements.txt b/dev_requirements.txt index d7992ad5c50..8946a6a8cb3 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,5 +1,5 @@ freezegun==0.3.12 -pytest==4.4.0 +pytest==5.4.3 flake8>=3.5.0 pytz==2017.2 bumpversion==0.5.3 diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index 1ae97769553..7b80f096f14 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -738,3 +738,17 @@ def grant_access_to(self, entity, entity_type, role, grant_target_dict): access_entries.append(AccessEntry(role, entity_type, entity)) dataset.access_entries = access_entries client.update_dataset(dataset, ['access_entries']) + + def get_rows_different_sql( + self, + relation_a: BigQueryRelation, + relation_b: BigQueryRelation, + column_names: Optional[List[str]] = None, + except_operator='EXCEPT DISTINCT' + ) -> str: + return super().get_rows_different_sql( + relation_a=relation_a, + relation_b=relation_b, + column_names=column_names, + except_operator=except_operator, + ) diff --git a/test/integration/base.py b/test/integration/base.py index 9079dfb50ea..00e21c54207 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -880,46 +880,11 @@ def get_models_in_schema(self, schema=None): def _assertTablesEqualSql(self, relation_a, relation_b, columns=None): if columns is None: columns = self.get_relation_columns(relation_a) + column_names = [c[0] for c in columns] - columns_csv = ', '.join([self.adapter.quote(record[0]) for record in columns]) - - if self.adapter_type == 'bigquery': - except_operator = 'EXCEPT DISTINCT' - else: - except_operator = 'EXCEPT' - - sql = """ - with diff_count as ( - SELECT - 1 as id, - COUNT(*) as num_missing FROM ( - (SELECT {columns} FROM {relation_a} {except_op} - SELECT {columns} FROM {relation_b}) - UNION ALL - (SELECT {columns} FROM {relation_b} {except_op} - SELECT {columns} FROM {relation_a}) - ) as a - ), table_a as ( - SELECT COUNT(*) as num_rows FROM {relation_a} - ), table_b as ( - SELECT COUNT(*) as num_rows FROM {relation_b} - ), row_count_diff as ( - select - 1 as id, - table_a.num_rows - table_b.num_rows as difference - from table_a, table_b - ) - select - row_count_diff.difference as row_count_difference, - diff_count.num_missing as num_mismatched - from row_count_diff - join diff_count using (id) - """.strip().format( - columns=columns_csv, - relation_a=str(relation_a), - relation_b=str(relation_b), - except_op=except_operator - ) + sql = self.adapter.get_rows_different_sql( + relation_a, relation_b, column_names + ) return sql