diff --git a/docs/hr/content/docs/06_glossary.md b/docs/hr/content/docs/06_glossary.md index 091591503..56c36c5e5 100644 --- a/docs/hr/content/docs/06_glossary.md +++ b/docs/hr/content/docs/06_glossary.md @@ -11,7 +11,7 @@ tags: | [**`Datalayer`**](07_datalayer_overview.md) | the main class used to connect with SuperDuperDB | | [**`db`**](04_connecting.md) | name, by convention, of the instance of the `Datalayer` built at the beginning of all `superduperdb` programs | | [**`Component`**](09_component_abstraction.md) | the base class which manages meta-data and serialization of all developer-added functionality | -| [**`Predictor`**] | A mixin class for `Component` descendants, to implement predictions | +| [**`Predictor`**](17_supported_ai_frameworks.md) | A mixin class for `Component` descendants, to implement predictions | | [**`Model`**](17_supported_ai_frameworks.md) | the `Component` type responsible for wrapping AI models | | [**`Document`**](10_document_encoder_abstraction.md) | the wrapper around dictionaries which `superduperdb` uses to send data back-and-forth to `db` | | [**`Encoder`**](10_document_encoder_abstraction.md) | the `Component` type responsible for encoding special data-types | diff --git a/docs/hr/content/docs/13_sql_query_API.md b/docs/hr/content/docs/13_sql_query_API.md index 6ce553efa..3aeb13c37 100644 --- a/docs/hr/content/docs/13_sql_query_API.md +++ b/docs/hr/content/docs/13_sql_query_API.md @@ -74,9 +74,9 @@ db.execute( ) ``` -### Coming soon: support for raw-sql +### Support for raw-sql -... the first query above will be equivalent to: +... the first query above is equivalent to: ```python db.execute( diff --git a/docs/hr/content/docs/29_developer_vs_production_mode.md b/docs/hr/content/docs/29_developer_vs_production_mode.md index a3d6e3aae..9a8daac79 100644 --- a/docs/hr/content/docs/29_developer_vs_production_mode.md +++ b/docs/hr/content/docs/29_developer_vs_production_mode.md @@ -19,7 +19,7 @@ To set production mode, configure: ```python from superduperdb import CFG -s.CFG.production = True +s.CFG.mode = 'production' ``` With production mode configured, the system assumes the existence of: diff --git a/docs/hr/content/docs/31_non_blocking_dask_jobs.md b/docs/hr/content/docs/31_non_blocking_dask_jobs.md index 59eac2408..90ba8b897 100644 --- a/docs/hr/content/docs/31_non_blocking_dask_jobs.md +++ b/docs/hr/content/docs/31_non_blocking_dask_jobs.md @@ -12,7 +12,7 @@ To configure this feature, configure: ```python from superduperdb import CFG -CFG.production = True +CFG.mode = 'production' ``` When this is so-configured the following functions push their computations to the `dask` cluster: diff --git a/superduperdb/__init__.py b/superduperdb/__init__.py index 69262449e..0a4215ae4 100644 --- a/superduperdb/__init__.py +++ b/superduperdb/__init__.py @@ -1,6 +1,6 @@ # ruff: noqa: E402 from .base import config, configs, jsonable, logger -from .misc.superduper import superduper +from .base.superduper import superduper ICON = '🔮' CFG = configs.CFG diff --git a/superduperdb/base/datalayer.py b/superduperdb/base/datalayer.py index 488a44524..14246e35d 100644 --- a/superduperdb/base/datalayer.py +++ b/superduperdb/base/datalayer.py @@ -16,6 +16,7 @@ from superduperdb.backends.ibis.query import Table from superduperdb.base import serializable from superduperdb.base.document import Document +from superduperdb.base.superduper import superduper from superduperdb.cdc.cdc import DatabaseChangeDataCapture from superduperdb.components.component import Component from superduperdb.components.encoder import Encodable, Encoder @@ -76,7 +77,8 @@ def __init__( :param metadata: metadata object containing connection to Metadatastore :param artifact_store: artifact_store object containing connection to Artifactstore - :param distributed_client: + :param distributed_client: distributed_client object containing connection to + ``dask`` cluster (leave alone) """ logging.info("Building Data Layer") @@ -483,7 +485,7 @@ def update(self, update: Update, refresh: bool = True) -> UpdateResult: def add( self, - object: t.Union[Component, t.Sequence[Component]], + object: t.Union[Component, t.Sequence[Component], t.Any], dependencies: t.Sequence[Job] = (), ): """ @@ -506,9 +508,7 @@ def add( elif isinstance(object, Component): return self._add(object=object, dependencies=dependencies) else: - raise ValueError( - 'object should be a sequence of `Component` or `Component`' - ) + return self._add(superduper(object)) def remove( self, diff --git a/superduperdb/misc/superduper.py b/superduperdb/base/superduper.py similarity index 93% rename from superduperdb/misc/superduper.py rename to superduperdb/base/superduper.py index 67bd86061..831ffafc7 100644 --- a/superduperdb/misc/superduper.py +++ b/superduperdb/base/superduper.py @@ -6,7 +6,7 @@ __all__ = ('superduper',) -def superduper(item: t.Any, **kwargs) -> t.Any: +def superduper(item: t.Optional[t.Any] = None, **kwargs) -> t.Any: """ Attempts to automatically wrap an item in a superduperdb container by using duck typing to recognize it. @@ -14,6 +14,11 @@ def superduper(item: t.Any, **kwargs) -> t.Any: :param item: A database or model """ + if item is None: + from superduperdb.base.build import build_datalayer + + return build_datalayer() + if isinstance(item, str): return _auto_identify_connection_string(item, **kwargs) @@ -35,7 +40,7 @@ def _auto_identify_connection_string(item: str, **kwargs) -> t.Any: else: if re.match(r'^[a-zA-Z0-9]+://', item) is None: - raise NotImplementedError(f'{item} is not a valid connection string') + raise ValueError(f'{item} is not a valid connection string') CFG.data_backend = item return build_datalayer(CFG, **kwargs) @@ -48,7 +53,7 @@ class _DuckTyper: def run(item: t.Any, **kwargs) -> t.Any: dts = [dt for dt in _DuckTyper._DUCK_TYPES if dt.accept(item)] if not dts: - raise NotImplementedError( + raise ValueError( f'Couldn\'t auto-identify {item}, please wrap explicitly using ' '``superduperdb.container.*``' ) diff --git a/superduperdb/components/dataset.py b/superduperdb/components/dataset.py index cac3bd4c5..2576a4440 100644 --- a/superduperdb/components/dataset.py +++ b/superduperdb/components/dataset.py @@ -35,8 +35,9 @@ class Dataset(Component): random_seed: t.Optional[int] = None creation_date: t.Optional[str] = None raw_data: t.Optional[t.Union[Artifact, t.Any]] = None - version: t.Optional[int] = None + # Don't set these manually + version: t.Optional[int] = None type_id: t.ClassVar[str] = 'dataset' @override diff --git a/superduperdb/components/encoder.py b/superduperdb/components/encoder.py index a88edeeb1..bbb7ccdc8 100644 --- a/superduperdb/components/encoder.py +++ b/superduperdb/components/encoder.py @@ -37,11 +37,7 @@ class Encoder(Component): `CFG.hybrid` mode """ - type_id: t.ClassVar[str] = 'encoder' - # TODO what's this for? - encoders: t.ClassVar[t.List] = [] artifact_artibutes: t.ClassVar[t.Sequence[str]] = ['decoder', 'encoder'] - identifier: str decoder: t.Union[t.Callable, Artifact] = dc.field( default_factory=lambda: Artifact(artifact=_pickle_decoder) @@ -50,9 +46,14 @@ class Encoder(Component): default_factory=lambda: Artifact(artifact=_pickle_encoder) ) shape: t.Optional[t.Sequence] = None - version: t.Optional[int] = None load_hybrid: bool = True + # Don't set this manually + version: t.Optional[int] = None + type_id: t.ClassVar[str] = 'encoder' + # TODO what's this for? + encoders: t.ClassVar[t.List] = [] + def __post_init__(self): self.encoders.append(self.identifier) if isinstance(self.decoder, t.Callable): diff --git a/superduperdb/components/listener.py b/superduperdb/components/listener.py index 617391a67..02b741dce 100644 --- a/superduperdb/components/listener.py +++ b/superduperdb/components/listener.py @@ -34,6 +34,8 @@ class Listener(Component): active: bool = True identifier: t.Optional[str] = None predict_kwargs: t.Optional[t.Dict] = dc.field(default_factory=dict) + + # Don't set this manually version: t.Optional[int] = None type_id: t.ClassVar[str] = 'listener' diff --git a/superduperdb/components/model.py b/superduperdb/components/model.py index 8bc956055..6bab95533 100644 --- a/superduperdb/components/model.py +++ b/superduperdb/components/model.py @@ -3,6 +3,7 @@ import dataclasses as dc import multiprocessing import typing as t +from abc import abstractmethod from functools import wraps import tqdm @@ -43,11 +44,11 @@ class _TrainingConfiguration(Component): :param identifier: Unique identifier of configuration :param **kwargs: Key-values pairs, the variables which configure training. - :param version: Version number of the configuration """ identifier: str kwargs: t.Optional[t.Dict] = None + version: t.Optional[int] = None type_id: t.ClassVar[str] = 'training_configuration' @@ -59,18 +60,21 @@ def get(self, k, default=None): return self.kwargs.get(k, default) -class PredictMixin: +class Predictor: """ Mixin class for components which can predict. :param identifier: Unique identifier of model :param encoder: Encoder instance (optional) + :param output_schema: Output schema (mapping of encoders) (optional) + :param flatten: Flatten the model outputs :param preprocess: Preprocess function (optional) :param postprocess: Postprocess function (optional) :param collate_fn: Collate function (optional) :param batch_predict: Whether to batch predict (optional) :param takes_context: Whether the model takes context into account (optional) :param to_call: The method to use for prediction (optional) + :param model_update_kwargs: The kwargs to use for model update (optional) """ identifier: str @@ -80,12 +84,14 @@ class PredictMixin: preprocess: t.Union[t.Callable, Artifact, None] = None postprocess: t.Union[t.Callable, Artifact, None] = None collate_fn: t.Union[t.Callable, Artifact, None] = None - version: t.Optional[int] = None batch_predict: bool takes_context: bool to_call: t.Callable model_update_kwargs: t.Dict + version: t.Optional[int] = None + type_id: t.ClassVar[str] = 'model' + def create_predict_job( self, X: str, @@ -161,7 +167,6 @@ def _predict( elif self.collate_fn is not None: X = self.collate_fn(X) - # TODO: Miss hanlding the context in _forward outputs = self._forward(X, **predict_kwargs) if isinstance(self.postprocess, Artifact): @@ -409,7 +414,7 @@ def _predict_with_select_and_ids( @dc.dataclass -class Model(Component, PredictMixin): +class Model(Component, Predictor): """Model component which wraps a model to become serializable :param identifier: Unique identifier of model @@ -425,7 +430,15 @@ class Model(Component, PredictMixin): :param model_to_device_method: The method to transfer the model to a device :param batch_predict: Whether to batch predict (optional) :param takes_context: Whether the model takes context into account (optional) - :param serializer: Serializer to store model to artifact store(optional) + :param train_X: The key of the input data to use for training (optional) + :param train_y: The key of the target data to use for training (optional) + :param training_select: The select to use for training (optional) + :param metric_values: The metric values (optional) + :param training_configuration: The training configuration (optional) + :param model_update_kwargs: The kwargs to use for model update (optional) + :param serializer: Serializer to store model to artifact store (optional) + :param device: The device to use (optional) + :param preferred_devices: The preferred devices to use (optional) """ identifier: str @@ -448,16 +461,14 @@ class Model(Component, PredictMixin): training_configuration: t.Union[str, _TrainingConfiguration, None] = None model_update_kwargs: dict = dc.field(default_factory=dict) serializer: str = 'dill' - - version: t.Optional[int] = None - future: t.Optional[Future] = None device: str = "cpu" - - # TODO: handle situation with multiple GPUs preferred_devices: t.Union[None, t.Sequence[str]] = ("cuda", "mps", "cpu") - artifact_attributes: t.ClassVar[t.Sequence[str]] = ['object'] + # Don't set these manually + future: t.Optional[Future] = None + version: t.Optional[int] = None + artifact_attributes: t.ClassVar[t.Sequence[str]] = ['object'] type_id: t.ClassVar[str] = 'model' def __post_init__(self): @@ -585,6 +596,7 @@ def create_fit_job( }, ) + @abstractmethod def _fit( self, X: t.Any, @@ -597,7 +609,7 @@ def _fit( select: t.Optional[Select] = None, validation_sets: t.Optional[t.Sequence[t.Union[str, Dataset]]] = None, ): - raise NotImplementedError + pass def fit( self, @@ -615,6 +627,17 @@ def fit( ) -> t.Optional[Pipeline]: """ Fit the model on the given data. + + :param X: The key of the input data to use for training + :param y: The key of the target data to use for training + :param configuration: The training configuration (optional) + :param data_prefetch: Whether to prefetch the data (optional) + :param db: The datalayer (optional) + :param dependencies: The dependencies (optional) + :param distributed: Whether to distribute the job (optional) + :param metrics: The metrics to evaluate on (optional) + :param select: The select to use for training (optional) + :param validation_sets: The validation ``Dataset`` instances to use (optional) """ if isinstance(select, dict): select = Serializable.deserialize(select) diff --git a/superduperdb/ext/anthropic/model.py b/superduperdb/ext/anthropic/model.py index 2a69c10ba..16ad55b79 100644 --- a/superduperdb/ext/anthropic/model.py +++ b/superduperdb/ext/anthropic/model.py @@ -6,7 +6,7 @@ from superduperdb.components.component import Component from superduperdb.components.encoder import Encoder -from superduperdb.components.model import PredictMixin +from superduperdb.components.model import Predictor from superduperdb.ext.utils import format_prompt, get_key from superduperdb.misc.retry import Retry @@ -18,7 +18,7 @@ @dc.dataclass -class Anthropic(Component, PredictMixin): +class Anthropic(Component, Predictor): """Anthropic predictor. :param model: The model to use, e.g. ``'claude-2'``. diff --git a/superduperdb/ext/cohere/model.py b/superduperdb/ext/cohere/model.py index 5d3dcf471..525b477e8 100644 --- a/superduperdb/ext/cohere/model.py +++ b/superduperdb/ext/cohere/model.py @@ -7,7 +7,7 @@ from superduperdb.components.component import Component from superduperdb.components.encoder import Encoder -from superduperdb.components.model import PredictMixin +from superduperdb.components.model import Predictor from superduperdb.components.vector_index import vector from superduperdb.ext.utils import format_prompt, get_key from superduperdb.misc.retry import Retry @@ -18,7 +18,7 @@ @dc.dataclass -class Cohere(Component, PredictMixin): +class Cohere(Component, Predictor): """Cohere predictor :param model: The model to use, e.g. ``'base-light'``. diff --git a/superduperdb/ext/openai/model.py b/superduperdb/ext/openai/model.py index 6dd7ed6be..52804d2d3 100644 --- a/superduperdb/ext/openai/model.py +++ b/superduperdb/ext/openai/model.py @@ -18,7 +18,7 @@ from superduperdb.components.component import Component from superduperdb.components.encoder import Encoder -from superduperdb.components.model import PredictMixin +from superduperdb.components.model import Predictor from superduperdb.components.vector_index import vector from superduperdb.misc.compat import cache from superduperdb.misc.retry import Retry @@ -32,7 +32,7 @@ def _available_models(): @dc.dataclass -class OpenAI(Component, PredictMixin): +class OpenAI(Component, Predictor): """OpenAI predictor. :param model: The model to use, e.g. ``'text-embedding-ada-002'``. @@ -49,9 +49,6 @@ class OpenAI(Component, PredictMixin): encoder: t.Union[Encoder, str, None] = None model_update_kwargs: dict = dc.field(default_factory=dict) - #: A unique name for the class - type_id: t.ClassVar[str] = 'model' - @property def child_components(self): if self.encoder is not None: diff --git a/test/conftest.py b/test/conftest.py index ab7251f6e..7cf5e8607 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -10,7 +10,7 @@ import superduperdb as s from superduperdb import logging -from superduperdb.misc import superduper +from superduperdb.base import superduper try: import torch diff --git a/test/unittest/component/test_model.py b/test/unittest/component/test_model.py index f60f9a425..2dad2a329 100644 --- a/test/unittest/component/test_model.py +++ b/test/unittest/component/test_model.py @@ -16,7 +16,7 @@ from superduperdb.components.metric import Metric from superduperdb.components.model import ( Model, - PredictMixin, + Predictor, TrainingConfiguration, _TrainingConfiguration, ) @@ -73,13 +73,13 @@ def mock_forward(self, x, **kwargs): return to_call(x) -class TestModel(Component, PredictMixin): +class TestModel(Component, Predictor): ... @pytest.fixture -def predict_mixin(request) -> PredictMixin: - cls_ = getattr(request, 'param', PredictMixin) +def predict_mixin(request) -> Predictor: + cls_ = getattr(request, 'param', Predictor) predict_mixin = cls_() predict_mixin.identifier = 'test' predict_mixin.to_call = to_call @@ -135,7 +135,7 @@ def test_pm_predict_one(predict_mixin): ], ) def test_pm_forward(batch_predict, num_workers, expect_type): - predict_mixin = PredictMixin() + predict_mixin = Predictor() X = np.random.randn(4, 5) predict_mixin.to_call = to_call @@ -146,7 +146,7 @@ def test_pm_forward(batch_predict, num_workers, expect_type): assert np.allclose(output, to_call(X)) -@patch.object(PredictMixin, '_forward', mock_forward) +@patch.object(Predictor, '_forward', mock_forward) def test_pm_core_predict(predict_mixin): X = np.random.randn(4, 5) @@ -286,7 +286,7 @@ def return_value(select_type): assert kwargs.get('ids') == ids_of_missing_outputs -@patch.object(PredictMixin, '_predict') +@patch.object(Predictor, '_predict') def test_pm_predict_with_select_ids(predict_mock, predict_mixin): xs = [np.random.randn(4) for _ in range(10)] ys = [int(random.random() > 0.5) for i in range(10)] diff --git a/test/unittest/misc/test_superduper.py b/test/unittest/misc/test_superduper.py index 891b054d4..0b8b190c3 100644 --- a/test/unittest/misc/test_superduper.py +++ b/test/unittest/misc/test_superduper.py @@ -6,7 +6,7 @@ torch = None from superduperdb import superduper -from superduperdb.misc.superduper import SklearnTyper, TorchTyper +from superduperdb.base.superduper import SklearnTyper, TorchTyper def test_sklearn_typer(): @@ -31,10 +31,10 @@ def test_superduper_model(): def test_superduper_raise(): - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): superduper(1) - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): superduper("string")