Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add highlevel endpoint API #24

Merged
merged 2 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
UCX_jll = "16e4e860-d6b8-5056-a518-93e88b6392ae"

Expand Down
25 changes: 12 additions & 13 deletions examples/client_server.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
using UCX
using Sockets

using UCX: UCXEndpoint
using UCX: recv, send

using Base.Threads

const default_port = 8890
const expected_clients = Atomic{Int}(0)

function echo_server(ep::UCXEndpoint)
function echo_server(ep::UCX.Endpoint)
size = Int[0]
recv(ep.worker, size, sizeof(Int), 777)
recv(ep, size, sizeof(Int))
data = Array{UInt8}(undef, size[1])
recv(ep.worker, data, sizeof(data), 777)
send(ep, data, sizeof(data), 777)
recv(ep, data, sizeof(data))
send(ep, data, sizeof(data))
atomic_sub!(expected_clients, 1)
end

Expand All @@ -26,10 +25,10 @@ function start_server(ch_port = Channel{Int}(1), port = default_port)
conn_request = UCX.UCXConnectionRequest(conn_request_h)
Threads.@spawn begin
try
echo_server(UCXEndpoint($worker, $conn_request))
echo_server(UCX.Endpoint($worker, $conn_request))
catch err
showerror(stderr, err, catch_backtrace())
exit(-1)
exit(-1) # Fatal
end
end
nothing
Expand All @@ -49,13 +48,13 @@ end
function start_client(port=default_port)
ctx = UCX.UCXContext()
worker = UCX.UCXWorker(ctx)
ep = UCX.UCXEndpoint(worker, IPv4("127.0.0.1"), port)
ep = UCX.Endpoint(worker, IPv4("127.0.0.1"), port)

data = "Hello world"
send(ep, Int[sizeof(data)], sizeof(Int), 777)
send(ep, data, sizeof(data), 777)
send(ep, Int[sizeof(data)], sizeof(Int))
send(ep, data, sizeof(data))
buffer = Array{UInt8}(undef, sizeof(data))
recv(worker, buffer, sizeof(buffer), 777)
recv(ep, buffer, sizeof(buffer))
@assert String(buffer) == data
end

Expand All @@ -70,10 +69,10 @@ if !isinteractive()
elseif kind =="test"
ch_port = Channel{Int}(1)
@sync begin
UCX.@async_showerr start_server(ch_port, nothing)
UCX.@spawn_showerr start_server(ch_port, nothing)
port = take!(ch_port)
for i in 1:expected_clients[]
UCX.@async_showerr start_client(port)
UCX.@spawn_showerr start_client(port)
end
end
end
Expand Down
10 changes: 5 additions & 5 deletions examples/client_server_stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ using Base.Threads
const default_port = 8890
const expected_clients = Atomic{Int}(0)

function echo_server(ep::UCXEndpoint)
function echo_server(ep::UCX.Endpoint)
size = Int[0]
recv(ep.worker, size, sizeof(Int), 777)
recv(ep, size, sizeof(Int))
data = Array{UInt8}(undef, size[1])
stream_recv(ep, data, sizeof(data))
stream_send(ep, data, sizeof(data))
Expand All @@ -26,7 +26,7 @@ function start_server(ch_port = Channel{Int}(1), port = default_port)
conn_request = UCX.UCXConnectionRequest(conn_request_h)
Threads.@spawn begin
try
echo_server(UCXEndpoint($worker, $conn_request))
echo_server(UCX.Endpoint($worker, $conn_request))
catch err
showerror(stderr, err, catch_backtrace())
exit(-1) # Fatal error
Expand All @@ -49,10 +49,10 @@ end
function start_client(port = default_port)
ctx = UCX.UCXContext()
worker = UCX.UCXWorker(ctx)
ep = UCX.UCXEndpoint(worker, IPv4("127.0.0.1"), port)
ep = UCX.Endpoint(worker, IPv4("127.0.0.1"), port)

data = "Hello world"
send(ep, Int[sizeof(data)], sizeof(Int), 777)
send(ep, Int[sizeof(data)], sizeof(Int))
stream_send(ep, data, sizeof(data))
buffer = Array{UInt8}(undef, sizeof(data))
stream_recv(ep, buffer, sizeof(buffer))
Expand Down
117 changes: 103 additions & 14 deletions src/UCX.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module UCX

using Sockets: InetAddr, IPv4, listenany
using Random

include("api.jl")

Expand Down Expand Up @@ -376,10 +377,12 @@ end
##

function send_callback(request::Ptr{Cvoid}, status::API.ucs_status_t)
@check status
nothing
end

function recv_callback(request::Ptr{Cvoid}, status::API.ucs_status_t, info::Ptr{API.ucp_tag_recv_info_t})
@check status
nothing
end

Expand Down Expand Up @@ -417,7 +420,7 @@ function recv(worker::UCXWorker, buffer, nbytes, tag, tag_mask=~zero(UCX.API.ucp
dt = ucp_dt_make_contig(1)
cb = @cfunction(recv_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t, Ptr{API.ucp_tag_recv_info_t}))

GC.@preserve buffer begin
GC.@preserve buffer cb begin
data = pointer(buffer)
ptr = API.ucp_tag_recv_nb(worker, data, nbytes, dt, tag, tag_mask, cb)
return handle_request(worker, ptr)
Expand Down Expand Up @@ -453,37 +456,123 @@ end

# UCX stream interface

function stream_send(ep::UCXEndpoint, buffer, nbytes)
dt = ucp_dt_make_contig(1) # since we are sending nbytes
cb = @cfunction(send_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t))
function stream_recv_callback(request::Ptr{Cvoid}, status::API.ucs_status_t, length::Csize_t)
@check status
nothing
end

function stream_send(ep::UCXEndpoint, buffer, nbytes)
GC.@preserve buffer begin
data = pointer(buffer)
stream_send(ep, data, nbytes)
end
end

ptr = API.ucp_stream_send_nb(ep, data, nbytes, dt, cb, #=flags=# 0)
return handle_request(ep, ptr)
function stream_send(ep::UCXEndpoint, ref::Ref{T}) where T
GC.@preserve ref begin
data = Base.unsafe_convert(Ptr{Cvoid}, ref)
stream_send(ep, data, sizeof(T))
end
end

function stream_recv(ep::UCXEndpoint, buffer, nbytes)
function stream_send(ep::UCXEndpoint, data::Ptr, nbytes)
dt = ucp_dt_make_contig(1) # since we are sending nbytes
cb = @cfunction(send_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t))

ptr = API.ucp_stream_send_nb(ep, data, nbytes, dt, cb, #=flags=# 0)
return handle_request(ep, ptr)
end

function stream_recv(ep::UCXEndpoint, buffer, nbytes)
GC.@preserve buffer begin
data = pointer(buffer)
stream_recv(ep, data, nbytes)
end
end

length = Ref{Csize_t}(0)
ptr = API.ucp_stream_recv_nb(ep, data, nbytes, dt, cb, length, API.UCP_STREAM_RECV_FLAG_WAITALL)
return handle_request(ep, ptr)
function stream_recv(ep::UCXEndpoint, ref::Ref{T}) where T
GC.@preserve ref begin
data = Base.unsafe_convert(Ptr{Cvoid}, ref)
stream_recv(ep, data, sizeof(T))
end
end

function stream_recv(ep::UCXEndpoint, data::Ptr, nbytes)
dt = ucp_dt_make_contig(1) # since we are sending nbytes
cb = @cfunction(stream_recv_callback, Cvoid, (Ptr{Cvoid}, API.ucs_status_t, Csize_t))

length = Ref{Csize_t}(0)
ptr = API.ucp_stream_recv_nb(ep, data, nbytes, dt, cb, length, API.UCP_STREAM_RECV_FLAG_WAITALL)
return handle_request(ep, ptr)
end

### TODO: stream_recv_data_nb
### TODO: stream_recv_nbx

## RMA

## Atomics

## AM

## Collectives

# Higher-Level API

mutable struct Endpoint
ep::UCXEndpoint
msg_tag_send::API.ucp_tag_t
msg_tag_recv::API.ucp_tag_t
end

# TODO: Tag structure
# OMPI uses msg_tag (24) | source_rank (20) | context_id (20)

tag(kind, seed, port) = hash(kind, hash(seed, hash(port)))

function Endpoint(worker::UCXWorker, addr, port)
ep = UCX.UCXEndpoint(worker, addr, port)
Endpoint(ep, false)
end

function Endpoint(worker::UCXWorker, connection::UCXConnectionRequest)
ep = UCX.UCXEndpoint(worker, connection)
Endpoint(ep, true)
end

function Endpoint(ep::UCXEndpoint, listener)
seed = rand(UInt128)
pid = getpid()
msg_tag = tag(:ctrl, seed, pid)

send_tag = Ref(msg_tag)
recv_tag = Ref(msg_tag)
if listener
stream_send(ep, send_tag)
stream_recv(ep, recv_tag)
else
stream_recv(ep, recv_tag)
stream_send(ep, send_tag)
end
@assert msg_tag !== recv_tag[]

Endpoint(ep, msg_tag, recv_tag[])
end

# RMA
function send(ep::Endpoint, buffer, nbytes)
send(ep.ep, buffer, nbytes, ep.msg_tag_send)
end

# Atomics
function recv(ep::Endpoint, buffer, nbytes)
recv(ep.ep.worker, buffer, nbytes, ep.msg_tag_recv)
end

# AM
function stream_send(ep::Endpoint, args...)
stream_send(ep.ep, args...)
end

# Collectives
function stream_recv(ep::Endpoint, args...)
stream_recv(ep.ep, args...)
end

end #module
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ end

@testset "Client-Server" begin
script = joinpath(examples_dir, "client_server.jl")
for i in 0:0
for i in 0:2
@test success(pipeline(`$cmd $script test $(2^i)`, stderr=stderr, stdout=stdout))
end
end
Expand Down