-
Notifications
You must be signed in to change notification settings - Fork 482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix clickhouse and optimize compatibility when using SQL link #1496
Changes from all commits
5579cd2
8afed9d
eb04771
55964ef
6a47c4c
930cf6d
253dd9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import base64 | ||
|
||
import pandas as pd | ||
|
||
BASE64_PREFIX = 'base64:' | ||
|
||
|
||
class Base64Mixin: | ||
def convert_data_format(self, data): | ||
"""Convert byte data to base64 format for storage in the database.""" | ||
if isinstance(data, bytes): | ||
return BASE64_PREFIX + base64.b64encode(data).decode('utf-8') | ||
else: | ||
return data | ||
|
||
def recover_data_format(self, data): | ||
"""Recover byte data from base64 format stored in the database.""" | ||
if isinstance(data, str) and data.startswith(BASE64_PREFIX): | ||
return base64.b64decode(data[len(BASE64_PREFIX) :]) | ||
else: | ||
return data | ||
|
||
def process_schema_types(self, schema_mapping): | ||
"""Convert bytes to string in the schema.""" | ||
for key, value in schema_mapping.items(): | ||
if value == 'Bytes': | ||
schema_mapping[key] = 'String' | ||
return schema_mapping | ||
|
||
|
||
class DBHelper: | ||
match_dialect = 'base' | ||
|
||
def __init__(self, dialect): | ||
self.dialect = dialect | ||
|
||
def process_before_insert(self, table_name, datas): | ||
return table_name, pd.DataFrame(datas) | ||
|
||
def process_schema_types(self, schema_mapping): | ||
return schema_mapping | ||
|
||
def convert_data_format(self, data): | ||
return data | ||
|
||
def recover_data_format(self, data): | ||
return data | ||
|
||
|
||
class ClickHouseHelper(Base64Mixin, DBHelper): | ||
match_dialect = 'clickhouse' | ||
|
||
def process_before_insert(self, table_name, datas): | ||
return f'`{table_name}`', pd.DataFrame(datas) | ||
|
||
|
||
def get_db_helper(dialect) -> DBHelper: | ||
"""Get the insert processor for the given dialect.""" | ||
for helper in DBHelper.__subclasses__(): | ||
if helper.match_dialect == dialect: | ||
return helper(dialect) | ||
|
||
return DBHelper(dialect) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
import types | ||
import typing as t | ||
|
||
import ibis | ||
import pandas | ||
|
||
from superduperdb import Document, logging | ||
|
@@ -21,6 +20,7 @@ | |
_ReprMixin, | ||
) | ||
from superduperdb.backends.ibis.cursor import SuperDuperIbisResult | ||
from superduperdb.backends.ibis.utils import get_output_table_name | ||
from superduperdb.base.serializable import Variable | ||
from superduperdb.components.component import Component | ||
from superduperdb.components.encoder import Encoder | ||
|
@@ -156,12 +156,9 @@ def _get_all_fields(self, db): | |
for tab in component_tables: | ||
fields_copy = tab.schema.fields.copy() | ||
if '_outputs' in tab.identifier and self.renamings: | ||
model = tab.identifier.split('/')[1] | ||
match = re.search(r"_outputs_(.*?)_(\d+)", tab.identifier) | ||
for k in self.renamings.values(): | ||
if ( | ||
re.match(f'^_outputs/{model}/[0-9]+$', tab.identifier) | ||
is not None | ||
): | ||
if match: | ||
fields_copy[k] = fields_copy['output'] | ||
del fields_copy['output'] | ||
else: | ||
|
@@ -179,7 +176,6 @@ def select_table(self): | |
def _execute_with_pre_like(self, db): | ||
assert self.pre_like is not None | ||
assert self.post_like is None | ||
similar_scores = None | ||
similar_ids, similar_scores = self.pre_like.execute(db) | ||
similar_scores = dict(zip(similar_ids, similar_scores)) | ||
|
||
|
@@ -301,7 +297,7 @@ def model_update( # type: ignore[override] | |
for r in table_records: | ||
if isinstance(r['output'], dict) and '_content' in r['output']: | ||
r['output'] = r['output']['_content']['bytes'] | ||
db.databackend.insert(f'_outputs/{model}/{version}', table_records) | ||
db.databackend.insert(get_output_table_name(model, version), table_records) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same db.databackend.get_table_name(model, version, output=True) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
def add_fold(self, fold: str) -> Select: | ||
if self.query_linker is not None: | ||
|
@@ -444,7 +440,7 @@ def _select_ids_of_missing_outputs( | |
self, key: str, model: str, query_id: str, version: int | ||
): | ||
output_table = IbisQueryTable( | ||
identifier=f'_outputs/{model}/{version}', | ||
identifier=get_output_table_name(model, version), | ||
primary_id='output_id', | ||
) | ||
filtered = output_table.filter( | ||
|
@@ -468,10 +464,10 @@ def _outputs(self, query_id: str, **kwargs): | |
model, version = model.split('/') | ||
symbol_table = IbisQueryTable( | ||
identifier=( | ||
f'_outputs/{model}/{version}' | ||
get_output_table_name(model, version) | ||
if version is not None | ||
else Variable( | ||
f'_outputs/{model}' + '/{version}', | ||
get_output_table_name(model, '{version}'), | ||
lambda db, value: value.format( | ||
version=db.show('model', model)[-1] | ||
), | ||
|
@@ -497,7 +493,7 @@ def _outputs(self, query_id: str, **kwargs): | |
self, self.table_or_collection.primary_id | ||
) | ||
other_query = self.join(symbol_table, symbol_table.input_id == attr) | ||
other_query = other_query.filter(symbol_table.key == key) | ||
other_query = other_query.filter(other_query.key == key) | ||
return other_query | ||
|
||
def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): | ||
|
@@ -519,6 +515,10 @@ def execute(self, db): | |
raise IbisBackendError( | ||
f'{native_query} Wrong query or not supported yet :: {exc}' | ||
) | ||
for column in result.columns: | ||
result[column] = result[column].map( | ||
db.databackend.db_helper.recover_data_format | ||
) | ||
return result | ||
|
||
|
||
|
@@ -552,9 +552,7 @@ def pre_create(self, db: 'Datalayer'): | |
return | ||
|
||
try: | ||
db.databackend.create_ibis_table( # type: ignore[attr-defined] | ||
self.identifier, schema=ibis.schema(self.schema.raw) | ||
) | ||
db.databackend.create_table_and_schema(self.identifier, self.schema.raw) | ||
except Exception as e: | ||
if 'already exists' in str(e): | ||
pass | ||
|
@@ -666,7 +664,7 @@ def select_ids_of_missing_outputs( | |
self, key: str, model: str, version: int | ||
) -> Select: | ||
output_table = IbisQueryTable( | ||
identifier=f'_outputs/{model}/{version}', | ||
identifier=get_output_table_name(model, version), | ||
primary_id='output_id', | ||
) | ||
query_id = str(hash(self)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
def get_output_table_name(model_identifier, version): | ||
"""Get the output table name for the given model.""" | ||
# use `_` to connect the model_identifier and version | ||
return f'_outputs_{model_identifier}_{version}' | ||
Comment on lines
+1
to
+4
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some databases do not support using "/" in table name There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good spot There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we can, but it doesn’t seem to make much sense. Because we only need to mark the model text between the first A more elegant way is to provide a method in the future to handle the table name processing function of a specific database, so we support a lot of databases, and different databases may have different restrictions. WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @blythed @jieguangzhou so for mongodb
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need a db helper to adapt the different databases in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import json | ||
from typing import Tuple | ||
|
||
from sqlalchemy import ( | ||
Boolean, | ||
DateTime, | ||
Integer, | ||
String, | ||
Text, | ||
TypeDecorator, | ||
) | ||
|
||
DEFAULT_LENGTH = 255 | ||
|
||
|
||
class JsonMixin: | ||
"""Mixin for JSON type columns. | ||
Converts dict to JSON strings before saving to database | ||
and converts JSON strings to dict when loading from database. | ||
""" | ||
|
||
def process_bind_param(self, value, dialect): | ||
if value is not None: | ||
value = json.dumps(value) | ||
return value | ||
|
||
def process_result_value(self, value, dialect): | ||
if value is not None: | ||
value = json.loads(value) | ||
return value | ||
|
||
|
||
class JsonAsString(JsonMixin, TypeDecorator): | ||
impl = String(DEFAULT_LENGTH) | ||
|
||
|
||
class JsonAsText(JsonMixin, TypeDecorator): | ||
impl = Text | ||
|
||
|
||
class DefaultConfig: | ||
type_string = String(DEFAULT_LENGTH) | ||
type_json_as_string = JsonAsString | ||
type_json_as_text = JsonAsText | ||
type_integer = Integer | ||
type_datetime = DateTime | ||
type_boolean = Boolean | ||
|
||
query_id_table_args: Tuple = tuple() | ||
job_table_args: Tuple = tuple() | ||
parent_child_association_table_args: Tuple = tuple() | ||
component_table_args: Tuple = tuple() | ||
meta_table_args: Tuple = tuple() | ||
|
||
|
||
def create_clickhouse_config(): | ||
# lazy import | ||
try: | ||
from clickhouse_sqlalchemy import engines, types | ||
except ImportError: | ||
raise ImportError( | ||
'The clickhouse_sqlalchemy package is required to use the ' | ||
'clickhouse dialect. Please install it with pip install ' | ||
'clickhouse-sqlalchemy' | ||
) | ||
|
||
class ClickHouseConfig: | ||
class JsonAsString(JsonMixin, TypeDecorator): | ||
impl = types.String | ||
|
||
class JsonAsText(JsonMixin, TypeDecorator): | ||
impl = types.String | ||
|
||
type_string = types.String | ||
type_json_as_string = JsonAsString | ||
type_json_as_text = JsonAsText | ||
type_integer = types.Int32 | ||
type_datetime = types.DateTime | ||
type_boolean = types.Boolean | ||
|
||
# clickhouse need engine args to create table | ||
query_id_table_args = (engines.MergeTree(order_by='query_id'),) | ||
job_table_args = (engines.MergeTree(order_by='identifier'),) | ||
parent_child_association_table_args = (engines.MergeTree(order_by='parent_id'),) | ||
component_table_args = (engines.MergeTree(order_by='id'),) | ||
meta_table_args = (engines.MergeTree(order_by='key'),) | ||
|
||
return ClickHouseConfig | ||
|
||
|
||
def get_db_config(dialect): | ||
if dialect == 'clickhouse': | ||
return create_clickhouse_config() | ||
else: | ||
return DefaultConfig | ||
Comment on lines
+91
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If other databases have other table creation behaviors on the meatastore in the future, you can add them here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Databackend should have
def get_table_name(..., ouput=True/False):
...
@jieguangzhou
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1496 (comment)