From 73d6267fd60c428e33ed4ee3760ffceff8756dd8 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 2 Feb 2021 09:13:24 -0500 Subject: [PATCH 1/2] add highlevel endpoint API --- Project.toml | 1 + examples/client_server.jl | 25 ++++--- examples/client_server_stream.jl | 10 +-- src/UCX.jl | 117 +++++++++++++++++++++++++++---- 4 files changed, 121 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 4afc881..ca3f7ab 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" UCX_jll = "16e4e860-d6b8-5056-a518-93e88b6392ae" diff --git a/examples/client_server.jl b/examples/client_server.jl index e79ab1b..a12d6a6 100644 --- a/examples/client_server.jl +++ b/examples/client_server.jl @@ -1,7 +1,6 @@ using UCX using Sockets -using UCX: UCXEndpoint using UCX: recv, send using Base.Threads @@ -9,12 +8,12 @@ using Base.Threads const default_port = 8890 const expected_clients = Atomic{Int}(0) -function echo_server(ep::UCXEndpoint) +function echo_server(ep::UCX.Endpoint) size = Int[0] - recv(ep.worker, size, sizeof(Int), 777) + recv(ep, size, sizeof(Int)) data = Array{UInt8}(undef, size[1]) - recv(ep.worker, data, sizeof(data), 777) - send(ep, data, sizeof(data), 777) + recv(ep, data, sizeof(data)) + send(ep, data, sizeof(data)) atomic_sub!(expected_clients, 1) end @@ -26,10 +25,10 @@ function start_server(ch_port = Channel{Int}(1), port = default_port) conn_request = UCX.UCXConnectionRequest(conn_request_h) Threads.@spawn begin try - echo_server(UCXEndpoint($worker, $conn_request)) + echo_server(UCX.Endpoint($worker, $conn_request)) catch err showerror(stderr, err, catch_backtrace()) - exit(-1) + exit(-1) # Fatal end end nothing @@ -49,13 +48,13 @@ end function start_client(port=default_port) ctx = UCX.UCXContext() worker = UCX.UCXWorker(ctx) - ep = UCX.UCXEndpoint(worker, IPv4("127.0.0.1"), port) + ep = UCX.Endpoint(worker, IPv4("127.0.0.1"), port) data = "Hello world" - send(ep, Int[sizeof(data)], sizeof(Int), 777) - send(ep, data, sizeof(data), 777) + send(ep, Int[sizeof(data)], sizeof(Int)) + send(ep, data, sizeof(data)) buffer = Array{UInt8}(undef, sizeof(data)) - recv(worker, buffer, sizeof(buffer), 777) + recv(ep, buffer, sizeof(buffer)) @assert String(buffer) == data end @@ -70,10 +69,10 @@ if !isinteractive() elseif kind =="test" ch_port = Channel{Int}(1) @sync begin - UCX.@async_showerr start_server(ch_port, nothing) + UCX.@spawn_showerr start_server(ch_port, nothing) port = take!(ch_port) for i in 1:expected_clients[] - UCX.@async_showerr start_client(port) + UCX.@spawn_showerr start_client(port) end end end diff --git a/examples/client_server_stream.jl b/examples/client_server_stream.jl index 55a07f4..e72aaf9 100644 --- a/examples/client_server_stream.jl +++ b/examples/client_server_stream.jl @@ -9,9 +9,9 @@ using Base.Threads const default_port = 8890 const expected_clients = Atomic{Int}(0) -function echo_server(ep::UCXEndpoint) +function echo_server(ep::UCX.Endpoint) size = Int[0] - recv(ep.worker, size, sizeof(Int), 777) + recv(ep, size, sizeof(Int)) data = Array{UInt8}(undef, size[1]) stream_recv(ep, data, sizeof(data)) stream_send(ep, data, sizeof(data)) @@ -26,7 +26,7 @@ function start_server(ch_port = Channel{Int}(1), port = default_port) conn_request = UCX.UCXConnectionRequest(conn_request_h) Threads.@spawn begin try - echo_server(UCXEndpoint($worker, $conn_request)) + echo_server(UCX.Endpoint($worker, $conn_request)) catch err showerror(stderr, err, catch_backtrace()) exit(-1) # Fatal error @@ -49,10 +49,10 @@ end function start_client(port = default_port) ctx = UCX.UCXContext() worker = UCX.UCXWorker(ctx) - ep = UCX.UCXEndpoint(worker, IPv4("127.0.0.1"), port) + ep = UCX.Endpoint(worker, IPv4("127.0.0.1"), port) data = "Hello world" - send(ep, Int[sizeof(data)], sizeof(Int), 777) + send(ep, Int[sizeof(data)], sizeof(Int)) stream_send(ep, data, sizeof(data)) buffer = Array{UInt8}(undef, sizeof(data)) stream_recv(ep, buffer, sizeof(buffer)) diff --git a/src/UCX.jl b/src/UCX.jl index 16f681a..4b3bab3 100644 --- a/src/UCX.jl +++ b/src/UCX.jl @@ -1,6 +1,7 @@ module UCX using Sockets: InetAddr, IPv4, listenany +using Random include("api.jl") @@ -376,10 +377,12 @@ end ## function send_callback(request::Ptr{Cvoid}, status::API.ucs_status_t) + @check status nothing end function recv_callback(request::Ptr{Cvoid}, status::API.ucs_status_t, info::Ptr{API.ucp_tag_recv_info_t}) + @check status nothing end @@ -417,7 +420,7 @@ function recv(worker::UCXWorker, buffer, nbytes, tag, tag_mask=~zero(UCX.API.ucp dt = ucp_dt_make_contig(1) cb = @cfunction(recv_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t, Ptr{API.ucp_tag_recv_info_t})) - GC.@preserve buffer begin + GC.@preserve buffer cb begin data = pointer(buffer) ptr = API.ucp_tag_recv_nb(worker, data, nbytes, dt, tag, tag_mask, cb) return handle_request(worker, ptr) @@ -453,37 +456,123 @@ end # UCX stream interface -function stream_send(ep::UCXEndpoint, buffer, nbytes) - dt = ucp_dt_make_contig(1) # since we are sending nbytes - cb = @cfunction(send_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t)) +function stream_recv_callback(request::Ptr{Cvoid}, status::API.ucs_status_t, length::Csize_t) + @check status + nothing +end +function stream_send(ep::UCXEndpoint, buffer, nbytes) GC.@preserve buffer begin data = pointer(buffer) + stream_send(ep, data, nbytes) + end +end - ptr = API.ucp_stream_send_nb(ep, data, nbytes, dt, cb, #=flags=# 0) - return handle_request(ep, ptr) +function stream_send(ep::UCXEndpoint, ref::Ref{T}) where T + GC.@preserve ref begin + data = Base.unsafe_convert(Ptr{Cvoid}, ref) + stream_send(ep, data, sizeof(T)) end end -function stream_recv(ep::UCXEndpoint, buffer, nbytes) +function stream_send(ep::UCXEndpoint, data::Ptr, nbytes) dt = ucp_dt_make_contig(1) # since we are sending nbytes cb = @cfunction(send_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t)) + ptr = API.ucp_stream_send_nb(ep, data, nbytes, dt, cb, #=flags=# 0) + return handle_request(ep, ptr) +end + +function stream_recv(ep::UCXEndpoint, buffer, nbytes) GC.@preserve buffer begin data = pointer(buffer) + stream_recv(ep, data, nbytes) + end +end - length = Ref{Csize_t}(0) - ptr = API.ucp_stream_recv_nb(ep, data, nbytes, dt, cb, length, API.UCP_STREAM_RECV_FLAG_WAITALL) - return handle_request(ep, ptr) +function stream_recv(ep::UCXEndpoint, ref::Ref{T}) where T + GC.@preserve ref begin + data = Base.unsafe_convert(Ptr{Cvoid}, ref) + stream_recv(ep, data, sizeof(T)) + end +end + +function stream_recv(ep::UCXEndpoint, data::Ptr, nbytes) + dt = ucp_dt_make_contig(1) # since we are sending nbytes + cb = @cfunction(stream_recv_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t, Csize_t)) + + length = Ref{Csize_t}(0) + ptr = API.ucp_stream_recv_nb(ep, data, nbytes, dt, cb, length, API.UCP_STREAM_RECV_FLAG_WAITALL) + return handle_request(ep, ptr) +end + +### TODO: stream_recv_data_nb +### TODO: stream_recv_nbx + +## RMA + +## Atomics + +## AM + +## Collectives + +# Higher-Level API + +mutable struct Endpoint + ep::UCXEndpoint + msg_tag_send::API.ucp_tag_t + msg_tag_recv::API.ucp_tag_t +end + +# TODO: Tag structure +# OMPI uses msg_tag (24) | source_rank (20) | context_id (20) + +tag(kind, seed, port) = hash(kind, hash(seed, hash(port))) + +function Endpoint(worker::UCXWorker, addr, port) + ep = UCX.UCXEndpoint(worker, addr, port) + Endpoint(ep, false) +end + +function Endpoint(worker::UCXWorker, connection::UCXConnectionRequest) + ep = UCX.UCXEndpoint(worker, connection) + Endpoint(ep, true) +end + +function Endpoint(ep::UCXEndpoint, listener) + seed = rand(UInt128) + pid = getpid() + msg_tag = tag(:ctrl, seed, pid) + + send_tag = Ref(msg_tag) + recv_tag = Ref(msg_tag) + if listener + stream_send(ep, send_tag) + stream_recv(ep, recv_tag) + else + stream_recv(ep, recv_tag) + stream_send(ep, send_tag) end + @assert msg_tag !== recv_tag[] + + Endpoint(ep, msg_tag, recv_tag[]) end -# RMA +function send(ep::Endpoint, buffer, nbytes) + send(ep.ep, buffer, nbytes, ep.msg_tag_send) +end -# Atomics +function recv(ep::Endpoint, buffer, nbytes) + recv(ep.ep.worker, buffer, nbytes, ep.msg_tag_recv) +end -# AM +function stream_send(ep::Endpoint, args...) + stream_send(ep.ep, args...) +end -# Collectives +function stream_recv(ep::Endpoint, args...) + stream_recv(ep.ep, args...) +end end #module From b572fd251fe7a6c330d06622487764b2ac42b3aa Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Tue, 2 Feb 2021 09:14:10 -0500 Subject: [PATCH 2/2] fixup! add highlevel endpoint API --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8f88032..75ec1b7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,7 @@ end @testset "Client-Server" begin script = joinpath(examples_dir, "client_server.jl") - for i in 0:0 + for i in 0:2 @test success(pipeline(`$cmd $script test $(2^i)`, stderr=stderr, stdout=stdout)) end end