Skip to content

Commit

Permalink
Add support for multiple subscribers
Browse files Browse the repository at this point in the history
sondreso committed Sep 15, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 6eb7862 commit adb5f57
Showing 2 changed files with 123 additions and 58 deletions.
93 changes: 93 additions & 0 deletions src/ert/experiment_server/experiment_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import asyncio
import queue
from multiprocessing.queues import Queue
from typing import Dict, List

from fastapi.encoders import jsonable_encoder

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.ensemble_evaluator.event import EndEvent, _UpdateEvent
from ert.run_models.base_run_model import BaseRunModel, StatusEvents


class EndTaskEvent:
pass

class Subscriber:
def __init__(self) -> None:
self.index = 0
self._event = asyncio.Event()

def notify(self):
self._event.set()

async def wait_for_event(self):
await self._event.wait()
self._event.clear()

class ExperimentTask:
def __init__(self, _id: str, model: BaseRunModel, status_queue: "Queue[StatusEvents]" ) -> None:
self._id = _id
self._model = model
self._status_queue = status_queue
self._subscribers: Dict[str, Subscriber] = {}
self._events: List[StatusEvents] = []

def cancel(self) -> None:
self._model.cancel()

async def run(self):
loop = asyncio.get_running_loop()
print(f"Starting experiment {self._id}")

port_range = None
if self._model.queue_system == QueueSystem.LOCAL:
port_range = range(49152, 51819)
evaluator_server_config = EvaluatorServerConfig(custom_port_range=port_range)

simulation_future = loop.run_in_executor(
None,
lambda: self._model.start_simulations_thread(
evaluator_server_config
),
)

while True:
try:
item: StatusEvents = self._status_queue.get(block=False)
except queue.Empty:
await asyncio.sleep(0.01)
continue

if isinstance(item, _UpdateEvent):
item.snapshot = item.snapshot.to_dict()
# print(item)
# print()
# print()
event = jsonable_encoder(item)
self._events.append(event)
for sub in self._subscribers.values():
sub.notify()
await asyncio.sleep(0.1)

if isinstance(item, EndEvent):
self._events.append(EndTaskEvent())
for sub in self._subscribers.values():
sub.notify()
break

await simulation_future
print(f"Experiment {self._id} done")

async def get_event(self, subscriber_id: str) -> StatusEvents:
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = Subscriber()
subscriber = self._subscribers[subscriber_id]

while subscriber.index >= len(self._events):
await subscriber.wait_for_event()

event = self._events[subscriber.index]
self._subscribers[subscriber_id].index += 1
return event
88 changes: 30 additions & 58 deletions src/ert/experiment_server/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import asyncio
import multiprocessing as mp
import os
import queue
import uuid
from concurrent.futures import ProcessPoolExecutor
from contextlib import asynccontextmanager
from multiprocessing.queues import Queue
from typing import Dict, Tuple, Union
from typing import Dict, Union

from fastapi import BackgroundTasks, FastAPI, HTTPException, WebSocket
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, Field

from ert.config import ErtConfig, QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.ensemble_evaluator.event import EndEvent, _UpdateEvent
from ert.config import ErtConfig
from ert.gui.simulation.ensemble_experiment_panel import (
Arguments as EnsembleExperimentArguments,
)
@@ -31,10 +26,12 @@
Arguments as MultipleDataAssimilationArguments,
)
from ert.gui.simulation.single_test_run_panel import Arguments as SingleTestRunArguments
from ert.run_models.base_run_model import BaseRunModel, StatusEvents
from ert.run_models.base_run_model import StatusEvents
from ert.run_models.model_factory import create_model
from ert.storage import open_storage

from .experiment_task import EndTaskEvent, ExperimentTask


class Experiment(BaseModel):
args: Union[
@@ -50,32 +47,21 @@ class Experiment(BaseModel):


mp_ctx = mp.get_context("fork")
process_pool = ProcessPoolExecutor(
max_workers=max((os.cpu_count() or 1) - 2, 1), mp_context=mp_ctx
)
app = FastAPI()
experiments: Dict[str, Tuple[BaseRunModel, "Queue[StatusEvents]"]] = {}
experiments: Dict[str, ExperimentTask] = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup actions
yield
# Shutdown actions

app = FastAPI(lifespan=lifespan)

@app.get("/")
async def root():
return {"message": "ping"}


async def run_experiment(
experiment_id: str, evaluator_server_config: EvaluatorServerConfig
):
loop = asyncio.get_running_loop()
print(f"Starting experiment {experiment_id}")
await loop.run_in_executor(
None,
lambda: experiments[experiment_id][0].start_simulations_thread(
evaluator_server_config
),
)
print(f"Experiment {experiment_id} done")


@app.post("/experiments/")
async def submit_experiment(experiment: Experiment, background_tasks: BackgroundTasks):
storage = open_storage(experiment.ert_config.ens_path, "w")
@@ -93,48 +79,34 @@ async def submit_experiment(experiment: Experiment, background_tasks: Background
detail=f"{experiment.args.mode} was not valid, failed with: {e}",
)

port_range = None
if model.queue_system == QueueSystem.LOCAL:
port_range = range(49152, 51819)
evaluator_server_config = EvaluatorServerConfig(custom_port_range=port_range)

experiment_id = str(uuid.uuid4())
experiments[experiment_id] = (model, status_queue)

background_tasks.add_task(
run_experiment, experiment_id, evaluator_server_config=evaluator_server_config
)
task = ExperimentTask(_id=experiment_id, model=model, status_queue=status_queue)
experiments[experiment_id] = task
background_tasks.add_task(task.run)
return {"message": "Experiment Started", "experiment_id": experiment_id}


@app.put("/experiments/{experiment_id}/cancel")
async def cancel_experiment(experiment_id: str):
if experiment_id in experiments:
experiments[experiment_id][0].cancel()
return {"message": "Experiment Canceled", "experiment_id": experiment_id}
if experiment_id not in experiments:
return HTTPException(
status_code=404,
detail=f"Experiment with id {experiment_id} does not exist.",
)
experiments[experiment_id].cancel()
return {"message": "Experiment canceled", "experiment_id": experiment_id}


@app.websocket("/experiments/{experiment_id}/events")
async def websocket_endpoint(websocket: WebSocket, experiment_id: str):
if experiment_id not in experiments:
return
subscriber_id = str(uuid.uuid4())
await websocket.accept()
print(experiment_id)
print(experiments)
q = experiments[experiment_id][1]
task = experiments[experiment_id]
while True:
try:
item: StatusEvents = q.get(block=False)
except queue.Empty:
await asyncio.sleep(0.01)
continue

if isinstance(item, _UpdateEvent):
item.snapshot = item.snapshot.to_dict()
print(item)
print()
print()
await websocket.send_json(jsonable_encoder(item))
await asyncio.sleep(0.1)
if isinstance(item, EndEvent):
event = await task.get_event(subscriber_id=subscriber_id)
if isinstance(event, EndTaskEvent):
break
await websocket.send_json(event)
await asyncio.sleep(0.1)

0 comments on commit adb5f57

Please sign in to comment.