Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename watch to listen everywhere (finish #583) #592

Merged
merged 1 commit into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ model.predict(

```python
model.predict(
X='input_col',
db=db,
select=coll.find().featurize({'X': '<upstream-model-id>'}), # already registered upstream model-id
watch=True,
X='input_col',
db=db,
select=coll.find().featurize({'X': '<upstream-model-id>'}), # already registered upstream model-id
listen=True,
)
```

Expand Down
3 changes: 2 additions & 1 deletion docs/cluster/change_data_capture.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ watcher = DatabaseWatcher(db=db, on=Collection(name='docs'))
```

Start the watcher thread to initiate the change stream monitoring:

```python
watcher.watch()
watcher.listen()
```

See [here](/how_to/mongo_cdc.html) for an example of usage of CDC.
4 changes: 2 additions & 2 deletions docs/cluster/jobs.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ the [configuration stystem](configuration).
The stdout and status of the job may be monitored using the returned `Job` object:

```python
>>> job = model.predict(X='my-key', db=db, select=collection.find())
>>> job.watch()
>> > job = model.predict(X='my-key', db=db, select=collection.find())
>> > job.listen()
# ... lots of lines of stdout
```

Expand Down
10 changes: 5 additions & 5 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ SuperDuperDB contains components allowing developers to configure models to cont
```python
# Watch the database for incoming data, and process this with a model
# Model outputs are continuously stored in the input records
model.predict(X='input_col', db=db, select=coll.find(), watch=True)
model.predict(X='input_col', db=db, select=coll.find(), listen=True)
```

## Use models outputs as inputs to downstream models
Expand All @@ -112,10 +112,10 @@ Simply add a simple method `featurize` to your queries, to register the fact tha

```python
model.predict(
X='input_col',
db=db,
select=coll.find().featurize({'X': '<upstream-model-id>'}), # already registered upstream model-id
watch=True,
X='input_col',
db=db,
select=coll.find().featurize({'X': '<upstream-model-id>'}), # already registered upstream model-id
listen=True,
)
```

Expand Down
2 changes: 1 addition & 1 deletion docs/usage/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ then the models spring into action, processing this new data, and repopulating o
into the datalayer.

```python
>>> model.predict(X='input_col', db=db, select=coll.find(), watch=True)
>> > model.predict(X='input_col', db=db, select=coll.find(), listen=True)
```

An equivalent syntax is the following:
Expand Down
6 changes: 3 additions & 3 deletions superduperdb/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ def serve():
@command(help='Start local cluster: server, dask and change data capture')
def local_cluster(on: t.List[str] = []):
from superduperdb.db.base.build import build_datalayer
from superduperdb.db.base.cdc import DatabaseWatcher
from superduperdb.db.base.cdc import DatabaseListener
from superduperdb.db.mongodb.query import Collection
from superduperdb.server.dask_client import dask_client
from superduperdb.server.server import serve

db = build_datalayer()
dask_client(s.CFG.dask, local=True)
for collection in on:
w = DatabaseWatcher(
w = DatabaseListener(
db=db,
on=Collection(name=collection),
)
w.watch()
w.listen()
serve(db)
4 changes: 2 additions & 2 deletions superduperdb/container/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@dc.dataclass
class Component(Serializable):
"""
Base component which model, watchers, learning tasks etc. inherit from.
Base component which model, listeners, learning tasks etc. inherit from.

:param identifier: Unique ID
"""
Expand Down Expand Up @@ -82,7 +82,7 @@ def schedule_jobs(
distributed: bool = False,
verbose: bool = False,
) -> t.Sequence[t.Any]:
"""Run the job for this watcher
"""Run the job for this listener

:param database: The db to process
:param dependencies: A sequence of dependencies,
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/container/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
self.db = None
self.future = None

def watch(self):
return self.db.metadata.watch_job(identifier=self.identifier)
def listen(self):
return self.db.metadata.listen_job(identifier=self.identifier)

def run_locally(self, db):
try:
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/container/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def predict(
ids: t.Optional[t.Sequence[str]] = None,
max_chunk_size: t.Optional[int] = None,
dependencies: t.Sequence[Job] = (),
watch: bool = False,
listen: bool = False,
one: bool = False,
context: t.Optional[t.Dict] = None,
in_memory: bool = True,
Expand All @@ -167,7 +167,7 @@ def predict(
if isinstance(select, dict):
select = Serializable.deserialize(select)

if watch:
if listen:
from superduperdb.container.listener import Listener

if db is None:
Expand Down
6 changes: 3 additions & 3 deletions superduperdb/container/task_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def add_edge(self, node1: str, node2: str) -> None:
def add_node(self, node: str, job: t.Union[FunctionJob, ComponentJob]) -> None:
self.G.add_node(node, job=job)

def watch(self) -> None:
"""Watch each job in this workflow in topological order"""
def listen(self) -> None:
"""Listen to each job in this workflow in topological order"""
for node in list(networkx.topological_sort(self.G)):
self.G.nodes[node]['job'].watch()
self.G.nodes[node]['job'].listen()

def run_jobs(self, distributed: t.Optional[bool] = False):
"""Run all the jobs in this workflow
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/container/vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class VectorIndex(Component):
#: Unique string identifier of index
identifier: str

#: Watcher which is applied to created vectors
#: Listener which is applied to created vectors
indexing_listener: t.Union[Listener, str]

#: List of additional listeners which can "talk" to the index (e.g. multi-modal)
Expand Down
35 changes: 17 additions & 18 deletions superduperdb/db/base/cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
Use this module like this::
db = any_arbitary_database.connect(...)
db = superduper(db)
watcher = DatabaseWatcher(db=db, on=Collection('test_collection'))
watcher.watch()
listener = DatabaseListener(db=db, on=Collection('test_collection'))
listener.listen()
"""

import threading
Expand All @@ -31,42 +31,41 @@
from superduperdb.db.mongodb import cdc
from superduperdb.db.mongodb.query import Collection

DBWatcherType = t.TypeVar('DBWatcherType')
DBListenerType = t.TypeVar('DBListenerType')


class DatabaseWatcherFactory(t.Generic[DBWatcherType]):
"""DatabaseWatcherFactory.
A Factory class to create instance of DatabaseWatcher corresponding to the
class DatabaseListenerFactory(t.Generic[DBListenerType]):
"""A Factory class to create instance of DatabaseListener corresponding to the
`db_type`.
"""

SUPPORTED_WATCHERS: t.List[str] = ['mongodb']
SUPPORTED_LISTENERS: t.List[str] = ['mongodb']

def __init__(self, db_type: str = 'mongodb'):
if db_type not in self.SUPPORTED_WATCHERS:
if db_type not in self.SUPPORTED_LISTENERS:
raise ValueError(f'{db_type} is not supported yet for CDC.')
self.watcher = db_type
self.listener = db_type

def create(self, *args, **kwargs) -> DBWatcherType:
def create(self, *args, **kwargs) -> DBListenerType:
stop_event = threading.Event()
kwargs['stop_event'] = stop_event
watcher = cdc.MongoDatabaseWatcher(*args, **kwargs)
return t.cast(DBWatcherType, watcher)
listener = cdc.MongoDatabaseListener(*args, **kwargs)
return t.cast(DBListenerType, listener)


def DatabaseWatcher(
def DatabaseListener(
db: DB,
on: Collection,
identifier: str = '',
*args,
**kwargs,
) -> cdc.BaseDatabaseWatcher:
) -> cdc.BaseDatabaseListener:
"""
Create an instance of `BaseDatabaseWatcher`
Create an instance of `BaseDatabaseListener`

:param db: A superduperdb instance.
:param on: Which collection/table watcher service this be invoked on?
:param identifier: A identity given to the watcher service.
:param on: Which collection/table listener service this be invoked on?
:param identifier: A identity given to the listener service.
"""
it = backends.data_backends.items()
if types := [k for k, v in it if isinstance(db.databackend, v)]:
Expand All @@ -77,6 +76,6 @@ def DatabaseWatcher(
if db_type != 'mongodb':
raise NotImplementedError(f'Database {db_type} not supported yet!')

factory_factory = DatabaseWatcherFactory[cdc.MongoDatabaseWatcher]
factory_factory = DatabaseListenerFactory[cdc.MongoDatabaseListener]
db_factory = factory_factory(db_type=db_type)
return db_factory.create(db=db, on=on, identifier=identifier, *args, **kwargs)
46 changes: 23 additions & 23 deletions superduperdb/db/base/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,11 @@ def _build_task_workflow(
args=[],
),
)
watchers = self.show('listener')
if not watchers:
listener = self.show('listener')
if not listener:
return G

for identifier in watchers:
for identifier in listener:
info = self.metadata.get_component('listener', identifier)
query = info['dict']['select']
model, key = identifier.split('/')
Expand All @@ -487,15 +487,15 @@ def _build_task_workflow(
),
)

for identifier in watchers:
for identifier in listener:
model, key = identifier.split('/')
G.add_edge(
f'{download_content.__name__}()',
f'{model}.predict({key})',
)
deps = self._get_dependencies_for_watcher(
deps = self._get_dependencies_for_listener(
identifier
) # TODO remove features as explicit argument to watcher
) # TODO remove features as explicit argument to listener
for dep in deps:
dep_model, dep_key = dep.split('/')
G.add_edge(
Expand Down Expand Up @@ -631,7 +631,7 @@ def _create_plan(self):
for identifier in self.metadata.show_components('listener', active=True):
G.add_node('listener', job=identifier)
for identifier in self.metadata.show_components('listener'):
deps = self._get_dependencies_for_watcher(identifier)
deps = self._get_dependencies_for_listener(identifier)
for dep in deps:
G.add_edge(('listener', dep), ('listener', identifier))
if not networkx.is_directed_acyclic_graph(G):
Expand Down Expand Up @@ -763,13 +763,13 @@ def _get_content_for_filter(self, filter) -> Document:
filter = Document(Document.decode(output, encoders=self.encoders))
return filter

def _get_dependencies_for_watcher(self, identifier):
def _get_dependencies_for_listener(self, identifier):
info = self.metadata.get_component('listener', identifier)
if info is None:
return []
watcher_features = info.get('features', {})
listener_features = info.get('features', {})
out = []
for k in watcher_features:
for k in listener_features:
out.append(f'{self.features[k]}/{k}')
if info['dict']['key'].startswith('_outputs.'):
_, key, model = info['key'].split('.')
Expand All @@ -785,22 +785,22 @@ def _get_file_content(self, r):
def _get_object_info(self, identifier, type_id, version=None):
return self.metadata.get_component(type_id, identifier, version=version)

def _apply_watcher( # noqa: F811
def _apply_listener( # noqa: F811
self,
identifier,
ids: t.Optional[t.Sequence[str]] = None,
verbose: bool = False,
max_chunk_size=5000,
model=None,
recompute: bool = False,
watcher_info=None,
listener_info=None,
**kwargs,
) -> t.List:
# NOTE: this method is never called anywhere except for itself!
if watcher_info is None:
watcher_info = self.metadata.get_component('listener', identifier)
if listener_info is None:
listener_info = self.metadata.get_component('listener', identifier)

select = Serializable.deserialize(watcher_info['select'])
select = Serializable.deserialize(listener_info['select'])
if ids is None:
ids = select.get_ids(self)
else:
Expand All @@ -814,28 +814,28 @@ def _apply_watcher( # noqa: F811
'computing chunk '
f'({it + 1}/{math.ceil(len(ids) / max_chunk_size)})'
)
self._apply_watcher(
self._apply_listener(
identifier,
ids=ids[i : i + max_chunk_size],
verbose=verbose,
max_chunk_size=None,
model=model,
recompute=recompute,
watcher_info=watcher_info,
listener_info=listener_info,
distributed=False,
**kwargs,
)
return []

model_info = self.metadata.get_component('model', watcher_info['model'])
model_info = self.metadata.get_component('model', listener_info['model'])
outputs = self._compute_model_outputs(
model_info,
ids,
select,
key=watcher_info['key'],
features=watcher_info.get('features', {}),
key=listener_info['key'],
features=listener_info.get('features', {}),
model=model,
predict_kwargs=watcher_info.get('predict_kwargs', {}),
predict_kwargs=listener_info.get('predict_kwargs', {}),
)
type = model_info.get('type')
if type is not None:
Expand All @@ -844,8 +844,8 @@ def _apply_watcher( # noqa: F811

select.model_update(
db=self,
model=watcher_info['model'],
key=watcher_info['key'],
model=listener_info['model'],
key=listener_info['key'],
outputs=outputs,
ids=ids,
)
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/db/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_job(self, job_id: str):
def update_job(self, job_id: str, key: str, value: t.Any):
pass

def watch_job(self, identifier: str):
def listen_job(self, identifier: str):
try:
status = 'pending'
n_lines = 0
Expand Down
Loading