Skip to content

Commit

Permalink
Add IndexableMap
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch committed Apr 14, 2021
1 parent 032832e commit 423d393
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 26 deletions.
13 changes: 8 additions & 5 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ function check_dim_mul(C, A, B)
end

# conversion of AbstractVecOrMat to LinearMap
convert_to_lmaps_(A::AbstractVecOrMat) = LinearMap(A)
convert_to_lmaps_(A::LinearMap) = A
convert_to_lmap_(A::AbstractVecOrMat) = LinearMap(A)
convert_to_lmap_(A::LinearMap) = A
convert_to_lmaps() = ()
convert_to_lmaps(A) = (convert_to_lmaps_(A),)
convert_to_lmaps(A) = (convert_to_lmap_(A),)
@inline convert_to_lmaps(A, B, Cs...) =
(convert_to_lmaps_(A), convert_to_lmaps_(B), convert_to_lmaps(Cs...)...)
(convert_to_lmap_(A), convert_to_lmap_(B), convert_to_lmaps(Cs...)...)

# The (internal) multiplication logic is as follows:
# - `*(A, x)` calls `mul!(y, A, x)` for appropriately-sized y
Expand Down Expand Up @@ -256,6 +256,7 @@ include("functionmap.jl") # using a function as linear map
include("blockmap.jl") # block linear maps
include("kronecker.jl") # Kronecker product of linear maps
include("fillmap.jl") # linear maps representing constantly filled matrices
include("indexablemap.jl") # indexable linear maps
include("conversion.jl") # conversion of linear maps to matrices
include("show.jl") # show methods for LinearMap objects

Expand Down Expand Up @@ -293,7 +294,9 @@ For the function-based constructor, there is one more keyword argument:
The default value is guessed by looking at the number of arguments of the first
occurrence of `f` in the method table.
"""
LinearMap(A::MapOrVecOrMat; kwargs...) = WrappedMap(A; kwargs...)
LinearMap(A::MapOrVecOrMat; getind=nothing, kwargs...) = _LinearMap(getind, A; kwargs...)
_LinearMap(::Nothing, A; kwargs...) = WrappedMap(A; kwargs...)
_LinearMap(getind, A; kwargs...) = IndexableMap(WrappedMap(A; kwargs...), getind)
LinearMap(J::UniformScaling, M::Int) = UniformScalingMap(J.λ, M)
LinearMap(f, M::Int; kwargs...) = LinearMap{Float64}(f, M; kwargs...)
LinearMap(f, M::Int, N::Int; kwargs...) = LinearMap{Float64}(f, M, N; kwargs...)
Expand Down
4 changes: 4 additions & 0 deletions src/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ end
# needed for disambiguation
Base.:(*)(A₁::ScaledMap, A₂::CompositeMap) = A₁.λ * (A₁.lmap * A₂)
Base.:(*)(A₁::CompositeMap, A₂::ScaledMap) = (A₁ * A₂.lmap) * A₂.λ
Base.:(*)(J::UniformScalingMap, B::CompositeMap) =
size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*"))
Base.:(*)(A::CompositeMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*"))

# special transposition behavior
LinearAlgebra.transpose(A::CompositeMap{T}) where {T} =
Expand Down
3 changes: 2 additions & 1 deletion src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ SparseArrays.SparseMatrixCSC(A::LinearMap) = sparse(A)

# ScaledMap
Base.Matrix{T}(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) where {T} =
convert(Matrix{T}, A.λ * A.lmap.lmap)
convert(Matrix{T}, A.λ * convert(AbstractMatrix, A.lmap))
Base.convert(::Type{AbstractMatrix}, A::ScaledMap) = A.λ * convert(AbstractMatrix, A.lmap)
SparseArrays.sparse(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) =
A.λ * sparse(A.lmap.lmap)

Expand Down
49 changes: 49 additions & 0 deletions src/indexablemap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
struct IndexableMap{T,A<:LinearMap{T},F} <: LinearMap{T}
lmap::A
getind::F
end

MulStyle(A::IndexableMap) = MulStyle(A.lmap)

Base.size(A::IndexableMap) = size(A.lmap)
LinearAlgebra.issymmetric(A::IndexableMap) = issymmetric(A.lmap)
LinearAlgebra.ishermitian(A::IndexableMap) = ishermitian(A.lmap)
LinearAlgebra.isposdef(A::IndexableMap) = isposdef(A.lmap)

Base.:(==)(A::IndexableMap, B::IndexableMap) = A.lmap == B.lmap

Base.adjoint(A::IndexableMap) = IndexableMap(adjoint(A.lmap), (i,j) -> conj(A.getind(j,i)))
Base.transpose(A::IndexableMap) = IndexableMap(transpose(A.lmap), (i,j) -> A.getind(j,i))
# rewrapping preserves indexability but redefines, e.g., symmetry properties
LinearMap(A::IndexableMap; getind=nothing, kwargs...) = IndexableMap(LinearMap(A.lmap; kwargs...), getind)
# addition/subtraction/scalar multiplication preserve indexability
Base.:(+)(A::IndexableMap, B::IndexableMap) =
IndexableMap(A.lmap + B.lmap, (i,j) -> A.getind(i,j) + B.getind(i,j))
Base.:(-)(A::IndexableMap, B::IndexableMap) =
IndexableMap(A.lmap - B.lmap, (i,j) -> A.getind(i,j) - B.getind(i,j))
for typ in (RealOrComplex, Number)
@eval begin
Base.:(*)(α::$typ, A::IndexableMap) = IndexableMap* A.lmap, (i,j) -> α*A.getind(i,j))
Base.:(*)(A::IndexableMap, α::$typ) = IndexableMap(A.lmap * α, (i,j) -> A.getind(i,j)*α)
end
end
Base.:(*)(A::IndexableMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A*J.λ : throw(DimensionMismatch("*"))
Base.:(*)(J::UniformScalingMap, A::IndexableMap) =
size(A, 1) == J.M ? J.λ*A : throw(DimensionMismatch("*"))

Base.@propagate_inbounds Base.getindex(A::IndexableMap, ::Colon, ::Colon) = A.getind(1:size(A, 1), 1:size(A, 2))
Base.@propagate_inbounds Base.getindex(A::IndexableMap, rows, ::Colon) = A.getind(rows, 1:size(A, 2))
Base.@propagate_inbounds Base.getindex(A::IndexableMap, ::Colon, cols) = A.getind(1:size(A, 1), cols)
Base.@propagate_inbounds Base.getindex(A::IndexableMap, rows, cols) = A.getind(rows, cols)

for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
@eval begin
function _unsafe_mul!(y::$Out, A::IndexableMap, x::$In)
return _unsafe_mul!(y, A.lmap, x)
end
function _unsafe_mul!(y::$Out, A::IndexableMap, x::$In, α::Number, β::Number)
return _unsafe_mul!(y, A.lmap, x, α, β)
end
end
end
4 changes: 2 additions & 2 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Construct a lazy representation of the `k`-th Kronecker power
(A, B, Cs...) = kron(convert_to_lmaps(A, B, Cs...)...)

Base.:(^)(A::MapOrMatrix, ::KronPower{p}) where {p} =
kron(ntuple(n -> convert_to_lmaps_(A), Val(p))...)
kron(ntuple(n -> convert_to_lmap_(A), Val(p))...)

Base.size(A::KroneckerMap) = map(*, size.(A.maps)...)

Expand Down Expand Up @@ -282,7 +282,7 @@ where `A` can be a square `AbstractMatrix` or a `LinearMap`.
(a, b, c...) = kronsum(a, b, c...)

Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} =
kronsum(ntuple(n->convert_to_lmaps_(A), Val(p))...)
kronsum(ntuple(n->convert_to_lmap_(A), Val(p))...)

Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i))
Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2))
Expand Down
4 changes: 4 additions & 0 deletions src/scaledmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ Base.:(*)(A::ScaledMap, α::Number) = A.lmap * (A.λ * α)
Base.:(*)(α::RealOrComplex, A::ScaledMap) =* A.λ) * A.lmap
Base.:(*)(A::ScaledMap, α::RealOrComplex) = (A.λ * α) * A.lmap
Base.:(-)(A::LinearMap) = -1 * A
Base.:(*)(J::UniformScalingMap, B::ScaledMap) =
size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*"))
Base.:(*)(A::ScaledMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*"))

# composition (not essential, but might save multiple scaling operations)
Base.:(*)(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * (A.lmap * B.lmap)
Expand Down
41 changes: 26 additions & 15 deletions src/uniformscalingmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ struct UniformScalingMap{T} <: LinearMap{T}
return new{typeof(λ)}(λ, M)
end
end
UniformScalingMap::Number, M::Int, N::Int) =
(M == N ?
UniformScalingMap(λ, M) : error("UniformScalingMap needs to be square"))
UniformScalingMap::T, sz::Dims{2}) where {T} =
(sz[1] == sz[2] ?
UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square"))
@deprecate UniformScalingMap::Number, M::Int, N::Int) UniformScalingMap(λ, M) false
@deprecate UniformScalingMap::Number, sz::Dims{2}) UniformScalingMap(λ, sz[1]) false
# the following methods are misleading: they introduce the opportunity to pass 2 dims,
# but throw if the dims are unequal!
# UniformScalingMap(λ::Number, M::Int, N::Int) =
# (M == N ?
# UniformScalingMap(λ, M) : error("UniformScalingMap needs to be square"))
# UniformScalingMap(λ::T, sz::Dims{2}) where {T} =
# (sz[1] == sz[2] ?
# UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square"))

MulStyle(::UniformScalingMap) = FiveArg()

Expand All @@ -27,25 +31,32 @@ Base.:(==)(A::UniformScalingMap, B::UniformScalingMap) = (A.λ == B.λ && A.M ==

# special transposition behavior
LinearAlgebra.transpose(A::UniformScalingMap) = A
LinearAlgebra.adjoint(A::UniformScalingMap) = UniformScalingMap(conj(A.λ), size(A))
LinearAlgebra.adjoint(A::UniformScalingMap) = UniformScalingMap(conj(A.λ), A.M)

# multiplication with scalar
Base.:(*)(A::UniformScaling, B::LinearMap) = A.λ * B
Base.:(*)(A::LinearMap, B::UniformScaling) = A * B.λ
Base.:(*)(α::Number, J::UniformScalingMap) = UniformScalingMap* J.λ, size(J))
Base.:(*)(J::UniformScalingMap, α::Number) = UniformScalingMap(J.λ * α, size(J))
Base.:(*)(α::Number, J::UniformScalingMap) = UniformScalingMap* J.λ, J.M)
Base.:(*)(J::UniformScalingMap, α::Number) = UniformScalingMap(J.λ * α, J.M)
# needed for disambiguation
Base.:(*)(α::RealOrComplex, J::UniformScalingMap) = UniformScalingMap* J.λ, size(J))
Base.:(*)(J::UniformScalingMap, α::RealOrComplex) = UniformScalingMap(J.λ * α, size(J))
Base.:(*)(α::RealOrComplex, J::UniformScalingMap) = UniformScalingMap* J.λ, J.M)
Base.:(*)(J::UniformScalingMap, α::RealOrComplex) = UniformScalingMap(J.λ * α, J.M)

# multiplication with vector
Base.:(*)(J::UniformScalingMap, x::AbstractVector) =
length(x) == J.M ? J.λ * x : throw(DimensionMismatch("*"))
# multiplication with matrix
# multiplication with map/matrix
Base.:(*)(J::UniformScalingMap, B::LinearMap) =
size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*"))
Base.:(*)(J::UniformScalingMap, B::AbstractMatrix) =
size(B, 1) == J.M ? J.λ * LinearMap(B) : throw(DimensionMismatch("*"))
Base.:(*)(A::AbstractMatrix, J::UniformScalingMap) =
size(A, 2) == J.M ? LinearMap(A) * J.λ : throw(DimensionMismatch("*"))
size(B, 1) == J.M ? J.λ * convert_to_lmap_(B) : throw(DimensionMismatch("*"))
Base.:(*)(A::LinearMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*"))
Base.:(*)(A::VecOrMat, J::UniformScalingMap) =
size(A, 2) == J.M ? convert_to_lmap_(A) * J.λ : throw(DimensionMismatch("*"))
Base.:(*)(U::UniformScalingMap, J::UniformScalingMap) =
U.M == J.M ? UniformScalingMap(U.λ*J.λ, J.M) : throw(DimensionMismatch("*"))

# disambiguation
Base.:(*)(xc::LinearAlgebra.AdjointAbsVec, J::UniformScalingMap) = xc * J.λ
Base.:(*)(xt::LinearAlgebra.TransposeAbsVec, J::UniformScalingMap) = xt * J.λ
Expand Down
Empty file added test/indexablemap.jl
Empty file.
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ include("conversion.jl")
include("left.jl")

include("fillmap.jl")

include("indexablemap.jl")
4 changes: 2 additions & 2 deletions test/uniformscalingmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
w = similar(v)
Id = @inferred LinearMap(I, 10)
@test occursin("10×10 LinearMaps.UniformScalingMap{Bool}", sprint((t, s) -> show(t, "text/plain", s), Id))
@test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20)
@test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20))
# @test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20)
# @test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20))
@test size(Id) == (10, 10)
@test @inferred isreal(Id)
@test @inferred issymmetric(Id)
Expand Down
2 changes: 1 addition & 1 deletion test/wrappedmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Test, LinearMaps, LinearAlgebra
@test occursin("10×20 LinearMaps.WrappedMap{Float64}", sprint((t, s) -> show(t, "text/plain", s), L))
MA = @inferred LinearMap(SA)
MB = @inferred LinearMap(SB)
@test eltype(Matrix{Complex{Float32}}(LinearMap(A))) <: Complex
@test eltype(Matrix{Complex{Float32}}(LinearMap(A))) == ComplexF32
@test size(L) == size(A)
@test @inferred !issymmetric(L)
@test @inferred issymmetric(MA)
Expand Down

0 comments on commit 423d393

Please sign in to comment.