From 095632067a6a0b034ace551c4edf1b6fa95c167e Mon Sep 17 00:00:00 2001 From: Johnny Chen Date: Tue, 26 Sep 2023 23:57:32 +0800 Subject: [PATCH] fix cat invalidations This patch removes the invalidation on cat and thus reduces the OneHotArrays loading time from 4.5s to 0.5s (the normal status) --- src/blockmap.jl | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/blockmap.jl b/src/blockmap.jl index 874ee543..63b72739 100644 --- a/src/blockmap.jl +++ b/src/blockmap.jl @@ -505,13 +505,28 @@ for k in 1:8 # is 8 sufficient? mapargs = ntuple(n ->:($(Symbol(:A, n))), Val(k-1)) # yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1))) - @eval function Base.cat($(Is...), $L, As::MapOrVecOrMat...; dims::Dims{2}) - if dims == (1,2) - return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., - $(Symbol(:A, k)), - convert_to_lmaps(As...)...) - else - throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + @static if VERSION >= v"1.8" + # Dispatching on `cat` makes compiler hard to infer types and causes invalidations + # after https://github.com/JuliaLang/julia/pull/45028 + # Here we instead dispatch on _cat + @eval function Base._cat(dims, $(Is...), $L, As...) + if dims == (1,2) + return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., + $(Symbol(:A, k)), + convert_to_lmaps(As...)...) + else + throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + end + end + else + @eval function Base.cat($(Is...), $L, As...; dims::Dims{2}) + if dims == (1,2) + return BlockDiagonalMap(convert_to_lmaps($(mapargs...))..., + $(Symbol(:A, k)), + convert_to_lmaps(As...)...) + else + throw(ArgumentError("dims keyword in cat of LinearMaps must be (1,2)")) + end end end end