Skip to content

Commit

Permalink
Merge pull request #3145 from jmcarp/jmcarp/bigquery-job-labels
Browse files Browse the repository at this point in the history
Parse query comment and use as bigquery job labels.
  • Loading branch information
jtcohen6 authored Mar 22, 2021
2 parents 7435828 + 044a6c6 commit 934c23b
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
### Features
- Add optional configs for `require_partition_filter` and `partition_expiration_days` in BigQuery ([#1843](https://github.com/fishtown-analytics/dbt/issues/1843), [#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
- Fix for EOL SQL comments prevent entire line execution ([#2731](https://github.com/fishtown-analytics/dbt/issues/2731), [#2974](https://github.com/fishtown-analytics/dbt/pull/2974))
- Use query comment JSON as job labels for BigQuery adapter when `query-comment.job-label` is set to `true` ([#2483](https://github.com/fishtown-analytics/dbt/issues/2483)), ([#3145](https://github.com/fishtown-analytics/dbt/pull/3145))

### Under the hood
- Add dependabot configuration for alerting maintainers about keeping dependencies up to date and secure. ([#3061](https://github.com/fishtown-analytics/dbt/issues/3061), [#3062](https://github.com/fishtown-analytics/dbt/pull/3062))
Expand All @@ -26,6 +27,7 @@ Contributors:
- [ran-eh](https://github.com/ran-eh) ([#3036](https://github.com/fishtown-analytics/dbt/pull/3036))
- [@pcasteran](https://github.com/pcasteran) ([#2976](https://github.com/fishtown-analytics/dbt/pull/2976))
- [@VasiliiSurov](https://github.com/VasiliiSurov) ([#3104](https://github.com/fishtown-analytics/dbt/pull/3104))
- [@jmcarp](https://github.com/jmcarp) ([#3145](https://github.com/fishtown-analytics/dbt/pull/3145))
- [@bastienboutonnet](https://github.com/bastienboutonnet) ([#3151](https://github.com/fishtown-analytics/dbt/pull/3151))
- [@techytushar](https://github.com/techytushar) ([#3158](https://github.com/fishtown-analytics/dbt/pull/3158))
- [@cgopalan](https://github.com/cgopalan) ([#3165](https://github.com/fishtown-analytics/dbt/pull/3165))
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/contracts/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dbt.logger import GLOBAL_LOGGER as logger
from typing_extensions import Protocol
from dbt.dataclass_schema import (
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin,
dbtClassMixin, StrEnum, ExtensibleDbtClassMixin, HyphenatedDbtClassMixin,
ValidatedStringMixin, register_pattern
)
from dbt.contracts.util import Replaceable
Expand Down Expand Up @@ -212,9 +212,10 @@ def to_target_dict(self):


@dataclass
class QueryComment(dbtClassMixin):
class QueryComment(HyphenatedDbtClassMixin):
comment: str = DEFAULT_QUERY_COMMENT
append: bool = False
job_label: bool = False


class AdapterRequiredConfig(HasCredentials, Protocol):
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/dataclass_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def __post_serialize__(self, dct):
# performing the conversion to a dict
@classmethod
def __pre_deserialize__(cls, data):
if cls._hyphenated:
# `data` might not be a dict, e.g. for `query_comment`, which accepts
# a dict or a string; only snake-case for dict values.
if cls._hyphenated and isinstance(data, dict):
new_dict = {}
for key in data:
if '-' in key:
Expand Down
34 changes: 30 additions & 4 deletions plugins/bigquery/dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import re
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
Expand Down Expand Up @@ -305,12 +307,16 @@ def raw_execute(self, sql, fetch=False, *, use_legacy_sql=False):

logger.debug('On {}: {}', conn.name, sql)

job_params = {'use_legacy_sql': use_legacy_sql}
if self.profile.query_comment.job_label:
query_comment = self.query_header.comment.query_comment
labels = self._labels_from_query_comment(query_comment)
else:
labels = {}

if active_user:
job_params['labels'] = {
'dbt_invocation_id': active_user.invocation_id
}
labels['dbt_invocation_id'] = active_user.invocation_id

job_params = {'use_legacy_sql': use_legacy_sql, 'labels': labels}

priority = conn.credentials.priority
if priority == Priority.Batch:
Expand Down Expand Up @@ -544,6 +550,16 @@ def _retry_generator(self):
initial=self.DEFAULT_INITIAL_DELAY,
maximum=self.DEFAULT_MAXIMUM_DELAY)

def _labels_from_query_comment(self, comment: str) -> Dict:
try:
comment_labels = json.loads(comment)
except (TypeError, ValueError):
return {'query_comment': _sanitize_label(comment)}
return {
_sanitize_label(key): _sanitize_label(str(value))
for key, value in comment_labels.items()
}


class _ErrorCounter(object):
"""Counts errors seen up to a threshold then raises the next error."""
Expand Down Expand Up @@ -573,3 +589,13 @@ def _is_retryable(error):
e['reason'] == 'rateLimitExceeded' for e in error.errors):
return True
return False


_SANITIZE_LABEL_PATTERN = re.compile(r"[^a-z0-9_-]")


def _sanitize_label(value: str) -> str:
"""Return a legal value for a BigQuery label."""
value = value.strip().lower()
value = _SANITIZE_LABEL_PATTERN.sub("_", value)
return value
26 changes: 24 additions & 2 deletions test/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import agate
import decimal
import json
import re
import pytest
import unittest
from contextlib import contextmanager
from requests.exceptions import ConnectionError
Expand All @@ -15,6 +17,7 @@
from dbt.adapters.bigquery import BigQueryRelation
from dbt.adapters.bigquery import Plugin as BigQueryPlugin
from dbt.adapters.bigquery.connections import BigQueryConnectionManager
from dbt.adapters.bigquery.connections import _sanitize_label
from dbt.adapters.base.query_headers import MacroQueryStringSetter
from dbt.clients import agate_helper
import dbt.exceptions
Expand Down Expand Up @@ -588,7 +591,6 @@ def test_query_and_results(self, mock_bq):
self.mock_client.query.assert_called_once_with(
'sql', job_config=mock_bq.QueryJobConfig())


def test_copy_bq_table_appends(self):
self._copy_table(
write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND)
Expand All @@ -615,12 +617,20 @@ def test_copy_bq_table_truncates(self):
kwargs['job_config'].write_disposition,
dbt.adapters.bigquery.impl.WRITE_TRUNCATE)

def test_job_labels_valid_json(self):
expected = {"key": "value"}
labels = self.connections._labels_from_query_comment(json.dumps(expected))
self.assertEqual(labels, expected)

def test_job_labels_invalid_json(self):
labels = self.connections._labels_from_query_comment("not json")
self.assertEqual(labels, {"query_comment": "not_json"})

def _table_ref(self, proj, ds, table, conn):
return google.cloud.bigquery.table.TableReference.from_string(
'{}.{}.{}'.format(proj, ds, table))

def _copy_table(self, write_disposition):

self.connections.table_ref = self._table_ref
source = BigQueryRelation.create(
database='project', schema='dataset', identifier='table1')
Expand Down Expand Up @@ -931,3 +941,15 @@ def test_convert_time_type(self):
expected = ['time', 'time', 'time']
for col_idx, expect in enumerate(expected):
assert BigQueryAdapter.convert_time_type(agate_table, col_idx) == expect


@pytest.mark.parametrize(
["input", "output"],
[
("ABC", "abc"),
("a c", "a_c"),
("a ", "a"),
],
)
def test_sanitize_label(input, output):
assert _sanitize_label(input) == output

0 comments on commit 934c23b

Please sign in to comment.