diff --git a/plugins/sqlalchemy/plugin_test/test_metadata.py b/plugins/sqlalchemy/plugin_test/test_metadata.py index 8327ae766..9b16e495e 100644 --- a/plugins/sqlalchemy/plugin_test/test_metadata.py +++ b/plugins/sqlalchemy/plugin_test/test_metadata.py @@ -1,11 +1,11 @@ from test.utils.database import metadata as metadata_utils import pytest -from superduper import CFG +from superduper import CFG, model, superduper from superduper_sqlalchemy.metadata import SQLAlchemyMetadata -DATABASE_URL = CFG.metadata_store or "sqlite:///:memory:" +DATABASE_URL = CFG.metadata_store or "sqlite://" @pytest.fixture @@ -30,3 +30,22 @@ def test_job(metadata): def test_artifact_relation(metadata): metadata_utils.test_artifact_relation(metadata) + + +def test_cleanup_metadata(): + + db = superduper(DATABASE_URL) + + @model + def test(x): return x + 1 + + db.apply(test, force=True) + + assert 'test' in db.show('model'), 'The model was not added to metadata' + + db.remove('model', 'test', force=True) + + assert not db.show(), 'The metadata was not cleared up' + + assert not db.metadata._cache, f'Cache not cleared: {db.metadata._cache}' + diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py b/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py index 6235a1560..304e088d4 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/__init__.py @@ -1,5 +1,5 @@ from .metadata import SQLAlchemyMetadata as MetaDataStore -__version__ = "0.5.5" +__version__ = "0.5.6" __all__ = ['MetaDataStore'] diff --git a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py index 31427b807..c610b7f3d 100644 --- a/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py +++ b/plugins/sqlalchemy/superduper_sqlalchemy/metadata.py @@ -67,6 +67,9 @@ def __init__(self): self._uuid2metadata: t.Dict[str, t.Dict] = {} self._type_id_identifier2metadata = defaultdict(dict) + def __bool__(self): + return bool(self._uuid2metadata) or bool(self._type_id_identifier2metadata) + def replace_metadata( self, metadata, uuid=None, type_id=None, version=None, identifier=None ): @@ -93,10 +96,18 @@ def expire(self, uuid): if (type_id, identifier) in self._type_id_identifier2metadata: del self._type_id_identifier2metadata[(type_id, identifier)] - def expire_identifier(self, type_id, identifier): + def expire_version(self, type_id, identifier, version): if (type_id, identifier) in self._type_id_identifier2metadata: - del self._type_id_identifier2metadata[(type_id, identifier)] + try: + r = self._type_id_identifier2metadata[(type_id, identifier)][version] + except KeyError: + return + del self._type_id_identifier2metadata[(type_id, identifier)][version] + del self._uuid2metadata[r['uuid']] + if not self._type_id_identifier2metadata[(type_id, identifier)]: + del self._type_id_identifier2metadata[(type_id, identifier)] + def add_metadata(self, metadata): metadata = copy.deepcopy(metadata) if 'dict' in metadata: @@ -455,7 +466,17 @@ def create_component( new_info = self._refactor_component_info(info) with self.session_context(commit=not self.batched) as session: if not self.batched: - stmt = insert(self.component_table).values(new_info) + primary_key_value = new_info['id'] + exists = session.execute( + select(self.component_table). + where(self.component_table.c.id == primary_key_value) + ).scalar() is not None + if exists: + return + stmt = ( + insert(self.component_table) + .values(new_info) + ) session.execute(stmt) else: self._insert_flush['component'].append(copy.deepcopy(new_info)) @@ -537,7 +558,7 @@ def delete_component_version(self, type_id: str, identifier: str, version: int): self.component_table.c.id == cv['id'] ) session.execute(stmt_delete) - self._cache.expire_identifier(type_id, identifier) + self._cache.expire_version(type_id, identifier, version) if cv: self.delete_parent_child(cv['id']) diff --git a/superduper/backends/base/queue.py b/superduper/backends/base/queue.py index bce62e2a9..92dd6a95a 100644 --- a/superduper/backends/base/queue.py +++ b/superduper/backends/base/queue.py @@ -85,6 +85,10 @@ def __init__(self, uri: t.Optional[str]): self.uri: t.Optional[str] = uri self.queue: t.Dict = defaultdict(lambda: []) + def clear(self): + """Clear the queue.""" + self.queue = defaultdict(lambda: []) + @abstractmethod def build_consumer(self, **kwargs): """Build a consumer instance.""" diff --git a/test/integration/usecase/test_reapply.py b/test/integration/usecase/test_reapply.py index 76ba47311..9e48aacf7 100644 --- a/test/integration/usecase/test_reapply.py +++ b/test/integration/usecase/test_reapply.py @@ -1,4 +1,12 @@ +import numpy + from superduper import Model +from superduper.components.dataset import Data +from superduper.components.model import model +from superduper.components.schema import Schema +from superduper.components.table import Table +from superduper.components.template import Template +from superduper.components.datatype import pickle_encoder class MyModel(Model): @@ -48,3 +56,38 @@ def build(name, data): listener_2_update = build('second', '2') db.apply(listener_2_update) + + +def test_template_component_deps(db): + + @model + def test(x): + return x + 1 + + test.datatype = pickle_encoder + + t = Template( + template=test, + identifier='test_template', + default_tables=[ + Table( + 'test_table', + schema=Schema('test_schema', fields={'x': 'str', 'y': pickle_encoder}), + data=Data('test_data', raw_data=[{'x': '1', 'y': numpy.random.randn(3)}]) + ) + ] + ) + + db.apply(t, force=True) + + m = t() + + db.apply(m, force=True) + + db.remove('model', 'test', recursive=True, force=True) + + m = t() + + db.apply(m, force=True) + + db.remove('model', 'test', recursive=True, force=True) \ No newline at end of file