Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

treat .= as syntactic sugar for broadcast! #17510

Merged
merged 7 commits into from
Jul 21, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ New language features
* Generators and comprehensions support filtering using `if` ([#550]) and nested
iteration using multiple `for` keywords ([#4867]).

* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
* Fused broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
and nested `f.(g.(args...))` calls are fused into a single `broadcast` loop ([#17300]).
Similarly, the syntax `x .= ...` is equivalent to a `broadcast!(identity, x, ...)`
call and fuses with nested "dot" calls; also, `x .+= y` and similar is now
equivalent to `x .= x .+ y`, rather than `=` ([#17510]).

* Macro expander functions are now generic, so macros can have multiple definitions
(e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]).
Expand Down Expand Up @@ -355,3 +358,4 @@ Deprecated or removed
[#17393]: https://github.com/JuliaLang/julia/issues/17393
[#17402]: https://github.com/JuliaLang/julia/issues/17402
[#17404]: https://github.com/JuliaLang/julia/issues/17404
[#17510]: https://github.com/JuliaLang/julia/issues/17510
9 changes: 9 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ export broadcast_getindex, broadcast_setindex!
broadcast(f) = f()
broadcast(f, x::Number...) = f(x...)

# special cases for "X .= ..." (broadcast!) assignments
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
broadcast!(f, X::AbstractArray) = fill!(X, f())
broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...))
function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N})
check_broadcast_shape(size(x), size(y))
copy!(x, y)
end

## Calculate the broadcast shape of the arguments, or error if incompatible
# array inputs
broadcast_shape() = ()
Expand Down
6 changes: 4 additions & 2 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,10 @@ show_unquoted(io::IO, ex, ::Int,::Int) = show(io, ex)
const indent_width = 4
const quoted_syms = Set{Symbol}([:(:),:(::),:(:=),:(=),:(==),:(!=),:(===),:(!==),:(=>),:(>=),:(<=)])
const uni_ops = Set{Symbol}([:(+), :(-), :(!), :(¬), :(~), :(<:), :(>:), :(√), :(∛), :(∜)])
const expr_infix_wide = Set{Symbol}([:(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(&=),
:(|=), :($=), :(>>>=), :(>>=), :(<<=), :(&&), :(||), :(<:), :(=>), :(÷=)])
const expr_infix_wide = Set{Symbol}([
:(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(^=), :(&=), :(|=), :(÷=), :(%=), :(>>>=), :(>>=), :(<<=),
:(.=), :(.+=), :(.-=), :(.*=), :(./=), :(.\=), :(.^=), :(.&=), :(.|=), :(.÷=), :(.%=), :(.>>>=), :(.>>=), :(.<<=),
:(&&), :(||), :(<:), :(=>), :($=)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does .$= not work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tkelman, .$= is not supported by the parser – it was not added by #17393 because the $ is pretty special (e.g. in macros you can do a.$b) and it seemed like it would require more parser hacking to support .$= than it was worth at this point in the 0.5 cycle.

const expr_infix = Set{Symbol}([:(:), :(->), Symbol("::")])
const expr_infix_any = union(expr_infix, expr_infix_wide)
const all_ops = union(quoted_syms, uni_ops, expr_infix_any)
Expand Down
2 changes: 1 addition & 1 deletion doc/manual/arrays.rst
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ function elementwise:
1.71056 0.847604
1.73659 0.873631

Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:.).
Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:).

Implementation
--------------
Expand Down
13 changes: 12 additions & 1 deletion doc/manual/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -652,9 +652,20 @@ the fusion stops as soon as a "non-dot" function is encountered; for example,
in ``sin.(sort(cos.(X)))`` the ``sin`` and ``cos`` loops cannot be merged
because of the intervening ``sort`` function.

Finally, the maximum efficiency is typically achieved when the output
array of a vectorized operation is *pre-allocated*, so that repeated
calls do not allocate new arrays over and over again for the results
(:ref:`man-preallocation`:). A convenient syntax for this is
``X .= ...``, which is equivalent to ``broadcast!(identity, X, ...)``
except that, as above, the ``broadcast!`` loop is fused with any nested
"dot" calls. For example, ``X .= sin.(Y)`` is equivalent to
``broadcast!(sin, X, Y)``, overwriting ``X`` with ``sin.(Y)`` in-place.

(In future versions of Julia, operators like ``.*`` will also be handled with
the same mechanism: they will be equivalent to ``broadcast`` calls and
will be fused with other nested "dot" calls.)
will be fused with other nested "dot" calls. ``x .+= y`` is equivalent
to ``x .= x .+ y`` and will eventually result in a fused in-place assignment.
Similarly for ``.*=`` etcetera.)

Further Reading
---------------
Expand Down
5 changes: 4 additions & 1 deletion doc/manual/performance-tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,10 @@ above, we could have passed a :class:`SubArray` rather than an :class:`Array`,
had we so desired.

Taken to its extreme, pre-allocation can make your code uglier, so
performance measurements and some judgment may be required.
performance measurements and some judgment may be required. However,
for "vectorized" (element-wise) functions, the convenient syntax
``x .= f.(y)`` can be used for in-place operations with fused loops
and no temporary arrays (:ref:`dot-vectorizing`).


Avoid string interpolation for I/O
Expand Down
177 changes: 95 additions & 82 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -1418,12 +1418,12 @@
`(call ,(cadr e) ,(expand-forms a) ,(expand-forms b))))))

;; convert `a+=b` to `a=a+b`
(define (expand-update-operator- op lhs rhs declT)
(define (expand-update-operator- op op= lhs rhs declT)
(let ((e (remove-argument-side-effects lhs)))
`(block ,@(cdr e)
,(if (null? declT)
`(= ,(car e) (call ,op ,(car e) ,rhs))
`(= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))
`(,op= ,(car e) (call ,op ,(car e) ,rhs))
`(,op= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))

(define (partially-expand-ref e)
(let ((a (cadr e))
Expand All @@ -1443,31 +1443,32 @@
,@(append stmts stuff)
(call getindex ,arr ,@new-idxs))))))

(define (expand-update-operator op lhs rhs . declT)
(define (expand-update-operator op op= lhs rhs . declT)
(cond ((and (pair? lhs) (eq? (car lhs) 'ref))
;; expand indexing inside op= first, to remove "end" and ":"
(let* ((ex (partially-expand-ref lhs))
(stmts (butlast (cdr ex)))
(refex (last (cdr ex)))
(nuref `(ref ,(caddr refex) ,@(cdddr refex))))
`(block ,@stmts
,(expand-update-operator- op nuref rhs declT))))
,(expand-update-operator- op op= nuref rhs declT))))
((and (pair? lhs) (eq? (car lhs) '|::|))
;; (+= (:: x T) rhs)
(let ((e (remove-argument-side-effects (cadr lhs)))
(T (caddr lhs)))
`(block ,@(cdr e)
,(expand-update-operator op (car e) rhs T))))
,(expand-update-operator op op= (car e) rhs T))))
(else
(expand-update-operator- op lhs rhs declT))))
(expand-update-operator- op op= lhs rhs declT))))

(define (lower-update-op e)
(expand-forms
(expand-update-operator
(let ((str (string (car e))))
(symbol (string.sub str 0 (- (length str) 1))))
(cadr e)
(caddr e))))
(let ((str (string (car e))))
(expand-update-operator
(symbol (string.sub str 0 (- (length str) 1)))
(if (= (string.char str 0) #\.) '.= '=)
(cadr e)
(caddr e)))))

(define (expand-and e)
(let ((e (cdr (flatten-ex '&& e))))
Expand Down Expand Up @@ -1546,11 +1547,9 @@
(cadr expr) ;; eta reduce `x->f(x)` => `f`
`(-> ,argname (block ,@splat ,expr)))))

(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call
(or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)))

;; fuse nested calls to f.(args...) into a single broadcast call
(define (expand-fuse-broadcast f args)
; fuse nested calls to expr == f.(args...) into a single broadcast call,
; or a broadcast! call if lhs is non-null.
(define (expand-fuse-broadcast lhs rhs)
(define (fuse? e) (and (pair? e) (eq? (car e) 'fuse)))
(define (anyfuse? exprs)
(if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs)))))
Expand Down Expand Up @@ -1594,72 +1593,83 @@
oldarg))
fargs args)))
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
(if (null? args)
(cons kwargs pargs)
(if (kwarg? (car args))
(sk (cdr args) (cons (car args) kwargs) pargs)
(sk (cdr args) kwargs (cons (car args) pargs)))))
(if (has-parameters? args)
(sk (reverse (cdr args)) (cdar args) '())
(sk (reverse args) '() '())))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
(if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e))))
(make-fuse (cadr e) (cdaddr e))
e))
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
(define (sk args kwargs pargs)
(if (null? args)
(cons kwargs pargs)
(if (kwarg? (car args))
(sk (cdr args) (cons (car args) kwargs) pargs)
(sk (cdr args) kwargs (cons (car args) pargs)))))
(if (has-parameters? args)
(sk (reverse (cdr args)) (cdar args) '())
(sk (reverse args) '() '())))
(let* ((kws.args (split-kwargs args))
(kws (car kws.args))
(args (cdr kws.args)) ; fusing occurs on positional args only
(args_ (map dot-to-fuse args)))
(if (anyfuse? args_)
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
`(fuse ,(to-lambda f args kws) ,args_))))
(if (and (pair? e) (eq? (car e) '|.|))
(let ((f (cadr e)) (x (caddr e)))
(if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))
`(call (core getfield) ,f ,x)
(make-fuse f (cdr x))))
e))
; given e == (fuse lambda args), compress the argument list by removing (pure)
; duplicates in args, inlining literals, and moving any varargs to the end:
(define (compress-fuse e)
(define (findfarg arg args fargs) ; for arg in args, return corresponding farg
(if (eq? arg (car args))
(car fargs)
(findfarg arg (cdr args) (cdr fargs))))
(let ((f (cadr e))
(args (caddr e)))
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
(if (null? old-args)
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
(nargs (if (null? vararg) new-args (cons vararg new-args))))
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
,(reverse nargs)))
(let ((farg (car old-fargs)) (arg (car old-args)))
(cond
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
(if (null? varfarg)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args renames farg arg)
(if (eq? (cadr vararg) (cadr arg))
(if (fuse? e)
(let ((f (cadr e))
(args (caddr e)))
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
(if (null? old-args)
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
(nargs (if (null? vararg) new-args (cons vararg new-args))))
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
,(reverse nargs)))
(let ((farg (car old-fargs)) (arg (car old-args)))
(cond
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
(if (null? varfarg)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
varfarg vararg)
(error "multiple splatted args cannot be fused into a single broadcast"))))
((number? arg) ; inline numeric literals
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg arg) renames)
varfarg vararg))
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
; (note: calling memq for every arg is O(length(args)^2) ...
; ... would be better to replace with a hash table if args is long)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
varfarg vararg))
(else
(cf (cdr old-fargs) (cdr old-args)
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '())))
(let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args)
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))))
new-fargs new-args renames farg arg)
(if (eq? (cadr vararg) (cadr arg))
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
varfarg vararg)
(error "multiple splatted args cannot be fused into a single broadcast"))))
((number? arg) ; inline numeric literals
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg arg) renames)
varfarg vararg))
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
; (note: calling memq for every arg is O(length(args)^2) ...
; ... would be better to replace with a hash table if args is long)
(cf (cdr old-fargs) (cdr old-args)
new-fargs new-args
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
varfarg vararg))
(else
(cf (cdr old-fargs) (cdr old-args)
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
(cf (cdadr f) args '() '() '() '() '()))
e)) ; (not (fuse? e))
(let ((e (compress-fuse (dot-to-fuse rhs)))) ; an expression '(fuse func args) if expr is a dot call
(if (fuse? e)
(if (null? lhs)
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))
(expand-forms `(call broadcast! ,(from-lambda (cadr e)) ,lhs ,@(caddr e))))
(if (null? lhs)
(expand-forms e)
(expand-forms `(call broadcast! identity ,lhs ,e))))))

;; table mapping expression head to a function expanding that form
(define expand-table
Expand Down Expand Up @@ -1697,13 +1707,11 @@

'|.|
(lambda (e) ; e = (|.| f x)
(let ((f (cadr e))
(x (caddr e)))
(if (getfield-field? x)
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
; otherwise, came from f.(args...) --> broadcast(f, args...),
; where we want to fuse with any nested broadcast calls.
(expand-fuse-broadcast f (cdr x)))))
(expand-fuse-broadcast '() e))

'.=
(lambda (e)
(expand-fuse-broadcast (cadr e) (caddr e)))

'|<:| syntactic-op-to-call
'|>:| syntactic-op-to-call
Expand Down Expand Up @@ -2008,11 +2016,16 @@
'%= lower-update-op
'.%= lower-update-op
'|\|=| lower-update-op
'|.\|=| lower-update-op
'&= lower-update-op
'.&= lower-update-op
'$= lower-update-op
'<<= lower-update-op
'.<<= lower-update-op
'>>= lower-update-op
'.>>= lower-update-op
'>>>= lower-update-op
'.>>>= lower-update-op

':
(lambda (e)
Expand Down
19 changes: 19 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,25 @@ let x = [1:4;]
@test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1)
end

# PR #17510: Fused in-place assignment
let x = [1:4;], y = x
y .= 2:5
@test y === x == [2:5;]
y .= factorial.(x)
@test y === x == [2,6,24,120]
y .= 7
@test y === x == [7,7,7,7]
y .= factorial.(3)
@test y === x == [6,6,6,6]
f17510() = 9
y .= f17510.()
@test y === x == [9,9,9,9]
y .-= 1
@test y === x == [8,8,8,8]
y .-= 1:4
@test y === x == [7,6,5,4]
end

# PR 16988
@test Base.promote_op(+, Bool) === Int
@test isa(broadcast(+, [true]), Array{Int,1})
Expand Down
4 changes: 4 additions & 0 deletions test/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,7 @@ end
@test repr(:(x for x in y if aa for z in w if bb)) == ":(x for x = y if aa for z = w if bb)"
@test repr(:([x for x = y])) == ":([x for x = y])"
@test repr(:([x for x = y if z])) == ":([x for x = y if z])"

for op in (:(.=), :(.+=), :(.&=))
@test repr(parse("x $op y")) == ":(x $op y)"
end