Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IndexableMap #145

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/LinearMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ function check_dim_mul(C, A, B)
end

# conversion of AbstractVecOrMat to LinearMap
convert_to_lmaps_(A::AbstractVecOrMat) = LinearMap(A)
convert_to_lmaps_(A::LinearMap) = A
convert_to_lmap_(A::AbstractVecOrMat) = LinearMap(A)
convert_to_lmap_(A::LinearMap) = A
convert_to_lmaps() = ()
convert_to_lmaps(A) = (convert_to_lmaps_(A),)
convert_to_lmaps(A) = (convert_to_lmap_(A),)
@inline convert_to_lmaps(A, B, Cs...) =
(convert_to_lmaps_(A), convert_to_lmaps_(B), convert_to_lmaps(Cs...)...)
(convert_to_lmap_(A), convert_to_lmap_(B), convert_to_lmaps(Cs...)...)

# The (internal) multiplication logic is as follows:
# - `*(A, x)` calls `mul!(y, A, x)` for appropriately-sized y
Expand Down Expand Up @@ -256,6 +256,7 @@ include("functionmap.jl") # using a function as linear map
include("blockmap.jl") # block linear maps
include("kronecker.jl") # Kronecker product of linear maps
include("fillmap.jl") # linear maps representing constantly filled matrices
include("indexablemap.jl") # indexable linear maps
include("conversion.jl") # conversion of linear maps to matrices
include("show.jl") # show methods for LinearMap objects

Expand Down Expand Up @@ -293,7 +294,9 @@ For the function-based constructor, there is one more keyword argument:
The default value is guessed by looking at the number of arguments of the first
occurrence of `f` in the method table.
"""
LinearMap(A::MapOrVecOrMat; kwargs...) = WrappedMap(A; kwargs...)
LinearMap(A::MapOrVecOrMat; getind=nothing, kwargs...) = _LinearMap(getind, A; kwargs...)
Copy link
Member

@JeffFessler JeffFessler Apr 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LinearMap(A::MapOrVecOrMat; getind=nothing, kwargs...) = _LinearMap(getind, A; kwargs...)
LinearMap(A::MapOrVecOrMat; getind=Base.getindex, kwargs...) = _LinearMap(getind, A; kwargs...)

For a WrappedMap could the default getind simply be Base.getindex because presumably that is the expected behavior for a matrix or even for any AbstractArray that already has a getindex method.

_LinearMap(::Nothing, A; kwargs...) = WrappedMap(A; kwargs...)
_LinearMap(getind, A; kwargs...) = IndexableMap(WrappedMap(A; kwargs...), getind)
LinearMap(J::UniformScaling, M::Int) = UniformScalingMap(J.λ, M)
LinearMap(f, M::Int; kwargs...) = LinearMap{Float64}(f, M; kwargs...)
LinearMap(f, M::Int, N::Int; kwargs...) = LinearMap{Float64}(f, M, N; kwargs...)
Expand Down
4 changes: 4 additions & 0 deletions src/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ end
# needed for disambiguation
Base.:(*)(A₁::ScaledMap, A₂::CompositeMap) = A₁.λ * (A₁.lmap * A₂)
Base.:(*)(A₁::CompositeMap, A₂::ScaledMap) = (A₁ * A₂.lmap) * A₂.λ
Base.:(*)(J::UniformScalingMap, B::CompositeMap) =
size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*"))
Base.:(*)(A::CompositeMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*"))

# special transposition behavior
LinearAlgebra.transpose(A::CompositeMap{T}) where {T} =
Expand Down
3 changes: 2 additions & 1 deletion src/conversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ SparseArrays.SparseMatrixCSC(A::LinearMap) = sparse(A)

# ScaledMap
Base.Matrix{T}(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) where {T} =
convert(Matrix{T}, A.λ * A.lmap.lmap)
convert(Matrix{T}, A.λ * convert(AbstractMatrix, A.lmap))
Base.convert(::Type{AbstractMatrix}, A::ScaledMap) = A.λ * convert(AbstractMatrix, A.lmap)
SparseArrays.sparse(A::ScaledMap{<:Any, <:Any, <:VecOrMatMap}) =
A.λ * sparse(A.lmap.lmap)

Expand Down
49 changes: 49 additions & 0 deletions src/indexablemap.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
struct IndexableMap{T,A<:LinearMap{T},F} <: LinearMap{T}
lmap::A
getind::F
end

MulStyle(A::IndexableMap) = MulStyle(A.lmap)

Base.size(A::IndexableMap) = size(A.lmap)
LinearAlgebra.issymmetric(A::IndexableMap) = issymmetric(A.lmap)
LinearAlgebra.ishermitian(A::IndexableMap) = ishermitian(A.lmap)
LinearAlgebra.isposdef(A::IndexableMap) = isposdef(A.lmap)

Base.:(==)(A::IndexableMap, B::IndexableMap) = A.lmap == B.lmap

Base.adjoint(A::IndexableMap) = IndexableMap(adjoint(A.lmap), (i,j) -> adjoint(A.getind(j,i)))
Base.transpose(A::IndexableMap) = IndexableMap(transpose(A.lmap), (i,j) -> transpose(A.getind(j,i)))
# rewrapping preserves indexability but redefines, e.g., symmetry properties
LinearMap(A::IndexableMap; getind=nothing, kwargs...) = IndexableMap(LinearMap(A.lmap; kwargs...), getind)
# addition/subtraction/scalar multiplication preserve indexability
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!
I wish I had a suggestion for how to handle a CompositeMap...

Base.:(+)(A::IndexableMap, B::IndexableMap) =
IndexableMap(A.lmap + B.lmap, (i,j) -> A.getind(i,j) + B.getind(i,j))
Base.:(-)(A::IndexableMap, B::IndexableMap) =
IndexableMap(A.lmap - B.lmap, (i,j) -> A.getind(i,j) - B.getind(i,j))
for typ in (RealOrComplex, Number)
@eval begin
Base.:(*)(α::$typ, A::IndexableMap) = IndexableMap(α * A.lmap, (i,j) -> α*A.getind(i,j))
Base.:(*)(A::IndexableMap, α::$typ) = IndexableMap(A.lmap * α, (i,j) -> A.getind(i,j)*α)
end
end
Base.:(*)(A::IndexableMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A*J.λ : throw(DimensionMismatch("*"))
Base.:(*)(J::UniformScalingMap, A::IndexableMap) =
size(A, 1) == J.M ? J.λ*A : throw(DimensionMismatch("*"))

Base.@propagate_inbounds Base.getindex(A::IndexableMap, ::Colon, ::Colon) = A.getind(1:size(A, 1), 1:size(A, 2))
Base.@propagate_inbounds Base.getindex(A::IndexableMap, rows, ::Colon) = A.getind(rows, 1:size(A, 2))
Base.@propagate_inbounds Base.getindex(A::IndexableMap, ::Colon, cols) = A.getind(1:size(A, 1), cols)
Base.@propagate_inbounds Base.getindex(A::IndexableMap, rows, cols) = A.getind(rows, cols)

for (In, Out) in ((AbstractVector, AbstractVecOrMat), (AbstractMatrix, AbstractMatrix))
@eval begin
function _unsafe_mul!(y::$Out, A::IndexableMap, x::$In)
return _unsafe_mul!(y, A.lmap, x)
end
function _unsafe_mul!(y::$Out, A::IndexableMap, x::$In, α::Number, β::Number)
return _unsafe_mul!(y, A.lmap, x, α, β)
end
end
end
4 changes: 2 additions & 2 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Construct a lazy representation of the `k`-th Kronecker power
⊗(A, B, Cs...) = kron(convert_to_lmaps(A, B, Cs...)...)

Base.:(^)(A::MapOrMatrix, ::KronPower{p}) where {p} =
kron(ntuple(n -> convert_to_lmaps_(A), Val(p))...)
kron(ntuple(n -> convert_to_lmap_(A), Val(p))...)

Base.size(A::KroneckerMap) = map(*, size.(A.maps)...)

Expand Down Expand Up @@ -282,7 +282,7 @@ where `A` can be a square `AbstractMatrix` or a `LinearMap`.
⊕(a, b, c...) = kronsum(a, b, c...)

Base.:(^)(A::MapOrMatrix, ::KronSumPower{p}) where {p} =
kronsum(ntuple(n->convert_to_lmaps_(A), Val(p))...)
kronsum(ntuple(n->convert_to_lmap_(A), Val(p))...)

Base.size(A::KroneckerSumMap, i) = prod(size.(A.maps, i))
Base.size(A::KroneckerSumMap) = (size(A, 1), size(A, 2))
Expand Down
4 changes: 4 additions & 0 deletions src/scaledmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ Base.:(*)(A::ScaledMap, α::Number) = A.lmap * (A.λ * α)
Base.:(*)(α::RealOrComplex, A::ScaledMap) = (α * A.λ) * A.lmap
Base.:(*)(A::ScaledMap, α::RealOrComplex) = (A.λ * α) * A.lmap
Base.:(-)(A::LinearMap) = -1 * A
Base.:(*)(J::UniformScalingMap, B::ScaledMap) =
size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*"))
Base.:(*)(A::ScaledMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*"))

# composition (not essential, but might save multiple scaling operations)
Base.:(*)(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * (A.lmap * B.lmap)
Expand Down
41 changes: 26 additions & 15 deletions src/uniformscalingmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ struct UniformScalingMap{T} <: LinearMap{T}
return new{typeof(λ)}(λ, M)
end
end
UniformScalingMap(λ::Number, M::Int, N::Int) =
(M == N ?
UniformScalingMap(λ, M) : error("UniformScalingMap needs to be square"))
UniformScalingMap(λ::T, sz::Dims{2}) where {T} =
(sz[1] == sz[2] ?
UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square"))
@deprecate UniformScalingMap(λ::Number, M::Int, N::Int) UniformScalingMap(λ, M) false
@deprecate UniformScalingMap(λ::Number, sz::Dims{2}) UniformScalingMap(λ, sz[1]) false
# the following methods are misleading: they introduce the opportunity to pass 2 dims,
# but throw if the dims are unequal!
# UniformScalingMap(λ::Number, M::Int, N::Int) =
# (M == N ?
# UniformScalingMap(λ, M) : error("UniformScalingMap needs to be square"))
# UniformScalingMap(λ::T, sz::Dims{2}) where {T} =
# (sz[1] == sz[2] ?
# UniformScalingMap(λ, sz[1]) : error("UniformScalingMap needs to be square"))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Base.getindex(A::UniformScalingMap, args...) = A.λ * A.lmap.getind(args...)

Seems like something like this should be the natural default for this type.

MulStyle(::UniformScalingMap) = FiveArg()

Expand All @@ -27,25 +31,32 @@ Base.:(==)(A::UniformScalingMap, B::UniformScalingMap) = (A.λ == B.λ && A.M ==

# special transposition behavior
LinearAlgebra.transpose(A::UniformScalingMap) = A
LinearAlgebra.adjoint(A::UniformScalingMap) = UniformScalingMap(conj(A.λ), size(A))
LinearAlgebra.adjoint(A::UniformScalingMap) = UniformScalingMap(conj(A.λ), A.M)

# multiplication with scalar
Base.:(*)(A::UniformScaling, B::LinearMap) = A.λ * B
Base.:(*)(A::LinearMap, B::UniformScaling) = A * B.λ
Base.:(*)(α::Number, J::UniformScalingMap) = UniformScalingMap(α * J.λ, size(J))
Base.:(*)(J::UniformScalingMap, α::Number) = UniformScalingMap(J.λ * α, size(J))
Base.:(*)(α::Number, J::UniformScalingMap) = UniformScalingMap(α * J.λ, J.M)
Base.:(*)(J::UniformScalingMap, α::Number) = UniformScalingMap(J.λ * α, J.M)
# needed for disambiguation
Base.:(*)(α::RealOrComplex, J::UniformScalingMap) = UniformScalingMap(α * J.λ, size(J))
Base.:(*)(J::UniformScalingMap, α::RealOrComplex) = UniformScalingMap(J.λ * α, size(J))
Base.:(*)(α::RealOrComplex, J::UniformScalingMap) = UniformScalingMap(α * J.λ, J.M)
Base.:(*)(J::UniformScalingMap, α::RealOrComplex) = UniformScalingMap(J.λ * α, J.M)

# multiplication with vector
Base.:(*)(J::UniformScalingMap, x::AbstractVector) =
length(x) == J.M ? J.λ * x : throw(DimensionMismatch("*"))
# multiplication with matrix
# multiplication with map/matrix
Base.:(*)(J::UniformScalingMap, B::LinearMap) =
size(B, 1) == J.M ? J.λ * B : throw(DimensionMismatch("*"))
Base.:(*)(J::UniformScalingMap, B::AbstractMatrix) =
size(B, 1) == J.M ? J.λ * LinearMap(B) : throw(DimensionMismatch("*"))
Base.:(*)(A::AbstractMatrix, J::UniformScalingMap) =
size(A, 2) == J.M ? LinearMap(A) * J.λ : throw(DimensionMismatch("*"))
size(B, 1) == J.M ? J.λ * convert_to_lmap_(B) : throw(DimensionMismatch("*"))
Base.:(*)(A::LinearMap, J::UniformScalingMap) =
size(A, 2) == J.M ? A * J.λ : throw(DimensionMismatch("*"))
Base.:(*)(A::VecOrMat, J::UniformScalingMap) =
size(A, 2) == J.M ? convert_to_lmap_(A) * J.λ : throw(DimensionMismatch("*"))
Base.:(*)(U::UniformScalingMap, J::UniformScalingMap) =
U.M == J.M ? UniformScalingMap(U.λ*J.λ, J.M) : throw(DimensionMismatch("*"))

# disambiguation
Base.:(*)(xc::LinearAlgebra.AdjointAbsVec, J::UniformScalingMap) = xc * J.λ
Base.:(*)(xt::LinearAlgebra.TransposeAbsVec, J::UniformScalingMap) = xt * J.λ
Expand Down
Empty file added test/indexablemap.jl
Empty file.
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ include("conversion.jl")
include("left.jl")

include("fillmap.jl")

include("indexablemap.jl")
4 changes: 2 additions & 2 deletions test/uniformscalingmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ using Test, LinearMaps, LinearAlgebra, BenchmarkTools
w = similar(v)
Id = @inferred LinearMap(I, 10)
@test occursin("10×10 LinearMaps.UniformScalingMap{Bool}", sprint((t, s) -> show(t, "text/plain", s), Id))
@test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20)
@test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20))
# @test_throws ErrorException LinearMaps.UniformScalingMap(1, 10, 20)
# @test_throws ErrorException LinearMaps.UniformScalingMap(1, (10, 20))
@test size(Id) == (10, 10)
@test @inferred isreal(Id)
@test @inferred issymmetric(Id)
Expand Down
2 changes: 1 addition & 1 deletion test/wrappedmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Test, LinearMaps, LinearAlgebra
@test occursin("10×20 LinearMaps.WrappedMap{Float64}", sprint((t, s) -> show(t, "text/plain", s), L))
MA = @inferred LinearMap(SA)
MB = @inferred LinearMap(SB)
@test eltype(Matrix{Complex{Float32}}(LinearMap(A))) <: Complex
@test eltype(Matrix{Complex{Float32}}(LinearMap(A))) == ComplexF32
@test size(L) == size(A)
@test @inferred !issymmetric(L)
@test @inferred issymmetric(MA)
Expand Down