Skip to content

Commit

Permalink
Some tweaks to the @tasks macro (#62)
Browse files Browse the repository at this point in the history
* move macro body to the internals module

* only accept one argument, error if the user gives a complex loop assignemnt

* remove double quoting, reformat comment

* remove outdated code

* re-org
  • Loading branch information
MasonProtter authored Mar 1, 2024
1 parent 7b8af0e commit 2d9acb4
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 188 deletions.
2 changes: 2 additions & 0 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using BangBang: append!!

using ChunkSplitters: ChunkSplitters

include("macro_impl.jl")

function auto_disable_chunking_warning()
@warn("You passed in a `ChunkSplitters.Chunk` but also a scheduler that has "*
"chunking enabled. Will turn off internal chunking to proceed.\n"*
Expand Down
174 changes: 174 additions & 0 deletions src/macro_impl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
function tasks_macro(forex)
if forex.head != :for
throw(ErrorException("Expected a for loop after `@tasks`."))
else
if forex.args[1].head != :(=)
# this'll catch cases like
# @tasks for _ ∈ 1:10, _ ∈ 1:10
# body
# end
throw(ErrorException("`@tasks` currently only supports a single threaded loop, got $(forex.args[1])"))
end
it = forex.args[1]
itvar = it.args[1]
itrng = it.args[2]
forbody = forex.args[2]
end

settings = Settings()

inits_before, init_inner = _maybe_handle_init_block!(forbody.args)
_maybe_handle_set_block!(settings, forbody.args)

forbody = esc(forbody)
itrng = esc(itrng)
itvar = esc(itvar)

# @show settings
q = if !isnothing(settings.reducer)
quote
tmapreduce(
$(settings.reducer), $(itrng); scheduler = $(settings.scheduler)) do $(itvar)
$(init_inner)
$(forbody)
end
end
elseif settings.collect
quote
tmap($(itrng); scheduler = $(settings.scheduler)) do $(itvar)
$(init_inner)
$(forbody)
end
end
else
quote
tforeach($(itrng); scheduler = $(settings.scheduler)) do $(itvar)
$(init_inner)
$(forbody)
end
end
end

# wrap everything in a let ... end block
# and, potentially, define the `TaskLocalValue`s.
result = :(let
end)
push!(result.args[2].args, q)
if !isnothing(inits_before)
for x in inits_before
push!(result.args[1].args, x)
end
end

result
end

Base.@kwdef mutable struct Settings
scheduler::Expr = :(DynamicScheduler())
reducer::Union{Expr, Symbol, Nothing} = nothing
collect::Bool = false
end

function _sym2scheduler(s)
if s == :dynamic
:(DynamicScheduler())
elseif s == :static
:(StaticScheduler())
elseif s == :greedy
:(GreedyScheduler())
else
throw(ArgumentError("Unknown scheduler symbol."))
end
end

function _maybe_handle_init_block!(args)
inits_before = nothing
init_inner = nothing
tlsidx = findfirst(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@init")
end
if !isnothing(tlsidx)
inits_before, init_inner = _unfold_init_block(args[tlsidx].args[3])
deleteat!(args, tlsidx)
end
return inits_before, init_inner
end

function _unfold_init_block(ex)
inits_before = Expr[]
if ex.head == :(=)
initb, init_inner = _init_assign_to_exprs(ex)
push!(inits_before, initb)
elseif ex.head == :block
tlsexprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode
init_inner = quote end
for x in tlsexprs
initb, initi = _init_assign_to_exprs(x)
push!(inits_before, initb)
push!(init_inner.args, initi)
end
else
throw(ErrorException("Wrong usage of @init. You must either provide a typed assignment or multiple typed assignments in a `begin ... end` block."))
end
return inits_before, init_inner
end

function _init_assign_to_exprs(ex)
left_ex = ex.args[1]
if left_ex isa Symbol || left_ex.head != :(::)
throw(ErrorException("Wrong usage of @init. Expected typed assignment, e.g. `A::Matrix{Float} = rand(2,2)`."))
end
tls_sym = esc(left_ex.args[1])
tls_type = esc(left_ex.args[2])
tls_def = esc(ex.args[2])
@gensym tls_storage
init_before = :($(tls_storage) = TaskLocalValue{$tls_type}(() -> $(tls_def)))
init_inner = :($(tls_sym) = $(tls_storage)[])
return init_before, init_inner
end

function _maybe_handle_set_block!(settings, args)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")
end
isnothing(idcs) && return # no set block found
for i in idcs
ex = args[i].args[3]
if ex.head == :(=)
_handle_set_single_assign!(settings, ex)
elseif ex.head == :block
exprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode
_handle_set_single_assign!.(Ref(settings), exprs)
else
throw(ErrorException("Wrong usage of @set. You must either provide an assignment or multiple assignments in a `begin ... end` block."))
end
end
deleteat!(args, idcs)
# check incompatible settings
if settings.collect && !isnothing(settings.reducer)
throw(ArgumentError("Specifying both collect and reducer isn't supported."))
end
end

function _handle_set_single_assign!(settings, ex)
if ex.head != :(=)
throw(ErrorException("Wrong usage of @set. Expected assignment, e.g. `scheduler = StaticScheduler()`."))
end
sym = ex.args[1]
if !hasfield(Settings, sym)
throw(ArgumentError("Unknown setting \"$(sym)\". Must be ∈ $(fieldnames(Settings))."))
end
def = ex.args[2]
if sym == :collect && !(def isa Bool)
throw(ArgumentError("Setting collect can only be true or false."))
#TODO support specifying the OutputElementType
end
def = if def isa QuoteNode
_sym2scheduler(def.value)
elseif def isa Bool
def
else
esc(def)
end
setfield!(settings, sym, def)
end
189 changes: 1 addition & 188 deletions src/macros.jl
Original file line number Diff line number Diff line change
@@ -1,123 +1,3 @@
Base.@kwdef mutable struct Settings
scheduler::Expr = :(DynamicScheduler())
reducer::Union{Expr, Symbol, Nothing} = nothing
collect::Bool = false
end

function _sym2scheduler(s)
if s == :dynamic
:(DynamicScheduler())
elseif s == :static
:(StaticScheduler())
elseif s == :greedy
:(GreedyScheduler())
else
throw(ArgumentError("Unknown scheduler symbol."))
end
end

function _maybe_handle_init_block!(args)
inits_before = nothing
init_inner = nothing
tlsidx = findfirst(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@init")
end
if !isnothing(tlsidx)
inits_before, init_inner = _unfold_init_block(args[tlsidx].args[3])
deleteat!(args, tlsidx)
end
return inits_before, init_inner
end

function _unfold_init_block(ex)
inits_before = Expr[]
if ex.head == :(=)
initb, init_inner = _init_assign_to_exprs(ex)
push!(inits_before, initb)
elseif ex.head == :block
tlsexprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode
init_inner = quote end
for x in tlsexprs
initb, initi = _init_assign_to_exprs(x)
push!(inits_before, initb)
push!(init_inner.args, initi)
end
else
throw(ErrorException("Wrong usage of @init. You must either provide a typed assignment or multiple typed assignments in a `begin ... end` block."))
end
return inits_before, init_inner
end

function _init_assign_to_exprs(ex)
left_ex = ex.args[1]
if left_ex isa Symbol || left_ex.head != :(::)
throw(ErrorException("Wrong usage of @init. Expected typed assignment, e.g. `A::Matrix{Float} = rand(2,2)`."))
end
tls_sym = esc(left_ex.args[1])
tls_type = esc(left_ex.args[2])
tls_def = esc(ex.args[2])
@gensym tls_storage
init_before = :($(tls_storage) = OhMyThreads.TaskLocalValue{$tls_type}(() -> $(tls_def)))
init_inner = :($(tls_sym) = $(tls_storage)[])
return init_before, init_inner
end

function _maybe_handle_set_block!(settings, args)
idcs = findall(args) do arg
arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set")
end
isnothing(idcs) && return # no set block found
for i in idcs
ex = args[i].args[3]
if ex.head == :(=)
_handle_set_single_assign!(settings, ex)
elseif ex.head == :block
exprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode
_handle_set_single_assign!.(Ref(settings), exprs)
else
throw(ErrorException("Wrong usage of @set. You must either provide an assignment or multiple assignments in a `begin ... end` block."))
end
end
deleteat!(args, idcs)
# check incompatible settings
if settings.collect && !isnothing(settings.reducer)
throw(ArgumentError("Specifying both collect and reducer isn't supported."))
end
end

function _handle_set_single_assign!(settings, ex)
if ex.head != :(=)
throw(ErrorException("Wrong usage of @set. Expected assignment, e.g. `scheduler = StaticScheduler()`."))
end
sym = ex.args[1]
if !hasfield(Settings, sym)
throw(ArgumentError("Unknown setting \"$(sym)\". Must be ∈ $(fieldnames(Settings))."))
end
def = ex.args[2]
if sym == :collect && !(def isa Bool)
throw(ArgumentError("Setting collect can only be true or false."))
#TODO support specifying the OutputElementType
end
def = if def isa QuoteNode
_sym2scheduler(def.value)
elseif def isa Bool
def
else
esc(def)
end
setfield!(settings, sym, def)
end

# function _kwarg_to_tuple(ex)
# ex.head != :(=) &&
# throw(ArgumentError("Invalid keyword argument. Doesn't contain '='."))
# name, val = ex.args
# !(name isa Symbol) &&
# throw(ArgumentError("First part of keyword argument isn't a symbol."))
# val isa QuoteNode && (val = val.value)
# (name, val)
# end

"""
@tasks for ... end
Expand Down Expand Up @@ -170,74 +50,7 @@ end
```
"""
macro tasks(args...)
forex = last(args)
if forex.head != :for || length(args) > 1
throw(ErrorException("Expected a for loop after `@tasks`."))
else
it = forex.args[1]
itvar = it.args[1]
itrng = it.args[2]
forbody = forex.args[2]
end

settings = Settings()

# kwexs = args[begin:(end - 1)]
# for ex in kwexs
# name, val = _kwarg_to_tuple(ex)
# if name == :scheduler
# settings.scheduler = val isa Symbol ? _sym2scheduler(val) : val
# elseif name == :reducer
# settings.reducer = val
# else
# throw(ArgumentError("Unknown keyword argument: $name"))
# end
# end

inits_before, init_inner = _maybe_handle_init_block!(forbody.args)
_maybe_handle_set_block!(settings, forbody.args)

forbody = esc(forbody)
itrng = esc(itrng)
itvar = esc(itvar)

# @show settings
q = if !isnothing(settings.reducer)
quote
tmapreduce(
$(settings.reducer), $(itrng); scheduler = $(settings.scheduler)) do $(itvar)
$(init_inner)
$(forbody)
end
end
elseif settings.collect
quote
tmap($(itrng); scheduler = $(settings.scheduler)) do $(itvar)
$(init_inner)
$(forbody)
end
end
else
quote
tforeach($(itrng); scheduler = $(settings.scheduler)) do $(itvar)
$(init_inner)
$(forbody)
end
end
end

# wrap everything in a let ... end block
# and, potentially, define the `TaskLocalValue`s.
result = :(let
end)
push!(result.args[2].args, q)
if !isnothing(inits_before)
for x in inits_before
push!(result.args[1].args, x)
end
end

result
Implementation.tasks_macro(args...)
end

"""
Expand Down

0 comments on commit 2d9acb4

Please sign in to comment.