diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 51b632ace..1a8094437 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -4,7 +4,7 @@ parse = (?P\d+) \.(?P\d+) \.(?P\d+) ((?Pa|b|rc)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}{prerelease}{num} {major}.{minor}.{patch} commit = False @@ -13,7 +13,7 @@ tag = False [bumpversion:part:prerelease] first_value = a optional_value = final -values = +values = a b rc diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..f39d154c0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +select = + E + W + F +ignore = + W503 # makes Flake8 work like black + W504 + E203 # makes Flake8 work like black + E741 + E501 +exclude = test diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 0295cce93..7f41f4c88 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -18,4 +18,4 @@ resolves # - [ ] I have signed the [CLA](https://docs.getdbt.com/docs/contributor-license-agreements) - [ ] I have run this code in development and it appears to resolve the stated issue - [ ] This PR includes tests, or tests are not required/relevant for this PR -- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-snowflake next" section. \ No newline at end of file +- [ ] I have updated the `CHANGELOG.md` and added information about my change to the "dbt-snowflake next" section. diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 93819e735..59894df8c 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -221,9 +221,9 @@ jobs: post-failure: runs-on: ubuntu-latest - needs: test + needs: test if: ${{ failure() }} - + steps: - name: Posting scheduled run failures uses: ravsamhq/notify-slack-action@v1 diff --git a/.github/workflows/jira-creation.yml b/.github/workflows/jira-creation.yml index c84e106a7..b4016befc 100644 --- a/.github/workflows/jira-creation.yml +++ b/.github/workflows/jira-creation.yml @@ -13,7 +13,7 @@ name: Jira Issue Creation on: issues: types: [opened, labeled] - + permissions: issues: write diff --git a/.github/workflows/jira-label.yml b/.github/workflows/jira-label.yml index fd533a170..3da2e3a38 100644 --- a/.github/workflows/jira-label.yml +++ b/.github/workflows/jira-label.yml @@ -13,7 +13,7 @@ name: Jira Label Mirroring on: issues: types: [labeled, unlabeled] - + permissions: issues: read @@ -24,4 +24,3 @@ jobs: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} - diff --git a/.github/workflows/jira-transition.yml b/.github/workflows/jira-transition.yml index 71273c7a9..ed9f9cd4f 100644 --- a/.github/workflows/jira-transition.yml +++ b/.github/workflows/jira-transition.yml @@ -21,4 +21,4 @@ jobs: secrets: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} - JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} \ No newline at end of file + JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6254fc339..1ae22c400 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,19 +37,10 @@ defaults: jobs: code-quality: - name: ${{ matrix.toxenv }} + name: code-quality runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - toxenv: [flake8] - - env: - TOXENV: ${{ matrix.toxenv }} - PYTEST_ADDOPTS: "-v --color=yes" - steps: - name: Check out the repository uses: actions/checkout@v2 @@ -62,12 +53,16 @@ jobs: - name: Install python dependencies run: | pip install --user --upgrade pip - pip install tox + pip install pre-commit + pip install mypy==0.782 + pip install -r dev_requirements.txt pip --version - tox --version + pre-commit --version + mypy --version + dbt --version - - name: Run tox - run: tox + - name: Run pre-commit hooks + run: pre-commit run --all-files --show-diff-on-failure unit: name: unit test / python ${{ matrix.python-version }} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 2848ce8f7..a56455d55 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,5 +13,3 @@ jobs: stale-pr-message: "This PR has been marked as Stale because it has been open for 180 days with no activity. If you would like the PR to remain open, please remove the stale label or comment on the PR, or it will be closed in 7 days." # mark issues/PRs stale when they haven't seen activity in 180 days days-before-stale: 180 - # ignore checking issues with the following labels - exempt-issue-labels: "epic, discussion" \ No newline at end of file diff --git a/.github/workflows/version-bump.yml b/.github/workflows/version-bump.yml index 4913a6e84..b0a3174df 100644 --- a/.github/workflows/version-bump.yml +++ b/.github/workflows/version-bump.yml @@ -1,16 +1,16 @@ # **what?** # This workflow will take a version number and a dry run flag. With that -# it will run versionbump to update the version number everywhere in the +# it will run versionbump to update the version number everywhere in the # code base and then generate an update Docker requirements file. If this # is a dry run, a draft PR will open with the changes. If this isn't a dry # run, the changes will be committed to the branch this is run on. # **why?** -# This is to aid in releasing dbt and making sure we have updated +# This is to aid in releasing dbt and making sure we have updated # the versions and Docker requirements in all places. # **when?** -# This is triggered either manually OR +# This is triggered either manually OR # from the repository_dispatch event "version-bump" which is sent from # the dbt-release repo Action @@ -25,11 +25,11 @@ on: is_dry_run: description: 'Creates a draft PR to allow testing instead of committing to a branch' required: true - default: 'true' + default: 'true' repository_dispatch: types: [version-bump] -jobs: +jobs: bump: runs-on: ubuntu-latest steps: @@ -57,19 +57,19 @@ jobs: run: | python3 -m venv env source env/bin/activate - pip install --upgrade pip - + pip install --upgrade pip + - name: Create PR branch if: ${{ steps.variables.outputs.IS_DRY_RUN == 'true' }} run: | git checkout -b bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git push origin bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID git branch --set-upstream-to=origin/bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_$GITHUB_RUN_ID - + - name: Bumping version run: | source env/bin/activate - pip install -r dev_requirements.txt + pip install -r dev_requirements.txt env/bin/bumpversion --allow-dirty --new-version ${{steps.variables.outputs.VERSION_NUMBER}} major git status @@ -99,4 +99,4 @@ jobs: draft: true base: ${{github.ref}} title: 'Bumping version to ${{steps.variables.outputs.VERSION_NUMBER}}' - branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' + branch: 'bumping-version/${{steps.variables.outputs.VERSION_NUMBER}}_${{GITHUB.RUN_ID}}' diff --git a/.gitignore b/.gitignore index 43724b61e..780d98f70 100644 --- a/.gitignore +++ b/.gitignore @@ -49,9 +49,7 @@ coverage.xml *,cover .hypothesis/ test.env - -# Mypy -.mypy_cache/ +*.pytest_cache/ # Translations *.mo @@ -66,10 +64,10 @@ docs/_build/ # PyBuilder target/ -#Ipython Notebook +# Ipython Notebook .ipynb_checkpoints -#Emacs +# Emacs *~ # Sublime Text @@ -78,6 +76,7 @@ target/ # Vim *.sw* +# Pyenv .python-version # Vim @@ -90,6 +89,7 @@ venv/ # AWS credentials .aws/ +# MacOS .DS_Store # vscode diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..ccaa093bf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,66 @@ +# For more on configuring pre-commit hooks (see https://pre-commit.com/) + +# TODO: remove global exclusion of tests when testing overhaul is complete +exclude: '^tests/.*' + +# Force all unspecified python hooks to run python 3.8 +default_language_version: + python: python3.8 + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: check-yaml + args: [--unsafe] + - id: check-json + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-case-conflict +- repo: https://github.com/psf/black + rev: 21.12b0 + hooks: + - id: black + additional_dependencies: ['click==8.0.4'] + args: + - "--line-length=99" + - "--target-version=py38" + - id: black + alias: black-check + stages: [manual] + additional_dependencies: ['click==8.0.4'] + args: + - "--line-length=99" + - "--target-version=py38" + - "--check" + - "--diff" +- repo: https://gitlab.com/pycqa/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + - id: flake8 + alias: flake8-check + stages: [manual] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.782 + hooks: + - id: mypy + # N.B.: Mypy is... a bit fragile. + # + # By using `language: system` we run this hook in the local + # environment instead of a pre-commit isolated one. This is needed + # to ensure mypy correctly parses the project. + + # It may cause trouble in that it adds environmental variables out + # of our control to the mix. Unfortunately, there's nothing we can + # do about per pre-commit's author. + # See https://github.com/pre-commit/pre-commit/issues/730 for details. + args: [--show-error-codes, --ignore-missing-imports] + files: ^dbt/adapters/.* + language: system + - id: mypy + alias: mypy-check + stages: [manual] + args: [--show-error-codes, --pretty, --ignore-missing-imports] + files: ^dbt/adapters + language: system diff --git a/CHANGELOG.md b/CHANGELOG.md index c7cec0658..f44cd9ba5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## dbt-snowflake 1.2.0 (tbd) + +### Under the hood +- Add precommits for this repo ([#107](https://github.com/dbt-labs/dbt-snowflake/pull/107)) + ## dbt-snowflake 1.1.0b1 (March 23, 2022) ### Features diff --git a/MANIFEST.in b/MANIFEST.in index 78412d5b8..cfbc714ed 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1 @@ -recursive-include dbt/include *.sql *.yml *.md \ No newline at end of file +recursive-include dbt/include *.sql *.yml *.md diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..05af9a2ca --- /dev/null +++ b/Makefile @@ -0,0 +1,61 @@ +.DEFAULT_GOAL:=help + +.PHONY: dev +dev: ## Installs adapter in develop mode along with development depedencies + @\ + pip install -r dev_requirements.txt && pre-commit install + +.PHONY: mypy +mypy: ## Runs mypy against staged changes for static type checking. + @\ + pre-commit run --hook-stage manual mypy-check | grep -v "INFO" + +.PHONY: flake8 +flake8: ## Runs flake8 against staged changes to enforce style guide. + @\ + pre-commit run --hook-stage manual flake8-check | grep -v "INFO" + +.PHONY: black +black: ## Runs black against staged changes to enforce style guide. + @\ + pre-commit run --hook-stage manual black-check -v | grep -v "INFO" + +.PHONY: lint +lint: ## Runs flake8 and mypy code checks against staged changes. + @\ + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" + +.PHONY: linecheck +linecheck: ## Checks for all Python lines 100 characters or more + @\ + find dbt -type f -name "*.py" -exec grep -I -r -n '.\{100\}' {} \; + +.PHONY: unit +unit: ## Runs unit tests with py38. + @\ + tox -e py38 + +.PHONY: test +test: ## Runs unit tests with py38 and code checks against staged changes. + @\ + tox -p -e py38; \ + pre-commit run black-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run flake8-check --hook-stage manual | grep -v "INFO"; \ + pre-commit run mypy-check --hook-stage manual | grep -v "INFO" + +.PHONY: integration +integration: ## Runs snowflake integration tests with py38. + @\ + tox -e py38-snowflake -- + +.PHONY: clean + @echo "cleaning repo" + @git clean -f -X + +.PHONY: help +help: ## Show this help message. + @echo 'usage: make [target]' + @echo + @echo 'targets:' + @grep -E '^[7+a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/dbt/adapters/snowflake/__init__.py b/dbt/adapters/snowflake/__init__.py index bb4e17c9b..e5f855722 100644 --- a/dbt/adapters/snowflake/__init__.py +++ b/dbt/adapters/snowflake/__init__.py @@ -4,10 +4,9 @@ from dbt.adapters.snowflake.relation import SnowflakeRelation # noqa from dbt.adapters.snowflake.impl import SnowflakeAdapter -from dbt.adapters.base import AdapterPlugin -from dbt.include import snowflake +from dbt.adapters.base import AdapterPlugin # type: ignore +from dbt.include import snowflake # type: ignore Plugin = AdapterPlugin( - adapter=SnowflakeAdapter, - credentials=SnowflakeCredentials, - include_path=snowflake.PACKAGE_PATH) + adapter=SnowflakeAdapter, credentials=SnowflakeCredentials, include_path=snowflake.PACKAGE_PATH +) diff --git a/dbt/adapters/snowflake/__version__.py b/dbt/adapters/snowflake/__version__.py index a86cb5c59..56ec17a89 100644 --- a/dbt/adapters/snowflake/__version__.py +++ b/dbt/adapters/snowflake/__version__.py @@ -1 +1 @@ -version = '1.1.0b1' +version = "1.1.0b1" diff --git a/dbt/adapters/snowflake/column.py b/dbt/adapters/snowflake/column.py index d7afb307f..ac9fdfb62 100644 --- a/dbt/adapters/snowflake/column.py +++ b/dbt/adapters/snowflake/column.py @@ -12,20 +12,32 @@ def is_integer(self) -> bool: def is_numeric(self) -> bool: return self.dtype.lower() in [ - 'int', 'integer', 'bigint', 'smallint', 'tinyint', 'byteint', - 'numeric', 'decimal', 'number' + "int", + "integer", + "bigint", + "smallint", + "tinyint", + "byteint", + "numeric", + "decimal", + "number", ] def is_float(self): return self.dtype.lower() in [ - 'float', 'float4', 'float8', 'double', 'double precision', 'real', + "float", + "float4", + "float8", + "double", + "double precision", + "real", ] def string_size(self) -> int: if not self.is_string(): raise RuntimeException("Called string_size() on non-string field!") - if self.dtype == 'text' or self.char_size is None: + if self.dtype == "text" or self.char_size is None: return 16777216 else: return int(self.char_size) diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index e4567ee09..ed8140924 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -15,17 +15,20 @@ import snowflake.connector.errors from dbt.exceptions import ( - InternalException, RuntimeException, FailedToConnectException, - DatabaseException, warn_or_error + InternalException, + RuntimeException, + FailedToConnectException, + DatabaseException, + warn_or_error, ) -from dbt.adapters.base import Credentials +from dbt.adapters.base import Credentials # type: ignore from dbt.contracts.connection import AdapterResponse -from dbt.adapters.sql import SQLConnectionManager -from dbt.events import AdapterLogger +from dbt.adapters.sql import SQLConnectionManager # type: ignore +from dbt.events import AdapterLogger # type: ignore logger = AdapterLogger("Snowflake") -_TOKEN_REQUEST_URL = 'https://{}.snowflakecomputing.com/oauth/token-request' +_TOKEN_REQUEST_URL = "https://{}.snowflakecomputing.com/oauth/token-request" @dataclass @@ -60,19 +63,18 @@ class SnowflakeCredentials(Credentials): insecure_mode: Optional[bool] = False def __post_init__(self): - if ( - self.authenticator != 'oauth' and - (self.oauth_client_secret or self.oauth_client_id or self.token) + if self.authenticator != "oauth" and ( + self.oauth_client_secret or self.oauth_client_id or self.token ): # the user probably forgot to set 'authenticator' like I keep doing warn_or_error( - 'Authenticator is not set to oauth, but an oauth-only ' - 'parameter is set! Did you mean to set authenticator: oauth?' + "Authenticator is not set to oauth, but an oauth-only " + "parameter is set! Did you mean to set authenticator: oauth?" ) @property def type(self): - return 'snowflake' + return "snowflake" @property def unique_field(self): @@ -80,8 +82,13 @@ def unique_field(self): def _connection_keys(self): return ( - 'account', 'user', 'database', 'schema', 'warehouse', 'role', - 'client_session_keep_alive' + "account", + "user", + "database", + "schema", + "warehouse", + "role", + "client_session_keep_alive", ) def auth_args(self): @@ -89,20 +96,20 @@ def auth_args(self): # let connector handle the actual arg validation result = {} if self.password: - result['password'] = self.password + result["password"] = self.password if self.host: - result['host'] = self.host + result["host"] = self.host if self.port: - result['port'] = self.port + result["port"] = self.port if self.proxy_host: - result['proxy_host'] = self.proxy_host + result["proxy_host"] = self.proxy_host if self.proxy_port: - result['proxy_port'] = self.proxy_port + result["proxy_port"] = self.proxy_port if self.protocol: - result['protocol'] = self.protocol + result["protocol"] = self.protocol if self.authenticator: - result['authenticator'] = self.authenticator - if self.authenticator == 'oauth': + result["authenticator"] = self.authenticator + if self.authenticator == "oauth": token = self.token # if we have a client ID/client secret, the token is a refresh # token, not an access token @@ -110,54 +117,51 @@ def auth_args(self): token = self._get_access_token() elif self.oauth_client_id: warn_or_error( - 'Invalid profile: got an oauth_client_id, but not an ' - 'oauth_client_secret!' + "Invalid profile: got an oauth_client_id, but not an " + "oauth_client_secret!" ) elif self.oauth_client_secret: warn_or_error( - 'Invalid profile: got an oauth_client_secret, but not ' - 'an oauth_client_id!' + "Invalid profile: got an oauth_client_secret, but not " + "an oauth_client_id!" ) - result['token'] = token + result["token"] = token # enable id token cache for linux - result['client_store_temporary_credential'] = True + result["client_store_temporary_credential"] = True # enable mfa token cache for linux - result['client_request_mfa_token'] = True - result['private_key'] = self._get_private_key() + result["client_request_mfa_token"] = True + result["private_key"] = self._get_private_key() return result def _get_access_token(self) -> str: - if self.authenticator != 'oauth': - raise InternalException('Can only get access tokens for oauth') + if self.authenticator != "oauth": + raise InternalException("Can only get access tokens for oauth") missing = any( - x is None for x in - (self.oauth_client_id, self.oauth_client_secret, self.token) + x is None for x in (self.oauth_client_id, self.oauth_client_secret, self.token) ) if missing: raise InternalException( - 'need a client ID a client secret, and a refresh token to get ' - 'an access token' + "need a client ID a client secret, and a refresh token to get " "an access token" ) # should the full url be a config item? token_url = _TOKEN_REQUEST_URL.format(self.account) # I think this is only used to redirect on success, which we ignore # (it does not have to match the integration's settings in snowflake) - redirect_uri = 'http://localhost:9999' + redirect_uri = "http://localhost:9999" data = { - 'grant_type': 'refresh_token', - 'refresh_token': self.token, - 'redirect_uri': redirect_uri + "grant_type": "refresh_token", + "refresh_token": self.token, + "redirect_uri": redirect_uri, } auth = base64.b64encode( - f'{self.oauth_client_id}:{self.oauth_client_secret}' - .encode('ascii') - ).decode('ascii') + f"{self.oauth_client_id}:{self.oauth_client_secret}".encode("ascii") + ).decode("ascii") headers = { - 'Authorization': f'Basic {auth}', - 'Content-type': 'application/x-www-form-urlencoded;charset=utf-8' + "Authorization": f"Basic {auth}", + "Content-type": "application/x-www-form-urlencoded;charset=utf-8", } result_json = None @@ -170,15 +174,19 @@ def _get_access_token(self) -> str: break except ValueError as e: message = result.text - logger.debug(f"Got a non-json response ({result.status_code}): \ - {e}, message: {message}") + logger.debug( + f"Got a non-json response ({result.status_code}): \ + {e}, message: {message}" + ) sleep(0.05) if result_json is None: - raise DatabaseException(f"""Did not receive valid json with access_token. - Showing json response: {result_json}""") + raise DatabaseException( + f"""Did not receive valid json with access_token. + Showing json response: {result_json}""" + ) - return result_json['access_token'] + return result_json["access_token"] def _get_private_key(self): """Get Snowflake private key by path or None.""" @@ -190,20 +198,20 @@ def _get_private_key(self): else: encoded_passphrase = None - with open(self.private_key_path, 'rb') as key: + with open(self.private_key_path, "rb") as key: p_key = serialization.load_pem_private_key( - key.read(), - password=encoded_passphrase, - backend=default_backend()) + key.read(), password=encoded_passphrase, backend=default_backend() + ) return p_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption()) + encryption_algorithm=serialization.NoEncryption(), + ) class SnowflakeConnectionManager(SQLConnectionManager): - TYPE = 'snowflake' + TYPE = "snowflake" @contextmanager def exception_handler(self, sql): @@ -212,23 +220,25 @@ def exception_handler(self, sql): except snowflake.connector.errors.ProgrammingError as e: msg = str(e) - logger.debug('Snowflake query id: {}'.format(e.sfqid)) - logger.debug('Snowflake error: {}'.format(msg)) + logger.debug("Snowflake query id: {}".format(e.sfqid)) + logger.debug("Snowflake error: {}".format(msg)) - if 'Empty SQL statement' in msg: + if "Empty SQL statement" in msg: logger.debug("got empty sql statement, moving on") - elif 'This session does not have a current database' in msg: + elif "This session does not have a current database" in msg: raise FailedToConnectException( - ('{}\n\nThis error sometimes occurs when invalid ' - 'credentials are provided, or when your default role ' - 'does not have access to use the specified database. ' - 'Please double check your profile and try again.') - .format(msg)) + ( + "{}\n\nThis error sometimes occurs when invalid " + "credentials are provided, or when your default role " + "does not have access to use the specified database. " + "Please double check your profile and try again." + ).format(msg) + ) else: raise DatabaseException(msg) except Exception as e: if isinstance(e, snowflake.connector.errors.Error): - logger.debug('Snowflake query id: {}'.format(e.sfqid)) + logger.debug("Snowflake query id: {}".format(e.sfqid)) logger.debug("Error running SQL: {}", sql) logger.debug("Rolling back transaction.") @@ -242,8 +252,8 @@ def exception_handler(self, sql): @classmethod def open(cls, connection): - if connection.state == 'open': - logger.debug('Connection is already open, skipping open.') + if connection.state == "open": + logger.debug("Connection is already open, skipping open.") return connection creds = connection.credentials @@ -259,80 +269,90 @@ def open(cls, connection): role=creds.role, autocommit=True, client_session_keep_alive=creds.client_session_keep_alive, - application='dbt', + application="dbt", insecure_mode=creds.insecure_mode, - **creds.auth_args() + **creds.auth_args(), ) if creds.query_tag: handle.cursor().execute( - ("alter session set query_tag = '{}'") - .format(creds.query_tag)) + ("alter session set query_tag = '{}'").format(creds.query_tag) + ) connection.handle = handle - connection.state = 'open' + connection.state = "open" break except snowflake.connector.errors.DatabaseError as e: - if (creds.retry_on_database_errors or creds.retry_all) \ - and creds.connect_retries > 0: + if ( + creds.retry_on_database_errors or creds.retry_all + ) and creds.connect_retries > 0: error = e - logger.warning("Got an error when attempting to open a " - "snowflake connection. Retrying due to " - "either retry configuration set to true." - "This was attempt number: {attempt} of " - "{retry_limit}. " - "Retrying in {timeout} " - "seconds. Error: '{error}'" - .format(attempt=attempt, - retry_limit=creds.connect_retries, - timeout=creds.connect_timeout, - error=e)) + logger.warning( + "Got an error when attempting to open a " + "snowflake connection. Retrying due to " + "either retry configuration set to true." + "This was attempt number: {attempt} of " + "{retry_limit}. " + "Retrying in {timeout} " + "seconds. Error: '{error}'".format( + attempt=attempt, + retry_limit=creds.connect_retries, + timeout=creds.connect_timeout, + error=e, + ) + ) sleep(creds.connect_timeout) else: - logger.debug("Got an error when attempting to open a " - "snowflake connection. No retries " - "attempted: '{}'" - .format(e)) + logger.debug( + "Got an error when attempting to open a " + "snowflake connection. No retries " + "attempted: '{}'".format(e) + ) connection.handle = None - connection.state = 'fail' + connection.state = "fail" raise FailedToConnectException(str(e)) except snowflake.connector.errors.Error as e: if creds.retry_all and creds.connect_retries > 0: error = e - logger.warning("Got an error when attempting to open a " - "snowflake connection. Retrying due to " - "'retry_all' configuration set to true." - "This was attempt number: {attempt} of " - "{retry_limit}. " - "Retrying in {timeout} " - "seconds. Error: '{error}'" - .format(attempt=attempt, - retry_limit=creds.connect_retries, - timeout=creds.connect_timeout, - error=e)) + logger.warning( + "Got an error when attempting to open a " + "snowflake connection. Retrying due to " + "'retry_all' configuration set to true." + "This was attempt number: {attempt} of " + "{retry_limit}. " + "Retrying in {timeout} " + "seconds. Error: '{error}'".format( + attempt=attempt, + retry_limit=creds.connect_retries, + timeout=creds.connect_timeout, + error=e, + ) + ) sleep(creds.connect_timeout) else: - logger.debug("Got an error when attempting to open a " - "snowflake connection. No retries " - "attempted: '{}'" - .format(e)) + logger.debug( + "Got an error when attempting to open a " + "snowflake connection. No retries " + "attempted: '{}'".format(e) + ) connection.handle = None - connection.state = 'fail' + connection.state = "fail" raise FailedToConnectException(str(e)) else: - logger.debug("Got an error when attempting to open a snowflake " - "connection: '{}'" - .format(error)) + logger.debug( + "Got an error when attempting to open a snowflake " + "connection: '{}'".format(error) + ) connection.handle = None - connection.state = 'fail' + connection.state = "fail" raise FailedToConnectException(str(error)) def cancel(self, connection): @@ -341,7 +361,7 @@ def cancel(self, connection): connection_name = connection.name - sql = 'select system$abort_session({})'.format(sid) + sql = "select system$abort_session({})".format(sid) logger.debug("Cancelling query '{}' ({})".format(connection_name, sid)) @@ -355,14 +375,14 @@ def get_response(cls, cursor) -> SnowflakeAdapterResponse: code = cursor.sqlstate if code is None: - code = 'SUCCESS' + code = "SUCCESS" return SnowflakeAdapterResponse( _message="{} {}".format(code, cursor.rowcount), rows_affected=cursor.rowcount, code=code, - query_id=cursor.sfqid - ) + query_id=cursor.sfqid, + ) # type: ignore # disable transactional logic by default on Snowflake # except for DML statements where explicitly defined @@ -410,8 +430,7 @@ def process_results(cls, column_names, rows): return super().process_results(column_names, fixed) - def add_query(self, sql, auto_begin=True, - bindings=None, abridge_sql_log=False): + def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): connection = None cursor = None @@ -428,24 +447,22 @@ def add_query(self, sql, auto_begin=True, # empty queries. this avoids using exceptions as flow control, # and also allows us to return the status of the last cursor without_comments = re.sub( - re.compile( - r'(\".*?\"|\'.*?\')|(/\*.*?\*/|--[^\r\n]*$)', re.MULTILINE - ), - '', individual_query).strip() + re.compile(r"(\".*?\"|\'.*?\')|(/\*.*?\*/|--[^\r\n]*$)", re.MULTILINE), + "", + individual_query, + ).strip() if without_comments == "": continue connection, cursor = super().add_query( - individual_query, auto_begin, - bindings=bindings, - abridge_sql_log=abridge_sql_log + individual_query, auto_begin, bindings=bindings, abridge_sql_log=abridge_sql_log ) if cursor is None: conn = self.get_thread_connection() if conn is None or conn.name is None: - conn_name = '' + conn_name = "" else: conn_name = conn.name @@ -453,8 +470,7 @@ def add_query(self, sql, auto_begin=True, "Tried to run an empty query on model '{}'. If you are " "conditionally running\nsql, eg. in a model hook, make " "sure your `else` clause contains valid sql!\n\n" - "Provided SQL:\n{}" - .format(conn_name, sql) + "Provided SQL:\n{}".format(conn_name, sql) ) return connection, cursor diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index e609ba5f5..f2af32795 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -4,7 +4,7 @@ import agate from dbt.adapters.base.impl import AdapterConfig -from dbt.adapters.sql import SQLAdapter +from dbt.adapters.sql import SQLAdapter # type: ignore from dbt.adapters.sql.impl import ( LIST_SCHEMAS_MACRO_NAME, LIST_RELATIONS_MACRO_NAME, @@ -13,9 +13,7 @@ from dbt.adapters.snowflake import SnowflakeRelation from dbt.adapters.snowflake import SnowflakeColumn from dbt.contracts.graph.manifest import Manifest -from dbt.exceptions import ( - raise_compiler_error, RuntimeException, DatabaseException -) +from dbt.exceptions import raise_compiler_error, RuntimeException, DatabaseException from dbt.utils import filter_null_values @@ -43,14 +41,10 @@ def date_function(cls): return "CURRENT_TIMESTAMP()" @classmethod - def _catalog_filter_table( - cls, table: agate.Table, manifest: Manifest - ) -> agate.Table: + def _catalog_filter_table(cls, table: agate.Table, manifest: Manifest) -> agate.Table: # On snowflake, users can set QUOTED_IDENTIFIERS_IGNORE_CASE, so force # the column names to their lowercased forms. - lowered = table.rename( - column_names=[c.lower() for c in table.column_names] - ) + lowered = table.rename(column_names=[c.lower() for c in table.column_names]) return super()._catalog_filter_table(lowered, manifest) def _make_match_kwargs(self, database, schema, identifier): @@ -69,105 +63,85 @@ def _make_match_kwargs(self, database, schema, identifier): ) def _get_warehouse(self) -> str: - _, table = self.execute( - 'select current_warehouse() as warehouse', - fetch=True - ) + _, table = self.execute("select current_warehouse() as warehouse", fetch=True) if len(table) == 0 or len(table[0]) == 0: # can this happen? - raise RuntimeException( - 'Could not get current warehouse: no results' - ) + raise RuntimeException("Could not get current warehouse: no results") return str(table[0][0]) def _use_warehouse(self, warehouse: str): """Use the given warehouse. Quotes are never applied.""" - self.execute('use warehouse {}'.format(warehouse)) + self.execute("use warehouse {}".format(warehouse)) def pre_model_hook(self, config: Mapping[str, Any]) -> Optional[str]: default_warehouse = self.config.credentials.warehouse - warehouse = config.get('snowflake_warehouse', default_warehouse) + warehouse = config.get("snowflake_warehouse", default_warehouse) if warehouse == default_warehouse or warehouse is None: return None previous = self._get_warehouse() self._use_warehouse(warehouse) return previous - def post_model_hook( - self, config: Mapping[str, Any], context: Optional[str] - ) -> None: + def post_model_hook(self, config: Mapping[str, Any], context: Optional[str]) -> None: if context is not None: self._use_warehouse(context) def list_schemas(self, database: str) -> List[str]: try: - results = self.execute_macro( - LIST_SCHEMAS_MACRO_NAME, - kwargs={'database': database} - ) + results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}) except DatabaseException as exc: - msg = ( - f'Database error while listing schemas in database ' - f'"{database}"\n{exc}' - ) + msg = f"Database error while listing schemas in database " f'"{database}"\n{exc}' raise RuntimeException(msg) # this uses 'show terse schemas in database', and the column name we # want is 'name' - return [row['name'] for row in results] + return [row["name"] for row in results] def get_columns_in_relation(self, relation): try: return super().get_columns_in_relation(relation) except DatabaseException as exc: - if 'does not exist or not authorized' in str(exc): + if "does not exist or not authorized" in str(exc): return [] else: raise def list_relations_without_caching( - self, schema_relation: SnowflakeRelation + self, schema_relation: SnowflakeRelation ) -> List[SnowflakeRelation]: - kwargs = {'schema_relation': schema_relation} + kwargs = {"schema_relation": schema_relation} try: - results = self.execute_macro( - LIST_RELATIONS_MACRO_NAME, - kwargs=kwargs - ) + results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs) except DatabaseException as exc: # if the schema doesn't exist, we just want to return. # Alternatively, we could query the list of schemas before we start # and skip listing the missing ones, which sounds expensive. - if 'Object does not exist' in str(exc): + if "Object does not exist" in str(exc): return [] raise relations = [] - quote_policy = { - 'database': True, - 'schema': True, - 'identifier': True - } + quote_policy = {"database": True, "schema": True, "identifier": True} - columns = ['database_name', 'schema_name', 'name', 'kind'] + columns = ["database_name", "schema_name", "name", "kind"] for _database, _schema, _identifier, _type in results.select(columns): try: _type = self.Relation.get_relation_type(_type.lower()) except ValueError: _type = self.Relation.External - relations.append(self.Relation.create( - database=_database, - schema=_schema, - identifier=_identifier, - quote_policy=quote_policy, - type=_type - )) + relations.append( + self.Relation.create( + database=_database, + schema=_schema, + identifier=_identifier, + quote_policy=quote_policy, + type=_type, + ) + ) return relations - def quote_seed_column( - self, column: str, quote_config: Optional[bool] - ) -> str: + def quote_seed_column(self, column: str, quote_config: Optional[bool]) -> str: quote_columns: bool = False if isinstance(quote_config, bool): quote_columns = quote_config @@ -176,7 +150,7 @@ def quote_seed_column( else: raise_compiler_error( f'The seed configuration value of "quote_columns" has an ' - f'invalid type {type(quote_config)}' + f"invalid type {type(quote_config)}" ) if quote_columns: @@ -184,7 +158,5 @@ def quote_seed_column( else: return column - def timestamp_add_sql( - self, add_to: str, number: int = 1, interval: str = 'hour' - ) -> str: - return f'DATEADD({interval}, {number}, {add_to})' + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: + return f"DATEADD({interval}, {number}, {add_to})" diff --git a/dbt/include/snowflake/__init__.py b/dbt/include/snowflake/__init__.py index 564a3d1e8..b177e5d49 100644 --- a/dbt/include/snowflake/__init__.py +++ b/dbt/include/snowflake/__init__.py @@ -1,2 +1,3 @@ import os + PACKAGE_PATH = os.path.dirname(__file__) diff --git a/dbt/include/snowflake/macros/adapters.sql b/dbt/include/snowflake/macros/adapters.sql index 6bd6a8258..c5f07ff5f 100644 --- a/dbt/include/snowflake/macros/adapters.sql +++ b/dbt/include/snowflake/macros/adapters.sql @@ -63,12 +63,12 @@ {{ sql_header if sql_header is not none }} create or replace {% if secure -%} secure - {%- endif %} view {{ relation }} + {%- endif %} view {{ relation }} {% if config.persist_column_docs() -%} {% set model_columns = model.columns %} {% set query_columns = get_columns_in_query(sql) %} {{ get_persist_docs_column_list(model_columns, query_columns) }} - + {%- endif %} {% if copy_grants -%} copy grants {%- endif %} as ( {{ sql }} @@ -214,13 +214,13 @@ {% do run_query("alter session unset query_tag") %} {% endif %} {% endif %} -{% endmacro %} +{% endmacro %} {% macro snowflake__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} - + {% if add_columns %} - + {% set sql -%} alter {{ relation.type }} {{ relation }} add column {% for column in add_columns %} @@ -233,16 +233,16 @@ {% endif %} {% if remove_columns %} - + {% set sql -%} alter {{ relation.type }} {{ relation }} drop column {% for column in remove_columns %} {{ column.name }}{{ ',' if not loop.last }} {% endfor %} {%- endset -%} - + {% do run_query(sql) %} - + {% endif %} {% endmacro %} @@ -250,7 +250,7 @@ {% macro snowflake_dml_explicit_transaction(dml) %} {# - Use this macro to wrap all INSERT, MERGE, UPDATE, DELETE, and TRUNCATE + Use this macro to wrap all INSERT, MERGE, UPDATE, DELETE, and TRUNCATE statements before passing them into run_query(), or calling in the 'main' statement of a materialization #} @@ -259,7 +259,7 @@ {{ dml }}; commit; {%- endset %} - + {% do return(dml_transaction) %} {% endmacro %} diff --git a/dbt/include/snowflake/macros/materializations/incremental.sql b/dbt/include/snowflake/macros/materializations/incremental.sql index 200eb938b..5710284f3 100644 --- a/dbt/include/snowflake/macros/materializations/incremental.sql +++ b/dbt/include/snowflake/macros/materializations/incremental.sql @@ -25,7 +25,7 @@ {% endmacro %} {% materialization incremental, adapter='snowflake' -%} - + {% set original_query_tag = set_query_tag() %} {%- set unique_key = config.get('unique_key') -%} @@ -43,16 +43,16 @@ {% if existing_relation is none %} {% set build_sql = create_table_as(False, target_relation, sql) %} - + {% elif existing_relation.is_view %} {#-- Can't overwrite a view with a table - we must drop --#} {{ log("Dropping relation " ~ target_relation ~ " because it is a view and this model is a table.") }} {% do adapter.drop_relation(existing_relation) %} {% set build_sql = create_table_as(False, target_relation, sql) %} - + {% elif full_refresh_mode %} {% set build_sql = create_table_as(False, target_relation, sql) %} - + {% else %} {% do run_query(create_table_as(True, tmp_relation, sql)) %} {% do adapter.expand_target_column_types( @@ -64,7 +64,7 @@ {% set dest_columns = adapter.get_columns_in_relation(existing_relation) %} {% endif %} {% set build_sql = dbt_snowflake_get_incremental_sql(strategy, tmp_relation, target_relation, unique_key, dest_columns) %} - + {% endif %} {%- call statement('main') -%} @@ -80,4 +80,4 @@ {{ return({'relations': [target_relation]}) }} -{%- endmaterialization %} \ No newline at end of file +{%- endmaterialization %} diff --git a/dbt/include/snowflake/macros/materializations/merge.sql b/dbt/include/snowflake/macros/materializations/merge.sql index 6a8f6b89c..689e93440 100644 --- a/dbt/include/snowflake/macros/materializations/merge.sql +++ b/dbt/include/snowflake/macros/materializations/merge.sql @@ -26,7 +26,7 @@ {%- endif -%} {%- endset -%} - + {% do return(snowflake_dml_explicit_transaction(dml)) %} {% endmacro %} diff --git a/dev_requirements.txt b/dev_requirements.txt index 79fc271d7..b5943e34b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -3,6 +3,8 @@ git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-core&subdirectory=core git+https://github.com/dbt-labs/dbt-core.git#egg=dbt-tests-adapter&subdirectory=tests/adapter +black==21.12b0 +click~=8.0.4 bumpversion flake8 flaky @@ -10,6 +12,7 @@ freezegun==0.3.12 ipdb mypy==0.782 pip-tools +pre-commit pytest pytest-dotenv pytest-logbook diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh index 65e6dbc97..3c3808399 100755 --- a/scripts/build-dist.sh +++ b/scripts/build-dist.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash set -eo pipefail diff --git a/setup.py b/setup.py index ba9bb5bd7..888da6449 100644 --- a/setup.py +++ b/setup.py @@ -5,41 +5,39 @@ # require python 3.7 or newer if sys.version_info < (3, 7): - print('Error: dbt does not support this version of Python.') - print('Please upgrade to Python 3.7 or higher.') + print("Error: dbt does not support this version of Python.") + print("Please upgrade to Python 3.7 or higher.") sys.exit(1) # require version of setuptools that supports find_namespace_packages from setuptools import setup + try: from setuptools import find_namespace_packages except ImportError: # the user has a downlevel version of setuptools. - print('Error: dbt requires setuptools v40.1.0 or higher.') - print('Please upgrade setuptools with "pip install --upgrade setuptools" ' - 'and try again') + print("Error: dbt requires setuptools v40.1.0 or higher.") + print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again") sys.exit(1) # pull long description from README this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md')) as f: +with open(os.path.join(this_directory, "README.md")) as f: long_description = f.read() # get this package's version from dbt/adapters//__version__.py def _get_plugin_version_dict(): - _version_path = os.path.join( - this_directory, 'dbt', 'adapters', 'snowflake', '__version__.py' - ) - _semver = r'''(?P\d+)\.(?P\d+)\.(?P\d+)''' - _pre = r'''((?Pa|b|rc)(?P
\d+))?'''
-    _version_pattern = fr'''version\s*=\s*["']{_semver}{_pre}["']'''
+    _version_path = os.path.join(this_directory, "dbt", "adapters", "snowflake", "__version__.py")
+    _semver = r"""(?P\d+)\.(?P\d+)\.(?P\d+)"""
+    _pre = r"""((?Pa|b|rc)(?P
\d+))?"""
+    _version_pattern = fr"""version\s*=\s*["']{_semver}{_pre}["']"""
     with open(_version_path) as f:
         match = re.search(_version_pattern, f.read().strip())
         if match is None:
-            raise ValueError(f'invalid version at {_version_path}')
+            raise ValueError(f"invalid version at {_version_path}")
         return match.groupdict()
 
 
@@ -47,7 +45,7 @@ def _get_plugin_version_dict():
 def _get_dbt_core_version():
     parts = _get_plugin_version_dict()
     minor = "{major}.{minor}.0".format(**parts)
-    pre = (parts["prekind"]+"1" if parts["prekind"] else "")
+    pre = parts["prekind"] + "1" if parts["prekind"] else ""
     return f"{minor}{pre}"
 
 
@@ -61,31 +59,28 @@ def _get_dbt_core_version():
     version=package_version,
     description=description,
     long_description=long_description,
-    long_description_content_type='text/markdown',
+    long_description_content_type="text/markdown",
     author="dbt Labs",
     author_email="info@dbtlabs.com",
     url="https://github.com/dbt-labs/dbt-snowflake",
-    packages=find_namespace_packages(include=['dbt', 'dbt.*']),
+    packages=find_namespace_packages(include=["dbt", "dbt.*"]),
     include_package_data=True,
     install_requires=[
-        'dbt-core~={}'.format(dbt_core_version),
-        'snowflake-connector-python[secure-local-storage]>=2.4.1,<2.8.0',
-        'requests<3.0.0',
-        'cryptography>=3.2,<4',
+        "dbt-core~={}".format(dbt_core_version),
+        "snowflake-connector-python[secure-local-storage]>=2.4.1,<2.8.0",
+        "requests<3.0.0",
+        "cryptography>=3.2,<4",
     ],
     zip_safe=False,
     classifiers=[
-        'Development Status :: 5 - Production/Stable',
-
-        'License :: OSI Approved :: Apache Software License',
-
-        'Operating System :: Microsoft :: Windows',
-        'Operating System :: MacOS :: MacOS X',
-        'Operating System :: POSIX :: Linux',
-
-        'Programming Language :: Python :: 3.7',
-        'Programming Language :: Python :: 3.8',
-        'Programming Language :: Python :: 3.9',
+        "Development Status :: 5 - Production/Stable",
+        "License :: OSI Approved :: Apache Software License",
+        "Operating System :: Microsoft :: Windows",
+        "Operating System :: MacOS :: MacOS X",
+        "Operating System :: POSIX :: Linux",
+        "Programming Language :: Python :: 3.7",
+        "Programming Language :: Python :: 3.8",
+        "Programming Language :: Python :: 3.9",
     ],
     python_requires=">=3.7",
 )
diff --git a/test.env.example b/test.env.example
index a56137a00..9fdc37a21 100644
--- a/test.env.example
+++ b/test.env.example
@@ -1,8 +1,8 @@
-# Note: Make sure you have a Snowflake account that is set up so these fields are easy to complete. 
-If you don't have an account set up yet, then take note of these required fields that way. When you're getting set up, 
+# Note: Make sure you have a Snowflake account that is set up so these fields are easy to complete.
+If you don't have an account set up yet, then take note of these required fields that way. When you're getting set up,
 you can use them later to build your Snowflake project.
 
-### Test Environment field definitions 
+### Test Environment field definitions
 # These will all be gathered from account information or created by you.
 
 # SNOWFLAKE_TEST_ACCOUNT: The name that uniquely identifies your Snowflake account.
@@ -28,4 +28,4 @@ SNOWFLAKE_TEST_OAUTH_REFRESH_TOKEN=TRUE
 SNOWFLAKE_TEST_PASSWORD=my_password
 SNOWFLAKE_TEST_QUOTED_DATABASE=my_quoted_database_name
 SNOWFLAKE_TEST_USER=my_username
-SNOWFLAKE_TEST_WAREHOUSE=my_warehouse_name
\ No newline at end of file
+SNOWFLAKE_TEST_WAREHOUSE=my_warehouse_name
diff --git a/tox.ini b/tox.ini
index 9e6e264b3..406972020 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,15 +1,6 @@
 [tox]
 skipsdist = True
-envlist = py37,py38,py39,flake8
-
-[testenv:flake8]
-description = flake8 code checks
-basepython = python3.8
-skip_install = true
-commands = flake8 --select=E,W,F --ignore=W504,E741 --max-line-length 99 \
-  dbt
-deps =
-  -rdev_requirements.txt
+envlist = py37,py38,py39
 
 [testenv:{unit,py37,py38,py39,py}]
 description = unit testing