Skip to content

jwscook/MPIQR.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

84 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

QRMPI.jl: QR factorisation distributed over MPI

QR factorise your MPI distributed matrix using Householder reflections and then solve your square or least square problem.

Currently this only works for Complex32, Float64, ComplexF32 and ComplexF64 matrix and vector eltypes. These isbitstypes make use of calls to threaded BLAS.gemm! and BLAS.axpy! on each rank for highest performance, which is on a par with off threaded LAPACK getrf speeds (with a comparable number of cores).

There is a tunable parameter blocksize, which enables larger calls to gemm! and reduces the number of MPI communications, albeit at the cost of bigger communcations when they do happen, which themselves should be hidden well by compute.

An example:

using MPIQR
using LinearAlgebra, MPI, Distributed, MPIClusterManagers
using ProgressMeter # optional

MPI.Init(;threadlevel=MPI.THREAD_SERIALIZED)
const rnk = MPI.Comm_rank(MPI.COMM_WORLD)

function run(T=ComplexF64;)
  # increase blocksize to improve usage of BLAS and decrease MPI comms
  blocksize = 2
  m, n = 2048, 1024
  A0 = zeros(T, 0, 0)
  x1 = b0 = zeros(T, 0)
  if rnk == 0 # assemble and solve serially to compare with MPIQR later
    A0 = rand(T, m, n) # the original matrix
    b0 = rand(T, m) # the original lhs
    A1 = deepcopy(A0) # this will get mutated
    b1 = deepcopy(b0) # as will this
    x1 = qr!(A1) \ b1
    y1 = A0 * x1 # this is the matrix vector product, not the least squares solution
  end
  Aall = MPI.bcast(A0, 0, MPI.COMM_WORLD) # lhs matrix on all ranks
  ball = MPI.bcast(b0, 0, MPI.COMM_WORLD) # rhs vector on all ranks
  xall = MPI.bcast(x1, 0, MPI.COMM_WORLD) # solution vector on all ranks

  # get the columns of the matrix that will be local to this rank
  localcols = MPIQR.localcolumns(rnk, n, blocksize, MPI.Comm_size(MPI.COMM_WORLD))
  b = deepcopy(ball)

  # distribute the serial matrix onto the columns local to this rank
  A = MPIQR.MPIQRMatrix(deepcopy(Aall[:, localcols]), size(Aall); blocksize=blocksize)
  y2 = A * xall # make sure matrix vector multiplication works...
  if iszero(rnk) # ... and is correct.
    @assert y2  y1
  end

  # qr! optionally accepts a progress meter
  # qr factorize A in-place and solve
  x2 = qr!(A; progress=Progress(A; showspeed=true)) \ b

  if iszero(rnk) # now see if the answer is right...
    @assert norm(Aall' * Aall * xall .- Aall' * ball) < 1e-8
    @show residual = norm(Aall' * Aall * x2 .- Aall' * ball)
  end
end
run()

MPI.Finalize()

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages