Skip to content

Commit

Permalink
Fix "remove()" not thread-safe causing possible RuntimeError (#1183)
Browse files Browse the repository at this point in the history
  • Loading branch information
Delgan committed Oct 6, 2024
1 parent aa6fdd8 commit 91e3646
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
7 changes: 3 additions & 4 deletions loguru/_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,17 +1040,16 @@ def remove(self, handler_id=None):
)

with self._core.lock:
handlers = self._core.handlers.copy()

if handler_id is not None and handler_id not in handlers:
if handler_id is not None and handler_id not in self._core.handlers:
raise ValueError("There is no existing handler with id %d" % handler_id) from None

if handler_id is None:
handler_ids = list(handlers.keys())
handler_ids = list(self._core.handlers)
else:
handler_ids = [handler_id]

for handler_id in handler_ids:
handlers = self._core.handlers.copy()
handler = handlers.pop(handler_id)

# This needs to be done first in case "stop()" raises an exception
Expand Down
61 changes: 60 additions & 1 deletion tests/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@


class NonSafeSink:
def __init__(self, sleep_time):
def __init__(self, sleep_time, stop_time=0):
self.sleep_time = sleep_time
self.stop_time = stop_time
self.written = ""
self.stopped = False

Expand All @@ -21,6 +22,7 @@ def write(self, message):
self.written += message[length:]

def stop(self):
time.sleep(self.stop_time)
self.stopped = True


Expand Down Expand Up @@ -111,6 +113,63 @@ def thread_2():
assert sink.written == "aaa0bbb\n"


def test_safe_removing_all_while_logging(capsys):
barrier = Barrier(2)

for _ in range(1000):
logger.add(lambda _: None, format="{message}", catch=False)

def thread_1():
barrier.wait()
logger.remove()

def thread_2():
barrier.wait()
for _ in range(100):
logger.info("Some message")

threads = [Thread(target=thread_1), Thread(target=thread_2)]

for thread in threads:
thread.start()

for thread in threads:
thread.join()

out, err = capsys.readouterr()
assert out == ""
assert err == ""


def test_safe_slow_removing_all_while_logging(capsys):
barrier = Barrier(2)

for _ in range(10):
sink = NonSafeSink(0.1, 0.1)
logger.add(sink, format="{message}", catch=False)

def thread_1():
barrier.wait()
logger.remove()

def thread_2():
barrier.wait()
time.sleep(0.5)
logger.info("Some message")

threads = [Thread(target=thread_1), Thread(target=thread_2)]

for thread in threads:
thread.start()

for thread in threads:
thread.join()

out, err = capsys.readouterr()
assert out == ""
assert err == ""


def test_safe_writing_after_removing(capsys):
barrier = Barrier(2)

Expand Down

0 comments on commit 91e3646

Please sign in to comment.