Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lambdas #27

Merged
merged 9 commits into from
Feb 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/amb.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Implementation of the amb operator using shift/reset
# http://community.schemewiki.org/?amb

include("continuations.jl");

struct Backtrack end

function require(x)
x || throw(Backtrack())
return
end

unwrap(e) = e
unwrap(e::CapturedException) = e.ex

function amb(iter)
shift() do k
for x in iter
try
return k(x)
catch e
unwrap(e) isa Backtrack || rethrow()
end
end
throw(Backtrack())
end
end

function ambrun(f)
try
@reset f()
catch e
e isa Backtrack || rethrow()
error("No possible combination found.")
end
end

ambrun() do
x = amb([1, 2, 3])
y = amb([1, 2, 3])
require(x^2 + y == 7)
(x, y)
end

ambrun() do
N = 20
i = amb(1:N)
j = amb(i:N)
k = amb(j:N)
require(i*i + j*j == k*k)
(i, j, k)
end
109 changes: 109 additions & 0 deletions examples/continuations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# An implementation of delimited continuations (the shift/reset operators) in
# Julia. Works by transforming all Julia code to continuation passing style.
# The `shift` operator then just has to return the continuation.

# https://en.wikipedia.org/wiki/Delimited_continuation
# https://en.wikipedia.org/wiki/Continuation-passing_style

using IRTools.All

struct Func
f # Avoid over-specialising on the continuation object.
end

(f::Func)(args...) = f.f(args...)

function captures(ir, vs)
us = Set()
for v in vs
isexpr(ir[v].expr) || continue
foreach(x -> x isa Variable && push!(us, x), ir[v].expr.args)
end
return setdiff(us, vs)
end

rename(env, x) = x
rename(env, x::Variable) = env[x]
rename(env, x::Expr) = Expr(x.head, rename.((env,), x.args)...)
rename(env, x::Statement) = stmt(x, expr = rename(env, x.expr))

excluded = [GlobalRef(Base, :getindex)]

function continuation!(bl, ir, env, vs, ret)
rename(x) = Main.rename(env, x)
local v, st
while true
isempty(vs) && return return!(bl, rename(Expr(:call, ret, returnvalue(block(ir, 1)))))
v = popfirst!(vs)
st = ir[v]
isexpr(st.expr, :call) && !(st.expr.args[1] ∈ excluded) && break
isexpr(st.expr, :lambda) &&
(st = stmt(st, expr = Expr(:lambda, cpslambda(st.expr.args[1]), st.expr.args[2:end]...)))
env[v] = push!(bl, rename(st))
end
cs = [ret, setdiff(captures(ir, vs), [v])...]
if isempty(vs)
next = rename(ret)
else
next = push!(bl, Expr(:lambda, continuation(ir, vs, cs, v, ret), rename.(cs)...))
next = xcall(Main, :Func, next)
end
ret = push!(bl, stmt(st, expr = xcall(Main, :cps, next, rename(st.expr).args...)))
return!(bl, ret)
end

function continuation(ir, vs, cs, in, ret)
bl = empty(ir)
env = Dict()
self = argument!(bl)
env[in] = argument!(bl)
for (i, c) in enumerate(cs)
env[c] = pushfirst!(bl, xcall(:getindex, self, i))
end
continuation!(bl, ir, env, vs, ret)
end

cpslambda(ir) = cpstransform(ir, true)

function cpstransform(ir, lambda = false)
lambda || (ir = functional(ir))
k = argument!(ir, at = lambda ? 2 : 1)
bl = empty(ir)
env = Dict()
for arg in arguments(ir)
env[arg] = argument!(bl)
end
continuation!(bl, ir, env, keys(ir), k)
end

cps(k, f::Core.IntrinsicFunction, args...) = k(f(args...))
cps(k, f::IRTools.Lambda{<:Tuple{typeof(cps),Vararg{Any}}}, args...) = f(k, args...)
cps(k, ::typeof(cond), c, t, f) = c ? cps(k, t) : cps(k, f)
cps(k, ::typeof(cps), args...) = k(cps(args...))

# Speed up compilation
for f in [Broadcast.broadcasted, Broadcast.materialize]
@eval cps(k, ::typeof($f), args...) = k($f(args...))
end

@dynamo function cps(k, args...)
ir = IR(args...)
ir == nothing && return :(args[1](args[2](args[3:end]...)))
cpstransform(IR(args...))
end

# shift/reset

reset(f) = cps(identity, f)
shift(f) = error("`shift` must be called inside `reset`")
cps(k, ::typeof(shift), f) = f(k)

macro reset(ex)
:(reset(() -> $(esc(ex))))
end

k = @reset begin
shift(k -> k)^2
end

k(4) == 16
7 changes: 4 additions & 3 deletions src/IRTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ module Inner
include("reflection/dynamo.jl")

include("passes/passes.jl")
include("passes/cps.jl")
include("passes/relooper.jl")
include("passes/stackifier.jl")

Expand All @@ -36,17 +37,17 @@ end

let exports = :[
# IR
IR, Block, BasicBlock, Variable, Statement, Branch, Pipe, CFG, Slot, branch, var, stmt, arguments, argtypes,
IRTools, IR, Block, BasicBlock, Variable, Statement, Branch, Pipe, CFG, Slot, branch, var, stmt, arguments, argtypes,
branches, undef, unreachable, isreturn, isconditional, block!, deleteblock!, branch!, argument!, return!,
canbranch, returnvalue, returntype, emptyargs!, deletearg!, block, blocks, successors, predecessors,
xcall, exprtype, exprline, isexpr, insertafter!, explicitbranch!, prewalk, postwalk,
prewalk!, postwalk!, finish, substitute!, substitute,
# Passes/Analysis
definitions, usages, dominators, domtree, domorder, domorder!, renumber,
merge_returns!, expand!, prune!, ssa!, inlineable!, log!, pis!, func, evalir,
Simple, Loop, Multiple, reloop, stackify,
Simple, Loop, Multiple, reloop, stackify, functional, cond,
# Reflection, Dynamo
Meta, TypedMeta, meta, typed_meta, dynamo, transform, refresh, recurse!, self,
Meta, TypedMeta, Lambda, meta, typed_meta, dynamo, transform, refresh, recurse!, self,
varargs!, slots!,
].args
append!(exports, Symbol.(["@code_ir", "@dynamo", "@meta", "@typed_meta"]))
Expand Down
16 changes: 11 additions & 5 deletions src/ir/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ end

BasicBlock(b::Block) = b.ir.blocks[b.id]
branches(b::Block) = branches(BasicBlock(b))

branches(ir::IR) = length(blocks(ir)) == 1 ? branches(block(ir, 1)) :
error("IR has multiple blocks, so `branches(ir)` is ambiguous.")

arguments(b::Block) = arguments(BasicBlock(b))
arguments(ir::IR) = arguments(block(ir, 1))

Expand Down Expand Up @@ -843,6 +847,8 @@ var!(p::Pipe) = NewVariable(p.var += 1)
substitute!(p::Pipe, x, y) = (p.map[x] = y; x)
substitute(p::Pipe, x::Union{Variable,NewVariable}) = p.map[x]
substitute(p::Pipe, x) = get(p.map, x, x)
substitute(p::Pipe, x::Statement) = stmt(x, expr = substitute(p, x.expr))
substitute(p::Pipe, x::Expr) = Expr(x.head, substitute.((p,), x.args)...)
substitute(p::Pipe) = x -> substitute(p, x)

function Pipe(ir)
Expand Down Expand Up @@ -876,7 +882,7 @@ function iterate(p::Pipe, (ks, b, i) = (pipestate(p.from), 1, 1))
end
v = ks[b][i]
st = p.from[v]
substitute!(p, v, push!(p.to, prewalk(substitute(p), st)))
substitute!(p, v, push!(p.to, substitute(p, st)))
((v, st), (ks, b, i+1))
end

Expand All @@ -890,13 +896,13 @@ setindex!(p::Pipe, x, v) = p.to[substitute(p, v)] = prewalk(substitute(p), x)

function Base.push!(p::Pipe, x)
tmp = var!(p)
substitute!(p, tmp, push!(p.to, prewalk(substitute(p), x)))
substitute!(p, tmp, push!(p.to, substitute(p, x)))
return tmp
end

function Base.pushfirst!(p::Pipe, x)
tmp = var!(p)
substitute!(p, tmp, pushfirst!(p.to, prewalk(substitute(p), x)))
substitute!(p, tmp, pushfirst!(p.to, substitute(p, x)))
return tmp
end

Expand All @@ -913,7 +919,7 @@ end

function insert!(p::Pipe, v, x; after = false)
v′ = substitute(p, v)
x = prewalk(substitute(p), x)
x = substitute(p, x)
tmp = var!(p)
if islastdef(p.to, v′) # we can make this case efficient by renumbering
if after
Expand All @@ -933,7 +939,7 @@ argument!(p::Pipe, a...; kw...) =
substitute!(p, var!(p), argument!(p.to, a...; kw...))

function branch!(ir::Pipe, b, args...; kw...)
args = map(a -> postwalk(substitute(ir), a), args)
args = map(a -> substitute(ir, a), args)
branch!(blocks(ir.to)[end], b, args...; kw...)
return ir
end
39 changes: 35 additions & 4 deletions src/ir/print.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import Base: show

# TODO: real expression printing
Base.show(io::IO, x::Variable) = print(io, "%", x.id)
function Base.show(io::IO, x::Variable)
bs = get(io, :bindings, Dict())
haskey(bs, x) ? print(io, bs[x]) : print(io, "%", x.id)
end

const printers = Dict{Symbol,Any}()

Expand All @@ -13,12 +16,12 @@ function show(io::IO, b::Branch)
if b == unreachable
print(io, "unreachable")
elseif isreturn(b)
print(io, "return $(repr(b.args[1]))")
print(io, "return ", b.args[1])
else
print(io, "br $(b.block)")
if !isempty(b.args)
print(io, " (")
join(io, repr.(b.args), ", ")
join(io, b.args, ", ")
print(io, ")")
end
b.condition != nothing && print(io, " unless $(b.condition)")
Expand All @@ -39,6 +42,7 @@ end

function show(io::IO, b::Block)
indent = get(io, :indent, 0)
bs = get(io, :bindings, Dict())
bb = BasicBlock(b)
print(io, tab^indent)
print(io, b.id, ":")
Expand All @@ -47,6 +51,7 @@ function show(io::IO, b::Block)
printargs(io, bb.args, bb.argtypes)
end
for (x, st) in b
haskey(bs, x) && continue
println(io)
print(io, tab^indent, " ")
x == nothing || print(io, string("%", x.id), " = ")
Expand Down Expand Up @@ -87,11 +92,37 @@ printers[:catch] = function (io, ex)
args = ex.args[2:end]
if !isempty(args)
print(io, " (")
join(io, repr.(args), ", ")
join(io, args, ", ")
print(io, ")")
end
end

printers[:pop_exception] = function (io, ex)
print(io, "pop exception $(ex.args[1])")
end

function lambdacx(io, ex)
bs = get(io, :bindings, Dict())
ir = ex.args[1]
args = ex.args[2:end]
bs′ = Dict()
for (v, st) in ir
ex = st.expr
if iscall(ex, GlobalRef(Base, :getindex)) &&
ex.args[2] == arguments(ir)[1] &&
ex.args[3] isa Integer
x = args[ex.args[3]]
bs′[v] = string(get(bs, x, x), "'")
end
end
return bs′
end

printers[:lambda] = function (io, ex)
print(io, "λ :")
# printargs(io, ex.args[2:end])
io = IOContext(io, :indent => get(io, :indent, 0)+2,
:bindings => lambdacx(io, ex))
println(io)
print(io, ex.args[1])
end
2 changes: 1 addition & 1 deletion src/ir/wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ slotname(ci, s) = Symbol(:_, s.id)

function IR(ci::CodeInfo, nargs::Integer; meta = nothing)
bs = blockstarts(ci)
ir = IR([ci.linetable...], meta = meta)
ir = IR(Core.LineInfoNode[ci.linetable...], meta = meta)
_rename = Dict()
rename(ex) = prewalk(ex) do x
haskey(_rename, x) && return _rename[x]
Expand Down
Loading