diff --git a/superduperdb/backends/local/artifacts.py b/superduperdb/backends/local/artifacts.py index f37fadec8..f0add7310 100644 --- a/superduperdb/backends/local/artifacts.py +++ b/superduperdb/backends/local/artifacts.py @@ -7,7 +7,6 @@ from superduperdb import logging from superduperdb.backends.base.artifact import ArtifactStore -from superduperdb.base import exceptions from superduperdb.misc.colors import Colors @@ -38,51 +37,31 @@ def delete(self, file_id: str): Delete artifact from artifact store :param file_id: File id uses to identify artifact in store """ - try: - os.remove(f'{self.conn}/{file_id}') - except Exception as e: - raise exceptions.ArtifactStoreDeleteException( - f'Error while deleting {file_id}' - ) from e + os.remove(f'{self.conn}/{file_id}') def drop(self, force: bool = False): """ Drop the artifact store. """ - try: - if not force: - if not click.confirm( - f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' - f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' - 'Are you sure you want to drop all artifacts? ', - default=False, - ): - logging.warn('Aborting...') - shutil.rmtree(self.conn, ignore_errors=force) - except Exception as e: - raise exceptions.ArtifactStoreDeleteException( - 'Error while dropping in artifact store' - ) from e + if not force: + if not click.confirm( + f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' + f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' + 'Are you sure you want to drop all artifacts? ', + default=False, + ): + logging.warn('Aborting...') + shutil.rmtree(self.conn, ignore_errors=force) def save_artifact(self, serialized: bytes) -> t.Any: - try: - h = uuid.uuid4().hex - with open(os.path.join(self.conn, h), 'wb') as f: - f.write(serialized) - return h - except Exception as e: - raise exceptions.ArtifactStoreSaveException( - 'Error while saving artifacts' - ) from e + h = uuid.uuid4().hex + with open(os.path.join(self.conn, h), 'wb') as f: + f.write(serialized) + return h def load_bytes(self, file_id: str) -> bytes: - try: - with open(os.path.join(self.conn, file_id), 'rb') as f: - return f.read() - except Exception as e: - raise exceptions.ArtifactStoreLoadException( - 'Error while loading artifacts' - ) from e + with open(os.path.join(self.conn, file_id), 'rb') as f: + return f.read() def disconnect(self): """ diff --git a/superduperdb/backends/mongodb/artifacts.py b/superduperdb/backends/mongodb/artifacts.py index c39a47590..1a457f8de 100644 --- a/superduperdb/backends/mongodb/artifacts.py +++ b/superduperdb/backends/mongodb/artifacts.py @@ -3,7 +3,6 @@ from superduperdb import logging from superduperdb.backends.base.artifact import ArtifactStore -from superduperdb.base import exceptions from superduperdb.misc.colors import Colors @@ -24,44 +23,24 @@ def url(self): return self.conn.HOST + ':' + str(self.conn.PORT) + '/' + self.name def drop(self, force: bool = False): - try: - if not force: - if not click.confirm( - f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' - f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' - 'Are you sure you want to drop all artifacts? ', - default=False, - ): - logging.warn('Aborting...') - return self.db.client.drop_database(self.db.name) - except Exception as e: - raise exceptions.ArtifactStoreDeleteException( - 'Error while dropping in artifact store' - ) from e + if not force: + if not click.confirm( + f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' + f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' + 'Are you sure you want to drop all artifacts? ', + default=False, + ): + logging.warn('Aborting...') + return self.db.client.drop_database(self.db.name) def delete(self, file_id: str): - try: - return self.filesystem.delete(file_id) - except Exception as e: - raise exceptions.ArtifactStoreDeleteException( - 'Error while dropping in artifact store' - ) from e + return self.filesystem.delete(file_id) def load_bytes(self, file_id: str): - try: - return self.filesystem.get(file_id).read() - except Exception as e: - raise exceptions.ArtifactStoreLoadException( - 'Error while saving artifacts' - ) from e + return self.filesystem.get(file_id).read() def save_artifact(self, serialized: bytes): - try: - return self.filesystem.put(serialized) - except Exception as e: - raise exceptions.ArtifactStoreSaveException( - 'Error while saving artifacts' - ) from e + return self.filesystem.put(serialized) def disconnect(self): """ diff --git a/superduperdb/backends/mongodb/metadata.py b/superduperdb/backends/mongodb/metadata.py index cba275a40..1d15697ae 100644 --- a/superduperdb/backends/mongodb/metadata.py +++ b/superduperdb/backends/mongodb/metadata.py @@ -6,7 +6,6 @@ from superduperdb import logging from superduperdb.backends.base.metadata import MetaDataStore -from superduperdb.base import exceptions from superduperdb.components.component import Component from superduperdb.misc.colors import Colors @@ -37,54 +36,34 @@ def url(self): return self.conn.HOST + ':' + str(self.conn.PORT) + '/' + self.name def drop(self, force: bool = False): - try: - if not force: - if not click.confirm( - f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' - f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' - 'Are you sure you want to drop all meta-data? ', - default=False, - ): - logging.warn('Aborting...') - self.db.drop_collection(self.meta_collection.name) - self.db.drop_collection(self.component_collection.name) - self.db.drop_collection(self.job_collection.name) - self.db.drop_collection(self.parent_child_mappings.name) - except Exception as e: - raise exceptions.MetaDataStoreDeleteException( - 'Error while dropping in metadata store' - ) from e + if not force: + if not click.confirm( + f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' + f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' + 'Are you sure you want to drop all meta-data? ', + default=False, + ): + logging.warn('Aborting...') + self.db.drop_collection(self.meta_collection.name) + self.db.drop_collection(self.component_collection.name) + self.db.drop_collection(self.job_collection.name) + self.db.drop_collection(self.parent_child_mappings.name) def create_parent_child(self, parent: str, child: str) -> None: - try: - self.parent_child_mappings.insert_one( - { - 'parent': parent, - 'child': child, - } - ) - except Exception as e: - raise exceptions.MetaDataStoreDeleteException( - 'Error while creating parent child' - ) from e + self.parent_child_mappings.insert_one( + { + 'parent': parent, + 'child': child, + } + ) def create_component(self, info: t.Dict) -> InsertOneResult: - try: - if 'hidden' not in info: - info['hidden'] = False - return self.component_collection.insert_one(info) - except Exception as e: - raise exceptions.MetaDataStoreCreateException( - 'Error while creating component in metadata store' - ) from e + if 'hidden' not in info: + info['hidden'] = False + return self.component_collection.insert_one(info) def create_job(self, info: t.Dict) -> InsertOneResult: - try: - return self.job_collection.insert_one(info) - except Exception as e: - raise exceptions.MetaDataStoreJobException( - 'Error while creating job in metadata store' - ) from e + return self.job_collection.insert_one(info) def get_parent_child_relations(self): c = self.parent_child_mappings.find() @@ -94,38 +73,16 @@ def get_component_version_children(self, unique_id: str): return self.parent_child_mappings.distinct('child', {'parent': unique_id}) def get_job(self, identifier: str): - try: - return self.job_collection.find_one({'identifier': identifier}) - except Exception as e: - raise exceptions.MetaDataStoreJobException( - 'Error while getting job in metadata store' - ) from e + return self.job_collection.find_one({'identifier': identifier}) def create_metadata(self, key: str, value: str): - try: - return self.meta_collection.insert_one({'key': key, 'value': value}) - except Exception as e: - raise exceptions.MetaDataStoreCreateException( - 'Error while creating metadata in metadata store' - ) from e + return self.meta_collection.insert_one({'key': key, 'value': value}) def get_metadata(self, key: str): - try: - return self.meta_collection.find_one({'key': key})['value'] - except Exception as e: - raise exceptions.MetadatastoreException( - 'Error while getting metadata in metadata store' - ) from e + return self.meta_collection.find_one({'key': key})['value'] def update_metadata(self, key: str, value: str): - try: - return self.meta_collection.update_one( - {'key': key}, {'$set': {'value': value}} - ) - except Exception as e: - raise exceptions.MetaDataStoreUpdateException( - 'Error while updating metadata in metadata store' - ) from e + return self.meta_collection.update_one({'key': key}, {'$set': {'value': value}}) def get_latest_version( self, type_id: str, identifier: str, allow_hidden: bool = False diff --git a/superduperdb/backends/sqlalchemy/metadata.py b/superduperdb/backends/sqlalchemy/metadata.py index f7c6e93d2..7b0c71257 100644 --- a/superduperdb/backends/sqlalchemy/metadata.py +++ b/superduperdb/backends/sqlalchemy/metadata.py @@ -9,7 +9,6 @@ from superduperdb import logging from superduperdb.backends.base.metadata import MetaDataStore, NonExistentMetadataError -from superduperdb.base import exceptions from superduperdb.base.serializable import Serializable from superduperdb.components.component import Component as _Component from superduperdb.misc.colors import Colors @@ -114,20 +113,15 @@ def drop(self, force: bool = False): """ Drop the metadata store. """ - try: - if not force: - if not click.confirm( - f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' - f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' - 'Are you sure you want to drop all meta-data? ', - default=False, - ): - logging.warn('Aborting...') - Base.metadata.drop_all(self.conn) - except Exception as e: - raise exceptions.MetaDataStoreDeleteException( - 'Error while dropping in metadata store' - ) from e + if not force: + if not click.confirm( + f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' + f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' + 'Are you sure you want to drop all meta-data? ', + default=False, + ): + logging.warn('Aborting...') + Base.metadata.drop_all(self.conn) @contextmanager def session_context(self): @@ -159,28 +153,16 @@ def component_version_has_parents( ) def create_component(self, info: t.Dict): - try: - if 'hidden' not in info: - info['hidden'] = False - info['id'] = f'{info["type_id"]}/{info["identifier"]}/{info["version"]}' - with self.session_context() as session: - session.add(Component(**info)) - except Exception as e: - raise exceptions.MetaDataStoreCreateException( - 'Error while creating component in metadata store' - ) from e + if 'hidden' not in info: + info['hidden'] = False + info['id'] = f'{info["type_id"]}/{info["identifier"]}/{info["version"]}' + with self.session_context() as session: + session.add(Component(**info)) def create_parent_child(self, parent_id: str, child_id: str): - try: - with self.session_context() as session: - association = ParentChildAssociation( - parent_id=parent_id, child_id=child_id - ) - session.add(association) - except Exception as e: - raise exceptions.MetaDataStoreCreateException( - 'Error while creating parent child in metadata store' - ) from e + with self.session_context() as session: + association = ParentChildAssociation(parent_id=parent_id, child_id=child_id) + session.add(association) def delete_component_version(self, type_id: str, identifier: str, version: int): with self.session_context() as session: @@ -336,22 +318,12 @@ def _update_object( # --------------- JOBS ----------------- def create_job(self, info: t.Dict): - try: - with self.session_context() as session: - session.add(Job(**info)) - except Exception as e: - raise exceptions.MetaDataStoreJobException( - 'Error while creating job in metadata store' - ) from e + with self.session_context() as session: + session.add(Job(**info)) def get_job(self, job_id: str): - try: - with self.session_context() as session: - return session.query(Job).filter(Job.identifier == job_id).first() - except Exception as e: - raise exceptions.MetaDataStoreJobException( - 'Error while getting job in metadata store' - ) from e + with self.session_context() as session: + return session.query(Job).filter(Job.identifier == job_id).first() def listen_job(self, identifier: str): # Not supported currently @@ -362,13 +334,8 @@ def show_jobs(self): return [j.identifier for j in session.query(Job).all()] def update_job(self, job_id: str, key: str, value: t.Any): - try: - with self.session_context() as session: - session.query(Job).filter(Job.identifier == job_id).update({key: value}) - except Exception as e: - raise exceptions.MetaDataStoreJobException( - 'Error while updating job in metadata store' - ) from e + with self.session_context() as session: + session.query(Job).filter(Job.identifier == job_id).update({key: value}) def write_output_to_job(self, identifier, msg, stream): # Not supported currently @@ -377,31 +344,16 @@ def write_output_to_job(self, identifier, msg, stream): # --------------- METADATA ----------------- def create_metadata(self, key, value): - try: - with self.session_context() as session: - session.add(Meta(key=key, value=value)) - except Exception as e: - raise exceptions.MetaDataStoreCreateException( - 'Error while creating metadata in metadata store' - ) from e + with self.session_context() as session: + session.add(Meta(key=key, value=value)) def get_metadata(self, key): - try: - with self.session_context() as session: - return session.query(Meta).filter(Meta.key == key).first().value - except Exception as e: - raise exceptions.MetadatastoreException( - 'Error while getting metadata in metadata store' - ) from e + with self.session_context() as session: + return session.query(Meta).filter(Meta.key == key).first().value def update_metadata(self, key, value): - try: - with self.session_context() as session: - session.query(Meta).filter(Meta.key == key).update({Meta.value: value}) - except Exception as e: - raise exceptions.MetaDataStoreUpdateException( - 'Error while updating metadata in metadata store' - ) from e + with self.session_context() as session: + session.query(Meta).filter(Meta.key == key).update({Meta.value: value}) # --------------- Query ID ----------------- def add_query(self, query: 'Select', model: str): diff --git a/superduperdb/base/datalayer.py b/superduperdb/base/datalayer.py index d240eacd6..7a5bc05eb 100644 --- a/superduperdb/base/datalayer.py +++ b/superduperdb/base/datalayer.py @@ -17,8 +17,7 @@ from superduperdb.backends.base.data_backend import BaseDataBackend from superduperdb.backends.base.metadata import MetaDataStore from superduperdb.backends.base.query import Delete, Insert, RawQuery, Select, Update -from superduperdb.backends.ibis.data_backend import IbisDataBackend -from superduperdb.backends.ibis.query import Table, RawSQL +from superduperdb.backends.ibis.query import Table from superduperdb.backends.local.compute import LocalComputeBackend from superduperdb.base import exceptions, serializable from superduperdb.base.cursor import SuperDuperCursor @@ -121,36 +120,31 @@ def server_mode(self, is_server: bool): def initialize_vector_searcher( self, identifier, searcher_type: t.Optional[str] = None, backfill=False ) -> BaseVectorSearcher: - try: - searcher_type = searcher_type or s.CFG.vector_search - logging.info(f"loading of vectors of vector-index: '{identifier}'") - vi = self.vector_indices[identifier] + searcher_type = searcher_type or s.CFG.vector_search + logging.info(f"loading of vectors of vector-index: '{identifier}'") + vi = self.vector_indices[identifier] - clt = vi.indexing_listener.select.table_or_collection + clt = vi.indexing_listener.select.table_or_collection - if self.cdc.running: - msg = 'CDC only supported for vector search via lance format' - assert s.CFG.vector_search == 'lance', msg + if self.cdc.running: + msg = 'CDC only supported for vector search via lance format' + assert s.CFG.vector_search == 'lance', msg - vector_search_cls = vector_searcher_implementations[searcher_type] - vector_comparison = vector_search_cls( - identifier=vi.identifier, - dimensions=vi.dimensions, - measure=vi.measure, - ) - assert isinstance(clt.identifier, str), 'clt.identifier must be a string' + vector_search_cls = vector_searcher_implementations[searcher_type] + vector_comparison = vector_search_cls( + identifier=vi.identifier, + dimensions=vi.dimensions, + measure=vi.measure, + ) + assert isinstance(clt.identifier, str), 'clt.identifier must be a string' - if self.cdc.running: - # In this case, loading has already happened on disk via CDC mechanism - return vector_comparison - if backfill or s.CFG.cluster.vector_search is None: - self.backfill_vector_search(vi, vector_comparison) + if self.cdc.running: + # In this case, loading has already happened on disk via CDC mechanism + return vector_comparison + if backfill or s.CFG.cluster.vector_search is None: + self.backfill_vector_search(vi, vector_comparison) - return FastVectorSearcher(self, vector_comparison, vi.identifier) - except Exception as e: - raise exceptions.VectorSearchException( - f"Failed to initialize vector search index '{identifier}'" - ) from e + return FastVectorSearcher(self, vector_comparison, vi.identifier) def backfill_vector_search(self, vi, searcher): if vi.indexing_listener.select is None: @@ -226,22 +220,9 @@ def drop(self, force: bool = False): ): logging.warn('Aborting...') - try: - self.databackend.drop(force=True) - except Exception as e: - raise exceptions.DatabackendException('Failed to drop data backend') from e - try: - self.metadata.drop(force=True) - except Exception as e: - raise exceptions.MetadatastoreException( - 'Failed to drop metadata store' - ) from e - try: - self.artifact_store.drop(force=True) - except Exception as e: - raise exceptions.ArtifactStoreException( - 'Failed to drop artifact store' - ) from e + self.databackend.drop(force=True) + self.metadata.drop(force=True) + self.artifact_store.drop(force=True) def validate( self, @@ -336,21 +317,15 @@ async def apredict( if context_select is not None: context = self._get_context(model, context_select, context_key) - try: - out = await model.apredict( - input.unpack() if isinstance(input, Document) else input, - one=True, - context=context, - **kwargs, - ) - except Exception as e: - raise exceptions.ModelException('Error while model async predict') from e + out = await model.apredict( + input.unpack() if isinstance(input, Document) else input, + one=True, + context=context, + **kwargs, + ) - try: - if model.encoder is not None: - out = model.encoder(out) - except Exception as e: - raise exceptions.EncoderException('Error while encoding model') from e + if model.encoder is not None: + out = model.encoder(out) if context is not None: return Document(out), [Document(x) for x in context] @@ -379,21 +354,15 @@ def predict( if context_select is not None: context = self._get_context(model, context_select, context_key) - try: - out = model.predict( - input.unpack() if isinstance(input, Document) else input, - one=True, - context=context, - **kwargs, - ) - except Exception as e: - raise exceptions.ModelException('Error while model predict') from e + out = model.predict( + input.unpack() if isinstance(input, Document) else input, + one=True, + context=context, + **kwargs, + ) - try: - if model.encoder is not None: - out = model.encoder(out) - except Exception as e: - raise exceptions.EncoderException('Error while encoding model') from e + if model.encoder is not None: + out = model.encoder(out) if context is not None: return Document(out), [Document(x) for x in context] @@ -406,27 +375,18 @@ def execute(self, query: ExecuteQuery, *args, **kwargs) -> ExecuteResult: :param query: select, insert, delete, update, """ - if isinstance(query, str): - assert isinstance(self.databackend, IbisDataBackend) - query = RawSQL(query) - - try: - if isinstance(query, Delete): - return self.delete(query, *args, **kwargs) - if isinstance(query, Insert): - return self.insert(query, *args, **kwargs) - if isinstance(query, Select): - return self.select(query, *args, **kwargs) - if isinstance(query, Table): - return self.select(query.to_query(), *args, **kwargs) - if isinstance(query, Update): - return self.update(query, *args, **kwargs) - if isinstance(query, RawQuery): - return query.execute(self) - except Exception as e: - breakpoint() - QueryExceptionCls = exceptions.query_exceptions(query) - raise QueryExceptionCls(f"Error while executing {str(query)} query") from e + if isinstance(query, Delete): + return self.delete(query, *args, **kwargs) + if isinstance(query, Insert): + return self.insert(query, *args, **kwargs) + if isinstance(query, Select): + return self.select(query, *args, **kwargs) + if isinstance(query, Table): + return self.select(query.to_query(), *args, **kwargs) + if isinstance(query, Update): + return self.update(query, *args, **kwargs) + if isinstance(query, RawQuery): + return query.execute(self) raise TypeError( f'Wrong type of {query}; ' @@ -493,21 +453,13 @@ def refresh_after_delete( :param ids: ids which reduce scopy of computations :param verbose: Toggle to ``True`` to get more output """ - try: - task_workflow: TaskWorkflow = self._build_delete_task_workflow( - query, - ids=ids, - verbose=verbose, - ) - except Exception as e: - raise exceptions.TaskWorkflowException( - 'Error while building task workflow' - ) from e - try: - task_workflow.run_jobs() - return task_workflow - except Exception as e: - raise exceptions.JobException('Error while running task workflow') from e + task_workflow: TaskWorkflow = self._build_delete_task_workflow( + query, + ids=ids, + verbose=verbose, + ) + task_workflow.run_jobs() + return task_workflow def refresh_after_update_or_insert( self, @@ -522,20 +474,12 @@ def refresh_after_update_or_insert( :param ids: ids which reduce scopy of computations :param verbose: Toggle to ``True`` to get more output """ - try: - task_workflow: TaskWorkflow = self._build_task_workflow( - query.select_table, # TODO can be replaced by select_using_ids - ids=ids, - verbose=verbose, - ) - except Exception as e: - raise exceptions.TaskWorkflowException( - 'Error while building task workflow' - ) from e - try: - task_workflow.run_jobs() - except Exception as e: - raise exceptions.JobException('Error while running job') from e + task_workflow: TaskWorkflow = self._build_task_workflow( + query.select_table, # TODO can be replaced by select_using_ids + ids=ids, + verbose=verbose, + ) + task_workflow.run_jobs() return task_workflow def update(self, update: Update, refresh: bool = True) -> UpdateResult: @@ -659,52 +603,56 @@ def load( components :param info_only: toggle to ``True`` to return metadata only """ - try: - info = self.metadata.get_component( - type_id=type_id, - identifier=identifier, - version=version, - allow_hidden=allow_hidden, - ) - - if info is None: - raise Exception( - f'No such object of type "{type_id}", ' - f'"{identifier}" has been registered.' - ) - - if info_only: - return info + info = self.metadata.get_component( + type_id=type_id, + identifier=identifier, + version=version, + allow_hidden=allow_hidden, + ) - def get_children(info): - return { - k: v - for k, v in info['dict'].items() - if isinstance(v, dict) and serializable.is_component_metadata(v) - } + if info is None: + raise exceptions.MetadataException( + f'No such object of type "{type_id}", ' + f'"{identifier}" has been registered.' + ) - def replace_children(r): - if isinstance(r, dict): - children = get_children(r) - for k, v in children.items(): - r['dict'][k] = replace_children( - self.metadata.get_component(**v, allow_hidden=True) + if info_only: + return info + + def get_children(info): + return { + k: v + for k, v in info['dict'].items() + if isinstance(v, dict) and serializable.is_component_metadata(v) + } + + def replace_children(r): + if isinstance(r, dict): + children = get_children(r) + for k, v in children.items(): + c = replace_children( + self.metadata.get_component(**v, allow_hidden=True) + ) + try: + r['dict'][k] = c + except KeyError: + raise exceptions.MetadataException( + 'Children {k} not found in `dict`' ) - return r + return r - info = replace_children(info) - info = self.artifact_store.load(info, lazy=True) + info = replace_children(info) + info = self.artifact_store.load(info, lazy=True) - m = Component.deserialize(info) - m.on_load(self) + m = Component.deserialize(info) + m.on_load(self) - if cm := self.type_id_to_cache_mapping.get(type_id): + if cm := self.type_id_to_cache_mapping.get(type_id): + try: getattr(self, cm)[m.identifier] = m - return m - except Exception as e: - raise exceptions.ComponentException( - f'Error while loading {type_id} with id {identifier}' - ) from e + except KeyError: + raise exceptions.ComponentException('%s not found in %s cache'.format()) + return m def _build_delete_task_workflow( self, @@ -907,42 +855,42 @@ def _add( parent: t.Optional[str] = None, ): jobs = [] - try: - object.pre_create(self) - assert hasattr(object, 'identifier') - assert hasattr(object, 'version') + object.pre_create(self) + assert hasattr(object, 'identifier') + assert hasattr(object, 'version') - existing_versions = self.show(object.type_id, object.identifier) - if isinstance(object.version, int) and object.version in existing_versions: - s.logging.debug(f'{object.unique_id} already exists - doing nothing') - return [], object + existing_versions = self.show(object.type_id, object.identifier) + if isinstance(object.version, int) and object.version in existing_versions: + s.logging.debug(f'{object.unique_id} already exists - doing nothing') + return [], object - if existing_versions: - object.version = max(existing_versions) + 1 - else: - object.version = 0 + if existing_versions: + object.version = max(existing_versions) + 1 + else: + object.version = 0 - if serialized is None: - serialized, artifacts = object.serialized - artifact_info = self.artifact_store.save(artifacts) - serialized = self.artifact_store.replace(serialized, artifact_info) + if serialized is None: + serialized, artifacts = object.serialized + artifact_info = self.artifact_store.save(artifacts) + serialized = self.artifact_store.replace(serialized, artifact_info) - else: + else: + try: serialized['version'] = object.version serialized['dict']['version'] = object.version + except KeyError: + raise exceptions.MetadataException( + '`dict` or `version` not found in serialized dict.' + ) - jobs.extend(self._create_children(object, serialized)) - self.metadata.create_component(serialized) - if parent is not None: - self.metadata.create_parent_child(parent, object.unique_id) + jobs.extend(self._create_children(object, serialized)) + self.metadata.create_component(serialized) + if parent is not None: + self.metadata.create_parent_child(parent, object.unique_id) - object.post_create(self) - object.on_load(self) - return object.schedule_jobs(self, dependencies=dependencies), object - except Exception as e: - raise exceptions.DatalayerException( - f'Error while adding object with id: {object.identifier}' - ) from e + object.post_create(self) + object.on_load(self) + return object.schedule_jobs(self, dependencies=dependencies), object def _create_children(self, component: Component, serialized: t.Dict): jobs = [] @@ -1208,30 +1156,25 @@ def replace( :param upsert: toggle to ``True`` to enable even if object doesn't exist yet """ try: - try: - info = self.metadata.get_component( - object.type_id, object.identifier, version=object.version - ) - except FileNotFoundError as e: - if upsert: - return self.add( - object, - ) - raise exceptions.FileNotFoundException(str(e)) from e - - # If object has no version, update the last version - object.version = info['version'] - new_info = self.artifact_store.update(object, metadata_info=info) - self.metadata.replace_object( - new_info, - identifier=object.identifier, - type_id='model', - version=object.version, + info = self.metadata.get_component( + object.type_id, object.identifier, version=object.version ) - except Exception as e: - raise exceptions.ComponentException( - f'Error while replacing component {object.identifier}' - ) from e + except FileNotFoundError: + if upsert: + return self.add( + object, + ) + raise FileNotFoundError + + # If object has no version, update the last version + object.version = info['version'] + new_info = self.artifact_store.update(object, metadata_info=info) + self.metadata.replace_object( + new_info, + identifier=object.identifier, + type_id='model', + version=object.version, + ) def select_nearest( self, @@ -1241,20 +1184,15 @@ def select_nearest( outputs: t.Optional[Document] = None, n: int = 100, ) -> t.Tuple[t.List[str], t.List[float]]: - try: - like = self._get_content_for_filter(like) - vi = self.vector_indices[vector_index] - if outputs is None: - outs = {} - else: - outs = outputs.encode() - if not isinstance(outs, dict): - raise TypeError(f'Expected dict, got {type(outputs)}') - return vi.get_nearest(like, db=self, ids=ids, n=n, outputs=outs) - except Exception as e: - raise exceptions.VectorSearchException( - f"Error while vector search on index '{vector_index}'" - ) from e + like = self._get_content_for_filter(like) + vi = self.vector_indices[vector_index] + if outputs is None: + outs = {} + else: + outs = outputs.encode() + if not isinstance(outs, dict): + raise TypeError(f'Expected dict, got {type(outputs)}') + return vi.get_nearest(like, db=self, ids=ids, n=n, outputs=outs) def close(self): """ diff --git a/superduperdb/base/exceptions.py b/superduperdb/base/exceptions.py index 6c898df45..99b7526d2 100644 --- a/superduperdb/base/exceptions.py +++ b/superduperdb/base/exceptions.py @@ -23,33 +23,9 @@ def __str__(self): return self.msg -class ComponentException(BaseException): - ''' - ComponentException - ''' - - -class ComponentAddException(ComponentException): - ''' - ComponentAddException - ''' - - -class ComponentReplaceException(ComponentException): - ''' - ComponentReplaceException - ''' - - -class ComponentLoadException(ComponentException): +class QueryException(BaseException): ''' - ComponentLoadException - ''' - - -class DatabaseConnectionException(BaseException): - ''' - DatabackendException + QueryException ''' @@ -59,176 +35,13 @@ class DatabackendException(BaseException): ''' -class MetadatastoreException(BaseException): - ''' - MetadatastoreException - ''' - - -class ArtifactStoreException(BaseException): - ''' - ArtifactStoreException - ''' - - -class ArtifactStoreDeleteException(ArtifactStoreException): - ''' - ArtifactStoreException - ''' - - -class ArtifactStoreLoadException(ArtifactStoreException): - ''' - ArtifactStoreException - ''' - - -class ArtifactStoreSaveException(ArtifactStoreException): - ''' - ArtifactStoreException - ''' - - -class DatalayerException(BaseException): - ''' - DatalayerException - ''' - - -class FileNotFoundException(BaseException): +class MetadataException(BaseException): ''' - FileNotFoundException + MetadataException ''' -class ServiceRequestException(BaseException): - ''' - ServiceException - ''' - - -class ModelException(BaseException): - ''' - ModelException - ''' - - -class VectorSearchException(ComponentException): - ''' - VectorSearchException - ''' - - -class EncoderException(ComponentException): - ''' - EncoderException - ''' - - -class QueryException(ComponentException): - ''' - QueryException - ''' - - -class SelectQueryException(QueryException): - ''' - SelectQueryException - ''' - - -class DeleteQueryException(QueryException): - ''' - DeleteQueryException - ''' - - -class InsertQueryException(QueryException): - ''' - InsertQueryException - ''' - - -class UpdateQueryException(QueryException): - ''' - UpdateQueryException - ''' - - -class TableQueryException(QueryException): - ''' - TableQueryException - ''' - - -class RawQueryException(QueryException): - ''' - RawQueryException - ''' - - -class JobException(ComponentException): - ''' - JobException - ''' - - -class TaskWorkflowException(ComponentException): - ''' - TaskWorkflowException - ''' - - -class MetaDataStoreDeleteException(MetadatastoreException): - ''' - MetaDataStoreDeleteException - ''' - - -class MetaDataStoreJobException(MetadatastoreException): - ''' - MetaDataStoreJobException - ''' - - -class MetaDataStoreCreateException(MetadatastoreException): - ''' - MetaDataStoreCreateException - ''' - - -class MetaDataStoreUpdateException(MetadatastoreException): - ''' - MetaDataStoreUpdateException - ''' - - -class ModelPredictException(ModelException): - ''' - ModelPredictException - ''' - - -class ModelFitException(ModelException): +class ComponentException(BaseException): ''' - ModelFitException + ComponentException ''' - - -_query_exceptions = { - 'Delete': DeleteQueryException, - 'Update': UpdateQueryException, - 'Table': TableQueryException, - 'Insert': InsertQueryException, - 'Select': SelectQueryException, - 'RawQuery': RawQueryException, -} - - -def query_exceptions(query): - query = str(query) - for k, v in _query_exceptions.items(): - if k in query: - return v - else: - return QueryException diff --git a/superduperdb/components/encoder.py b/superduperdb/components/encoder.py index 1e15e0688..bbb7ccdc8 100644 --- a/superduperdb/components/encoder.py +++ b/superduperdb/components/encoder.py @@ -3,7 +3,6 @@ import pickle import typing as t -from superduperdb.base import exceptions from superduperdb.base.artifact import Artifact from superduperdb.components.component import Component @@ -68,14 +67,8 @@ def __call__( return Encodable(self, x=x, uri=uri) def decode(self, b: bytes) -> t.Any: - try: - assert isinstance(self.decoder, Artifact) - return self(self.decoder.artifact(b)) - except Exception as e: - raise exceptions.EncoderException( - f'Error while decoding bytes \ - Encoder: {self.identifier} Shape: {self.shape}' - ) from e + assert isinstance(self.decoder, Artifact) + return self(self.decoder.artifact(b)) def dump(self, other): return self.encoder.artifact(other) @@ -86,37 +79,32 @@ def encode( uri: t.Optional[str] = None, wrap: bool = True, ) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]: - try: - # TODO clarify what is going on here - def _wrap_content(x): - return { - '_content': { - 'bytes': self.encoder.artifact(x), - 'encoder': self.identifier, - } + # TODO clarify what is going on here + def _wrap_content(x): + return { + '_content': { + 'bytes': self.encoder.artifact(x), + 'encoder': self.identifier, } + } - if self.encoder is not None: - if x is not None: - if wrap: - return _wrap_content(x) - return self.encoder.artifact(x) # type: ignore[union-attr] - else: - if wrap: - return { - '_content': { - 'uri': uri, - 'encoder': self.identifier, - } - } - return uri + if self.encoder is not None: + if x is not None: + if wrap: + return _wrap_content(x) + return self.encoder.artifact(x) # type: ignore[union-attr] else: - assert x is not None - return x - except Exception as e: - raise exceptions.EncoderException( - f'Error while encoding x Encoder: {self.identifier} Shape: {self.shape}' - ) from e + if wrap: + return { + '_content': { + 'uri': uri, + 'encoder': self.identifier, + } + } + return uri + else: + assert x is not None + return x @dc.dataclass diff --git a/superduperdb/components/model.py b/superduperdb/components/model.py index ba317a0e6..53028b8eb 100644 --- a/superduperdb/components/model.py +++ b/superduperdb/components/model.py @@ -16,7 +16,6 @@ from superduperdb.backends.ibis.field_types import FieldType from superduperdb.backends.ibis.query import IbisCompoundSelect, Table from superduperdb.backends.query_dataset import QueryDataset -from superduperdb.base import exceptions from superduperdb.base.artifact import Artifact from superduperdb.base.serializable import Serializable from superduperdb.components.component import Component @@ -201,89 +200,74 @@ def predict( overwrite: bool = False, **kwargs, ) -> t.Any: - predict_type = 'batch prediction' + if one: + assert db is None, 'db must be None when ``one=True`` (direct call)' - try: - if one: - assert db is None, 'db must be None when ``one=True`` (direct call)' - - if isinstance(select, dict): - select = Serializable.deserialize(select) - - if isinstance(select, Table): - select = select.to_query() - - if db is not None: - if isinstance(select, IbisCompoundSelect): - try: - _ = db.metadata.get_query(str(hash(select))) - except NonExistentMetadataError: - logging.info(f'Query {select} not found in metadata, adding...') - db.metadata.add_query(select, self.identifier) - logging.info('Done') - logging.info(f'Adding model {self.identifier} to db') - assert isinstance(self, Component) - db.add(self) - - if listen: + if isinstance(select, dict): + select = Serializable.deserialize(select) + + if isinstance(select, Table): + select = select.to_query() + + if db is not None: + if isinstance(select, IbisCompoundSelect): + try: + _ = db.metadata.get_query(str(hash(select))) + except NonExistentMetadataError: + logging.info(f'Query {select} not found in metadata, adding...') + db.metadata.add_query(select, self.identifier) + logging.info('Done') + logging.info(f'Adding model {self.identifier} to db') + assert isinstance(self, Component) + db.add(self) + + if listen: + assert db is not None + assert select is not None + return self._predict_and_listen( + X=X, + db=db, + select=select, + max_chunk_size=max_chunk_size, + **kwargs, + ) + + if db is not None and db.compute.type == 'distributed': + return self.create_predict_job( + X, + select=select, + ids=ids, + max_chunk_size=max_chunk_size, + overwrite=overwrite, + **kwargs, + )(db=db, dependencies=dependencies) + else: + if select is not None and ids is None: assert db is not None - assert select is not None - return self._predict_and_listen( + return self._predict_with_select( X=X, - db=db, select=select, + db=db, + in_memory=in_memory, max_chunk_size=max_chunk_size, + overwrite=overwrite, **kwargs, ) - - if db is not None and db.compute.type == 'distributed': - return self.create_predict_job( - X, + elif select is not None and ids is not None: + assert db is not None + return self._predict_with_select_and_ids( + X=X, select=select, ids=ids, + db=db, max_chunk_size=max_chunk_size, - overwrite=overwrite, + in_memory=in_memory, **kwargs, - )(db=db, dependencies=dependencies) + ) else: - if select is not None and ids is None: - assert db is not None - return self._predict_with_select( - X=X, - select=select, - db=db, - in_memory=in_memory, - max_chunk_size=max_chunk_size, - overwrite=overwrite, - **kwargs, - ) - elif select is not None and ids is not None: - assert db is not None - return self._predict_with_select_and_ids( - X=X, - select=select, - ids=ids, - db=db, - max_chunk_size=max_chunk_size, - in_memory=in_memory, - **kwargs, - ) - else: - predict_type = 'single prediction' - if self.takes_context: - kwargs['context'] = context - return self._predict(X, one=one, **kwargs) - except Exception as e: - select_collection = None - if select: - assert isinstance(select, Select) - select_collection = select.table_or_collection.identifier - - raise exceptions.ModelPredictException( - f'Error while model prediction\ - \ninfo: predict type: {predict_type}, \ - Collection: {select_collection}, Model: {self.identifier}' - ) from e + if self.takes_context: + kwargs['context'] = context + return self._predict(X, one=one, **kwargs) async def apredict( self, @@ -292,14 +276,9 @@ async def apredict( one: bool = False, **kwargs, ): - try: - if self.takes_context: - kwargs['context'] = context - return await self._apredict(X, one=one, **kwargs) - except Exception as e: - raise exceptions.ModelPredictException( - f'Error while async model prediction Model: {self.identifier}' - ) from e + if self.takes_context: + kwargs['context'] = context + return await self._apredict(X, one=one, **kwargs) def _predict_and_listen( self, @@ -660,46 +639,38 @@ def fit( :param select: The select to use for training (optional) :param validation_sets: The validation ``Dataset`` instances to use (optional) """ - try: - if isinstance(select, dict): - select = Serializable.deserialize(select) - - if validation_sets: - validation_sets = list(validation_sets) - for i, vs in enumerate(validation_sets): - if isinstance(vs, Dataset): - assert db is not None - db.add(vs) - validation_sets[i] = vs.identifier - - if db is not None: - db.add(self) - - if db is not None and db.compute.type == 'distributed': - return self.create_fit_job( - X, - select=select, - y=y, - **kwargs, - )(db=db, dependencies=dependencies) - else: - return self._fit( - X, - y=y, - configuration=configuration, - data_prefetch=data_prefetch, - db=db, - metrics=metrics, - select=select, - validation_sets=validation_sets, - **kwargs, - ) - except Exception: - assert isinstance(select, Select) - raise exceptions.ModelFitException( - f"Error while model fit\ - \ninfo: Collection:{select.table_or_collection.identifier} \ - Model:{self.identifier}" + if isinstance(select, dict): + select = Serializable.deserialize(select) + + if validation_sets: + validation_sets = list(validation_sets) + for i, vs in enumerate(validation_sets): + if isinstance(vs, Dataset): + assert db is not None + db.add(vs) + validation_sets[i] = vs.identifier + + if db is not None: + db.add(self) + + if db is not None and db.compute.type == 'distributed': + return self.create_fit_job( + X, + select=select, + y=y, + **kwargs, + )(db=db, dependencies=dependencies) + else: + return self._fit( + X, + y=y, + configuration=configuration, + data_prefetch=data_prefetch, + db=db, + metrics=metrics, + select=select, + validation_sets=validation_sets, + **kwargs, ) diff --git a/superduperdb/jobs/job.py b/superduperdb/jobs/job.py index 54b750011..eae91393c 100644 --- a/superduperdb/jobs/job.py +++ b/superduperdb/jobs/job.py @@ -4,7 +4,6 @@ from abc import abstractmethod import superduperdb as s -from superduperdb.base import exceptions from superduperdb.jobs.tasks import callable_job, method_job @@ -115,21 +114,18 @@ def submit(self, dependencies=()): :param dependencies: list of dependencies """ - try: - self.future = self.db.compute.submit( - callable_job, - cfg=s.CFG, - function_to_call=self.callable, - job_id=self.identifier, - args=self.args, - kwargs=self.kwargs, - dependencies=dependencies, - db=self.db if self.db.compute.type == 'local' else None, - local=self.db.compute.type == 'local', - ) - - except Exception as e: - raise exceptions.JobException('Error while submitting job') from e + self.future = self.db.compute.submit( + callable_job, + cfg=s.CFG, + function_to_call=self.callable, + job_id=self.identifier, + args=self.args, + kwargs=self.kwargs, + dependencies=dependencies, + db=self.db if self.db.compute.type == 'local' else None, + local=self.db.compute.type == 'local', + ) + return def __call__(self, db: t.Any = None, dependencies=()): @@ -185,22 +181,19 @@ def submit(self, dependencies=()): Submit job for execution :param dependencies: list of dependencies """ - try: - self.future = self.db.compute.submit( - method_job, - cfg=s.CFG, - type_id=self.type_id, - identifier=self.component_identifier, - method_name=self.method_name, - job_id=self.identifier, - args=self.args, - kwargs=self.kwargs, - dependencies=dependencies, - db=self.db if self.db.compute.type == 'local' else None, - local=self.db.compute.type == 'local', - ) - except Exception as e: - raise exceptions.JobException('Error while submitting job') from e + self.future = self.db.compute.submit( + method_job, + cfg=s.CFG, + type_id=self.type_id, + identifier=self.component_identifier, + method_name=self.method_name, + job_id=self.identifier, + args=self.args, + kwargs=self.kwargs, + dependencies=dependencies, + db=self.db if self.db.compute.type == 'local' else None, + local=self.db.compute.type == 'local', + ) return def __call__(self, db: t.Any = None, dependencies=()): diff --git a/superduperdb/jobs/task_workflow.py b/superduperdb/jobs/task_workflow.py index 92541b46f..bbeabc82c 100644 --- a/superduperdb/jobs/task_workflow.py +++ b/superduperdb/jobs/task_workflow.py @@ -7,8 +7,6 @@ import networkx from networkx import DiGraph, ancestors -from superduperdb.base import exceptions - from .job import ComponentJob, FunctionJob, Job if t.TYPE_CHECKING: @@ -53,16 +51,10 @@ def run_jobs( for node in current_group: job: Job = self.G.nodes[node]['job'] dependencies = [self.G.nodes[a]['job'].future for a in pred(node)] - try: - job( - self.database, - dependencies=dependencies, - ) - except Exception as e: - raise exceptions.TaskWorkflowException( - f'Error while running job {job} with \ - dependencies {dependencies}' - ) from e + job( + self.database, + dependencies=dependencies, + ) done.add(node) current_group = [ diff --git a/superduperdb/misc/download.py b/superduperdb/misc/download.py index 6a6ed62d9..735077dfb 100644 --- a/superduperdb/misc/download.py +++ b/superduperdb/misc/download.py @@ -15,7 +15,6 @@ from superduperdb import CFG, logging from superduperdb.backends.base.query import Insert, Select -from superduperdb.base import exceptions from superduperdb.base.document import Document from superduperdb.base.serializable import Serializable @@ -363,19 +362,19 @@ def download_content( if n_download_workers is None: try: n_download_workers = db.metadata.get_metadata(key='n_download_workers') - except exceptions.MetadatastoreException: + except TypeError: n_download_workers = 0 if headers is None: try: headers = db.metadata.get_metadata(key='headers') - except exceptions.MetadatastoreException: + except TypeError: pass if timeout is None: try: timeout = db.metadata.get_metadata(key='download_timeout') - except exceptions.MetadatastoreException: + except TypeError: pass if CFG.hybrid_storage: diff --git a/superduperdb/vector_search/interface.py b/superduperdb/vector_search/interface.py index 758751b5c..a6a950ae2 100644 --- a/superduperdb/vector_search/interface.py +++ b/superduperdb/vector_search/interface.py @@ -3,7 +3,6 @@ import numpy as np from superduperdb import CFG -from superduperdb.base import exceptions from superduperdb.misc.server import request_server from superduperdb.vector_search.base import BaseVectorSearcher, VectorItem @@ -36,29 +35,19 @@ def add(self, items: t.Sequence[VectorItem]) -> None: :param items: t.Sequence of VectorItems """ - try: - vector_items = [{'vector': i.vector, 'id': i.id} for i in items] - if CFG.cluster.vector_search: - request_server( - service='vector_search', - data=vector_items, - endpoint='add/search', - args={ - 'vector_index': self.vector_index, - }, - ) - return - - return self.searcher.add(items) - except Exception as e: - local_msg = ( - 'remote vector search service' - if CFG.cluster.vector_search - else 'local vector search' + vector_items = [{'vector': i.vector, 'id': i.id} for i in items] + if CFG.cluster.vector_search: + request_server( + service='vector_search', + data=vector_items, + endpoint='add/search', + args={ + 'vector_index': self.vector_index, + }, ) - raise exceptions.VectorSearchException( - f'Error while adding vector to {local_msg}' - ) from e + return + + return self.searcher.add(items) def delete(self, ids: t.Sequence[str]) -> None: """ @@ -66,28 +55,18 @@ def delete(self, ids: t.Sequence[str]) -> None: :param ids: t.Sequence of ids of vectors. """ - try: - if CFG.cluster.vector_search: - request_server( - service='vector_search', - data=ids, - endpoint='delete/search', - args={ - 'vector_index': self.vector_index, - }, - ) - return - - return self.searcher.delete(ids) - except Exception as e: - local_msg = ( - 'remote vector search service' - if CFG.cluster.vector_search - else 'local vector search' + if CFG.cluster.vector_search: + request_server( + service='vector_search', + data=ids, + endpoint='delete/search', + args={ + 'vector_index': self.vector_index, + }, ) - raise exceptions.VectorSearchException( - f'Error while deleting ids {ids} from {local_msg}' - ) from e + return + + return self.searcher.delete(ids) def find_nearest_from_id( self, @@ -101,29 +80,15 @@ def find_nearest_from_id( :param _id: id of the vector :param n: number of nearest vectors to return """ - try: - if CFG.cluster.vector_search: - response = request_server( - service='vector_search', - endpoint='query/id/search', - args={'vector_index': self.vector_index, 'n': n, 'id': _id}, - ) - return response['ids'], response['scores'] - - return self.searcher.find_nearest_from_id(_id, n=n, within_ids=within_ids) - except Exception as e: - local_msg = ( - 'remote vector search service' - if CFG.cluster.vector_search - else 'local vector search' + if CFG.cluster.vector_search: + response = request_server( + service='vector_search', + endpoint='query/id/search', + args={'vector_index': self.vector_index, 'n': n, 'id': _id}, ) - raise exceptions.VectorSearchException( - f'Error while finding nearest array from {local_msg} \n\ - The problem might be either wrong id {_id} provided or vector database \ - is empty (Not initialized properly), check if model/listener outputs \ - are successfully computed, generally have issues \ - when computes is distributed to dask for example. ' - ) from e + return response['ids'], response['scores'] + + return self.searcher.find_nearest_from_id(_id, n=n, within_ids=within_ids) def find_nearest_from_array( self, @@ -137,29 +102,13 @@ def find_nearest_from_array( :param h: vector :param n: number of nearest vectors to return """ - try: - if CFG.cluster.vector_search: - response = request_server( - service='vector_search', - data=h, - endpoint='query/search', - args={'vector_index': self.vector_index, 'n': n}, - ) - return response['ids'], response['scores'] - - return self.searcher.find_nearest_from_array( - h=h, n=n, within_ids=within_ids - ) - except Exception as e: - local_msg = ( - 'remote vector search service' - if CFG.cluster.vector_search - else 'local vector search' + if CFG.cluster.vector_search: + response = request_server( + service='vector_search', + data=h, + endpoint='query/search', + args={'vector_index': self.vector_index, 'n': n}, ) - raise exceptions.VectorSearchException( - f'Error while finding nearest array from {local_msg} \n\ - The problem might be either wrong vector provided or vector database \ - is empty (Not initialized properly), check if model/listener outputs \ - are successfully computed, generally have issues \ - when computes is distributed to dask for example. ' - ) from e + return response['ids'], response['scores'] + + return self.searcher.find_nearest_from_array(h=h, n=n, within_ids=within_ids) diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index 7720d8e68..eaac83ccc 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -19,8 +19,7 @@ from superduperdb.backends.ibis.query import Table from superduperdb.backends.mongodb.data_backend import MongoDataBackend from superduperdb.backends.mongodb.query import Collection -from superduperdb.base import exceptions -from superduperdb.base.artifact import Artifact +from superduperdb.base.artifact import Artifact, ArtifactSavingError from superduperdb.base.datalayer import Datalayer from superduperdb.base.document import Document from superduperdb.base.exceptions import ComponentInUseError, ComponentInUseWarning @@ -145,7 +144,7 @@ def test_add_version(db): def test_add_component_with_bad_artifact(db): artifact = Artifact({'data': lambda x: x}, serializer='pickle') component = TestComponent(identifier='test', artifact=artifact) - with pytest.raises(exceptions.DatalayerException): + with pytest.raises(ArtifactSavingError): db.add(component) @@ -177,7 +176,7 @@ def test_add_child(db): assert parents == [component.unique_id] component_2 = TestComponent(identifier='test-2', child='child-2') - with pytest.raises(exceptions.DatalayerException): + with pytest.raises(FileNotFoundError): db.add(component_2) child_component_2 = TestComponent(identifier='child-2') @@ -279,7 +278,7 @@ def test_remove_component_from_data_layer_dict(db): test_encoder = Encoder(identifier='test_encoder', version=0) db.add(test_encoder) db._remove_component_version('encoder', 'test_encoder', 0, force=True) - with pytest.raises(exceptions.ComponentException): + with pytest.raises(FileNotFoundError): db.encoders['test_encoder'] @@ -339,7 +338,7 @@ def test_remove_multi_version(db): "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_not_exist_component(db): - with pytest.raises(exceptions.ComponentException) as e: + with pytest.raises(FileNotFoundError) as e: db.remove('test-component', 'test', 0, force=True) assert 'test' in str(e)