-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Let muladd
accept arrays
#37065
Let muladd
accept arrays
#37065
Conversation
Now has tests, and |
I think this change broke the following: julia> using LinearAlgebra
julia> muladd([2 2; 1 2], [2 1; 1 2], I)
2×2 Array{Int64,2}:
7 6
4 6 while currently on master: julia> using LinearAlgebra
julia> muladd([2 2; 1 2], [2 1; 1 2], I)
ERROR: MethodError: no method matching length(::UniformScaling{Bool})
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:54
length(::Union{Adjoint{T, S}, Transpose{T, S}} where S where T) at /home/mateusz/repos/julia/julia/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:194
length(::IdDict) at iddict.jl:136
...
Stacktrace:
[1] _similar_for(c::UnitRange{Int64}, #unused#::Type{Bool}, itr::UniformScaling{Bool}, #unused#::Base.HasLength)
@ Base ./array.jl:575
[2] _collect(cont::UnitRange{Int64}, itr::UniformScaling{Bool}, #unused#::Base.HasEltype, isz::Base.HasLength)
@ Base ./array.jl:608
[3] collect(itr::UniformScaling{Bool})
@ Base ./array.jl:602
[4] broadcastable(x::UniformScaling{Bool})
@ Base.Broadcast ./broadcast.jl:682
[5] broadcasted
@ ./broadcast.jl:1303 [inlined]
[6] muladd(A::Matrix{Int64}, B::Matrix{Int64}, z::UniformScaling{Bool})
@ LinearAlgebra ~/repos/julia/julia/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:214
[7] top-level scope
@ REPL[12]:1
julia> versioninfo()
Julia Version 1.6.0-DEV.1384
Commit 4711fc3cdf (2020-10-30 14:50 UTC)
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: Intel(R) Core(TM) i7-4800MQ CPU @ 2.70GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.0 (ORCJIT, haswell) |
Thanks for reporting this! I guess we should have restricted the third argument here to |
Oops, I was working on the assumption that the sole existing use for The error here is from However, if there are other existing uses of julia> muladd(Diagonal([1,2]), [3,4], [5,6])
2-element SparseArrays.SparseVector{Int64, Int64} with 2 stored entries:
[1] = 8
[2] = 14
julia> muladd(Diagonal([1,2]), [3,4], [5,6]) # on Julia 1.5
2-element Array{Int64,1}:
8
14 |
Wouldn't replacing |
Yea, broadcasting in |
Yes, good point. I don't have any good ideas then 🙂 . |
Thanks for the links. I guess this use of But I'm a little worried about all the possibilities with |
I was thinking whether there is some way to deduce the type of |
I think quite a lot of the structured types returned by di = Diagonal(1:3)
ut = UpperTriangular(rand(1:9, 3,3))
di * di .+ di isa Diagonal
ut * ut .+ ut isa UpperTriangular
ut * ut .+ di isa UpperTriangular And it looks like only similar(di, 3,3) isa SparseMatrixCSC
similar(parent(di), 3,3) isa Matrix
similar(ut, 3,3) isa Matrix So maybe something like this is almost enough? const UpperUnion = Union{UpperTriangular, UnitUpperTriangular, Diagonal}
function Base.muladd(A::UpperUnion, B::UpperUnion, z::UpperUnion)
T = promote_type(eltype(A), eltype(B), eltype(z))
C = similar(parent(B), T, axes(A,1), axes(B,2))
C .= z
mul!(C, A, B, true, true)
UpperTriangular(C)
end
Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal) =
Diagonal(A.diag .* B.diag .+ z.diag)
Base.muladd(A::Union{Diagonal, UniformScaling}, B::Union{Diagonal, UniformScaling}, z::Union{Diagonal, UniformScaling}) =
Diagonal(_diag_or_value(A) .* _diag_or_value(B) .+ _diag_or_value(z))
Base.muladd(A::UniformScaling, B::UniformScaling, z::UniformScaling) =
UniformScaling(A.λ * B.λ + z.λ)
_diag_or_value(A::Diagonal) = A.diag
_diag_or_value(A::UniformScaling) = A.λ |
I think there may be a lot more patterns where |
Sparse matrices appear to be handled correctly, in that m = rand(3,3); sm = sparse(m)
v = rand(3); sv = sparse(v);
sm * sm .+ 1 isa SparseMatrixCSC
sm * sm .+ v isa SparseMatrixCSC
sm * sm .+ m isa SparseMatrixCSC
sm * sm + 1 # error
sm * sm + v # error
sm * sm + m isa Matrix The snippet above should handle dense triangular cases. Acting with Still hoping a little bit that it won't be necessary to go all the way down to StridedArray, but we may end up there. |
I don't quite understand. Right now in |
Sure, matching the existing choices for every possible trio of types probably can't be done with some simple rule. Besides the SymTridiagonal(m+m') * m isa Matrix
m * SymTridiagonal(m+m') isa SparseMatrixCSC I hope nobody's relying on that. But more importantly, someone may be relying on this, which ought not to go via using StaticArrays
muladd(SA[1 2; 3 4], SA[5,6], SA[7,8]) Which argues for making it much more opt-in. |
They are the way they are because, as I said, we call |
Thanks I hadn't seen that discussion. There's "why" as in what method, and "why" as in what's desired... and it sounds like for my failed attempt at an exotic case nobody would want, someone wants the opposite. |
Not quite done but this is a branch: This supports triangular & diagonal matrices. It also supports |
Now at #38250. |
This adjusts #37065 to be much more cautious about what arrays it acts on: it calls mul! on StridedArrays, treats a few special types like Diagonal, UpperTriangular, and UniformScaling, and sends anything else to muladd(A,y,z) = A*y .+ z. However this broadcasting restricts the shape of z, mostly such that A*y .= z would work. That ensures you should get the same error from the mul!(::StridedMatrix, ...) method, as from the fallback broadcasting one. Both allow z of lower dimension than the existing muladd(x,y,z) = x*y+z. But x*y+z also allows z to have trailing dimensions, as long as they are of size 1. I made the broadcasting method allow these too, which I think should make this non-breaking. (I presume this is rarely used, and thus not worth sending to the fast method.) Structured matrices such as UpperTriangular should all go to x*y+z. Some combinations could be made more efficient but it gets complicated. Only the case of 3 diagonals is handled.
This extends
Base.muladd
to work on arrays, usingmul!
, which makes some combinations more efficient:Motivated by things like FluxML/Flux.jl#1272, where it would be convenient to fuse these two, rather than fusing addition & function application
tanh.((W*x) .+ b)
with broadcasting.Needs tests, and thinking aboutif anyone thinks this is a good idea.muladd(::Adjoint, ...)
&muladd(..., ::Number)
cases,