QR factorise your MPI distributed matrix using Householder reflections and then solve your square or least square problem.
Currently this only works for Complex32
, Float64
, ComplexF32
and ComplexF64
matrix and vector eltype
These isbitstypes
make use of calls to threaded BLAS.gemm!
and BLAS.axpy!
on each rank for highest performance, which is on a par with off threaded LAPACK
speeds (with a comparable number of cores).
There is a tunable parameter blocksize
, which enables larger calls to gemm!
and reduces the number of MPI communications, albeit at the cost of bigger communcations when they do happen, which themselves should be hidden well by compute.
An example:
using MPIQR
using LinearAlgebra, MPI, Distributed, MPIClusterManagers
using ProgressMeter # optional
const rnk = MPI.Comm_rank(MPI.COMM_WORLD)
function run(T=ComplexF64;)
# increase blocksize to improve usage of BLAS and decrease MPI comms
blocksize = 2
m, n = 2048, 1024
A0 = zeros(T, 0, 0)
x1 = b0 = zeros(T, 0)
if rnk == 0 # assemble and solve serially to compare with MPIQR later
A0 = rand(T, m, n) # the original matrix
b0 = rand(T, m) # the original lhs
A1 = deepcopy(A0) # this will get mutated
b1 = deepcopy(b0) # as will this
x1 = qr!(A1) \ b1
y1 = A0 * x1 # this is the matrix vector product, not the least squares solution
Aall = MPI.bcast(A0, 0, MPI.COMM_WORLD) # lhs matrix on all ranks
ball = MPI.bcast(b0, 0, MPI.COMM_WORLD) # rhs vector on all ranks
xall = MPI.bcast(x1, 0, MPI.COMM_WORLD) # solution vector on all ranks
# get the columns of the matrix that will be local to this rank
localcols = MPIQR.localcolumns(rnk, n, blocksize, MPI.Comm_size(MPI.COMM_WORLD))
b = deepcopy(ball)
# distribute the serial matrix onto the columns local to this rank
A = MPIQR.MPIQRMatrix(deepcopy(Aall[:, localcols]), size(Aall); blocksize=blocksize)
y2 = A * xall # make sure matrix vector multiplication works...
if iszero(rnk) # ... and is correct.
@assert y2 ≈ y1
# qr! optionally accepts a progress meter
# qr factorize A in-place and solve
x2 = qr!(A; progress=Progress(A; showspeed=true)) \ b
if iszero(rnk) # now see if the answer is right...
@assert norm(Aall' * Aall * xall .- Aall' * ball) < 1e-8
@show residual = norm(Aall' * Aall * x2 .- Aall' * ball)