diff --git a/.devcontainer/devcontainer-lock.json b/.devcontainer/devcontainer-lock.json new file mode 100644 index 00000000..a3500695 --- /dev/null +++ b/.devcontainer/devcontainer-lock.json @@ -0,0 +1,9 @@ +{ + "features": { + "ghcr.io/devcontainers/features/docker-in-docker:2": { + "version": "2.7.1", + "resolved": "ghcr.io/devcontainers/features/docker-in-docker@sha256:f6a73ee06601d703db7d95d03e415cab229e78df92bb5002e8559bcfc047fec6", + "integrity": "sha256:f6a73ee06601d703db7d95d03e415cab229e78df92bb5002e8559bcfc047fec6" + } + } +} diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..457926e3 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,11 @@ +{ + "name": "Python 3", + "image": "mcr.microsoft.com/devcontainers/python:1-3.10-bookworm", + "features": { + "ghcr.io/devcontainers/features/docker-in-docker:2": {} + }, + "postStartCommand": "bash .devcontainer/setup_odbc.sh && bash .devcontainer/install_pyenv.sh && bash .devcontainer/setup_env.sh", + "runArgs": [ + "--env-file", "${localWorkspaceFolder}/.devcontainer/test.env" + ] +} diff --git a/.devcontainer/install_pyenv.sh b/.devcontainer/install_pyenv.sh new file mode 100644 index 00000000..4d9a0cff --- /dev/null +++ b/.devcontainer/install_pyenv.sh @@ -0,0 +1,6 @@ +#/bin/bash +curl https://pyenv.run | bash + +echo 'export PYENV_ROOT="$HOME/.pyenv" +[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH" +eval "$(pyenv init -)"' >> ~/.bashrc diff --git a/.devcontainer/setup_env.sh b/.devcontainer/setup_env.sh new file mode 100644 index 00000000..b3854530 --- /dev/null +++ b/.devcontainer/setup_env.sh @@ -0,0 +1,6 @@ +pyenv install 3.10.7 +pyenv virtualenv 3.10.7 dbt-sqlserver +pyenv activate dbt-sqlserver + +make dev +make server diff --git a/.devcontainer/setup_odbc.sh b/.devcontainer/setup_odbc.sh new file mode 100644 index 00000000..9a0821b9 --- /dev/null +++ b/.devcontainer/setup_odbc.sh @@ -0,0 +1,18 @@ +curl https://packages.microsoft.com/keys/microsoft.asc | sudo tee /etc/apt/trusted.gpg.d/microsoft.asc + +#Download appropriate package for the OS version +#Choose only ONE of the following, corresponding to your OS version + +#Debian 12 +curl https://packages.microsoft.com/config/debian/12/prod.list | sudo tee /etc/apt/sources.list.d/mssql-release.list + +sudo apt-get update +sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18 +# optional: for bcp and sqlcmd +sudo ACCEPT_EULA=Y apt-get install -y mssql-tools18 +echo 'export PATH="$PATH:/opt/mssql-tools18/bin"' >> ~/.bashrc +source ~/.bashrc +# optional: for unixODBC development headers +sudo apt-get install -y unixodbc-dev +# optional: kerberos library for debian-slim distributions +sudo apt-get install -y libgssapi-krb5-2 diff --git a/.github/workflows/integration-tests-azure.yml b/.github/workflows/integration-tests-azure.yml deleted file mode 100644 index b4f4d257..00000000 --- a/.github/workflows/integration-tests-azure.yml +++ /dev/null @@ -1,104 +0,0 @@ ---- -name: Integration tests on Azure -on: # yamllint disable-line rule:truthy - workflow_dispatch: - push: - branches: - - master - - v* - pull_request: - branches: - - master - - v* - -jobs: - integration-tests-azure: - name: Regular - strategy: - fail-fast: false - matrix: - python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] - profile: ["ci_azure_cli", "ci_azure_auto", "ci_azure_environment", "ci_azure_basic"] - msodbc_version: ["17", "18"] - runs-on: ubuntu-latest - container: - image: ghcr.io/${{ github.repository }}:CI-${{ matrix.python_version }}-msodbc${{ matrix.msodbc_version }} - steps: - - name: AZ CLI login - run: az login --service-principal --username="${AZURE_CLIENT_ID}" --password="${AZURE_CLIENT_SECRET}" --tenant="${AZURE_TENANT_ID}" - env: - AZURE_CLIENT_ID: ${{ secrets.DBT_AZURE_SP_NAME }} - AZURE_CLIENT_SECRET: ${{ secrets.DBT_AZURE_SP_SECRET }} - AZURE_TENANT_ID: ${{ secrets.DBT_AZURE_TENANT }} - - - uses: actions/checkout@v3 - - - name: Install dependencies - run: pip install -r dev_requirements.txt - - - name: Wake up server - env: - DBT_AZURESQL_SERVER: ${{ secrets.DBT_AZURESQL_SERVER }} - DBT_AZURESQL_DB: ${{ secrets.DBT_AZURESQL_DB }} - DBT_AZURESQL_UID: ${{ secrets.DBT_AZURESQL_UID }} - DBT_AZURESQL_PWD: ${{ secrets.DBT_AZURESQL_PWD }} - MSODBC_VERSION: ${{ matrix.msodbc_version }} - run: python devops/scripts/wakeup_azure.py - - - name: Configure test users - run: sqlcmd -b -I -i devops/scripts/init.sql - env: - DBT_TEST_USER_1: DBT_TEST_USER_1 - DBT_TEST_USER_2: DBT_TEST_USER_2 - DBT_TEST_USER_3: DBT_TEST_USER_3 - SQLCMDUSER: ${{ secrets.DBT_AZURESQL_UID }} - SQLCMDPASSWORD: ${{ secrets.DBT_AZURESQL_PWD }} - SQLCMDSERVER: ${{ secrets.DBT_AZURESQL_SERVER }} - SQLCMDDBNAME: ${{ secrets.DBT_AZURESQL_DB }} - - - name: Run functional tests - env: - DBT_AZURESQL_SERVER: ${{ secrets.DBT_AZURESQL_SERVER }} - DBT_AZURESQL_DB: ${{ secrets.DBT_AZURESQL_DB }} - DBT_AZURESQL_UID: ${{ secrets.DBT_AZURESQL_UID }} - DBT_AZURESQL_PWD: ${{ secrets.DBT_AZURESQL_PWD }} - AZURE_CLIENT_ID: ${{ secrets.DBT_AZURE_SP_NAME }} - AZURE_CLIENT_SECRET: ${{ secrets.DBT_AZURE_SP_SECRET }} - AZURE_TENANT_ID: ${{ secrets.DBT_AZURE_TENANT }} - DBT_TEST_USER_1: DBT_TEST_USER_1 - DBT_TEST_USER_2: DBT_TEST_USER_2 - DBT_TEST_USER_3: DBT_TEST_USER_3 - SQLSERVER_TEST_DRIVER: 'ODBC Driver ${{ matrix.msodbc_version }} for SQL Server' - run: pytest --ignore=tests/functional/adapter/test_provision_users.py -ra -v tests/functional --profile "${{ matrix.profile }}" - - flaky-tests-azure: - name: Flaky tests on Azure - runs-on: ubuntu-latest - container: - image: ghcr.io/${{ github.repository }}:CI-3.11-msodbc18 - steps: - - uses: actions/checkout@v3 - - - name: Install dependencies - run: pip install -r dev_requirements.txt - - - name: Wake up server - env: - DBT_AZURESQL_SERVER: ${{ secrets.DBT_AZURESQL_SERVER }} - DBT_AZURESQL_DB: ${{ secrets.DBT_AZURESQL_DB }} - DBT_AZURESQL_UID: ${{ secrets.DBT_AZURESQL_UID }} - DBT_AZURESQL_PWD: ${{ secrets.DBT_AZURESQL_PWD }} - MSODBC_VERSION: ${{ matrix.msodbc_version }} - run: python devops/scripts/wakeup_azure.py - - - name: Run auto provisioning tests - env: - DBT_AZURESQL_SERVER: ${{ secrets.DBT_AZURESQL_SERVER }} - DBT_AZURESQL_DB: ${{ secrets.DBT_AZURESQL_DB }} - AZURE_CLIENT_ID: ${{ secrets.DBT_AZURE_SP_NAME }} - AZURE_CLIENT_SECRET: ${{ secrets.DBT_AZURE_SP_SECRET }} - AZURE_TENANT_ID: ${{ secrets.DBT_AZURE_TENANT }} - DBT_TEST_AAD_PRINCIPAL_1: ${{ secrets.DBT_TEST_AAD_PRINCIPAL_1 }} - DBT_TEST_AAD_PRINCIPAL_2: ${{ secrets.DBT_TEST_AAD_PRINCIPAL_2 }} - SQLSERVER_TEST_DRIVER: 'ODBC Driver 18 for SQL Server' - run: pytest -ra -v tests/functional/adapter/test_provision_users.py --profile "ci_azure_environment" diff --git a/.github/workflows/integration-tests-sqlserver.yml b/.github/workflows/integration-tests-sqlserver.yml index 7eca7015..1de39167 100644 --- a/.github/workflows/integration-tests-sqlserver.yml +++ b/.github/workflows/integration-tests-sqlserver.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python_version: ["3.8", "3.9", "3.10", "3.11"] msodbc_version: ["17", "18"] sqlserver_version: ["2017", "2019", "2022"] collation: ["SQL_Latin1_General_CP1_CS_AS", "SQL_Latin1_General_CP1_CI_AS"] diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 2aed4961..1ca0f563 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -16,7 +16,7 @@ jobs: name: Unit tests strategy: matrix: - python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python_version: ["3.8", "3.9", "3.10", "3.11"] runs-on: ubuntu-latest permissions: contents: read diff --git a/.gitignore b/.gitignore index 38ba195c..3b122e2f 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,4 @@ venv/ ENV/ env.bak/ venv.bak/ +.mise.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e23b4d9f..a916f3ea 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ default_language_version: - python: python3.9 + python: python3.10 repos: - repo: 'https://github.com/pre-commit/pre-commit-hooks' rev: v4.4.0 @@ -62,7 +62,7 @@ repos: - manual args: - '--line-length=99' - - '--target-version=py39' + - '--target-version=py310' - '--check' - '--diff' - repo: 'https://github.com/pycqa/flake8' @@ -94,4 +94,5 @@ repos: - '--show-error-codes' - '--pretty' - '--ignore-missing-imports' + - '--explicit-package-bases' files: '^dbt/adapters' diff --git a/CHANGELOG.md b/CHANGELOG.md index 699b4b31..1cac1ab4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +### v1.7.2 + +Updated to use dbt-fabric as the upstream adapter (https://github.com/dbt-msft/dbt-sqlserver/issues/441#issuecomment-1815837171)[https://github.com/dbt-msft/dbt-sqlserver/issues/441#issuecomment-1815837171] and (https://github.com/microsoft/dbt-fabric/issues/105)[https://github.com/microsoft/dbt-fabric/issues/105] + +As the fabric adapter implements the majority of auth and required t-sql, this adapter delegates primarily to SQL auth and SQL Server specific +adaptations (using `SELECT INTO` vs `CREATE TABLE AS`). + +Additional major changes pulled from fabric adapter: + +* `TIMESTAMP` changing from `DATETIMEOFFSET` to `DATETIME2(6)` +* `STRING` changing from `VARCHAR(MAX)` to `VARCHAR(8000)` + + +#### Future work to be validated + +* Fabric specific items that need further over-rides (clone for example needed overriding) +* Azure Auth elements to be deferred to Fabric, but should be validated +* T-SQL Package to be updated and validated with these changes. + ### v1.4.3 Another minor release to follow up on the 1.4 releases. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f01e4bc0..28bdc43b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,6 +20,10 @@ make help [Pre-commit](https://pre-commit.com/) helps us to maintain a consistent style and code quality across the entire project. After running `make dev`, pre-commit will automatically validate your commits and fix any formatting issues whenever possible. +## Devcontainer + +A devcontainer file has been added since 1.7.2 to simpify creating the development environment. + ## Testing The functional tests require a running SQL Server instance. You can easily spin up a local instance with the following command: diff --git a/dbt/adapters/sqlserver/__init__.py b/dbt/adapters/sqlserver/__init__.py index eb6b2573..6bd51201 100644 --- a/dbt/adapters/sqlserver/__init__.py +++ b/dbt/adapters/sqlserver/__init__.py @@ -11,6 +11,7 @@ adapter=SQLServerAdapter, credentials=SQLServerCredentials, include_path=sqlserver.PACKAGE_PATH, + dependencies=["fabric"], ) __all__ = [ diff --git a/dbt/adapters/sqlserver/__version__.py b/dbt/adapters/sqlserver/__version__.py index 8fb4690e..2196826f 100644 --- a/dbt/adapters/sqlserver/__version__.py +++ b/dbt/adapters/sqlserver/__version__.py @@ -1 +1 @@ -version = "1.4.3" +version = "1.7.2" diff --git a/dbt/adapters/sqlserver/sql_server_adapter.py b/dbt/adapters/sqlserver/sql_server_adapter.py index 6a621ac2..dde85aa7 100644 --- a/dbt/adapters/sqlserver/sql_server_adapter.py +++ b/dbt/adapters/sqlserver/sql_server_adapter.py @@ -1,175 +1,55 @@ -from typing import List, Optional - -import agate -from dbt.adapters.base.relation import BaseRelation -from dbt.adapters.cache import _make_ref_key_msg -from dbt.adapters.sql import SQLAdapter -from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME -from dbt.events.functions import fire_event -from dbt.events.types import SchemaCreation +# https://github.com/microsoft/dbt-fabric/blob/main/dbt/adapters/fabric/fabric_adapter.py +from dbt.adapters.fabric import FabricAdapter from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn from dbt.adapters.sqlserver.sql_server_configs import SQLServerConfigs from dbt.adapters.sqlserver.sql_server_connection_manager import SQLServerConnectionManager +# from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support + -class SQLServerAdapter(SQLAdapter): +class SQLServerAdapter(FabricAdapter): ConnectionManager = SQLServerConnectionManager Column = SQLServerColumn AdapterSpecificConfigs = SQLServerConfigs - def create_schema(self, relation: BaseRelation) -> None: - relation = relation.without_identifier() - fire_event(SchemaCreation(relation=_make_ref_key_msg(relation))) - macro_name = CREATE_SCHEMA_MACRO_NAME - kwargs = { - "relation": relation, - } - - if self.config.credentials.schema_authorization: - kwargs["schema_authorization"] = self.config.credentials.schema_authorization - macro_name = "sqlserver__create_schema_with_authorization" - - self.execute_macro(macro_name, kwargs=kwargs) - self.commit_if_has_connection() + # _capabilities: CapabilityDict = CapabilityDict( + # { + # Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full), + # Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full), + # } + # ) + + # region - these are implement in fabric but not in sqlserver + # _capabilities: CapabilityDict = CapabilityDict( + # { + # Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full), + # Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full), + # } + # ) + # CONSTRAINT_SUPPORT = { + # ConstraintType.check: ConstraintSupport.NOT_SUPPORTED, + # ConstraintType.not_null: ConstraintSupport.ENFORCED, + # ConstraintType.unique: ConstraintSupport.ENFORCED, + # ConstraintType.primary_key: ConstraintSupport.ENFORCED, + # ConstraintType.foreign_key: ConstraintSupport.ENFORCED, + # } + + # @available.parse(lambda *a, **k: []) + # def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: + # """Get a list of the Columns with names and data types from the given sql.""" + # _, cursor = self.connections.add_select_query(sql) + + # columns = [ + # self.Column.create( + # column_name, self.connections.data_type_code_to_name(column_type_code) + # ) + # # https://peps.python.org/pep-0249/#description + # for column_name, column_type_code, *_ in cursor.description + # ] + # return columns + # endregion @classmethod def date_function(cls): return "getdate()" - - @classmethod - def convert_text_type(cls, agate_table, col_idx): - column = agate_table.columns[col_idx] - # see https://github.com/fishtown-analytics/dbt/pull/2255 - lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()] - max_len = max(lens) if lens else 64 - length = max_len if max_len > 16 else 16 - return "varchar({})".format(length) - - @classmethod - def convert_datetime_type(cls, agate_table, col_idx): - return "datetime" - - @classmethod - def convert_boolean_type(cls, agate_table, col_idx): - return "bit" - - @classmethod - def convert_number_type(cls, agate_table, col_idx): - decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) - return "float" if decimals else "int" - - @classmethod - def convert_time_type(cls, agate_table, col_idx): - return "datetime" - - # Methods used in adapter tests - def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: - # note: 'interval' is not supported for T-SQL - # for backwards compatibility, we're compelled to set some sort of - # default. A lot of searching has lead me to believe that the - # '+ interval' syntax used in postgres/redshift is relatively common - # and might even be the SQL standard's intention. - return f"DATEADD({interval},{number},{add_to})" - - def string_add_sql( - self, - add_to: str, - value: str, - location="append", - ) -> str: - """ - `+` is T-SQL's string concatenation operator - """ - if location == "append": - return f"{add_to} + '{value}'" - elif location == "prepend": - return f"'{value}' + {add_to}" - else: - raise ValueError(f'Got an unexpected location value of "{location}"') - - def get_rows_different_sql( - self, - relation_a: BaseRelation, - relation_b: BaseRelation, - column_names: Optional[List[str]] = None, - except_operator: str = "EXCEPT", - ) -> str: - """ - note: using is not supported on Synapse so COLUMNS_EQUAL_SQL is adjsuted - Generate SQL for a query that returns a single row with a two - columns: the number of rows that are different between the two - relations and the number of mismatched rows. - """ - # This method only really exists for test reasons. - names: List[str] - if column_names is None: - columns = self.get_columns_in_relation(relation_a) - names = sorted((self.quote(c.name) for c in columns)) - else: - names = sorted((self.quote(n) for n in column_names)) - columns_csv = ", ".join(names) - - sql = COLUMNS_EQUAL_SQL.format( - columns=columns_csv, - relation_a=str(relation_a), - relation_b=str(relation_b), - except_op=except_operator, - ) - - return sql - - def valid_incremental_strategies(self): - """The set of standard builtin strategies which this adapter supports out-of-the-box. - Not used to validate custom strategies defined by end users. - """ - return ["append", "delete+insert", "merge", "insert_overwrite"] - - # This is for use in the test suite - def run_sql_for_tests(self, sql, fetch, conn): - cursor = conn.handle.cursor() - try: - cursor.execute(sql) - if not fetch: - conn.handle.commit() - if fetch == "one": - return cursor.fetchone() - elif fetch == "all": - return cursor.fetchall() - else: - return - except BaseException: - if conn.handle and not getattr(conn.handle, "closed", True): - conn.handle.rollback() - raise - finally: - conn.transaction_open = False - - -COLUMNS_EQUAL_SQL = """ -with diff_count as ( - SELECT - 1 as id, - COUNT(*) as num_missing FROM ( - (SELECT {columns} FROM {relation_a} {except_op} - SELECT {columns} FROM {relation_b}) - UNION ALL - (SELECT {columns} FROM {relation_b} {except_op} - SELECT {columns} FROM {relation_a}) - ) as a -), table_a as ( - SELECT COUNT(*) as num_rows FROM {relation_a} -), table_b as ( - SELECT COUNT(*) as num_rows FROM {relation_b} -), row_count_diff as ( - select - 1 as id, - table_a.num_rows - table_b.num_rows as difference - from table_a, table_b -) -select - row_count_diff.difference as row_count_difference, - diff_count.num_missing as num_mismatched -from row_count_diff -join diff_count on row_count_diff.id = diff_count.id -""".strip() diff --git a/dbt/adapters/sqlserver/sql_server_column.py b/dbt/adapters/sqlserver/sql_server_column.py index 4f2e01a4..2ee5f7b7 100644 --- a/dbt/adapters/sqlserver/sql_server_column.py +++ b/dbt/adapters/sqlserver/sql_server_column.py @@ -1,20 +1,5 @@ -from typing import Any, ClassVar, Dict +from dbt.adapters.fabric import FabricColumn -from dbt.adapters.base import Column - -class SQLServerColumn(Column): - TYPE_LABELS: ClassVar[Dict[str, str]] = { - "STRING": "VARCHAR(MAX)", - "TIMESTAMP": "DATETIMEOFFSET", - "FLOAT": "FLOAT", - "INTEGER": "INT", - "BOOLEAN": "BIT", - } - - @classmethod - def string_type(cls, size: int) -> str: - return f"varchar({size if size > 0 else 'MAX'})" - - def literal(self, value: Any) -> str: - return "cast('{}' as {})".format(value, self.data_type) +class SQLServerColumn(FabricColumn): + ... diff --git a/dbt/adapters/sqlserver/sql_server_configs.py b/dbt/adapters/sqlserver/sql_server_configs.py index bf6d2d1e..2804179b 100644 --- a/dbt/adapters/sqlserver/sql_server_configs.py +++ b/dbt/adapters/sqlserver/sql_server_configs.py @@ -1,9 +1,8 @@ from dataclasses import dataclass -from typing import Optional -from dbt.adapters.protocol import AdapterConfig +from dbt.adapters.fabric import FabricConfigs @dataclass -class SQLServerConfigs(AdapterConfig): - auto_provision_aad_principals: Optional[bool] = False +class SQLServerConfigs(FabricConfigs): + ... diff --git a/dbt/adapters/sqlserver/sql_server_connection_manager.py b/dbt/adapters/sqlserver/sql_server_connection_manager.py index f5ac8546..cf42bff4 100644 --- a/dbt/adapters/sqlserver/sql_server_connection_manager.py +++ b/dbt/adapters/sqlserver/sql_server_connection_manager.py @@ -1,97 +1,28 @@ -import datetime as dt -import struct -import time -from contextlib import contextmanager -from itertools import chain, repeat -from typing import Any, Callable, Dict, Mapping, Optional, Tuple - -import agate -import dbt.exceptions +from typing import Callable, Mapping + import pyodbc from azure.core.credentials import AccessToken -from azure.identity import ( - AzureCliCredential, - ClientSecretCredential, - DefaultAzureCredential, - EnvironmentCredential, - ManagedIdentityCredential, +from azure.identity import ClientSecretCredential, ManagedIdentityCredential +from dbt.adapters.fabric import FabricConnectionManager +from dbt.adapters.fabric.fabric_connection_manager import ( + AZURE_AUTH_FUNCTIONS as AZURE_AUTH_FUNCTIONS_FABRIC, +) +from dbt.adapters.fabric.fabric_connection_manager import ( + AZURE_CREDENTIAL_SCOPE, + bool_to_connection_string_arg, + get_pyodbc_attrs_before, ) -from dbt.adapters.sql import SQLConnectionManager -from dbt.clients.agate_helper import empty_table -from dbt.contracts.connection import AdapterResponse, Connection, ConnectionState +from dbt.contracts.connection import Connection, ConnectionState from dbt.events import AdapterLogger from dbt.adapters.sqlserver import __version__ from dbt.adapters.sqlserver.sql_server_credentials import SQLServerCredentials -AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default" -_TOKEN: Optional[AccessToken] = None AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials], AccessToken] logger = AdapterLogger("SQLServer") -def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: - """ - Convert bytes to a Microsoft windows byte string. - - Parameters - ---------- - value : bytes - The bytes. - - Returns - ------- - out : bytes - The Microsoft byte string. - """ - encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0)))) - return struct.pack(" bytes: - """ - Convert an access token to a Microsoft windows byte string. - - Parameters - ---------- - token : AccessToken - The token. - - Returns - ------- - out : bytes - The Microsoft byte string. - """ - value = bytes(token.token, "UTF-8") - return convert_bytes_to_mswindows_byte_string(value) - - -def get_cli_access_token(credentials: SQLServerCredentials) -> AccessToken: - """ - Get an Azure access token using the CLI credentials - - First login with: - - ```bash - az login - ``` - - Parameters - ---------- - credentials: SQLServerConnectionManager - The credentials. - - Returns - ------- - out : AccessToken - Access token. - """ - _ = credentials - token = AzureCliCredential().get_token(AZURE_CREDENTIAL_SCOPE) - return token - - def get_msi_access_token(credentials: SQLServerCredentials) -> AccessToken: """ Get an Azure access token from the system's managed identity @@ -110,42 +41,6 @@ def get_msi_access_token(credentials: SQLServerCredentials) -> AccessToken: return token -def get_auto_access_token(credentials: SQLServerCredentials) -> AccessToken: - """ - Get an Azure access token automatically through azure-identity - - Parameters - ----------- - credentials: SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - token = DefaultAzureCredential().get_token(AZURE_CREDENTIAL_SCOPE) - return token - - -def get_environment_access_token(credentials: SQLServerCredentials) -> AccessToken: - """ - Get an Azure access token by reading environment variables - - Parameters - ----------- - credentials: SQLServerCredentials - Credentials. - - Returns - ------- - out : AccessToken - The access token. - """ - token = EnvironmentCredential().get_token(AZURE_CREDENTIAL_SCOPE) - return token - - def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken: """ Get an Azure access token using the SP credentials. @@ -161,144 +56,23 @@ def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken: The access token. """ token = ClientSecretCredential( - str(credentials.tenant_id), str(credentials.client_id), str(credentials.client_secret) + str(credentials.tenant_id), + str(credentials.client_id), + str(credentials.client_secret), ).get_token(AZURE_CREDENTIAL_SCOPE) return token AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = { + **AZURE_AUTH_FUNCTIONS_FABRIC, "serviceprincipal": get_sp_access_token, - "cli": get_cli_access_token, "msi": get_msi_access_token, - "auto": get_auto_access_token, - "environment": get_environment_access_token, } -def get_pyodbc_attrs_before(credentials: SQLServerCredentials) -> Dict: - """ - Get the pyodbc attrs before. - - Parameters - ---------- - credentials : SQLServerCredentials - Credentials. - - Returns - ------- - out : Dict - The pyodbc attrs before. - - Source - ------ - Authentication for SQL server with an access token: - https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory?view=sql-server-ver15#authenticating-with-an-access-token - """ - global _TOKEN - attrs_before: Dict - MAX_REMAINING_TIME = 300 - - authentication = str(credentials.authentication).lower() - if authentication in AZURE_AUTH_FUNCTIONS: - time_remaining = (_TOKEN.expires_on - time.time()) if _TOKEN else MAX_REMAINING_TIME - - if _TOKEN is None or (time_remaining < MAX_REMAINING_TIME): - azure_auth_function = AZURE_AUTH_FUNCTIONS[authentication] - _TOKEN = azure_auth_function(credentials) - - token_bytes = convert_access_token_to_mswindows_byte_string(_TOKEN) - sql_copt_ss_access_token = 1256 # see source in docstring - attrs_before = {sql_copt_ss_access_token: token_bytes} - else: - attrs_before = {} - - return attrs_before - - -def bool_to_connection_string_arg(key: str, value: bool) -> str: - """ - Convert a boolean to a connection string argument. - - Parameters - ---------- - key : str - The key to use in the connection string. - value : bool - The boolean to convert. - - Returns - ------- - out : str - The connection string argument. - """ - return f'{key}={"Yes" if value else "No"}' - - -def byte_array_to_datetime(value: bytes) -> dt.datetime: - """ - Converts a DATETIMEOFFSET byte array to a timezone-aware datetime object - - Parameters - ---------- - value : buffer - A binary value conforming to SQL_SS_TIMESTAMPOFFSET_STRUCT - - Returns - ------- - out : datetime - - Source - ------ - SQL_SS_TIMESTAMPOFFSET datatype and SQL_SS_TIMESTAMPOFFSET_STRUCT layout: - https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date-time/data-type-support-for-odbc-date-and-time-improvements - """ - # unpack 20 bytes of data into a tuple of 9 values - tup = struct.unpack("<6hI2h", value) - - # construct a datetime object - return dt.datetime( - year=tup[0], - month=tup[1], - day=tup[2], - hour=tup[3], - minute=tup[4], - second=tup[5], - microsecond=tup[6] // 1000, # https://bugs.python.org/issue15443 - tzinfo=dt.timezone(dt.timedelta(hours=tup[7], minutes=tup[8])), - ) - - -class SQLServerConnectionManager(SQLConnectionManager): +class SQLServerConnectionManager(FabricConnectionManager): TYPE = "sqlserver" - @contextmanager - def exception_handler(self, sql): - try: - yield - - except pyodbc.DatabaseError as e: - logger.debug("Database error: {}".format(str(e))) - - try: - # attempt to release the connection - self.release() - except pyodbc.Error: - logger.debug("Failed to release connection!") - - raise dbt.exceptions.DbtDatabaseError(str(e).strip()) from e - - except Exception as e: - logger.debug(f"Error running SQL: {sql}") - logger.debug("Rolling back transaction.") - self.release() - if isinstance(e, dbt.exceptions.DbtRuntimeError): - # during a sql query, an internal to dbt exception was raised. - # this sounds a lot like a signal handler and probably has - # useful information, so raise it without modification. - raise - - raise dbt.exceptions.DbtRuntimeError(e) - @classmethod def open(cls, connection: Connection) -> Connection: if connection.state == ConnectionState.OPEN: @@ -306,6 +80,10 @@ def open(cls, connection: Connection) -> Connection: return connection credentials = cls.get_credentials(connection.credentials) + if credentials.authentication != "sql": + return super().open(connection) + + # sql login authentication con_str = [f"DRIVER={{{credentials.driver}}}"] @@ -320,20 +98,8 @@ def open(cls, connection: Connection) -> Connection: assert credentials.authentication is not None - if "ActiveDirectory" in credentials.authentication: - con_str.append(f"Authentication={credentials.authentication}") - - if credentials.authentication == "ActiveDirectoryPassword": - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") - elif credentials.authentication == "ActiveDirectoryInteractive": - con_str.append(f"UID={{{credentials.UID}}}") - - elif credentials.windows_login: - con_str.append("trusted_connection=Yes") - elif credentials.authentication == "sql": - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") + con_str.append(f"UID={{{credentials.UID}}}") + con_str.append(f"PWD={{{credentials.PWD}}}") # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15 assert credentials.encrypt is not None @@ -390,95 +156,3 @@ def connect(): retry_limit=credentials.retries, retryable_exceptions=retryable_exceptions, ) - - def cancel(self, connection: Connection): - logger.debug("Cancel query") - - def add_begin_query(self): - # return self.add_query('BEGIN TRANSACTION', auto_begin=False) - pass - - def add_commit_query(self): - # return self.add_query('COMMIT TRANSACTION', auto_begin=False) - pass - - def add_query( - self, - sql: str, - auto_begin: bool = True, - bindings: Optional[Any] = None, - abridge_sql_log: bool = False, - ) -> Tuple[Connection, Any]: - connection = self.get_thread_connection() - - if auto_begin and connection.transaction_open is False: - self.begin() - - logger.debug('Using {} connection "{}".'.format(self.TYPE, connection.name)) - - with self.exception_handler(sql): - if abridge_sql_log: - logger.debug("On {}: {}....".format(connection.name, sql[0:512])) - else: - logger.debug("On {}: {}".format(connection.name, sql)) - pre = time.time() - - cursor = connection.handle.cursor() - - # pyodbc does not handle a None type binding! - if bindings is None: - cursor.execute(sql) - else: - cursor.execute(sql, bindings) - - # convert DATETIMEOFFSET binary structures to datetime ojbects - # https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 - connection.handle.add_output_converter(-155, byte_array_to_datetime) - - logger.debug( - "SQL status: {} in {:0.2f} seconds".format( - self.get_response(cursor), (time.time() - pre) - ) - ) - - return connection, cursor - - @classmethod - def get_credentials(cls, credentials: SQLServerCredentials) -> SQLServerCredentials: - return credentials - - @classmethod - def get_response(cls, cursor: Any) -> AdapterResponse: - # message = str(cursor.statusmessage) - message = "OK" - rows = cursor.rowcount - # status_message_parts = message.split() if message is not None else [] - # status_messsage_strings = [ - # part - # for part in status_message_parts - # if not part.isdigit() - # ] - # code = ' '.join(status_messsage_strings) - return AdapterResponse( - _message=message, - # code=code, - rows_affected=rows, - ) - - def execute( - self, sql: str, auto_begin: bool = True, fetch: bool = False - ) -> Tuple[AdapterResponse, agate.Table]: - _, cursor = self.add_query(sql, auto_begin) - response = self.get_response(cursor) - if fetch: - # Get the result of the first non-empty result set (if any) - while cursor.description is None: - if not cursor.nextset(): - break - table = self.get_result_from_cursor(cursor) - else: - table = empty_table() - # Step through all result sets so we process all errors - while cursor.nextset(): - pass - return response, table diff --git a/dbt/adapters/sqlserver/sql_server_credentials.py b/dbt/adapters/sqlserver/sql_server_credentials.py index 21b85b24..db9274d3 100644 --- a/dbt/adapters/sqlserver/sql_server_credentials.py +++ b/dbt/adapters/sqlserver/sql_server_credentials.py @@ -1,69 +1,17 @@ from dataclasses import dataclass from typing import Optional -from dbt.contracts.connection import Credentials +from dbt.adapters.fabric import FabricCredentials @dataclass -class SQLServerCredentials(Credentials): - driver: str - host: str - database: str - schema: str +class SQLServerCredentials(FabricCredentials): port: Optional[int] = 1433 - UID: Optional[str] = None - PWD: Optional[str] = None - windows_login: Optional[bool] = False - tenant_id: Optional[str] = None - client_id: Optional[str] = None - client_secret: Optional[str] = None authentication: Optional[str] = "sql" - encrypt: Optional[bool] = True # default value in MS ODBC Driver 18 as well - trust_cert: Optional[bool] = False # default value in MS ODBC Driver 18 as well - retries: int = 1 - schema_authorization: Optional[str] = None - login_timeout: Optional[int] = 0 - query_timeout: Optional[int] = 0 - - _ALIASES = { - "user": "UID", - "username": "UID", - "pass": "PWD", - "password": "PWD", - "server": "host", - "trusted_connection": "windows_login", - "auth": "authentication", - "app_id": "client_id", - "app_secret": "client_secret", - "TrustServerCertificate": "trust_cert", - "schema_auth": "schema_authorization", - } @property def type(self): return "sqlserver" def _connection_keys(self): - # return an iterator of keys to pretty-print in 'dbt debug' - # raise NotImplementedError - if self.windows_login is True: - self.authentication = "Windows Login" - - return ( - "server", - "database", - "schema", - "port", - "UID", - "client_id", - "authentication", - "encrypt", - "trust_cert", - "retries", - "login_timeout", - "query_timeout", - ) - - @property - def unique_field(self): - return self.host + return super()._connection_keys() + ("port",) diff --git a/dbt/include/sqlserver/macros/.gitkeep b/dbt/include/sqlserver/macros/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/dbt/include/sqlserver/macros/adapters/.gitkeep b/dbt/include/sqlserver/macros/adapters/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/dbt/include/sqlserver/macros/adapters/apply_grants.sql b/dbt/include/sqlserver/macros/adapters/apply_grants.sql deleted file mode 100644 index 253c34b3..00000000 --- a/dbt/include/sqlserver/macros/adapters/apply_grants.sql +++ /dev/null @@ -1,75 +0,0 @@ -{% macro sqlserver__get_show_grant_sql(relation) %} - select - GRANTEE as grantee, - PRIVILEGE_TYPE as privilege_type - from INFORMATION_SCHEMA.TABLE_PRIVILEGES {{ information_schema_hints() }} - where TABLE_CATALOG = '{{ relation.database }}' - and TABLE_SCHEMA = '{{ relation.schema }}' - and TABLE_NAME = '{{ relation.identifier }}' -{% endmacro %} - - -{%- macro sqlserver__get_grant_sql(relation, privilege, grantees) -%} - {%- set grantees_safe = [] -%} - {%- for grantee in grantees -%} - {%- set grantee_safe = "[" ~ grantee ~ "]" -%} - {%- do grantees_safe.append(grantee_safe) -%} - {%- endfor -%} - grant {{ privilege }} on {{ relation }} to {{ grantees_safe | join(', ') }} -{%- endmacro -%} - - -{%- macro sqlserver__get_revoke_sql(relation, privilege, grantees) -%} - {%- set grantees_safe = [] -%} - {%- for grantee in grantees -%} - {%- set grantee_safe = "[" ~ grantee ~ "]" -%} - {%- do grantees_safe.append(grantee_safe) -%} - {%- endfor -%} - revoke {{ privilege }} on {{ relation }} from {{ grantees_safe | join(', ') }} -{%- endmacro -%} - - -{% macro get_provision_sql(relation, privilege, grantees) %} - {% for grantee in grantees %} - if not exists(select name from sys.database_principals where name = '{{ grantee }}') - create user [{{ grantee }}] from external provider; - {% endfor %} -{% endmacro %} - - -{% macro sqlserver__apply_grants(relation, grant_config, should_revoke=True) %} - {#-- If grant_config is {} or None, this is a no-op --#} - {% if grant_config %} - {% if should_revoke %} - {#-- We think previous grants may have carried over --#} - {#-- Show current grants and calculate diffs --#} - {% set current_grants_table = run_query(get_show_grant_sql(relation)) %} - {% set current_grants_dict = adapter.standardize_grants_dict(current_grants_table) %} - {% set needs_granting = diff_of_two_dicts(grant_config, current_grants_dict) %} - {% set needs_revoking = diff_of_two_dicts(current_grants_dict, grant_config) %} - {% if not (needs_granting or needs_revoking) %} - {{ log('On ' ~ relation ~': All grants are in place, no revocation or granting needed.')}} - {% endif %} - {% else %} - {#-- We don't think there's any chance of previous grants having carried over. --#} - {#-- Jump straight to granting what the user has configured. --#} - {% set needs_revoking = {} %} - {% set needs_granting = grant_config %} - {% endif %} - {% if needs_granting or needs_revoking %} - {% set revoke_statement_list = get_dcl_statement_list(relation, needs_revoking, get_revoke_sql) %} - - {% if config.get('auto_provision_aad_principals', False) %} - {% set provision_statement_list = get_dcl_statement_list(relation, needs_granting, get_provision_sql) %} - {% else %} - {% set provision_statement_list = [] %} - {% endif %} - - {% set grant_statement_list = get_dcl_statement_list(relation, needs_granting, get_grant_sql) %} - {% set dcl_statement_list = revoke_statement_list + provision_statement_list + grant_statement_list %} - {% if dcl_statement_list %} - {{ call_dcl_statements(dcl_statement_list) }} - {% endif %} - {% endif %} - {% endif %} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/columns.sql b/dbt/include/sqlserver/macros/adapters/columns.sql index 315c7a3d..bce8e5ee 100644 --- a/dbt/include/sqlserver/macros/adapters/columns.sql +++ b/dbt/include/sqlserver/macros/adapters/columns.sql @@ -1,46 +1,3 @@ -{% macro sqlserver__get_columns_in_relation(relation) -%} - {% call statement('get_columns_in_relation', fetch_result=True) %} - - with mapping as ( - select - row_number() over (partition by object_name(c.object_id) order by c.column_id) as ordinal_position, - c.name collate database_default as column_name, - t.name as data_type, - c.max_length as character_maximum_length, - c.precision as numeric_precision, - c.scale as numeric_scale - from [{{ 'tempdb' if '#' in relation.identifier else relation.database }}].sys.columns c {{ information_schema_hints() }} - inner join sys.types t {{ information_schema_hints() }} - on c.user_type_id = t.user_type_id - where c.object_id = object_id('{{ 'tempdb..' ~ relation.include(database=false, schema=false) if '#' in relation.identifier else relation }}') - ) - - select - column_name, - data_type, - character_maximum_length, - numeric_precision, - numeric_scale - from mapping - order by ordinal_position - - {% endcall %} - {% set table = load_result('get_columns_in_relation').table %} - {{ return(sql_convert_columns_in_relation(table)) }} -{% endmacro %} - - -{% macro sqlserver__get_columns_in_query(select_sql) %} - {% call statement('get_columns_in_query', fetch_result=True, auto_begin=False) -%} - select TOP 0 * from ( - {{ select_sql }} - ) as __dbt_sbq - where 0 = 1 - {% endcall %} - - {{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }} -{% endmacro %} - {% macro sqlserver__alter_column_type(relation, column_name, new_column_type) %} {%- set tmp_column = column_name + "__dbt_alter" -%} @@ -59,18 +16,3 @@ {%- endcall -%} {% endmacro %} - - -{% macro sqlserver__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} - {% call statement('add_drop_columns') -%} - {% if add_columns %} - alter {{ relation.type }} {{ relation }} - add {% for column in add_columns %}"{{ column.name }}" {{ column.data_type }}{{ ', ' if not loop.last }}{% endfor %}; - {% endif %} - - {% if remove_columns %} - alter {{ relation.type }} {{ relation }} - drop column {% for column in remove_columns %}"{{ column.name }}"{{ ',' if not loop.last }}{% endfor %}; - {% endif %} - {%- endcall -%} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/metadata.sql b/dbt/include/sqlserver/macros/adapters/metadata.sql deleted file mode 100644 index 84a55dce..00000000 --- a/dbt/include/sqlserver/macros/adapters/metadata.sql +++ /dev/null @@ -1,166 +0,0 @@ -{% macro information_schema_hints() %} - {{ return(adapter.dispatch('information_schema_hints')()) }} -{% endmacro %} - -{% macro default__information_schema_hints() %}{% endmacro %} -{% macro sqlserver__information_schema_hints() %}with (nolock){% endmacro %} - -{% macro sqlserver__get_catalog(information_schemas, schemas) -%} - - {%- call statement('catalog', fetch_result=True) -%} - - with - principals as ( - select - name as principal_name, - principal_id as principal_id - from - sys.database_principals {{ information_schema_hints() }} - ), - - schemas as ( - select - name as schema_name, - schema_id as schema_id, - principal_id as principal_id - from - sys.schemas {{ information_schema_hints() }} - ), - - tables as ( - select - name as table_name, - schema_id as schema_id, - principal_id as principal_id, - 'BASE TABLE' as table_type - from - sys.tables {{ information_schema_hints() }} - ), - - tables_with_metadata as ( - select - table_name, - schema_name, - coalesce(tables.principal_id, schemas.principal_id) as owner_principal_id, - table_type - from - tables - join schemas on tables.schema_id = schemas.schema_id - ), - - views as ( - select - name as table_name, - schema_id as schema_id, - principal_id as principal_id, - 'VIEW' as table_type - from - sys.views {{ information_schema_hints() }} - ), - - views_with_metadata as ( - select - table_name, - schema_name, - coalesce(views.principal_id, schemas.principal_id) as owner_principal_id, - table_type - from - views - join schemas on views.schema_id = schemas.schema_id - ), - - tables_and_views as ( - select - table_name, - schema_name, - principal_name, - table_type - from - tables_with_metadata - join principals on tables_with_metadata.owner_principal_id = principals.principal_id - union all - select - table_name, - schema_name, - principal_name, - table_type - from - views_with_metadata - join principals on views_with_metadata.owner_principal_id = principals.principal_id - ), - - cols as ( - - select - table_catalog as table_database, - table_schema, - table_name, - column_name, - ordinal_position as column_index, - data_type as column_type - from INFORMATION_SCHEMA.COLUMNS {{ information_schema_hints() }} - - ) - - select - cols.table_database, - tv.schema_name as table_schema, - tv.table_name, - tv.table_type, - null as table_comment, - tv.principal_name as table_owner, - cols.column_name, - cols.column_index, - cols.column_type, - null as column_comment - from tables_and_views tv - join cols on tv.schema_name = cols.table_schema and tv.table_name = cols.table_name - order by column_index - - {%- endcall -%} - - {{ return(load_result('catalog').table) }} - -{%- endmacro %} - -{% macro sqlserver__information_schema_name(database) -%} - {%- if database -%} - [{{ database }}].INFORMATION_SCHEMA - {%- else -%} - INFORMATION_SCHEMA - {%- endif -%} -{%- endmacro %} - -{% macro sqlserver__list_schemas(database) %} - {% call statement('list_schemas', fetch_result=True, auto_begin=False) -%} - USE {{ database }}; - select name as [schema] - from sys.schemas {{ information_schema_hints() }} - {% endcall %} - {{ return(load_result('list_schemas').table) }} -{% endmacro %} - -{% macro sqlserver__check_schema_exists(information_schema, schema) -%} - {% call statement('check_schema_exists', fetch_result=True, auto_begin=False) -%} - --USE {{ database_name }} - SELECT count(*) as schema_exist FROM sys.schemas WHERE name = '{{ schema }}' - {%- endcall %} - {{ return(load_result('check_schema_exists').table) }} -{% endmacro %} - -{% macro sqlserver__list_relations_without_caching(schema_relation) %} - {% call statement('list_relations_without_caching', fetch_result=True) -%} - select - table_catalog as [database], - table_name as [name], - table_schema as [schema], - case when table_type = 'BASE TABLE' then 'table' - when table_type = 'VIEW' then 'view' - else table_type - end as table_type - - from [{{ schema_relation.database }}].INFORMATION_SCHEMA.TABLES {{ information_schema_hints() }} - where table_schema = '{{ schema_relation.schema }}' - {% endcall %} - {{ return(load_result('list_relations_without_caching').table) }} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/persist_docs.sql b/dbt/include/sqlserver/macros/adapters/persist_docs.sql deleted file mode 100644 index 8b3e4f90..00000000 --- a/dbt/include/sqlserver/macros/adapters/persist_docs.sql +++ /dev/null @@ -1,4 +0,0 @@ -{# we don't support "persist docs" today, but we'd like to! - https://github.com/dbt-msft/dbt-sqlserver/issues/134 - - #} diff --git a/dbt/include/sqlserver/macros/adapters/relation.sql b/dbt/include/sqlserver/macros/adapters/relation.sql index 88b5e168..57defbd1 100644 --- a/dbt/include/sqlserver/macros/adapters/relation.sql +++ b/dbt/include/sqlserver/macros/adapters/relation.sql @@ -1,63 +1,5 @@ -{% macro sqlserver__make_temp_relation(base_relation, suffix) %} - {%- set temp_identifier = '#' ~ base_relation.identifier ~ suffix -%} - {%- set temp_relation = base_relation.incorporate( - path={"identifier": temp_identifier}) -%} - - {{ return(temp_relation) }} -{% endmacro %} - -{% macro sqlserver__drop_relation(relation) -%} - {% call statement('drop_relation', auto_begin=False) -%} - {{ sqlserver__drop_relation_script(relation) }} - {%- endcall %} -{% endmacro %} - -{% macro sqlserver__drop_relation_script(relation) -%} - {% call statement('find_references', fetch_result=true) %} - USE [{{ relation.database }}]; - select - sch.name as schema_name, - obj.name as view_name - from sys.sql_expression_dependencies refs - inner join sys.objects obj - on refs.referencing_id = obj.object_id - inner join sys.schemas sch - on obj.schema_id = sch.schema_id - where refs.referenced_database_name = '{{ relation.database }}' - and refs.referenced_schema_name = '{{ relation.schema }}' - and refs.referenced_entity_name = '{{ relation.identifier }}' - and refs.referencing_class = 1 - and obj.type = 'V' - {% endcall %} - {% set references = load_result('find_references')['data'] %} - {% for reference in references -%} - -- dropping referenced view {{ reference[0] }}.{{ reference[1] }} - {{ sqlserver__drop_relation_script(relation.incorporate( - type="view", - path={"schema": reference[0], "identifier": reference[1]})) }} - {% endfor %} - {% if relation.type == 'view' -%} - {% set object_id_type = 'V' %} - {% elif relation.type == 'table'%} - {% set object_id_type = 'U' %} - {%- else -%} - {{ exceptions.raise_not_implemented('Invalid relation being dropped: ' ~ relation) }} - {% endif %} - USE [{{ relation.database }}]; - if object_id ('{{ relation.include(database=False) }}','{{ object_id_type }}') is not null - begin - drop {{ relation.type }} {{ relation.include(database=False) }} - end -{% endmacro %} - -{% macro sqlserver__rename_relation(from_relation, to_relation) -%} - {% call statement('rename_relation') -%} - USE [{{ to_relation.database }}]; - EXEC sp_rename '{{ from_relation.schema }}.{{ from_relation.identifier }}', '{{ to_relation.identifier }}' - IF EXISTS( - SELECT * - FROM sys.indexes {{ information_schema_hints() }} - WHERE name='{{ from_relation.schema }}_{{ from_relation.identifier }}_cci' and object_id = OBJECT_ID('{{ from_relation.schema }}.{{ to_relation.identifier }}')) - EXEC sp_rename N'{{ from_relation.schema }}.{{ to_relation.identifier }}.{{ from_relation.schema }}_{{ from_relation.identifier }}_cci', N'{{ from_relation.schema }}_{{ to_relation.identifier }}_cci', N'INDEX' - {%- endcall %} +{% macro sqlserver__truncate_relation(relation) %} + {% call statement('truncate_relation') -%} + truncate table {{ relation }} + {%- endcall %} {% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/schema.sql b/dbt/include/sqlserver/macros/adapters/schema.sql deleted file mode 100644 index cf625f64..00000000 --- a/dbt/include/sqlserver/macros/adapters/schema.sql +++ /dev/null @@ -1,38 +0,0 @@ -{% macro sqlserver__create_schema(relation) -%} - {% call statement('create_schema') -%} - USE [{{ relation.database }}]; - IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ relation.schema }}') - BEGIN - EXEC('CREATE SCHEMA [{{ relation.schema }}]') - END - {% endcall %} -{% endmacro %} - -{% macro sqlserver__create_schema_with_authorization(relation, schema_authorization) -%} - {% call statement('create_schema') -%} - USE [{{ relation.database }}]; - IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ relation.schema }}') - BEGIN - EXEC('CREATE SCHEMA [{{ relation.schema }}] AUTHORIZATION [{{ schema_authorization }}]') - END - {% endcall %} -{% endmacro %} - -{% macro sqlserver__drop_schema(relation) -%} - {%- set relations_in_schema = list_relations_without_caching(relation) %} - - {% for row in relations_in_schema %} - {%- set schema_relation = api.Relation.create(database=relation.database, - schema=relation.schema, - identifier=row[1], - type=row[3] - ) -%} - {% do drop_relation(schema_relation) %} - {%- endfor %} - - {% call statement('drop_schema') -%} - IF EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ relation.schema }}') - BEGIN - EXEC('DROP SCHEMA {{ relation.schema }}') - END {% endcall %} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql deleted file mode 100644 index 7261e911..00000000 --- a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql +++ /dev/null @@ -1,9 +0,0 @@ -{% macro sqlserver__get_incremental_default_sql(arg_dict) %} - - {% if arg_dict["unique_key"] %} - {% do return(get_incremental_delete_insert_sql(arg_dict)) %} - {% else %} - {% do return(get_incremental_append_sql(arg_dict)) %} - {% endif %} - -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql deleted file mode 100644 index 38202a9f..00000000 --- a/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql +++ /dev/null @@ -1,57 +0,0 @@ - {# global project no longer includes semi-colons in merge statements, so - default macro are invoked below w/ a semi-colons after it. - more context: - https://github.com/dbt-labs/dbt-core/pull/3510 - https://getdbt.slack.com/archives/C50NEBJGG/p1636045535056600 - #} - -{% macro sqlserver__get_merge_sql(target, source, unique_key, dest_columns, incremental_predicates=none) %} - {{ default__get_merge_sql(target, source, unique_key, dest_columns, incremental_predicates) }}; -{% endmacro %} - -{% macro sqlserver__get_insert_overwrite_merge_sql(target, source, dest_columns, predicates, include_sql_header) %} - {{ default__get_insert_overwrite_merge_sql(target, source, dest_columns, predicates, include_sql_header) }}; -{% endmacro %} - -{% macro sqlserver__get_delete_insert_merge_sql(target, source, unique_key, dest_columns, incremental_predicates=none) %} - {%- set dest_cols_csv = get_quoted_csv(dest_columns | map(attribute="name")) -%} - - {% if unique_key %} - {% if unique_key is sequence and unique_key is not string %} - delete from {{ target }} - where exists ( - select null - from {{ source }} - where - {% for key in unique_key %} - {{ source }}.{{ key }} = {{ target }}.{{ key }} - {{ "and " if not loop.last }} - {% endfor %} - - ) - {% if incremental_predicates %} - {% for predicate in incremental_predicates %} - and {{ predicate }} - {% endfor %} - {% endif %}; - {% else %} - delete from {{ target }} - where ( - {{ unique_key }}) in ( - select ({{ unique_key }}) - from {{ source }} - ) - {%- if incremental_predicates %} - {% for predicate in incremental_predicates %} - and {{ predicate }} - {% endfor %} - {%- endif -%}; - {% endif %} - {% endif %} - - insert into {{ target }} ({{ dest_cols_csv }}) - ( - select {{ dest_cols_csv }} - from {{ source }} - ) -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/table/clone.sql b/dbt/include/sqlserver/macros/materializations/models/table/clone.sql new file mode 100644 index 00000000..c1a1ecc7 --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/table/clone.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__can_clone_table() %} + {{ return(False) }} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/table/create_table_as.sql b/dbt/include/sqlserver/macros/materializations/models/table/create_table_as.sql index 604ca7bb..d8365717 100644 --- a/dbt/include/sqlserver/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/sqlserver/macros/materializations/models/table/create_table_as.sql @@ -1,35 +1,38 @@ +{# +Fabric uses the 'CREATE TABLE XYZ AS SELECT * FROM ABC' syntax to create tables. +SQL Server doesnt support this, so we use the 'SELECT * INTO XYZ FROM ABC' syntax instead. +#} + {% macro sqlserver__create_table_as(temporary, relation, sql) -%} - {#- TODO: add contracts here when in dbt 1.5 -#} - {%- set sql_header = config.get('sql_header', none) -%} - {%- set as_columnstore = config.get('as_columnstore', default=true) -%} - {%- set temp_view_sql = sql.replace("'", "''") -%} - {%- set tmp_relation = relation.incorporate( - path={"identifier": relation.identifier.replace("#", "") ~ '_temp_view'}, - type='view') -%} - - {{- sql_header if sql_header is not none -}} - - -- drop previous temp view - {{- sqlserver__drop_relation_script(tmp_relation) }} - - -- create temp view - USE [{{ relation.database }}]; - EXEC('create view {{ tmp_relation.include(database=False) }} as - {{ temp_view_sql }} - '); - - -- select into the table and create it that way - {# TempDB schema is ignored, always goes to dbo #} - SELECT * - INTO {{ relation.include(database=False, schema=(not temporary)) }} - FROM {{ tmp_relation }} - - -- drop temp view - {{ sqlserver__drop_relation_script(tmp_relation) }} - - {%- if not temporary and as_columnstore -%} - -- add columnstore index - {{ sqlserver__create_clustered_columnstore_index(relation) }} - {%- endif -%} + + {% set tmp_relation = relation.incorporate( + path={"identifier": relation.identifier.replace("#", "") ~ '_temp_view'}, + type='view')-%} + {% do run_query(fabric__drop_relation_script(tmp_relation)) %} + {% do run_query(fabric__drop_relation_script(relation)) %} + + {% set contract_config = config.get('contract') %} + + {{ fabric__create_view_as(tmp_relation, sql) }} + {% if contract_config.enforced %} + + CREATE TABLE [{{relation.database}}].[{{relation.schema}}].[{{relation.identifier}}] + {{ fabric__table_columns_and_constraints(relation) }} + {{ get_assert_columns_equivalent(sql) }} + + {% set listColumns %} + {% for column in model['columns'] %} + {{ "["~column~"]" }}{{ ", " if not loop.last }} + {% endfor %} + {%endset%} + + INSERT INTO [{{relation.database}}].[{{relation.schema}}].[{{relation.identifier}}] + ({{listColumns}}) SELECT {{listColumns}} FROM [{{tmp_relation.database}}].[{{tmp_relation.schema}}].[{{tmp_relation.identifier}}]; + + {%- else %} + EXEC('SELECT * INTO [{{relation.database}}].[{{relation.schema}}].[{{relation.identifier}}] FROM [{{tmp_relation.database}}].[{{tmp_relation.schema}}].[{{tmp_relation.identifier}}];'); + {% endif %} + + {{ fabric__drop_relation_script(tmp_relation) }} {% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/view/create_view_as.sql b/dbt/include/sqlserver/macros/materializations/models/view/create_view_as.sql deleted file mode 100644 index fb54495f..00000000 --- a/dbt/include/sqlserver/macros/materializations/models/view/create_view_as.sql +++ /dev/null @@ -1,14 +0,0 @@ -{% macro sqlserver__create_view_as(relation, sql) -%} - {%- set sql_header = config.get('sql_header', none) -%} - {{ sql_header if sql_header is not none }} - USE [{{ relation.database }}]; - {{ sqlserver__create_view_exec(relation, sql) }} -{% endmacro %} - -{% macro sqlserver__create_view_exec(relation, sql) -%} - {#- TODO: add contracts here when in dbt 1.5 -#} - {%- set temp_view_sql = sql.replace("'", "''") -%} - execute('create view {{ relation.include(database=False) }} as - {{ temp_view_sql }} - '); -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/seeds/helpers.sql b/dbt/include/sqlserver/macros/materializations/seeds/helpers.sql deleted file mode 100644 index f1597227..00000000 --- a/dbt/include/sqlserver/macros/materializations/seeds/helpers.sql +++ /dev/null @@ -1,57 +0,0 @@ -{% macro sqlserver__get_binding_char() %} - {{ return('?') }} -{% endmacro %} - -{% macro sqlserver__get_batch_size() %} - {{ return(400) }} -{% endmacro %} - -{% macro calc_batch_size(num_columns) %} - {# - SQL Server allows for a max of 2100 parameters in a single statement. - Check if the max_batch_size fits with the number of columns, otherwise - reduce the batch size so it fits. - #} - {% set max_batch_size = get_batch_size() %} - {% set calculated_batch = (2100 / num_columns)|int %} - {% set batch_size = [max_batch_size, calculated_batch] | min %} - - {{ return(batch_size) }} -{% endmacro %} - -{% macro sqlserver__load_csv_rows(model, agate_table) %} - {% set cols_sql = get_seed_column_quoted_csv(model, agate_table.column_names) %} - {% set batch_size = calc_batch_size(agate_table.column_names|length) %} - {% set bindings = [] %} - {% set statements = [] %} - - {{ log("Inserting batches of " ~ batch_size ~ " records") }} - - {% for chunk in agate_table.rows | batch(batch_size) %} - {% set bindings = [] %} - - {% for row in chunk %} - {% do bindings.extend(row) %} - {% endfor %} - - {% set sql %} - insert into {{ this.render() }} ({{ cols_sql }}) values - {% for row in chunk -%} - ({%- for column in agate_table.column_names -%} - {{ get_binding_char() }} - {%- if not loop.last%},{%- endif %} - {%- endfor -%}) - {%- if not loop.last%},{%- endif %} - {%- endfor %} - {% endset %} - - {% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %} - - {% if loop.index0 == 0 %} - {% do statements.append(sql) %} - {% endif %} - {% endfor %} - - {# Return SQL so we can render it out into the compiled files #} - {{ return(statements[0]) }} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql index 928ac5fb..a863c4e5 100644 --- a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql +++ b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql @@ -1,19 +1,52 @@ -{% macro sqlserver__post_snapshot(staging_relation) %} - -- Clean up the snapshot temp table - {% do drop_relation(staging_relation) %} -{% endmacro %} +{# +Fabric uses the 'CREATE TABLE XYZ AS SELECT * FROM ABC' syntax to create tables. +SQL Server doesnt support this, so we use the 'SELECT * INTO XYZ FROM ABC' syntax instead. +#} {% macro sqlserver__create_columns(relation, columns) %} {# default__ macro uses "add column" TSQL preferes just "add" #} - {% for column in columns %} - {% call statement() %} - alter table {{ relation }} add "{{ column.name }}" {{ column.data_type }}; - {% endcall %} - {% endfor %} -{% endmacro %} -{% macro sqlserver__get_true_sql() %} - {{ return('1=1') }} + {% set columns %} + {% for column in columns %} + , CAST(NULL AS {{column.data_type}}) AS {{column_name}} + {% endfor %} + {% endset %} + + {% set tempTableName %} + [{{relation.database}}].[{{ relation.schema }}].[{{ relation.identifier }}_{{ range(1300, 19000) | random }}] + {% endset %} + + {% set tempTable %} + SELECT * INTO {{tempTableName}} {{columns}} FROM [{{relation.database}}].[{{ relation.schema }}].[{{ relation.identifier }}] {{ information_schema_hints() }} + {% endset %} + + {% call statement('create_temp_table') -%} + {{ tempTable }} + {%- endcall %} + + {% set dropTable %} + DROP TABLE [{{relation.database}}].[{{ relation.schema }}].[{{ relation.identifier }}] + {% endset %} + + {% call statement('drop_table') -%} + {{ dropTable }} + {%- endcall %} + + {% set createTable %} + SELECT * INTO {{ relation }} FROM {{tempTableName}} {{ information_schema_hints() }} + {% endset %} + + {% call statement('create_Table') -%} + {{ createTable }} + {%- endcall %} + + {% set dropTempTable %} + DROP TABLE {{tempTableName}} + {% endset %} + + {% call statement('drop_temp_table') -%} + {{ dropTempTable }} + {%- endcall %} {% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot_merge.sql b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot_merge.sql deleted file mode 100644 index ff27ae31..00000000 --- a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot_merge.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% macro sqlserver__snapshot_merge_sql(target, source, insert_cols) %} - {{ default__snapshot_merge_sql(target, source, insert_cols) }}; -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/strategies.sql b/dbt/include/sqlserver/macros/materializations/snapshots/strategies.sql deleted file mode 100644 index 9420f4aa..00000000 --- a/dbt/include/sqlserver/macros/materializations/snapshots/strategies.sql +++ /dev/null @@ -1,5 +0,0 @@ -{% macro sqlserver__snapshot_hash_arguments(args) %} - CONVERT(VARCHAR(32), HashBytes('MD5', {% for arg in args %} - coalesce(cast({{ arg }} as varchar(max)), '') {% if not loop.last %} + '|' + {% endif %} - {% endfor %}), 2) -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/tests/helpers.sql b/dbt/include/sqlserver/macros/materializations/tests/helpers.sql deleted file mode 100644 index 1122df16..00000000 --- a/dbt/include/sqlserver/macros/materializations/tests/helpers.sql +++ /dev/null @@ -1,12 +0,0 @@ -{% macro sqlserver__get_test_sql(main_sql, fail_calc, warn_if, error_if, limit) -%} - select - {{ "top (" ~ limit ~ ')' if limit != none }} - {{ fail_calc }} as failures, - case when {{ fail_calc }} {{ warn_if }} - then 'true' else 'false' end as should_warn, - case when {{ fail_calc }} {{ error_if }} - then 'true' else 'false' end as should_error - from ( - {{ main_sql }} - ) dbt_internal_test -{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/tests/test.sql b/dbt/include/sqlserver/macros/materializations/tests/test.sql deleted file mode 100644 index ceebea5f..00000000 --- a/dbt/include/sqlserver/macros/materializations/tests/test.sql +++ /dev/null @@ -1,48 +0,0 @@ -{%- materialization test, adapter='sqlserver' -%} - - {% set relations = [] %} - - {% set identifier = model['alias'] %} - {% set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) %} - {% set target_relation = api.Relation.create( - identifier=identifier, schema=schema, database=database, type='table') -%} %} - - - {% if old_relation %} - {% do adapter.drop_relation(old_relation) %} - {% elif not old_relation %} - {% do adapter.create_schema(target_relation) %} - {% endif %} - - {% call statement(auto_begin=True) %} - {{ create_table_as(False, target_relation, sql) }} - {% endcall %} - - {% set main_sql %} - select * - from {{ target_relation }} - {% endset %} - - {{ adapter.commit() }} - - - {% set limit = config.get('limit') %} - {% set fail_calc = config.get('fail_calc') %} - {% set warn_if = config.get('warn_if') %} - {% set error_if = config.get('error_if') %} - - {% call statement('main', fetch_result=True) -%} - - {{ get_test_sql(main_sql, fail_calc, warn_if, error_if, limit)}} - - {%- endcall %} - - {% if should_store_failures() %} - {% do relations.append(target_relation) %} - {% elif not should_store_failures() %} - {% do adapter.drop_relation(target_relation) %} - {% endif %} - - {{ return({'relations': relations}) }} - -{%- endmaterialization -%} diff --git a/dbt/include/sqlserver/macros/utils/any_value.sql b/dbt/include/sqlserver/macros/utils/any_value.sql deleted file mode 100644 index 6dcf8ec2..00000000 --- a/dbt/include/sqlserver/macros/utils/any_value.sql +++ /dev/null @@ -1,5 +0,0 @@ -{% macro sqlserver__any_value(expression) -%} - - min({{ expression }}) - -{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/array_construct.sql b/dbt/include/sqlserver/macros/utils/array_construct.sql deleted file mode 100644 index 5088c9ac..00000000 --- a/dbt/include/sqlserver/macros/utils/array_construct.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% macro sqlserver__array_construct(inputs, data_type) -%} - JSON_ARRAY({{ inputs|join(' , ') }}) -{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql b/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql deleted file mode 100644 index 9771afbf..00000000 --- a/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql +++ /dev/null @@ -1,7 +0,0 @@ -{% macro sqlserver__cast_bool_to_text(field) %} - case {{ field }} - when 1 then 'true' - when 0 then 'false' - else null - end -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/concat.sql b/dbt/include/sqlserver/macros/utils/concat.sql deleted file mode 100644 index 705e7f56..00000000 --- a/dbt/include/sqlserver/macros/utils/concat.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% macro sqlserver__concat(fields) -%} - concat({{ fields|join(', ') }}) -{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/date_trunc.sql b/dbt/include/sqlserver/macros/utils/date_trunc.sql deleted file mode 100644 index 85b4ce32..00000000 --- a/dbt/include/sqlserver/macros/utils/date_trunc.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% macro sqlserver__date_trunc(datepart, date) %} - CAST(DATEADD({{datepart}}, DATEDIFF({{datepart}}, 0, {{date}}), 0) AS DATE) -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/dateadd.sql b/dbt/include/sqlserver/macros/utils/dateadd.sql deleted file mode 100644 index 605379e3..00000000 --- a/dbt/include/sqlserver/macros/utils/dateadd.sql +++ /dev/null @@ -1,9 +0,0 @@ -{% macro sqlserver__dateadd(datepart, interval, from_date_or_timestamp) %} - - dateadd( - {{ datepart }}, - {{ interval }}, - cast({{ from_date_or_timestamp }} as datetime) - ) - -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/hash.sql b/dbt/include/sqlserver/macros/utils/hash.sql deleted file mode 100644 index e170631e..00000000 --- a/dbt/include/sqlserver/macros/utils/hash.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% macro sqlserver__hash(field) %} - lower(convert(varchar(50), hashbytes('md5', coalesce(convert(varchar(max), {{field}}), '')), 2)) -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/last_day.sql b/dbt/include/sqlserver/macros/utils/last_day.sql deleted file mode 100644 index c523d944..00000000 --- a/dbt/include/sqlserver/macros/utils/last_day.sql +++ /dev/null @@ -1,13 +0,0 @@ -{% macro sqlserver__last_day(date, datepart) -%} - - {%- if datepart == 'quarter' -%} - CAST(DATEADD(QUARTER, DATEDIFF(QUARTER, 0, {{ date }}) + 1, -1) AS DATE) - {%- elif datepart == 'month' -%} - EOMONTH ( {{ date }}) - {%- elif datepart == 'year' -%} - CAST(DATEADD(YEAR, DATEDIFF(year, 0, {{ date }}) + 1, -1) AS DATE) - {%- else -%} - {{dbt_utils.default_last_day(date, datepart)}} - {%- endif -%} - -{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/length.sql b/dbt/include/sqlserver/macros/utils/length.sql deleted file mode 100644 index ee9431ac..00000000 --- a/dbt/include/sqlserver/macros/utils/length.sql +++ /dev/null @@ -1,5 +0,0 @@ -{% macro sqlserver__length(expression) %} - - len( {{ expression }} ) - -{%- endmacro -%} diff --git a/dbt/include/sqlserver/macros/utils/listagg.sql b/dbt/include/sqlserver/macros/utils/listagg.sql deleted file mode 100644 index 4d6ab215..00000000 --- a/dbt/include/sqlserver/macros/utils/listagg.sql +++ /dev/null @@ -1,8 +0,0 @@ -{% macro sqlserver__listagg(measure, delimiter_text, order_by_clause, limit_num) -%} - - string_agg({{ measure }}, {{ delimiter_text }}) - {%- if order_by_clause != None %} - within group ({{ order_by_clause }}) - {%- endif %} - -{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/position.sql b/dbt/include/sqlserver/macros/utils/position.sql deleted file mode 100644 index bd3f6577..00000000 --- a/dbt/include/sqlserver/macros/utils/position.sql +++ /dev/null @@ -1,8 +0,0 @@ -{% macro sqlserver__position(substring_text, string_text) %} - - CHARINDEX( - {{ substring_text }}, - {{ string_text }} - ) - -{%- endmacro -%} diff --git a/dbt/include/sqlserver/macros/utils/safe_cast.sql b/dbt/include/sqlserver/macros/utils/safe_cast.sql deleted file mode 100644 index 4ae065a7..00000000 --- a/dbt/include/sqlserver/macros/utils/safe_cast.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% macro sqlserver__safe_cast(field, type) %} - try_cast({{field}} as {{type}}) -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/split_part.sql b/dbt/include/sqlserver/macros/utils/split_part.sql index a0e2a1c3..2e94e1fa 100644 --- a/dbt/include/sqlserver/macros/utils/split_part.sql +++ b/dbt/include/sqlserver/macros/utils/split_part.sql @@ -3,10 +3,15 @@ On Azure SQL and SQL Server 2019, we can use the string_split function instead of the XML trick. But since we don't know which version of SQL Server the user is using, we'll stick with the XML trick in this adapter. However, since the XML data type is not supported in Synapse, it has to be overriden in that adapter. + + To adjust for negative part numbers, aka 'from the end of the split', we take the position and subtract from last to get the specific part. + Since the input is '-1' for the last, '-2' for second last, we add 1 to the part number to get the correct position. #} {% macro sqlserver__split_part(string_text, delimiter_text, part_number) %} - - LTRIM(CAST((''+REPLACE({{ string_text }},{{ delimiter_text }} ,'')+'') AS XML).value('(/X)[{{ part_number }}]', 'VARCHAR(128)')) - + {% if part_number >= 0 %} + LTRIM(CAST((''+REPLACE({{ string_text }},{{ delimiter_text }} ,'')+'') AS XML).value('(/X)[{{ part_number }}]', 'VARCHAR(128)')) + {% else %} + LTRIM(CAST((''+REPLACE({{ string_text }},{{ delimiter_text }} ,'')+'') AS XML).value('(/X)[position() = last(){{ part_number }}+1][1]', 'VARCHAR(128)')) + {% endif %} {% endmacro %} diff --git a/dev_requirements.txt b/dev_requirements.txt index ad5a8b5d..6bb328e1 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,10 +1,10 @@ -pytest==7.3.1 +pytest==7.4.2 twine==4.0.2 -wheel==0.40.0 -pre-commit==2.21.0;python_version<"3.8" -pre-commit==3.3.2;python_version>="3.8" +wheel==0.42 +pre-commit==3.5 pytest-dotenv==0.5.2 -dbt-tests-adapter~=1.4.5 +dbt-tests-adapter==1.7.2 +dbt-fabric==1.7.2 flaky==3.7.0 -pytest-xdist==3.3.1 +pytest-xdist==3.5.0 -e . diff --git a/devops/server.Dockerfile b/devops/server.Dockerfile index 6ab402dd..8f7c3ece 100644 --- a/devops/server.Dockerfile +++ b/devops/server.Dockerfile @@ -2,7 +2,7 @@ ARG MSSQL_VERSION="2022" FROM mcr.microsoft.com/mssql/server:${MSSQL_VERSION}-latest ENV COLLATION="SQL_Latin1_General_CP1_CI_AS" - +USER root RUN mkdir -p /opt/init_scripts WORKDIR /opt/init_scripts COPY scripts/* /opt/init_scripts/ diff --git a/setup.py b/setup.py index f45a9338..749553ed 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ package_name = "dbt-sqlserver" authors_list = ["Mikael Ene", "Anders Swanson", "Sam Debruyn", "Cor Zuurmond"] -dbt_version = "1.4" +dbt_version = "1.7" description = """A Microsoft SQL Server adapter plugin for dbt""" this_directory = os.path.abspath(os.path.dirname(__file__)) @@ -66,8 +66,9 @@ def run(self): packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[ - "dbt-core~=1.4.5", - "pyodbc~=4.0.35,!=4.0.36,!=4.0.37", + "dbt-core~=1.7.2", + "dbt-fabric~=1.7.2", + "pyodbc>=4.0.35,<5.1.0", "azure-identity>=1.12.0", ], cmdclass={ diff --git a/tests/functional/adapter/test_basic.py b/tests/functional/adapter/test_basic.py index a263d590..f55e8e0b 100644 --- a/tests/functional/adapter/test_basic.py +++ b/tests/functional/adapter/test_basic.py @@ -13,6 +13,7 @@ from dbt.tests.adapter.basic.test_singular_tests_ephemeral import BaseSingularTestsEphemeral from dbt.tests.adapter.basic.test_snapshot_check_cols import BaseSnapshotCheckCols from dbt.tests.adapter.basic.test_snapshot_timestamp import BaseSnapshotTimestamp +from dbt.tests.adapter.basic.test_table_materialization import BaseTableMaterialization from dbt.tests.adapter.basic.test_validate_connection import BaseValidateConnection @@ -69,3 +70,7 @@ class TestBaseCachingSQLServer(BaseAdapterMethod): class TestValidateConnectionSQLServer(BaseValidateConnection): pass + + +class TestTableMaterializationSQLServer(BaseTableMaterialization): + ... diff --git a/tests/functional/adapter/test_changing_relation_type.py b/tests/functional/adapter/test_changing_relation_type.py index aaa43fa0..f135cf8f 100644 --- a/tests/functional/adapter/test_changing_relation_type.py +++ b/tests/functional/adapter/test_changing_relation_type.py @@ -1,5 +1,7 @@ +import pytest from dbt.tests.adapter.relations.test_changing_relation_type import BaseChangeRelationTypeValidator +@pytest.mark.skip(reason="CTAS is not supported without a underlying table definition.") class TestChangeRelationTypesSQLServer(BaseChangeRelationTypeValidator): pass diff --git a/tests/functional/adapter/test_data_types.py b/tests/functional/adapter/test_data_types.py index c091120d..28d67889 100644 --- a/tests/functional/adapter/test_data_types.py +++ b/tests/functional/adapter/test_data_types.py @@ -46,7 +46,7 @@ def seeds(self): - name: expected config: column_types: - timestamp_col: "datetimeoffset" + timestamp_col: "datetime2" """ return { diff --git a/tests/functional/adapter/test_debug.py b/tests/functional/adapter/test_debug.py index 84738fbc..7e5457cb 100644 --- a/tests/functional/adapter/test_debug.py +++ b/tests/functional/adapter/test_debug.py @@ -1,7 +1,9 @@ import os import re +import pytest import yaml +from dbt.cli.exceptions import DbtUsageException from dbt.tests.adapter.dbt_debug.test_dbt_debug import BaseDebug, BaseDebugProfileVariable from dbt.tests.util import run_dbt @@ -48,7 +50,9 @@ def test_badproject(self, project): self.check_project(splitout) def test_not_found_project(self, project): - run_dbt(["debug", "--project-dir", "nopass"], expect_pass=False) + with pytest.raises(DbtUsageException) as dbt_exeption: + run_dbt(["debug", "--project-dir", "nopass"], expect_pass=False) + dbt_exeption = dbt_exeption splitout = self.capsys.readouterr().out.split("\n") self.check_project(splitout, msg="ERROR not found") diff --git a/tests/functional/adapter/test_docs.py b/tests/functional/adapter/test_docs.py index c9039950..1130295e 100644 --- a/tests/functional/adapter/test_docs.py +++ b/tests/functional/adapter/test_docs.py @@ -29,7 +29,7 @@ def expected_catalog(self, project): role=os.getenv("DBT_TEST_USER_1"), id_type="int", text_type="varchar", - time_type="datetime", + time_type="datetime2", view_type="VIEW", table_type="BASE TABLE", model_stats=no_stats(), @@ -49,7 +49,7 @@ def expected_catalog(self, project): role=os.getenv("DBT_TEST_USER_1"), id_type="int", text_type="varchar", - time_type="datetime", + time_type="datetime2", bigint_type="int", view_type="VIEW", table_type="BASE TABLE", diff --git a/tests/functional/adapter/test_grants.py b/tests/functional/adapter/test_grants.py index aaac1c4d..1f297dbd 100644 --- a/tests/functional/adapter/test_grants.py +++ b/tests/functional/adapter/test_grants.py @@ -2,7 +2,11 @@ from dbt.tests.adapter.grants.test_invalid_grants import BaseInvalidGrants from dbt.tests.adapter.grants.test_model_grants import BaseModelGrants from dbt.tests.adapter.grants.test_seed_grants import BaseSeedGrants -from dbt.tests.adapter.grants.test_snapshot_grants import BaseSnapshotGrants +from dbt.tests.adapter.grants.test_snapshot_grants import ( + BaseSnapshotGrants, + user2_snapshot_schema_yml, +) +from dbt.tests.util import get_manifest, run_dbt, run_dbt_and_capture, write_file class TestIncrementalGrantsSQLServer(BaseIncrementalGrants): @@ -26,4 +30,34 @@ class TestSeedGrantsSQLServer(BaseSeedGrants): class TestSnapshotGrantsSQLServer(BaseSnapshotGrants): - pass + def test_snapshot_grants(self, project, get_test_users): + test_users = get_test_users + select_privilege_name = self.privilege_grantee_name_overrides()["select"] + + # run the snapshot + results = run_dbt(["snapshot"]) + assert len(results) == 1 + manifest = get_manifest(project.project_root) + snapshot_id = "snapshot.test.my_snapshot" + snapshot = manifest.nodes[snapshot_id] + expected = {select_privilege_name: [test_users[0]]} + assert snapshot.config.grants == expected + self.assert_expected_grants_match_actual(project, "my_snapshot", expected) + + # run it again, nothing should have changed + # we do expect to see the grant again. + # dbt selects into a temporary table, drops existing, selects into original table name + # this means we need to grant select again, so we will see the grant again + (results, log_output) = run_dbt_and_capture(["--debug", "snapshot"]) + assert len(results) == 1 + assert "revoke " not in log_output + assert "grant " in log_output + self.assert_expected_grants_match_actual(project, "my_snapshot", expected) + + # change the grantee, assert it updates + updated_yaml = self.interpolate_name_overrides(user2_snapshot_schema_yml) + write_file(updated_yaml, project.project_root, "snapshots", "schema.yml") + (results, log_output) = run_dbt_and_capture(["--debug", "snapshot"]) + assert len(results) == 1 + expected = {select_privilege_name: [test_users[1]]} + self.assert_expected_grants_match_actual(project, "my_snapshot", expected) diff --git a/tests/functional/adapter/test_incremental.py b/tests/functional/adapter/test_incremental.py index c4d14c2d..0c3356dc 100644 --- a/tests/functional/adapter/test_incremental.py +++ b/tests/functional/adapter/test_incremental.py @@ -117,4 +117,9 @@ class TestIncrementalPredicatesDeleteInsertSQLServer(BaseIncrementalPredicates): class TestPredicatesDeleteInsertSQLServer(BaseIncrementalPredicates): @pytest.fixture(scope="class") def project_config_update(self): - return {"models": {"+predicates": ["id != 2"], "+incremental_strategy": "delete+insert"}} + return { + "models": { + "+predicates": ["id != 2"], + "+incremental_strategy": "delete+insert", + } + } diff --git a/tests/functional/adapter/test_seed.py b/tests/functional/adapter/test_seed.py index 0eb26b66..bd173910 100644 --- a/tests/functional/adapter/test_seed.py +++ b/tests/functional/adapter/test_seed.py @@ -1,7 +1,8 @@ import os import pytest -from dbt.tests.adapter.simple_seed.seeds import seeds__expected_sql +from dbt.tests.adapter.simple_seed.fixtures import models__downstream_from_seed_actual +from dbt.tests.adapter.simple_seed.seeds import seed__actual_csv, seeds__expected_sql from dbt.tests.adapter.simple_seed.test_seed import SeedConfigBase from dbt.tests.adapter.simple_seed.test_seed import TestBasicSeedTests as BaseBasicSeedTests from dbt.tests.adapter.simple_seed.test_seed import ( @@ -23,13 +24,13 @@ seeds__disabled_in_config_csv, seeds__enabled_in_config_csv, ) -from dbt.tests.util import get_connection, run_dbt +from dbt.tests.util import check_relations_equal, check_table_does_exist, get_connection, run_dbt from dbt.adapters.sqlserver import SQLServerAdapter -fixed_setup_sql = seeds__expected_sql.replace("TIMESTAMP WITHOUT TIME ZONE", "DATETIME").replace( - "TEXT", "VARCHAR(255)" -) +fixed_setup_sql = seeds__expected_sql.replace( + "TIMESTAMP WITHOUT TIME ZONE", "DATETIME2(6)" +).replace("TEXT", "VARCHAR(255)") seeds__tricky_csv = """ seed_id,seed_id_str,a_bool,looks_like_a_bool,a_date,looks_like_a_date,relative,weekday @@ -104,7 +105,7 @@ - name: a_date tests: - column_type: - type: datetime + type: datetime2 - name: looks_like_a_date tests: - column_type: @@ -145,12 +146,35 @@ class TestBasicSeedTestsSQLServer(BaseBasicSeedTests): def setUp(self, project): project.run_sql(fixed_setup_sql) + def test_simple_seed(self, project): + """Build models and observe that run truncates a seed and re-inserts rows""" + self._build_relations_for_test(project) + self._check_relation_end_state(run_result=run_dbt(["seed"]), project=project, exists=True) + + def test_simple_seed_full_refresh_flag(self, project): + """Drop the seed_actual table and re-create. + Verifies correct behavior by the absence of the + model which depends on seed_actual.""" + self._build_relations_for_test(project) + self._check_relation_end_state( + run_result=run_dbt(["seed", "--full-refresh"]), project=project, exists=True + ) + class TestSeedConfigFullRefreshOnSQLServer(BaseSeedConfigFullRefreshOn): @pytest.fixture(scope="class", autouse=True) def setUp(self, project): project.run_sql(fixed_setup_sql) + def test_simple_seed_full_refresh_config(self, project): + """Drop the seed_actual table and re-create. + Verifies correct behavior by the absence of the + model which depends on seed_actual.""" + self._build_relations_for_test(project) + self._check_relation_end_state( + run_result=run_dbt(["seed", "--full-refresh"]), project=project, exists=True + ) + class TestSeedConfigFullRefreshOffSQLServer(BaseSeedConfigFullRefreshOff): @pytest.fixture(scope="class", autouse=True) @@ -217,5 +241,51 @@ def test_custom_batch_size(self, project, logs_dir): run_dbt(["seed"]) with open(os.path.join(logs_dir, "dbt.log"), "r") as fp: logs = "".join(fp.readlines()) + # this is changed from 350. + # Fabric goes -1 of min batch of (2100/number of columns -1) or 400 + assert "Inserting batches of 349.0 records" in logs + + +class SeedConfigBase: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "seeds": { + "quote_columns": False, + }, + } + + +class SeedTestBase(SeedConfigBase): + @pytest.fixture(scope="class", autouse=True) + def setUp(self, project): + """Create table for ensuring seeds and models used in tests build correctly""" + project.run_sql(seeds__expected_sql) + + @pytest.fixture(scope="class") + def seeds(self, test_data_dir): + return {"seed_actual.csv": seed__actual_csv} + + @pytest.fixture(scope="class") + def models(self): + return { + "models__downstream_from_seed_actual.sql": models__downstream_from_seed_actual, + } - assert "Inserting batches of 350 records" in logs + def _build_relations_for_test(self, project): + """The testing environment needs seeds and models to interact with""" + seed_result = run_dbt(["seed"]) + assert len(seed_result) == 1 + check_relations_equal(project.adapter, ["seed_expected", "seed_actual"]) + + run_result = run_dbt() + assert len(run_result) == 1 + check_relations_equal( + project.adapter, ["models__downstream_from_seed_actual", "seed_expected"] + ) + + def _check_relation_end_state(self, run_result, project, exists: bool): + assert len(run_result) == 1 + check_relations_equal(project.adapter, ["seed_actual", "seed_expected"]) + if exists: + check_table_does_exist(project.adapter, "models__downstream_from_seed_actual") diff --git a/tests/functional/adapter/test_utils.py b/tests/functional/adapter/test_utils.py index e1072473..be166f29 100644 --- a/tests/functional/adapter/test_utils.py +++ b/tests/functional/adapter/test_utils.py @@ -45,16 +45,16 @@ def macros(self): } -class TestAnyValueSQLServer(BaseFixedMacro, BaseAnyValue): +class TestAnyValueSQLServer(BaseAnyValue): pass @pytest.mark.skip("bool_or not supported in this adapter") -class TestBoolOrSQLServer(BaseFixedMacro, BaseBoolOr): +class TestBoolOrSQLServer(BaseBoolOr): pass -class TestCastBoolToTextSQLServer(BaseFixedMacro, BaseCastBoolToText): +class TestCastBoolToTextSQLServer(BaseCastBoolToText): @pytest.fixture(scope="class") def models(self): models__test_cast_bool_to_text_sql = """ @@ -82,7 +82,7 @@ def models(self): } -class TestConcatSQLServer(BaseFixedMacro, BaseConcat): +class TestConcatSQLServer(BaseConcat): @pytest.fixture(scope="class") def seeds(self): return { @@ -94,7 +94,7 @@ def seeds(self): } -class TestDateTruncSQLServer(BaseFixedMacro, BaseDateTrunc): +class TestDateTruncSQLServer(BaseDateTrunc): pass @@ -105,41 +105,41 @@ class TestDateTruncSQLServer(BaseFixedMacro, BaseDateTrunc): ,d41d8cd98f00b204e9800998ecf8427e""" -class TestHashSQLServer(BaseFixedMacro, BaseHash): +class TestHashSQLServer(BaseHash): @pytest.fixture(scope="class") def seeds(self): return {"data_hash.csv": seeds__data_hash_csv} -class TestStringLiteralSQLServer(BaseFixedMacro, BaseStringLiteral): +class TestStringLiteralSQLServer(BaseStringLiteral): pass -class TestSplitPartSQLServer(BaseFixedMacro, BaseSplitPart): +class TestSplitPartSQLServer(BaseSplitPart): pass -class TestDateDiffSQLServer(BaseFixedMacro, BaseDateDiff): +class TestDateDiffSQLServer(BaseDateDiff): pass -class TestEscapeSingleQuotesSQLServer(BaseFixedMacro, BaseEscapeSingleQuotesQuote): +class TestEscapeSingleQuotesSQLServer(BaseEscapeSingleQuotesQuote): pass -class TestIntersectSQLServer(BaseFixedMacro, BaseIntersect): +class TestIntersectSQLServer(BaseIntersect): pass -class TestLastDaySQLServer(BaseFixedMacro, BaseLastDay): +class TestLastDaySQLServer(BaseLastDay): pass -class TestLengthSQLServer(BaseFixedMacro, BaseLength): +class TestLengthSQLServer(BaseLength): pass -class TestListaggSQLServer(BaseFixedMacro, BaseListagg): +class TestListaggSQLServer(BaseListagg): # Only supported in SQL Server 2017 and later or cloud versions # DISTINCT not supported # limit not supported @@ -221,15 +221,15 @@ def models(self): } -class TestRightSQLServer(BaseFixedMacro, BaseRight): +class TestRightSQLServer(BaseRight): pass -class TestSafeCastSQLServer(BaseFixedMacro, BaseSafeCast): +class TestSafeCastSQLServer(BaseSafeCast): pass -class TestDateAddSQLServer(BaseFixedMacro, BaseDateAdd): +class TestDateAddSQLServer(BaseDateAdd): @pytest.fixture(scope="class") def project_config_update(self): return { @@ -247,15 +247,15 @@ def project_config_update(self): } -class TestExceptSQLServer(BaseFixedMacro, BaseExcept): +class TestExceptSQLServer(BaseExcept): pass -class TestPositionSQLServer(BaseFixedMacro, BasePosition): +class TestPositionSQLServer(BasePosition): pass -class TestReplaceSQLServer(BaseFixedMacro, BaseReplace): +class TestReplaceSQLServer(BaseReplace): pass @@ -264,15 +264,15 @@ class TestCurrentTimestampSQLServer(BaseCurrentTimestampNaive): @pytest.mark.skip(reason="arrays not supported") -class TestArrayAppendSQLServer(BaseFixedMacro, BaseArrayAppend): +class TestArrayAppendSQLServer(BaseArrayAppend): pass @pytest.mark.skip(reason="arrays not supporteTd") -class TestArrayConcatSQLServer(BaseFixedMacro, BaseArrayConcat): +class TestArrayConcatSQLServer(BaseArrayConcat): pass @pytest.mark.skip(reason="arrays not supported") -class TestArrayConstructSQLServer(BaseFixedMacro, BaseArrayConstruct): +class TestArrayConstructSQLServer(BaseArrayConstruct): pass diff --git a/tests/unit/adapters/sqlserver/test_sql_server_connection_manager.py b/tests/unit/adapters/sqlserver/test_sql_server_connection_manager.py index b321da0c..010ecd1a 100644 --- a/tests/unit/adapters/sqlserver/test_sql_server_connection_manager.py +++ b/tests/unit/adapters/sqlserver/test_sql_server_connection_manager.py @@ -1,13 +1,8 @@ -import datetime as dt -import json -from unittest import mock - import pytest from azure.identity import AzureCliCredential -from dbt.adapters.sqlserver.sql_server_connection_manager import ( +from dbt.adapters.sqlserver.sql_server_connection_manager import ( # byte_array_to_datetime, bool_to_connection_string_arg, - byte_array_to_datetime, get_pyodbc_attrs_before, ) from dbt.adapters.sqlserver.sql_server_credentials import SQLServerCredentials @@ -28,22 +23,22 @@ def credentials() -> SQLServerCredentials: return credentials -@pytest.fixture -def mock_cli_access_token() -> str: - access_token = "access token" - expected_expires_on = 1602015811 - successful_output = json.dumps( - { - "expiresOn": dt.datetime.fromtimestamp(expected_expires_on).strftime( - "%Y-%m-%d %H:%M:%S.%f" - ), - "accessToken": access_token, - "subscription": "some-guid", - "tenant": "some-guid", - "tokenType": "Bearer", - } - ) - return successful_output +# @pytest.fixture +# def mock_cli_access_token() -> str: +# access_token = "access token" +# expected_expires_on = 1602015811 +# successful_output = json.dumps( +# { +# "expiresOn": dt.datetime.fromtimestamp(expected_expires_on).strftime( +# "%Y-%m-%d %H:%M:%S.%f" +# ), +# "accessToken": access_token, +# "subscription": "some-guid", +# "tenant": "some-guid", +# "tokenType": "Bearer", +# } +# ) +# return successful_output def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( @@ -56,77 +51,78 @@ def test_get_pyodbc_attrs_before_empty_dict_when_service_principal( assert attrs_before == {} -@pytest.mark.parametrize("authentication", ["CLI", "cli", "cLi"]) -def test_get_pyodbc_attrs_before_contains_access_token_key_for_cli_authentication( - credentials: SQLServerCredentials, - authentication: str, - mock_cli_access_token: str, -) -> None: - """ - When the cli authentication is used, the attrs before should contain an - access token key. - """ - credentials.authentication = authentication - with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=mock_cli_access_token)): - attrs_before = get_pyodbc_attrs_before(credentials) - assert 1256 in attrs_before.keys() +# @pytest.mark.parametrize("authentication", ["CLI", "cli", "cLi"]) +# def test_get_pyodbc_attrs_before_contains_access_token_key_for_cli_authentication( +# credentials: SQLServerCredentials, +# authentication: str, +# mock_cli_access_token: str, +# ) -> None: +# """ +# When the cli authentication is used, the attrs before should contain an +# access token key. +# """ +# credentials.authentication = authentication +# with mock.patch(CHECK_OUTPUT, mock.Mock(return_value=mock_cli_access_token)): +# attrs_before = get_pyodbc_attrs_before(credentials) +# assert 1256 in attrs_before.keys() @pytest.mark.parametrize( - "key, value, expected", [("somekey", False, "somekey=No"), ("somekey", True, "somekey=Yes")] + "key, value, expected", + [("somekey", False, "somekey=No"), ("somekey", True, "somekey=Yes")], ) def test_bool_to_connection_string_arg(key: str, value: bool, expected: str) -> None: assert bool_to_connection_string_arg(key, value) == expected -@pytest.mark.parametrize( - "value, expected_datetime, expected_str", - [ - ( - bytes( - [ - 0xE6, - 0x07, # 2022 year unsigned short - 0x0C, - 0x00, # 12 month unsigned short - 0x11, - 0x00, # 17 day unsigned short - 0x11, - 0x00, # 17 hour unsigned short - 0x34, - 0x00, # 52 minute unsigned short - 0x12, - 0x00, # 18 second unsigned short - 0xBC, - 0xCC, - 0x5B, - 0x07, # 123456700 10⁻⁷ second unsigned long - 0xFE, - 0xFF, # -2 offset hour signed short - 0xE2, - 0xFF, # -30 offset minute signed short - ] - ), - dt.datetime( - year=2022, - month=12, - day=17, - hour=17, - minute=52, - second=18, - microsecond=123456700 // 1000, # 10⁻⁶ second - tzinfo=dt.timezone(dt.timedelta(hours=-2, minutes=-30)), - ), - "2022-12-17 17:52:18.123456-02:30", - ) - ], -) -def test_byte_array_to_datetime( - value: bytes, expected_datetime: dt.datetime, expected_str: str -) -> None: - """ - Assert SQL_SS_TIMESTAMPOFFSET_STRUCT bytes are converted to datetime and str - https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date-time/data-type-support-for-odbc-date-and-time-improvements#sql_ss_timestampoffset_struct - """ - assert byte_array_to_datetime(value) == expected_datetime - assert str(byte_array_to_datetime(value)) == expected_str +# @pytest.mark.parametrize( +# "value, expected_datetime, expected_str", +# [ +# ( +# bytes( +# [ +# 0xE6, +# 0x07, # 2022 year unsigned short +# 0x0C, +# 0x00, # 12 month unsigned short +# 0x11, +# 0x00, # 17 day unsigned short +# 0x11, +# 0x00, # 17 hour unsigned short +# 0x34, +# 0x00, # 52 minute unsigned short +# 0x12, +# 0x00, # 18 second unsigned short +# 0xBC, +# 0xCC, +# 0x5B, +# 0x07, # 123456700 10⁻⁷ second unsigned long +# 0xFE, +# 0xFF, # -2 offset hour signed short +# 0xE2, +# 0xFF, # -30 offset minute signed short +# ] +# ), +# dt.datetime( +# year=2022, +# month=12, +# day=17, +# hour=17, +# minute=52, +# second=18, +# microsecond=123456700 // 1000, # 10⁻⁶ second +# tzinfo=dt.timezone(dt.timedelta(hours=-2, minutes=-30)), +# ), +# "2022-12-17 17:52:18.123456-02:30", +# ) +# ], +# ) +# def test_byte_array_to_datetime( +# value: bytes, expected_datetime: dt.datetime, expected_str: str +# ) -> None: +# """ +# Assert SQL_SS_TIMESTAMPOFFSET_STRUCT bytes are converted to datetime and str +# https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date-time/data-type-support-for-odbc-date-and-time-improvements#sql_ss_timestampoffset_struct +# """ +# assert byte_array_to_datetime(value) == expected_datetime +# assert str(byte_array_to_datetime(value)) == expected_str