Skip to content

Commit

Permalink
Add ibis backend db helper to adapt to different databases
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Dec 13, 2023
1 parent b1cdd38 commit 592e8b5
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 39 deletions.
27 changes: 6 additions & 21 deletions superduperdb/backends/ibis/data_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import base64
import typing as t
from warnings import warn

Expand All @@ -7,7 +6,7 @@
from ibis.backends.base import BaseBackend

from superduperdb.backends.base.data_backend import BaseDataBackend
from superduperdb.backends.ibis.db_helper import get_insert_processor
from superduperdb.backends.ibis.db_helper import get_db_helper
from superduperdb.backends.ibis.field_types import FieldType, dtype
from superduperdb.backends.ibis.query import Table
from superduperdb.backends.ibis.utils import get_output_table_name
Expand All @@ -23,6 +22,8 @@ class IbisDataBackend(BaseDataBackend):
def __init__(self, conn: BaseBackend, name: str, in_memory: bool = False):
super().__init__(conn=conn, name=name)
self.in_memory = in_memory
self.dialect = getattr(conn, 'name', 'base')
self.db_helper = get_db_helper(self.dialect)

def url(self):
return self.conn.con.url + self.name
Expand All @@ -39,32 +40,15 @@ def create_ibis_table(self, identifier: str, schema: Schema):
def insert(self, table_name, raw_documents):
for doc in raw_documents:
for k, v in doc.items():
doc[k] = self.convert_data_format(v)
table_name, raw_documents = get_insert_processor(self.conn.name)(
doc[k] = self.db_helper.convert_data_format(v)
table_name, raw_documents = self.db_helper.process_before_insert(
table_name, raw_documents
)
if not self.in_memory:
self.conn.insert(table_name, raw_documents)
else:
self.conn.create_table(table_name, pandas.DataFrame(raw_documents))

@staticmethod
def convert_data_format(data):
"""Convert byte data to base64 format for storage in the database."""

if isinstance(data, bytes):
return BASE64_PREFIX + base64.b64encode(data).decode('utf-8')
else:
return data

@staticmethod
def recover_data_format(data):
"""Recover byte data from base64 format stored in the database."""
if isinstance(data, str) and data.startswith(BASE64_PREFIX):
return base64.b64decode(data[len(BASE64_PREFIX) :])
else:
return data

def create_model_table_or_collection(self, model: t.Union[Model, APIModel]):
msg = (
"Model must have an encoder to create with the"
Expand Down Expand Up @@ -95,6 +79,7 @@ def create_table_and_schema(self, identifier: str, mapping: dict):
"""

try:
mapping = self.db_helper.process_schema_types(mapping)
t = self.conn.create_table(identifier, schema=ibis.schema(mapping))
except Exception as e:
if 'exists' in str(e):
Expand Down
66 changes: 55 additions & 11 deletions superduperdb/backends/ibis/db_helper.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,63 @@
import base64

import pandas as pd

BASE64_PREFIX = 'base64:'


class Base64Mixin:
def convert_data_format(self, data):
"""Convert byte data to base64 format for storage in the database."""
if isinstance(data, bytes):
return BASE64_PREFIX + base64.b64encode(data).decode('utf-8')
else:
return data

def recover_data_format(self, data):
"""Recover byte data from base64 format stored in the database."""
if isinstance(data, str) and data.startswith(BASE64_PREFIX):
return base64.b64decode(data[len(BASE64_PREFIX) :])
else:
return data

def process_schema_types(self, schema_mapping):
"""Convert bytes to string in the schema."""
for key, value in schema_mapping.items():
if value == 'Bytes':
schema_mapping[key] = 'String'
return schema_mapping


class DBHelper:
match_dialect = 'base'

def _default_insert_processor(table_name, datas):
"""Default insert processor for SQL dialects."""
return table_name, datas
def __init__(self, dialect):
self.dialect = dialect

def process_before_insert(self, table_name, datas):
return table_name, pd.DataFrame(datas)

def _clickhouse_insert_processor(table_name, datas):
"""Insert processor for ClickHouse."""
return f'`{table_name}`', pd.DataFrame(datas)
def process_schema_types(self, schema_mapping):
return schema_mapping

def convert_data_format(self, data):
return data

def get_insert_processor(dialect):
def recover_data_format(self, data):
return data


class ClickHouseHelper(Base64Mixin, DBHelper):
match_dialect = 'clickhouse'

def process_before_insert(self, table_name, datas):
return f'`{table_name}`', pd.DataFrame(datas)


def get_db_helper(dialect) -> DBHelper:
"""Get the insert processor for the given dialect."""
funcs = {
'clickhouse': _clickhouse_insert_processor,
}
return funcs.get(dialect, _default_insert_processor)
for helper in DBHelper.__subclasses__():
if helper.match_dialect == dialect:
return helper(dialect)

return DBHelper(dialect)
9 changes: 4 additions & 5 deletions superduperdb/backends/ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import types
import typing as t

import ibis
import pandas

from superduperdb import Document, logging
Expand Down Expand Up @@ -517,7 +516,9 @@ def execute(self, db):
f'{native_query} Wrong query or not supported yet :: {exc}'
)
for column in result.columns:
result[column] = result[column].map(db.databackend.recover_data_format)
result[column] = result[column].map(
db.databackend.db_helper.recover_data_format
)
return result


Expand Down Expand Up @@ -551,9 +552,7 @@ def pre_create(self, db: 'Datalayer'):
return

try:
db.databackend.create_ibis_table( # type: ignore[attr-defined]
self.identifier, schema=ibis.schema(self.schema.raw)
)
db.databackend.create_table_and_schema(self.identifier, self.schema.raw)
except Exception as e:
if 'already exists' in str(e):
pass
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def pre_create(self, db) -> None:
@property
def raw(self):
return {
k: (v.identifier if not isinstance(v, Encoder) else 'String')
k: (v.identifier if not isinstance(v, Encoder) else 'Bytes')
for k, v in self.fields.items()
}

Expand Down
2 changes: 1 addition & 1 deletion test/unittest/test_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ALLOWABLE_DEFECTS = {
'cast': 15, # Try to keep this down
'noqa': 4, # This should never change
'type_ignore': 30, # This should only ever increase in obscure edge cases
'type_ignore': 29, # This should only ever increase in obscure edge cases
}


Expand Down

0 comments on commit 592e8b5

Please sign in to comment.