From 8c00e17908cd4d84e8366b2d44dc1af432c00d82 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Sep 2022 05:43:59 -0400 Subject: [PATCH] Fix `mapreduce` on `AdjOrTrans` (#46605) Co-authored-by: Daniel Karrasch Co-authored-by: Daniel Karrasch Co-authored-by: Martin Holters --- base/permuteddimsarray.jl | 14 ++++- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 1 + stdlib/LinearAlgebra/src/adjtrans.jl | 44 ++++++++++------ stdlib/LinearAlgebra/test/adjtrans.jl | 64 ++++++++++++++++------- test/arrayops.jl | 6 ++- 5 files changed, 93 insertions(+), 36 deletions(-) diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index ea1863de8b708..a8523013434b3 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -275,11 +275,21 @@ end P end -function Base._mapreduce_dim(f, op, init::Base._InitialValue, A::PermutedDimsArray, dims::Colon) +const CommutativeOps = Union{typeof(+),typeof(Base.add_sum),typeof(min),typeof(max),typeof(Base._extrema_rf),typeof(|),typeof(&)} + +function Base._mapreduce_dim(f, op::CommutativeOps, init::Base._InitialValue, A::PermutedDimsArray, dims::Colon) + Base._mapreduce_dim(f, op, init, parent(A), dims) +end +function Base._mapreduce_dim(f::typeof(identity), op::Union{typeof(Base.mul_prod),typeof(*)}, init::Base._InitialValue, A::PermutedDimsArray{<:Union{Real,Complex}}, dims::Colon) Base._mapreduce_dim(f, op, init, parent(A), dims) end -function Base.mapreducedim!(f, op, B::AbstractArray{T,N}, A::PermutedDimsArray{T,N,perm,iperm}) where {T,N,perm,iperm} +function Base.mapreducedim!(f, op::CommutativeOps, B::AbstractArray{T,N}, A::PermutedDimsArray{S,N,perm,iperm}) where {T,S,N,perm,iperm} + C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output + Base.mapreducedim!(f, op, C, parent(A)) + B +end +function Base.mapreducedim!(f::typeof(identity), op::Union{typeof(Base.mul_prod),typeof(*)}, B::AbstractArray{T,N}, A::PermutedDimsArray{<:Union{Real,Complex},N,perm,iperm}) where {T,N,perm,iperm} C = PermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output Base.mapreducedim!(f, op, C, parent(A)) B diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 95ebe1c0dfdc5..b107777e7ae85 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -20,6 +20,7 @@ using Base: IndexLinear, promote_eltype, promote_op, promote_typeof, @propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing, Splat using Base.Broadcast: Broadcasted, broadcasted +using Base.PermutedDimsArrays: CommutativeOps using OpenBLAS_jll using libblastrampoline_jll import Libdl diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index bcac7625259ab..1a67b7f69e24a 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -378,22 +378,36 @@ Broadcast.broadcast_preserving_zero_d(f, tvs::Union{Number,TransposeAbsVec}...) ### reductions -# faster to sum the Array than to work through the wrapper -Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Transpose, dims::Colon) = - transpose(Base._mapreduce_dim(_sandwich(transpose, f), _sandwich(transpose, op), init, parent(A), dims)) -Base._mapreduce_dim(f, op, init::Base._InitialValue, A::Adjoint, dims::Colon) = - adjoint(Base._mapreduce_dim(_sandwich(adjoint, f), _sandwich(adjoint, op), init, parent(A), dims)) +# faster to sum the Array than to work through the wrapper (but only in commutative reduction ops as in Base/permuteddimsarray.jl) +Base._mapreduce_dim(f, op::CommutativeOps, init::Base._InitialValue, A::Transpose, dims::Colon) = + Base._mapreduce_dim(f∘transpose, op, init, parent(A), dims) +Base._mapreduce_dim(f, op::CommutativeOps, init::Base._InitialValue, A::Adjoint, dims::Colon) = + Base._mapreduce_dim(f∘adjoint, op, init, parent(A), dims) +# in prod, use fast path only in the commutative case to avoid surprises +Base._mapreduce_dim(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, init::Base._InitialValue, A::Transpose{<:Union{Real,Complex}}, dims::Colon) = + Base._mapreduce_dim(f∘transpose, op, init, parent(A), dims) +Base._mapreduce_dim(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, init::Base._InitialValue, A::Adjoint{<:Union{Real,Complex}}, dims::Colon) = + Base._mapreduce_dim(f∘adjoint, op, init, parent(A), dims) +# count allows for optimization only if the parent array has Bool eltype +Base._count(::typeof(identity), A::Transpose{Bool}, ::Colon, init) = Base._count(identity, parent(A), :, init) +Base._count(::typeof(identity), A::Adjoint{Bool}, ::Colon, init) = Base._count(identity, parent(A), :, init) +Base._any(f, A::Transpose, ::Colon) = Base._any(f∘transpose, parent(A), :) +Base._any(f, A::Adjoint, ::Colon) = Base._any(f∘adjoint, parent(A), :) +Base._all(f, A::Transpose, ::Colon) = Base._all(f∘transpose, parent(A), :) +Base._all(f, A::Adjoint, ::Colon) = Base._all(f∘adjoint, parent(A), :) # sum(A'; dims) -Base.mapreducedim!(f, op, B::AbstractArray, A::TransposeAbsMat) = - transpose(Base.mapreducedim!(_sandwich(transpose, f), _sandwich(transpose, op), transpose(B), parent(A))) -Base.mapreducedim!(f, op, B::AbstractArray, A::AdjointAbsMat) = - adjoint(Base.mapreducedim!(_sandwich(adjoint, f), _sandwich(adjoint, op), adjoint(B), parent(A))) - -_sandwich(adj::Function, fun) = (xs...,) -> adj(fun(map(adj, xs)...)) -for fun in [:identity, :add_sum, :mul_prod] #, :max, :min] - @eval _sandwich(::Function, ::typeof(Base.$fun)) = Base.$fun -end - +Base.mapreducedim!(f, op::CommutativeOps, B::AbstractArray, A::TransposeAbsMat) = + (Base.mapreducedim!(f∘transpose, op, switch_dim12(B), parent(A)); B) +Base.mapreducedim!(f, op::CommutativeOps, B::AbstractArray, A::AdjointAbsMat) = + (Base.mapreducedim!(f∘adjoint, op, switch_dim12(B), parent(A)); B) +Base.mapreducedim!(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, B::AbstractArray, A::TransposeAbsMat{<:Union{Real,Complex}}) = + (Base.mapreducedim!(f∘transpose, op, switch_dim12(B), parent(A)); B) +Base.mapreducedim!(f::typeof(identity), op::Union{typeof(*),typeof(Base.mul_prod)}, B::AbstractArray, A::AdjointAbsMat{<:Union{Real,Complex}}) = + (Base.mapreducedim!(f∘adjoint, op, switch_dim12(B), parent(A)); B) + +switch_dim12(B::AbstractVector) = permutedims(B) +switch_dim12(B::AbstractArray{<:Any,0}) = B +switch_dim12(B::AbstractArray) = PermutedDimsArray(B, (2, 1, ntuple(Base.Fix1(+,2), ndims(B) - 2)...)) ### linear algebra diff --git a/stdlib/LinearAlgebra/test/adjtrans.jl b/stdlib/LinearAlgebra/test/adjtrans.jl index 7b782d463768d..e96ea28531d37 100644 --- a/stdlib/LinearAlgebra/test/adjtrans.jl +++ b/stdlib/LinearAlgebra/test/adjtrans.jl @@ -588,24 +588,52 @@ end @test transpose(Int[]) * Int[] == 0 end -@testset "reductions: $adjtrans" for adjtrans in [transpose, adjoint] - mat = rand(ComplexF64, 3,5) - @test sum(adjtrans(mat)) ≈ sum(collect(adjtrans(mat))) - @test sum(adjtrans(mat), dims=1) ≈ sum(collect(adjtrans(mat)), dims=1) - @test sum(adjtrans(mat), dims=(1,2)) ≈ sum(collect(adjtrans(mat)), dims=(1,2)) - - @test sum(imag, adjtrans(mat)) ≈ sum(imag, collect(adjtrans(mat))) - @test sum(imag, adjtrans(mat), dims=1) ≈ sum(imag, collect(adjtrans(mat)), dims=1) - - mat = [rand(ComplexF64,2,2) for _ in 1:3, _ in 1:5] - @test sum(adjtrans(mat)) ≈ sum(collect(adjtrans(mat))) - @test sum(adjtrans(mat), dims=1) ≈ sum(collect(adjtrans(mat)), dims=1) - @test sum(adjtrans(mat), dims=(1,2)) ≈ sum(collect(adjtrans(mat)), dims=(1,2)) - - @test sum(imag, adjtrans(mat)) ≈ sum(imag, collect(adjtrans(mat))) - @test sum(x -> x[1,2], adjtrans(mat)) ≈ sum(x -> x[1,2], collect(adjtrans(mat))) - @test sum(imag, adjtrans(mat), dims=1) ≈ sum(imag, collect(adjtrans(mat)), dims=1) - @test sum(x -> x[1,2], adjtrans(mat), dims=1) ≈ sum(x -> x[1,2], collect(adjtrans(mat)), dims=1) +@testset "reductions: $adjtrans" for adjtrans in (transpose, adjoint) + for (reduction, reduction!, op) in ((sum, sum!, +), (prod, prod!, *), (minimum, minimum!, min), (maximum, maximum!, max)) + T = op in (max, min) ? Float64 : ComplexF64 + mat = rand(T, 3,5) + rd1 = zeros(T, 1, 3) + rd2 = zeros(T, 5, 1) + rd3 = zeros(T, 1, 1) + @test reduction(adjtrans(mat)) ≈ reduction(copy(adjtrans(mat))) + @test reduction(adjtrans(mat), dims=1) ≈ reduction(copy(adjtrans(mat)), dims=1) + @test reduction(adjtrans(mat), dims=2) ≈ reduction(copy(adjtrans(mat)), dims=2) + @test reduction(adjtrans(mat), dims=(1,2)) ≈ reduction(copy(adjtrans(mat)), dims=(1,2)) + + @test reduction!(rd1, adjtrans(mat)) ≈ reduction!(rd1, copy(adjtrans(mat))) + @test reduction!(rd2, adjtrans(mat)) ≈ reduction!(rd2, copy(adjtrans(mat))) + @test reduction!(rd3, adjtrans(mat)) ≈ reduction!(rd3, copy(adjtrans(mat))) + + @test reduction(imag, adjtrans(mat)) ≈ reduction(imag, copy(adjtrans(mat))) + @test reduction(imag, adjtrans(mat), dims=1) ≈ reduction(imag, copy(adjtrans(mat)), dims=1) + @test reduction(imag, adjtrans(mat), dims=2) ≈ reduction(imag, copy(adjtrans(mat)), dims=2) + @test reduction(imag, adjtrans(mat), dims=(1,2)) ≈ reduction(imag, copy(adjtrans(mat)), dims=(1,2)) + + @test Base.mapreducedim!(imag, op, rd1, adjtrans(mat)) ≈ Base.mapreducedim!(imag, op, rd1, copy(adjtrans(mat))) + @test Base.mapreducedim!(imag, op, rd2, adjtrans(mat)) ≈ Base.mapreducedim!(imag, op, rd2, copy(adjtrans(mat))) + @test Base.mapreducedim!(imag, op, rd3, adjtrans(mat)) ≈ Base.mapreducedim!(imag, op, rd3, copy(adjtrans(mat))) + + op in (max, min) && continue + mat = [rand(T,2,2) for _ in 1:3, _ in 1:5] + rd1 = fill(zeros(T, 2, 2), 1, 3) + rd2 = fill(zeros(T, 2, 2), 5, 1) + rd3 = fill(zeros(T, 2, 2), 1, 1) + @test reduction(adjtrans(mat)) ≈ reduction(copy(adjtrans(mat))) + @test reduction(adjtrans(mat), dims=1) ≈ reduction(copy(adjtrans(mat)), dims=1) + @test reduction(adjtrans(mat), dims=2) ≈ reduction(copy(adjtrans(mat)), dims=2) + @test reduction(adjtrans(mat), dims=(1,2)) ≈ reduction(copy(adjtrans(mat)), dims=(1,2)) + + @test reduction(imag, adjtrans(mat)) ≈ reduction(imag, copy(adjtrans(mat))) + @test reduction(x -> x[1,2], adjtrans(mat)) ≈ reduction(x -> x[1,2], copy(adjtrans(mat))) + @test reduction(imag, adjtrans(mat), dims=1) ≈ reduction(imag, copy(adjtrans(mat)), dims=1) + @test reduction(x -> x[1,2], adjtrans(mat), dims=1) ≈ reduction(x -> x[1,2], copy(adjtrans(mat)), dims=1) + end + # see #46605 + Ac = [1 2; 3 4]' + @test mapreduce(identity, (x, y) -> 10x+y, copy(Ac)) == mapreduce(identity, (x, y) -> 10x+y, Ac) == 1234 + @test extrema([3,7,4]') == (3, 7) + @test mapreduce(x -> [x;;;], +, [1, 2, 3]') == sum(x -> [x;;;], [1, 2, 3]') == [6;;;] + @test mapreduce(string, *, [1 2; 3 4]') == mapreduce(string, *, copy([1 2; 3 4]')) == "1234" end end # module TestAdjointTranspose diff --git a/test/arrayops.jl b/test/arrayops.jl index d05906bbc9f0f..45e7451c22a78 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -708,7 +708,7 @@ end ap = PermutedDimsArray(Array(a), (2,1,3)) @test strides(ap) == (3,1,12) - for A in [rand(1,2,3,4),rand(2,2,2,2),rand(5,6,5,6),rand(1,1,1,1)] + for A in [rand(1,2,3,4),rand(2,2,2,2),rand(5,6,5,6),rand(1,1,1,1), [rand(ComplexF64, 2,2) for _ in 1:2, _ in 1:3, _ in 1:2, _ in 1:4]] perm = randperm(4) @test isequal(A,permutedims(permutedims(A,perm),invperm(perm))) @test isequal(A,permutedims(permutedims(A,invperm(perm)),perm)) @@ -716,6 +716,10 @@ end @test sum(permutedims(A,perm)) ≈ sum(PermutedDimsArray(A,perm)) @test sum(permutedims(A,perm), dims=2) ≈ sum(PermutedDimsArray(A,perm), dims=2) @test sum(permutedims(A,perm), dims=(2,4)) ≈ sum(PermutedDimsArray(A,perm), dims=(2,4)) + + @test prod(permutedims(A,perm)) ≈ prod(PermutedDimsArray(A,perm)) + @test prod(permutedims(A,perm), dims=2) ≈ prod(PermutedDimsArray(A,perm), dims=2) + @test prod(permutedims(A,perm), dims=(2,4)) ≈ prod(PermutedDimsArray(A,perm), dims=(2,4)) end m = [1 2; 3 4]