Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lambdas #27

Merged
merged 9 commits into from
Feb 6, 2020
Merged

Lambdas #27

merged 9 commits into from
Feb 6, 2020

Conversation

MikeInnes
Copy link
Member

@MikeInnes MikeInnes commented Sep 16, 2019

An in-IR representation of closures/lamdbas. The main goal is that it's useful for external IRs which make use of lambdas for control flow (e.g. ONNX, XLA, Myia), though this could also dovetail nicely with first-class closure support in Julia's compiler too (JuliaLang/julia#31253).

ir = IR()
x = argument!(ir)

t_block = IR()
return!(t_block, xcall(:getindex, argument!(t_block), 1))
t = push!(ir, Expr(:lambda, t_block, x))

f_block = IR()
return!(f_block, xcall(:zero, xcall(:getindex, argument!(f_block), 1)))
f = push!(ir, Expr(:lambda, f_block, x))

push!(ir, xcall(:ifelse, xcall(:>, x, 0), t, f))

produces:

1: (%1)
  %2 = λ: (%1)
    1: (%1)
      %2 = Base.getindex(%1, 1)
      return %2

  %3 = λ: (%1)
    1: (%1)
      %2 = Base.getindex(%1, 1)
      %3 = Base.zero(%2)
      return %3

  %4 = %1 > 0
  %5 = Base.ifelse(%4, %2, %3)

The semantics are that the tuple of values in λ: (...) get passed, as a tuple, as the lambda's first argument. So in this case the lambdas are zero-argument thunks, and the getindex calls are retrieving the original input x.

I could imagine coming up with nicer syntax sugar / printing for this (e.g. maybe %x' asks for x in the enclosing scope) but for now this is OK. We just need some tests (might need to finally implement that interpreter for those).

@MikeInnes MikeInnes marked this pull request as ready for review January 16, 2020 15:04
@MikeInnes
Copy link
Member Author

MikeInnes commented Feb 5, 2020

Small update: I wrote a quick utility to turn our current CFGs into a pure-functional closure nest (equivalent to A-Normal Form).

julia> relu(x) = x > 0 ? x : 0
relu (generic function with 1 method)

julia> functional(@code_ir relu(1))
1: (%1, %2)
  %3 = %2 > 0
  %4 = λ :
    1: (%1)
      return 0
  %5 = λ :
    1: (%1)
      return %2'
  %6 = IRTools.cond(%3, %5, %4)
  return %6

julia> function pow(x, n)
         r = 1
         while n > 0
           n -= 1
           r *= x
         end
         return r
       end
pow (generic function with 1 method)

julia> functional(@code_ir pow(2, 3))
1: (%1, %2, %3)
  %4 = λ :
    1: (%1, %2, %3)
      %5 = %2 > 0
      %6 = λ :
        1: (%1)
          return %3'
      %7 = λ :
        1: (%1)
          %6 = %2' - 1
          %7 = %3' * %2''
          %8 = (%1')(%6, %7)
          return %8
      %8 = IRTools.cond(%5, %7, %6)
      return %8
  %5 = (%4)(%3, 1)
  return %5

(cond(c, t, f) = c ? t() : f() and recursion is implemented by calling the usual self parameter.)

Not that useful in itself, but it's easy to translate this form into CPS and get continuations. I also need to implement support in dynamos so that they can create closures, and then we can test this.

@MikeInnes
Copy link
Member Author

Added a small demo of how to do the CPS transformation. Sneak peek:

julia> foo(x) = f(g(h(x)))
foo (generic function with 1 method)

julia> cpstransform(@code_ir foo(1))
1: (%1, %2, %3)
  %4 = λ :
    1: (%1, %2)
      %4 = λ :
        1: (%1, %2)
          %4 = λ :
            1: (%1, %2)
              %4 = (%1''')(%2)
              return %4
          %5 = cps(%4, f, %2)
          return %5
      %5 = cps(%4, g, %2)
      return %5
  %5 = cps(%4, h, %3)
  return %5

julia> cpstransform(@code_ir pow(2, 3))
1: (%1, %2, %3, %4)
  %5 = λ :
    1: (%1, %2, %3, %4)
      %6 = λ :
        1: (%1, %2)
          %8 = λ :
            1: (%1, %2)
              %4 = (%2)(%4'')
              return %4
          %9 = λ :
            1: (%1, %2)
              %7 = λ :
                1: (%1, %2)
                  %7 = λ :
                    1: (%1, %2)
                      %6 = λ :
                        1: (%1, %2)
                          %4 = (%2''')(%2)
                          return %4
                      %7 = cps(%6, %1'''', %2', %2)
                      return %7
                  %8 = cps(%7, *, %4''', %3'''')
                  return %8
              %8 = cps(%7, -, %3'', 1)
              return %8
          %10 = λ :
            1: (%1, %2)
              %4 = (%2'')(%2)
              return %4
          %11 = cps(%10, cond, %2, %9, %8)
          return %11
      %7 = cps(%6, >, %3, 0)
      return %7
  %6 = λ :
    1: (%1, %2)
      %4 = (%1')(%2)
      return %4
  %7 = cps(%6, %5, %4, 1)
  return %7

@MikeInnes MikeInnes merged commit 6f66487 into master Feb 6, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant