-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhancement] Add singular value decomposition (#16)
- Loading branch information
Showing
10 changed files
with
448 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,14 @@ | ||
name = "BlockSparseArrays" | ||
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" | ||
authors = ["ITensor developers <[email protected]> and contributors"] | ||
version = "0.2.6" | ||
version = "0.2.7" | ||
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" | ||
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" | ||
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" | ||
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" | ||
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" | ||
|
@@ -31,7 +32,8 @@ Adapt = "4.1.1" | |
Aqua = "0.8.9" | ||
ArrayLayouts = "1.10.4" | ||
BlockArrays = "1.2.0" | ||
DerivableInterfaces = "0.3.7" | ||
DerivableInterfaces = "0.3.8" | ||
DiagonalArrays = "0.2.2" | ||
Dictionaries = "0.4.3" | ||
FillArrays = "1.13.0" | ||
GPUArraysCore = "0.1.0, 0.2" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Library | ||
|
||
```@autodocs | ||
Modules = [BlockSparseArrays] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,84 @@ | ||
const AbstractBlockSparseMatrix{T} = AbstractBlockSparseArray{T,2} | ||
|
||
# SVD is implemented by trying to | ||
# 1. Attempt to find a block-diagonal implementation by permuting | ||
# 2. Fallback to AbstractBlockArray implementation via BlockedArray | ||
|
||
function eigencopy_oftype(A::AbstractBlockSparseMatrix, T) | ||
if is_block_permutation_matrix(A) | ||
Acopy = similar(A, T) | ||
for bI in eachblockstoredindex(A) | ||
Acopy[bI] = eigencopy_oftype(@view!(A[bI]), T) | ||
end | ||
return Acopy | ||
else | ||
return BlockedMatrix{T}(A) | ||
end | ||
end | ||
|
||
function is_block_permutation_matrix(a::AbstractBlockSparseMatrix) | ||
return allunique(first ∘ Tuple, eachblockstoredindex(a)) && | ||
allunique(last ∘ Tuple, eachblockstoredindex(a)) | ||
end | ||
|
||
function _allocate_svd_output(A::AbstractBlockSparseMatrix, full::Bool, ::Algorithm) | ||
@assert !full "TODO" | ||
bm, bn = blocksize(A) | ||
bmn = min(bm, bn) | ||
|
||
brows = blocklengths(axes(A, 1)) | ||
bcols = blocklengths(axes(A, 2)) | ||
slengths = Vector{Int}(undef, bmn) | ||
|
||
# fill in values for blocks that are present | ||
bIs = collect(eachblockstoredindex(A)) | ||
browIs = Int.(first.(Tuple.(bIs))) | ||
bcolIs = Int.(last.(Tuple.(bIs))) | ||
for bI in eachblockstoredindex(A) | ||
row, col = Int.(Tuple(bI)) | ||
nrows = brows[row] | ||
ncols = bcols[col] | ||
slengths[col] = min(nrows, ncols) | ||
end | ||
|
||
# fill in values for blocks that aren't present, pairing them in order of occurence | ||
# this is a convention, which at least gives the expected results for blockdiagonal | ||
emptyrows = setdiff(1:bm, browIs) | ||
emptycols = setdiff(1:bn, bcolIs) | ||
for (row, col) in zip(emptyrows, emptycols) | ||
slengths[col] = min(brows[row], bcols[col]) | ||
end | ||
|
||
s_axis = blockedrange(slengths) | ||
U = similar(A, axes(A, 1), s_axis) | ||
S = similar(A, real(eltype(A)), s_axis) | ||
Vt = similar(A, s_axis, axes(A, 2)) | ||
|
||
# also fill in identities for blocks that aren't present | ||
for (row, col) in zip(emptyrows, emptycols) | ||
copyto!(@view!(U[Block(row, col)]), LinearAlgebra.I) | ||
copyto!(@view!(Vt[Block(col, col)]), LinearAlgebra.I) | ||
end | ||
|
||
return U, S, Vt | ||
end | ||
|
||
function svd(A::AbstractBlockSparseMatrix; kwargs...) | ||
return svd!(eigencopy_oftype(A, LinearAlgebra.eigtype(eltype(A))); kwargs...) | ||
end | ||
|
||
function svd!( | ||
A::AbstractBlockSparseMatrix; full::Bool=false, alg::Algorithm=default_svd_alg(A) | ||
) | ||
@assert is_block_permutation_matrix(A) "Cannot keep sparsity: use `svd` to convert to `BlockedMatrix" | ||
U, S, Vt = _allocate_svd_output(A, full, alg) | ||
for bI in eachblockstoredindex(A) | ||
bUSV = svd!(@view!(A[bI]); full, alg) | ||
brow, bcol = Tuple(bI) | ||
U[brow, bcol] = bUSV.U | ||
S[bcol] = bUSV.S | ||
Vt[bcol, bcol] = bUSV.Vt | ||
end | ||
|
||
return SVD(U, S, Vt) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# type alias for block-diagonal | ||
using LinearAlgebra: Diagonal | ||
using DiagonalArrays: DiagonalArrays, diagonal | ||
|
||
const BlockDiagonal{T,A,Axes,V<:AbstractVector{A}} = BlockSparseMatrix{ | ||
T,A,Diagonal{A,V},Axes | ||
} | ||
const BlockSparseDiagonal{T,A<:AbstractBlockSparseVector{T}} = Diagonal{T,A} | ||
|
||
@interface interface::BlockSparseArrayInterface function blocks(a::BlockSparseDiagonal) | ||
return Diagonal(Diagonal.(blocks(a.diag))) | ||
end | ||
|
||
function BlockDiagonal(blocks::AbstractVector{<:AbstractMatrix}) | ||
return BlockSparseArray( | ||
Diagonal(blocks), (blockedrange(size.(blocks, 1)), blockedrange(size.(blocks, 2))) | ||
) | ||
end | ||
|
||
function DiagonalArrays.diagonal(S::BlockSparseVector) | ||
D = similar(S, (axes(S, 1), axes(S, 1))) | ||
for bI in eachblockstoredindex(S) | ||
D[bI, bI] = diagonal(@view!(S[bI])) | ||
end | ||
return D | ||
end |
Oops, something went wrong.