Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support transtype for DualLayout{ApplyLayout{typeof(*)}} #335

Merged
merged 11 commits into from
Jul 8, 2024
1 change: 1 addition & 0 deletions .github/workflows/downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- {repo: InfiniteArrays.jl, group: JuliaArrays}
- {repo: QuasiArrays.jl, group: JuliaApproximation}
- {repo: ContinuumArrays.jl, group: JuliaApproximation}
- {repo: ClassicalOrthogonalPolynomials.jl, group: JuliaApproximation}

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "2.1.4"
version = "2.1.5"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
49 changes: 32 additions & 17 deletions ext/LazyArraysBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using LazyArrays.ArrayLayouts, LazyArrays.FillArrays, LazyArrays.LazyArrays
import ArrayLayouts: colsupport, rowsupport, materialize!, MatMulVecAdd, MatMulMatAdd, DenseColumnMajor,
OnesLayout, AbstractFillLayout, mulreduce, inv_layout, _fill_lmul!, copyto!_layout, _copy_oftype,
layout_getindex
layout_getindex, transtype
import LazyArrays: sublayout, symmetriclayout, hermitianlayout, applylayout, cachedlayout, transposelayout,
LazyArrayStyle, ApplyArrayBroadcastStyle, AbstractInvLayout, AbstractLazyLayout, LazyLayouts,
AbstractPaddedLayout, PaddedLayout, PaddedRows, PaddedColumns, CachedArray, CachedMatrix, LazyLayout, BroadcastLayout, ApplyLayout,
Expand Down Expand Up @@ -132,22 +132,30 @@
# Vcat(Zeros{T}(max(0,j-u-1)), view(data, (kr .- j .+ (u+1)) ∩ axes(data,1), j))
# end

function similar(M::MulAdd{<:BandedLayouts,<:AbstractPaddedLayout}, ::Type{T}, axes::Tuple{Any}) where T
function similar(M::MulAdd{<:BandedLayouts,<:Union{PaddedColumns,PaddedLayout}}, ::Type{T}, axes::Tuple{Any}) where T

Check warning on line 135 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L135

Added line #L135 was not covered by tests
A,x = M.A,M.B
xf = paddeddata(x)
n = max(0,min(length(xf) + bandwidth(A,1),length(M)))
Vcat(Vector{T}(undef, n), Zeros{T}(size(A,1)-n))
end

function similar(M::MulAdd{<:BandedLayouts,<:AbstractPaddedLayout}, ::Type{T}, axes::Tuple{Any,Any}) where T
function similar(M::MulAdd{<:BandedLayouts,<:Union{PaddedColumns,PaddedLayout}}, ::Type{T}, axes::Tuple{Any,Any}) where T

Check warning on line 142 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L142

Added line #L142 was not covered by tests
A,x = M.A,M.B
xf = paddeddata(x)
m = max(0,min(size(xf,1) + bandwidth(A,1),size(M,1)))
n = size(xf,2)
PaddedArray(Matrix{T}(undef, m, n), size(A,1), size(x,2))
end

function materialize!(M::MatMulVecAdd{<:BandedLayouts,<:AbstractPaddedLayout,<:AbstractPaddedLayout})
function similar(M::MulAdd{<:DualLayout{<:PaddedRows}, <:BandedLayouts}, ::Type{T}, axes::Tuple{Any,Any}) where T
xt,A = M.A,M.B
trans = transtype(xt)
xf = paddeddata(trans(xt))
n = max(0,min(length(xf) + bandwidth(A,2),size(M,2)))
trans(Vcat(Vector{T}(undef, n), Zeros{T}(size(A,1)-n)))

Check warning on line 155 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L150-L155

Added lines #L150 - L155 were not covered by tests
end

function materialize!(M::MatMulVecAdd{<:BandedLayouts,<:Union{PaddedColumns,PaddedLayout},<:Union{PaddedColumns,PaddedLayout}})

Check warning on line 158 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L158

Added line #L158 was not covered by tests
α,A,x,β,y = M.α,M.A,M.B,M.β,M.C
length(y) == size(A,1) || throw(DimensionMismatch())
length(x) == size(A,2) || throw(DimensionMismatch())
Expand All @@ -160,7 +168,7 @@
y
end

function materialize!(M::MatMulMatAdd{<:BandedLayouts,<:AbstractPaddedLayout,<:AbstractPaddedLayout})
function materialize!(M::MatMulMatAdd{<:BandedLayouts,<:Union{PaddedColumns,PaddedLayout},<:Union{PaddedColumns,PaddedLayout}})

Check warning on line 171 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L171

Added line #L171 was not covered by tests
α,A,x,β,y = M.α,M.A,M.B,M.β,M.C
size(y) == (size(A,1),size(x,2)) || throw(DimensionMismatch())
size(x,1) == size(A,2) || throw(DimensionMismatch())
Expand Down Expand Up @@ -293,7 +301,7 @@
broadcasted(::LazyArrayStyle, ::typeof(/), A::BandedMatrix, c::Number) = _BandedMatrix(A.data ./ c, A.raxis, A.l, A.u)


copy(M::Mul{BroadcastBandedLayout{typeof(*)}, <:AbstractPaddedLayout}) = _broadcast_banded_padded_mul(arguments(BroadcastBandedLayout{typeof(*)}(), M.A), M.B)
copy(M::Mul{BroadcastBandedLayout{typeof(*)}, <:Union{PaddedColumns,PaddedLayout}}) = _broadcast_banded_padded_mul(arguments(BroadcastBandedLayout{typeof(*)}(), M.A), M.B)

Check warning on line 304 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L304

Added line #L304 was not covered by tests


###
Expand Down Expand Up @@ -585,6 +593,7 @@
copy(M::Mul{<:AbstractInvLayout, <:BandedLazyLayouts}) = simplify(M)



copy(L::Ldiv{<:BandedLazyLayouts}) = lazymaterialize(\, L.A, L.B)
copy(L::Ldiv{<:BandedLazyLayouts,<:AbstractLazyLayout}) = lazymaterialize(\, L.A, L.B)
copy(L::Ldiv{<:BandedLazyLayouts, Blay}) where Blay<:Union{AbstractStridedLayout,PaddedColumns} = copy(Ldiv{UnknownLayout,Blay}(L.A, L.B))
Expand Down Expand Up @@ -619,18 +628,24 @@
copy(M::Mul{BroadcastLayout{typeof(*)}, <:BroadcastBandedLayout}) = lazymaterialize(*, M.A, M.B)

## padded copy
mulreduce(M::Mul{<:BroadcastBandedLayout, <:Union{AbstractPaddedLayout,AbstractStridedLayout}}) = MulAdd(M)
mulreduce(M::Mul{ApplyBandedLayout{F}, D}) where {F,D<:Union{AbstractPaddedLayout,AbstractStridedLayout}} = Mul{ApplyLayout{F},D}(M.A, M.B)
mulreduce(M::Mul{<:BroadcastBandedLayout, <:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}}) = MulAdd(M)
mulreduce(M::Mul{ApplyBandedLayout{F}, D}) where {F,D<:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}} = Mul{ApplyLayout{F},D}(M.A, M.B)

Check warning on line 632 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L631-L632

Added lines #L631 - L632 were not covered by tests
# need to overload copy due to above
copy(M::Mul{<:BroadcastBandedLayout, <:Union{AbstractPaddedLayout,AbstractStridedLayout}}) = copy(mulreduce(M))
copy(M::Mul{<:AbstractInvLayout{<:BandedLazyLayouts}, <:Union{AbstractPaddedLayout,AbstractStridedLayout}}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
copy(M::Mul{<:BandedLazyLayouts, <:Union{AbstractPaddedLayout,AbstractStridedLayout}}) = copy(mulreduce(M))
copy(M::Mul{<:Union{AbstractPaddedLayout,AbstractStridedLayout,DualLayout{<:PaddedRows}}, <:BandedLazyLayouts}) = copy(mulreduce(M))

simplifiable(M::Mul{<:AbstractInvLayout{<:BandedLazyLayouts}, <:Union{AbstractPaddedLayout,AbstractStridedLayout}}) = Val(true)
simplifiable(M::Mul{<:Union{AbstractPaddedLayout,AbstractStridedLayout}, <:BandedLazyLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLazyLayouts, <:Union{AbstractPaddedLayout,AbstractStridedLayout}}) = Val(true)
simplifiable(::Mul{<:BroadcastBandedLayout, <:Union{AbstractPaddedLayout,AbstractStridedLayout}}) = Val(true)
copy(M::Mul{<:BroadcastBandedLayout, <:Union{PaddedColumns,PaddedLayout}}) = copy(mulreduce(M))
copy(M::Mul{<:BroadcastBandedLayout, <:AbstractStridedLayout}) = copy(mulreduce(M))
copy(M::Mul{<:AbstractInvLayout{<:BandedLazyLayouts}, <:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}}) = ArrayLayouts.ldiv(pinv(M.A), M.B)
copy(M::Mul{<:BandedLayouts, <:Union{PaddedColumns,PaddedLayout}}) = copy(mulreduce(M))
copy(M::Mul{<:BandedLazyLayouts, <:Union{PaddedColumns,PaddedLayout}}) = copy(mulreduce(M))
copy(M::Mul{<:BandedLazyLayouts, <:AbstractStridedLayout}) = copy(mulreduce(M))
copy(M::Mul{<:Union{PaddedRows,PaddedLayout,DualLayout{<:PaddedRows}}, <:BandedLayouts}) = copy(mulreduce(M))
copy(M::Mul{<:Union{PaddedRows,PaddedLayout,DualLayout{<:PaddedRows}}, <:BandedLazyLayouts}) = copy(mulreduce(M))
copy(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLazyLayouts}) = copy(mulreduce(M))

Check warning on line 642 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L634-L642

Added lines #L634 - L642 were not covered by tests

simplifiable(M::Mul{<:AbstractInvLayout{<:BandedLayouts}, <:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}}) = Val(true)
simplifiable(M::Mul{<:Union{PaddedRows,PaddedLayout,DualLayout{<:PaddedRows}}, <:BandedLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLayouts, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLazyLayouts}) = Val(true)
simplifiable(M::Mul{<:BandedLazyLayouts, <:AbstractStridedLayout}) = Val(true)

Check warning on line 648 in ext/LazyArraysBandedMatricesExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LazyArraysBandedMatricesExt.jl#L644-L648

Added lines #L644 - L648 were not covered by tests

copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where Lay = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where {Lay<:AbstractLazyLayout} = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
Expand Down
2 changes: 1 addition & 1 deletion src/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import ArrayLayouts: AbstractQLayout, Dot, Dotu, Ldiv, Lmul, MatMulMatAdd, MatMu
hermitianlayout, layout_getindex, layout_replace_in_print_matrix, ldivaxes, materialize,
materialize!, mulreduce, reshapedlayout, rowsupport, scalarone, scalarzero, sub_materialize,
sublayout, symmetriclayout, symtridiagonallayout, transposelayout, triangulardata,
triangularlayout, tridiagonallayout, zero!
triangularlayout, tridiagonallayout, zero!, transtype

import FillArrays: AbstractFill, getindex_value

Expand Down
23 changes: 18 additions & 5 deletions src/linalg/add.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,24 @@

for op in (:+, :-)
@eval begin
simplify(M::Mul{Lay}) where Lay<:BroadcastLayout{typeof($op)} = broadcast($op, _broadcasted_mul(arguments(Lay(), M.A), M.B)...)
simplify(M::Mul{<:Any,Lay}) where Lay<:BroadcastLayout{typeof($op)} = broadcast($op, _broadcasted_mul(M.A, arguments(Lay(), M.B))...)
simplify(M::Mul{Lay,Lay}) where Lay<:BroadcastLayout{typeof($op)} = simplify(Mul{Lay,UnknownLayout}(M.A, M.B))
simplifiable(M::Mul{<:BroadcastLayout{typeof($op)}}) = Val(true)
simplifiable(M::Mul{<:Any,<:BroadcastLayout{typeof($op)}}) = Val(true)
simplifiable(M::Mul{<:BroadcastLayout{typeof($op)},<:BroadcastLayout{typeof($op)}}) = simplifiable(Mul{BroadcastLayout{typeof($op)},UnknownLayout}(M.A, M.B))
copy(M::Mul{Lay}) where Lay<:BroadcastLayout{typeof($op)} = broadcast($op, _broadcasted_mul(arguments(Lay(), M.A), M.B)...)
copy(M::Mul{Lay,<:LazyLayouts}) where Lay<:BroadcastLayout{typeof($op)} = broadcast($op, _broadcasted_mul(arguments(Lay(), M.A), M.B)...)
copy(M::Mul{<:Any,Lay}) where Lay<:BroadcastLayout{typeof($op)} = broadcast($op, _broadcasted_mul(M.A, arguments(Lay(), M.B))...)
copy(M::Mul{<:LazyLayouts,Lay}) where Lay<:BroadcastLayout{typeof($op)} = broadcast($op, _broadcasted_mul(M.A, arguments(Lay(), M.B))...)
copy(M::Mul{Lay,Lay}) where Lay<:BroadcastLayout{typeof($op)} = copy(Mul{Lay,UnknownLayout}(M.A, M.B))
simplify(M::Mul{Lay}) where Lay<:BroadcastLayout{typeof($op)} = copy(Mul{Lay,UnknownLayout}(M.A, M.B)) # TODO: remove, here for back-compat with QuasiArrays.jl
simplify(M::Mul{<:Any,Lay}) where Lay<:BroadcastLayout{typeof($op)} = copy(Mul{UnknownLayout,Lay}(M.A, M.B)) # TODO: remove, here for back-compat with QuasiArrays.jl
simplify(M::Mul{Lay,Lay}) where Lay<:BroadcastLayout{typeof($op)} = copy(M) # TODO: remove, here for back-compat with QuasiArrays.jl

Check warning on line 115 in src/linalg/add.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/add.jl#L105-L115

Added lines #L105 - L115 were not covered by tests
end
end

simplify(M::Mul{Lay,BroadcastLayout{typeof(-)}}) where Lay<:BroadcastLayout{typeof(+)} = simplify(Mul{Lay,UnknownLayout}(M.A, M.B))
simplify(M::Mul{Lay,BroadcastLayout{typeof(+)}}) where Lay<:BroadcastLayout{typeof(-)} = simplify(Mul{Lay,UnknownLayout}(M.A, M.B))
simplifiable(M::Mul{Lay,BroadcastLayout{typeof(-)}}) where Lay<:BroadcastLayout{typeof(+)} = simplifiable(Mul{Lay,UnknownLayout}(M.A, M.B))
simplifiable(M::Mul{Lay,BroadcastLayout{typeof(+)}}) where Lay<:BroadcastLayout{typeof(-)} = simplifiable(Mul{Lay,UnknownLayout}(M.A, M.B))

Check warning on line 120 in src/linalg/add.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/add.jl#L119-L120

Added lines #L119 - L120 were not covered by tests

copy(M::Mul{Lay,BroadcastLayout{typeof(-)}}) where Lay<:BroadcastLayout{typeof(+)} = copy(Mul{Lay,UnknownLayout}(M.A, M.B))
copy(M::Mul{Lay,BroadcastLayout{typeof(+)}}) where Lay<:BroadcastLayout{typeof(-)} = copy(Mul{Lay,UnknownLayout}(M.A, M.B))
simplify(M::Mul{Lay,BroadcastLayout{typeof(-)}}) where Lay<:BroadcastLayout{typeof(+)} = copy(Mul{Lay,UnknownLayout}(M.A, M.B)) # TODO: remove, here for back-compat with QuasiArrays.jl
simplify(M::Mul{Lay,BroadcastLayout{typeof(+)}}) where Lay<:BroadcastLayout{typeof(-)} = copy(Mul{Lay,UnknownLayout}(M.A, M.B)) # TODO: remove, here for back-compat with QuasiArrays.jl

Check warning on line 125 in src/linalg/add.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/add.jl#L122-L125

Added lines #L122 - L125 were not covered by tests
1 change: 1 addition & 0 deletions src/linalg/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@
@inline copy(M::Mul{<:DualLayout,<:LazyLayouts,<:AbstractMatrix,<:AbstractVector}) = copy(Dot(M))

applylayout(::Type{typeof(*)}, ::DualLayout{Lay}, args...) where Lay = DualLayout{typeof(applylayout(typeof(*), Lay(), args...))}()
transtype(A::MulMatrix) = transtype(first(A.args))

Check warning on line 371 in src/linalg/mul.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg/mul.jl#L371

Added line #L371 was not covered by tests

#TODO: Why not all DiagonalLayout?
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
Expand Down
28 changes: 25 additions & 3 deletions src/padded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,33 @@
end

# avoid ambiguity in LazyBandedMatrices
copy(M::Mul{<:DiagonalLayout,<:AbstractPaddedLayout}) = copy(Lmul(M))
copy(M::Mul{<:Union{TriangularLayout{'U', 'N', <:AbstractLazyLayout}, TriangularLayout{'U', 'U', <:AbstractLazyLayout}}, <:AbstractPaddedLayout}) = copy(Lmul(M))
simplifiable(::Mul{<:Union{TriangularLayout{'U', 'N', <:AbstractLazyLayout}, TriangularLayout{'U', 'U', <:AbstractLazyLayout}}, <:AbstractPaddedLayout}) = Val(true)
copy(M::Mul{<:DiagonalLayout,<:Union{PaddedColumns,PaddedLayout}}) = copy(Lmul(M))
copy(M::Mul{<:Union{TriangularLayout{'U', 'N', <:AbstractLazyLayout}, TriangularLayout{'U', 'U', <:AbstractLazyLayout}}, <:Union{PaddedColumns,PaddedLayout}}) = copy(Lmul(M))
simplifiable(::Mul{<:Union{TriangularLayout{'U', 'N', <:AbstractLazyLayout}, TriangularLayout{'U', 'U', <:AbstractLazyLayout}}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)

Check warning on line 510 in src/padded.jl

View check run for this annotation

Codecov / codecov/patch

src/padded.jl#L508-L510

Added lines #L508 - L510 were not covered by tests


simplifiable(::Mul{<:DualLayout{<:AbstractLazyLayout}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
copy(M::Mul{<:DualLayout{<:AbstractLazyLayout}, <:Union{PaddedColumns,PaddedLayout}}) = copy(mulreduce(M))
simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)

Check warning on line 515 in src/padded.jl

View check run for this annotation

Codecov / codecov/patch

src/padded.jl#L513-L515

Added lines #L513 - L515 were not covered by tests

function simplifiable(M::Mul{<:DualLayout{<:PaddedRows}, <:LazyLayouts})
trans = transtype(M.A)
simplifiable(*, trans(M.B), trans(M.A))

Check warning on line 519 in src/padded.jl

View check run for this annotation

Codecov / codecov/patch

src/padded.jl#L517-L519

Added lines #L517 - L519 were not covered by tests
end
function copy(M::Mul{<:DualLayout{<:PaddedRows}, <:LazyLayouts})
trans = transtype(M.A)
trans(trans(M.B) * trans(M.A))

Check warning on line 523 in src/padded.jl

View check run for this annotation

Codecov / codecov/patch

src/padded.jl#L521-L523

Added lines #L521 - L523 were not covered by tests
end

for op in (:+, :-)
@eval begin
simplifiable(M::Mul{<:BroadcastLayout{typeof($op)},<:Union{PaddedColumns,PaddedLayout}}) = Val(true)
simplifiable(M::Mul{<:DualLayout{<:PaddedRows},<:BroadcastLayout{typeof($op)}}) = Val(true)
copy(M::Mul{Lay,<:Union{PaddedColumns,PaddedLayout}}) where Lay<:BroadcastLayout{typeof($op)} = copy(Mul{Lay,UnknownLayout}(M.A, M.B))
copy(M::Mul{<:DualLayout{<:PaddedRows},Lay}) where Lay<:BroadcastLayout{typeof($op)} = copy(Mul{UnknownLayout,Lay}(M.A, M.B))

Check warning on line 531 in src/padded.jl

View check run for this annotation

Codecov / codecov/patch

src/padded.jl#L528-L531

Added lines #L528 - L531 were not covered by tests
end
end


# Triangular columns

Expand Down
30 changes: 29 additions & 1 deletion test/addtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module AddTests

using LazyArrays, Test
using LinearAlgebra
import LazyArrays: Add, AddArray, MulAdd, materialize!, MemoryLayout, ApplyLayout
import LazyArrays: Add, AddArray, MulAdd, materialize!, MemoryLayout, ApplyLayout, simplifiable, simplify

@testset "Add/Subtract" begin
@testset "Add" begin
Expand Down Expand Up @@ -302,6 +302,34 @@ import LazyArrays: Add, AddArray, MulAdd, materialize!, MemoryLayout, ApplyLayou
@test B*C ≈ 3A * (-A)
@test C*B ≈ (-A) * 3A
end

@testset "simplifiable" begin
A = randn(5,5)
B = BroadcastArray(+, A, 2A)
C = BroadcastArray(-, A, 2A)
D = ApplyArray(exp, A)
@test simplifiable(*, A, B) == Val(true)
@test simplifiable(*, B, A) == Val(true)
@test simplifiable(*, A, C) == Val(true)
@test simplifiable(*, C, A) == Val(true)
@test simplifiable(*, B, C) == Val(true)
@test simplifiable(*, C, B) == Val(true)
@test simplifiable(*, B, B) == Val(true)
@test simplifiable(*, C, C) == Val(true)
@test simplifiable(*, B, D) == Val(true)
@test simplifiable(*, D, B) == Val(true)

@test A*B ≈ simplify(Mul(A,B)) ≈ A * Matrix(B)
@test B*A ≈ simplify(Mul(B,A)) ≈ Matrix(B)A
@test A*C ≈ simplify(Mul(A,C)) ≈ A * Matrix(C)
@test C*A ≈ simplify(Mul(C,A)) ≈ Matrix(C)A
@test B*C ≈ simplify(Mul(B,C)) ≈ B * Matrix(C)
@test C*B ≈ simplify(Mul(C,B)) ≈ Matrix(C)B
@test B*B ≈ simplify(Mul(B,B)) ≈ Matrix(B)B
@test C*C ≈ simplify(Mul(C,C)) ≈ Matrix(C)C
@test B*D ≈ simplify(Mul(B,D)) ≈ Matrix(B)*D
@test D*B ≈ simplify(Mul(D,B)) ≈ D*Matrix(B)
end
end
end # testset

Expand Down
44 changes: 44 additions & 0 deletions test/bandedtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,50 @@ LinearAlgebra.lmul!(β::Number, A::PseudoBandedMatrix) = (lmul!(β, A.data); A)
@test A[band(1)] == [A[i, i+1] for i in 1:7]
@test A[band(-1)] == [A[i+1, i] for i in 1:8]
end

@testset "Banded * Padded" begin
n = 10
A = _BandedMatrix(MyLazyArray(randn(3,n)),n,1,1)
B = brand(n,n,1,1)
C = BroadcastArray(+, B)
x = Vcat([1,2,3], Zeros(n-3))
y = randn(n)
@test A*x isa Vcat
@test B*x isa Vcat
@test C*x isa Vcat
@test simplifiable(*, A, x) == Val(true)
@test simplifiable(*, B, x) == Val(true)
@test simplifiable(*, C, x) == Val(true)

@test A*y isa Vector
@test C*y isa Vector
@test simplifiable(*, A, y) == Val(true)
@test simplifiable(*, C, y) == Val(true)

@test x'A isa Adjoint{<:Any,<:Vcat}
@test x'B isa Adjoint{<:Any,<:Vcat}
@test x'C isa Adjoint{<:Any,<:Vcat}
@test transpose(x)A isa Transpose{<:Any,<:Vcat}
@test transpose(x)B isa Transpose{<:Any,<:Vcat}
@test transpose(x)C isa Transpose{<:Any,<:Vcat}
@test simplifiable(*, x', A) == Val(true)
@test simplifiable(*, x', B) == Val(true)
@test simplifiable(*, x', C) == Val(true)
@test simplifiable(*, transpose(x), A) == Val(true)
@test simplifiable(*, transpose(x), B) == Val(true)
@test simplifiable(*, transpose(x), C) == Val(true)
@test x'A ≈ x'Matrix(A)
@test x'B ≈ x'Matrix(B)
@test x'C ≈ x'Matrix(C)

@test y'A isa Adjoint{<:Any,<:Vector}
@test y'C isa Adjoint{<:Any,<:Vector}
@test simplifiable(*, y', A) == Val(true)
@test simplifiable(*, y', B) == Val(true)
@test simplifiable(*, y', C) == Val(true)
@test y'A ≈ y'Matrix(A)
@test y'C ≈ y'Matrix(C)
end
end

end # module
5 changes: 5 additions & 0 deletions test/lazymultests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
@test (UpperTriangular(A)UpperTriangular(A)) * b isa Vector
@test (UpperTriangular(A)UpperTriangular(A)) * b ≈UpperTriangular(A.A)^2 * b
end

@testset "DualLayout{ApplyLayout{typeof(*)}} translayout" begin
@test ApplyArray(*, (1:2)', [1 2; 3 4]) * [1 2; 3 4] ≈ [37,54]'
@test ApplyArray(*, (1:2)', [1 2; 3 4]) * [1 2; 3 4] isa Adjoint
end
end

end # module
26 changes: 25 additions & 1 deletion test/paddedtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module PaddedTests

using LazyArrays, FillArrays, ArrayLayouts, Base64, Test
using StaticArrays
import LazyArrays: PaddedLayout, PaddedRows, PaddedColumns, LayoutVector, MemoryLayout, paddeddata, ApplyLayout, sub_materialize, CachedVector
import LazyArrays: PaddedLayout, PaddedRows, PaddedColumns, LayoutVector, MemoryLayout, paddeddata, ApplyLayout, sub_materialize, CachedVector, simplifiable
import ArrayLayouts: OnesLayout
import Base: setindex
using LinearAlgebra
Expand Down Expand Up @@ -396,6 +396,30 @@ paddeddata(a::PaddedPadded) = a
@test H[:,1:10] == [1 zeros(9)']
@test H[:,2:10] == zeros(9)'
end

@testset "Mul simplifiable" begin
a = Vcat(5, 1:7)
b = Vcat([1,2], Zeros(6))
@test a'b == b'a == Vector(a)'b
@test simplifiable(*, a', b) == Val(true)
@test simplifiable(*, b', a) == Val(true)

D = Diagonal(Fill(2,8))
@test D*b isa Vcat
@test simplifiable(*, D, b) == Val(true)

B = BroadcastArray(+, 1:8, (2:9)')
C = ApplyArray(exp, randn(8,8))
@test B'b == Matrix(B)'b
@test b'B == b'Matrix(B)
@test simplifiable(*, B', b) == Val(true)
@test simplifiable(*, b', B) == Val(true)
@test simplifiable(*, C', b) == Val(false)
@test simplifiable(*, b', C) == Val(false)

@test C'b ≈ Matrix(C)'b
@test b'C ≈ b'Matrix(C)
end
end

end # module
Loading