diff --git a/NEWS.md b/NEWS.md index c7c03bfa164f2..74ff3ef6d259f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,8 +10,11 @@ New language features * Generators and comprehensions support filtering using `if` ([#550]) and nested iteration using multiple `for` keywords ([#4867]). - * Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]), + * Fused broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]), and nested `f.(g.(args...))` calls are fused into a single `broadcast` loop ([#17300]). + Similarly, the syntax `x .= ...` is equivalent to a `broadcast!(identity, x, ...)` + call and fuses with nested "dot" calls; also, `x .+= y` and similar is now + equivalent to `x .= x .+ y`, rather than `=` ([#17510]). * Macro expander functions are now generic, so macros can have multiple definitions (e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]). @@ -355,3 +358,4 @@ Deprecated or removed [#17393]: https://github.com/JuliaLang/julia/issues/17393 [#17402]: https://github.com/JuliaLang/julia/issues/17402 [#17404]: https://github.com/JuliaLang/julia/issues/17404 +[#17510]: https://github.com/JuliaLang/julia/issues/17510 diff --git a/base/broadcast.jl b/base/broadcast.jl index cca42e514c214..62e96bc9ee4b6 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -15,6 +15,15 @@ export broadcast_getindex, broadcast_setindex! broadcast(f) = f() broadcast(f, x::Number...) = f(x...) +# special cases for "X .= ..." (broadcast!) assignments +broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x) +broadcast!(f, X::AbstractArray) = fill!(X, f()) +broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...)) +function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N}) + check_broadcast_shape(size(x), size(y)) + copy!(x, y) +end + ## Calculate the broadcast shape of the arguments, or error if incompatible # array inputs broadcast_shape() = () diff --git a/base/show.jl b/base/show.jl index 51b04e057ae47..b6679dcc62c80 100644 --- a/base/show.jl +++ b/base/show.jl @@ -408,8 +408,10 @@ show_unquoted(io::IO, ex, ::Int,::Int) = show(io, ex) const indent_width = 4 const quoted_syms = Set{Symbol}([:(:),:(::),:(:=),:(=),:(==),:(!=),:(===),:(!==),:(=>),:(>=),:(<=)]) const uni_ops = Set{Symbol}([:(+), :(-), :(!), :(¬), :(~), :(<:), :(>:), :(√), :(∛), :(∜)]) -const expr_infix_wide = Set{Symbol}([:(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(&=), - :(|=), :($=), :(>>>=), :(>>=), :(<<=), :(&&), :(||), :(<:), :(=>), :(÷=)]) +const expr_infix_wide = Set{Symbol}([ + :(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(^=), :(&=), :(|=), :(÷=), :(%=), :(>>>=), :(>>=), :(<<=), + :(.=), :(.+=), :(.-=), :(.*=), :(./=), :(.\=), :(.^=), :(.&=), :(.|=), :(.÷=), :(.%=), :(.>>>=), :(.>>=), :(.<<=), + :(&&), :(||), :(<:), :(=>), :($=)]) const expr_infix = Set{Symbol}([:(:), :(->), Symbol("::")]) const expr_infix_any = union(expr_infix, expr_infix_wide) const all_ops = union(quoted_syms, uni_ops, expr_infix_any) diff --git a/doc/manual/arrays.rst b/doc/manual/arrays.rst index 44100b50b9094..ceadc254dc644 100644 --- a/doc/manual/arrays.rst +++ b/doc/manual/arrays.rst @@ -568,7 +568,7 @@ function elementwise: 1.71056 0.847604 1.73659 0.873631 -Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:.). +Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:). Implementation -------------- diff --git a/doc/manual/functions.rst b/doc/manual/functions.rst index 0b1b77aef8102..e898f663201ef 100644 --- a/doc/manual/functions.rst +++ b/doc/manual/functions.rst @@ -652,9 +652,20 @@ the fusion stops as soon as a "non-dot" function is encountered; for example, in ``sin.(sort(cos.(X)))`` the ``sin`` and ``cos`` loops cannot be merged because of the intervening ``sort`` function. +Finally, the maximum efficiency is typically achieved when the output +array of a vectorized operation is *pre-allocated*, so that repeated +calls do not allocate new arrays over and over again for the results +(:ref:`man-preallocation`:). A convenient syntax for this is +``X .= ...``, which is equivalent to ``broadcast!(identity, X, ...)`` +except that, as above, the ``broadcast!`` loop is fused with any nested +"dot" calls. For example, ``X .= sin.(Y)`` is equivalent to +``broadcast!(sin, X, Y)``, overwriting ``X`` with ``sin.(Y)`` in-place. + (In future versions of Julia, operators like ``.*`` will also be handled with the same mechanism: they will be equivalent to ``broadcast`` calls and -will be fused with other nested "dot" calls.) +will be fused with other nested "dot" calls. ``x .+= y`` is equivalent +to ``x .= x .+ y`` and will eventually result in a fused in-place assignment. +Similarly for ``.*=`` etcetera.) Further Reading --------------- diff --git a/doc/manual/performance-tips.rst b/doc/manual/performance-tips.rst index eb092420787bc..f98db0d5b8130 100644 --- a/doc/manual/performance-tips.rst +++ b/doc/manual/performance-tips.rst @@ -944,7 +944,10 @@ above, we could have passed a :class:`SubArray` rather than an :class:`Array`, had we so desired. Taken to its extreme, pre-allocation can make your code uglier, so -performance measurements and some judgment may be required. +performance measurements and some judgment may be required. However, +for "vectorized" (element-wise) functions, the convenient syntax +``x .= f.(y)`` can be used for in-place operations with fused loops +and no temporary arrays (:ref:`dot-vectorizing`). Avoid string interpolation for I/O diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 4cfcf09e11fb6..1531efc41777c 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1418,12 +1418,12 @@ `(call ,(cadr e) ,(expand-forms a) ,(expand-forms b)))))) ;; convert `a+=b` to `a=a+b` -(define (expand-update-operator- op lhs rhs declT) +(define (expand-update-operator- op op= lhs rhs declT) (let ((e (remove-argument-side-effects lhs))) `(block ,@(cdr e) ,(if (null? declT) - `(= ,(car e) (call ,op ,(car e) ,rhs)) - `(= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs)))))) + `(,op= ,(car e) (call ,op ,(car e) ,rhs)) + `(,op= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs)))))) (define (partially-expand-ref e) (let ((a (cadr e)) @@ -1443,7 +1443,7 @@ ,@(append stmts stuff) (call getindex ,arr ,@new-idxs)))))) -(define (expand-update-operator op lhs rhs . declT) +(define (expand-update-operator op op= lhs rhs . declT) (cond ((and (pair? lhs) (eq? (car lhs) 'ref)) ;; expand indexing inside op= first, to remove "end" and ":" (let* ((ex (partially-expand-ref lhs)) @@ -1451,23 +1451,24 @@ (refex (last (cdr ex))) (nuref `(ref ,(caddr refex) ,@(cdddr refex)))) `(block ,@stmts - ,(expand-update-operator- op nuref rhs declT)))) + ,(expand-update-operator- op op= nuref rhs declT)))) ((and (pair? lhs) (eq? (car lhs) '|::|)) ;; (+= (:: x T) rhs) (let ((e (remove-argument-side-effects (cadr lhs))) (T (caddr lhs))) `(block ,@(cdr e) - ,(expand-update-operator op (car e) rhs T)))) + ,(expand-update-operator op op= (car e) rhs T)))) (else - (expand-update-operator- op lhs rhs declT)))) + (expand-update-operator- op op= lhs rhs declT)))) (define (lower-update-op e) (expand-forms - (expand-update-operator - (let ((str (string (car e)))) - (symbol (string.sub str 0 (- (length str) 1)))) - (cadr e) - (caddr e)))) + (let ((str (string (car e)))) + (expand-update-operator + (symbol (string.sub str 0 (- (length str) 1))) + (if (= (string.char str 0) #\.) '.= '=) + (cadr e) + (caddr e))))) (define (expand-and e) (let ((e (cdr (flatten-ex '&& e)))) @@ -1546,11 +1547,9 @@ (cadr expr) ;; eta reduce `x->f(x)` => `f` `(-> ,argname (block ,@splat ,expr))))) -(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call - (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))) - -;; fuse nested calls to f.(args...) into a single broadcast call -(define (expand-fuse-broadcast f args) +; fuse nested calls to expr == f.(args...) into a single broadcast call, +; or a broadcast! call if lhs is non-null. +(define (expand-fuse-broadcast lhs rhs) (define (fuse? e) (and (pair? e) (eq? (car e) 'fuse))) (define (anyfuse? exprs) (if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs))))) @@ -1594,28 +1593,31 @@ oldarg)) fargs args))) (let ,fbody ,@(reverse (fuse-lets fargs args '())))))) - (define (make-fuse f args) ; check for nested (fuse f args) exprs and combine - (define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args - (define (sk args kwargs pargs) - (if (null? args) - (cons kwargs pargs) - (if (kwarg? (car args)) - (sk (cdr args) (cons (car args) kwargs) pargs) - (sk (cdr args) kwargs (cons (car args) pargs))))) - (if (has-parameters? args) - (sk (reverse (cdr args)) (cdar args) '()) - (sk (reverse args) '() '()))) - (define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args) - (if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e)))) - (make-fuse (cadr e) (cdaddr e)) - e)) - (let* ((kws.args (split-kwargs args)) - (kws (car kws.args)) - (args (cdr kws.args)) ; fusing occurs on positional args only - (args_ (map dot-to-fuse args))) - (if (anyfuse? args_) - `(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_)) - `(fuse ,(to-lambda f args kws) ,args_)))) + (define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args) + (define (make-fuse f args) ; check for nested (fuse f args) exprs and combine + (define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args + (define (sk args kwargs pargs) + (if (null? args) + (cons kwargs pargs) + (if (kwarg? (car args)) + (sk (cdr args) (cons (car args) kwargs) pargs) + (sk (cdr args) kwargs (cons (car args) pargs))))) + (if (has-parameters? args) + (sk (reverse (cdr args)) (cdar args) '()) + (sk (reverse args) '() '()))) + (let* ((kws.args (split-kwargs args)) + (kws (car kws.args)) + (args (cdr kws.args)) ; fusing occurs on positional args only + (args_ (map dot-to-fuse args))) + (if (anyfuse? args_) + `(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_)) + `(fuse ,(to-lambda f args kws) ,args_)))) + (if (and (pair? e) (eq? (car e) '|.|)) + (let ((f (cadr e)) (x (caddr e))) + (if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)) + `(call (core getfield) ,f ,x) + (make-fuse f (cdr x)))) + e)) ; given e == (fuse lambda args), compress the argument list by removing (pure) ; duplicates in args, inlining literals, and moving any varargs to the end: (define (compress-fuse e) @@ -1623,43 +1625,51 @@ (if (eq? arg (car args)) (car fargs) (findfarg arg (cdr args) (cdr fargs)))) - (let ((f (cadr e)) - (args (caddr e))) - (define (cf old-fargs old-args new-fargs new-args renames varfarg vararg) - (if (null? old-args) - (let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs))) - (nargs (if (null? vararg) new-args (cons vararg new-args)))) - `(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames)) - ,(reverse nargs))) - (let ((farg (car old-fargs)) (arg (car old-args))) - (cond - ((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument - (if (null? varfarg) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args renames farg arg) - (if (eq? (cadr vararg) (cadr arg)) + (if (fuse? e) + (let ((f (cadr e)) + (args (caddr e))) + (define (cf old-fargs old-args new-fargs new-args renames varfarg vararg) + (if (null? old-args) + (let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs))) + (nargs (if (null? vararg) new-args (cons vararg new-args)))) + `(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames)) + ,(reverse nargs))) + (let ((farg (car old-fargs)) (arg (car old-args))) + (cond + ((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument + (if (null? varfarg) (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames) - varfarg vararg) - (error "multiple splatted args cannot be fused into a single broadcast")))) - ((number? arg) ; inline numeric literals - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args - (cons (cons farg arg) renames) - varfarg vararg)) - ((and (symbol? arg) (memq arg new-args)) ; combine duplicate args - ; (note: calling memq for every arg is O(length(args)^2) ... - ; ... would be better to replace with a hash table if args is long) - (cf (cdr old-fargs) (cdr old-args) - new-fargs new-args - (cons (cons farg (findfarg arg new-args new-fargs)) renames) - varfarg vararg)) - (else - (cf (cdr old-fargs) (cdr old-args) - (cons farg new-fargs) (cons arg new-args) renames varfarg vararg)))))) - (cf (cdadr f) args '() '() '() '() '()))) - (let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args) - (expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e))))) + new-fargs new-args renames farg arg) + (if (eq? (cadr vararg) (cadr arg)) + (cf (cdr old-fargs) (cdr old-args) + new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames) + varfarg vararg) + (error "multiple splatted args cannot be fused into a single broadcast")))) + ((number? arg) ; inline numeric literals + (cf (cdr old-fargs) (cdr old-args) + new-fargs new-args + (cons (cons farg arg) renames) + varfarg vararg)) + ((and (symbol? arg) (memq arg new-args)) ; combine duplicate args + ; (note: calling memq for every arg is O(length(args)^2) ... + ; ... would be better to replace with a hash table if args is long) + (cf (cdr old-fargs) (cdr old-args) + new-fargs new-args + (cons (cons farg (findfarg arg new-args new-fargs)) renames) + varfarg vararg)) + (else + (cf (cdr old-fargs) (cdr old-args) + (cons farg new-fargs) (cons arg new-args) renames varfarg vararg)))))) + (cf (cdadr f) args '() '() '() '() '())) + e)) ; (not (fuse? e)) + (let ((e (compress-fuse (dot-to-fuse rhs)))) ; an expression '(fuse func args) if expr is a dot call + (if (fuse? e) + (if (null? lhs) + (expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e))) + (expand-forms `(call broadcast! ,(from-lambda (cadr e)) ,lhs ,@(caddr e)))) + (if (null? lhs) + (expand-forms e) + (expand-forms `(call broadcast! identity ,lhs ,e)))))) ;; table mapping expression head to a function expanding that form (define expand-table @@ -1697,13 +1707,11 @@ '|.| (lambda (e) ; e = (|.| f x) - (let ((f (cadr e)) - (x (caddr e))) - (if (getfield-field? x) - `(call (core getfield) ,(expand-forms f) ,(expand-forms x)) - ; otherwise, came from f.(args...) --> broadcast(f, args...), - ; where we want to fuse with any nested broadcast calls. - (expand-fuse-broadcast f (cdr x))))) + (expand-fuse-broadcast '() e)) + + '.= + (lambda (e) + (expand-fuse-broadcast (cadr e) (caddr e))) '|<:| syntactic-op-to-call '|>:| syntactic-op-to-call @@ -2008,11 +2016,16 @@ '%= lower-update-op '.%= lower-update-op '|\|=| lower-update-op + '|.\|=| lower-update-op '&= lower-update-op + '.&= lower-update-op '$= lower-update-op '<<= lower-update-op + '.<<= lower-update-op '>>= lower-update-op + '.>>= lower-update-op '>>>= lower-update-op + '.>>>= lower-update-op ': (lambda (e) diff --git a/test/broadcast.jl b/test/broadcast.jl index 2d6adca165282..31a2166a4ee14 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -248,6 +248,25 @@ let x = [1:4;] @test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1) end +# PR #17510: Fused in-place assignment +let x = [1:4;], y = x + y .= 2:5 + @test y === x == [2:5;] + y .= factorial.(x) + @test y === x == [2,6,24,120] + y .= 7 + @test y === x == [7,7,7,7] + y .= factorial.(3) + @test y === x == [6,6,6,6] + f17510() = 9 + y .= f17510.() + @test y === x == [9,9,9,9] + y .-= 1 + @test y === x == [8,8,8,8] + y .-= 1:4 + @test y === x == [7,6,5,4] +end + # PR 16988 @test Base.promote_op(+, Bool) === Int @test isa(broadcast(+, [true]), Array{Int,1}) diff --git a/test/show.jl b/test/show.jl index 3eadf5e1e434c..55234b9b46e09 100644 --- a/test/show.jl +++ b/test/show.jl @@ -515,3 +515,7 @@ end @test repr(:(x for x in y if aa for z in w if bb)) == ":(x for x = y if aa for z = w if bb)" @test repr(:([x for x = y])) == ":([x for x = y])" @test repr(:([x for x = y if z])) == ":([x for x = y if z])" + +for op in (:(.=), :(.+=), :(.&=)) + @test repr(parse("x $op y")) == ":(x $op y)" +end