Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make better use of BlockedTuple in contract logic to track codomain and domain #33

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.10"
version = "0.1.11"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
4 changes: 2 additions & 2 deletions src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
end

function blockedperm(bt::AbstractBlockTuple)
return blockedperm(Val(length(bt)), blocks(bt)...)
function blockedperm(bt::AbstractBlockTuple; length::Union{Val,Nothing}=nothing)
return blockedperm(Val(Base.length(bt)), blocks(bt)...)

Check warning on line 78 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L77-L78

Added lines #L77 - L78 were not covered by tests
end

function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
Expand Down
6 changes: 3 additions & 3 deletions src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...)
biperm_dest = blockedperm(permblocks_dest...)

Check warning on line 25 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L25

Added line #L25 was not covered by tests
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedperm(filter(!isempty, permblocks1)...)
biperm1 = blockedperm(permblocks1...)

Check warning on line 27 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L27

Added line #L27 was not covered by tests
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedperm(filter(!isempty, permblocks2)...)
biperm2 = blockedperm(permblocks2...)

Check warning on line 29 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L29

Added line #L29 was not covered by tests
return biperm_dest, biperm1, biperm2
end
12 changes: 6 additions & 6 deletions src/contract/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ default_contract_alg() = Matricize()
function contract!(
alg::Algorithm,
a_dest::AbstractArray,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation{2},
α::Number,
β::Number,
)
Expand Down Expand Up @@ -110,11 +110,11 @@ end

function contract(
alg::Algorithm,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation{2},
α::Number;
kwargs...,
)
Expand Down
95 changes: 7 additions & 88 deletions src/contract/contract_matricize/contract.jl
Original file line number Diff line number Diff line change
@@ -1,103 +1,22 @@
using LinearAlgebra: mul!

function contract!(
alg::Matricize,
::Matricize,
a_dest::AbstractArray,
biperm_dest::BlockedPermutation,
biperm_dest::AbstractBlockPermutation{2},
a1::AbstractArray,
biperm1::BlockedPermutation,
biperm1::AbstractBlockPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation,
biperm2::AbstractBlockPermutation{2},
α::Number,
β::Number,
)
a_dest_mat = fusedims(a_dest, biperm_dest)
a1_mat = fusedims(a1, biperm1)
a2_mat = fusedims(a2, biperm2)
_mul!(a_dest_mat, a1_mat, a2_mat, α, β)
@assert ndims(a1_mat) == 2
@assert ndims(a2_mat) == 2
mul!(a_dest_mat, a1_mat, a2_mat, α, β)

Check warning on line 19 in src/contract/contract_matricize/contract.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/contract_matricize/contract.jl#L17-L19

Added lines #L17 - L19 were not covered by tests
splitdims!(a_dest, a_dest_mat, biperm_dest)
return a_dest
end

# Matrix multiplication.
function _mul!(
a_dest::AbstractMatrix, a1::AbstractMatrix, a2::AbstractMatrix, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

# Inner product.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractVector,
a2::AbstractVector,
α::Number,
β::Number,
)
a_dest[] = transpose(a1) * a2 * α + a_dest[] * β
return a_dest
end

# Vec-mat.
function _mul!(
a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number
)
mul!(transpose(a_dest), transpose(a1), a2, α, β)
return a_dest
end

# Mat-vec.
function _mul!(
a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end

# Outer product.
function _mul!(
a_dest::AbstractMatrix, a1::AbstractVector, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, transpose(a2), α, β)
return a_dest
end

# Array-scalar contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractVector,
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
α′ = a2[] * α
a_dest .= a1 .* α′ .+ a_dest .* β
return a_dest
end

# Scalar-array contraction.
function _mul!(
a_dest::AbstractVector,
a1::AbstractArray{<:Any,0},
a2::AbstractVector,
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest .= a1[] .* a2 .* α .+ a_dest .* β
return a_dest
end

# Scalar-scalar contraction.
function _mul!(
a_dest::AbstractArray{<:Any,0},
a1::AbstractArray{<:Any,0},
a2::AbstractArray{<:Any,0},
α::Number,
β::Number,
)
# Preserve the ordering in case of non-commutative algebra.
a_dest[] = a1[] * a2[] * α + a_dest[] * β
return a_dest
end
12 changes: 6 additions & 6 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
using ArrayLayouts: LayoutMatrix
using LinearAlgebra: LinearAlgebra, Diagonal

function qr(a::AbstractArray, biperm::BlockedPermutation{2})
function qr(a::AbstractArray, biperm::AbstractBlockPermutation{2})

Check warning on line 4 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L4

Added line #L4 was not covered by tests
a_matricized = fusedims(a, biperm)
# TODO: Make this more generic, allow choosing thin or full,
# make sure this works on GPU.
q_fact, r_matricized = LinearAlgebra.qr(a_matricized)
q_matricized = typeof(a_matricized)(q_fact)
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_q = (axes_codomain..., axes(q_matricized, 2))
axes_r = (axes(r_matricized, 1), axes_domain...)
axes_q = tuplemortar((axes_codomain, (axes(q_matricized, 2),)))
axes_r = tuplemortar(((axes(r_matricized, 1),), axes_domain))

Check warning on line 12 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L11-L12

Added lines #L11 - L12 were not covered by tests
q = splitdims(q_matricized, axes_q)
r = splitdims(r_matricized, axes_r)
return q, r
Expand All @@ -22,15 +22,15 @@
)
end

function svd(a::AbstractArray, biperm::BlockedPermutation{2})
function svd(a::AbstractArray, biperm::AbstractBlockPermutation{2})

Check warning on line 25 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L25

Added line #L25 was not covered by tests
a_matricized = fusedims(a, biperm)
usv_matricized = LinearAlgebra.svd(a_matricized)
u_matricized = usv_matricized.U
s_diag = usv_matricized.S
v_matricized = usv_matricized.Vt
axes_codomain, axes_domain = blockpermute(axes(a), biperm)
axes_u = (axes_codomain..., axes(u_matricized, 2))
axes_v = (axes(v_matricized, 1), axes_domain...)
axes_u = tuplemortar((axes_codomain, (axes(u_matricized, 2),)))
axes_v = tuplemortar(((axes(v_matricized, 1),), axes_domain))

Check warning on line 33 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
u = splitdims(u_matricized, axes_u)
# TODO: Use `DiagonalArrays.diagonal` to make it more general.
s = Diagonal(s_diag)
Expand Down
23 changes: 13 additions & 10 deletions src/fusedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
combine_fusion_styles(style1::FusionStyle, style2::FusionStyle) = ReshapeFusion()
combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles)
FusionStyle(axis::AbstractUnitRange) = ReshapeFusion()
FusionStyle(::Tuple{}) = ReshapeFusion()

Check warning on line 14 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L14

Added line #L14 was not covered by tests
function FusionStyle(axes::Tuple{Vararg{AbstractUnitRange}})
return combine_fusion_styles(FusionStyle.(axes)...)
end
Expand All @@ -33,7 +34,6 @@
return fusedims(FusionStyle(a), a, ax, axes...)
end

# Overload this version for fusion tensors, array maps, etc.
function fusedims(
a::AbstractArray,
axb::Tuple{Vararg{AbstractUnitRange}},
Expand All @@ -42,14 +42,6 @@
return fusedims(a, flatten_tuples((axb, axesblocks...))...)
end

# Fix ambiguity issue
fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a

# TODO: Is this needed? Maybe delete.
function fusedims(a::AbstractArray, permblocks...)
return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a))))
end

function fuseaxes(
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation
)
Expand All @@ -67,7 +59,18 @@
return fusedims(a, axes_fused)
end

function fusedims(a::AbstractArray, blockedperm::BlockedPermutation)
# deal with zero-dim case
fusedims(a::AbstractArray{<:Any,0}, t::Tuple{}...) = reshape(a, ntuple(_ -> 1, length(t)))

Check warning on line 63 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L63

Added line #L63 was not covered by tests

function fusedims(a::AbstractArray, blockedperm::AbstractBlockPermutation)

Check warning on line 65 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L65

Added line #L65 was not covered by tests
# TBD define permutedims(::AbstractArray, ::AbstractBlockPermutation)
# TBD remove call to BlockedTrivialPermutation?
a_perm = _permutedims(a, Tuple(blockedperm))
return fusedims(a_perm, trivialperm(blockedperm))
end

# fusedims(ones((2,2,2,2)), (3, 1, 2), (4,))
# fusedims(ones((2,2,2,2)), (3, 1, 2), 4)
function fusedims(a::AbstractArray, permblocks...)
return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a))))

Check warning on line 75 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L74-L75

Added lines #L74 - L75 were not covered by tests
end
53 changes: 28 additions & 25 deletions src/splitdims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,65 +4,68 @@
to_axis(n::Integer) = Base.OneTo(n)

function blockedaxes(a::AbstractArray, sizeblocks::Pair...)
axes_a = axes(a)
axes_split = tuple.(axes(a))
for (dim, sizeblock) in sizeblocks
# TODO: Handle conversion from length to range!
axes_split = Base.setindex(axes_split, to_axis.(sizeblock), dim)
end
return axes_split
return tuplemortar(axes_split)

Check warning on line 12 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L12

Added line #L12 was not covered by tests
end

# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2)
function splitdims(::ReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...)
function splitdims(::ReshapeFusion, a::AbstractArray, abt::BlockedTuple)

Check warning on line 15 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L15

Added line #L15 was not covered by tests
# TODO: Add `uncanonicalizedims`.
# TODO: Need `length` since `reshape` doesn't accept `axes`,
# maybe make a `reshape_axes` function.
return reshape(a, length.(axes)...)
return reshape(a, Tuple(length.(abt)))

Check warning on line 19 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L19

Added line #L19 was not covered by tests
end

# ambiguity for zero-dim
function splitdims(a::AbstractArray{<:Any,N}, abt::BlockedTuple{N,<:Any,Tuple{}}) where {N}
return splitdims(FusionStyle(a), a, abt)

Check warning on line 24 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
end

function splitdims(

Check warning on line 27 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L27

Added line #L27 was not covered by tests
a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{AbstractUnitRange}}}
) where {N}
return splitdims(FusionStyle(a), a, bt)

Check warning on line 30 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L30

Added line #L30 was not covered by tests
end

# splitdims(randn(4, 4), 1:2, 1:2, 1:2, 1:2)
function splitdims(a::AbstractArray, axes::AbstractUnitRange...)
return splitdims(FusionStyle(a), a, axes...)
return splitdims(a, tuple.(axes)...)

Check warning on line 35 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L35

Added line #L35 was not covered by tests
end

# splitdims(randn(4, 4), (1:2, 1:2), (1:2, 1:2))
function splitdims(a::AbstractArray, axesblocks::Tuple{Vararg{AbstractUnitRange}}...)
# TODO: Add `uncanonicalizedims`.
return splitdims(a, flatten_tuples(axesblocks)...)
return splitdims(a, tuplemortar(axesblocks))

Check warning on line 41 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L41

Added line #L41 was not covered by tests
end

# Fix ambiguity issue
splitdims(a::AbstractArray) = a

# splitdims(randn(4, 4), (2, 2), (2, 2))
function splitdims(a::AbstractArray, sizeblocks::Tuple{Vararg{Integer}}...)
return splitdims(a, map(x -> Base.OneTo.(x), sizeblocks)...)
return splitdims(a, tuplemortar(sizeblocks))

Check warning on line 49 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L49

Added line #L49 was not covered by tests
end

# splitdims(randn(4, 4), 2 => (1:2, 1:2))
function splitdims(a::AbstractArray, sizeblocks::Pair...)
return splitdims(a, blockedaxes(a, sizeblocks...)...)
# splitdims(randn(4, 4), tuplemortar(((2, 2), (2, 2))))
function splitdims(

Check warning on line 53 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L53

Added line #L53 was not covered by tests
a::AbstractArray{<:Any,N}, bt::BlockedTuple{N,<:Any,<:Tuple{Vararg{Integer}}}
) where {N}
return splitdims(a, to_axis.(bt))

Check warning on line 56 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L56

Added line #L56 was not covered by tests
end

# TODO: Is this needed?
function splitdims(
a::AbstractArray,
axes_dest::Tuple{Vararg{AbstractUnitRange}},
blockedperm::BlockedPermutation,
)
# TODO: Pass grouped axes.
a_dest_perm = splitdims(a, axes_dest...)
a_dest = _permutedims(a_dest_perm, invperm(Tuple(blockedperm)))
return a_dest
# splitdims(randn(4, 4), 2 => (1:2, 1:2))
function splitdims(a::AbstractArray, sizeblocks::Pair...)
return splitdims(a, blockedaxes(a, sizeblocks...))

Check warning on line 61 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L60-L61

Added lines #L60 - L61 were not covered by tests
end

function splitdims!(
a_dest::AbstractArray, a::AbstractArray, blockedperm::BlockedPermutation
a_dest::AbstractArray, a::AbstractArray, blockedperm::AbstractBlockPermutation
)
axes_dest = map(i -> axes(a_dest, i), Tuple(blockedperm))
# TODO: Pass grouped axes.
a_dest_perm = splitdims(a, axes_dest...)
axes_dest = map(i -> axes(a_dest, i), blockedperm)
a_dest_perm = splitdims(a, axes_dest)

Check warning on line 68 in src/splitdims.jl

View check run for this annotation

Codecov / codecov/patch

src/splitdims.jl#L67-L68

Added lines #L67 - L68 were not covered by tests
_permutedims!(a_dest, a_dest_perm, invperm(Tuple(blockedperm)))
return a_dest
end
Loading
Loading