Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch BaseException on UCX read error #6996

Merged
merged 8 commits into from
Sep 30, 2022
Merged

Conversation

pentschev
Copy link
Member

This change will ensure CancelledErrors are catched upon shutting down the Dask cluster, which may otherwise raise various errors.

  • Tests added / passed
  • Passes pre-commit run --all-files

This change will ensure `CancelledError`s are catched upon shutting down
the Dask cluster, which may otherwise raise various errors.

See also dask#6574 .
@pentschev pentschev marked this pull request as ready for review September 2, 2022 20:35
@github-actions
Copy link
Contributor

github-actions bot commented Sep 2, 2022

Unit Test Results

See test report for an extended history of previous test failures. This is useful for diagnosing flaky tests.

       15 files  ±0         15 suites  ±0   6h 14m 37s ⏱️ - 8m 39s
  3 146 tests ±0    3 061 ✔️ +2    85 💤 ±0  0  - 2 
23 286 runs  ±0  22 377 ✔️ +2  909 💤 ±0  0  - 2 

Results for commit 8de1557. ± Comparison against base commit 68e5a6a.

♻️ This comment has been updated with latest results.

Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to these comments, I think it is also necessary to wrap the same exception handling around write (in case the other end hung up in a strange way). Or is that not necessary?

# Depending on the UCP protocol selected, it may raise either
# `asyncio.TimeoutError` or `CommClosedError`, so validate either one.
with pytest.raises((asyncio.TimeoutError, CommClosedError)):
await asyncio.wait_for(reader.read(), 0.01)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so here we're waiting for a read that will never be matched by a write, and so eventually we'll fail.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right.

distributed/comm/ucx.py Outdated Show resolved Hide resolved
except BaseException:
# In addition to UCX exceptions, may be CancelledError or a another
# "low-level" exception. The only safe thing to do is to abort.
# (See also https://github.com/dask/distributed/pull/6574).
self.abort()
raise CommClosedError("Connection closed by writer")
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to catch connection issues on line 354 as well.

So perhaps lines 353 and 354 should be replaced by:

try:
    for frame in recv_frames:
        await self.ep.recv(frame)
except BaseException as e:
    raise CommClosedError("Connection closed by writer.\nInner exception: {e!r}")

I had thought that one might be able to reduce synchronisation a little bit by using:

await asyncio.gather(*(map(self.ep.recv, recv_frames))

With a matching change in write of await asyncio.gather(*map(self.ep.send, send_frames)).

But I am unsure of the semantics of UCX wrt message overtaking. I think this could potentially result in the second (say) sent frame ending up in the first receive slot, which would be bad.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to catch connection issues on line 354 as well.

So perhaps lines 353 and 354 should be replaced by:

try:
    for frame in recv_frames:
        await self.ep.recv(frame)
except BaseException as e:
    raise CommClosedError("Connection closed by writer.\nInner exception: {e!r}")

I'm not entirely sure we want that, maybe it never occurred in practice or just raising the original exception may be fine. I'm mostly concerned with unforeseen side-effects this may cause and would prefer not to mess with it now given it's not been a problem so far. WDYT?

I had thought that one might be able to reduce synchronisation a little bit by using:

await asyncio.gather(*(map(self.ep.recv, recv_frames))

With a matching change in write of await asyncio.gather(*map(self.ep.send, send_frames)).

But I am unsure of the semantics of UCX wrt message overtaking. I think this could potentially result in the second (say) sent frame ending up in the first receive slot, which would be bad.

I would expect that as well and had done it once and had to revert #5505 because that caused various issues, unfortunately. In any case, with the C++ UCX introduction of "multi-transfers" this will anyway be reduced to a single future, so I will not try to improve this code in its current form.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure we want that, maybe it never occurred in practice or just raising the original exception may be fine. I'm mostly concerned with unforeseen side-effects this may cause and would prefer not to mess with it now given it's not been a problem so far. WDYT?

What could happen (although my guess is that it would be low likelihood) is that we're receiving a bunch of frames, each await yields to the event loop, and in between awaits the remote endpoint is closed for some other reason.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I fear that by raising a different exception now we may end up in some different control path that we didn't expect. I'm hoping that this patch can end up in the next Distributed release and it could be included in RAPIDS 22.10. I would be fine trying that out afterwards, but I'm a bit nervous of breaking something close to release time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks, makes sense.

@pentschev
Copy link
Member Author

Thanks @wence- , replied to your comments, please take another look when you have a chance.

@wence-
Copy link
Contributor

wence- commented Sep 28, 2022

Thanks @wence- , replied to your comments, please take another look when you have a chance.

I think we need to catch errors (other than UCXBaseException) on the write side as well as read (since those awaitables may also be cancelled).

Like this perhaps?

diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py
index 16ce82d4..b3a595a5 100644
--- a/distributed/comm/ucx.py
+++ b/distributed/comm/ucx.py
@@ -254,27 +254,27 @@ class UCX(Comm):
     ) -> int:
         if self.closed():
             raise CommClosedError("Endpoint is closed -- unable to send message")
-        try:
-            if serializers is None:
-                serializers = ("cuda", "dask", "pickle", "error")
-            # msg can also be a list of dicts when sending batched messages
-            frames = await to_frames(
-                msg,
-                serializers=serializers,
-                on_error=on_error,
-                allow_offload=self.allow_offload,
-            )
-            nframes = len(frames)
-            cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames)
-            sizes = tuple(nbytes(f) for f in frames)
-            cuda_send_frames, send_frames = zip(
-                *(
-                    (is_cuda, each_frame)
-                    for is_cuda, each_frame in zip(cuda_frames, frames)
-                    if nbytes(each_frame) > 0
-                )
+        if serializers is None:
+            serializers = ("cuda", "dask", "pickle", "error")
+        # msg can also be a list of dicts when sending batched messages
+        frames = await to_frames(
+            msg,
+            serializers=serializers,
+            on_error=on_error,
+            allow_offload=self.allow_offload,
+        )
+        nframes = len(frames)
+        cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames)
+        sizes = tuple(nbytes(f) for f in frames)
+        cuda_send_frames, send_frames = zip(
+            *(
+                (is_cuda, each_frame)
+                for is_cuda, each_frame in zip(cuda_frames, frames)
+                if nbytes(each_frame) > 0
             )
+        )
 
+        try:
             # Send meta data
 
             # Send close flag and number of frames (_Bool, int64)
@@ -297,10 +297,11 @@ class UCX(Comm):
 
             for each_frame in send_frames:
                 await self.ep.send(each_frame)
-            return sum(sizes)
-        except (ucp.exceptions.UCXBaseException):
+        except BaseException as e:
             self.abort()
-            raise CommClosedError("While writing, the connection was closed")
+            raise CommClosedError("While writing, the connection was closed.\n"
+                                  f"Inner exception: {e!r}")
+        return sum(sizes)
 
     @log_errors
     async def read(self, deserializers=("cuda", "dask", "pickle", "error")):

I've moved the scope of the try/catch block because we're now catching a broader range of exceptions.

distributed/comm/ucx.py Outdated Show resolved Hide resolved
@pentschev
Copy link
Member Author

Thanks @wence- , replied to your comments, please take another look when you have a chance.

I think we need to catch errors (other than UCXBaseException) on the write side as well as read (since those awaitables may also be cancelled).

Like this perhaps?

diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py
index 16ce82d4..b3a595a5 100644
--- a/distributed/comm/ucx.py
+++ b/distributed/comm/ucx.py
@@ -254,27 +254,27 @@ class UCX(Comm):
     ) -> int:
         if self.closed():
             raise CommClosedError("Endpoint is closed -- unable to send message")
-        try:
-            if serializers is None:
-                serializers = ("cuda", "dask", "pickle", "error")
-            # msg can also be a list of dicts when sending batched messages
-            frames = await to_frames(
-                msg,
-                serializers=serializers,
-                on_error=on_error,
-                allow_offload=self.allow_offload,
-            )
-            nframes = len(frames)
-            cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames)
-            sizes = tuple(nbytes(f) for f in frames)
-            cuda_send_frames, send_frames = zip(
-                *(
-                    (is_cuda, each_frame)
-                    for is_cuda, each_frame in zip(cuda_frames, frames)
-                    if nbytes(each_frame) > 0
-                )
+        if serializers is None:
+            serializers = ("cuda", "dask", "pickle", "error")
+        # msg can also be a list of dicts when sending batched messages
+        frames = await to_frames(
+            msg,
+            serializers=serializers,
+            on_error=on_error,
+            allow_offload=self.allow_offload,
+        )
+        nframes = len(frames)
+        cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames)
+        sizes = tuple(nbytes(f) for f in frames)
+        cuda_send_frames, send_frames = zip(
+            *(
+                (is_cuda, each_frame)
+                for is_cuda, each_frame in zip(cuda_frames, frames)
+                if nbytes(each_frame) > 0
             )
+        )
 
+        try:
             # Send meta data
 
             # Send close flag and number of frames (_Bool, int64)
@@ -297,10 +297,11 @@ class UCX(Comm):
 
             for each_frame in send_frames:
                 await self.ep.send(each_frame)
-            return sum(sizes)
-        except (ucp.exceptions.UCXBaseException):
+        except BaseException as e:
             self.abort()
-            raise CommClosedError("While writing, the connection was closed")
+            raise CommClosedError("While writing, the connection was closed.\n"
+                                  f"Inner exception: {e!r}")
+        return sum(sizes)
 
     @log_errors
     async def read(self, deserializers=("cuda", "dask", "pickle", "error")):

I've moved the scope of the try/catch block because we're now catching a broader range of exceptions.

Similar to the try/catch changes in the previous read, I'm a bit concerned about breaking something else right now. I propose we open a new PR to address those two changes and wait to merge them after the upcoming release. WDYT?

Co-authored-by: Lawrence Mitchell <[email protected]>
Comment on lines 333 to 334
raise CommClosedError("Connection closed by writer.\n"
f"Inner exception: {e!r}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise CommClosedError("Connection closed by writer.\n"
f"Inner exception: {e!r}")
raise CommClosedError(f"Connection closed by writer.\nInner exception: {e!r}")

To pacify the linter.

Copy link
Contributor

@wence- wence- left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modulo the lint pacification, looks good, and we can revisit catching exceptions around write, etc... later.

Can confirm, too, that this fixes ugly-looking UCX errors on disconnect during cluster shutdown for me.

@quasiben
Copy link
Member

rerun tests

@quasiben
Copy link
Member

This failure is due to changes in strides for broadcasted arrays which occurred in NumPy 1.23

@quasiben
Copy link
Member

@gmarkall pointed me to numpy/numpy#21477 -- I'll fix the pinning in the docker image to <1.22 . Graham, is NumPy 1.23 compatibility something being tracked in Numba ?

@gmarkall
Copy link
Contributor

gmarkall commented Sep 29, 2022

Graham, is NumPy 1.23 compatibility something being tracked in Numba ?

Numba 0.56.2 is compatible with NumPy 1.23. I need to determine whether the issue here is np.empty_like() not quite creating an empty array like its argument, or if the array compatibility check in Numba needs a fix for this case.

@gmarkall
Copy link
Contributor

PR to WAR the issue: #7089

@quasiben
Copy link
Member

Numba 0.56.2 is compatible with NumPy 1.23. I need to determine whether the issue here is np.empty_like() not quite creating an empty array like its argument, or if the array compatibility check in Numba needs a fix for this case.

Apologies for the mischaracterization!

@quasiben
Copy link
Member

Merged with main (including #7089) when tests pass I'll merge

@quasiben
Copy link
Member

This is now passing -- merging in. Thanks @pentschev and @wence- !

@pentschev pentschev deleted the ucx-base-exception branch October 10, 2022 12:47
gjoseph92 pushed a commit to gjoseph92/distributed that referenced this pull request Oct 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants