Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make trivial fixes to respect current docs #1349

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/hr/content/docs/06_glossary.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ tags:
| [**`Datalayer`**](07_datalayer_overview.md) | the main class used to connect with SuperDuperDB |
| [**`db`**](04_connecting.md) | name, by convention, of the instance of the `Datalayer` built at the beginning of all `superduperdb` programs |
| [**`Component`**](09_component_abstraction.md) | the base class which manages meta-data and serialization of all developer-added functionality |
| [**`Predictor`**] | A mixin class for `Component` descendants, to implement predictions |
| [**`Predictor`**](17_supported_ai_frameworks.md) | A mixin class for `Component` descendants, to implement predictions |
| [**`Model`**](17_supported_ai_frameworks.md) | the `Component` type responsible for wrapping AI models |
| [**`Document`**](10_document_encoder_abstraction.md) | the wrapper around dictionaries which `superduperdb` uses to send data back-and-forth to `db` |
| [**`Encoder`**](10_document_encoder_abstraction.md) | the `Component` type responsible for encoding special data-types |
Expand Down
4 changes: 2 additions & 2 deletions docs/hr/content/docs/13_sql_query_API.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ db.execute(
)
```

### Coming soon: support for raw-sql
### Support for raw-sql

... the first query above will be equivalent to:
... the first query above is equivalent to:

```python
db.execute(
Expand Down
2 changes: 1 addition & 1 deletion docs/hr/content/docs/29_developer_vs_production_mode.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ To set production mode, configure:
```python
from superduperdb import CFG

s.CFG.production = True
s.CFG.mode = 'production'
```

With production mode configured, the system assumes the existence of:
Expand Down
2 changes: 1 addition & 1 deletion docs/hr/content/docs/31_non_blocking_dask_jobs.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ To configure this feature, configure:
```python
from superduperdb import CFG

CFG.production = True
CFG.mode = 'production'
```

When this is so-configured the following functions push their computations to the `dask` cluster:
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ruff: noqa: E402
from .base import config, configs, jsonable, logger
from .misc.superduper import superduper
from .base.superduper import superduper

ICON = '🔮'
CFG = configs.CFG
Expand Down
10 changes: 5 additions & 5 deletions superduperdb/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from superduperdb.backends.ibis.query import Table
from superduperdb.base import serializable
from superduperdb.base.document import Document
from superduperdb.base.superduper import superduper
from superduperdb.cdc.cdc import DatabaseChangeDataCapture
from superduperdb.components.component import Component
from superduperdb.components.encoder import Encodable, Encoder
Expand Down Expand Up @@ -76,7 +77,8 @@ def __init__(
:param metadata: metadata object containing connection to Metadatastore
:param artifact_store: artifact_store object containing connection to
Artifactstore
:param distributed_client:
:param distributed_client: distributed_client object containing connection to
``dask`` cluster (leave alone)
"""
logging.info("Building Data Layer")

Expand Down Expand Up @@ -483,7 +485,7 @@ def update(self, update: Update, refresh: bool = True) -> UpdateResult:

def add(
self,
object: t.Union[Component, t.Sequence[Component]],
object: t.Union[Component, t.Sequence[Component], t.Any],
dependencies: t.Sequence[Job] = (),
):
"""
Expand All @@ -506,9 +508,7 @@ def add(
elif isinstance(object, Component):
return self._add(object=object, dependencies=dependencies)
else:
raise ValueError(
'object should be a sequence of `Component` or `Component`'
)
return self._add(superduper(object))

def remove(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
__all__ = ('superduper',)


def superduper(item: t.Any, **kwargs) -> t.Any:
def superduper(item: t.Optional[t.Any] = None, **kwargs) -> t.Any:
"""
Attempts to automatically wrap an item in a superduperdb container by
using duck typing to recognize it.

:param item: A database or model
"""

if item is None:
from superduperdb.base.build import build_datalayer

return build_datalayer()

if isinstance(item, str):
return _auto_identify_connection_string(item, **kwargs)

Expand All @@ -35,7 +40,7 @@ def _auto_identify_connection_string(item: str, **kwargs) -> t.Any:

else:
if re.match(r'^[a-zA-Z0-9]+://', item) is None:
raise NotImplementedError(f'{item} is not a valid connection string')
raise ValueError(f'{item} is not a valid connection string')
CFG.data_backend = item
return build_datalayer(CFG, **kwargs)

Expand All @@ -48,7 +53,7 @@ class _DuckTyper:
def run(item: t.Any, **kwargs) -> t.Any:
dts = [dt for dt in _DuckTyper._DUCK_TYPES if dt.accept(item)]
if not dts:
raise NotImplementedError(
raise ValueError(
f'Couldn\'t auto-identify {item}, please wrap explicitly using '
'``superduperdb.container.*``'
)
Expand Down
3 changes: 2 additions & 1 deletion superduperdb/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class Dataset(Component):
random_seed: t.Optional[int] = None
creation_date: t.Optional[str] = None
raw_data: t.Optional[t.Union[Artifact, t.Any]] = None
version: t.Optional[int] = None

# Don't set these manually
version: t.Optional[int] = None
type_id: t.ClassVar[str] = 'dataset'

@override
Expand Down
11 changes: 6 additions & 5 deletions superduperdb/components/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ class Encoder(Component):
`CFG.hybrid` mode
"""

type_id: t.ClassVar[str] = 'encoder'
# TODO what's this for?
encoders: t.ClassVar[t.List] = []
artifact_artibutes: t.ClassVar[t.Sequence[str]] = ['decoder', 'encoder']

identifier: str
decoder: t.Union[t.Callable, Artifact] = dc.field(
default_factory=lambda: Artifact(artifact=_pickle_decoder)
Expand All @@ -50,9 +46,14 @@ class Encoder(Component):
default_factory=lambda: Artifact(artifact=_pickle_encoder)
)
shape: t.Optional[t.Sequence] = None
version: t.Optional[int] = None
load_hybrid: bool = True

# Don't set this manually
version: t.Optional[int] = None
type_id: t.ClassVar[str] = 'encoder'
# TODO what's this for?
encoders: t.ClassVar[t.List] = []

def __post_init__(self):
self.encoders.append(self.identifier)
if isinstance(self.decoder, t.Callable):
Expand Down
2 changes: 2 additions & 0 deletions superduperdb/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Listener(Component):
active: bool = True
identifier: t.Optional[str] = None
predict_kwargs: t.Optional[t.Dict] = dc.field(default_factory=dict)

# Don't set this manually
version: t.Optional[int] = None

type_id: t.ClassVar[str] = 'listener'
Expand Down
49 changes: 36 additions & 13 deletions superduperdb/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dataclasses as dc
import multiprocessing
import typing as t
from abc import abstractmethod
from functools import wraps

import tqdm
Expand Down Expand Up @@ -43,11 +44,11 @@ class _TrainingConfiguration(Component):

:param identifier: Unique identifier of configuration
:param **kwargs: Key-values pairs, the variables which configure training.
:param version: Version number of the configuration
"""

identifier: str
kwargs: t.Optional[t.Dict] = None

version: t.Optional[int] = None

type_id: t.ClassVar[str] = 'training_configuration'
Expand All @@ -59,18 +60,21 @@ def get(self, k, default=None):
return self.kwargs.get(k, default)


class PredictMixin:
class Predictor:
"""
Mixin class for components which can predict.

:param identifier: Unique identifier of model
:param encoder: Encoder instance (optional)
:param output_schema: Output schema (mapping of encoders) (optional)
:param flatten: Flatten the model outputs
:param preprocess: Preprocess function (optional)
:param postprocess: Postprocess function (optional)
:param collate_fn: Collate function (optional)
:param batch_predict: Whether to batch predict (optional)
:param takes_context: Whether the model takes context into account (optional)
:param to_call: The method to use for prediction (optional)
:param model_update_kwargs: The kwargs to use for model update (optional)
"""

identifier: str
Expand All @@ -80,12 +84,14 @@ class PredictMixin:
preprocess: t.Union[t.Callable, Artifact, None] = None
postprocess: t.Union[t.Callable, Artifact, None] = None
collate_fn: t.Union[t.Callable, Artifact, None] = None
version: t.Optional[int] = None
batch_predict: bool
takes_context: bool
to_call: t.Callable
model_update_kwargs: t.Dict

version: t.Optional[int] = None
type_id: t.ClassVar[str] = 'model'

def create_predict_job(
self,
X: str,
Expand Down Expand Up @@ -161,7 +167,6 @@ def _predict(
elif self.collate_fn is not None:
X = self.collate_fn(X)

# TODO: Miss hanlding the context in _forward
outputs = self._forward(X, **predict_kwargs)

if isinstance(self.postprocess, Artifact):
Expand Down Expand Up @@ -409,7 +414,7 @@ def _predict_with_select_and_ids(


@dc.dataclass
class Model(Component, PredictMixin):
class Model(Component, Predictor):
"""Model component which wraps a model to become serializable

:param identifier: Unique identifier of model
Expand All @@ -425,7 +430,15 @@ class Model(Component, PredictMixin):
:param model_to_device_method: The method to transfer the model to a device
:param batch_predict: Whether to batch predict (optional)
:param takes_context: Whether the model takes context into account (optional)
:param serializer: Serializer to store model to artifact store(optional)
:param train_X: The key of the input data to use for training (optional)
:param train_y: The key of the target data to use for training (optional)
:param training_select: The select to use for training (optional)
:param metric_values: The metric values (optional)
:param training_configuration: The training configuration (optional)
:param model_update_kwargs: The kwargs to use for model update (optional)
:param serializer: Serializer to store model to artifact store (optional)
:param device: The device to use (optional)
:param preferred_devices: The preferred devices to use (optional)
"""

identifier: str
Expand All @@ -448,16 +461,14 @@ class Model(Component, PredictMixin):
training_configuration: t.Union[str, _TrainingConfiguration, None] = None
model_update_kwargs: dict = dc.field(default_factory=dict)
serializer: str = 'dill'

version: t.Optional[int] = None
future: t.Optional[Future] = None
device: str = "cpu"

# TODO: handle situation with multiple GPUs
preferred_devices: t.Union[None, t.Sequence[str]] = ("cuda", "mps", "cpu")

artifact_attributes: t.ClassVar[t.Sequence[str]] = ['object']
# Don't set these manually
future: t.Optional[Future] = None
version: t.Optional[int] = None

artifact_attributes: t.ClassVar[t.Sequence[str]] = ['object']
type_id: t.ClassVar[str] = 'model'

def __post_init__(self):
Expand Down Expand Up @@ -585,6 +596,7 @@ def create_fit_job(
},
)

@abstractmethod
def _fit(
self,
X: t.Any,
Expand All @@ -597,7 +609,7 @@ def _fit(
select: t.Optional[Select] = None,
validation_sets: t.Optional[t.Sequence[t.Union[str, Dataset]]] = None,
):
raise NotImplementedError
pass

def fit(
self,
Expand All @@ -615,6 +627,17 @@ def fit(
) -> t.Optional[Pipeline]:
"""
Fit the model on the given data.

:param X: The key of the input data to use for training
:param y: The key of the target data to use for training
:param configuration: The training configuration (optional)
:param data_prefetch: Whether to prefetch the data (optional)
:param db: The datalayer (optional)
:param dependencies: The dependencies (optional)
:param distributed: Whether to distribute the job (optional)
:param metrics: The metrics to evaluate on (optional)
:param select: The select to use for training (optional)
:param validation_sets: The validation ``Dataset`` instances to use (optional)
"""
if isinstance(select, dict):
select = Serializable.deserialize(select)
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/ext/anthropic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from superduperdb.components.component import Component
from superduperdb.components.encoder import Encoder
from superduperdb.components.model import PredictMixin
from superduperdb.components.model import Predictor
from superduperdb.ext.utils import format_prompt, get_key
from superduperdb.misc.retry import Retry

Expand All @@ -18,7 +18,7 @@


@dc.dataclass
class Anthropic(Component, PredictMixin):
class Anthropic(Component, Predictor):
"""Anthropic predictor.

:param model: The model to use, e.g. ``'claude-2'``.
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/ext/cohere/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from superduperdb.components.component import Component
from superduperdb.components.encoder import Encoder
from superduperdb.components.model import PredictMixin
from superduperdb.components.model import Predictor
from superduperdb.components.vector_index import vector
from superduperdb.ext.utils import format_prompt, get_key
from superduperdb.misc.retry import Retry
Expand All @@ -18,7 +18,7 @@


@dc.dataclass
class Cohere(Component, PredictMixin):
class Cohere(Component, Predictor):
"""Cohere predictor

:param model: The model to use, e.g. ``'base-light'``.
Expand Down
7 changes: 2 additions & 5 deletions superduperdb/ext/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from superduperdb.components.component import Component
from superduperdb.components.encoder import Encoder
from superduperdb.components.model import PredictMixin
from superduperdb.components.model import Predictor
from superduperdb.components.vector_index import vector
from superduperdb.misc.compat import cache
from superduperdb.misc.retry import Retry
Expand All @@ -32,7 +32,7 @@ def _available_models():


@dc.dataclass
class OpenAI(Component, PredictMixin):
class OpenAI(Component, Predictor):
"""OpenAI predictor.

:param model: The model to use, e.g. ``'text-embedding-ada-002'``.
Expand All @@ -49,9 +49,6 @@ class OpenAI(Component, PredictMixin):
encoder: t.Union[Encoder, str, None] = None
model_update_kwargs: dict = dc.field(default_factory=dict)

#: A unique name for the class
type_id: t.ClassVar[str] = 'model'

@property
def child_components(self):
if self.encoder is not None:
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import superduperdb as s
from superduperdb import logging
from superduperdb.misc import superduper
from superduperdb.base import superduper

try:
import torch
Expand Down
Loading