Skip to content

Commit

Permalink
Merge pull request #17300 from JuliaLang/dot-fusion
Browse files Browse the repository at this point in the history
fusion of nested f.(args) calls into a single broadcast call
  • Loading branch information
JeffBezanson authored Jul 12, 2016
2 parents f47d9fe + fb8f1e1 commit 8fdaf91
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 6 deletions.
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ New language features
* Generators and comprehensions support filtering using `if` ([#550]) and nested
iteration using multiple `for` keywords ([#4867]).

* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]).
* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
and nested `f.(g.(args...))` calls are fused into a single `broadcast` loop ([#17300]).

* Macro expander functions are now generic, so macros can have multiple definitions
(e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]).
Expand Down Expand Up @@ -319,4 +320,5 @@ Deprecated or removed
[#17037]: https://github.com/JuliaLang/julia/issues/17037
[#17075]: https://github.com/JuliaLang/julia/issues/17075
[#17266]: https://github.com/JuliaLang/julia/issues/17266
[#17300]: https://github.com/JuliaLang/julia/issues/17300
[#17374]: https://github.com/JuliaLang/julia/issues/17374
16 changes: 16 additions & 0 deletions doc/manual/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,22 @@ then ``f.(pi,A)`` will return a new array consisting of ``f(pi,a)`` for each
consisting of ``f(vector1[i],vector2[i])`` for each index ``i``
(throwing an exception if the vectors have different length).

Moreover, *nested* ``f.(args...)`` calls are *fused* into a single ``broadcast``
loop. For example, ``sin.(cos.(X))`` is equivalent to ``broadcast(x -> sin(cos(x)), X)``,
similar to ``[sin(cos(x)) for x in X]``: there is only a single loop over ``X``,
and a single array is allocated for the result. [In contrast, ``sin(cos(X))``
in a typical "vectorized" language would first allocate one temporary array for ``tmp=cos(X)``,
and then compute ``sin(tmp)`` in a separate loop, allocating a second array.]
This loop fusion is not a compiler optimization that may or may not occur, it
is a *syntactic guarantee* whenever nested ``f.(args...)`` calls are encountered. Technically,
the fusion stops as soon as a "non-dot" function is encountered; for example,
in ``sin.(sort(cos.(X)))`` the ``sin`` and ``cos`` loops cannot be merged
because of the intervening ``sort`` function.

(In future versions of Julia, operators like ``.*`` will also be handled with
the same mechanism: they will be equivalent to ``broadcast`` calls and
will be fused with other nested "dot" calls.)

Further Reading
---------------

Expand Down
125 changes: 120 additions & 5 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,121 @@
(cadr expr) ;; eta reduce `x->f(x)` => `f`
`(-> ,argname (block ,@splat ,expr)))))

(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call
(or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)))

;; fuse nested calls to f.(args...) into a single broadcast call
(define (expand-fuse-broadcast f args)
(define (fuse? e) (and (pair? e) (eq? (car e) 'fuse)))
(define (anyfuse? exprs)
(if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs)))))
(define (to-lambda f args kwargs) ; convert f to anonymous function with hygienic tuple args
(define (genarg arg) (if (vararg? arg) (list '... (gensy)) (gensy)))
; (To do: optimize the case where f is already an anonymous function, in which
; case we only need to hygienicize the arguments? But it is quite tricky
; to fully handle splatted args, typed args, keywords, etcetera. And probably
; the extra function call is harmless because it will get inlined anyway.)
(let ((genargs (map genarg args))) ; hygienic formal parameters
(if (null? kwargs)
`(-> ,(cons 'tuple genargs) (call ,f ,@genargs)) ; no keyword args
`(-> ,(cons 'tuple genargs) (call ,f (parameters ,@kwargs) ,@genargs)))))
(define (from-lambda f) ; convert (-> (tuple args...) (call func args...)) back to func
(if (and (pair? f) (eq? (car f) '->) (pair? (cadr f)) (eq? (caadr f) 'tuple)
(pair? (caddr f)) (eq? (caaddr f) 'call) (equal? (cdadr f) (cdr (cdaddr f))))
(car (cdaddr f))
f))
(define (fuse-args oldargs) ; replace (fuse f args) with args in oldargs list
(define (fargs newargs oldargs)
(if (null? oldargs)
newargs
(fargs (if (fuse? (car oldargs))
(append (reverse (caddar oldargs)) newargs)
(cons (car oldargs) newargs))
(cdr oldargs))))
(reverse (fargs '() oldargs)))
(define (fuse-funcs f args) ; for (fuse g a) in args, merge/inline g into f
; any argument A of f that is (fuse g a) gets replaced by let A=(body of g):
(define (fuse-lets fargs args lets)
(if (null? args)
lets
(if (fuse? (car args))
(fuse-lets (cdr fargs) (cdr args) (cons (list '= (car fargs) (caddr (cadar args))) lets))
(fuse-lets (cdr fargs) (cdr args) lets))))
(let ((fargs (cdadr f))
(fbody (caddr f)))
`(->
(tuple ,@(fuse-args (map (lambda (oldarg arg) (if (fuse? arg)
`(fuse _ ,(cdadr (cadr arg)))
oldarg))
fargs args)))
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
(if (null? args)
(cons kwargs pargs)
(if (kwarg? (car args))
(sk (cdr args) (cons (car args) kwargs) pargs)
(sk (cdr args) kwargs (cons (car args) pargs)))))
(if (has-parameters? args)
(sk (reverse (cdr args)) (cdar args) '())
(sk (reverse args) '() '())))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
(if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e))))
(make-fuse (cadr e) (cdaddr e))
e))
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
; given e == (fuse lambda args), compress the argument list by removing (pure)
; duplicates in args, inlining literals, and moving any varargs to the end:
(define (compress-fuse e)
(define (findfarg arg args fargs) ; for arg in args, return corresponding farg
(if (eq? arg (car args))
(car fargs)
(findfarg arg (cdr args) (cdr fargs))))
(let ((f (cadr e))
(args (caddr e)))
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
(if (null? old-args)
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
(nargs (if (null? vararg) new-args (cons vararg new-args))))
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
,(reverse nargs)))
(let ((farg (car old-fargs)) (arg (car old-args)))
(cond
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
(if (null? varfarg)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args renames farg arg)
(if (eq? (cadr vararg) (cadr arg))
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
varfarg vararg)
(error "multiple splatted args cannot be fused into a single broadcast"))))
((number? arg) ; inline numeric literals
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg arg) renames)
varfarg vararg))
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
; (note: calling memq for every arg is O(length(args)^2) ...
; ... would be better to replace with a hash table if args is long)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
varfarg vararg))
(else
(cf (cdr old-fargs) (cdr old-args)
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '())))
(let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args)
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))))

;; table mapping expression head to a function expanding that form
(define expand-table
(table
Expand Down Expand Up @@ -1584,11 +1699,11 @@
(lambda (e) ; e = (|.| f x)
(let ((f (cadr e))
(x (caddr e)))
(if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
; otherwise, came from f.(args...) --> broadcast(f, args...),
; where x = (tuple args...) at this point:
(expand-forms `(call broadcast ,f ,@(cdr x))))))
(if (getfield-field? x)
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
; otherwise, came from f.(args...) --> broadcast(f, args...),
; where we want to fuse with any nested broadcast calls.
(expand-fuse-broadcast f (cdr x)))))

'|<:| syntactic-op-to-call
'|>:| syntactic-op-to-call
Expand Down
35 changes: 35 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,41 @@ let a = sin.([1, 2])
@test a [0.8414709848078965, 0.9092974268256817]
end

# PR #17300: loop fusion
@test (x->x+1).((x->x+2).((x->x+3).(1:10))) == collect(7:16)
let A = [sqrt(i)+j for i = 1:3, j=1:4]
@test atan2.(log.(A), sum(A,1)) == broadcast(atan2, broadcast(log, A), sum(A, 1))
end
let x = sin.(1:10)
@test atan2.((x->x+1).(x), (x->x+2).(x)) == atan2(x+1, x+2) == atan2(x.+1, x.+2)
@test sin.(atan2.([x+1,x+2]...)) == sin.(atan2.(x+1,x+2))
@test sin.(atan2.(x, 3.7)) == broadcast(x -> sin(atan2(x,3.7)), x)
@test atan2.(x, 3.7) == broadcast(x -> atan2(x,3.7), x) == broadcast(atan2, x, 3.7)
end
# Use side effects to check for loop fusion. Note that, due to #17314,
# a broadcasted function is currently called an extra time with an argument 1.
let g = Int[]
f17300(x) = begin; push!(g, x); x+1; end
f17300.(f17300.(f17300.(1:3)))
@test g == [1,2,3, 1,2,3, 2,3,4, 3,4,5]
end
# fusion with splatted args:
let x = sin.(1:10), a = [x]
@test cos.(x) == cos.(a...)
@test atan2.(x,x) == atan2.(a..., a...) == atan2.([x, x]...)
@test atan2.(x, cos.(x)) == atan2.(a..., cos.(x)) == atan2(x, cos.(a...)) == atan2(a..., cos.(a...))
@test ((args...)->cos(args[1])).(x) == cos.(x) == ((y,args...)->cos(y)).(x)
end
@test atan2.(3,4) == atan2(3,4) == (() -> atan2(3,4)).()
# fusion with keyword args:
let x = [1:4;]
f17300kw(x; y=0) = x + y
@test f17300kw.(x) == x
@test f17300kw.(x, y=1) == f17300kw.(x; y=1) == f17300kw.(x; [(:y,1)]...) == x .+ 1
@test f17300kw.(sin.(x), y=1) == f17300kw.(sin.(x); y=1) == sin.(x) .+ 1
@test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1)
end

# PR 16988
@test Base.promote_op(+, Bool) === Int
@test isa(broadcast(+, [true]), Array{Int,1})
Expand Down

0 comments on commit 8fdaf91

Please sign in to comment.