Skip to content

Commit

Permalink
AMReX Algebra: AlgVector and SpMatrix (#4259)
Browse files Browse the repository at this point in the history
This implements distributed vector and sparse matrix. The
functionalities are limited so far. However, there are enough functions
implemented for a simple matrix and vector based GMRES solver with a
weighted Jacobi smoother.

This is still experimental and is not ready for users yet.
  • Loading branch information
WeiqunZhang authored Dec 9, 2024
1 parent 96db0a6 commit 50e25ec
Show file tree
Hide file tree
Showing 21 changed files with 1,883 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/dependencies/dependencies_hip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ sudo apt-get install -y --no-install-recommends \
rocprofiler-dev \
rocrand-dev \
rocfft-dev \
rocprim-dev
rocprim-dev \
rocsparse-dev

# hiprand-dev is a new package that does not exist in old versions
sudo apt-get install -y --no-install-recommends hiprand-dev || true
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/dependencies/dependencies_nvcc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ sudo apt-get install -y \
cuda-nvml-dev-$VERSION_DASHED \
cuda-nvtx-$VERSION_DASHED \
libcufft-dev-$VERSION_DASHED \
libcurand-dev-$VERSION_DASHED
libcurand-dev-$VERSION_DASHED \
libcusparse-dev-$VERSION_DASHED
sudo ln -s cuda-$VERSION_DOTTED /usr/local/cuda
38 changes: 27 additions & 11 deletions Src/Base/AMReX_TableData.H
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

namespace amrex {

template <typename T>
template <typename T, typename IDX = int>
struct Table1D
{
T* AMREX_RESTRICT p = nullptr;
int begin = 1;
int end = 0;
IDX begin = 1;
IDX end = 0;

constexpr Table1D () noexcept = default;

Expand All @@ -33,7 +33,7 @@ struct Table1D
{}

AMREX_GPU_HOST_DEVICE
constexpr Table1D (T* a_p, int a_begin, int a_end) noexcept
constexpr Table1D (T* a_p, IDX a_begin, IDX a_end) noexcept
: p(a_p),
begin(a_begin),
end(a_end)
Expand All @@ -44,7 +44,7 @@ struct Table1D

template <class U=T, std::enable_if_t<!std::is_void_v<U>,int> = 0>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
U& operator() (int i) const noexcept {
U& operator() (IDX i) const noexcept {
#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
index_assert(i);
#endif
Expand All @@ -53,14 +53,30 @@ struct Table1D

#if defined(AMREX_DEBUG) || defined(AMREX_BOUND_CHECK)
AMREX_GPU_HOST_DEVICE inline
void index_assert (int i) const
void index_assert (IDX i) const
{
if (i < begin || i >= end) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%d) is out of bound (%d:%d)\n",
i, begin, end-1);
amrex::Abort();
))
if constexpr (std::is_same_v<IDX,int>) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%d) is out of bound (%d:%d)\n",
i, begin, end-1);
amrex::Abort();
))
} else if constexpr (std::is_same_v<IDX,long>) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%ld) is out of bound (%ld:%ld)\n",
i, begin, end-1);
amrex::Abort();
))
} else if constexpr (std::is_same_v<IDX,long long>) {
AMREX_IF_ON_DEVICE((
AMREX_DEVICE_PRINTF(" (%lld) is out of bound (%lld:%lld)\n",
i, begin, end-1);
amrex::Abort();
))
} else {
AMREX_IF_ON_DEVICE(( amrex::Abort(" Out of bound\n"); ))
}
AMREX_IF_ON_HOST((
std::stringstream ss;
ss << " (" << i << ") is out of bound ("
Expand Down
58 changes: 58 additions & 0 deletions Src/LinearSolvers/AMReX_AlgPartition.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifndef AMREX_ALG_PARTITION_H_
#define AMREX_ALG_PARTITION_H_
#include <AMReX_Config.H>

#include <AMReX_INT.H>
#include <AMReX_ParallelDescriptor.H>
#include <AMReX_Vector.H>

#include <memory>

namespace amrex {

class AlgPartition
{
public:
AlgPartition ();
explicit AlgPartition (Long global_size);
explicit AlgPartition (Vector<Long> const& rows);
explicit AlgPartition (Vector<Long>&& rows) noexcept;

void define (Long global_size);
void define (Vector<Long> const& rows);
void define (Vector<Long>&& rows);

[[nodiscard]] bool empty () const { return m_ref->m_row.empty(); }

[[nodiscard]] Long operator[] (int i) const { return m_ref->m_row[i]; }
[[nodiscard]] Long numGlobalRows () const { return m_ref->m_row.back(); }
[[nodiscard]] int numActiveProcs () const { return m_ref->m_n_active_procs; }

[[nodiscard]] Vector<Long> const& dataVector () const { return m_ref->m_row; }

[[nodiscard]] bool operator== (AlgPartition const& rhs) const noexcept;
[[nodiscard]] bool operator!= (AlgPartition const& rhs) const noexcept;

private:
struct Ref
{
friend class AlgPartition;
Ref () = default;
explicit Ref (Long global_size);
explicit Ref (Vector<Long> const& rows);
explicit Ref (Vector<Long>&& rows);
void define (Long global_size);
void define (Vector<Long> const& rows);
void define (Vector<Long>&& rows);
void update_n_active_procs ();

Vector<Long> m_row; // size: nprocs + 1
int m_n_active_procs = 0;
};

std::shared_ptr<Ref> m_ref;
};

}

#endif
102 changes: 102 additions & 0 deletions Src/LinearSolvers/AMReX_AlgPartition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include <AMReX_AlgPartition.H>

namespace amrex {

AlgPartition::AlgPartition ()
: m_ref(std::make_shared<Ref>())
{}

AlgPartition::AlgPartition (Long global_size)
: m_ref(std::make_shared<Ref>(global_size))
{}

AlgPartition::AlgPartition (Vector<Long> const& rows)
: m_ref(std::make_shared<Ref>(rows))
{}

AlgPartition::AlgPartition (Vector<Long>&& rows) noexcept
: m_ref(std::make_shared<Ref>(std::move(rows)))
{}

void AlgPartition::define (Long global_size)
{
m_ref->define(global_size);
}

void AlgPartition::define (Vector<Long> const& rows)
{
m_ref->define(rows);
}

void AlgPartition::define (Vector<Long>&& rows)
{
m_ref->define(std::move(rows));
}

bool AlgPartition::operator== (AlgPartition const& rhs) const noexcept
{
return m_ref == rhs.m_ref || m_ref->m_row == rhs.m_ref->m_row;
}

bool AlgPartition::operator!= (AlgPartition const& rhs) const noexcept
{
return !operator==(rhs);
}

AlgPartition::Ref::Ref (Long global_size)
{
define(global_size);
}

AlgPartition::Ref::Ref (Vector<Long> const& rows)
: m_row(rows)
{
update_n_active_procs();
}

AlgPartition::Ref::Ref (Vector<Long>&& rows)
: m_row(std::move(rows))
{
update_n_active_procs();
}

void AlgPartition::Ref::define (Long global_size)
{
auto nprocs = Long(ParallelDescriptor::NProcs());
Long sz = global_size / nprocs;
Long extra = global_size - sz*nprocs;
m_row.resize(nprocs+1);
for (Long i = 0; i < nprocs; ++i) {
if (i < extra) {
m_row[i] = i*(sz+1);
} else {
m_row[i] = i*sz + extra;
}
}
m_row[nprocs] = global_size;

update_n_active_procs();
}

void AlgPartition::Ref::define (Vector<Long> const& rows)
{
m_row = rows;
update_n_active_procs();
}

void AlgPartition::Ref::define (Vector<Long>&& rows)
{
m_row = std::move(rows);
update_n_active_procs();
}

void AlgPartition::Ref::update_n_active_procs ()
{
AMREX_ASSERT(m_row.size() == ParallelDescriptor::NProcs()+1);
m_n_active_procs = 0;
for (int i = 0, N = int(m_row.size())-1; i < N; ++i) {
if (m_row[i] < m_row[i+1]) { ++m_n_active_procs; }
}
}

}
Loading

0 comments on commit 50e25ec

Please sign in to comment.