Skip to content

Commit

Permalink
[PLUGINS] Bump Version [sqlalchemy]
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Feb 6, 2025
1 parent 32680e8 commit 606dc31
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 7 deletions.
23 changes: 21 additions & 2 deletions plugins/sqlalchemy/plugin_test/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from test.utils.database import metadata as metadata_utils

import pytest
from superduper import CFG
from superduper import CFG, model, superduper

from superduper_sqlalchemy.metadata import SQLAlchemyMetadata

DATABASE_URL = CFG.metadata_store or "sqlite:///:memory:"
DATABASE_URL = CFG.metadata_store or "sqlite://"


@pytest.fixture
Expand All @@ -30,3 +30,22 @@ def test_job(metadata):

def test_artifact_relation(metadata):
metadata_utils.test_artifact_relation(metadata)


def test_cleanup_metadata():

db = superduper(DATABASE_URL)

@model
def test(x): return x + 1

db.apply(test, force=True)

assert 'test' in db.show('model'), 'The model was not added to metadata'

db.remove('model', 'test', force=True)

assert not db.show(), 'The metadata was not cleared up'

assert not db.metadata._cache, f'Cache not cleared: {db.metadata._cache}'

2 changes: 1 addition & 1 deletion plugins/sqlalchemy/superduper_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .metadata import SQLAlchemyMetadata as MetaDataStore

__version__ = "0.5.5"
__version__ = "0.5.6"

__all__ = ['MetaDataStore']
29 changes: 25 additions & 4 deletions plugins/sqlalchemy/superduper_sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def __init__(self):
self._uuid2metadata: t.Dict[str, t.Dict] = {}
self._type_id_identifier2metadata = defaultdict(dict)

def __bool__(self):
return bool(self._uuid2metadata) or bool(self._type_id_identifier2metadata)

def replace_metadata(
self, metadata, uuid=None, type_id=None, version=None, identifier=None
):
Expand All @@ -93,10 +96,18 @@ def expire(self, uuid):
if (type_id, identifier) in self._type_id_identifier2metadata:
del self._type_id_identifier2metadata[(type_id, identifier)]

def expire_identifier(self, type_id, identifier):
def expire_version(self, type_id, identifier, version):
if (type_id, identifier) in self._type_id_identifier2metadata:
del self._type_id_identifier2metadata[(type_id, identifier)]
try:
r = self._type_id_identifier2metadata[(type_id, identifier)][version]
except KeyError:
return

del self._type_id_identifier2metadata[(type_id, identifier)][version]
del self._uuid2metadata[r['uuid']]
if not self._type_id_identifier2metadata[(type_id, identifier)]:
del self._type_id_identifier2metadata[(type_id, identifier)]

def add_metadata(self, metadata):
metadata = copy.deepcopy(metadata)
if 'dict' in metadata:
Expand Down Expand Up @@ -455,7 +466,17 @@ def create_component(
new_info = self._refactor_component_info(info)
with self.session_context(commit=not self.batched) as session:
if not self.batched:
stmt = insert(self.component_table).values(new_info)
primary_key_value = new_info['id']
exists = session.execute(
select(self.component_table).
where(self.component_table.c.id == primary_key_value)
).scalar() is not None
if exists:
return
stmt = (
insert(self.component_table)
.values(new_info)
)
session.execute(stmt)
else:
self._insert_flush['component'].append(copy.deepcopy(new_info))
Expand Down Expand Up @@ -537,7 +558,7 @@ def delete_component_version(self, type_id: str, identifier: str, version: int):
self.component_table.c.id == cv['id']
)
session.execute(stmt_delete)
self._cache.expire_identifier(type_id, identifier)
self._cache.expire_version(type_id, identifier, version)

if cv:
self.delete_parent_child(cv['id'])
Expand Down
4 changes: 4 additions & 0 deletions superduper/backends/base/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __init__(self, uri: t.Optional[str]):
self.uri: t.Optional[str] = uri
self.queue: t.Dict = defaultdict(lambda: [])

def clear(self):
"""Clear the queue."""
self.queue = defaultdict(lambda: [])

@abstractmethod
def build_consumer(self, **kwargs):
"""Build a consumer instance."""
Expand Down
43 changes: 43 additions & 0 deletions test/integration/usecase/test_reapply.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import numpy

from superduper import Model
from superduper.components.dataset import Data
from superduper.components.model import model
from superduper.components.schema import Schema
from superduper.components.table import Table
from superduper.components.template import Template
from superduper.components.datatype import pickle_encoder


class MyModel(Model):
Expand Down Expand Up @@ -48,3 +56,38 @@ def build(name, data):
listener_2_update = build('second', '2')

db.apply(listener_2_update)


def test_template_component_deps(db):

@model
def test(x):
return x + 1

test.datatype = pickle_encoder

t = Template(
template=test,
identifier='test_template',
default_tables=[
Table(
'test_table',
schema=Schema('test_schema', fields={'x': 'str', 'y': pickle_encoder}),
data=Data('test_data', raw_data=[{'x': '1', 'y': numpy.random.randn(3)}])
)
]
)

db.apply(t, force=True)

m = t()

db.apply(m, force=True)

db.remove('model', 'test', recursive=True, force=True)

m = t()

db.apply(m, force=True)

db.remove('model', 'test', recursive=True, force=True)

0 comments on commit 606dc31

Please sign in to comment.