Skip to content

Commit

Permalink
Merge pull request #48 from JuliaParallel/vc/config2
Browse files Browse the repository at this point in the history
allow fine-grained control over context features
  • Loading branch information
vchuravy authored Dec 13, 2021
2 parents 515a237 + 7b29025 commit de1bed2
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
branches:
- main
tags: "*"
pull-request:
jobs:
docs:
name: Documentation
Expand All @@ -19,6 +20,7 @@ jobs:
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
Pkg.build("UCX")
'
- name: run doctests
run: |
Expand Down
88 changes: 76 additions & 12 deletions src/UCX.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module UCX
using Sockets: InetAddr, IPv4, listenany
using Random
import FunctionWrappers: FunctionWrapper
import CEnum

const PROGRESS_MODE = Ref(:idling)

Expand Down Expand Up @@ -117,6 +118,21 @@ function version()
VersionNumber(major[], minor[], patch[])
end

function query()
field_mask = API.UCP_LIB_ATTR_FIELD_MAX_THREAD_LEVEL
r_attr = Ref{API.ucp_lib_attr_t}()

memzero!(r_attr)
set!(r_attr, :field_mask, field_mask)

@check API.ucp_lib_query(r_attr)
r_attr[]
end

function max_thread_level()
query().max_thread_level
end

mutable struct UCXConfig
handle::Ptr{API.ucp_config_t}

Expand Down Expand Up @@ -171,37 +187,57 @@ end

mutable struct UCXContext
handle::API.ucp_context_h
features::UInt32
config::Dict{Symbol, String}

function UCXContext(wakeup = true; kwargs...)
field_mask = API.UCP_PARAM_FIELD_FEATURES
function UCXContext(;wakeup = true, tag = true, stream = true, am = true, rdma = true, amo = true,
shared = false,
kwargs...)
field_mask = API.UCP_PARAM_FIELD_FEATURES

# Note: ucx-py always request UCP_FEATURE_WAKEUP even when in blocking mode
# See <https://github.com/rapidsai/ucx-py/pull/377>

# There is also AMO32 & AMO64 (atomic), RMA,
features = API.UCP_FEATURE_TAG |
API.UCP_FEATURE_STREAM |
API.UCP_FEATURE_AM |
API.UCP_FEATURE_RMA
if shared
field_mask |= API.UCP_PARAM_FIELD_MT_WORKERS_SHARED
end

features = zero(CEnum.basetype(UCX.API.ucp_feature))
if wakeup
features |= API.UCP_FEATURE_WAKEUP
end
if tag
features |= API.UCP_FEATURE_TAG
end
if stream
features |= API.UCP_FEATURE_STREAM
end
if am
features |= API.UCP_FEATURE_AM
end
if rdma
features |= API.UCP_FEATURE_RMA
end
if amo
features |= API.UCP_FEATURE_AMO32 |
API.UCP_FEATURE_AMO64 |
API.UCP_FEATURE_RMA
end

params = Ref{API.ucp_params}()
memzero!(params)
set!(params, :field_mask, field_mask)
set!(params, :features, features)

if shared
set!(params, :mt_workers_shared, true)
end

config = UCXConfig(; kwargs...)

r_handle = Ref{API.ucp_context_h}()
# UCP.ucp_init is a header function so we call, UCP.ucp_init_version
@check API.ucp_init_version(API.UCP_API_MAJOR, API.UCP_API_MINOR,
params, config, r_handle)

context = new(r_handle[], parse(Dict, config))
context = new(r_handle[], features, parse(Dict, config))

finalizer(context) do context
API.ucp_cleanup(context)
Expand Down Expand Up @@ -231,11 +267,20 @@ function info(ucx::UCXContext)
end

function query(ctx::UCXContext)
field_mask = API.UCP_ATTR_FIELD_THREAD_MODE
r_attr = Ref{API.ucp_context_attr_t}()
API.ucp_context_query(ctx, r_attr)

memzero!(r_attr)
set!(r_attr, :field_mask, field_mask)

@check API.ucp_context_query(ctx, r_attr)
r_attr[]
end

function thread_mode(ctx::UCXContext)
query(ctx).thread_mode
end

mutable struct UCXWorker
handle::API.ucp_worker_h
fd::RawFD
Expand Down Expand Up @@ -281,6 +326,25 @@ mutable struct UCXWorker
end
Base.unsafe_convert(::Type{API.ucp_worker_h}, worker::UCXWorker) = worker.handle

function query(worker::UCXWorker)
field_mask = API.UCP_WORKER_ATTR_FIELD_THREAD_MODE |
API.UCP_WORKER_ATTR_FIELD_MAX_AM_HEADER
r_attr = Ref{API.ucp_worker_attr_t}()

memzero!(r_attr)
set!(r_attr, :field_mask, field_mask)
API.ucp_worker_query(worker, r_attr)
r_attr[]
end

function thread_mode(worker::UCXWorker)
query(worker).thread_mode
end

function max_am_header(worker::UCXWorker)
query(worker).max_am_header
end

ispolling(worker::UCXWorker) = worker.fd != RawFD(-1)
progress_mode(worker::UCXWorker) = worker.mode
context(worker::UCXWorker) = worker.context
Expand Down
23 changes: 23 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,29 @@ using UCX
@test ctx.config[:TLS] == "tcp"
end

@testset "query" begin
# TODO support other builds
UCX.max_thread_level() == UCX.API.UCS_THREAD_MODE_MULTI
end

@testset "context" begin
ctx = UCX.UCXContext(TLS="tcp")
@test ctx.config[:TLS] == "tcp"

@test UCX.thread_mode(UCX.UCXContext(; shared=false)) == UCX.API.UCS_THREAD_MODE_SINGLE
@test UCX.thread_mode(UCX.UCXContext(; shared=true)) == UCX.API.UCS_THREAD_MODE_MULTI
end

@testset "Worker" begin
worker = UCX.UCXWorker(UCX.UCXContext())

UCX.thread_mode(worker) == UCX.API.UCS_THREAD_MODE_MULTI
UCX.max_am_header(worker) > 0

worker = UCX.UCXWorker(UCX.UCXContext(; am = false))
UCX.max_am_header(worker) == 0
end

@testset "progress" begin
using UCX
UCX.PROGRESS_MODE[] = :polling
Expand Down

0 comments on commit de1bed2

Please sign in to comment.