From ae7a83baa077fc7aec7ffae145211e02ca111b23 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 2 Dec 2021 16:45:49 -0500 Subject: [PATCH 1/2] Add error handler to endpoint --- src/UCX.jl | 70 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/src/UCX.jl b/src/UCX.jl index 5b89c83..5ee540f 100644 --- a/src/UCX.jl +++ b/src/UCX.jl @@ -532,9 +532,20 @@ mutable struct UCXEndpoint end Base.unsafe_convert(::Type{API.ucp_ep_h}, ep::UCXEndpoint) = ep.handle -function UCXEndpoint(worker::UCXWorker, ip::IPv4, port) +function ucp_err_handler(arg::Ptr{Cvoid}, ep::API.ucp_ep_h, status::API.ucs_status_t) + @error "Endpoint error" exception=UCXException(status) + # TODO should we throw here and close the endpoint? + return nothing +end + +function UCXEndpoint(worker::UCXWorker, ip::IPv4, port; + error_handling=true) field_mask = API.UCP_EP_PARAM_FIELD_FLAGS | API.UCP_EP_PARAM_FIELD_SOCK_ADDR + if error_handling + field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + API.UCP_EP_PARAM_FIELD_ERR_HANDLER + end flags = API.UCP_EP_PARAMS_FLAGS_CLIENT_SERVER sockaddr = Ref(IP.sockaddr_in(InetAddr(ip, port))) @@ -549,8 +560,15 @@ function UCXEndpoint(worker::UCXWorker, ip::IPv4, port) set!(params, :field_mask, field_mask) set!(params, :sockaddr, ucs_sockaddr) set!(params, :flags, flags) - - # TODO: Error callback + if error_handling + err_handler = API.ucp_err_handler( + @cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)), + C_NULL + ) + + set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER) + set!(params, :err_handler, err_handler) + end @check API.ucp_ep_create(worker, params, r_handle) end @@ -558,9 +576,15 @@ function UCXEndpoint(worker::UCXWorker, ip::IPv4, port) UCXEndpoint(worker, r_handle[]) end -function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest) +function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest; + error_handling=true) field_mask = API.UCP_EP_PARAM_FIELD_FLAGS | API.UCP_EP_PARAM_FIELD_CONN_REQUEST + if error_handling + field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + API.UCP_EP_PARAM_FIELD_ERR_HANDLER + end + flags = API.UCP_EP_PARAMS_FLAGS_NO_LOOPBACK params = Ref{API.ucp_ep_params}() @@ -568,8 +592,15 @@ function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest) set!(params, :field_mask, field_mask) set!(params, :conn_request, conn_request.handle) set!(params, :flags, flags) - - # TODO: Error callback + if error_handling + err_handler = API.ucp_err_handler( + @cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)), + C_NULL + ) + + set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER) + set!(params, :err_handler, err_handler) + end r_handle = Ref{API.ucp_ep_h}() @check API.ucp_ep_create(worker, params, r_handle) @@ -577,29 +608,42 @@ function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest) UCXEndpoint(worker, r_handle[]) end -function UCXEndpoint(worker::UCXWorker, addr::UCXAddress) +function UCXEndpoint(worker::UCXWorker, addr::UCXAddress; + error_handling=true) GC.@preserve addr begin - _UCXEndpoint(worker, addr.handle) + _UCXEndpoint(worker, addr.handle, error_handling) end end -function UCXEndpoint(worker::UCXWorker, addr_buf::Vector{UInt8}) +function UCXEndpoint(worker::UCXWorker, addr_buf::Vector{UInt8}; + error_handling=true) GC.@preserve addr_buf begin addr = Base.unsafe_convert(Ptr{API.ucp_address_t}, pointer(addr_buf)) - _UCXEndpoint(worker, addr) + _UCXEndpoint(worker, addr, error_handling) end end -function _UCXEndpoint(worker::UCXWorker, addr::Ptr{API.ucp_address_t}) +function _UCXEndpoint(worker::UCXWorker, addr::Ptr{API.ucp_address_t}, error_handling) field_mask = API.UCP_EP_PARAM_FIELD_REMOTE_ADDRESS + if error_handling + field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | + API.UCP_EP_PARAM_FIELD_ERR_HANDLER + end r_handle = Ref{API.ucp_ep_h}() params = Ref{API.ucp_ep_params}() memzero!(params) set!(params, :field_mask, field_mask) set!(params, :address, addr) - - # TODO: Error callback + if error_handling + err_handler = API.ucp_err_handler( + @cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)), + C_NULL + ) + + set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER) + set!(params, :err_handler, err_handler) + end @check API.ucp_ep_create(worker, params, r_handle) From 728c64d1109b483a202e6deb28976f665214d963 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 12 Dec 2021 20:45:26 -0500 Subject: [PATCH 2/2] add test --- src/UCX.jl | 30 ++++++++++++++++-------------- test/runtests.jl | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/UCX.jl b/src/UCX.jl index 5ee540f..9337800 100644 --- a/src/UCX.jl +++ b/src/UCX.jl @@ -271,11 +271,7 @@ mutable struct UCXWorker end worker = new(handle, fd, context, IdDict{Any,Nothing}(), Dict{UInt16, Any}(), fill(false, Base.Threads.nthreads()), true, progress_mode) - finalizer(worker) do worker - worker.open = false - @assert isempty(worker.inflight) - API.ucp_worker_destroy(worker) - end + finalizer(destroy, worker) return worker end end @@ -313,6 +309,15 @@ function fence(worker::UCXWorker) @check API.ucp_worker_fence(worker) end +function destroy(worker::UCXWorker) + if worker.handle != C_NULL + close(worker) + @assert isempty(worker.inflight) + API.ucp_worker_destroy(worker) + worker.handle = C_NULL + end +end + function lock_am(worker::UCXWorker) tid = Base.Threads.threadid() if worker.in_amhandler[tid] @@ -387,18 +392,16 @@ function Base.notify(worker::UCXWorker) end function Base.isopen(worker::UCXWorker) - worker.open + worker.open && worker.handle != C_NULL end function Base.close(worker::UCXWorker) - @debug "Close worker" - worker.open = false - notify(worker) + if isopen(worker) + worker.open = false + notify(worker) + end end - - - """ AMHandler(func) @@ -533,8 +536,7 @@ end Base.unsafe_convert(::Type{API.ucp_ep_h}, ep::UCXEndpoint) = ep.handle function ucp_err_handler(arg::Ptr{Cvoid}, ep::API.ucp_ep_h, status::API.ucs_status_t) - @error "Endpoint error" exception=UCXException(status) - # TODO should we throw here and close the endpoint? + throw(UCXException(status)) return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 70bfc14..f4956e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,45 @@ end @test addr.len > 0 end +@testset "Error handler" begin + ctx = UCX.UCXContext() + server = UCX.UCXWorker(ctx) + + UCX.@spawn_showerr begin + while isopen(server) + wait(server) + end + close(server) + end + + barrier = Base.Event() + + am_called = Ref{Int}(0) + AM_TEST = 1 + function am_test(worker, header, header_length, data, length, _param) + am_called[] += 1 + notify(barrier) + return UCX.API.UCS_OK + end + UCX.AMHandler(server, am_test, AM_TEST) + + server_addr = UCX.UCXAddress(server) + client = UCX.UCXWorker(ctx) + ep = UCX.UCXEndpoint(client, server_addr) + + req = UCX.am_send(ep, AM_TEST, Int[]) + wait(req) # wait on request to be send before suspending in `take!` + wait(barrier) + + @test am_called[] == 1 + + UCX.destroy(server) + barrier = Base.Event() + req = UCX.am_send(ep, AM_TEST, Int[]) + @test_throws UCX.UCXException wait(req) # wait on request to be send before suspending in `take!` + @test am_called[] == 1 +end + @testset "Active Messages" begin cmd = Base.julia_cmd() if Base.JLOptions().project != C_NULL