Skip to content

Commit

Permalink
Fix possible data corruption due to race condition in StreamRedirector (
Browse files Browse the repository at this point in the history
#1773)

* fix: StreamRedirector may cause data corruption due to race condition

Signed-off-by: Yen-Nan (Maso) Lin <[email protected]>
Signed-off-by: Maso Lin <[email protected]>

* fix: formating error

Signed-off-by: Maso Lin <[email protected]>

* fix: type checking fails

Signed-off-by: Maso Lin <[email protected]>

* Set explicit 10s deadline for test_stream_redirector_race_condition

---------

Signed-off-by: Yen-Nan (Maso) Lin <[email protected]>
Signed-off-by: Maso Lin <[email protected]>
Co-authored-by: Maso Lin <[email protected]>
  • Loading branch information
mattt and masolin authored Jun 26, 2024
1 parent ae395ab commit e8e2e91
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 10 deletions.
29 changes: 19 additions & 10 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
self._events = events
self._tee_output = tee_output
self._cancelable = False
self._events_lock = _spawn.Lock()

super().__init__()

Expand Down Expand Up @@ -200,7 +201,8 @@ def _setup(self) -> None:
raise
finally:
self._stream_redirector.drain()
self._events.send(done)
with self._events_lock:
self._events.send(done)

def _loop(self) -> None:
while True:
Expand All @@ -221,13 +223,18 @@ def _predict(self, payload: Dict[str, Any]) -> None:
result = predict(**payload)

if result:
if isinstance(result, types.GeneratorType):
self._events.send(PredictionOutputType(multi=True))
for r in result:
self._events.send(PredictionOutput(payload=make_encodeable(r)))
else:
self._events.send(PredictionOutputType(multi=False))
self._events.send(PredictionOutput(payload=make_encodeable(result)))
with self._events_lock:
if isinstance(result, types.GeneratorType):
self._events.send(PredictionOutputType(multi=True))
for r in result:
self._events.send(
PredictionOutput(payload=make_encodeable(r))
)
else:
self._events.send(PredictionOutputType(multi=False))
self._events.send(
PredictionOutput(payload=make_encodeable(result))
)
except CancelationException:
done.canceled = True
except Exception as e:
Expand All @@ -237,7 +244,8 @@ def _predict(self, payload: Dict[str, Any]) -> None:
finally:
self._cancelable = False
self._stream_redirector.drain()
self._events.send(done)
with self._events_lock:
self._events.send(done)

def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None:
if signum == signal.SIGUSR1 and self._cancelable:
Expand All @@ -249,4 +257,5 @@ def _stream_write_hook(
if self._tee_output:
original_stream.write(data)
original_stream.flush()
self._events.send(Log(data, source=stream_name))
with self._events_lock:
self._events.send(Log(data, source=stream_name))
17 changes: 17 additions & 0 deletions python/tests/server/fixtures/stream_redirector_race_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from cog import BasePredictor
import threading


def keep_printing():
for _ in range(10000):
print("hello")


class Predictor(BasePredictor):
def setup(self):
self.print_thread = threading.Thread(target=keep_printing)

def predict(self) -> str:
self.print_thread.start()
output = "output" * 100000 # bigger output increases the chance of race condition
return output
24 changes: 24 additions & 0 deletions python/tests/server/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,30 @@ def test_no_exceptions_from_recoverable_failures(data, name, payloads):
w.terminate()


@given(data=st.data())
@settings(deadline=10000) # 10 seconds
def test_stream_redirector_race_condition(data):
"""
StreamRedirector and _ChildWorker are using the same _events pipe to send data.
When there are multiple threads trying to write to the same pipe, it can cause data corruption by race condition.
The data corruption will cause pipe receiver to raise an exception due to unpickling error.
"""
w = Worker(
predictor_ref=_fixture_path("stream_redirector_race_condition"),
tee_output=False,
)

try:
result = _process(w.setup())
assert not result.done.error

payload = data.draw(st.fixed_dictionaries({}))
_process(w.predict(payload))

finally:
w.terminate()


@pytest.mark.parametrize("name,payloads,output_generator", OUTPUT_FIXTURES)
@given(data=st.data())
def test_output(data, name, payloads, output_generator):
Expand Down

0 comments on commit e8e2e91

Please sign in to comment.