From 7b4b7280a7622e1c28e8c132f7159d1398ff9950 Mon Sep 17 00:00:00 2001 From: Joshua Carp Date: Thu, 4 Mar 2021 10:06:52 -0500 Subject: [PATCH] Address comments from code review. --- core/dbt/contracts/connection.py | 4 +-- .../dbt/adapters/bigquery/connections.py | 28 ++++++++++--------- test/unit/test_bigquery_adapter.py | 12 ++++++-- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 22765b583d8..109482e252b 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -9,7 +9,7 @@ from dbt.logger import GLOBAL_LOGGER as logger from typing_extensions import Protocol from dbt.dataclass_schema import ( - dbtClassMixin, StrEnum, ExtensibleDbtClassMixin, + dbtClassMixin, StrEnum, ExtensibleDbtClassMixin, HyphenatedDbtClassMixin, ValidatedStringMixin, register_pattern ) from dbt.contracts.util import Replaceable @@ -212,7 +212,7 @@ def to_target_dict(self): @dataclass -class QueryComment(dbtClassMixin): +class QueryComment(HyphenatedDbtClassMixin): comment: str = DEFAULT_QUERY_COMMENT append: bool = False job_label: bool = False diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index 99e68f7313f..69ae2178532 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -307,23 +307,15 @@ def raw_execute(self, sql, fetch=False, *, use_legacy_sql=False): logger.debug('On {}: {}', conn.name, sql) - labels = {} + if self.profile.query_comment.job_label: + query_comment = self.query_header.comment.query_comment + labels = self._labels_from_query_comment(query_comment) + else: + labels = {} if active_user: labels['dbt_invocation_id'] = active_user.invocation_id - if self.profile.query_comment.job_label: - try: - comment_labels = json.loads( - self.query_header.comment.query_comment - ) - labels.update({ - _sanitize_label(key): _sanitize_label(str(value)) - for key, value in comment_labels.items() - }) - except (TypeError, ValueError): - pass - job_params = {'use_legacy_sql': use_legacy_sql, 'labels': labels} priority = conn.credentials.priority @@ -558,6 +550,16 @@ def _retry_generator(self): initial=self.DEFAULT_INITIAL_DELAY, maximum=self.DEFAULT_MAXIMUM_DELAY) + def _labels_from_query_comment(self, comment: str) -> Dict: + try: + comment_labels = json.loads(comment) + except (TypeError, ValueError): + return {'query_comment': _sanitize_label(comment)} + return { + _sanitize_label(key): _sanitize_label(str(value)) + for key, value in comment_labels.items() + } + class _ErrorCounter(object): """Counts errors seen up to a threshold then raises the next error.""" diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index 00a4ef3b885..2df7a56ffba 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -1,5 +1,6 @@ import agate import decimal +import json import re import unittest from contextlib import contextmanager @@ -588,7 +589,6 @@ def test_query_and_results(self, mock_bq): self.mock_client.query.assert_called_once_with( 'sql', job_config=mock_bq.QueryJobConfig()) - def test_copy_bq_table_appends(self): self._copy_table( write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND) @@ -615,12 +615,20 @@ def test_copy_bq_table_truncates(self): kwargs['job_config'].write_disposition, dbt.adapters.bigquery.impl.WRITE_TRUNCATE) + def test_job_labels_valid_json(self): + expected = {"key": "value"} + labels = self.connections._labels_from_query_comment(json.dumps(expected)) + self.assertEqual(labels, expected) + + def test_job_labels_invalid_json(self): + labels = self.connections._labels_from_query_comment("not json") + self.assertEqual(labels, {"query_comment": "not_json"}) + def _table_ref(self, proj, ds, table, conn): return google.cloud.bigquery.table.TableReference.from_string( '{}.{}.{}'.format(proj, ds, table)) def _copy_table(self, write_disposition): - self.connections.table_ref = self._table_ref source = BigQueryRelation.create( database='project', schema='dataset', identifier='table1')