Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fusion of nested f.(args) calls into a single broadcast call #17300

Merged
merged 12 commits into from
Jul 12, 2016
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ New language features
* Generators and comprehensions support filtering using `if` ([#550]) and nested
iteration using multiple `for` keywords ([#4867]).

* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]).
* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
and nested `f.(g.(args...))` calls are fused into a single `broadcast` loop ([#17300]).

* Macro expander functions are now generic, so macros can have multiple definitions
(e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]).
Expand Down Expand Up @@ -316,4 +317,5 @@ Deprecated or removed
[#17037]: https://github.com/JuliaLang/julia/issues/17037
[#17075]: https://github.com/JuliaLang/julia/issues/17075
[#17266]: https://github.com/JuliaLang/julia/issues/17266
[#17300]: https://github.com/JuliaLang/julia/issues/17300
[#17374]: https://github.com/JuliaLang/julia/issues/17374
16 changes: 16 additions & 0 deletions doc/manual/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,22 @@ then ``f.(pi,A)`` will return a new array consisting of ``f(pi,a)`` for each
consisting of ``f(vector1[i],vector2[i])`` for each index ``i``
(throwing an exception if the vectors have different length).

Moreover, *nested* ``f.(args...)`` calls are *fused* into a single ``broadcast``
loop. For example, ``sin.(cos.(X))`` is equivalent to ``broadcast(x -> sin(cos(x)), X)``,
similar to ``[sin(cos(x)) for x in X]``: there is only a single loop over ``X``,
and a single array is allocated for the result. [In contrast, ``sin(cos(X))``
Copy link
Contributor

@tkelman tkelman Jul 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope brackets don't have special meaning in rst

I don't think we use square brackets for parenthetical comments elsewhere in the docs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usual typographical convention is to use square brackets for parenthetical comments that have nested parentheses.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I haven't come across that convention. Do you have a citation for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In math, the usual style is to nest parens as {[(...)]}. In text, it is common to recommend square inside round parens if you must nest, but that isn't possible here because the nested parens are code and hence are constrained by Julia syntax.

http://blog.apastyle.org/apastyle/2013/05/punctuation-junction-parentheses-and-brackets.html
http://www.chicagomanualofstyle.org/16/ch12/ch12_sec026.html
http://www.chicagomanualofstyle.org/16/ch06/ch06_sec099.html

in a typical "vectorized" language would first allocate one temporary array for ``tmp=cos(X)``,
and then compute ``sin(tmp)`` in a separate loop, allocating a second array.]
This loop fusion is not a compiler optimization that may or may not occur, it
is a *syntactic guarantee* whenever nested ``f.(args...)`` calls are encountered. Technically,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why "Technically" ?

Copy link
Member Author

@stevengj stevengj Jul 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is a technical caveat (something that most users won't need to think too much about) to the broad informal statement at the beginning of the paragraph that nested dot calls are fused.

the fusion stops as soon as a "non-dot" function is encountered; for example,
in ``sin.(sort(cos.(X)))`` the ``sin`` and ``cos`` loops cannot be merged
because of the intervening ``sort`` function.

(In future versions of Julia, operators like ``.*`` will also be handled with
the same mechanism: they will be equivalent to ``broadcast`` calls and
will be fused with other nested "dot" calls.)

Further Reading
---------------

Expand Down
125 changes: 120 additions & 5 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,121 @@
(cadr expr) ;; eta reduce `x->f(x)` => `f`
`(-> ,argname (block ,@splat ,expr)))))

(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call
(or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)))

;; fuse nested calls to f.(args...) into a single broadcast call
(define (expand-fuse-broadcast f args)
(define (fuse? e) (and (pair? e) (eq? (car e) 'fuse)))
(define (anyfuse? exprs)
(if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs)))))
(define (to-lambda f args kwargs) ; convert f to anonymous function with hygienic tuple args
(define (genarg arg) (if (vararg? arg) (list '... (gensy)) (gensy)))
; (To do: optimize the case where f is already an anonymous function, in which
; case we only need to hygienicize the arguments? But it is quite tricky
; to fully handle splatted args, typed args, keywords, etcetera. And probably
; the extra function call is harmless because it will get inlined anyway.)
(let ((genargs (map genarg args))) ; hygienic formal parameters
(if (null? kwargs)
`(-> ,(cons 'tuple genargs) (call ,f ,@genargs)) ; no keyword args
`(-> ,(cons 'tuple genargs) (call ,f (parameters ,@kwargs) ,@genargs)))))
(define (from-lambda f) ; convert (-> (tuple args...) (call func args...)) back to func
(if (and (pair? f) (eq? (car f) '->) (pair? (cadr f)) (eq? (caadr f) 'tuple)
(pair? (caddr f)) (eq? (caaddr f) 'call) (equal? (cdadr f) (cdr (cdaddr f))))
(car (cdaddr f))
f))
(define (fuse-args oldargs) ; replace (fuse f args) with args in oldargs list
(define (fargs newargs oldargs)
(if (null? oldargs)
newargs
(fargs (if (fuse? (car oldargs))
(append (reverse (caddar oldargs)) newargs)
(cons (car oldargs) newargs))
(cdr oldargs))))
(reverse (fargs '() oldargs)))
(define (fuse-funcs f args) ; for (fuse g a) in args, merge/inline g into f
; any argument A of f that is (fuse g a) gets replaced by let A=(body of g):
(define (fuse-lets fargs args lets)
(if (null? args)
lets
(if (fuse? (car args))
(fuse-lets (cdr fargs) (cdr args) (cons (list '= (car fargs) (caddr (cadar args))) lets))
(fuse-lets (cdr fargs) (cdr args) lets))))
(let ((fargs (cdadr f))
(fbody (caddr f)))
`(->
(tuple ,@(fuse-args (map (lambda (oldarg arg) (if (fuse? arg)
`(fuse _ ,(cdadr (cadr arg)))
oldarg))
fargs args)))
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
(if (null? args)
(cons kwargs pargs)
(if (kwarg? (car args))
(sk (cdr args) (cons (car args) kwargs) pargs)
(sk (cdr args) kwargs (cons (car args) pargs)))))
(if (has-parameters? args)
(sk (reverse (cdr args)) (cdar args) '())
(sk (reverse args) '() '())))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
(if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e))))
(make-fuse (cadr e) (cdaddr e))
e))
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
; given e == (fuse lambda args), compress the argument list by removing (pure)
; duplicates in args, inlining literals, and moving any varargs to the end:
(define (compress-fuse e)
(define (findfarg arg args fargs) ; for arg in args, return corresponding farg
(if (eq? arg (car args))
(car fargs)
(findfarg arg (cdr args) (cdr fargs))))
(let ((f (cadr e))
(args (caddr e)))
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
(if (null? old-args)
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
(nargs (if (null? vararg) new-args (cons vararg new-args))))
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
,(reverse nargs)))
(let ((farg (car old-fargs)) (arg (car old-args)))
(cond
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
(if (null? varfarg)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args renames farg arg)
(if (eq? (cadr vararg) (cadr arg))
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
varfarg vararg)
(error "multiple splatted args cannot be fused into a single broadcast"))))
((number? arg) ; inline numeric literals
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg arg) renames)
varfarg vararg))
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
; (note: calling memq for every arg is O(length(args)^2) ...
; ... would be better to replace with a hash table if args is long)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
varfarg vararg))
(else
(cf (cdr old-fargs) (cdr old-args)
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '())))
(let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args)
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))))

;; table mapping expression head to a function expanding that form
(define expand-table
(table
Expand Down Expand Up @@ -1584,11 +1699,11 @@
(lambda (e) ; e = (|.| f x)
(let ((f (cadr e))
(x (caddr e)))
(if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
; otherwise, came from f.(args...) --> broadcast(f, args...),
; where x = (tuple args...) at this point:
(expand-forms `(call broadcast ,f ,@(cdr x))))))
(if (getfield-field? x)
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
; otherwise, came from f.(args...) --> broadcast(f, args...),
; where we want to fuse with any nested broadcast calls.
(expand-fuse-broadcast f (cdr x)))))

'|<:| syntactic-op-to-call
'|>:| syntactic-op-to-call
Expand Down
35 changes: 35 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,41 @@ let a = sin.([1, 2])
@test a ≈ [0.8414709848078965, 0.9092974268256817]
end

# PR #17300: loop fusion
@test (x->x+1).((x->x+2).((x->x+3).(1:10))) == collect(7:16)
let A = [sqrt(i)+j for i = 1:3, j=1:4]
@test atan2.(log.(A), sum(A,1)) == broadcast(atan2, broadcast(log, A), sum(A, 1))
end
let x = sin.(1:10)
@test atan2.((x->x+1).(x), (x->x+2).(x)) == atan2(x+1, x+2) == atan2(x.+1, x.+2)
@test sin.(atan2.([x+1,x+2]...)) == sin.(atan2.(x+1,x+2))
@test sin.(atan2.(x, 3.7)) == broadcast(x -> sin(atan2(x,3.7)), x)
@test atan2.(x, 3.7) == broadcast(x -> atan2(x,3.7), x) == broadcast(atan2, x, 3.7)
end
# Use side effects to check for loop fusion. Note that, due to #17314,
# a broadcasted function is currently called an extra time with an argument 1.
let g = Int[]
f17300(x) = begin; push!(g, x); x+1; end
f17300.(f17300.(f17300.(1:3)))
@test g == [1,2,3, 1,2,3, 2,3,4, 3,4,5]
end
# fusion with splatted args:
let x = sin.(1:10), a = [x]
@test cos.(x) == cos.(a...)
@test atan2.(x,x) == atan2.(a..., a...) == atan2.([x, x]...)
@test atan2.(x, cos.(x)) == atan2.(a..., cos.(x)) == atan2(x, cos.(a...)) == atan2(a..., cos.(a...))
@test ((args...)->cos(args[1])).(x) == cos.(x) == ((y,args...)->cos(y)).(x)
end
@test atan2.(3,4) == atan2(3,4) == (() -> atan2(3,4)).()
# fusion with keyword args:
let x = [1:4;]
f17300kw(x; y=0) = x + y
@test f17300kw.(x) == x
@test f17300kw.(x, y=1) == f17300kw.(x; y=1) == f17300kw.(x; [(:y,1)]...) == x .+ 1
@test f17300kw.(sin.(x), y=1) == f17300kw.(sin.(x); y=1) == sin.(x) .+ 1
@test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1)
end

# PR 16988
@test Base.promote_op(+, Bool) === Int
@test isa(broadcast(+, [true]), Array{Int,1})
Expand Down