Skip to content

Commit

Permalink
Reduce UniformScalingMaps under Kronecker products, perf improvements (
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Mar 17, 2021
1 parent b53903d commit 209d6e2
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 25 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LinearMaps"
uuid = "7a12625a-238d-50fd-b39a-03d52299707e"
version = "3.2.3"
version = "3.2.4"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
50 changes: 27 additions & 23 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Base.kron(A::KroneckerMap, B::KroneckerMap) =
Base.kron(A::ScaledMap, B::LinearMap) = A.λ * kron(A.lmap, B)
Base.kron(A::LinearMap, B::ScaledMap) = kron(A, B.lmap) * B.λ
Base.kron(A::ScaledMap, B::ScaledMap) = (A.λ * B.λ) * kron(A.lmap, B.lmap)
# reduce UniformScalingMaps
Base.kron(A::UniformScalingMap, B::UniformScalingMap) = UniformScalingMap(A.λ * B.λ, A.M * B.M)
# disambiguation
Base.kron(A::ScaledMap, B::KroneckerMap) = A.λ * kron(A.lmap, B)
Base.kron(A::KroneckerMap, B::ScaledMap) = kron(A, B.lmap) * B.λ
Expand Down Expand Up @@ -112,44 +114,46 @@ Base.:(==)(A::KroneckerMap, B::KroneckerMap) = (eltype(A) == eltype(B) && A.maps
# multiplication helper functions
#################

@inline function _kronmul!(y, B, x, At, T)
na, ma = size(At)
@inline function _kronmul!(y, B, x, A, T)
ma, na = size(A)
mb, nb = size(B)
X = reshape(x, (nb, na))
v = zeros(T, ma)
temp1 = similar(y, na)
temp2 = similar(y, nb)
@views @inbounds for i in 1:ma
v[i] = one(T)
_unsafe_mul!(temp1, At, v)
_unsafe_mul!(temp2, X, temp1)
_unsafe_mul!(y[((i-1)*mb+1):i*mb], B, temp2)
v[i] = zero(T)
Y = reshape(y, (mb, ma))
if B isa UniformScalingMap
_unsafe_mul!(Y, X, transpose(A))
lmul!(B.λ, y)
else
temp = similar(Y, (ma, nb))
_unsafe_mul!(temp, A, copy(transpose(X)))
_unsafe_mul!(Y, B, transpose(temp))
end
return y
end
@inline function _kronmul!(y, B, x, At::UniformScalingMap, _)
na, ma = size(At)
@inline function _kronmul!(y, B, x, A::UniformScalingMap, _)
ma, na = size(A)
mb, nb = size(B)
iszero(A.λ) && return fill!(y, zero(eltype(y)))
X = reshape(x, (nb, na))
Y = reshape(y, (mb, ma))
_unsafe_mul!(Y, B, X, At.λ, false)
_unsafe_mul!(Y, B, X)
!isone(A.λ) && rmul!(y, A.λ)
return y
end
@inline function _kronmul!(y, B, x, At::MatrixMap, _)
na, ma = size(At)
@inline function _kronmul!(y, B, x, A::MatrixMap, _)
ma, na = size(A)
mb, nb = size(B)
X = reshape(x, (nb, na))
Y = reshape(y, (mb, ma))
At = transpose(A.lmap)
if B isa UniformScalingMap
# the following is (maybe due to the reshape?) faster than
# _unsafe_mul!(Y, B * X, At.lmap)
_unsafe_mul!(Y, X, At.lmap)
# the following is (perhaps due to the reshape?) faster than
# _unsafe_mul!(Y, B * X, At)
_unsafe_mul!(Y, X, At)
lmul!(B.λ, y)
elseif nb*ma <= mb*na
_unsafe_mul!(Y, B, X * At.lmap)
_unsafe_mul!(Y, B, X * At)
else
_unsafe_mul!(Y, Matrix(B*X), At.lmap)
_unsafe_mul!(Y, Matrix(B*X), At)
end
return y
end
Expand All @@ -163,14 +167,14 @@ const KroneckerMap2{T} = KroneckerMap{T, <:Tuple{LinearMap, LinearMap}}
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap2, x::AbstractVector)
require_one_based_indexing(y)
A, B = L.maps
_kronmul!(y, B, x, transpose(A), eltype(L))
_kronmul!(y, B, x, A, eltype(L))
return y
end
function _unsafe_mul!(y::AbstractVecOrMat, L::KroneckerMap, x::AbstractVector)
require_one_based_indexing(y)
A = first(L.maps)
B = kron(Base.tail(L.maps)...)
_kronmul!(y, B, x, transpose(A), eltype(L))
_kronmul!(y, B, x, A, eltype(L))
return y
end
# mixed-product rule, prefer the right if possible
Expand Down
8 changes: 8 additions & 0 deletions test/composition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,12 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays
@test u1 == u2
@test w1 == w2
end
L1 = LinearMap(rand(2,3))
L2 = LinearMap(rand(4,2))
L3 = LinearMap(rand(3, 4))
L4 = LinearMap(rand(5, 3))
Ls = L4*L3*L2*L1
X = rand(size(Ls, 2), 10)
Y = similar(X, (size(Ls, 1), size(X, 2)))
@test mul!(Y, Ls, X) L4.lmap * L3.lmap * L2.lmap * L1.lmap * X
end
7 changes: 6 additions & 1 deletion test/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ using Test, LinearMaps, LinearAlgebra, SparseArrays

m = 3
A = rand(m, m)
F = LinearMap(x -> A*x, m, m)
S = sparse(I, m, m)
J = LinearMap(I, m)
@test kron(J, J) == LinearMap(I, m*m)
v = rand(m^3)
for (K, M) in (((A, J, J), kron(A, S, S)),
((J, A, J), kron(S, A, S)),
((J, J, A), kron(S, S, A)))
((J, J, A), kron(S, S, A)),
((F, J, J), kron(A, S, S)),
((J, F, J), kron(S, A, S)),
((J, J, F), kron(S, S, A)))
@test K * v M * v
@test Matrix(K) M
end
Expand Down

2 comments on commit 209d6e2

@dkarrasch
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32158

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.2.4 -m "<description of version>" 209d6e237f58a4c82887b3714dd36a876e0e1b59
git push origin v3.2.4

Please sign in to comment.