Skip to content

Commit

Permalink
Fix mapreduce on AdjOrTrans (#46605)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Karrasch <[email protected]>
Co-authored-by: Daniel Karrasch <[email protected]>
Co-authored-by: Martin Holters <[email protected]>
  • Loading branch information
3 people authored Sep 16, 2022
1 parent 174b893 commit 8c00e17
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 36 deletions.
14 changes: 12 additions & 2 deletions base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,21 @@ end
P
end

function Base._mapreduce_dim(f, op, init::Base._InitialValue, A::PermutedDimsArray, dims::Colon)
const CommutativeOps = Union{typeof(+),typeof(Base.add_sum),typeof(min),typeof(max),typeof(Base._extrema_rf),typeof(|),typeof(&)}

function Base._mapreduce_dim(f, op::CommutativeOps, init::Base._InitialValue, A::PermutedDimsArray, dims::Colon)
Base._mapreduce_dim(f, op, init, parent(A), dims)
end
function Base._mapreduce_dim(f::typeof(identity), op::Union{typeof(Base.mul_prod),typeof(*)}, init::Base._InitialValue, A::PermutedDimsArray{<:Union{Real,Complex}}, dims::Colon)
Base._mapreduce_dim(f, op, init, parent(A), dims)
end

function Base.mapreducedim!(f, op, B::AbstractArray{T,N}, A::PermutedDimsArray{T,N,perm,iperm}) where {T,N,perm,iperm}
function Base.mapreducedim!(f, op::CommutativeOps, B::AbstractArray{T,N}, A::PermutedDimsArray{S,N,perm,iperm}) where {T,S,N,perm,iperm}
C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
B
end
function Base.mapreducedim!(f::typeof(identity), op::Union{typeof(Base.mul_prod),typeof(*)}, B::AbstractArray{T,N}, A::PermutedDimsArray{<:Union{Real,Complex},N,perm,iperm}) where {T,N,perm,iperm}
C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
B
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,
@propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
Splat
using Base.Broadcast: Broadcasted, broadcasted
using Base.PermutedDimsArrays: CommutativeOps
using OpenBLAS_jll
using libblastrampoline_jll
import Libdl
Expand Down
44 changes: 29 additions & 15 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,22 +378,36 @@ Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...)


### reductions
# faster to sum the Array than to work through the wrapper
Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Transpose, dims::Colon) =
transpose(Base._mapreduce_dim(_sandwich(transpose, f), _sandwich(transpose, op), init, parent(A), dims))
Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Adjoint, dims::Colon) =
adjoint(Base._mapreduce_dim(_sandwich(adjoint, f), _sandwich(adjoint, op), init, parent(A), dims))
# faster to sum the Array than to work through the wrapper (but only in commutative reduction ops as in Base/permuteddimsarray.jl)
Base._mapreduce_dim(f, op::CommutativeOps, init::Base._InitialValue, A::Transpose, dims::Colon) =
Base._mapreduce_dim(ftranspose, op, init, parent(A), dims)
Base._mapreduce_dim(f, op::CommutativeOps, init::Base._InitialValue, A::Adjoint, dims::Colon) =
Base._mapreduce_dim(fadjoint, op, init, parent(A), dims)
# in prod, use fast path only in the commutative case to avoid surprises
Base._mapreduce_dim(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, init::Base._InitialValue, A::Transpose{<:Union{Real,Complex}}, dims::Colon) =
Base._mapreduce_dim(ftranspose, op, init, parent(A), dims)
Base._mapreduce_dim(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, init::Base._InitialValue, A::Adjoint{<:Union{Real,Complex}}, dims::Colon) =
Base._mapreduce_dim(fadjoint, op, init, parent(A), dims)
# count allows for optimization only if the parent array has Bool eltype
Base._count(::typeof(identity), A::Transpose{Bool}, ::Colon, init) = Base._count(identity, parent(A), :, init)
Base._count(::typeof(identity), A::Adjoint{Bool}, ::Colon, init) = Base._count(identity, parent(A), :, init)
Base._any(f, A::Transpose, ::Colon) = Base._any(ftranspose, parent(A), :)
Base._any(f, A::Adjoint, ::Colon) = Base._any(fadjoint, parent(A), :)
Base._all(f, A::Transpose, ::Colon) = Base._all(ftranspose, parent(A), :)
Base._all(f, A::Adjoint, ::Colon) = Base._all(fadjoint, parent(A), :)
# sum(A'; dims)
Base.mapreducedim!(f, op, B::AbstractArray, A::TransposeAbsMat) =
transpose(Base.mapreducedim!(_sandwich(transpose, f), _sandwich(transpose, op), transpose(B), parent(A)))
Base.mapreducedim!(f, op, B::AbstractArray, A::AdjointAbsMat) =
adjoint(Base.mapreducedim!(_sandwich(adjoint, f), _sandwich(adjoint, op), adjoint(B), parent(A)))

_sandwich(adj::Function, fun) = (xs...,) -> adj(fun(map(adj, xs)...))
for fun in [:identity, :add_sum, :mul_prod] #, :max, :min]
@eval _sandwich(::Function, ::typeof(Base.$fun)) = Base.$fun
end

Base.mapreducedim!(f, op::CommutativeOps, B::AbstractArray, A::TransposeAbsMat) =
(Base.mapreducedim!(ftranspose, op, switch_dim12(B), parent(A)); B)
Base.mapreducedim!(f, op::CommutativeOps, B::AbstractArray, A::AdjointAbsMat) =
(Base.mapreducedim!(fadjoint, op, switch_dim12(B), parent(A)); B)
Base.mapreducedim!(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, B::AbstractArray, A::TransposeAbsMat{<:Union{Real,Complex}}) =
(Base.mapreducedim!(ftranspose, op, switch_dim12(B), parent(A)); B)
Base.mapreducedim!(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, B::AbstractArray, A::AdjointAbsMat{<:Union{Real,Complex}}) =
(Base.mapreducedim!(fadjoint, op, switch_dim12(B), parent(A)); B)

switch_dim12(B::AbstractVector) = permutedims(B)
switch_dim12(B::AbstractArray{<:Any,0}) = B
switch_dim12(B::AbstractArray) = PermutedDimsArray(B, (2, 1, ntuple(Base.Fix1(+,2), ndims(B) - 2)...))

### linear algebra

Expand Down
64 changes: 46 additions & 18 deletions stdlib/LinearAlgebra/test/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,24 +588,52 @@ end
@test transpose(Int[]) * Int[] == 0
end

@testset "reductions: $adjtrans" for adjtrans in [transpose, adjoint]
mat = rand(ComplexF64, 3,5)
@test sum(adjtrans(mat)) sum(collect(adjtrans(mat)))
@test sum(adjtrans(mat), dims=1) sum(collect(adjtrans(mat)), dims=1)
@test sum(adjtrans(mat), dims=(1,2)) sum(collect(adjtrans(mat)), dims=(1,2))

@test sum(imag, adjtrans(mat)) sum(imag, collect(adjtrans(mat)))
@test sum(imag, adjtrans(mat), dims=1) sum(imag, collect(adjtrans(mat)), dims=1)

mat = [rand(ComplexF64,2,2) for _ in 1:3, _ in 1:5]
@test sum(adjtrans(mat)) sum(collect(adjtrans(mat)))
@test sum(adjtrans(mat), dims=1) sum(collect(adjtrans(mat)), dims=1)
@test sum(adjtrans(mat), dims=(1,2)) sum(collect(adjtrans(mat)), dims=(1,2))

@test sum(imag, adjtrans(mat)) sum(imag, collect(adjtrans(mat)))
@test sum(x -> x[1,2], adjtrans(mat)) sum(x -> x[1,2], collect(adjtrans(mat)))
@test sum(imag, adjtrans(mat), dims=1) sum(imag, collect(adjtrans(mat)), dims=1)
@test sum(x -> x[1,2], adjtrans(mat), dims=1) sum(x -> x[1,2], collect(adjtrans(mat)), dims=1)
@testset "reductions: $adjtrans" for adjtrans in (transpose, adjoint)
for (reduction, reduction!, op) in ((sum, sum!, +), (prod, prod!, *), (minimum, minimum!, min), (maximum, maximum!, max))
T = op in (max, min) ? Float64 : ComplexF64
mat = rand(T, 3,5)
rd1 = zeros(T, 1, 3)
rd2 = zeros(T, 5, 1)
rd3 = zeros(T, 1, 1)
@test reduction(adjtrans(mat)) reduction(copy(adjtrans(mat)))
@test reduction(adjtrans(mat), dims=1) reduction(copy(adjtrans(mat)), dims=1)
@test reduction(adjtrans(mat), dims=2) reduction(copy(adjtrans(mat)), dims=2)
@test reduction(adjtrans(mat), dims=(1,2)) reduction(copy(adjtrans(mat)), dims=(1,2))

@test reduction!(rd1, adjtrans(mat)) reduction!(rd1, copy(adjtrans(mat)))
@test reduction!(rd2, adjtrans(mat)) reduction!(rd2, copy(adjtrans(mat)))
@test reduction!(rd3, adjtrans(mat)) reduction!(rd3, copy(adjtrans(mat)))

@test reduction(imag, adjtrans(mat)) reduction(imag, copy(adjtrans(mat)))
@test reduction(imag, adjtrans(mat), dims=1) reduction(imag, copy(adjtrans(mat)), dims=1)
@test reduction(imag, adjtrans(mat), dims=2) reduction(imag, copy(adjtrans(mat)), dims=2)
@test reduction(imag, adjtrans(mat), dims=(1,2)) reduction(imag, copy(adjtrans(mat)), dims=(1,2))

@test Base.mapreducedim!(imag, op, rd1, adjtrans(mat)) Base.mapreducedim!(imag, op, rd1, copy(adjtrans(mat)))
@test Base.mapreducedim!(imag, op, rd2, adjtrans(mat)) Base.mapreducedim!(imag, op, rd2, copy(adjtrans(mat)))
@test Base.mapreducedim!(imag, op, rd3, adjtrans(mat)) Base.mapreducedim!(imag, op, rd3, copy(adjtrans(mat)))

op in (max, min) && continue
mat = [rand(T,2,2) for _ in 1:3, _ in 1:5]
rd1 = fill(zeros(T, 2, 2), 1, 3)
rd2 = fill(zeros(T, 2, 2), 5, 1)
rd3 = fill(zeros(T, 2, 2), 1, 1)
@test reduction(adjtrans(mat)) reduction(copy(adjtrans(mat)))
@test reduction(adjtrans(mat), dims=1) reduction(copy(adjtrans(mat)), dims=1)
@test reduction(adjtrans(mat), dims=2) reduction(copy(adjtrans(mat)), dims=2)
@test reduction(adjtrans(mat), dims=(1,2)) reduction(copy(adjtrans(mat)), dims=(1,2))

@test reduction(imag, adjtrans(mat)) reduction(imag, copy(adjtrans(mat)))
@test reduction(x -> x[1,2], adjtrans(mat)) reduction(x -> x[1,2], copy(adjtrans(mat)))
@test reduction(imag, adjtrans(mat), dims=1) reduction(imag, copy(adjtrans(mat)), dims=1)
@test reduction(x -> x[1,2], adjtrans(mat), dims=1) reduction(x -> x[1,2], copy(adjtrans(mat)), dims=1)
end
# see #46605
Ac = [1 2; 3 4]'
@test mapreduce(identity, (x, y) -> 10x+y, copy(Ac)) == mapreduce(identity, (x, y) -> 10x+y, Ac) == 1234
@test extrema([3,7,4]') == (3, 7)
@test mapreduce(x -> [x;;;], +, [1, 2, 3]') == sum(x -> [x;;;], [1, 2, 3]') == [6;;;]
@test mapreduce(string, *, [1 2; 3 4]') == mapreduce(string, *, copy([1 2; 3 4]')) == "1234"
end

end # module TestAdjointTranspose
6 changes: 5 additions & 1 deletion test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,14 +708,18 @@ end
ap = PermutedDimsArray(Array(a), (2,1,3))
@test strides(ap) == (3,1,12)

for A in [rand(1,2,3,4),rand(2,2,2,2),rand(5,6,5,6),rand(1,1,1,1)]
for A in [rand(1,2,3,4),rand(2,2,2,2),rand(5,6,5,6),rand(1,1,1,1), [rand(ComplexF64, 2,2) for _ in 1:2, _ in 1:3, _ in 1:2, _ in 1:4]]
perm = randperm(4)
@test isequal(A,permutedims(permutedims(A,perm),invperm(perm)))
@test isequal(A,permutedims(permutedims(A,invperm(perm)),perm))

@test sum(permutedims(A,perm)) sum(PermutedDimsArray(A,perm))
@test sum(permutedims(A,perm), dims=2) sum(PermutedDimsArray(A,perm), dims=2)
@test sum(permutedims(A,perm), dims=(2,4)) sum(PermutedDimsArray(A,perm), dims=(2,4))

@test prod(permutedims(A,perm)) prod(PermutedDimsArray(A,perm))
@test prod(permutedims(A,perm), dims=2) prod(PermutedDimsArray(A,perm), dims=2)
@test prod(permutedims(A,perm), dims=(2,4)) prod(PermutedDimsArray(A,perm), dims=(2,4))
end

m = [1 2; 3 4]
Expand Down

0 comments on commit 8c00e17

Please sign in to comment.