From ab1ceaaa87c3183157629f1ebea150f4703c46fb Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Mon, 16 Sep 2019 14:34:20 +0100 Subject: [PATCH 1/9] lambdas --- src/ir/print.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/ir/print.jl b/src/ir/print.jl index a53c7f5..3e08c24 100644 --- a/src/ir/print.jl +++ b/src/ir/print.jl @@ -95,3 +95,11 @@ end printers[:pop_exception] = function (io, ex) print(io, "pop exception $(ex.args[1])") end + +printers[:lambda] = function (io, ex) + io = IOContext(io, :indent=>get(io, :indent, 0)+2) + print(io, "λ: ") + printargs(io, ex.args[2:end]) + println(io) + print(io, ex.args[1]) +end From 3b2e086aad92a622e2fddf86784171635911bacd Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Feb 2020 18:17:15 +0000 Subject: [PATCH 2/9] ANF conversion --- src/IRTools.jl | 3 ++- src/ir/ir.jl | 14 +++++++--- src/passes/cps.jl | 63 ++++++++++++++++++++++++++++++++++++++++++++ src/passes/passes.jl | 5 ++-- 4 files changed, 78 insertions(+), 7 deletions(-) create mode 100644 src/passes/cps.jl diff --git a/src/IRTools.jl b/src/IRTools.jl index 3101eff..bf0a9b6 100644 --- a/src/IRTools.jl +++ b/src/IRTools.jl @@ -22,6 +22,7 @@ module Inner include("reflection/dynamo.jl") include("passes/passes.jl") + include("passes/cps.jl") include("passes/relooper.jl") include("passes/stackifier.jl") @@ -44,7 +45,7 @@ let exports = :[ # Passes/Analysis definitions, usages, dominators, domtree, domorder, domorder!, renumber, merge_returns!, expand!, prune!, ssa!, inlineable!, log!, pis!, func, evalir, - Simple, Loop, Multiple, reloop, stackify, + Simple, Loop, Multiple, reloop, stackify, functional, # Reflection, Dynamo Meta, TypedMeta, meta, typed_meta, dynamo, transform, refresh, recurse!, self, varargs!, slots!, diff --git a/src/ir/ir.jl b/src/ir/ir.jl index a622a05..106ca3d 100644 --- a/src/ir/ir.jl +++ b/src/ir/ir.jl @@ -246,6 +246,10 @@ end BasicBlock(b::Block) = b.ir.blocks[b.id] branches(b::Block) = branches(BasicBlock(b)) + +branches(ir::IR) = length(blocks(ir)) == 1 ? branches(block(ir, 1)) : + error("IR has multiple blocks, so `branches(ir)` is ambiguous.") + arguments(b::Block) = arguments(BasicBlock(b)) arguments(ir::IR) = arguments(block(ir, 1)) @@ -843,6 +847,8 @@ var!(p::Pipe) = NewVariable(p.var += 1) substitute!(p::Pipe, x, y) = (p.map[x] = y; x) substitute(p::Pipe, x::Union{Variable,NewVariable}) = p.map[x] substitute(p::Pipe, x) = get(p.map, x, x) +substitute(p::Pipe, x::Statement) = stmt(x, expr = substitute(p, x.expr)) +substitute(p::Pipe, x::Expr) = Expr(x.head, substitute.((p,), x.args)...) substitute(p::Pipe) = x -> substitute(p, x) function Pipe(ir) @@ -876,7 +882,7 @@ function iterate(p::Pipe, (ks, b, i) = (pipestate(p.from), 1, 1)) end v = ks[b][i] st = p.from[v] - substitute!(p, v, push!(p.to, prewalk(substitute(p), st))) + substitute!(p, v, push!(p.to, substitute(p, st))) ((v, st), (ks, b, i+1)) end @@ -890,7 +896,7 @@ setindex!(p::Pipe, x, v) = p.to[substitute(p, v)] = prewalk(substitute(p), x) function Base.push!(p::Pipe, x) tmp = var!(p) - substitute!(p, tmp, push!(p.to, prewalk(substitute(p), x))) + substitute!(p, tmp, push!(p.to, substitute(p, x))) return tmp end @@ -913,7 +919,7 @@ end function insert!(p::Pipe, v, x; after = false) v′ = substitute(p, v) - x = prewalk(substitute(p), x) + x = substitute(p, x) tmp = var!(p) if islastdef(p.to, v′) # we can make this case efficient by renumbering if after @@ -933,7 +939,7 @@ argument!(p::Pipe, a...; kw...) = substitute!(p, var!(p), argument!(p.to, a...; kw...)) function branch!(ir::Pipe, b, args...; kw...) - args = map(a -> postwalk(substitute(ir), a), args) + args = map(a -> substitute(ir, a), args) branch!(blocks(ir.to)[end], b, args...; kw...) return ir end diff --git a/src/passes/cps.jl b/src/passes/cps.jl new file mode 100644 index 0000000..9a7b6e6 --- /dev/null +++ b/src/passes/cps.jl @@ -0,0 +1,63 @@ +cond(c, t, f) = c ? t() : f() + +allsuccs(ir, (id,chs)) = + union(map(b -> b.id, successors(block(ir, id))), + allsuccs.((ir,), chs)...) + +function captures(ir, (id, chs), cs = Dict()) + bl = block(ir, id) + captures.((ir,), chs, (cs,)) + cs[id] = setdiff(union([cs[i] for (i, _) in chs]..., usages(bl)), definitions(bl)) + return cs +end + +function return_thunk(x) + ir = IR() + return!(ir, xcall(:getindex, argument!(ir), 1)) + Expr(:lambda, ir, x) +end + +function functionalbranches!(bl, pr, labels) + if length(branches(bl)) == 1 + br = branches(bl)[1] + if !isreturn(br) + r = push!(pr, Expr(:call, labels[br.block], br.args...)) + empty!(branches(pr.to)) + return!(pr, r) + end + else + @assert length(branches(bl)) == 2 + f, t = branches(bl) + function brfunc(br) + isreturn(br) && return return_thunk(returnvalue(br)) + @assert isempty(arguments(br)) + labels[br.block] + end + r = push!(pr, xcall(IRTools, :cond, f.condition, brfunc(t), brfunc(f))) + empty!(branches(pr.to)) + return!(pr, r) + end +end + +function _functional(ir, tree, vars = [], cs = captures(ir, tree)) + id = tree[1] + bl = IR(block(ir, id)) + labels = Dict() + labels[id] = self = id == 1 ? arguments(bl)[1] : argument!(bl, at = 1) + pr = Pipe(bl) + for (i, v) in enumerate(vars) + v′ = push!(pr, xcall(:getindex, self, i)) + v isa Integer ? (labels[v] = v′) : substitute!(pr, v, substitute(pr, v′)) + end + for _ in pr end + for t in reverse(tree[2]) + bs = filter(id -> haskey(labels, id), allsuccs(ir, t)) + λ = Expr(:lambda, _functional(ir, t, [bs..., cs[t[1]]...]), + map(b -> labels[b], bs)..., cs[t[1]]...) + labels[t[1]] = push!(pr, λ) + end + functionalbranches!(bl, pr, labels) + return finish(pr) +end + +functional(ir) = _functional(explicitbranch!(ir), domtree(ir)) diff --git a/src/passes/passes.jl b/src/passes/passes.jl index 6b37def..e33a9fd 100644 --- a/src/passes/passes.jl +++ b/src/passes/passes.jl @@ -30,7 +30,6 @@ Base.adjoint(cfg::CFG) = transpose(cfg) function definitions(b::Block) defs = [Variable(i) for i = 1:length(b.ir.defs) if b.ir.defs[i][1] == b.id] - append!(defs, arguments(b)) end function usages(b::Block) @@ -72,7 +71,7 @@ function dominators(cfg; entry = 1) return doms end -function domtree(cfg; entry = 1) +function domtree(cfg::CFG; entry = 1) doms = dominators(cfg, entry = entry) doms = Dict(b => filter(c -> b != c && b in doms[c], 1:length(cfg)) for b in 1:length(cfg)) children(b) = filter(c -> !(c in union(map(c -> doms[c], doms[b])...)), doms[b]) @@ -80,6 +79,8 @@ function domtree(cfg; entry = 1) tree(entry) end +domtree(ir::IR; entry = 1) = domtree(CFG(ir), entry = entry) + function idoms(cfg; entry = 1) ds = zeros(Int, length(cfg)) _idoms((a, bs)) = foreach(((b, cs),) -> (ds[b] = a; _idoms(b=>cs)), bs) From 96feb5d7389618137fe0d1ad01d42b3476f29a92 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Feb 2020 18:21:18 +0000 Subject: [PATCH 3/9] lambda support in dynamo --- src/reflection/dynamo.jl | 50 +++++++++++++++++++++++++++++++++++++++- src/reflection/utils.jl | 12 ++++++++++ test/compiler.jl | 17 ++++++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) diff --git a/src/reflection/dynamo.jl b/src/reflection/dynamo.jl index af1ff84..4c0c123 100644 --- a/src/reflection/dynamo.jl +++ b/src/reflection/dynamo.jl @@ -7,6 +7,20 @@ end struct Self end const self = Self() +# S -> function signature +# I -> lambda index +# T -> environment type +struct Lambda{S,I,T} + data::T +end + +Lambda{D,S}(data...) where {D,S} = + Lambda{D,S,typeof(data)}(data) + +@inline Base.getindex(l::Lambda, i::Integer) = l.data[i] + +Base.show(io::IO, l::Lambda) = print(io, "λ") + function Base.showerror(io::IO, err::CompileError) println(io, "Error compiling @dynamo $(err.transform) on $(err.args):") showerror(io, err.err) @@ -21,6 +35,24 @@ function fallthrough(args...) Expr(:call, [:(args[$i]) for i = 1:length(args)]...)) end +function lambdalift!(ir, S, I = ()) + i = 0 + for (v, st) in ir + isexpr(st.expr, :lambda) || continue + ir[v] = Expr(:call, Lambda{S,(I...,i+=1)}, st.expr.args[2:end]...) + end +end + +function getlambda(ir, I) + isempty(I) && return ir + i = 0 + for (v, st) in ir + isexpr(st.expr, :lambda) || continue + (i += 1) == I[1] && return getlambda(st.expr.args[1], Base.tail(I)) + end + error("Something has gone wrong in IRTools; couldn't find lambda in IR") +end + # Used only for its CodeInfo dummy(args...) = nothing @@ -32,6 +64,7 @@ function dynamo(f, args...) end ir isa Expr && return ir ir == nothing && return fallthrough(args...) + lambdalift!(ir, Tuple{f,args...}, ()) if ir.meta isa Meta m = ir.meta ir = varargs!(m, ir) @@ -46,6 +79,16 @@ function dynamo(f, args...) return update!(m.code, ir) end +function dynamo_lambda(f::Type{<:Lambda{S,I}}, args...) where {S,I} + ir = transform(S.parameters...) + ir = getlambda(ir, I) + lambdalift!(ir, S, I) + closureargs!(ir) + m = @meta dummy(1) + m.code.method_for_inference_limit_heuristics = nothing + return update!(m.code, ir) +end + unesc(x) = prewalk(x -> isexpr(x, :escape) ? x.args[1] : x, x) function lifttype(x) @@ -64,7 +107,12 @@ macro dynamo(ex) f, T = isexpr(name, :(::)) ? (length(name.args) == 1 ? (esc(gensym()), esc(name.args[1])) : esc.(name.args)) : (esc(gensym()), :(Core.Typeof($(esc(name))))) - gendef = :(@generated ($f::$T)($(esc(:args))...) where $(Ts...) = return IRTools.dynamo($f, args...)) + gendef = quote + @generated ($f::$T)($(esc(:args))...) where $(Ts...) = + return IRTools.dynamo($f, args...) + @generated (f::IRTools.Inner.Lambda{<:Tuple{<:$T,Vararg{Any}}})(args...) where $(Ts...) = + return IRTools.Inner.dynamo_lambda(f, args...) + end quote $(isexpr(name, :(::)) || esc(:(function $name end))) function IRTools.transform(::Type{<:$T}, $(esc.(lifttype.(args))...)) where $(Ts...) diff --git a/src/reflection/utils.jl b/src/reflection/utils.jl index bc2cbd1..fc30b71 100644 --- a/src/reflection/utils.jl +++ b/src/reflection/utils.jl @@ -102,6 +102,18 @@ function varargs!(meta, ir::IR, n = 0) return ir end +function closureargs!(ir::IR) + args = arguments(ir)[2:end] + deletearg!(ir, 2:length(arguments(ir))) + argtuple = argument!(ir) + env = Dict() + for (i, a) in reverse(collect(enumerate(args))) + env[a] = pushfirst!(ir, xcall(:getindex, argtuple, i)) + end + prewalk!(x -> get(env, x, x), ir) + return ir +end + # TODO this is hacky and leaves `ir.defs` incorrect function splicearg!(ir::IR) args = arguments(ir) diff --git a/test/compiler.jl b/test/compiler.jl index 73178b0..9952238 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -151,3 +151,20 @@ mul = func(ir) end @test ir_add(5, 2) == 7 + +@dynamo function test_lambda(x) + λ = IR() + self = argument!(λ) + y = argument!(λ) + x = push!(λ, xcall(:getindex, self, 1)) + return!(λ, xcall(:+, x, y)) + ir = IR() + args = argument!(ir) + x = push!(ir, xcall(:getindex, args, 1)) + return!(ir, Expr(:lambda, λ, x)) +end + +let + f = test_lambda(3) + @test f(6) == 9 +end From add949f81dff8d5a229db91f05086d88e85695a7 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Feb 2020 18:50:41 +0000 Subject: [PATCH 4/9] test anf with dynamo --- src/IRTools.jl | 2 +- src/reflection/dynamo.jl | 5 +++-- test/compiler.jl | 12 +++++++++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/IRTools.jl b/src/IRTools.jl index bf0a9b6..b331666 100644 --- a/src/IRTools.jl +++ b/src/IRTools.jl @@ -45,7 +45,7 @@ let exports = :[ # Passes/Analysis definitions, usages, dominators, domtree, domorder, domorder!, renumber, merge_returns!, expand!, prune!, ssa!, inlineable!, log!, pis!, func, evalir, - Simple, Loop, Multiple, reloop, stackify, functional, + Simple, Loop, Multiple, reloop, stackify, functional, cond, # Reflection, Dynamo Meta, TypedMeta, meta, typed_meta, dynamo, transform, refresh, recurse!, self, varargs!, slots!, diff --git a/src/reflection/dynamo.jl b/src/reflection/dynamo.jl index 4c0c123..913b538 100644 --- a/src/reflection/dynamo.jl +++ b/src/reflection/dynamo.jl @@ -128,9 +128,10 @@ macro code_ir(dy, ex) :(transform(typeof($(esc(dy))), meta($typesof($(esc(f)), $(esc.(args)...))))) end -function recurse!(ir) +function recurse!(ir, to = self) for (x, st) in ir isexpr(st.expr, :call) || continue - ir[x] = Expr(:call, self, st.expr.args...) + ir[x] = Expr(:call, to, st.expr.args...) end + return ir end diff --git a/test/compiler.jl b/test/compiler.jl index 9952238..d37b7c9 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,6 +1,6 @@ using IRTools, MacroTools, InteractiveUtils, Test using IRTools: @dynamo, IR, meta, isexpr, xcall, self, insertafter!, recurse!, - argument!, return!, func, var + argument!, return!, func, var, functional @dynamo roundtrip(a...) = IR(a...) @@ -168,3 +168,13 @@ let f = test_lambda(3) @test f(6) == 9 end + +anf(f::Core.IntrinsicFunction, args...) = f(args...) + +@dynamo function anf(args...) + ir = IR(args...) + ir == nothing && return + functional(recurse!(ir, anf)) +end + +@test anf(pow, 2, 3) == 8 From cbe044f99a03e72d3df4c61c1fb6d3c9ef3512b8 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 5 Feb 2020 22:05:19 +0000 Subject: [PATCH 5/9] first pass at CPS --- examples/continuations.jl | 91 +++++++++++++++++++++++++++++++++++++++ src/ir/wrap.jl | 2 +- 2 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 examples/continuations.jl diff --git a/examples/continuations.jl b/examples/continuations.jl new file mode 100644 index 0000000..3c8f767 --- /dev/null +++ b/examples/continuations.jl @@ -0,0 +1,91 @@ +using IRTools.All + +function captures(ir, vs) + us = Set() + for v in vs + isexpr(ir[v].expr) || continue + foreach(x -> x isa Variable && push!(us, x), ir[v].expr.args) + end + return setdiff(us, vs) +end + +rename(env, x) = x +rename(env, x::Variable) = env[x] +rename(env, x::Expr) = Expr(x.head, rename.((env,), x.args)...) +rename(env, x::Statement) = stmt(x, expr = rename(env, x.expr)) + +function continuation(ir, vs, cs, in, ret) + bl = empty(ir) + env = Dict() + rename(x) = Main.rename(env, x) + self = argument!(bl) + env[in] = argument!(bl) + for (i, c) in enumerate(cs) + env[c] = pushfirst!(bl, xcall(:getindex, self, i)) + end + while true + if isempty(vs) + return!(bl, rename(Expr(:call, ret, returnvalue(block(ir, 1))))) + return bl + elseif isexpr(ir[vs[1]].expr, :call) + break + else + v = popfirst!(vs) + env[v] = push!(bl, rename(ir[v])) + end + end + v = popfirst!(vs) + st = ir[v] + cs = [ret, setdiff(captures(ir, vs), [v])...] + next = push!(bl, Expr(:lambda, continuation(ir, vs, cs, v, ret), rename.(cs)...)) + ret = push!(bl, stmt(st, expr = xcall(Main, :cps, next, rename(st.expr).args...))) + return!(bl, ret) +end + +function cpstransform(ir) + ir = functional(ir) + k = argument!(ir, at = 1) + bl = empty(ir) + env = Dict() + for arg in arguments(ir) + env[arg] = argument!(bl) + end + cs = arguments(ir) + cont = push!(bl, Expr(:lambda, continuation(ir, keys(ir), cs, nothing, k), rename.((env,), cs)...)) + return!(bl, Expr(:call, cont, nothing)) + return bl +end + +cps(k, f::Core.IntrinsicFunction, args...) = k(f(args...)) +cps(k, ::typeof(cond), c, t, f) = c ? cps(k, t) : cps(k, f) +cps(k, ::typeof(cps), args...) = k(cps(args...)) + +@dynamo function cps(k, args...) + ir = IR(args...) + ir == nothing && return :(args[1](args[2](args[3:end]...))) + cpstransform(IR(args...)) +end + +function pow(x, n) + r = 1 + while n > 0 + n -= 1 + r *= x + end + return r +end + +cps(identity, pow, 2, 3) + +# shift/reset + +reset(f) = cps(identity, f) +shift(f) = error("`shift` must be called inside `reset`") +cps(k, ::typeof(shift), f) = f(k) + +macro reset(ex) + :(reset(() -> $(esc(ex)))) +end + +k = @reset shift(identity)^2 +k(4) diff --git a/src/ir/wrap.jl b/src/ir/wrap.jl index cc02103..ba5a1cb 100644 --- a/src/ir/wrap.jl +++ b/src/ir/wrap.jl @@ -160,7 +160,7 @@ slotname(ci, s) = Symbol(:_, s.id) function IR(ci::CodeInfo, nargs::Integer; meta = nothing) bs = blockstarts(ci) - ir = IR([ci.linetable...], meta = meta) + ir = IR(Core.LineInfoNode[ci.linetable...], meta = meta) _rename = Dict() rename(ex) = prewalk(ex) do x haskey(_rename, x) && return _rename[x] From 721c28320151eed533fbfc3b33a8f4b78d632a8e Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 6 Feb 2020 01:08:29 +0000 Subject: [PATCH 6/9] ir caching --- src/reflection/dynamo.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/reflection/dynamo.jl b/src/reflection/dynamo.jl index 913b538..e18a645 100644 --- a/src/reflection/dynamo.jl +++ b/src/reflection/dynamo.jl @@ -41,6 +41,7 @@ function lambdalift!(ir, S, I = ()) isexpr(st.expr, :lambda) || continue ir[v] = Expr(:call, Lambda{S,(I...,i+=1)}, st.expr.args[2:end]...) end + return ir end function getlambda(ir, I) @@ -56,7 +57,7 @@ end # Used only for its CodeInfo dummy(args...) = nothing -function dynamo(f, args...) +function dynamo(cache, f, args...) try ir = transform(f, args...)::Union{IR,Expr,Nothing} catch e @@ -64,7 +65,8 @@ function dynamo(f, args...) end ir isa Expr && return ir ir == nothing && return fallthrough(args...) - lambdalift!(ir, Tuple{f,args...}, ()) + cache[args] = ir + ir = lambdalift!(copy(ir), Tuple{f,args...}) if ir.meta isa Meta m = ir.meta ir = varargs!(m, ir) @@ -79,10 +81,10 @@ function dynamo(f, args...) return update!(m.code, ir) end -function dynamo_lambda(f::Type{<:Lambda{S,I}}, args...) where {S,I} - ir = transform(S.parameters...) +function dynamo_lambda(cache, f::Type{<:Lambda{S,I}}) where {S,I} + ir = cache[(S.parameters[2:end]...,)] ir = getlambda(ir, I) - lambdalift!(ir, S, I) + ir = lambdalift!(copy(ir), S, I) closureargs!(ir) m = @meta dummy(1) m.code.method_for_inference_limit_heuristics = nothing @@ -108,10 +110,11 @@ macro dynamo(ex) (length(name.args) == 1 ? (esc(gensym()), esc(name.args[1])) : esc.(name.args)) : (esc(gensym()), :(Core.Typeof($(esc(name))))) gendef = quote + local cache = Dict() @generated ($f::$T)($(esc(:args))...) where $(Ts...) = - return IRTools.dynamo($f, args...) + return IRTools.dynamo(cache, $f, args...) @generated (f::IRTools.Inner.Lambda{<:Tuple{<:$T,Vararg{Any}}})(args...) where $(Ts...) = - return IRTools.Inner.dynamo_lambda(f, args...) + return IRTools.Inner.dynamo_lambda(cache, f) end quote $(isexpr(name, :(::)) || esc(:(function $name end))) From 8b14b0efecf6d7acefa2bb6ded1b06c73dde1896 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 6 Feb 2020 12:03:06 +0000 Subject: [PATCH 7/9] continuations example --- examples/amb.jl | 52 ++++++++++++++++++++++++++++++ examples/continuations.jl | 67 +++++++++++++++++++++++---------------- src/IRTools.jl | 4 +-- src/ir/ir.jl | 2 +- src/passes/cps.jl | 11 +++++-- 5 files changed, 104 insertions(+), 32 deletions(-) create mode 100644 examples/amb.jl diff --git a/examples/amb.jl b/examples/amb.jl new file mode 100644 index 0000000..5facd38 --- /dev/null +++ b/examples/amb.jl @@ -0,0 +1,52 @@ +# Implementation of the amb operator using shift/reset +# http://community.schemewiki.org/?amb + +include("continuations.jl"); + +struct Backtrack end + +function require(x) + x || throw(Backtrack()) + return +end + +unwrap(e) = e +unwrap(e::CapturedException) = e.ex + +function amb(iter) + shift() do k + for x in iter + try + return k(x) + catch e + unwrap(e) isa Backtrack || rethrow() + end + end + throw(Backtrack()) + end +end + +function ambrun(f) + try + @reset f() + catch e + e isa Backtrack || rethrow() + error("No possible combination found.") + end +end + +ambrun() do + x = amb([1, 2, 3]) + y = amb([1, 2, 3]) + require(x^2 + y == 7) + (x, y) +end + +ambrun() do + N = 20 + i = amb(1:N) + j = amb(i:N) + k = amb(j:N) + require(i*i + j*j == k*k) + (i, j, k) +end diff --git a/examples/continuations.jl b/examples/continuations.jl index 3c8f767..b4788d7 100644 --- a/examples/continuations.jl +++ b/examples/continuations.jl @@ -1,5 +1,18 @@ +# An implementation of delimited continuations (the shift/reset operators) in +# Julia. Works by transforming all Julia code to continuation passing style. +# The `shift` operator then just has to return the continuation. + +# https://en.wikipedia.org/wiki/Delimited_continuation +# https://en.wikipedia.org/wiki/Continuation-passing_style + using IRTools.All +struct Func + f # Avoid over-specialising on the continuation object. +end + +(f::Func)(args...) = f.f(args...) + function captures(ir, vs) us = Set() for v in vs @@ -14,6 +27,8 @@ rename(env, x::Variable) = env[x] rename(env, x::Expr) = Expr(x.head, rename.((env,), x.args)...) rename(env, x::Statement) = stmt(x, expr = rename(env, x.expr)) +excluded = [GlobalRef(Base, :getindex)] + function continuation(ir, vs, cs, in, ret) bl = empty(ir) env = Dict() @@ -23,28 +38,28 @@ function continuation(ir, vs, cs, in, ret) for (i, c) in enumerate(cs) env[c] = pushfirst!(bl, xcall(:getindex, self, i)) end + local v, st while true - if isempty(vs) - return!(bl, rename(Expr(:call, ret, returnvalue(block(ir, 1))))) - return bl - elseif isexpr(ir[vs[1]].expr, :call) - break - else - v = popfirst!(vs) - env[v] = push!(bl, rename(ir[v])) - end + isempty(vs) && return return!(bl, rename(Expr(:call, ret, returnvalue(block(ir, 1))))) + v = popfirst!(vs) + st = ir[v] + isexpr(st.expr, :call) && !(st.expr.args[1] ∈ excluded) && break + isexpr(st.expr, :lambda) && + (st = stmt(st, expr = Expr(:lambda, cpslambda(st.expr.args[1]), st.expr.args[2:end]...))) + env[v] = push!(bl, rename(st)) end - v = popfirst!(vs) - st = ir[v] cs = [ret, setdiff(captures(ir, vs), [v])...] next = push!(bl, Expr(:lambda, continuation(ir, vs, cs, v, ret), rename.(cs)...)) + next = xcall(Main, :Func, next) ret = push!(bl, stmt(st, expr = xcall(Main, :cps, next, rename(st.expr).args...))) return!(bl, ret) end -function cpstransform(ir) - ir = functional(ir) - k = argument!(ir, at = 1) +cpslambda(ir) = cpstransform(ir, true) + +function cpstransform(ir, lambda = false) + lambda || (ir = functional(ir)) + k = argument!(ir, at = lambda ? 2 : 1) bl = empty(ir) env = Dict() for arg in arguments(ir) @@ -57,26 +72,21 @@ function cpstransform(ir) end cps(k, f::Core.IntrinsicFunction, args...) = k(f(args...)) +cps(k, f::IRTools.Lambda{<:Tuple{typeof(cps),Vararg{Any}}}, args...) = f(k, args...) cps(k, ::typeof(cond), c, t, f) = c ? cps(k, t) : cps(k, f) cps(k, ::typeof(cps), args...) = k(cps(args...)) +# Speed up compilation +for f in [Broadcast.broadcasted, Broadcast.materialize] + @eval cps(k, ::typeof($f), args...) = k($f(args...)) +end + @dynamo function cps(k, args...) ir = IR(args...) ir == nothing && return :(args[1](args[2](args[3:end]...))) cpstransform(IR(args...)) end -function pow(x, n) - r = 1 - while n > 0 - n -= 1 - r *= x - end - return r -end - -cps(identity, pow, 2, 3) - # shift/reset reset(f) = cps(identity, f) @@ -87,5 +97,8 @@ macro reset(ex) :(reset(() -> $(esc(ex)))) end -k = @reset shift(identity)^2 -k(4) +k = @reset begin + shift(k -> k)^2 +end + +k(4) == 16 diff --git a/src/IRTools.jl b/src/IRTools.jl index b331666..693f1c7 100644 --- a/src/IRTools.jl +++ b/src/IRTools.jl @@ -37,7 +37,7 @@ end let exports = :[ # IR - IR, Block, BasicBlock, Variable, Statement, Branch, Pipe, CFG, Slot, branch, var, stmt, arguments, argtypes, + IRTools, IR, Block, BasicBlock, Variable, Statement, Branch, Pipe, CFG, Slot, branch, var, stmt, arguments, argtypes, branches, undef, unreachable, isreturn, isconditional, block!, deleteblock!, branch!, argument!, return!, canbranch, returnvalue, returntype, emptyargs!, deletearg!, block, blocks, successors, predecessors, xcall, exprtype, exprline, isexpr, insertafter!, explicitbranch!, prewalk, postwalk, @@ -47,7 +47,7 @@ let exports = :[ merge_returns!, expand!, prune!, ssa!, inlineable!, log!, pis!, func, evalir, Simple, Loop, Multiple, reloop, stackify, functional, cond, # Reflection, Dynamo - Meta, TypedMeta, meta, typed_meta, dynamo, transform, refresh, recurse!, self, + Meta, TypedMeta, Lambda, meta, typed_meta, dynamo, transform, refresh, recurse!, self, varargs!, slots!, ].args append!(exports, Symbol.(["@code_ir", "@dynamo", "@meta", "@typed_meta"])) diff --git a/src/ir/ir.jl b/src/ir/ir.jl index 106ca3d..a86ba2f 100644 --- a/src/ir/ir.jl +++ b/src/ir/ir.jl @@ -902,7 +902,7 @@ end function Base.pushfirst!(p::Pipe, x) tmp = var!(p) - substitute!(p, tmp, pushfirst!(p.to, prewalk(substitute(p), x))) + substitute!(p, tmp, pushfirst!(p.to, substitute(p, x))) return tmp end diff --git a/src/passes/cps.jl b/src/passes/cps.jl index 9a7b6e6..2a1c1d7 100644 --- a/src/passes/cps.jl +++ b/src/passes/cps.jl @@ -17,6 +17,13 @@ function return_thunk(x) Expr(:lambda, ir, x) end +function br_thunk(args...) + ir = IR() + self = argument!(ir) + return!(ir, xcall([xcall(:getindex, self, i) for i = 1:length(args)]...)) + Expr(:lambda, ir, args...) +end + function functionalbranches!(bl, pr, labels) if length(branches(bl)) == 1 br = branches(bl)[1] @@ -30,8 +37,8 @@ function functionalbranches!(bl, pr, labels) f, t = branches(bl) function brfunc(br) isreturn(br) && return return_thunk(returnvalue(br)) - @assert isempty(arguments(br)) - labels[br.block] + isempty(arguments(br)) && return labels[br.block] + br_thunk(labels[br.block], arguments(br)...) end r = push!(pr, xcall(IRTools, :cond, f.condition, brfunc(t), brfunc(f))) empty!(branches(pr.to)) From 5c36ce037b1e149cdd9cf13ec7ba352c9f607b03 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 6 Feb 2020 12:43:29 +0000 Subject: [PATCH 8/9] nicer lambda printing --- src/ir/print.jl | 37 ++++++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/src/ir/print.jl b/src/ir/print.jl index 3e08c24..386a673 100644 --- a/src/ir/print.jl +++ b/src/ir/print.jl @@ -1,7 +1,10 @@ import Base: show # TODO: real expression printing -Base.show(io::IO, x::Variable) = print(io, "%", x.id) +function Base.show(io::IO, x::Variable) + bs = get(io, :bindings, Dict()) + haskey(bs, x) ? print(io, bs[x]) : print(io, "%", x.id) +end const printers = Dict{Symbol,Any}() @@ -13,12 +16,12 @@ function show(io::IO, b::Branch) if b == unreachable print(io, "unreachable") elseif isreturn(b) - print(io, "return $(repr(b.args[1]))") + print(io, "return ", b.args[1]) else print(io, "br $(b.block)") if !isempty(b.args) print(io, " (") - join(io, repr.(b.args), ", ") + join(io, b.args, ", ") print(io, ")") end b.condition != nothing && print(io, " unless $(b.condition)") @@ -39,6 +42,7 @@ end function show(io::IO, b::Block) indent = get(io, :indent, 0) + bs = get(io, :bindings, Dict()) bb = BasicBlock(b) print(io, tab^indent) print(io, b.id, ":") @@ -47,6 +51,7 @@ function show(io::IO, b::Block) printargs(io, bb.args, bb.argtypes) end for (x, st) in b + haskey(bs, x) && continue println(io) print(io, tab^indent, " ") x == nothing || print(io, string("%", x.id), " = ") @@ -87,7 +92,7 @@ printers[:catch] = function (io, ex) args = ex.args[2:end] if !isempty(args) print(io, " (") - join(io, repr.(args), ", ") + join(io, args, ", ") print(io, ")") end end @@ -96,10 +101,28 @@ printers[:pop_exception] = function (io, ex) print(io, "pop exception $(ex.args[1])") end +function lambdacx(io, ex) + bs = get(io, :bindings, Dict()) + ir = ex.args[1] + args = ex.args[2:end] + bs′ = Dict() + for (v, st) in ir + ex = st.expr + if iscall(ex, GlobalRef(Base, :getindex)) && + ex.args[2] == arguments(ir)[1] && + ex.args[3] isa Integer + x = args[ex.args[3]] + bs′[v] = string(get(bs, x, x), "'") + end + end + return bs′ +end + printers[:lambda] = function (io, ex) - io = IOContext(io, :indent=>get(io, :indent, 0)+2) - print(io, "λ: ") - printargs(io, ex.args[2:end]) + print(io, "λ :") + # printargs(io, ex.args[2:end]) + io = IOContext(io, :indent => get(io, :indent, 0)+2, + :bindings => lambdacx(io, ex)) println(io) print(io, ex.args[1]) end From ee046a7f57256f15ba442bb476d039a43944af58 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 6 Feb 2020 13:28:22 +0000 Subject: [PATCH 9/9] better cps output --- examples/continuations.jl | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/examples/continuations.jl b/examples/continuations.jl index b4788d7..9dd762b 100644 --- a/examples/continuations.jl +++ b/examples/continuations.jl @@ -29,15 +29,8 @@ rename(env, x::Statement) = stmt(x, expr = rename(env, x.expr)) excluded = [GlobalRef(Base, :getindex)] -function continuation(ir, vs, cs, in, ret) - bl = empty(ir) - env = Dict() +function continuation!(bl, ir, env, vs, ret) rename(x) = Main.rename(env, x) - self = argument!(bl) - env[in] = argument!(bl) - for (i, c) in enumerate(cs) - env[c] = pushfirst!(bl, xcall(:getindex, self, i)) - end local v, st while true isempty(vs) && return return!(bl, rename(Expr(:call, ret, returnvalue(block(ir, 1))))) @@ -49,12 +42,27 @@ function continuation(ir, vs, cs, in, ret) env[v] = push!(bl, rename(st)) end cs = [ret, setdiff(captures(ir, vs), [v])...] - next = push!(bl, Expr(:lambda, continuation(ir, vs, cs, v, ret), rename.(cs)...)) - next = xcall(Main, :Func, next) + if isempty(vs) + next = rename(ret) + else + next = push!(bl, Expr(:lambda, continuation(ir, vs, cs, v, ret), rename.(cs)...)) + next = xcall(Main, :Func, next) + end ret = push!(bl, stmt(st, expr = xcall(Main, :cps, next, rename(st.expr).args...))) return!(bl, ret) end +function continuation(ir, vs, cs, in, ret) + bl = empty(ir) + env = Dict() + self = argument!(bl) + env[in] = argument!(bl) + for (i, c) in enumerate(cs) + env[c] = pushfirst!(bl, xcall(:getindex, self, i)) + end + continuation!(bl, ir, env, vs, ret) +end + cpslambda(ir) = cpstransform(ir, true) function cpstransform(ir, lambda = false) @@ -65,10 +73,7 @@ function cpstransform(ir, lambda = false) for arg in arguments(ir) env[arg] = argument!(bl) end - cs = arguments(ir) - cont = push!(bl, Expr(:lambda, continuation(ir, keys(ir), cs, nothing, k), rename.((env,), cs)...)) - return!(bl, Expr(:call, cont, nothing)) - return bl + continuation!(bl, ir, env, keys(ir), k) end cps(k, f::Core.IntrinsicFunction, args...) = k(f(args...))