Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

block tridiagonal matrices #21

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/BlockArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export Block, getblock, getblock!, setblock!, nblocks, blocksize, blockcheckboun

export BlockArray, BlockMatrix, BlockVector, BlockVecOrMat
export PseudoBlockArray, PseudoBlockMatrix, PseudoBlockVector, PseudoBlockVecOrMat
export BlockTridiagMatrix

import Base: @propagate_inbounds, Array
using Base.Cartesian
Expand All @@ -17,6 +18,7 @@ include("abstractblockarray.jl")
include("blocksizes.jl")
include("blockindices.jl")
include("blockarray.jl")
include("blocktridiag.jl")
include("pseudo_blockarray.jl")
include("show.jl")

Expand Down
243 changes: 243 additions & 0 deletions src/blocktridiag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
#=======================
BlockTridiagMatrix
=======================#

"""
BlockTridiagMatrix{T, R <: AbstractMatrix{T}} <: AbstractBlockArray{T, N}

A `BlockTridiagMatrix` is a block tridiagonal matrix where each block is stored contiguously.
This means that insertions and retrieval of blocks
can be very fast and non allocating since no copying of data is needed.

In the type definition, `R` defines the array type that each block has,
for example `Matrix{Float64}.
"""
struct BlockTridiagMatrix{T, R <: AbstractMatrix{T}} <: AbstractBlockMatrix{T}
diagl::Vector{R}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative would be to store the blocks in a BandedMatrix from the BandedMatrices.jl package, which would then take care of the logic at the block level. But then I don't know if introducing a dependency on BandedMatrices.jl is a good idea.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the name should be TridiagBlockMatrix?

lower::Vector{R}
upper::Vector{R}
block_sizes::BlockSizes{2}
end

# Auxilary outer constructors
function BlockTridiagMatrix{T, R <: AbstractArray{T}
}(diagl::Vector{R},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, these linebreaks before the } looks a bit weird and I haven't seen Julia code using it before.

lower::Vector{R},
upper::Vector{R},
block_sizes::BlockSizes{2})
return BlockTridiagMatrix{T, R}(diagl, lower, upper, block_sizes)
end

function BlockTridiagMatrix{T, R <: AbstractArray{T}
}(diagl::Vector{R},
lower::Vector{R},
upper::Vector{R},
block_sizes::Vararg{Vector{Int}, 2})
return BlockTridiagMatrix{T, R}(diagl, lower, upper,
BlockSizes(block_sizes...))
end


################
# Constructors #
################

"""
Constructs a `BlockTridiagMatrix` with uninitialized blocks from a block type `R`
with sizes defind by `block_sizes`.

```jldoctest
julia> BlockTridiagMatrix(Matrix{Float64}, [1,3], [2,2])
2×2-blocked 4×4 BlockArrays.BlockTridiagMatrix{Float64,2,Array{Float64,2}}:
#undef │ #undef #undef #undef │
--------┼--------------------------┼
#undef │ #undef #undef #undef │
#undef │ #undef #undef #undef │
--------┼--------------------------┼
#undef │ #undef #undef #undef │
```
"""
@inline function BlockTridiagMatrix{T, R <: AbstractMatrix{T}
}(::Type{R},
block_sizes::Vararg{Vector{Int}, 2})
BlockTridiagMatrix(R, BlockSizes(block_sizes...))
end

function BlockTridiagMatrix{T, R <: AbstractMatrix{T}}(::Type{R}, block_sizes::BlockSizes{2})
n_blocks = nblocks(block_sizes)
n_blocks[1] == n_blocks[2] || throw("expect same number of blocks in both dimensions")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably throw an ArgumentError. I also prefer to write erros in "must"-form e.g the number of blocks must be the same in both dimensions.

diagl = Vector{R}(n_blocks[1])
lower = Vector{R}(n_blocks[1]-1)
upper = Vector{R}(n_blocks[1]-1)
BlockTridiagMatrix{T,R}(diagl, lower, upper, block_sizes)
end

function BlockTridiagMatrix(arr::AbstractMatrix,
block_sizes::Vararg{Vector{Int}, 2})
for i in 1:2
if sum(block_sizes[i]) != size(arr, i)
throw(DimensionMismatch(
"block size for dimension $i: $(block_sizes[i])" *
"does not sum to the array size: $(size(arr, i))"))
end
end

_block_sizes = BlockSizes(block_sizes...)
bltrid_mat = BlockTridiagMatrix(typeof(arr), _block_sizes)
row_blocks, col_blocks = nblocks(bltrid_mat)
for brow in 1:row_blocks
for bcol in max(1,brow-1):min(brow+1,col_blocks)
indices = globalrange(_block_sizes, (brow,bcol))
setblock!(bltrid_mat, arr[indices...], brow, bcol)
end
end

return bltrid_mat
end

################################
# AbstractBlockArray Interface #
################################
@inline nblocks(bltrid_mat::BlockTridiagMatrix) = nblocks(bltrid_mat.block_sizes)
@inline blocksize(bltrid_mat::BlockTridiagMatrix, i::Int, j::Int) = blocksize(bltrid_mat.block_sizes, (i, j))

@inline function getblock(bltrid_mat::BlockTridiagMatrix, i::Int, j::Int)
@boundscheck blockcheckbounds(bltrid_mat, i, j)
if i==j
# for blocks on the diagonal,
# get the block from `diagl`
return bltrid_mat.diagl[i]
elseif i==j+1
# for blocks below the diagonal,
# get the block from `lower`
return bltrid_mat.lower[j]
elseif i+1==j
# for blocks above the diagonal,
# get the block from `upper`
return bltrid_mat.upper[i]
else
# otherwise return a freshly-baked
# matrix of zeros (with a warning
# because that's dumb)
warn(@sprintf("""The (%d,%d) block of a block tridiagonal matrix
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like warnings like this. It should either be an error or be allowed without warning. Potential wastefulness should go in the documentation of the type. You don't get a warning if you access a zero element of a sparse matrix.

is just zeros. It's wasteful to obtain this block.
""", i, j),
once=true,
key="blocktridiagonal_inefficient_getblock")
return zeros(eltype(bltrid_mat), blocksize(bltrid_mat, i, j))
end
end

@inline function Base.getindex(bltrid_mat::BlockTridiagMatrix, blockindex::BlockIndex{2})
block_i, block_j = blockindex.I
@boundscheck blockcheckbounds(bltrid_mat, block_i, block_j)
if abs(block_i-block_j) > 1
return zero(eltype(bltrid_mat))
end
@inbounds block = getblock(bltrid_mat, blockindex.I...)
@boundscheck checkbounds(block, blockindex.α...)
@inbounds v = block[blockindex.α...]
return v
end


###########################
# AbstractArray Interface #
###########################

@inline function Base.similar{T2}(bltrid_mat::BlockTridiagMatrix,
::Type{T2})
diagl = bltrid_mat.diagl
lower = bltrid_mat.lower
upper = bltrid_mat.upper
BlockTridiagMatrix(similar(diagl, Matrix{T2}),
similar(lower, Matrix{T2}),
similar(upper, Matrix{T2}),
copy(bltrid_mat.block_sizes))
end

function Base.size(arr::BlockTridiagMatrix)
return (arr.block_sizes[1][end]-1,
arr.block_sizes[2][end]-1)
end

@inline function Base.getindex(bltrid_mat::BlockTridiagMatrix, i::Vararg{Int, 2})
@boundscheck checkbounds(bltrid_mat, i...)
@inbounds v = bltrid_mat[global2blockindex(bltrid_mat.block_sizes, i)]
return v
end

@inline function Base.setindex!(bltrid_mat::BlockTridiagMatrix, v, i::Vararg{Int, 2})
@boundscheck checkbounds(bltrid_mat, i...)
@inbounds bltrid_mat[global2blockindex(bltrid_mat.block_sizes, i)] = v
return bltrid_mat
end

############
# Indexing #
############

function _check_setblock!(bltrid_mat::BlockTridiagMatrix, v, i::Int, j::Int)
if size(v) != blocksize(bltrid_mat, i, j)
throw(DimensionMismatch(string("tried to assign $(size(v)) array to ", blocksize(bltrid_mat, i, j), " block")))
end
end


@inline function setblock!(bltrid_mat::BlockTridiagMatrix, v, i::Int, j::Int)
@boundscheck blockcheckbounds(bltrid_mat, i, j)
@boundscheck _check_setblock!(bltrid_mat, v, i, j)
@inbounds begin
if i==j
# for blocks on the diagonal,
# get the block from `diagl`
bltrid_mat.diagl[i] = v
elseif i==j+1
# for blocks below the diagonal,
# get the block from `lower`
bltrid_mat.lower[j] = v
elseif i+1==j
# for blocks above the diagonal,
# get the block from `upper`
bltrid_mat.upper[i] = v
else
throw("tried to set zero block of BlockTridiagMatrix")
end
end
return bltrid_mat
end

@propagate_inbounds function Base.setindex!{T,N}(bltrid_mat::BlockTridiagMatrix{T, N}, v, block_index::BlockIndex{N})
getblock(bltrid_mat, block_index.I...)[block_index.α...] = v
end

########
# Misc #
########

function Base.Array{T,R}(bltrid_mat::BlockTridiagMatrix{T, R})
# TODO: This will fail for empty block array
block_sizes = bltrid_mat.block_sizes
row_blocks, col_blocks = nblocks(bltrid_mat)
arr = zeros(T, size(bltrid_mat))
for brow in 1:row_blocks
for bcol in max(1,brow-1):min(brow+1,col_blocks)
indices = globalrange(block_sizes, (brow,bcol))
arr[indices...] = getblock(bltrid_mat, brow, bcol)
end
end
return arr
end

function Base.copy!{T, R<:AbstractArray{T}, M<:AbstractArray{T}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be done here but there should be a way where the code here doesnt have to be so similar to the one working for BlockArray.

}(bltrid_mat::BlockTridiagMatrix{T, R}, arr::M)
block_sizes = bltrid_mat.block_sizes
row_blocks, col_blocks = nblocks(bltrid_mat)
for brow in 1:row_blocks
for bcol in max(1,brow-1):min(brow+1,col_blocks)
indices = globalrange(block_sizes, (brow,bcol))
copy!(getblock(bltrid_mat, brow, bcol), view(arr, indices...))
end
end
return bltrid_mat
end