Skip to content

Commit

Permalink
make err_handler optional
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Dec 6, 2021
1 parent e744a27 commit b64a420
Showing 1 changed file with 53 additions and 37 deletions.
90 changes: 53 additions & 37 deletions src/UCX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,19 +535,17 @@ function ucp_err_handler(arg::Ptr{Cvoid}, ep::API.ucp_ep_h, status::API.ucs_stat
return nothing
end

function UCXEndpoint(worker::UCXWorker, ip::IPv4, port)
function UCXEndpoint(worker::UCXWorker, ip::IPv4, port;
error_handling=true)
field_mask = API.UCP_EP_PARAM_FIELD_FLAGS |
API.UCP_EP_PARAM_FIELD_SOCK_ADDR |
API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
API.UCP_EP_PARAM_FIELD_ERR_HANDLER
API.UCP_EP_PARAM_FIELD_SOCK_ADDR
if error_handling
field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
API.UCP_EP_PARAM_FIELD_ERR_HANDLER
end
flags = API.UCP_EP_PARAMS_FLAGS_CLIENT_SERVER
sockaddr = Ref(IP.sockaddr_in(InetAddr(ip, port)))

err_handler = API.ucp_err_handler(
@cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
C_NULL
)

r_handle = Ref{API.ucp_ep_h}()
GC.@preserve sockaddr begin
ptr = Base.unsafe_convert(Ptr{IP.sockaddr_in}, sockaddr)
Expand All @@ -559,72 +557,90 @@ function UCXEndpoint(worker::UCXWorker, ip::IPv4, port)
set!(params, :field_mask, field_mask)
set!(params, :sockaddr, ucs_sockaddr)
set!(params, :flags, flags)
set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
set!(params, :err_handler, err_handler)
if error_handling
err_handler = API.ucp_err_handler(
@cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
C_NULL
)

set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
set!(params, :err_handler, err_handler)
end

@check API.ucp_ep_create(worker, params, r_handle)
end

UCXEndpoint(worker, r_handle[])
end

function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest)
function UCXEndpoint(worker::UCXWorker, conn_request::UCXConnectionRequest;
error_handling=true)
field_mask = API.UCP_EP_PARAM_FIELD_FLAGS |
API.UCP_EP_PARAM_FIELD_CONN_REQUEST |
API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
API.UCP_EP_PARAM_FIELD_ERR_HANDLER
API.UCP_EP_PARAM_FIELD_CONN_REQUEST
if error_handling
field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
API.UCP_EP_PARAM_FIELD_ERR_HANDLER
end

flags = API.UCP_EP_PARAMS_FLAGS_NO_LOOPBACK

err_handler = API.ucp_err_handler(
@cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
C_NULL
)

params = Ref{API.ucp_ep_params}()
memzero!(params)
set!(params, :field_mask, field_mask)
set!(params, :conn_request, conn_request.handle)
set!(params, :flags, flags)
set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
set!(params, :err_handler, err_handler)
if error_handling
err_handler = API.ucp_err_handler(
@cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
C_NULL
)

set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
set!(params, :err_handler, err_handler)
end

r_handle = Ref{API.ucp_ep_h}()
@check API.ucp_ep_create(worker, params, r_handle)

UCXEndpoint(worker, r_handle[])
end

function UCXEndpoint(worker::UCXWorker, addr::UCXAddress)
function UCXEndpoint(worker::UCXWorker, addr::UCXAddress;
error_handling=true)
GC.@preserve addr begin
_UCXEndpoint(worker, addr.handle)
_UCXEndpoint(worker, addr.handle, error_handling)
end
end

function UCXEndpoint(worker::UCXWorker, addr_buf::Vector{UInt8})
function UCXEndpoint(worker::UCXWorker, addr_buf::Vector{UInt8};
error_handling=true)
GC.@preserve addr_buf begin
addr = Base.unsafe_convert(Ptr{API.ucp_address_t}, pointer(addr_buf))
_UCXEndpoint(worker, addr)
_UCXEndpoint(worker, addr, error_handling)
end
end

function _UCXEndpoint(worker::UCXWorker, addr::Ptr{API.ucp_address_t})
field_mask = API.UCP_EP_PARAM_FIELD_REMOTE_ADDRESS |
API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
API.UCP_EP_PARAM_FIELD_ERR_HANDLER

err_handler = API.ucp_err_handler(
@cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
C_NULL
)
function _UCXEndpoint(worker::UCXWorker, addr::Ptr{API.ucp_address_t}, error_handling)
field_mask = API.UCP_EP_PARAM_FIELD_REMOTE_ADDRESS
if error_handling
field_mask |= API.UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE |
API.UCP_EP_PARAM_FIELD_ERR_HANDLER
end

r_handle = Ref{API.ucp_ep_h}()
params = Ref{API.ucp_ep_params}()
memzero!(params)
set!(params, :field_mask, field_mask)
set!(params, :address, addr)
set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
set!(params, :err_handler, err_handler)
if error_handling
err_handler = API.ucp_err_handler(
@cfunction(ucp_err_handler, Cvoid, (Ptr{Cvoid}, API.ucp_ep_h, API.ucs_status_t)),
C_NULL
)

set!(params, :err_mode, API.UCP_ERR_HANDLING_MODE_PEER)
set!(params, :err_handler, err_handler)
end

@check API.ucp_ep_create(worker, params, r_handle)

Expand Down

0 comments on commit b64a420

Please sign in to comment.