Skip to content

Commit

Permalink
Revert dask client to remote (#1424)
Browse files Browse the repository at this point in the history
  • Loading branch information
thejumpman2323 authored Nov 29, 2023
1 parent faf523e commit c4c3116
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 20 deletions.
4 changes: 2 additions & 2 deletions test/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def fake_updates(database_with_default_encoders_and_model):


@pytest.fixture
def local_dask_client(monkeypatch, request):
def dask_client(monkeypatch, request):
db_name = "test_db"
data_backend = f'mongodb://superduper:superduper@localhost:27017/{db_name}'

Expand All @@ -121,7 +121,7 @@ def local_dask_client(monkeypatch, request):
# Change the default value
client = DaskComputeBackend(
address='tcp://localhost:8786',
local=True,
local=False,
)

yield client
Expand Down
30 changes: 14 additions & 16 deletions test/integration/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def add_and_cleanup_listener(database, collection_name):

@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_taskgraph_futures_with_dask(
local_dask_client, database_with_default_encoders_and_model, fake_updates
dask_client, database_with_default_encoders_and_model, fake_updates
):
collection_name = str(uuid.uuid4())
database_with_default_encoders_and_model.set_compute(local_dask_client)
database_with_default_encoders_and_model.set_compute(dask_client)
_, graph = database_with_default_encoders_and_model.execute(
Collection(identifier=collection_name).insert_many(fake_updates)
)
Expand All @@ -49,7 +49,7 @@ def test_taskgraph_futures_with_dask(
Collection(identifier=collection_name).find({'update': True})
)
)
local_dask_client.wait_all_pending_tasks()
dask_client.wait_all_pending_tasks()

nodes = graph.G.nodes
jobs = [nodes[node]['job'] for node in nodes]
Expand All @@ -59,12 +59,12 @@ def test_taskgraph_futures_with_dask(

@pytest.mark.skipif(not torch, reason='Torch not installed')
@pytest.mark.parametrize(
'local_dask_client, test_db',
'dask_client, test_db',
[('test_insert_with_distributed', 'test_insert_with_distributed')],
indirect=True,
)
def test_insert_with_dask(
local_dask_client, database_with_default_encoders_and_model, fake_updates
dask_client, database_with_default_encoders_and_model, fake_updates
):
collection_name = str(uuid.uuid4())

Expand All @@ -73,14 +73,14 @@ def test_insert_with_dask(
collection_name,
) as db:
# Submit job
db.set_compute(local_dask_client)
db.set_compute(dask_client)
db.execute(Collection(identifier=collection_name).insert_many(fake_updates))

# Barrier
local_dask_client.wait_all_pending_tasks()
dask_client.wait_all_pending_tasks()

# Get distributed logs
logs = local_dask_client.client.get_worker_logs()
logs = dask_client.client.get_worker_logs()

logging.info("worker logs", logs)

Expand All @@ -91,9 +91,7 @@ def test_insert_with_dask(


@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_dependencies_with_dask(
local_dask_client, database_with_default_encoders_and_model
):
def test_dependencies_with_dask(dask_client, database_with_default_encoders_and_model):
def test_node_1(*args, **kwargs):
return 1

Expand All @@ -103,7 +101,7 @@ def test_node_2(*args, **kwargs):
# Set Dask as Compute engine.
# ------------------------------
database = database_with_default_encoders_and_model
database.set_compute(local_dask_client)
database.set_compute(dask_client)

# Build Task Graph
# ------------------------------
Expand All @@ -125,11 +123,11 @@ def test_node_2(*args, **kwargs):
# Run Job
# ------------------------------
g.run_jobs()
local_dask_client.wait_all_pending_tasks()
dask_client.wait_all_pending_tasks()

# Validate Output
# ------------------------------
futures = list(local_dask_client.list_all_pending_tasks().values())
futures = list(dask_client.list_all_pending_tasks().values())
assert len(futures) == 2
assert futures[0].status == 'finished'
assert futures[1].status == 'finished'
Expand All @@ -138,11 +136,11 @@ def test_node_2(*args, **kwargs):


def test_model_job_logs(
local_dask_client, database_with_default_encoders_and_model, fake_updates
dask_client, database_with_default_encoders_and_model, fake_updates
):
# Set Dask as compute engine.
# ------------------------------
database_with_default_encoders_and_model.set_compute(local_dask_client)
database_with_default_encoders_and_model.set_compute(dask_client)

# Set Collection Listener
# ------------------------------
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def check_outputs():


@pytest.fixture
def distributed_db(monkeypatch, test_db, local_dask_client):
def distributed_db(monkeypatch, test_db, dask_client):
from superduperdb import CFG

existing_databackend = CFG.data_backend
Expand All @@ -107,7 +107,7 @@ def distributed_db(monkeypatch, test_db, local_dask_client):
vector_search = 'http://localhost:8000'
monkeypatch.setattr(CFG.cluster, 'cdc', cdc)
monkeypatch.setattr(CFG.cluster, 'vector_search', vector_search)
test_db.set_compute(local_dask_client)
test_db.set_compute(dask_client)

def update_syspath():
import sys
Expand Down

0 comments on commit c4c3116

Please sign in to comment.