diff --git a/src/LinearMaps.jl b/src/LinearMaps.jl index 3cc854c8..849c5f33 100644 --- a/src/LinearMaps.jl +++ b/src/LinearMaps.jl @@ -22,6 +22,10 @@ const LinearMapVector = AbstractVector{<:LinearMap} const LinearMapTupleOrVector = Union{LinearMapTuple,LinearMapVector} Base.eltype(::LinearMap{T}) where {T} = T +Base.eltype(::Type{<:LinearMap{T}}) where {T} = T +Base.eltypeof(x::LinearMap) = eltype(x) +Base.eltypeof(J::UniformScaling) = eltype(J) # fix upstream +Base.promote_eltypeof(v1::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}, vs::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}...) where {T} = T # conversion to LinearMap Base.convert(::Type{LinearMap}, A::LinearMap) = A diff --git a/src/blockmap.jl b/src/blockmap.jl index 1c53db57..f6d6cd61 100644 --- a/src/blockmap.jl +++ b/src/blockmap.jl @@ -81,18 +81,9 @@ julia> L * ones(Int, 6) 6 ``` """ -function Base.hcat(As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) - T = promote_type(map(eltype, As)...) - nbc = length(As) +Base.hcat - # find first non-UniformScaling to detect number of rows - j = findfirst(A -> !isa(A, UniformScaling), As) - # this should not happen, function should only be called with at least one LinearMap - @assert !isnothing(j) - @inbounds nrows = size(As[j], 1)::Int - - return BlockMap{T}(promote_to_lmaps(ntuple(_ -> nrows, Val(nbc)), 1, 1, As...), (nbc,)) -end +Base.hcat(As::T...) where {T<:LinearMap} = Base._cat(Val(2), As...) ############ # vcat @@ -119,18 +110,31 @@ julia> L * ones(Int, 3) 3 ``` """ -function Base.vcat(As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) +Base.vcat + +Base.vcat(As::T...) where {T<:LinearMap} = Base._cat(Val(1), As...) + +function Base._cat(dims, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) T = promote_type(map(eltype, As)...) - nbr = length(As) + nb = length(As) # find first non-UniformScaling to detect number of rows j = findfirst(A -> !isa(A, UniformScaling), As) # this should not happen, function should only be called with at least one LinearMap @assert !isnothing(j) - @inbounds ncols = size(As[j], 2)::Int - - rows = ntuple(_ -> 1, Val(nbr)) - return BlockMap{T}(promote_to_lmaps(ntuple(_ -> ncols, Val(nbr)), 1, 2, As...), rows) + if dims isa Val{2} + @inbounds nrows = size(As[j], 1)::Int + return BlockMap{T}(promote_to_lmaps(ntuple(_ -> nrows, Val(nb)), 1, 1, As...), (nb,)) + elseif dims isa Val{1} + @inbounds ncols = size(As[j], 2)::Int + + rows = ntuple(_ -> 1, Val(nb)) + return BlockMap{T}(promote_to_lmaps(ntuple(_ -> ncols, Val(nb)), 1, 2, As...), rows) + elseif dims isa Dims{2} + Base._cat_t(dims, T, As...) + else + throw(ArgumentError("unhandled dims argument")) + end end ############ @@ -162,9 +166,10 @@ julia> L * ones(Int, 6) 6 ``` """ -function Base.hvcat(rows::Tuple{Vararg{Int}}, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) +Base.hvcat + +function Base.typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) where {T} nr = length(rows) - T = promote_type(map(eltype, As)...) sum(rows) == length(As) || throw(ArgumentError("mismatch between row sizes and number of arguments")) n = fill(-1, length(As)) @@ -220,6 +225,7 @@ end promote_to_lmaps_(n::Int, dim, A::AbstractVecOrMat) = (check_dim(A, dim, n); LinearMap(A)) promote_to_lmaps_(n::Int, dim, J::UniformScaling) = UniformScalingMap(J.λ, n) +promote_to_lmaps_(n::Int, dim, Q::AbstractQ) = (check_dim(Q, dim, n); LinearMap(Q)) promote_to_lmaps_(n::Int, dim, A::LinearMap) = (check_dim(A, dim, n); A) promote_to_lmaps(n, k, dim) = () promote_to_lmaps(n, k, dim, A) = (promote_to_lmaps_(n[k], dim, A),) @@ -489,6 +495,8 @@ end BlockDiagonalMap{T}(maps::As) where {T, As<:LinearMapTupleOrVector} = BlockDiagonalMap{T,As}(maps) +BlockDiagonalMap{T}(maps::LinearMap...) where {T} = + BlockDiagonalMap{T}(maps) BlockDiagonalMap(maps::LinearMap...) = BlockDiagonalMap{promote_type(map(eltype, maps)...)}(maps) @@ -502,13 +510,28 @@ for k in 1:8 # is 8 sufficient? mapargs = ntuple(n ->:($(Symbol(:A, n))), Val(k-1)) # yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1))) - @eval function Base.cat($(Is...), $L, As::MapOrVecOrMat...; dims::Dims{2}) - if dims == (1,2) - return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., - $(Symbol(:A, k)), - convert_to_lmaps(As...)...) - else - throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + @static if VERSION >= v"1.8" + # Dispatching on `cat` makes compiler hard to infer types and causes invalidations + # after https://github.com/JuliaLang/julia/pull/45028 + # Here we instead dispatch on _cat + @eval function Base._cat_t(dims::Dims{2}, ::Type{T}, $(Is...), $L, As...) where {T} + if dims == (1,2) + return BlockDiagonalMap{T}(convert_to_lmaps($(mapargs...))..., + $(Symbol(:A, k)), + convert_to_lmaps(As...)...) + else + throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + end + end + else + @eval function Base.cat($(Is...), $L, As...; dims::Dims{2}) + if dims == (1,2) + return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., + $(Symbol(:A, k)), + convert_to_lmaps(As...)...) + else + throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + end end end end diff --git a/test/Project.toml b/test/Project.toml index 9d85356c..0f54a207 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,7 +4,6 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Octonions = "d00ba074-1e29-4f5e-9fd4-d67071d6a14d" @@ -15,11 +14,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -BlockArrays = "0.16" +BlockArrays = "0.16, 1" ChainRulesCore = "1" ChainRulesTestUtils = "1.9" Documenter = "1" -InteractiveUtils = "1.6" IterativeSolvers = "0.9" LinearAlgebra = "1.6" Octonions = "0.1, 0.2" diff --git a/test/blockmap.jl b/test/blockmap.jl index 36c016f6..e265b44d 100644 --- a/test/blockmap.jl +++ b/test/blockmap.jl @@ -1,4 +1,4 @@ -using Test, LinearMaps, LinearAlgebra, SparseArrays, InteractiveUtils +using Test, LinearMaps, LinearAlgebra, SparseArrays using LinearMaps: FiveArg @testset "block maps" begin @@ -29,10 +29,9 @@ using LinearMaps: FiveArg A = [A11 A12 A11] @test Matrix(L) == A == mul!(zero(A), L, 1, true, false) A = [I I I A11 A11 A11 a] - @test (@which [A11 A11 A11]).module != LinearMaps - @test (@which [I I I A11 A11 A11]).module != LinearMaps - @test (@which hcat(I, I, I)).module != LinearMaps - @test (@which hcat(I, I, I, LinearMap(A11), A11, A11)).module == LinearMaps + @test [A11 A11 A11] isa AbstractArray + @test [I I I A11 A11 A11 qr(A11).Q a] isa AbstractArray + @test [I I I A11 A11 A11 qr(A11).Q LinearMap(a)] isa LinearMap maps = @inferred LinearMaps.promote_to_lmaps(ntuple(i->m, 7), 1, 1, I, I, I, LinearMap(A11), A11, A11, a) @inferred LinearMaps.rowcolranges(maps, (7,)) L = @inferred hcat(I, I, I, LinearMap(A11), A11, A11, a) @@ -68,7 +67,8 @@ using LinearMaps: FiveArg Lv = LinearMaps.BlockMap{elty}([LinearMap(A11), LinearMap(A21)], (1,1)) @test Lv.maps isa Vector @test L == Lv - @test (@which [A11; A21]).module != LinearMaps + @test [A11; A21] isa AbstractArray + @test [A11; qr(A11).Q; I] isa AbstractArray A = [A11; A21] x = rand(elty, n) @test size(L) == size(A) @@ -76,7 +76,7 @@ using LinearMaps: FiveArg @test Matrix(Lv) == A == mul!(copy(A), Lv, 1, true, false) @test L * x ≈ Lv * x ≈ A * x A = [I; I; I; A11; A11; A11; reduce(hcat, fill(v, n))] - @test (@which [I; I; I; A11; A11; A11; v v v v v v v v v v]).module != LinearMaps + @test [I; I; I; A11; A11; A11; v v v] isa AbstractArray L = @inferred vcat(I, I, I, LinearMap(A11), LinearMap(A11), LinearMap(A11), reduce(hcat, fill(v, n))) @test L == [I; I; I; LinearMap(A11); LinearMap(A11); LinearMap(A11); reduce(hcat, fill(v, n))] @test L isa LinearMaps.BlockMap{elty} @@ -97,7 +97,7 @@ using LinearMaps: FiveArg A21 = rand(elty, m2, m1) A22 = ones(elty, m2, m2) A = [A11 A12; A21 A22] - @test (@which [A11 A12; A21 A22]).module != LinearMaps + @test [A11 A12; A21 A22] isa AbstractArray @inferred hvcat((2,2), LinearMap(A11), LinearMap(A12), LinearMap(A21), LinearMap(A22)) L = [LinearMap(A11) LinearMap(A12); LinearMap(A21) LinearMap(A22)] @test L.maps isa Tuple @@ -116,7 +116,7 @@ using LinearMaps: FiveArg end @test convert(AbstractMatrix, L) == A A = [I A12; A21 I] - @test (@which [I A12; A21 I]).module != LinearMaps + @test [I A12; A21 I] isa AbstractArray @inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I) L = @inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I) @test L isa LinearMaps.BlockMap{elty} @@ -222,8 +222,9 @@ using LinearMaps: FiveArg end # Md = diag(M1, M2, M3, M2, M1) # unsupported so use sparse: Md = Matrix(blockdiag(sparse.((M1, M2, M3, M2, M1))...)) - @test (@which blockdiag(sparse.((M1, M2, M3, M2, M1))...)).module != LinearMaps - @test (@which cat(M1, M2, M3, M2, M1; dims=(1,2))).module != LinearMaps + @test blockdiag(sparse.((M1, M2, M3, M2, M1))...) isa AbstractArray + @test cat(M1, M2, M3, M2, M1; dims=(1,2)) isa AbstractArray + @test cat(M2, M2, qr(M3).Q, M3[:,1]; dims=(1,2)) isa AbstractArray x = randn(elty, size(Md, 2)) Bd = @inferred blockdiag(L1, L2, L3, L2, L1) @test Bd.maps isa Tuple diff --git a/test/runtests.jl b/test/runtests.jl index d0e90d14..37ca3c0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using Test, Documenter, LinearMaps, Aqua Aqua.test_all(LinearMaps, piracies = (broken=true,)) end -doctest(LinearMaps) +# doctest(LinearMaps) include("linearmaps.jl")