Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo Plan: wrapping literals #2

Open
oxinabox opened this issue Feb 14, 2019 · 1 comment
Open

Demo Plan: wrapping literals #2

oxinabox opened this issue Feb 14, 2019 · 1 comment

Comments

@oxinabox
Copy link
Contributor

Use a pass to find any literals in the IR,
and replace them with a call to a function
where they are the input.
Could be a convert or a constructor, or wherver

@staticfloat meantioned this would be useful for wrapping XLALiterals

@jrevels
Copy link
Owner

jrevels commented Feb 20, 2019

Here's a naive version of the transform. Haven't tested it much so there might be bugs, but this should at least help with getting started.

There's probably more stuff you'd want to handle specially, e.g. gotoifnot statements, _apply calls (or maybe the literal tuples there are fine already?), etc.

using Cassette
using Core: CodeInfo, SlotNumber, NewvarNode, GotoNode, SSAValue

is_literal(x) = !(isa(x, Expr) ||
                  isa(x, GlobalRef) ||
                  isa(x, SSAValue) ||
                  isa(x, GotoNode) ||
                  isa(x, SlotNumber) ||
                  isa(x, NewvarNode))

rhs(stmt) = Base.Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt

function wrap_literals!(wrapper, info::CodeInfo)
    count_literal_args = (stmt, i) -> begin
        stmt = rhs(stmt)
        if is_literal(stmt)
            return 1
        elseif Base.Meta.isexpr(stmt, :call)
            literal_count = count(is_literal, stmt.args[2:end])
            literal_count > 0 && return literal_count + 1
        end
        return nothing
    end
    wrap_literal_args = (old, i) -> begin
        stmt = rhs(old)
        inserted = Any[]
        if is_literal(stmt)
            stmt = Expr(:call, wrapper, stmt)
        else
            offset = 0
            argidx = 2
            for arg in stmt.args[2:end]
                if is_literal(arg)
                    push!(inserted, Expr(:call, wrapper, arg))
                    stmt.args[argidx] = SSAValue(i + offset)
                    offset += 1
                end
                argidx += 1
            end
        end
        if Base.Meta.isexpr(old, :(=))
            old.args[2] = stmt
        else
            old = stmt
        end
        push!(inserted, old)
        return inserted
    end
    Cassette.insert_statements!(info.code, info.codelocs,
                                count_literal_args,
                                wrap_literal_args)
    return info
end

example in action:

julia> function rosenbrock(x::Vector)
                  a = 1.0
                  b = 100.0
                  result = 0.0
                  for i in 1:length(x)-1
                      result += (a - x[i])^2 + b*(x[i+1] - x[i]^2)^2
                  end
                  return result
              end
rosenbrock (generic function with 1 method)

julia> info = @code_lowered rosenbrock(rand(2))
CodeInfo(
1 ─       a = 1.0
│         b = 100.0
│         result = 0.0%4  = Main.length(x)
│   %5  = %4 - 1%6  = 1:%5#temp# = Base.iterate(%6)%8  = #temp# === nothing%9  = Base.not_int(%8)
└──       goto #4 if not %9
2%11 = #temp#
│         i = Core.getfield(%11, 1)
│   %13 = Core.getfield(%11, 2)
│   %14 = result
│   %15 = a
│   %16 = Base.getindex(x, i)
│   %17 = %15 - %16%18 = Core.apply_type(Base.Val, 2)
│   %19 = (%18)()
│   %20 = Base.literal_pow(Main.:^, %17, %19)
│   %21 = b
│   %22 = i + 1%23 = Base.getindex(x, %22)
│   %24 = Base.getindex(x, i)
│   %25 = Core.apply_type(Base.Val, 2)
│   %26 = (%25)()
│   %27 = Base.literal_pow(Main.:^, %24, %26)
│   %28 = %23 - %27%29 = Core.apply_type(Base.Val, 2)
│   %30 = (%29)()
│   %31 = Base.literal_pow(Main.:^, %28, %30)
│   %32 = %21 * %31%33 = %20 + %32
│         result = %14 + %33#temp# = Base.iterate(%6, %13)%36 = #temp# === nothing%37 = Base.not_int(%36)
└──       goto #4 if not %37
3 ─       goto #2
4return result
)

julia> wrap_literals!(:MYWRAPPER, info)
CodeInfo(
1 ─       a = MYWRAPPER(1.0)
│         b = MYWRAPPER(100.0)
│         result = MYWRAPPER(0.0)
│   %4  = Main.length(x)
│   %5  = MYWRAPPER(1)
│   %6  = %4 - %5%7  = MYWRAPPER(1)
│   %8  = %7:%6#temp# = Base.iterate(%8)%10 = MYWRAPPER(nothing)
│   %11 = #temp# === %10%12 = Base.not_int(%11)
└──       goto #4 if not %12
2%14 = #temp#%15 = MYWRAPPER(1)
│         i = Core.getfield(%14, %15)
│   %17 = MYWRAPPER(2)
│   %18 = Core.getfield(%14, %17)
│   %19 = result
│   %20 = a
│   %21 = Base.getindex(x, i)
│   %22 = %20 - %21%23 = MYWRAPPER(2)
│   %24 = Core.apply_type(Base.Val, %23)
│   %25 = (%24)()
│   %26 = Base.literal_pow(Main.:^, %22, %25)
│   %27 = b
│   %28 = MYWRAPPER(1)
│   %29 = i + %28%30 = Base.getindex(x, %29)
│   %31 = Base.getindex(x, i)
│   %32 = MYWRAPPER(2)
│   %33 = Core.apply_type(Base.Val, %32)
│   %34 = (%33)()
│   %35 = Base.literal_pow(Main.:^, %31, %34)
│   %36 = %30 - %35%37 = MYWRAPPER(2)
│   %38 = Core.apply_type(Base.Val, %37)
│   %39 = (%38)()
│   %40 = Base.literal_pow(Main.:^, %36, %39)
│   %41 = %27 * %40%42 = %26 + %41
│         result = %19 + %42#temp# = Base.iterate(%8, %18)%45 = MYWRAPPER(nothing)
│   %46 = #temp# === %45%47 = Base.not_int(%46)
└──       goto #4 if not %47
3 ─       goto #2
4return result
)

EDIT: fix assignment bug

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants