From 095632067a6a0b034ace551c4edf1b6fa95c167e Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Tue, 26 Sep 2023 23:57:32 +0800 Subject: [PATCH 1/5] fix cat invalidations This patch removes the invalidation on cat and thus reduces the OneHotArrays loading time from 4.5s to 0.5s (the normal status) --- src/blockmap.jl | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/blockmap.jl b/src/blockmap.jl index 874ee543..63b72739 100644 --- a/src/blockmap.jl +++ b/src/blockmap.jl @@ -505,13 +505,28 @@ for k in 1:8 # is 8 sufficient? mapargs = ntuple(n ->:($(Symbol(:A, n))), Val(k-1)) # yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1))) - @eval function Base.cat($(Is...), $L, As::MapOrVecOrMat...; dims::Dims{2}) - if dims == (1,2) - return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., - $(Symbol(:A, k)), - convert_to_lmaps(As...)...) - else - throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + @static if VERSION >= v"1.8" + # Dispatching on `cat` makes compiler hard to infer types and causes invalidations + # after https://github.com/JuliaLang/julia/pull/45028 + # Here we instead dispatch on _cat + @eval function Base._cat(dims, $(Is...), $L, As...) + if dims == (1,2) + return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., + $(Symbol(:A, k)), + convert_to_lmaps(As...)...) + else + throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + end + end + else + @eval function Base.cat($(Is...), $L, As...; dims::Dims{2}) + if dims == (1,2) + return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., + $(Symbol(:A, k)), + convert_to_lmaps(As...)...) + else + throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + end end end end From d42c80f50e8e4bcd95618257b14ce1e20dfe9b76 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 27 Jun 2024 23:32:04 +0200 Subject: [PATCH 2/5] ride on the generic call chain --- src/LinearMaps.jl | 3 +++ src/blockmap.jl | 22 +++++++++++++--------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/LinearMaps.jl b/src/LinearMaps.jl index bd030de8..29fe0757 100644 --- a/src/LinearMaps.jl +++ b/src/LinearMaps.jl @@ -22,6 +22,9 @@ const LinearMapVector = AbstractVector{<:LinearMap} const LinearMapTupleOrVector = Union{LinearMapTuple,LinearMapVector} Base.eltype(::LinearMap{T}) where {T} = T +Base.eltype(::Type{L}) where {T,L<:LinearMap{T}} = T +Base.eltypeof(x::LinearMap) = eltype(x) +Base.promote_eltypeof(v1::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}, vs::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}...) where {T} = T # conversion to LinearMap Base.convert(::Type{LinearMap}, A::LinearMap) = A diff --git a/src/blockmap.jl b/src/blockmap.jl index 63b72739..5be6eb79 100644 --- a/src/blockmap.jl +++ b/src/blockmap.jl @@ -81,8 +81,10 @@ julia> L * ones(Int, 6) 6 ``` """ -function Base.hcat(As::Union{LinearMap, UniformScaling, AbstractVecOrMatOrQ}...) - T = promote_type(map(eltype, As)...) +Base.hcat + +Base.hcat(As::T...) where {T<:LinearMap} = Base._cat_t(Val(2), eltype(T), As...) +function Base._cat_t(::Val{2}, ::Type{T}, As::Union{LinearMap, UniformScaling, AbstractVecOrMatOrQ}...) where {T} nbc = length(As) # find first non-UniformScaling to detect number of rows @@ -119,8 +121,10 @@ julia> L * ones(Int, 3) 3 ``` """ -function Base.vcat(As::Union{LinearMap,UniformScaling,AbstractVecOrMatOrQ}...) - T = promote_type(map(eltype, As)...) +Base.vcat + +Base.vcat(As::T...) where {T<:LinearMap} = Base._cat_t(Val(1), eltype(T), As...) +function Base._cat_t(::Val{1}, ::Type{T}, As::Union{LinearMap, UniformScaling, AbstractVecOrMatOrQ}...) where {T} nbr = length(As) # find first non-UniformScaling to detect number of rows @@ -164,10 +168,8 @@ julia> L * ones(Int, 6) """ Base.hvcat -function Base.hvcat(rows::Tuple{Vararg{Int}}, - As::Union{LinearMap, UniformScaling, AbstractVecOrMatOrQ}...) +function Base.typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, As::Union{LinearMap, UniformScaling, AbstractVecOrMatOrQ}...) where {T} nr = length(rows) - T = promote_type(map(eltype, As)...) sum(rows) == length(As) || throw(ArgumentError("mismatch between row sizes and number of arguments")) n = fill(-1, length(As)) @@ -492,6 +494,8 @@ end BlockDiagonalMap{T}(maps::As) where {T, As<:LinearMapTupleOrVector} = BlockDiagonalMap{T,As}(maps) +BlockDiagonalMap{T}(maps::LinearMap...) where {T} = + BlockDiagonalMap{T}(maps) BlockDiagonalMap(maps::LinearMap...) = BlockDiagonalMap{promote_type(map(eltype, maps)...)}(maps) @@ -509,9 +513,9 @@ for k in 1:8 # is 8 sufficient? # Dispatching on `cat` makes compiler hard to infer types and causes invalidations # after https://github.com/JuliaLang/julia/pull/45028 # Here we instead dispatch on _cat - @eval function Base._cat(dims, $(Is...), $L, As...) + @eval function Base._cat_t(dims::Dims{2}, ::Type{T}, $(Is...), $L, As...) where {T} if dims == (1,2) - return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., + return BlockDiagonalMap{T}(convert_to_lmaps($(mapargs...))..., $(Symbol(:A, k)), convert_to_lmaps(As...)...) else From 223f060fcff6006f272d203d9b9003a5e6823060 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 28 Jun 2024 18:30:55 +0200 Subject: [PATCH 3/5] hang [h/v]cat lower --- src/LinearMaps.jl | 3 ++- src/blockmap.jl | 39 ++++++++++++++++++++------------------- test/Project.toml | 4 +--- test/blockmap.jl | 22 +++++++++++----------- 4 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/LinearMaps.jl b/src/LinearMaps.jl index 3eb6ce6d..849c5f33 100644 --- a/src/LinearMaps.jl +++ b/src/LinearMaps.jl @@ -22,8 +22,9 @@ const LinearMapVector = AbstractVector{<:LinearMap} const LinearMapTupleOrVector = Union{LinearMapTuple,LinearMapVector} Base.eltype(::LinearMap{T}) where {T} = T -Base.eltype(::Type{L}) where {T,L<:LinearMap{T}} = T +Base.eltype(::Type{<:LinearMap{T}}) where {T} = T Base.eltypeof(x::LinearMap) = eltype(x) +Base.eltypeof(J::UniformScaling) = eltype(J) # fix upstream Base.promote_eltypeof(v1::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}, vs::Union{AbstractVecOrMatOrQ{T},LinearMap{T}}...) where {T} = T # conversion to LinearMap diff --git a/src/blockmap.jl b/src/blockmap.jl index 7ff6c0a7..f6d6cd61 100644 --- a/src/blockmap.jl +++ b/src/blockmap.jl @@ -83,18 +83,7 @@ julia> L * ones(Int, 6) """ Base.hcat -Base.hcat(As::T...) where {T<:LinearMap} = Base._cat_t(Val(2), eltype(T), As...) -function Base._cat_t(::Val{2}, ::Type{T}, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) where {T} - nbc = length(As) - - # find first non-UniformScaling to detect number of rows - j = findfirst(A -> !isa(A, UniformScaling), As) - # this should not happen, function should only be called with at least one LinearMap - @assert !isnothing(j) - @inbounds nrows = size(As[j], 1)::Int - - return BlockMap{T}(promote_to_lmaps(ntuple(_ -> nrows, Val(nbc)), 1, 1, As...), (nbc,)) -end +Base.hcat(As::T...) where {T<:LinearMap} = Base._cat(Val(2), As...) ############ # vcat @@ -123,18 +112,29 @@ julia> L * ones(Int, 3) """ Base.vcat -Base.vcat(As::T...) where {T<:LinearMap} = Base._cat_t(Val(1), eltype(T), As...) -function Base._cat_t(::Val{1}, ::Type{T}, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) where {T} - nbr = length(As) +Base.vcat(As::T...) where {T<:LinearMap} = Base._cat(Val(1), As...) + +function Base._cat(dims, As::Union{LinearMap, UniformScaling, AbstractArray, AbstractQ}...) + T = promote_type(map(eltype, As)...) + nb = length(As) # find first non-UniformScaling to detect number of rows j = findfirst(A -> !isa(A, UniformScaling), As) # this should not happen, function should only be called with at least one LinearMap @assert !isnothing(j) - @inbounds ncols = size(As[j], 2)::Int - - rows = ntuple(_ -> 1, Val(nbr)) - return BlockMap{T}(promote_to_lmaps(ntuple(_ -> ncols, Val(nbr)), 1, 2, As...), rows) + if dims isa Val{2} + @inbounds nrows = size(As[j], 1)::Int + return BlockMap{T}(promote_to_lmaps(ntuple(_ -> nrows, Val(nb)), 1, 1, As...), (nb,)) + elseif dims isa Val{1} + @inbounds ncols = size(As[j], 2)::Int + + rows = ntuple(_ -> 1, Val(nb)) + return BlockMap{T}(promote_to_lmaps(ntuple(_ -> ncols, Val(nb)), 1, 2, As...), rows) + elseif dims isa Dims{2} + Base._cat_t(dims, T, As...) + else + throw(ArgumentError("unhandled dims argument")) + end end ############ @@ -225,6 +225,7 @@ end promote_to_lmaps_(n::Int, dim, A::AbstractVecOrMat) = (check_dim(A, dim, n); LinearMap(A)) promote_to_lmaps_(n::Int, dim, J::UniformScaling) = UniformScalingMap(J.λ, n) +promote_to_lmaps_(n::Int, dim, Q::AbstractQ) = (check_dim(Q, dim, n); LinearMap(Q)) promote_to_lmaps_(n::Int, dim, A::LinearMap) = (check_dim(A, dim, n); A) promote_to_lmaps(n, k, dim) = () promote_to_lmaps(n, k, dim, A) = (promote_to_lmaps_(n[k], dim, A),) diff --git a/test/Project.toml b/test/Project.toml index 9d85356c..0f54a207 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,7 +4,6 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Octonions = "d00ba074-1e29-4f5e-9fd4-d67071d6a14d" @@ -15,11 +14,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -BlockArrays = "0.16" +BlockArrays = "0.16, 1" ChainRulesCore = "1" ChainRulesTestUtils = "1.9" Documenter = "1" -InteractiveUtils = "1.6" IterativeSolvers = "0.9" LinearAlgebra = "1.6" Octonions = "0.1, 0.2" diff --git a/test/blockmap.jl b/test/blockmap.jl index 36c016f6..3b57dd3c 100644 --- a/test/blockmap.jl +++ b/test/blockmap.jl @@ -1,4 +1,4 @@ -using Test, LinearMaps, LinearAlgebra, SparseArrays, InteractiveUtils +using Test, LinearMaps, LinearAlgebra, SparseArrays using LinearMaps: FiveArg @testset "block maps" begin @@ -29,10 +29,9 @@ using LinearMaps: FiveArg A = [A11 A12 A11] @test Matrix(L) == A == mul!(zero(A), L, 1, true, false) A = [I I I A11 A11 A11 a] - @test (@which [A11 A11 A11]).module != LinearMaps - @test (@which [I I I A11 A11 A11]).module != LinearMaps - @test (@which hcat(I, I, I)).module != LinearMaps - @test (@which hcat(I, I, I, LinearMap(A11), A11, A11)).module == LinearMaps + @test [A11 A11 A11] isa AbstractArray + @test [I I I A11 A11 A11 qr(A11).Q a] isa AbstractArray + @test [I I I A11 A11 A11 qr(A11).Q LinearMap(a)] isa LinearMap maps = @inferred LinearMaps.promote_to_lmaps(ntuple(i->m, 7), 1, 1, I, I, I, LinearMap(A11), A11, A11, a) @inferred LinearMaps.rowcolranges(maps, (7,)) L = @inferred hcat(I, I, I, LinearMap(A11), A11, A11, a) @@ -68,7 +67,7 @@ using LinearMaps: FiveArg Lv = LinearMaps.BlockMap{elty}([LinearMap(A11), LinearMap(A21)], (1,1)) @test Lv.maps isa Vector @test L == Lv - @test (@which [A11; A21]).module != LinearMaps + @test [A11; A21] isa AbstractArray A = [A11; A21] x = rand(elty, n) @test size(L) == size(A) @@ -76,7 +75,7 @@ using LinearMaps: FiveArg @test Matrix(Lv) == A == mul!(copy(A), Lv, 1, true, false) @test L * x ≈ Lv * x ≈ A * x A = [I; I; I; A11; A11; A11; reduce(hcat, fill(v, n))] - @test (@which [I; I; I; A11; A11; A11; v v v v v v v v v v]).module != LinearMaps + @test [I; I; I; A11; A11; A11; v v v] isa AbstractArray L = @inferred vcat(I, I, I, LinearMap(A11), LinearMap(A11), LinearMap(A11), reduce(hcat, fill(v, n))) @test L == [I; I; I; LinearMap(A11); LinearMap(A11); LinearMap(A11); reduce(hcat, fill(v, n))] @test L isa LinearMaps.BlockMap{elty} @@ -97,7 +96,7 @@ using LinearMaps: FiveArg A21 = rand(elty, m2, m1) A22 = ones(elty, m2, m2) A = [A11 A12; A21 A22] - @test (@which [A11 A12; A21 A22]).module != LinearMaps + @test [A11 A12; A21 A22] isa AbstractArray @inferred hvcat((2,2), LinearMap(A11), LinearMap(A12), LinearMap(A21), LinearMap(A22)) L = [LinearMap(A11) LinearMap(A12); LinearMap(A21) LinearMap(A22)] @test L.maps isa Tuple @@ -116,7 +115,7 @@ using LinearMaps: FiveArg end @test convert(AbstractMatrix, L) == A A = [I A12; A21 I] - @test (@which [I A12; A21 I]).module != LinearMaps + @test [I A12; A21 I] isa AbstractArray @inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I) L = @inferred hvcat((2,2), I, LinearMap(A12), LinearMap(A21), I) @test L isa LinearMaps.BlockMap{elty} @@ -222,8 +221,9 @@ using LinearMaps: FiveArg end # Md = diag(M1, M2, M3, M2, M1) # unsupported so use sparse: Md = Matrix(blockdiag(sparse.((M1, M2, M3, M2, M1))...)) - @test (@which blockdiag(sparse.((M1, M2, M3, M2, M1))...)).module != LinearMaps - @test (@which cat(M1, M2, M3, M2, M1; dims=(1,2))).module != LinearMaps + @test blockdiag(sparse.((M1, M2, M3, M2, M1))...) isa AbstractArray + @test cat(M1, M2, M3, M2, M1; dims=(1,2)) isa AbstractArray + @test cat(M2, M2, qr(M3).Q, M3[:,1]; dims=(1,2)) isa AbstractArray x = randn(elty, size(Md, 2)) Bd = @inferred blockdiag(L1, L2, L3, L2, L1) @test Bd.maps isa Tuple From c652100bc8c3892cd627bb83651233b8bc1bce23 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Fri, 28 Jun 2024 20:02:40 +0200 Subject: [PATCH 4/5] don't run doctests temporarily --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index d0e90d14..37ca3c0b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using Test, Documenter, LinearMaps, Aqua Aqua.test_all(LinearMaps, piracies = (broken=true,)) end -doctest(LinearMaps) +# doctest(LinearMaps) include("linearmaps.jl") From da9eded25edf1719fc74a22dfd1c95b3e76663ad Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 29 Jun 2024 10:25:04 +0200 Subject: [PATCH 5/5] add another test --- test/blockmap.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/blockmap.jl b/test/blockmap.jl index 3b57dd3c..e265b44d 100644 --- a/test/blockmap.jl +++ b/test/blockmap.jl @@ -68,6 +68,7 @@ using LinearMaps: FiveArg @test Lv.maps isa Vector @test L == Lv @test [A11; A21] isa AbstractArray + @test [A11; qr(A11).Q; I] isa AbstractArray A = [A11; A21] x = rand(elty, n) @test size(L) == size(A)