From 423d393a34ed5281ec50c78a303a30a59e69520d Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 14 Apr 2021 10:22:27 +0200 Subject: [PATCH 1/2] Add IndexableMap --- src/LinearMaps.jl | 13 +++++++---- src/composition.jl | 4 ++++ src/conversion.jl | 3 ++- src/indexablemap.jl | 49 +++++++++++++++++++++++++++++++++++++++ src/kronecker.jl | 4 ++-- src/scaledmap.jl | 4 ++++ src/uniformscalingmap.jl | 41 ++++++++++++++++++++------------ test/indexablemap.jl | 0 test/runtests.jl | 2 ++ test/uniformscalingmap.jl | 4 ++-- test/wrappedmap.jl | 2 +- 11 files changed, 100 insertions(+), 26 deletions(-) create mode 100644 src/indexablemap.jl create mode 100644 test/indexablemap.jl diff --git a/src/LinearMaps.jl b/src/LinearMaps.jl index d3355e5d..8ee7526f 100644 --- a/src/LinearMaps.jl +++ b/src/LinearMaps.jl @@ -69,12 +69,12 @@ function check_dim_mul(C, A, B) end # conversion of AbstractVecOrMat to LinearMap -convert_to_lmaps_(A::AbstractVecOrMat) = LinearMap(A) -convert_to_lmaps_(A::LinearMap) = A +convert_to_lmap_(A::AbstractVecOrMat) = LinearMap(A) +convert_to_lmap_(A::LinearMap) = A convert_to_lmaps() = () -convert_to_lmaps(A) = (convert_to_lmaps_(A),) +convert_to_lmaps(A) = (convert_to_lmap_(A),) @inline convert_to_lmaps(A, B, Cs...) = - (convert_to_lmaps_(A), convert_to_lmaps_(B), convert_to_lmaps(Cs...)...) + (convert_to_lmap_(A), convert_to_lmap_(B), convert_to_lmaps(Cs...)...) # The (internal) multiplication logic is as follows: # - `*(A, x)` calls `mul!(y, A, x)` for appropriately-sized y @@ -256,6 +256,7 @@ include("functionmap.jl") # using a function as linear map include("blockmap.jl") # block linear maps include("kronecker.jl") # Kronecker product of linear maps include("fillmap.jl") # linear maps representing constantly filled matrices +include("indexablemap.jl") # indexable linear maps include("conversion.jl") # conversion of linear maps to matrices include("show.jl") # show methods for LinearMap objects @@ -293,7 +294,9 @@ For the function-based constructor, there is one more keyword argument: The default value is guessed by looking at the number of arguments of the first occurrence of `f` in the method table. """ -LinearMap(A::MapOrVecOrMat; kwargs...) = WrappedMap(A; kwargs...) +LinearMap(A::MapOrVecOrMat; getind=nothing, kwargs...) = _LinearMap(getind, A; kwargs...) +_LinearMap(::Nothing, A; kwargs...) = WrappedMap(A; kwargs...) +_LinearMap(getind, A; kwargs...) = IndexableMap(WrappedMap(A; kwargs...), getind) LinearMap(J::UniformScaling, M::Int) = UniformScalingMap(J.λ, M) LinearMap(f, M::Int; kwargs...) = LinearMap{Float64}(f, M; kwargs...) LinearMap(f, M::Int, N::Int; kwargs...) = LinearMap{Float64}(f, M, N; kwargs...) diff --git a/src/composition.jl b/src/composition.jl index b06a3c1a..c995c18e 100644 --- a/src/composition.jl +++ b/src/composition.jl @@ -119,6 +119,10 @@ end # needed for disambiguation Base.:(*)(A₁::ScaledMap, A₂::CompositeMap) = A₁.λ * (A₁.lmap * A₂) Base.:(*)(A₁::CompositeMap, A₂::ScaledMap) = (A₁ * A₂.lmap) * A₂.λ +Base.:(*)(J::UniformScalingMap, B::CompositeMap) = + size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*")) +Base.:(*)(A::CompositeMap, J::UniformScalingMap) = + size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*")) # special transposition behavior LinearAlgebra.transpose(A::CompositeMap{T}) where {T} = diff --git a/src/conversion.jl b/src/conversion.jl index 309150f4..1cd6abb2 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -48,7 +48,8 @@ SparseArrays.SparseMatrixCSC(A::LinearMap) = sparse(A) # ScaledMap Base.Matrix{T}(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) where {T} = - convert(Matrix{T}, A.λ * A.lmap.lmap) + convert(Matrix{T}, A.λ * convert(AbstractMatrix, A.lmap)) +Base.convert(::Type{AbstractMatrix}, A::ScaledMap) = A.λ * convert(AbstractMatrix, A.lmap) SparseArrays.sparse(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) = A.λ * sparse(A.lmap.lmap) diff --git a/src/indexablemap.jl b/src/indexablemap.jl new file mode 100644 index 00000000..980ce252 --- /dev/null +++ b/src/indexablemap.jl @@ -0,0 +1,49 @@ +struct IndexableMap{T,A<:LinearMap{T},F} <: LinearMap{T} + lmap::A + getind::F +end + +MulStyle(A::IndexableMap) = MulStyle(A.lmap) + +Base.size(A::IndexableMap) = size(A.lmap) +LinearAlgebra.issymmetric(A::IndexableMap) = issymmetric(A.lmap) +LinearAlgebra.ishermitian(A::IndexableMap) = ishermitian(A.lmap) +LinearAlgebra.isposdef(A::IndexableMap) = isposdef(A.lmap) + +Base.:(==)(A::IndexableMap, B::IndexableMap) = A.lmap == B.lmap + +Base.adjoint(A::IndexableMap) = IndexableMap(adjoint(A.lmap), (i,j) -> conj(A.getind(j,i))) +Base.transpose(A::IndexableMap) = IndexableMap(transpose(A.lmap), (i,j) -> A.getind(j,i)) +# rewrapping preserves indexability but redefines, e.g., symmetry properties +LinearMap(A::IndexableMap; getind=nothing, kwargs...) = IndexableMap(LinearMap(A.lmap; kwargs...), getind) +# addition/subtraction/scalar multiplication preserve indexability +Base.:(+)(A::IndexableMap, B::IndexableMap) = + IndexableMap(A.lmap + B.lmap, (i,j) -> A.getind(i,j) + B.getind(i,j)) +Base.:(-)(A::IndexableMap, B::IndexableMap) = + IndexableMap(A.lmap - B.lmap, (i,j) -> A.getind(i,j) - B.getind(i,j)) +for typ in (RealOrComplex, Number) + @eval begin + Base.:(*)(α::$typ, A::IndexableMap) = IndexableMap(α * A.lmap, (i,j) -> α*A.getind(i,j)) + Base.:(*)(A::IndexableMap, α::$typ) = IndexableMap(A.lmap * α, (i,j) -> A.getind(i,j)*α) + end +end +Base.:(*)(A::IndexableMap, J::UniformScalingMap) = + size(A, 2) == J.M ? A*J.λ : throw(DimensionMismatch("*")) +Base.:(*)(J::UniformScalingMap, A::IndexableMap) = + size(A, 1) == J.M ? J.λ*A : throw(DimensionMismatch("*")) + +Base.@propagate_inbounds Base.getindex(A::IndexableMap, ::Colon, ::Colon) = A.getind(1:size(A, 1), 1:size(A, 2)) +Base.@propagate_inbounds Base.getindex(A::IndexableMap, rows, ::Colon) = A.getind(rows, 1:size(A, 2)) +Base.@propagate_inbounds Base.getindex(A::IndexableMap, ::Colon, cols) = A.getind(1:size(A, 1), cols) +Base.@propagate_inbounds Base.getindex(A::IndexableMap, rows, cols) = A.getind(rows, cols) + +for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix)) + @eval begin + function _unsafe_mul!(y::$Out, A::IndexableMap, x::$In) + return _unsafe_mul!(y, A.lmap, x) + end + function _unsafe_mul!(y::$Out, A::IndexableMap, x::$In, α::Number, β::Number) + return _unsafe_mul!(y, A.lmap, x, α, β) + end + end +end diff --git a/src/kronecker.jl b/src/kronecker.jl index 4443b90f..c90c928a 100644 --- a/src/kronecker.jl +++ b/src/kronecker.jl @@ -97,7 +97,7 @@ Construct a lazy representation of the `k`-th Kronecker power ⊗(A, B, Cs...) = kron(convert_to_lmaps(A, B, Cs...)...) Base.:(^)(A::MapOrMatrix, ::KronPower{p}) where {p} = - kron(ntuple(n -> convert_to_lmaps_(A), Val(p))...) + kron(ntuple(n -> convert_to_lmap_(A), Val(p))...) Base.size(A::KroneckerMap) = map(*, size.(A.maps)...) @@ -282,7 +282,7 @@ where `A` can be a square `AbstractMatrix` or a `LinearMap`. ⊕(a, b, c...) = kronsum(a, b, c...) Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} = - kronsum(ntuple(n->convert_to_lmaps_(A), Val(p))...) + kronsum(ntuple(n->convert_to_lmap_(A), Val(p))...) Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i)) Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2)) diff --git a/src/scaledmap.jl b/src/scaledmap.jl index 94050292..f6d682c9 100644 --- a/src/scaledmap.jl +++ b/src/scaledmap.jl @@ -40,6 +40,10 @@ Base.:(*)(A::ScaledMap, α::Number) = A.lmap * (A.λ * α) Base.:(*)(α::RealOrComplex, A::ScaledMap) = (α * A.λ) * A.lmap Base.:(*)(A::ScaledMap, α::RealOrComplex) = (A.λ * α) * A.lmap Base.:(-)(A::LinearMap) = -1 * A +Base.:(*)(J::UniformScalingMap, B::ScaledMap) = + size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*")) +Base.:(*)(A::ScaledMap, J::UniformScalingMap) = + size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*")) # composition (not essential, but might save multiple scaling operations) Base.:(*)(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * (A.lmap * B.lmap) diff --git a/src/uniformscalingmap.jl b/src/uniformscalingmap.jl index 5530f896..13e90b72 100644 --- a/src/uniformscalingmap.jl +++ b/src/uniformscalingmap.jl @@ -6,12 +6,16 @@ struct UniformScalingMap{T} <: LinearMap{T} return new{typeof(λ)}(λ, M) end end -UniformScalingMap(λ::Number, M::Int, N::Int) = - (M == N ? - UniformScalingMap(λ, M) : error("UniformScalingMap needs to be square")) -UniformScalingMap(λ::T, sz::Dims{2}) where {T} = - (sz[1] == sz[2] ? - UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square")) +@deprecate UniformScalingMap(λ::Number, M::Int, N::Int) UniformScalingMap(λ, M) false +@deprecate UniformScalingMap(λ::Number, sz::Dims{2}) UniformScalingMap(λ, sz[1]) false +# the following methods are misleading: they introduce the opportunity to pass 2 dims, +# but throw if the dims are unequal! +# UniformScalingMap(λ::Number, M::Int, N::Int) = +# (M == N ? +# UniformScalingMap(λ, M) : error("UniformScalingMap needs to be square")) +# UniformScalingMap(λ::T, sz::Dims{2}) where {T} = +# (sz[1] == sz[2] ? +# UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square")) MulStyle(::UniformScalingMap) = FiveArg() @@ -27,25 +31,32 @@ Base.:(==)(A::UniformScalingMap, B::UniformScalingMap) = (A.λ == B.λ && A.M == # special transposition behavior LinearAlgebra.transpose(A::UniformScalingMap) = A -LinearAlgebra.adjoint(A::UniformScalingMap) = UniformScalingMap(conj(A.λ), size(A)) +LinearAlgebra.adjoint(A::UniformScalingMap) = UniformScalingMap(conj(A.λ), A.M) # multiplication with scalar Base.:(*)(A::UniformScaling, B::LinearMap) = A.λ * B Base.:(*)(A::LinearMap, B::UniformScaling) = A * B.λ -Base.:(*)(α::Number, J::UniformScalingMap) = UniformScalingMap(α * J.λ, size(J)) -Base.:(*)(J::UniformScalingMap, α::Number) = UniformScalingMap(J.λ * α, size(J)) +Base.:(*)(α::Number, J::UniformScalingMap) = UniformScalingMap(α * J.λ, J.M) +Base.:(*)(J::UniformScalingMap, α::Number) = UniformScalingMap(J.λ * α, J.M) # needed for disambiguation -Base.:(*)(α::RealOrComplex, J::UniformScalingMap) = UniformScalingMap(α * J.λ, size(J)) -Base.:(*)(J::UniformScalingMap, α::RealOrComplex) = UniformScalingMap(J.λ * α, size(J)) +Base.:(*)(α::RealOrComplex, J::UniformScalingMap) = UniformScalingMap(α * J.λ, J.M) +Base.:(*)(J::UniformScalingMap, α::RealOrComplex) = UniformScalingMap(J.λ * α, J.M) # multiplication with vector Base.:(*)(J::UniformScalingMap, x::AbstractVector) = length(x) == J.M ? J.λ * x : throw(DimensionMismatch("*")) -# multiplication with matrix +# multiplication with map/matrix +Base.:(*)(J::UniformScalingMap, B::LinearMap) = + size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*")) Base.:(*)(J::UniformScalingMap, B::AbstractMatrix) = - size(B, 1) == J.M ? J.λ * LinearMap(B) : throw(DimensionMismatch("*")) -Base.:(*)(A::AbstractMatrix, J::UniformScalingMap) = - size(A, 2) == J.M ? LinearMap(A) * J.λ : throw(DimensionMismatch("*")) + size(B, 1) == J.M ? J.λ * convert_to_lmap_(B) : throw(DimensionMismatch("*")) +Base.:(*)(A::LinearMap, J::UniformScalingMap) = + size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*")) +Base.:(*)(A::VecOrMat, J::UniformScalingMap) = + size(A, 2) == J.M ? convert_to_lmap_(A) * J.λ : throw(DimensionMismatch("*")) +Base.:(*)(U::UniformScalingMap, J::UniformScalingMap) = + U.M == J.M ? UniformScalingMap(U.λ*J.λ, J.M) : throw(DimensionMismatch("*")) + # disambiguation Base.:(*)(xc::LinearAlgebra.AdjointAbsVec, J::UniformScalingMap) = xc * J.λ Base.:(*)(xt::LinearAlgebra.TransposeAbsVec, J::UniformScalingMap) = xt * J.λ diff --git a/test/indexablemap.jl b/test/indexablemap.jl new file mode 100644 index 00000000..e69de29b diff --git a/test/runtests.jl b/test/runtests.jl index 33b600a9..71e464de 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,3 +34,5 @@ include("conversion.jl") include("left.jl") include("fillmap.jl") + +include("indexablemap.jl") diff --git a/test/uniformscalingmap.jl b/test/uniformscalingmap.jl index 03e276c9..6f283030 100644 --- a/test/uniformscalingmap.jl +++ b/test/uniformscalingmap.jl @@ -11,8 +11,8 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools w = similar(v) Id = @inferred LinearMap(I, 10) @test occursin("10×10 LinearMaps.UniformScalingMap{Bool}", sprint((t, s) -> show(t, "text/plain", s), Id)) - @test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20) - @test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20)) + # @test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20) + # @test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20)) @test size(Id) == (10, 10) @test @inferred isreal(Id) @test @inferred issymmetric(Id) diff --git a/test/wrappedmap.jl b/test/wrappedmap.jl index 565742e7..458da408 100644 --- a/test/wrappedmap.jl +++ b/test/wrappedmap.jl @@ -10,7 +10,7 @@ using Test, LinearMaps, LinearAlgebra @test occursin("10×20 LinearMaps.WrappedMap{Float64}", sprint((t, s) -> show(t, "text/plain", s), L)) MA = @inferred LinearMap(SA) MB = @inferred LinearMap(SB) - @test eltype(Matrix{Complex{Float32}}(LinearMap(A))) <: Complex + @test eltype(Matrix{Complex{Float32}}(LinearMap(A))) == ComplexF32 @test size(L) == size(A) @test @inferred !issymmetric(L) @test @inferred issymmetric(MA) From 3f36f3d32e4c216b8454f9e97b4f1c3580f41465 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 14 Apr 2021 12:41:53 +0200 Subject: [PATCH 2/2] make adjoint and transpose "recursive" --- src/indexablemap.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/indexablemap.jl b/src/indexablemap.jl index 980ce252..fdbd4f6b 100644 --- a/src/indexablemap.jl +++ b/src/indexablemap.jl @@ -12,8 +12,8 @@ LinearAlgebra.isposdef(A::IndexableMap) = isposdef(A.lmap) Base.:(==)(A::IndexableMap, B::IndexableMap) = A.lmap == B.lmap -Base.adjoint(A::IndexableMap) = IndexableMap(adjoint(A.lmap), (i,j) -> conj(A.getind(j,i))) -Base.transpose(A::IndexableMap) = IndexableMap(transpose(A.lmap), (i,j) -> A.getind(j,i)) +Base.adjoint(A::IndexableMap) = IndexableMap(adjoint(A.lmap), (i,j) -> adjoint(A.getind(j,i))) +Base.transpose(A::IndexableMap) = IndexableMap(transpose(A.lmap), (i,j) -> transpose(A.getind(j,i))) # rewrapping preserves indexability but redefines, e.g., symmetry properties LinearMap(A::IndexableMap; getind=nothing, kwargs...) = IndexableMap(LinearMap(A.lmap; kwargs...), getind) # addition/subtraction/scalar multiplication preserve indexability