Skip to content

Commit

Permalink
Merge pull request #141 from jwills/jwills_plugins
Browse files Browse the repository at this point in the history
Create a simple plugin system for loading data from external sources
  • Loading branch information
jwills authored Apr 14, 2023
2 parents 5de6b58 + 3e094de commit f665487
Show file tree
Hide file tree
Showing 22 changed files with 700 additions and 76 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,51 @@ jobs:
name: fsspec_results_${{ matrix.python-version }}-${{ steps.date.outputs.date }}.csv
path: fsspec_results.csv

plugins:
name: plugins test / python ${{ matrix.python-version }}

runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
python-version: ['3.9']

env:
TOXENV: "plugins"
PYTEST_ADDOPTS: "-v --color=yes --csv plugins_results.csv"

steps:
- name: Check out the repository
uses: actions/checkout@v3
with:
persist-credentials: false

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install python dependencies
run: |
python -m pip install tox
python -m pip --version
tox --version
- name: Run tox
run: tox

- name: Get current date
if: always()
id: date
run: echo "date=$(date +'%Y-%m-%dT%H_%M_%S')" >> $GITHUB_OUTPUT #no colons allowed for artifacts

- uses: actions/upload-artifact@v3
if: always()
with:
name: plugins_results_${{ matrix.python-version }}-${{ steps.date.outputs.date }}.csv
path: plugins_results.csv

build:
name: build packages

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ target/

.DS_Store
.idea/
.vscode/
19 changes: 17 additions & 2 deletions dbt/adapters/duckdb/buenavista.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import psycopg2

from . import credentials
from . import utils
from .environments import Environment
from dbt.contracts.connection import AdapterResponse

Expand All @@ -29,6 +30,9 @@ def handle(self):
cursor.close()
return conn

def get_binding_char(self) -> str:
return "%s"

def submit_python_job(self, handle, parsed_model: dict, compiled_code: str) -> AdapterResponse:
identifier = parsed_model["alias"]
payload = {
Expand All @@ -42,5 +46,16 @@ def submit_python_job(self, handle, parsed_model: dict, compiled_code: str) -> A
handle.cursor().execute(json.dumps(payload))
return AdapterResponse(_message="OK")

def get_binding_char(self) -> str:
return "%s"
def load_source(self, plugin_name: str, source_config: utils.SourceConfig):
handle = self.handle()
payload = {
"method": "dbt_load_source",
"params": {
"plugin_name": plugin_name,
"source_config": source_config.as_dict(),
},
}
cursor = handle.cursor()
cursor.execute(json.dumps(payload))
cursor.close()
handle.close()
25 changes: 16 additions & 9 deletions dbt/adapters/duckdb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,31 @@

class DuckDBConnectionManager(SQLConnectionManager):
TYPE = "duckdb"
LOCK = threading.RLock()
ENV = None
_LOCK = threading.RLock()
_ENV = None

def __init__(self, profile: AdapterRequiredConfig):
super().__init__(profile)

@classmethod
def env(cls) -> environments.Environment:
with cls._LOCK:
if not cls._ENV:
raise Exception("DuckDBConnectionManager environment requested before creation!")
return cls._ENV

@classmethod
def open(cls, connection: Connection) -> Connection:
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection

credentials = cls.get_credentials(connection.credentials)
with cls.LOCK:
with cls._LOCK:
try:
if not cls.ENV:
cls.ENV = environments.create(credentials)
connection.handle = cls.ENV.handle()
if not cls._ENV:
cls._ENV = environments.create(credentials)
connection.handle = cls._ENV.handle()
connection.state = ConnectionState.OPEN

except RuntimeError as e:
Expand Down Expand Up @@ -79,9 +86,9 @@ def get_response(cls, cursor) -> AdapterResponse:

@classmethod
def close_all_connections(cls):
with cls.LOCK:
if cls.ENV is not None:
cls.ENV = None
with cls._LOCK:
if cls._ENV is not None:
cls._ENV = None


atexit.register(DuckDBConnectionManager.close_all_connections)
21 changes: 20 additions & 1 deletion dbt/adapters/duckdb/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,20 @@ def to_sql(self) -> str:
return base


@dataclass
class PluginConfig(dbtClassMixin):
# The name that this plugin will be referred to by in sources/models; must
# be unique within the project
name: str

# The fully-specified class name of the plugin code to use, which must be a
# subclass of dbt.adapters.duckdb.plugins.Plugin.
impl: str

# A plugin-specific set of configuration options
config: Optional[Dict[str, Any]] = None


@dataclass
class Remote(dbtClassMixin):
host: str
Expand All @@ -61,7 +75,7 @@ class DuckDBCredentials(Credentials):
# to DuckDB (e.g., if we need to enable using unsigned extensions)
config_options: Optional[Dict[str, Any]] = None

# any extensions we want to install and load (httpfs, parquet, etc.)
# any DuckDB extensions we want to install and load (httpfs, parquet, etc.)
extensions: Optional[Tuple[str, ...]] = None

# any additional pragmas we want to configure on our DuckDB connections;
Expand Down Expand Up @@ -95,6 +109,11 @@ class DuckDBCredentials(Credentials):
# Used to configure remote environments/connections
remote: Optional[Remote] = None

# A list of dbt-duckdb plugins that can be used to customize the
# behavior of loading source data and/or storing the relations that are
# created by SQL or Python models; see the plugins module for more details.
plugins: Optional[List[PluginConfig]] = None

@classmethod
def __pre_deserialize__(cls, data: Dict[Any, Any]) -> Dict[Any, Any]:
data = super().__pre_deserialize__(data)
Expand Down
61 changes: 49 additions & 12 deletions dbt/adapters/duckdb/environments.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import abc
import importlib.util
import os
import tempfile
from typing import Dict

import duckdb

from .credentials import DuckDBCredentials
from .plugins import Plugin
from .utils import SourceConfig
from dbt.contracts.connection import AdapterResponse
from dbt.exceptions import DbtRuntimeError

Expand Down Expand Up @@ -53,18 +57,21 @@ def cursor(self):
return self._cursor


class Environment:
class Environment(abc.ABC):
@abc.abstractmethod
def handle(self):
raise NotImplementedError

def cursor(self):
raise NotImplementedError
pass

@abc.abstractmethod
def submit_python_job(self, handle, parsed_model: dict, compiled_code: str) -> AdapterResponse:
raise NotImplementedError
pass

def get_binding_char(self) -> str:
return "?"

def close(self, cursor):
raise NotImplementedError
@abc.abstractmethod
def load_source(self, plugin_name: str, source_config: SourceConfig) -> str:
pass

@classmethod
def initialize_db(cls, creds: DuckDBCredentials):
Expand Down Expand Up @@ -93,7 +100,7 @@ def initialize_db(cls, creds: DuckDBCredentials):
return conn

@classmethod
def initialize_cursor(cls, creds, cursor):
def initialize_cursor(cls, creds: DuckDBCredentials, cursor):
# Extensions/settings need to be configured per cursor
for ext in creds.extensions or []:
cursor.execute(f"LOAD '{ext}'")
Expand All @@ -103,6 +110,21 @@ def initialize_cursor(cls, creds, cursor):
cursor.execute(f"SET {key} = '{value}'")
return cursor

@classmethod
def initialize_plugins(cls, creds: DuckDBCredentials) -> Dict[str, Plugin]:
ret = {}
for plugin in creds.plugins or []:
if plugin.name in ret:
raise Exception("Duplicate plugin name: " + plugin.name)
else:
if plugin.impl in Plugin.WELL_KNOWN_PLUGINS:
plugin.impl = Plugin.WELL_KNOWN_PLUGINS[plugin.impl]
try:
ret[plugin.name] = Plugin.create(plugin.impl, plugin.config or {})
except Exception as e:
raise Exception(f"Error attempting to create plugin {plugin.name}", e)
return ret

@classmethod
def run_python_job(cls, con, load_df_function, identifier: str, compiled_code: str):
mod_file = tempfile.NamedTemporaryFile(suffix=".py", delete=False)
Expand Down Expand Up @@ -136,13 +158,11 @@ def run_python_job(cls, con, load_df_function, identifier: str, compiled_code: s
finally:
os.unlink(mod_file.name)

def get_binding_char(self) -> str:
return "?"


class LocalEnvironment(Environment):
def __init__(self, credentials: DuckDBCredentials):
self.conn = self.initialize_db(credentials)
self._plugins = self.initialize_plugins(credentials)
self.creds = credentials

def handle(self):
Expand All @@ -159,6 +179,23 @@ def ldf(table_name):
self.run_python_job(con, ldf, parsed_model["alias"], compiled_code)
return AdapterResponse(_message="OK")

def load_source(self, plugin_name: str, source_config: SourceConfig):
if plugin_name not in self._plugins:
raise Exception(
f"Plugin {plugin_name} not found; known plugins are: "
+ ",".join(self._plugins.keys())
)
df = self._plugins[plugin_name].load(source_config)
assert df is not None
handle = self.handle()
cursor = handle.cursor()
materialization = source_config.meta.get("materialization", "table")
cursor.execute(
f"CREATE OR REPLACE {materialization} {source_config.table_name()} AS SELECT * FROM df"
)
cursor.close()
handle.close()

def close(self):
if self.conn:
self.conn.close()
Expand Down
9 changes: 3 additions & 6 deletions dbt/adapters/duckdb/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def use_database(self) -> bool:

@available
def get_binding_char(self):
return DuckDBConnectionManager.ENV.get_binding_char()
return DuckDBConnectionManager.env().get_binding_char()

@available
def external_write_options(self, write_location: str, rendered_options: dict) -> str:
Expand Down Expand Up @@ -144,11 +144,8 @@ def submit_python_job(self, parsed_model: dict, compiled_code: str) -> AdapterRe
connection = self.connections.get_if_exists()
if not connection:
connection = self.connections.get_thread_connection()
if DuckDBConnectionManager.ENV:
env = DuckDBConnectionManager.ENV
return env.submit_python_job(connection.handle, parsed_model, compiled_code)
else:
raise Exception("No ENV defined to execute dbt-duckdb python models!")
env = DuckDBConnectionManager.env()
return env.submit_python_job(connection.handle, parsed_model, compiled_code)

def get_rows_different_sql(
self,
Expand Down
39 changes: 39 additions & 0 deletions dbt/adapters/duckdb/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import abc
import importlib
from typing import Any
from typing import Dict

from ..utils import SourceConfig
from dbt.dataclass_schema import dbtClassMixin


class PluginConfig(dbtClassMixin):
"""A helper class for defining the configuration settings a particular plugin uses."""

pass


class Plugin(abc.ABC):
WELL_KNOWN_PLUGINS = {
"excel": "dbt.adapters.duckdb.plugins.excel.ExcelPlugin",
"gsheet": "dbt.adapters.duckdb.plugins.gsheet.GSheetPlugin",
"iceberg": "dbt.adapters.duckdb.plugins.iceberg.IcebergPlugin",
"sqlalchemy": "dbt.adapters.duckdb.plugins.sqlalchemy.SQLAlchemyPlugin",
}

@classmethod
def create(cls, impl: str, config: Dict[str, Any]) -> "Plugin":
module_name, class_name = impl.rsplit(".", 1)
module = importlib.import_module(module_name)
Class = getattr(module, class_name)
if not issubclass(Class, Plugin):
raise TypeError(f"{impl} is not a subclass of Plugin")
return Class(config)

@abc.abstractmethod
def __init__(self, plugin_config: Dict):
pass

def load(self, source_config: SourceConfig):
"""Load data from a source config and return it as a DataFrame-like object that DuckDB can read."""
raise NotImplementedError
19 changes: 19 additions & 0 deletions dbt/adapters/duckdb/plugins/excel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pathlib
from typing import Dict

import pandas as pd

from . import Plugin
from ..utils import SourceConfig


class ExcelPlugin(Plugin):
def __init__(self, config: Dict):
self._config = config

def load(self, source_config: SourceConfig):
ext_location = source_config.meta["external_location"]
ext_location = ext_location.format(**source_config.as_dict())
source_location = pathlib.Path(ext_location.strip("'"))
sheet_name = source_config.meta.get("sheet_name", 0)
return pd.read_excel(source_location, sheet_name=sheet_name)
Loading

0 comments on commit f665487

Please sign in to comment.