From e8e2e91d2d94dc11b6959fabf6db328bafa1c2f7 Mon Sep 17 00:00:00 2001 From: Mattt Date: Wed, 26 Jun 2024 07:19:39 -0700 Subject: [PATCH] Fix possible data corruption due to race condition in StreamRedirector (#1773) * fix: StreamRedirector may cause data corruption due to race condition Signed-off-by: Yen-Nan (Maso) Lin Signed-off-by: Maso Lin * fix: formating error Signed-off-by: Maso Lin * fix: type checking fails Signed-off-by: Maso Lin * Set explicit 10s deadline for test_stream_redirector_race_condition --------- Signed-off-by: Yen-Nan (Maso) Lin Signed-off-by: Maso Lin Co-authored-by: Maso Lin --- python/cog/server/worker.py | 29 ++++++++++++------- .../stream_redirector_race_condition.py | 17 +++++++++++ python/tests/server/test_worker.py | 24 +++++++++++++++ 3 files changed, 60 insertions(+), 10 deletions(-) create mode 100644 python/tests/server/fixtures/stream_redirector_race_condition.py diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index 04e5b866a4..3bc1e81f1e 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -153,6 +153,7 @@ def __init__( self._events = events self._tee_output = tee_output self._cancelable = False + self._events_lock = _spawn.Lock() super().__init__() @@ -200,7 +201,8 @@ def _setup(self) -> None: raise finally: self._stream_redirector.drain() - self._events.send(done) + with self._events_lock: + self._events.send(done) def _loop(self) -> None: while True: @@ -221,13 +223,18 @@ def _predict(self, payload: Dict[str, Any]) -> None: result = predict(**payload) if result: - if isinstance(result, types.GeneratorType): - self._events.send(PredictionOutputType(multi=True)) - for r in result: - self._events.send(PredictionOutput(payload=make_encodeable(r))) - else: - self._events.send(PredictionOutputType(multi=False)) - self._events.send(PredictionOutput(payload=make_encodeable(result))) + with self._events_lock: + if isinstance(result, types.GeneratorType): + self._events.send(PredictionOutputType(multi=True)) + for r in result: + self._events.send( + PredictionOutput(payload=make_encodeable(r)) + ) + else: + self._events.send(PredictionOutputType(multi=False)) + self._events.send( + PredictionOutput(payload=make_encodeable(result)) + ) except CancelationException: done.canceled = True except Exception as e: @@ -237,7 +244,8 @@ def _predict(self, payload: Dict[str, Any]) -> None: finally: self._cancelable = False self._stream_redirector.drain() - self._events.send(done) + with self._events_lock: + self._events.send(done) def _signal_handler(self, signum: int, frame: Optional[types.FrameType]) -> None: if signum == signal.SIGUSR1 and self._cancelable: @@ -249,4 +257,5 @@ def _stream_write_hook( if self._tee_output: original_stream.write(data) original_stream.flush() - self._events.send(Log(data, source=stream_name)) + with self._events_lock: + self._events.send(Log(data, source=stream_name)) diff --git a/python/tests/server/fixtures/stream_redirector_race_condition.py b/python/tests/server/fixtures/stream_redirector_race_condition.py new file mode 100644 index 0000000000..9063f18084 --- /dev/null +++ b/python/tests/server/fixtures/stream_redirector_race_condition.py @@ -0,0 +1,17 @@ +from cog import BasePredictor +import threading + + +def keep_printing(): + for _ in range(10000): + print("hello") + + +class Predictor(BasePredictor): + def setup(self): + self.print_thread = threading.Thread(target=keep_printing) + + def predict(self) -> str: + self.print_thread.start() + output = "output" * 100000 # bigger output increases the chance of race condition + return output diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 0bb950c6df..2c1d4dc322 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -201,6 +201,30 @@ def test_no_exceptions_from_recoverable_failures(data, name, payloads): w.terminate() +@given(data=st.data()) +@settings(deadline=10000) # 10 seconds +def test_stream_redirector_race_condition(data): + """ + StreamRedirector and _ChildWorker are using the same _events pipe to send data. + When there are multiple threads trying to write to the same pipe, it can cause data corruption by race condition. + The data corruption will cause pipe receiver to raise an exception due to unpickling error. + """ + w = Worker( + predictor_ref=_fixture_path("stream_redirector_race_condition"), + tee_output=False, + ) + + try: + result = _process(w.setup()) + assert not result.done.error + + payload = data.draw(st.fixed_dictionaries({})) + _process(w.predict(payload)) + + finally: + w.terminate() + + @pytest.mark.parametrize("name,payloads,output_generator", OUTPUT_FIXTURES) @given(data=st.data()) def test_output(data, name, payloads, output_generator):