From df70b69a8f8552cc1bebed770187099a27c3bfdf Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 1 Mar 2024 12:07:59 +0100 Subject: [PATCH 1/5] move macro body to the internals module --- src/implementation.jl | 2 + src/macro_impl.jl | 191 ++++++++++++++++++++++++++++++++++++++++++ src/macros.jl | 189 +---------------------------------------- 3 files changed, 194 insertions(+), 188 deletions(-) create mode 100644 src/macro_impl.jl diff --git a/src/implementation.jl b/src/implementation.jl index 83740a30..ef6d2736 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -13,6 +13,8 @@ using BangBang: append!! using ChunkSplitters: ChunkSplitters +include("macro_impl.jl") + function auto_disable_chunking_warning() @warn("You passed in a `ChunkSplitters.Chunk` but also a scheduler that has "* "chunking enabled. Will turn off internal chunking to proceed.\n"* diff --git a/src/macro_impl.jl b/src/macro_impl.jl new file mode 100644 index 00000000..0c114d8f --- /dev/null +++ b/src/macro_impl.jl @@ -0,0 +1,191 @@ +Base.@kwdef mutable struct Settings + scheduler::Expr = :(DynamicScheduler()) + reducer::Union{Expr, Symbol, Nothing} = nothing + collect::Bool = false +end + +function tasks_macro(args...) + length(args) + forex = last(args) + if forex.head != :for || length(args) > 1 + throw(ErrorException("Expected a for loop after `@tasks`.")) + else + it = forex.args[1] + itvar = it.args[1] + itrng = it.args[2] + forbody = forex.args[2] + end + + settings = Settings() + + # kwexs = args[begin:(end - 1)] + # for ex in kwexs + # name, val = _kwarg_to_tuple(ex) + # if name == :scheduler + # settings.scheduler = val isa Symbol ? _sym2scheduler(val) : val + # elseif name == :reducer + # settings.reducer = val + # else + # throw(ArgumentError("Unknown keyword argument: $name")) + # end + # end + + inits_before, init_inner = _maybe_handle_init_block!(forbody.args) + _maybe_handle_set_block!(settings, forbody.args) + + forbody = esc(forbody) + itrng = esc(itrng) + itvar = esc(itvar) + + # @show settings + q = if !isnothing(settings.reducer) + quote + tmapreduce( + $(settings.reducer), $(itrng); scheduler = $(settings.scheduler)) do $(itvar) + $(init_inner) + $(forbody) + end + end + elseif settings.collect + quote + tmap($(itrng); scheduler = $(settings.scheduler)) do $(itvar) + $(init_inner) + $(forbody) + end + end + else + quote + tforeach($(itrng); scheduler = $(settings.scheduler)) do $(itvar) + $(init_inner) + $(forbody) + end + end + end + + # wrap everything in a let ... end block + # and, potentially, define the `TaskLocalValue`s. + result = :(let + end) + push!(result.args[2].args, q) + if !isnothing(inits_before) + for x in inits_before + push!(result.args[1].args, x) + end + end + + result +end + +function _sym2scheduler(s) + if s == :dynamic + :(DynamicScheduler()) + elseif s == :static + :(StaticScheduler()) + elseif s == :greedy + :(GreedyScheduler()) + else + throw(ArgumentError("Unknown scheduler symbol.")) + end +end + +function _maybe_handle_init_block!(args) + inits_before = nothing + init_inner = nothing + tlsidx = findfirst(args) do arg + arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@init") + end + if !isnothing(tlsidx) + inits_before, init_inner = _unfold_init_block(args[tlsidx].args[3]) + deleteat!(args, tlsidx) + end + return inits_before, init_inner +end + +function _unfold_init_block(ex) + inits_before = Expr[] + if ex.head == :(=) + initb, init_inner = _init_assign_to_exprs(ex) + push!(inits_before, initb) + elseif ex.head == :block + tlsexprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode + init_inner = quote end + for x in tlsexprs + initb, initi = _init_assign_to_exprs(x) + push!(inits_before, initb) + push!(init_inner.args, initi) + end + else + throw(ErrorException("Wrong usage of @init. You must either provide a typed assignment or multiple typed assignments in a `begin ... end` block.")) + end + return inits_before, init_inner +end + +function _init_assign_to_exprs(ex) + left_ex = ex.args[1] + if left_ex isa Symbol || left_ex.head != :(::) + throw(ErrorException("Wrong usage of @init. Expected typed assignment, e.g. `A::Matrix{Float} = rand(2,2)`.")) + end + tls_sym = esc(left_ex.args[1]) + tls_type = esc(left_ex.args[2]) + tls_def = esc(ex.args[2]) + @gensym tls_storage + init_before = :($(tls_storage) = OhMyThreads.TaskLocalValue{$tls_type}(() -> $(tls_def))) + init_inner = :($(tls_sym) = $(tls_storage)[]) + return init_before, init_inner +end + +function _maybe_handle_set_block!(settings, args) + idcs = findall(args) do arg + arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set") + end + isnothing(idcs) && return # no set block found + for i in idcs + ex = args[i].args[3] + if ex.head == :(=) + _handle_set_single_assign!(settings, ex) + elseif ex.head == :block + exprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode + _handle_set_single_assign!.(Ref(settings), exprs) + else + throw(ErrorException("Wrong usage of @set. You must either provide an assignment or multiple assignments in a `begin ... end` block.")) + end + end + deleteat!(args, idcs) + # check incompatible settings + if settings.collect && !isnothing(settings.reducer) + throw(ArgumentError("Specifying both collect and reducer isn't supported.")) + end +end + +function _handle_set_single_assign!(settings, ex) + if ex.head != :(=) + throw(ErrorException("Wrong usage of @set. Expected assignment, e.g. `scheduler = StaticScheduler()`.")) + end + sym = ex.args[1] + if !hasfield(Settings, sym) + throw(ArgumentError("Unknown setting \"$(sym)\". Must be ∈ $(fieldnames(Settings)).")) + end + def = ex.args[2] + if sym == :collect && !(def isa Bool) + throw(ArgumentError("Setting collect can only be true or false.")) + #TODO support specifying the OutputElementType + end + def = if def isa QuoteNode + _sym2scheduler(def.value) + elseif def isa Bool + def + else + esc(def) + end + setfield!(settings, sym, def) +end + +# function _kwarg_to_tuple(ex) +# ex.head != :(=) && +# throw(ArgumentError("Invalid keyword argument. Doesn't contain '='.")) +# name, val = ex.args +# !(name isa Symbol) && +# throw(ArgumentError("First part of keyword argument isn't a symbol.")) +# val isa QuoteNode && (val = val.value) +# (name, val) +# end diff --git a/src/macros.jl b/src/macros.jl index 5ae959b5..ed56b416 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -1,123 +1,3 @@ -Base.@kwdef mutable struct Settings - scheduler::Expr = :(DynamicScheduler()) - reducer::Union{Expr, Symbol, Nothing} = nothing - collect::Bool = false -end - -function _sym2scheduler(s) - if s == :dynamic - :(DynamicScheduler()) - elseif s == :static - :(StaticScheduler()) - elseif s == :greedy - :(GreedyScheduler()) - else - throw(ArgumentError("Unknown scheduler symbol.")) - end -end - -function _maybe_handle_init_block!(args) - inits_before = nothing - init_inner = nothing - tlsidx = findfirst(args) do arg - arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@init") - end - if !isnothing(tlsidx) - inits_before, init_inner = _unfold_init_block(args[tlsidx].args[3]) - deleteat!(args, tlsidx) - end - return inits_before, init_inner -end - -function _unfold_init_block(ex) - inits_before = Expr[] - if ex.head == :(=) - initb, init_inner = _init_assign_to_exprs(ex) - push!(inits_before, initb) - elseif ex.head == :block - tlsexprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode - init_inner = quote end - for x in tlsexprs - initb, initi = _init_assign_to_exprs(x) - push!(inits_before, initb) - push!(init_inner.args, initi) - end - else - throw(ErrorException("Wrong usage of @init. You must either provide a typed assignment or multiple typed assignments in a `begin ... end` block.")) - end - return inits_before, init_inner -end - -function _init_assign_to_exprs(ex) - left_ex = ex.args[1] - if left_ex isa Symbol || left_ex.head != :(::) - throw(ErrorException("Wrong usage of @init. Expected typed assignment, e.g. `A::Matrix{Float} = rand(2,2)`.")) - end - tls_sym = esc(left_ex.args[1]) - tls_type = esc(left_ex.args[2]) - tls_def = esc(ex.args[2]) - @gensym tls_storage - init_before = :($(tls_storage) = OhMyThreads.TaskLocalValue{$tls_type}(() -> $(tls_def))) - init_inner = :($(tls_sym) = $(tls_storage)[]) - return init_before, init_inner -end - -function _maybe_handle_set_block!(settings, args) - idcs = findall(args) do arg - arg isa Expr && arg.head == :macrocall && arg.args[1] == Symbol("@set") - end - isnothing(idcs) && return # no set block found - for i in idcs - ex = args[i].args[3] - if ex.head == :(=) - _handle_set_single_assign!(settings, ex) - elseif ex.head == :block - exprs = filter(x -> x isa Expr, ex.args) # skip LineNumberNode - _handle_set_single_assign!.(Ref(settings), exprs) - else - throw(ErrorException("Wrong usage of @set. You must either provide an assignment or multiple assignments in a `begin ... end` block.")) - end - end - deleteat!(args, idcs) - # check incompatible settings - if settings.collect && !isnothing(settings.reducer) - throw(ArgumentError("Specifying both collect and reducer isn't supported.")) - end -end - -function _handle_set_single_assign!(settings, ex) - if ex.head != :(=) - throw(ErrorException("Wrong usage of @set. Expected assignment, e.g. `scheduler = StaticScheduler()`.")) - end - sym = ex.args[1] - if !hasfield(Settings, sym) - throw(ArgumentError("Unknown setting \"$(sym)\". Must be ∈ $(fieldnames(Settings)).")) - end - def = ex.args[2] - if sym == :collect && !(def isa Bool) - throw(ArgumentError("Setting collect can only be true or false.")) - #TODO support specifying the OutputElementType - end - def = if def isa QuoteNode - _sym2scheduler(def.value) - elseif def isa Bool - def - else - esc(def) - end - setfield!(settings, sym, def) -end - -# function _kwarg_to_tuple(ex) -# ex.head != :(=) && -# throw(ArgumentError("Invalid keyword argument. Doesn't contain '='.")) -# name, val = ex.args -# !(name isa Symbol) && -# throw(ArgumentError("First part of keyword argument isn't a symbol.")) -# val isa QuoteNode && (val = val.value) -# (name, val) -# end - """ @tasks for ... end @@ -170,74 +50,7 @@ end ``` """ macro tasks(args...) - forex = last(args) - if forex.head != :for || length(args) > 1 - throw(ErrorException("Expected a for loop after `@tasks`.")) - else - it = forex.args[1] - itvar = it.args[1] - itrng = it.args[2] - forbody = forex.args[2] - end - - settings = Settings() - - # kwexs = args[begin:(end - 1)] - # for ex in kwexs - # name, val = _kwarg_to_tuple(ex) - # if name == :scheduler - # settings.scheduler = val isa Symbol ? _sym2scheduler(val) : val - # elseif name == :reducer - # settings.reducer = val - # else - # throw(ArgumentError("Unknown keyword argument: $name")) - # end - # end - - inits_before, init_inner = _maybe_handle_init_block!(forbody.args) - _maybe_handle_set_block!(settings, forbody.args) - - forbody = esc(forbody) - itrng = esc(itrng) - itvar = esc(itvar) - - # @show settings - q = if !isnothing(settings.reducer) - quote - tmapreduce( - $(settings.reducer), $(itrng); scheduler = $(settings.scheduler)) do $(itvar) - $(init_inner) - $(forbody) - end - end - elseif settings.collect - quote - tmap($(itrng); scheduler = $(settings.scheduler)) do $(itvar) - $(init_inner) - $(forbody) - end - end - else - quote - tforeach($(itrng); scheduler = $(settings.scheduler)) do $(itvar) - $(init_inner) - $(forbody) - end - end - end - - # wrap everything in a let ... end block - # and, potentially, define the `TaskLocalValue`s. - result = :(let - end) - push!(result.args[2].args, q) - if !isnothing(inits_before) - for x in inits_before - push!(result.args[1].args, x) - end - end - - result + Implementation.tasks_macro(args...) end """ From 7363c09229a7b00ded37a319184e7c1ff62e704c Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 1 Mar 2024 12:12:14 +0100 Subject: [PATCH 2/5] only accept one argument, error if the user gives a complex loop assignemnt --- src/macro_impl.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/macro_impl.jl b/src/macro_impl.jl index 0c114d8f..248d27d3 100644 --- a/src/macro_impl.jl +++ b/src/macro_impl.jl @@ -4,12 +4,13 @@ Base.@kwdef mutable struct Settings collect::Bool = false end -function tasks_macro(args...) - length(args) - forex = last(args) - if forex.head != :for || length(args) > 1 +function tasks_macro(forex) + if forex.head != :for throw(ErrorException("Expected a for loop after `@tasks`.")) else + if forex.args[1].head != :(=) + throw(ErrorException("`@tasks` currently only supports a single threaded loop, got $(forex.args[1])")) + end it = forex.args[1] itvar = it.args[1] itrng = it.args[2] From 1965cfbcb9182d781fffb80fccb061c0ba4dc8a1 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 1 Mar 2024 12:16:29 +0100 Subject: [PATCH 3/5] remove double quoting, reformat comment --- src/macro_impl.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/macro_impl.jl b/src/macro_impl.jl index 248d27d3..8577cb68 100644 --- a/src/macro_impl.jl +++ b/src/macro_impl.jl @@ -9,6 +9,10 @@ function tasks_macro(forex) throw(ErrorException("Expected a for loop after `@tasks`.")) else if forex.args[1].head != :(=) + # this'll catch cases like + # @tasks for _ ∈ 1:10, _ ∈ 1:10 + # body + # end throw(ErrorException("`@tasks` currently only supports a single threaded loop, got $(forex.args[1])")) end it = forex.args[1] @@ -130,7 +134,7 @@ function _init_assign_to_exprs(ex) tls_type = esc(left_ex.args[2]) tls_def = esc(ex.args[2]) @gensym tls_storage - init_before = :($(tls_storage) = OhMyThreads.TaskLocalValue{$tls_type}(() -> $(tls_def))) + init_before = :($(tls_storage) = TaskLocalValue{$tls_type}(() -> $(tls_def))) init_inner = :($(tls_sym) = $(tls_storage)[]) return init_before, init_inner end From 5ab323e4a878edf3764d11ea646112bd7df28e37 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 1 Mar 2024 12:17:01 +0100 Subject: [PATCH 4/5] remove outdated code --- src/macro_impl.jl | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/macro_impl.jl b/src/macro_impl.jl index 8577cb68..d3f20e82 100644 --- a/src/macro_impl.jl +++ b/src/macro_impl.jl @@ -23,18 +23,6 @@ function tasks_macro(forex) settings = Settings() - # kwexs = args[begin:(end - 1)] - # for ex in kwexs - # name, val = _kwarg_to_tuple(ex) - # if name == :scheduler - # settings.scheduler = val isa Symbol ? _sym2scheduler(val) : val - # elseif name == :reducer - # settings.reducer = val - # else - # throw(ArgumentError("Unknown keyword argument: $name")) - # end - # end - inits_before, init_inner = _maybe_handle_init_block!(forbody.args) _maybe_handle_set_block!(settings, forbody.args) @@ -184,13 +172,3 @@ function _handle_set_single_assign!(settings, ex) end setfield!(settings, sym, def) end - -# function _kwarg_to_tuple(ex) -# ex.head != :(=) && -# throw(ArgumentError("Invalid keyword argument. Doesn't contain '='.")) -# name, val = ex.args -# !(name isa Symbol) && -# throw(ArgumentError("First part of keyword argument isn't a symbol.")) -# val isa QuoteNode && (val = val.value) -# (name, val) -# end From 113cebcee1a0cdb1f2a5c0fa75e435d0e49b1bf0 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 1 Mar 2024 12:17:39 +0100 Subject: [PATCH 5/5] re-org --- src/macro_impl.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/macro_impl.jl b/src/macro_impl.jl index d3f20e82..995d24be 100644 --- a/src/macro_impl.jl +++ b/src/macro_impl.jl @@ -1,9 +1,3 @@ -Base.@kwdef mutable struct Settings - scheduler::Expr = :(DynamicScheduler()) - reducer::Union{Expr, Symbol, Nothing} = nothing - collect::Bool = false -end - function tasks_macro(forex) if forex.head != :for throw(ErrorException("Expected a for loop after `@tasks`.")) @@ -69,6 +63,12 @@ function tasks_macro(forex) result end +Base.@kwdef mutable struct Settings + scheduler::Expr = :(DynamicScheduler()) + reducer::Union{Expr, Symbol, Nothing} = nothing + collect::Bool = false +end + function _sym2scheduler(s) if s == :dynamic :(DynamicScheduler())