diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index e4cd7128..f92a27c3 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -305,6 +305,51 @@ jobs: name: buenavista_results_${{ matrix.python-version }}-${{ steps.date.outputs.date }}.csv path: buenavista_results.csv + unity: + name: Unity Catalog functional test / python ${{ matrix.python-version }} + + runs-on: ubuntu-latest + + strategy: + fail-fast: false + matrix: + python-version: [ '3.9' ] + + env: + TOXENV: "unity" + PYTEST_ADDOPTS: "-v --color=yes --csv unity_results.csv" + + steps: + - name: Check out the repository + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install python dependencies + run: | + python -m pip install tox + python -m pip --version + tox --version + + - name: Mock Unity catalog server and run tox + run: ./scripts/test-unity.sh + + - name: Get current date + if: always() + id: date + run: echo "date=$(date +'%Y-%m-%dT%H_%M_%S')" >> $GITHUB_OUTPUT #no colons allowed for artifacts + + - uses: actions/upload-artifact@v3 + if: always() + with: + name: unity_results{{ matrix.python-version }}-${{ steps.date.outputs.date }}.csv + path: unity_results.csv + fsspec: name: fsspec test / python ${{ matrix.python-version }} diff --git a/.gitignore b/.gitignore index 75c102c1..cd6bad96 100644 --- a/.gitignore +++ b/.gitignore @@ -80,3 +80,4 @@ target/ .idea/ .vscode/ .env +.venv/* diff --git a/.mock-uc-server-stats.yml b/.mock-uc-server-stats.yml new file mode 100644 index 00000000..1aed4174 --- /dev/null +++ b/.mock-uc-server-stats.yml @@ -0,0 +1,2 @@ +configured_endpoints: 25 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/eventual%2Funitycatalog-afb7536b3c70b699dd3090dc47209959c1591beae2945ee5fe14e1b53139fe83.yml diff --git a/dbt/adapters/duckdb/credentials.py b/dbt/adapters/duckdb/credentials.py index 56d3faf3..82f3c5dc 100644 --- a/dbt/adapters/duckdb/credentials.py +++ b/dbt/adapters/duckdb/credentials.py @@ -5,14 +5,12 @@ from typing import Dict from typing import List from typing import Optional -from typing import Tuple from urllib.parse import urlparse from dbt_common.dataclass_schema import dbtClassMixin from dbt_common.exceptions import DbtRuntimeError from dbt.adapters.contracts.connection import Credentials -from dbt.adapters.duckdb.secrets import DEFAULT_SECRET_PREFIX from dbt.adapters.duckdb.secrets import Secret @@ -77,6 +75,12 @@ class Retries(dbtClassMixin): retryable_exceptions: List[str] = field(default_factory=lambda: ["IOException"]) +@dataclass +class Extension(dbtClassMixin): + name: str + repository: Optional[str] = None + + @dataclass class DuckDBCredentials(Credentials): database: str = "main" @@ -88,7 +92,7 @@ class DuckDBCredentials(Credentials): config_options: Optional[Dict[str, Any]] = None # any DuckDB extensions we want to install and load (httpfs, parquet, etc.) - extensions: Optional[Tuple[str, ...]] = None + extensions: Optional[List[Extension]] = None # any additional pragmas we want to configure on our DuckDB connections; # a list of the built-in pragmas can be found here: @@ -177,11 +181,12 @@ def __post_init__(self): if self.secrets: self._secrets = [ Secret.create( - secret_type=secret.pop("type"), - name=secret.pop("name", f"{DEFAULT_SECRET_PREFIX}{num + 1}"), + secret_type=secret_type, + name=secret.pop("name", f"__default_{secret_type}"), **secret, ) - for num, secret in enumerate(self.secrets) + for secret in self.secrets + if (secret_type := secret.get("type")) ] def secrets_sql(self) -> List[str]: diff --git a/dbt/adapters/duckdb/environments/__init__.py b/dbt/adapters/duckdb/environments/__init__.py index cab331e6..f6dfceda 100644 --- a/dbt/adapters/duckdb/environments/__init__.py +++ b/dbt/adapters/duckdb/environments/__init__.py @@ -167,8 +167,19 @@ def initialize_db( # install any extensions on the connection if creds.extensions is not None: for extension in creds.extensions: - conn.install_extension(extension) - conn.load_extension(extension) + if extension.repository: + conn.execute(f"SET custom_extension_repository = '{extension.repository}'") + else: + conn.execute( + "SET custom_extension_repository = 'http://extensions.duckdb.org'" + ) + conn.install_extension(extension.name) + conn.load_extension(extension.name) + + # install any secrets on the connection + if creds.secrets: + for sql in creds.secrets_sql(): + conn.execute(sql) # Attach any fsspec filesystems on the database if creds.filesystems: @@ -207,9 +218,6 @@ def initialize_cursor( # to the correct type cursor.execute(f"SET {key} = '{value}'") - for sql in creds.secrets_sql(): - cursor.execute(sql) - # update cursor if something is lost in the copy # of the parent connection if plugins: diff --git a/dbt/adapters/duckdb/environments/local.py b/dbt/adapters/duckdb/environments/local.py index 1897175c..8925e24d 100644 --- a/dbt/adapters/duckdb/environments/local.py +++ b/dbt/adapters/duckdb/environments/local.py @@ -1,10 +1,13 @@ import threading +import pyarrow from dbt_common.exceptions import DbtRuntimeError +from duckdb import CatalogException from . import Environment from .. import credentials from .. import utils +from ..utils import get_retry_decorator from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.contracts.connection import Connection @@ -29,6 +32,7 @@ def execute(self, sql, bindings=None): class DuckDBConnectionWrapper: def __init__(self, cursor, env): + self._conn = env.conn self._cursor = DuckDBCursorWrapper(cursor) self._env = env @@ -98,7 +102,8 @@ def load_source(self, plugin_name: str, source_config: utils.SourceConfig): handle = self.handle() cursor = handle.cursor() - if source_config.schema: + # Schema creation is currently not supported by the uc_catalog duckdb extension + if source_config.schema and plugin_name != "unity": cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {source_config.schema}") save_mode = source_config.get("save_mode", "overwrite") @@ -133,13 +138,49 @@ def load_source(self, plugin_name: str, source_config: utils.SourceConfig): # save to df instance to register on each cursor creation self._REGISTERED_DF[df_name] = df - cursor.execute( - f"CREATE OR REPLACE {materialization} {source_table_name} AS SELECT * FROM {df_name}" - ) + # CREATE OR REPLACE table creation is currently not supported by the uc_catalog duckdb extension + if plugin_name != "unity": + cursor.execute( + f"CREATE OR REPLACE {materialization} {source_table_name} AS SELECT * FROM {df_name}" + ) cursor.close() handle.close() + def get_arrow_dataframe( + self, compiled_code: str, retries: int, wait_time: float + ) -> pyarrow.lib.Table: + """Get the arrow dataframe from the compiled code. + + :param compiled_code: Compiled code + :param retries: Number of retries + :param wait_time: Wait time between retries + + :returns: Arrow dataframe + """ + + @get_retry_decorator(retries, wait_time, CatalogException) + def execute_query(): + try: + # Get the handle and cursor + handle = self.handle() + cursor = handle.cursor() + + # Execute the compiled code + df = cursor.sql(compiled_code).arrow() + + return df + except CatalogException as e: + # Reset the connection to refresh the catalog + self.conn = None + + # Raise the exception to retry the operation + raise CatalogException( + f"{str(e)}: failed to execute compiled code {compiled_code}" + ) + + return execute_query() + def store_relation(self, plugin_name: str, target_config: utils.TargetConfig) -> None: if plugin_name not in self._plugins: if plugin_name.startswith("glue|"): @@ -155,7 +196,25 @@ def store_relation(self, plugin_name: str, target_config: utils.TargetConfig) -> + ",".join(self._plugins.keys()) ) plugin = self._plugins[plugin_name] - plugin.store(target_config) + + handle = self.handle() + cursor = handle.cursor() + + # Get the number of retries and the wait time for a dbt model + retries = int(target_config.config.get("retries", 20)) + wait_time = float(target_config.config.get("wait_time", 0.05)) + + # Get the arrow dataframe + df = self.get_arrow_dataframe( + compiled_code=target_config.config.model.compiled_code, + retries=retries, + wait_time=wait_time, + ) + + plugin.store(target_config, df) + + cursor.close() + handle.close() def close(self): if self.conn: diff --git a/dbt/adapters/duckdb/plugins/__init__.py b/dbt/adapters/duckdb/plugins/__init__.py index 9e610d3a..45f8b3f5 100644 --- a/dbt/adapters/duckdb/plugins/__init__.py +++ b/dbt/adapters/duckdb/plugins/__init__.py @@ -90,6 +90,7 @@ def __init__( """ self.name = name self.creds = credentials + self.plugin_config = plugin_config self.initialize(plugin_config) def initialize(self, plugin_config: Dict[str, Any]): @@ -134,7 +135,7 @@ def load(self, source_config: SourceConfig): """ raise NotImplementedError(f"load method not implemented for {self.name}") - def store(self, target_config: TargetConfig): + def store(self, target_config: TargetConfig, df=None): raise NotImplementedError(f"store method not implemented for {self.name}") def configure_cursor(self, cursor): diff --git a/dbt/adapters/duckdb/plugins/delta.py b/dbt/adapters/duckdb/plugins/delta.py index c6b0aa2a..2c3fd901 100644 --- a/dbt/adapters/duckdb/plugins/delta.py +++ b/dbt/adapters/duckdb/plugins/delta.py @@ -1,10 +1,155 @@ +from __future__ import annotations + +from enum import Enum from typing import Any from typing import Dict +import pyarrow as pa +import pyarrow.compute as pc from deltalake import DeltaTable +from deltalake import write_deltalake +from deltalake._internal import TableNotFoundError from . import BasePlugin from ..utils import SourceConfig +from ..utils import TargetConfig + + +class WriteModes(str, Enum): + """Enum class for the write modes supported by the plugin.""" + + OVERWRITE_PARTITION = "overwrite_partition" + MERGE = "merge" + OVERWRITE = "overwrite" + + +class PartitionKeyMissingError(Exception): + """Exception raised when the partition key is missing from the target configuration.""" + + pass + + +class UniqueKeyMissingError(Exception): + """Exception raised when the unique key is missing from the target configuration.""" + + pass + + +class DeltaTablePathMissingError(Exception): + """Exception raised when the delta table path is missing from the source configuration.""" + + pass + + +def delta_table_exists(table_path: str, storage_options: dict) -> bool: + """Check if a delta table exists at the given path.""" + try: + DeltaTable(table_path, storage_options=storage_options) + except TableNotFoundError: + return False + return True + + +def create_insert_partition( + table_path: str, data: pa.lib.Table, partitions: list, storage_options: dict +) -> None: + """Create a new delta table with partitions or overwrite an existing one.""" + + if delta_table_exists(table_path, storage_options): + partition_expr = [ + (partition_name, "=", partition_value) + for (partition_name, partition_value) in partitions + ] + print( + f"Overwriting delta table under: {table_path} \nwith partition expr: {partition_expr}" + ) + write_deltalake(table_path, data, partition_filters=partition_expr, mode="overwrite") + else: + partitions = [partition_name for (partition_name, partition_value) in partitions] + print(f"Creating delta table under: {table_path} \nwith partitions: {partitions}") + write_deltalake(table_path, data, partition_by=partitions) + + +def delta_write( + mode: WriteModes, + table_path: str, + df: pa.lib.Table, + storage_options: dict, + partition_key: str | list[str], + unique_key: str | list[str], +): + if storage_options is None: + storage_options = {} + + if mode == WriteModes.OVERWRITE_PARTITION: + if not partition_key: + raise PartitionKeyMissingError( + "'partition_key' has to be defined for mode 'overwrite_partition'!" + ) + + if isinstance(partition_key, str): + partition_key = [partition_key] + + partition_dict = [] + # TODO: Add support overwriting multiple partitions + for each_key in partition_key: + unique_key_array = pc.unique(df[each_key]) + + if len(unique_key_array) == 1: + partition_dict.append((each_key, str(unique_key_array[0]))) + else: + raise Exception( + f"'{each_key}' column has not one unique value, values are: {str(unique_key_array)}" + ) + create_insert_partition(table_path, df, partition_dict, storage_options) + elif mode == WriteModes.MERGE: + if not unique_key: + raise UniqueKeyMissingError("'unique_key' has to be defined when mode 'merge'!") + if isinstance(unique_key, str): + unique_key = [unique_key] + + predicate_stm = " and ".join( + [ + f'source."{each_unique_key}" = target."{each_unique_key}"' + for each_unique_key in unique_key + ] + ) + + if not delta_table_exists(table_path, storage_options): + write_deltalake(table_or_uri=table_path, data=df, storage_options=storage_options) + + target_dt = DeltaTable(table_path, storage_options=storage_options) + # TODO there is a problem if the column name is uppercase + target_dt.merge( + source=df, + predicate=predicate_stm, + source_alias="source", + target_alias="target", + ).when_not_matched_insert_all().execute() + elif mode == WriteModes.OVERWRITE: + write_deltalake( + table_or_uri=table_path, + data=df, + mode="overwrite", + storage_options=storage_options, + ) + else: + raise NotImplementedError(f"Mode {mode} not supported!") + + # TODO: Add support for OPTIMIZE + + +def delta_load(table_path: str, storage_options: dict, as_of_version: int, as_of_datetime: str): + """Load a delta table as a pyarrow dataset.""" + dt = DeltaTable(table_path, storage_options=storage_options) + + if as_of_version: + dt.load_version(as_of_version) + + if as_of_datetime: + dt.load_with_datetime(as_of_datetime) + + return dt.to_pyarrow_dataset() class Plugin(BasePlugin): @@ -16,33 +161,48 @@ def configure_cursor(self, cursor): def load(self, source_config: SourceConfig): if "delta_table_path" not in source_config: - raise Exception("'delta_table_path' is a required argument for the delta table!") + raise DeltaTablePathMissingError( + "'delta_table_path' is a required argument for the delta table!" + ) + # Get required variables from the source configuration table_path = source_config["delta_table_path"] - storage_options = source_config.get("storage_options", None) - - if storage_options: - dt = DeltaTable(table_path, storage_options=storage_options) - else: - dt = DeltaTable(table_path) - # delta attributes + # Get optional variables from the source configuration as_of_version = source_config.get("as_of_version", None) as_of_datetime = source_config.get("as_of_datetime", None) + storage_options = source_config.get("storage_options", {}) - if as_of_version: - dt.load_version(as_of_version) - - if as_of_datetime: - dt.load_with_datetime(as_of_datetime) - - df = dt.to_pyarrow_dataset() + df = delta_load( + table_path=table_path, + storage_options=storage_options, + as_of_version=as_of_version, + as_of_datetime=as_of_datetime, + ) return df def default_materialization(self): return "view" - -# Future -# TODO add databricks catalog + def store(self, target_config: TargetConfig, df: pa.lib.Table = None): + # Assert that the target_config has a location and relation identifier + assert target_config.location is not None, "Location is required for storing data!" + + # Get required variables from the target configuration + table_path = target_config.location.path + + # Get optional variables from the target configuration + mode = target_config.config.get("mode", "overwrite") + storage_options = target_config.config.get("storage_options", {}) + partition_key = target_config.config.get("partition_key", None) + unique_key = target_config.config.get("unique_key", None) + + delta_write( + mode=mode, + table_path=table_path, + df=df, + storage_options=storage_options, + partition_key=partition_key, + unique_key=unique_key, + ) diff --git a/dbt/adapters/duckdb/plugins/excel.py b/dbt/adapters/duckdb/plugins/excel.py index 2f5ac0f8..be10c0d8 100644 --- a/dbt/adapters/duckdb/plugins/excel.py +++ b/dbt/adapters/duckdb/plugins/excel.py @@ -42,7 +42,7 @@ def load(self, source_config: SourceConfig): sheet_name = source_config.get("sheet_name", 0) return pd.read_excel(source_location, sheet_name=sheet_name) - def store(self, target_config: TargetConfig): + def store(self, target_config: TargetConfig, df=None): plugin_output_config = self._config["output"] # Create the writer on the first instance of the call to store. diff --git a/dbt/adapters/duckdb/plugins/glue.py b/dbt/adapters/duckdb/plugins/glue.py index 78fb7581..6422deee 100644 --- a/dbt/adapters/duckdb/plugins/glue.py +++ b/dbt/adapters/duckdb/plugins/glue.py @@ -349,7 +349,7 @@ def initialize(self, config: Dict[str, Any]): self.database = config.get("glue_database", "default") self.delimiter = config.get("delimiter", ",") - def store(self, target_config: TargetConfig): + def store(self, target_config: TargetConfig, df=None): assert target_config.location is not None assert target_config.relation.identifier is not None table: str = target_config.relation.identifier diff --git a/dbt/adapters/duckdb/plugins/sqlalchemy.py b/dbt/adapters/duckdb/plugins/sqlalchemy.py index b66c8de8..58ceaf91 100644 --- a/dbt/adapters/duckdb/plugins/sqlalchemy.py +++ b/dbt/adapters/duckdb/plugins/sqlalchemy.py @@ -30,7 +30,7 @@ def load(self, source_config: SourceConfig) -> pd.DataFrame: with self.engine.connect() as conn: return pd.read_sql_table(table, con=conn) - def store(self, target_config: TargetConfig): + def store(self, target_config: TargetConfig, df=None): # first, load the data frame from the external location df = pd_utils.target_to_df(target_config) table_name = target_config.relation.identifier diff --git a/dbt/adapters/duckdb/plugins/unity.py b/dbt/adapters/duckdb/plugins/unity.py new file mode 100644 index 00000000..43a3563e --- /dev/null +++ b/dbt/adapters/duckdb/plugins/unity.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import json +import sys +from enum import Enum +from typing import Any +from typing import Dict +from typing import Literal + +import pyarrow as pa +from unitycatalog import Unitycatalog +from unitycatalog.types.table_create_params import Column + +from . import BasePlugin +from ..utils import find_secrets_by_type +from ..utils import SourceConfig +from ..utils import TargetConfig + + +class StorageFormat(str, Enum): + """Enum class for the storage formats supported by the plugin.""" + + DELTA = "DELTA" + + +def uc_schema_exists(client: Unitycatalog, schema_name: str, catalog_name: str = "unity") -> bool: + """Check if a UC schema exists in the catalog.""" + schema_list_request = client.schemas.list(catalog_name=catalog_name) + + if not schema_list_request.schemas: + return False + + return schema_name in [schema.name for schema in schema_list_request.schemas] + + +def uc_table_exists( + client: Unitycatalog, table_name: str, schema_name: str, catalog_name: str = "unity" +) -> bool: + """Check if a UC table exists in the catalog.""" + + table_list_request = client.tables.list(catalog_name=catalog_name, schema_name=schema_name) + + if not table_list_request.tables: + return False + + return table_name in [table.name for table in table_list_request.tables] + + +UCSupportedTypeLiteral = Literal[ + "BOOLEAN", + "BYTE", + "SHORT", + "INT", + "LONG", + "FLOAT", + "DOUBLE", + "DATE", + "TIMESTAMP", + "TIMESTAMP_NTZ", + "STRING", + "BINARY", + "DECIMAL", + "INTERVAL", + "ARRAY", + "STRUCT", + "MAP", + "CHAR", + "NULL", + "USER_DEFINED_TYPE", + "TABLE_TYPE", +] + +UCSupportedFormatLiteral = Literal["DELTA", "CSV", "JSON", "AVRO", "PARQUET", "ORC", "TEXT"] + + +def pyarrow_type_to_supported_uc_json_type(data_type: pa.DataType) -> UCSupportedTypeLiteral: + """Convert a PyArrow data type to a supported Unitycatalog JSON type.""" + if pa.types.is_boolean(data_type): + return "BOOLEAN" + elif pa.types.is_int8(data_type): + return "BYTE" + elif pa.types.is_int16(data_type): + return "SHORT" + elif pa.types.is_int32(data_type): + return "INT" + elif pa.types.is_int64(data_type): + return "LONG" + elif pa.types.is_float32(data_type): + return "FLOAT" + elif pa.types.is_float64(data_type): + return "DOUBLE" + elif pa.types.is_date32(data_type): + return "DATE" + elif pa.types.is_timestamp(data_type): + return "TIMESTAMP" + elif pa.types.is_string(data_type): + return "STRING" + elif pa.types.is_binary(data_type): + return "BINARY" + elif pa.types.is_decimal(data_type): + return "DECIMAL" + elif pa.types.is_duration(data_type): + return "INTERVAL" + elif pa.types.is_list(data_type): + return "ARRAY" + elif pa.types.is_struct(data_type): + return "STRUCT" + elif pa.types.is_map(data_type): + return "MAP" + elif pa.types.is_null(data_type): + return "NULL" + else: + raise NotImplementedError(f"Type {data_type} not supported") + + +def pyarrow_schema_to_columns(schema: pa.Schema) -> list[Column]: + """Convert a PyArrow schema to a list of Unitycatalog Column objects.""" + columns = [] + + for i, field in enumerate(schema): + data_type = field.type + json_type = pyarrow_type_to_supported_uc_json_type(data_type) + + column = Column( + name=field.name, + type_name=json_type, + nullable=field.nullable, + comment=f"Field {field.name}", # Generic comment, modify as needed + position=i, + type_json=json.dumps( + { + "name": field.name, + "type": json_type, + "nullable": field.nullable, + "metadata": field.metadata or {}, + } + ), + type_precision=0, + type_scale=0, + type_text=json_type, + ) + + # Adjust type precision and scale for decimal types + if pa.types.is_decimal(data_type): + column["type_precision"] = data_type.precision + column["type_scale"] = data_type.scale + + columns.append(column) + + return columns + + +def create_table_if_not_exists( + uc_client: Unitycatalog, + table_name: str, + schema_name: str, + catalog_name: str, + storage_location: str, + schema: list[Column], + storage_format: UCSupportedFormatLiteral, +): + """Create or update a Unitycatalog table.""" + + if not uc_schema_exists(uc_client, schema_name, catalog_name): + uc_client.schemas.create(catalog_name=catalog_name, name=schema_name) + + if not uc_table_exists(uc_client, table_name, schema_name, catalog_name): + uc_client.tables.create( + catalog_name=catalog_name, + columns=schema, + data_source_format=storage_format, + name=table_name, + schema_name=schema_name, + table_type="EXTERNAL", + storage_location=storage_location, + ) + else: + # TODO: Add support for schema checks/schema evolution with existing schema and dataframe schema + pass + + +class Plugin(BasePlugin): + # The name of the catalog + catalog_name: str = "unity" + + # The default storage format + default_format = StorageFormat.DELTA + + # The Unitycatalog client + uc_client: Unitycatalog + + def initialize(self, config: Dict[str, Any]): + # Assert that the credentials and secrets are present + assert self.creds is not None, "Credentials are required for the plugin!" + assert self.creds.secrets is not None, "Secrets are required for the plugin!" + + # Find the UC secret + uc_secret = find_secrets_by_type(self.creds.secrets, "UC") + + # Get the endpoint from the UC secret + host_and_port = uc_secret["endpoint"] + + # Construct the full base URL + catalog_base_url = f"{host_and_port}/api/2.1/unity-catalog" + + # Prism mocks the UC server to http://127.0.0.1:4010 with no option to specify a basePath (api/2.1/unity-catalog) + # https://github.com/stoplightio/prism/discussions/906 + # This is why we need to check if we are running in pytest and only use the host_and_port + # Otherwise we will not be able to connect to the mock UC server + if "pytest" in sys.modules: + self.uc_client: Unitycatalog = Unitycatalog(base_url=host_and_port) + else: + # Otherwise, use the full base URL + self.uc_client: Unitycatalog = Unitycatalog(base_url=catalog_base_url) + + def load(self, source_config: SourceConfig): + # Assert that the source_config has a name, schema, and database + assert source_config.identifier is not None, "Name is required for loading data!" + assert source_config.schema is not None, "Schema is required for loading data!" + assert source_config.get("location") is not None, "Location is required for loading data!" + + # Get the required variables from the source configuration + table_path = source_config.get("location") + table_name = source_config.identifier + schema_name = source_config.schema + + # Get the optional variables from the source configuration + storage_format = source_config.get("format", self.default_format) + storage_options = source_config.get("storage_options", {}) + as_of_version = source_config.get("as_of_version", None) + as_of_datetime = source_config.get("as_of_datetime", None) + + if storage_format == StorageFormat.DELTA: + from .delta import delta_load + + df = delta_load( + table_path=table_path, + storage_options=storage_options, + as_of_version=as_of_version, + as_of_datetime=as_of_datetime, + ) + else: + raise NotImplementedError(f"Loading storage format {storage_format} not supported!") + + converted_schema = pyarrow_schema_to_columns(schema=df.schema) + + # Create he table in the Unitycatalog if it does not exist + create_table_if_not_exists( + uc_client=self.uc_client, + table_name=table_name, + schema_name=schema_name, + catalog_name=self.catalog_name, + storage_location=table_path, + schema=converted_schema, + storage_format=storage_format, + ) + + return df + + def store(self, target_config: TargetConfig, df: pa.lib.Table = None): + # Assert that the target_config has a location and relation identifier + assert target_config.location is not None, "Location is required for storing data!" + assert ( + target_config.relation.identifier is not None + ), "Relation identifier is required to name the table!" + + # Get required variables from the target configuration + table_path = target_config.location.path + table_name = target_config.relation.identifier + + # Get optional variables from the target configuration + mode = target_config.config.get("mode", "overwrite") + schema_name = target_config.config.get("schema") + + # If schema is not provided or empty set to default" + if not schema_name or schema_name == "": + schema_name = "default" + + storage_options = target_config.config.get("storage_options", {}) + partition_key = target_config.config.get("partition_key", None) + unique_key = target_config.config.get("unique_key", None) + + # Get the storage format from the plugin configuration + storage_format = self.plugin_config.get("format", self.default_format) + + # Convert the pa schema to columns + converted_schema = pyarrow_schema_to_columns(schema=df.schema) + + # Create he table in the Unitycatalog if it does not exist + create_table_if_not_exists( + uc_client=self.uc_client, + table_name=table_name, + schema_name=schema_name, + catalog_name=self.catalog_name, + storage_location=table_path, + schema=converted_schema, + storage_format=storage_format, + ) + + if storage_format == StorageFormat.DELTA: + from .delta import delta_write + + delta_write( + mode=mode, + table_path=table_path, + df=df, + storage_options=storage_options, + partition_key=partition_key, + unique_key=unique_key, + ) + else: + raise NotImplementedError(f"Writing storage format {storage_format} not supported!") diff --git a/dbt/adapters/duckdb/utils.py b/dbt/adapters/duckdb/utils.py index 19d3486f..88c1e258 100644 --- a/dbt/adapters/duckdb/utils.py +++ b/dbt/adapters/duckdb/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import Any from typing import Dict @@ -5,9 +7,17 @@ from typing import Optional from typing import Sequence +from duckdb.duckdb import DatabaseError +from tenacity import retry +from tenacity import retry_if_exception_type +from tenacity import stop_after_attempt +from tenacity import wait_incrementing + from dbt.adapters.base.column import Column from dbt.adapters.base.relation import BaseRelation from dbt.adapters.contracts.relation import RelationConfig + + # TODO # from dbt.context.providers import RuntimeConfigObject @@ -88,3 +98,26 @@ def as_dict(self) -> Dict[str, Any]: if self.location: base["location"] = self.location.as_dict() return base + + +def find_secrets_by_type(secrets: list[dict], secret_type: str) -> dict: + """Find secrets of a specific type in the secrets dictionary.""" + for secret in secrets: + if secret.get("type") == secret_type: + return secret + raise SecretTypeMissingError(f"Secret type {secret_type} not found in the secrets!") + + +class SecretTypeMissingError(Exception): + """Exception raised when the secret type is missing from the secrets dictionary.""" + + pass + + +def get_retry_decorator(max_attempts: int, wait_time: float, exception: DatabaseError): + return retry( + stop=stop_after_attempt(max_attempts), + wait=wait_incrementing(start=wait_time, increment=0.05), + retry=retry_if_exception_type(exception), + reraise=True, + ) diff --git a/dbt/include/duckdb/macros/materializations/external_table.sql b/dbt/include/duckdb/macros/materializations/external_table.sql new file mode 100644 index 00000000..420e475f --- /dev/null +++ b/dbt/include/duckdb/macros/materializations/external_table.sql @@ -0,0 +1,19 @@ +{% materialization external_table, adapter="duckdb", supported_languages=['sql', 'python'] %} +{{ log("External macro") }} + +{%- set target_relation = this.incorporate(type='view') %} + +{%- set plugin_name = config.get('plugin') -%} +{%- set location = render(config.get('location', default=external_location(this, config))) -%}) +{%- set format = config.get('format', 'parquet') -%} + +{% do store_relation(plugin_name, target_relation, location, format, config) %} + +{% call statement('main', language='sql') -%} + +{%- endcall %} + +-- we have to load this table as df and create target_relation view + +{{ return({'relations': [target_relation]}) }} +{% endmaterialization %} diff --git a/dbt/include/duckdb/macros/materializations/ref.sql b/dbt/include/duckdb/macros/materializations/ref.sql new file mode 100644 index 00000000..901e30f5 --- /dev/null +++ b/dbt/include/duckdb/macros/materializations/ref.sql @@ -0,0 +1,51 @@ +{% macro ref() %} + -- default ref: https://docs.getdbt.com/reference/dbt-jinja-functions/builtins + -- extract user-provided positional and keyword arguments + {% set version = kwargs.get('version') or kwargs.get('v') %} + {% set packagename = none %} + {%- if (varargs | length) == 1 -%} + {% set modelname = varargs[0] %} + {%- else -%} + {% set packagename = varargs[0] %} + {% set modelname = varargs[1] %} + {% endif %} + + -- call builtins.ref based on provided positional arguments + {% set rel = None %} + {% if packagename is not none %} + {% set rel = builtins.ref(packagename, modelname, version=version) %} + {% else %} + {% set rel = builtins.ref(modelname, version=version) %} + {% endif %} + + {% if execute %} + {% if graph.get('nodes') %} + {% for node in graph.nodes.values() | selectattr("name", "equalto", modelname) %} + -- Get the associated materialization from the node config + {% set materialization = node.config.materialized %} + -- Get the associated plugin from the node config + {% set plugin = node.config.plugin %} + + {% if plugin == 'unity' and materialization == 'external_table' %} + -- Retrieve the catalog value from the active target configuration + {% set catalog = target.get("catalog", "unity") %} + -- Get the associated schema from the node config + {% set schema = node.config.schema %} + + {% if not schema %} + {% set schema = 'default' %} + {% endif %} + + {% set new_rel = catalog ~ '.' ~ schema ~ '.' ~ rel.identifier %} + + {% do return(new_rel) %} + {% else %} + {% do return(rel) %} + {% endif %} + {% endfor %} + {% endif %} + {% endif %} + + -- return the original relation object by default + {% do return(rel) %} +{% endmacro %} diff --git a/dev-requirements.txt b/dev-requirements.txt index 0b9b1f7f..21a34f1c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -34,3 +34,4 @@ tox>=3.13 twine wheel deltalake +unitycatalog diff --git a/scripts/mock-uc-server.sh b/scripts/mock-uc-server.sh new file mode 100755 index 00000000..8740539d --- /dev/null +++ b/scripts/mock-uc-server.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +if [[ -n "$1" && "$1" != '--'* ]]; then + URL="$1" + shift +else + URL="$(grep 'openapi_spec_url' .mock-uc-server-stats.yml | cut -d' ' -f2)" +fi + +# Check if the URL is empty +if [ -z "$URL" ]; then + echo "Error: No OpenAPI spec path/url provided or found in ..mock-uc-server-stats.yml" + exit 1 +fi + +echo "==> Starting mock server with URL ${URL}" + +# Run prism mock on the given spec +if [ "$1" == "--daemon" ]; then + npm exec --package=@stoplight/prism-cli@~5.8 -- prism mock "$URL" &> .prism.log & + + # Wait for server to come online + echo -n "Waiting for server" + while ! grep -q "✖ fatal\|Prism is listening" ".prism.log" ; do + echo -n "." + sleep 0.1 + done + + if grep -q "✖ fatal" ".prism.log"; then + cat .prism.log + exit 1 + fi + + echo +else + npm exec --package=@stoplight/prism-cli@~5.8 -- prism mock "$URL" -p 4010 +fi diff --git a/scripts/test-unity.sh b/scripts/test-unity.sh new file mode 100755 index 00000000..3f83d6c6 --- /dev/null +++ b/scripts/test-unity.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[0;33m' +NC='\033[0m' # No Color + +function prism_is_running() { + curl --silent "http://localhost:4010" >/dev/null 2>&1 +} + +kill_server_on_port() { + pids=$(lsof -t -i tcp:"$1" || echo "") + if [ "$pids" != "" ]; then + kill "$pids" + echo "Stopped $pids." + fi +} + +function is_overriding_api_base_url() { + [ -n "$TEST_API_BASE_URL" ] +} + +if ! is_overriding_api_base_url && ! prism_is_running ; then + # When we exit this script, make sure to kill the background mock server process + trap 'kill_server_on_port 4010' EXIT + + # Start the dev server + ./scripts/mock-uc-server.sh --daemon +fi + +if is_overriding_api_base_url ; then + echo -e "${GREEN}✔ Running tests against ${TEST_API_BASE_URL}${NC}" + echo +elif ! prism_is_running ; then + echo -e "${RED}ERROR:${NC} The test suite will not run without a mock Prism server" + echo -e "running against your OpenAPI spec." + echo + echo -e "To run the server, pass in the path or url of your OpenAPI" + echo -e "spec to the prism command:" + echo + echo -e " \$ ${YELLOW}npm exec --package=@stoplight/prism-cli@~5.3.2 -- prism mock path/to/your.openapi.yml${NC}" + echo + + exit 1 +else + echo -e "${GREEN}✔ Mock prism server is running with your OpenAPI spec${NC}" + echo +fi + +echo "==> Running tests" +tox "$@" diff --git a/setup.cfg b/setup.cfg index 957d04b8..ea1cdbdd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ install_requires= dbt-common>=1,<2 dbt-adapters>=1,<2 duckdb>=1.0.0 + tenacity>=7.0.0 # add dbt-core to ensure backwards compatibility of installation, this is not a functional dependency dbt-core>=1.8.0 python_requires = >=3.8 @@ -44,6 +45,13 @@ requires = ["setuptools >= 61.2", "pbr>=1.9"] glue = boto3 mypy-boto3-glue +unity = + unitycatalog==0.1.1 + deltalake==0.18.2 + pyarrow==17.0.0 +delta = + deltalake==0.18.2 + pyarrow==17.0.0 [files] packages = diff --git a/tests/conftest.py b/tests/conftest.py index cb257622..2b290b52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -72,6 +72,21 @@ def dbt_profile_target(profile_type, bv_server_process, tmp_path_factory): profile["token"] = os.environ.get(TEST_MOTHERDUCK_TOKEN) profile["disable_transactions"] = True profile["path"] = "md:test" + elif profile_type == "unity": + profile["extensions"] = [{"name": "uc_catalog", + "repository": "http://nightly-extensions.duckdb.org"}] + profile["attach"] = [ + {"path": "unity", + "alias": "unity", + "type": "UC_CATALOG"}, + ] + profile["secrets"] = [{ + "type": "UC", + # here our mock uc server is running, prism defaults to 4010 + "endpoint": "http://127.0.0.1:4010", + "token": "test", + "aws_region": "eu-west-1" + }] elif profile_type == "memory": pass # use the default path-less profile else: diff --git a/tests/create_function_plugin.py b/tests/create_function_plugin.py index c4d52d6d..be277bc8 100644 --- a/tests/create_function_plugin.py +++ b/tests/create_function_plugin.py @@ -12,5 +12,5 @@ class Plugin(BasePlugin): def configure_connection(self, conn: DuckDBPyConnection): conn.create_function("foo", foo) - def store(self, target_config: TargetConfig): + def store(self, target_config: TargetConfig, df=None): assert target_config.config.get("key") == "value" diff --git a/tests/functional/plugins/test_delta.py b/tests/functional/plugins/test_delta.py index 0262f70f..29a1ba53 100644 --- a/tests/functional/plugins/test_delta.py +++ b/tests/functional/plugins/test_delta.py @@ -8,6 +8,8 @@ ) from deltalake.writer import write_deltalake +from tests.functional.plugins.utils import get_table_row_count + delta_schema_yml = """ version: 2 sources: @@ -55,38 +57,26 @@ class TestPlugins: @pytest.fixture(scope="class") def delta_test_table1(self): - td = tempfile.TemporaryDirectory() - path = Path(td.name) - table_path = path / "test_delta_table1" - - df = pd.DataFrame({"x": [1, 2, 3]}) - write_deltalake(table_path, df, mode="overwrite") + with tempfile.TemporaryDirectory() as tmpdir: + table_path = Path(tmpdir) / "test_delta_table1" - yield table_path + df = pd.DataFrame({"x": [1, 2, 3]}) + write_deltalake(table_path, df, mode="overwrite") - td.cleanup() + yield table_path @pytest.fixture(scope="class") def delta_test_table2(self): - td = tempfile.TemporaryDirectory() - path = Path(td.name) - table_path = path / "test_delta_table2" - - df = pd.DataFrame({ - "x": [1], - "y": ["a"] - }) - write_deltalake(table_path, df, mode="overwrite") + with tempfile.TemporaryDirectory() as tmpdir: + table_path = Path(tmpdir) / "test_delta_table2" - df = pd.DataFrame({ - "x": [1, 2], - "y": ["a","b"] - }) - write_deltalake(table_path, df, mode="overwrite") + df1 = pd.DataFrame({"x": [1], "y": ["a"]}) + write_deltalake(table_path, df1, mode="overwrite") - yield table_path + df2 = pd.DataFrame({"x": [1, 2], "y": ["a", "b"]}) + write_deltalake(table_path, df2, mode="overwrite") - td.cleanup() + yield table_path @pytest.fixture(scope="class") def profiles_config_update(self, dbt_profile_target): @@ -121,12 +111,12 @@ def test_plugins(self, project): results = run_dbt() assert len(results) == 4 - # check_relations_equal( - # project.adapter, - # [ - # "delta_table3", - # "delta_table3_expected", - # ], - # ) - # res = project.run_sql("SELECT count(1) FROM 'delta_table3'", fetch="one") - # assert res[0] == 2 + delta_table1_row_count = get_table_row_count(project, "main.delta_table1") + assert delta_table1_row_count == 3 + + delta_table2_row_count = get_table_row_count(project, "main.delta_table2") + assert delta_table2_row_count == 1 + + delta_table3_row_count = get_table_row_count(project, "main.delta_table3") + assert delta_table3_row_count == 0 + diff --git a/tests/functional/plugins/test_delta_write.py b/tests/functional/plugins/test_delta_write.py new file mode 100644 index 00000000..023be8e9 --- /dev/null +++ b/tests/functional/plugins/test_delta_write.py @@ -0,0 +1,66 @@ +import tempfile +from pathlib import Path + +import pytest +from dbt.tests.util import ( + run_dbt, ) + +from tests.functional.plugins.utils import get_table_row_count + +ref1 = """ +select 2 as a, 'test' as b +""" + + +def delta_table_sql(location: str) -> str: + return f""" + {{{{ config( + materialized='table', + plugin = 'delta', + location = '{location}', + mode = 'merge', + unique_key = 'a' + + ) }}}} + select * from {{{{ref('ref1')}}}} +""" + + +@pytest.mark.skip_profile("buenavista", "md") +class TestPlugins: + @pytest.fixture(scope="class") + def delta_test_table(self): + with tempfile.TemporaryDirectory() as tmpdir: + table_path = Path(tmpdir) / "test_delta_table" + yield table_path + + @pytest.fixture(scope="class") + def profiles_config_update(self, dbt_profile_target): + plugins = [{"module": "delta"}] + return { + "test": { + "outputs": { + "dev": { + "type": "duckdb", + "path": dbt_profile_target.get("path", ":memory:"), + "plugins": plugins, + } + }, + "target": "dev", + } + } + + @pytest.fixture(scope="class") + def models(self, delta_test_table): + return { + + "delta_table.sql": delta_table_sql(str(delta_test_table)), + "ref1.sql": ref1 + } + + def test_plugins(self, project): + results = run_dbt() + assert len(results) == 2 + + delta_table_row_count = get_table_row_count(project, "main.delta_table") + assert delta_table_row_count == 1 diff --git a/tests/functional/plugins/test_glue.py b/tests/functional/plugins/test_glue.py index ef220dc7..e233c60a 100644 --- a/tests/functional/plugins/test_glue.py +++ b/tests/functional/plugins/test_glue.py @@ -35,7 +35,7 @@ def seeds(self): @pytest.fixture(scope="class") def dbt_profile_target(self, dbt_profile_target): dbt_profile_target["external_root"] = "s3://duckdbtest/glue_test" - dbt_profile_target["extensions"] = ["httpfs"] + dbt_profile_target["extensions"] = [{"name": "httpfs"}] dbt_profile_target["settings"] = { "s3_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"), "s3_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"), diff --git a/tests/functional/plugins/test_motherduck.py b/tests/functional/plugins/test_motherduck.py index a2b01bd0..d1ada48a 100644 --- a/tests/functional/plugins/test_motherduck.py +++ b/tests/functional/plugins/test_motherduck.py @@ -1,15 +1,16 @@ -import pytest from unittest import mock -from unittest.mock import Mock + +import pytest from dbt.tests.util import ( run_dbt, ) -from dbt.adapters.duckdb.environments import Environment +from dbt.version import __version__ + +from dbt.adapters.duckdb.__version__ import version as plugin_version from dbt.adapters.duckdb.credentials import DuckDBCredentials from dbt.adapters.duckdb.credentials import PluginConfig +from dbt.adapters.duckdb.environments import Environment from dbt.adapters.duckdb.plugins.motherduck import Plugin -from dbt.adapters.duckdb.__version__ import version as plugin_version -from dbt.version import __version__ random_logs_sql = """ {{ config(materialized='table', meta=dict(temp_schema_name='dbt_temp_test')) }} @@ -43,7 +44,8 @@ group by all """ -@pytest.mark.skip_profile("buenavista", "file", "memory") + +@pytest.mark.skip_profile("buenavista", "file", "memory", "unity") class TestMDPlugin: @pytest.fixture(scope="class") def profiles_config_update(self, dbt_profile_target): @@ -65,7 +67,7 @@ def profiles_config_update(self, dbt_profile_target): @pytest.fixture(scope="class") def database_name(self, dbt_profile_target): return dbt_profile_target["path"].replace("md:", "") - + @pytest.fixture(scope="class") def md_sql(self, database_name): # Reads from a MD database in my test account in the cloud @@ -106,48 +108,47 @@ def test_incremental(self, project): res = project.run_sql("SELECT count(*) FROM summary_of_logs_test", fetch="one") assert res == (105,) - res = project.run_sql("SELECT schema_name FROM information_schema.schemata WHERE catalog_name = 'test'", fetch="all") + res = project.run_sql("SELECT schema_name FROM information_schema.schemata WHERE catalog_name = 'test'", + fetch="all") assert "dbt_temp_test" in [_r for (_r,) in res] def test_incremental_temp_table_exists(self, project): - project.run_sql('create or replace table test.dbt_temp_test.summary_of_logs_test as (select 1 from generate_series(1,10) g(x))') + project.run_sql( + 'create or replace table test.dbt_temp_test.summary_of_logs_test as (select 1 from generate_series(1,10) g(x))') run_dbt() res = project.run_sql("SELECT count(*) FROM summary_of_logs_test", fetch="one") assert res == (70,) -@pytest.fixture -def mock_md_plugin(): - return Plugin.create("motherduck") - - -@pytest.fixture -def mock_creds(dbt_profile_target): - plugin_config = PluginConfig(module="motherduck", config={"token": "quack"}) - if "md:" in dbt_profile_target["path"]: - return DuckDBCredentials(path=dbt_profile_target["path"], plugins=[plugin_config]) - return DuckDBCredentials(path=dbt_profile_target["path"]) + @pytest.fixture + def mock_md_plugin(self): + return Plugin.create("motherduck") + @pytest.fixture + def mock_creds(self, dbt_profile_target): + plugin_config = PluginConfig(module="motherduck", config={"token": "quack"}) + if "md:" in dbt_profile_target["path"]: + return DuckDBCredentials(path=dbt_profile_target["path"], plugins=[plugin_config]) + return DuckDBCredentials(path=dbt_profile_target["path"]) -@pytest.fixture -def mock_plugins(mock_creds, mock_md_plugin): - plugins = {} - if mock_creds.is_motherduck: - plugins["motherduck"] = mock_md_plugin - return plugins - - -def test_motherduck_user_agent(dbt_profile_target, mock_plugins, mock_creds): - with mock.patch("dbt.adapters.duckdb.environments.duckdb.connect") as mock_connect: - mock_creds.settings = {"custom_user_agent": "downstream-dep"} - Environment.initialize_db(mock_creds, plugins=mock_plugins) + @pytest.fixture + def mock_plugins(self, mock_creds, mock_md_plugin): + plugins = {} if mock_creds.is_motherduck: - kwargs = { - 'read_only': False, - 'config': { - 'custom_user_agent': f'dbt/{__version__} dbt-duckdb/{plugin_version} downstream-dep', - 'motherduck_token': 'quack' + plugins["motherduck"] = mock_md_plugin + return plugins + + def test_motherduck_user_agent(self, dbt_profile_target, mock_plugins, mock_creds): + with mock.patch("dbt.adapters.duckdb.environments.duckdb.connect") as mock_connect: + mock_creds.settings = {"custom_user_agent": "downstream-dep"} + Environment.initialize_db(mock_creds, plugins=mock_plugins) + if mock_creds.is_motherduck: + kwargs = { + 'read_only': False, + 'config': { + 'custom_user_agent': f'dbt/{__version__} dbt-duckdb/{plugin_version} downstream-dep', + 'motherduck_token': 'quack' + } } - } - mock_connect.assert_called_with(dbt_profile_target["path"], **kwargs) - else: - mock_connect.assert_called_with(dbt_profile_target["path"], read_only=False, config = {}) + mock_connect.assert_called_with(dbt_profile_target["path"], **kwargs) + else: + mock_connect.assert_called_with(dbt_profile_target["path"], read_only=False, config={}) diff --git a/tests/functional/plugins/test_plugins.py b/tests/functional/plugins/test_plugins.py index 124ee178..ccd858ce 100644 --- a/tests/functional/plugins/test_plugins.py +++ b/tests/functional/plugins/test_plugins.py @@ -41,14 +41,19 @@ """ -@pytest.mark.skip_profile("buenavista", "md") +@pytest.mark.skip_profile("buenavista", "md", "unity") class TestPlugins: @pytest.fixture(scope="class") def sqlite_test_db(self): path = "/tmp/satest.db" db = sqlite3.connect(path) cursor = db.cursor() - cursor.execute("CREATE TABLE tt1 (id int, name text)") + + # clean up + cursor.execute("DROP TABLE IF EXISTS tt1") + cursor.execute("DROP TABLE IF EXISTS test_table2") + + cursor.execute("CREATE TABLE tt1 (id int, name text)") cursor.execute("INSERT INTO tt1 VALUES (1, 'John Doe')") cursor.execute("INSERT INTO tt1 VALUES (2, 'Jane Smith')") cursor.execute("CREATE TABLE test_table2 (a int, b int, c int)") diff --git a/tests/functional/plugins/test_unity.py b/tests/functional/plugins/test_unity.py new file mode 100644 index 00000000..97c2b316 --- /dev/null +++ b/tests/functional/plugins/test_unity.py @@ -0,0 +1,144 @@ +import tempfile +from pathlib import Path + +import pandas as pd +import pytest +from dbt.tests.util import ( + run_dbt, +) +from deltalake.writer import write_deltalake + +unity_schema_yml = """ +version: 2 +sources: + - name: default + meta: + plugin: unity + tables: + - name: unity_source_table + description: "A UC table" + meta: + location: "{unity_source_table_location}" + format: DELTA + + - name: test + meta: + plugin: unity + tables: + - name: unity_source_table_with_version + description: "A UC table that loads a specific version of the table" + meta: + location: "{unity_source_table_with_version_location}" + format: DELTA + as_of_version: 0 +""" + +ref1 = """ +select 2 as a, 'test' as b +""" + + +def unity_create_table_sql(location: str) -> str: + return f""" + {{{{ config( + materialized='external_table', + plugin = 'unity', + location = '{location}' + ) }}}} + select * from {{{{ref('ref1')}}}} +""" + + +def unity_create_table_and_schema_sql(location: str) -> str: + return f""" + {{{{ config( + materialized='external_table', + plugin = 'unity', + schema = 'test_schema', + location = '{location}' + ) }}}} + select * from {{{{ref('ref1')}}}} +""" + + +@pytest.mark.skip_profile("buenavista", "file", "memory", "md") +class TestPlugins: + @pytest.fixture(scope="class") + def unity_source_table(self): + with tempfile.TemporaryDirectory() as tmpdir: + table_path = Path(tmpdir) / "unity_source_table" + + df = pd.DataFrame({"x": [1, 2, 3]}) + write_deltalake(table_path, df, mode="overwrite") + + yield table_path + + @pytest.fixture(scope="class") + def unity_source_table_with_version(self): + with tempfile.TemporaryDirectory() as tmpdir: + table_path = Path(tmpdir) / "unity_source_table_with_version" + + df1 = pd.DataFrame({"x": [1], "y": ["a"]}) + write_deltalake(table_path, df1, mode="overwrite") + + df2 = pd.DataFrame({"x": [1, 2], "y": ["a", "b"]}) + write_deltalake(table_path, df2, mode="overwrite") + + yield table_path + + @pytest.fixture(scope="class") + def unity_create_table(self): + td = tempfile.TemporaryDirectory() + path = Path(td.name) + table_path = path / "test_unity_create_table" + + yield table_path + + td.cleanup() + + @pytest.fixture(scope="class") + def unity_create_table_and_schema(self): + td = tempfile.TemporaryDirectory() + path = Path(td.name) + table_path = path / "test_unity_create_table_and_schema" + + yield table_path + + td.cleanup() + + @pytest.fixture(scope="class") + def profiles_config_update(self, dbt_profile_target): + plugins = [{"module": "unity"}] + extensions = dbt_profile_target.get("extensions") + extensions.extend([{"name": "delta"}]) + return { + "test": { + "outputs": { + "dev": { + "type": "duckdb", + "path": dbt_profile_target.get("path", ":memory:"), + "plugins": plugins, + "extensions": extensions, + "secrets": dbt_profile_target.get("secrets"), + "attach": dbt_profile_target.get("attach") + } + }, + "target": "dev", + } + } + + @pytest.fixture(scope="class") + def models(self, unity_create_table, unity_create_table_and_schema, unity_source_table, unity_source_table_with_version): + return { + "source_schema.yml": unity_schema_yml.format( + unity_source_table_location=unity_source_table, + unity_source_table_with_version_location=unity_source_table_with_version + ), + "unity_create_table.sql": unity_create_table_sql(str(unity_create_table)), + "unity_create_table_and_schema.sql": unity_create_table_and_schema_sql(str(unity_create_table_and_schema)), + "ref1.sql": ref1 + } + + def test_plugins(self, project): + results = run_dbt() + assert len(results) == 3 diff --git a/tests/functional/plugins/utils.py b/tests/functional/plugins/utils.py new file mode 100644 index 00000000..5c29d917 --- /dev/null +++ b/tests/functional/plugins/utils.py @@ -0,0 +1,6 @@ + +def get_table_row_count(dbt_project, full_table_name: str): + """Get the row count of a table.""" + count_result, *_ = dbt_project.run_sql(f"SELECT COUNT(*) FROM {full_table_name}", fetch="all") + row_count, *_ = count_result + return row_count diff --git a/tests/unit/test_credentials.py b/tests/unit/test_credentials.py index fd0756a3..529bc064 100644 --- a/tests/unit/test_credentials.py +++ b/tests/unit/test_credentials.py @@ -105,7 +105,7 @@ def test_add_unsupported_secret_param(): ) sql = creds.secrets_sql()[0] assert sql == \ -"""CREATE OR REPLACE SECRET _dbt_secret_1 ( +"""CREATE OR REPLACE SECRET __default_s3 ( type s3, password 'secret' )""" diff --git a/tests/unit/test_duckdb_adapter.py b/tests/unit/test_duckdb_adapter.py index f0942938..521f3bff 100644 --- a/tests/unit/test_duckdb_adapter.py +++ b/tests/unit/test_duckdb_adapter.py @@ -115,8 +115,8 @@ def test_create_secret(self, connector): DuckDBConnectionManager.close_all_connections() connection = self.adapter.acquire_connection("dummy") assert connection.handle - connection.handle._cursor._cursor.execute.assert_called_with( -"""CREATE OR REPLACE SECRET _dbt_secret_1 ( + connection.handle._conn.execute.assert_called_with( +"""CREATE OR REPLACE SECRET __default_s3 ( type s3, key_id 'abc', secret 'xyz', diff --git a/tox.ini b/tox.ini index 8cecc051..303bf88a 100644 --- a/tox.ini +++ b/tox.ini @@ -39,6 +39,15 @@ deps = -rdev-requirements.txt -e. +[testenv:{unity,py39}] +description = adapter functional testing using a Unity Catalog server +skip_install = True +passenv = * +commands = {envpython} -m pytest --profile=unity --maxfail=2 {posargs} tests/functional/adapter tests/functional/plugins/test_unity.py +deps = + -rdev-requirements.txt + -e. + [testenv:{md,py39}] description = adapter function testing using MotherDuck skip_install = True