Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
philandstuff committed Nov 25, 2024
1 parent db1cbef commit 8630036
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 63 deletions.
95 changes: 45 additions & 50 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,13 @@ class CancelRequest:
tag: Optional[str]


@define
class PredictionState:
def __init__(
self, tag: Optional[str], payload: Dict[str, Any], result: "Future[Done]"
) -> None:
self.tag = tag
self.payload = payload
self.result = result
tag: Optional[str]
payload: Dict[str, Any]
result: "Future[Done]"

self.cancel_sent = False
cancel_sent: bool = False


class Worker:
Expand Down Expand Up @@ -154,7 +152,7 @@ def predict(
self._assert_state(WorkerState.READY)
result = Future()
self._predictions_in_flight[tag] = PredictionState(tag, payload, result)
self._request_send_conn.send(PredictionRequest(tag))
self._request_send_conn.send(PredictionRequest(tag))
return result

def subscribe(
Expand Down Expand Up @@ -236,7 +234,6 @@ def _consume_events_inner(self) -> None:
# If we didn't get a done event, the child process died.
if not done:
exitcode = self._child.exitcode
assert self._setup_result
self._setup_result.set_exception(
FatalWorkerException(
f"Predictor setup failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
Expand All @@ -245,7 +242,6 @@ def _consume_events_inner(self) -> None:
self._state = WorkerState.DEFUNCT
return
if done.error:
assert self._setup_result
self._setup_result.set_exception(
FatalWorkerException(
"Predictor errored during setup: " + done.error_detail
Expand All @@ -266,42 +262,43 @@ def _consume_events_inner(self) -> None:
read_socks, _, _ = select.select(
[self._request_recv_conn, self._events], [], [], 0.1
)
for sock in read_socks:
if sock == self._request_recv_conn:
ev = self._request_recv_conn.recv()
if isinstance(ev, PredictionRequest):
with self._predictions_lock:
state = self._predictions_in_flight[ev.tag]

# Prepare payload (download URLPath objects)
# FIXME this blocks the event loop, which is bad in concurrent mode
try:
_prepare_payload(state.payload)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, state.tag))
self._complete_prediction(done, state.tag)
else:
# Start the prediction
self._events.send(
Envelope(
event=PredictionInput(payload=state.payload),
tag=state.tag,
)
if self._request_recv_conn in read_socks:
ev = self._request_recv_conn.recv()
if isinstance(ev, PredictionRequest):
with self._predictions_lock:
state = self._predictions_in_flight[ev.tag]

# Prepare payload (download URLPath objects)
# FIXME this blocks the event loop, which is bad in concurrent mode
try:
_prepare_payload(state.payload)
except Exception as e:
done = Done(error=True, error_detail=str(e))
self._publish(Envelope(done, state.tag))
self._complete_prediction(done, state.tag)
else:
# Start the prediction
self._events.send(
Envelope(
event=PredictionInput(payload=state.payload),
tag=state.tag,
)
if isinstance(ev, CancelRequest):
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True

else: # sock == self._events
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)
)
elif isinstance(ev, CancelRequest):
with self._predictions_lock:
predict_state = self._predictions_in_flight.get(ev.tag)
if predict_state and not predict_state.cancel_sent:
self._child.send_cancel()
self._events.send(Envelope(event=Cancel(), tag=ev.tag))
predict_state.cancel_sent = True
else:
log.warn("unrecognized request event: {ev}")

if self._events in read_socks:
ev = self._events.recv()
self._publish(ev)
if isinstance(ev.event, Done):
self._complete_prediction(ev.event, ev.tag)

# If we dropped off the end off the end of the loop, it's because the
# child process died. First, process any remaining messages on the connection
Expand All @@ -314,22 +311,20 @@ def _consume_events_inner(self) -> None:
if not self._terminating:
self._state = WorkerState.DEFUNCT
with self._predictions_lock:
for tag in list(self._predictions_in_flight.keys()):
for state in self._predictions_in_flight.values():
exitcode = self._child.exitcode
self._predictions_in_flight[tag].result.set_exception(
state.result.set_exception(
FatalWorkerException(
f"Prediction failed for an unknown reason. It might have run out of memory? (exitcode {exitcode})"
)
)
del self._predictions_in_flight[tag]
self._predictions_in_flight.clear()

def _complete_prediction(self, done: Done, tag: Optional[str]) -> None:
# We update the in-flight dictionary before completing the prediction
# future, so that we can immediately accept work.
with self._predictions_lock:
predict_state = self._predictions_in_flight.pop(tag)
if len(self._predictions_in_flight) == 0:
self._state = WorkerState.READY
predict_state.result.set_result(done)

def _publish(self, e: Envelope) -> None:
Expand Down
20 changes: 7 additions & 13 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,6 @@ def test_can_subscribe_for_a_specific_tag(worker):

@uses_worker("sleep_async", max_concurrency=5)
def test_can_run_predictions_concurrently_on_async_predictor(worker):
tag = "123"

result = Result()
subid = worker.subscribe(result.handle_event, tag=tag)
subids = []

try:
Expand Down Expand Up @@ -299,7 +295,8 @@ def test_can_run_predictions_concurrently_on_async_predictor(worker):
assert result.output == "done in 0.5 seconds"

finally:
worker.unsubscribe(subid)
for subid in subids:
worker.unsubscribe(subid)


@uses_worker("stream_redirector_race_condition")
Expand Down Expand Up @@ -513,7 +510,6 @@ class PredictState:

class FakeChildWorker:
exitcode = None
cancel_sent = False
alive = True
pid: int = 0

Expand All @@ -524,7 +520,7 @@ def is_alive(self):
return self.alive

def send_cancel(self):
self.cancel_sent = True
pass

def terminate(self):
pass
Expand Down Expand Up @@ -735,17 +731,15 @@ def await_predict(self, state: PredictState):
def cancel(self, state: PredictState):
self.worker.cancel(tag=state.tag)

if state.canceled:
# if this has previously been canceled, we expect no Cancel event
# sent to the child
assert not self.child_events.poll(timeout=0.1)
else:
if not state.canceled:
# if this prediction has not previously been canceled, Worker will
# send a Cancel event to the child. We need to consume this event to
# ensure we stay synced up on the child connection
assert self.child_events.poll(timeout=0.5)
e = self.child_events.recv()
assert isinstance(e, Envelope)
assert isinstance(e.event, Cancel)
assert e.tag == state.tag
assert self.child.cancel_sent

return evolve(state, canceled=True)

Expand Down

0 comments on commit 8630036

Please sign in to comment.