Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend sparse linear algebra interface with multiply-add #240

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 4 additions & 39 deletions src/atlas/interpolation/method/Method.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup with new API

Original file line number Diff line number Diff line change
Expand Up @@ -168,65 +168,30 @@ void Method::interpolate_field_rank3(const Field& src, Field& tgt, const Matrix&

template <typename Value>
void Method::adjoint_interpolate_field_rank1(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());
auto backend = std::is_same<Value, float>::value ? sparse::backend::openmp() : sparse::Backend{linalg_backend_};

auto tmp_v = array::make_view<Value, 1>(tmp);
auto src_v = array::make_view<Value, 1>(src);
auto tgt_v = array::make_view<Value, 1>(tgt);

tmp_v.assign(0.);

if (std::is_same<Value, float>::value) {
sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());
}
else {
sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::Backend{linalg_backend_});
}


for (idx_t t = 0; t < tmp.shape(0); ++t) {
src_v(t) += tmp_v(t);
}
sparse_matrix_multiply_add(W, tgt_v, src_v, backend);
}

template <typename Value>
void Method::adjoint_interpolate_field_rank2(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());

auto tmp_v = array::make_view<Value, 2>(tmp);
auto src_v = array::make_view<Value, 2>(src);
auto tgt_v = array::make_view<Value, 2>(tgt);

tmp_v.assign(0.);

sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t k = 0; k < tmp.shape(1); ++k) {
src_v(t, k) += tmp_v(t, k);
}
}
sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp());
}

template <typename Value>
void Method::adjoint_interpolate_field_rank3(Field& src, const Field& tgt, const Matrix& W) const {
array::ArrayT<Value> tmp(src.shape());

auto tmp_v = array::make_view<Value, 3>(tmp);
auto src_v = array::make_view<Value, 3>(src);
auto tgt_v = array::make_view<Value, 3>(tgt);

tmp_v.assign(0.);

sparse_matrix_multiply(W, tgt_v, tmp_v, sparse::backend::openmp());

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t j = 0; j < tmp.shape(1); ++j) {
for (idx_t k = 0; k < tmp.shape(2); ++k) {
src_v(t, j, k) += tmp_v(t, j, k);
}
}
}
sparse_matrix_multiply_add(W, tgt_v, src_v, sparse::backend::openmp());
}

void Method::check_compatibility(const Field& src, const Field& tgt, const Matrix& W) const {
Expand Down
43 changes: 40 additions & 3 deletions src/atlas/linalg/sparse/SparseMatrixMultiply.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing,
const Configuration& config);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, const Configuration& config);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing);

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing,
const Configuration& config);

class SparseMatrixMultiply {
public:
SparseMatrixMultiply() = default;
Expand All @@ -46,14 +59,34 @@ class SparseMatrixMultiply {

template <typename Matrix, typename SourceView, typename TargetView>
void operator()(const Matrix& matrix, const SourceView& src, TargetView& tgt) const {
sparse_matrix_multiply(matrix, src, tgt, backend());
multiply(matrix, src, tgt);
}

template <typename Matrix, typename SourceView, typename TargetView>
void operator()(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const {
multiply(matrix, src, tgt, indexing);
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt) const {
sparse_matrix_multiply(matrix, src, tgt, backend());
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const {
sparse_matrix_multiply(matrix, src, tgt, indexing, backend());
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt) const {
sparse_matrix_multiply_add(matrix, src, tgt, backend());
}

template <typename Matrix, typename SourceView, typename TargetView>
void multiply_add(const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing) const {
sparse_matrix_multiply_add(matrix, src, tgt, indexing, backend());
}

const sparse::Backend& backend() const { return backend_; }

private:
Expand All @@ -65,8 +98,12 @@ namespace sparse {
// Template class which needs (full or partial) specialization for concrete template parameters
template <typename Backend, Indexing, int Rank, typename SourceValue, typename TargetValue>
struct SparseMatrixMultiply {
static void apply(const SparseMatrix&, const View<SourceValue, Rank>&, View<TargetValue, Rank>&,
const Configuration&) {
static void multiply(const SparseMatrix&, const View<SourceValue, Rank>&, View<TargetValue, Rank>&,
const Configuration&) {
throw_NotImplemented("SparseMatrixMultiply needs a template specialization with the implementation", Here());
}
static void multiply_add(const SparseMatrix&, const View<SourceValue, Rank>&, View<TargetValue, Rank>&,
const Configuration&) {
throw_NotImplemented("SparseMatrixMultiply needs a template specialization with the implementation", Here());
}
};
Expand Down
83 changes: 78 additions & 5 deletions src/atlas/linalg/sparse/SparseMatrixMultiply.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,25 @@ namespace {
template <typename Backend, Indexing indexing>
struct SparseMatrixMultiplyHelper {
template <typename SourceView, typename TargetView>
static void apply( const SparseMatrix& W, const SourceView& src, TargetView& tgt,
static void multiply( const SparseMatrix& W, const SourceView& src, TargetView& tgt,
const eckit::Configuration& config ) {
using SourceValue = const typename std::remove_const<typename SourceView::value_type>::type;
using TargetValue = typename std::remove_const<typename TargetView::value_type>::type;
constexpr int src_rank = introspection::rank<SourceView>();
constexpr int tgt_rank = introspection::rank<TargetView>();
static_assert( src_rank == tgt_rank, "src and tgt need same rank" );
SparseMatrixMultiply<Backend, indexing, src_rank, SourceValue, TargetValue>::multiply( W, src, tgt, config );
}

template <typename SourceView, typename TargetView>
static void multiply_add( const SparseMatrix& W, const SourceView& src, TargetView& tgt,
const eckit::Configuration& config ) {
using SourceValue = const typename std::remove_const<typename SourceView::value_type>::type;
using TargetValue = typename std::remove_const<typename TargetView::value_type>::type;
constexpr int src_rank = introspection::rank<SourceView>();
constexpr int tgt_rank = introspection::rank<TargetView>();
static_assert( src_rank == tgt_rank, "src and tgt need same rank" );
SparseMatrixMultiply<Backend, indexing, src_rank, SourceValue, TargetValue>::apply( W, src, tgt, config );
SparseMatrixMultiply<Backend, indexing, src_rank, SourceValue, TargetValue>::multiply_add( W, src, tgt, config );
}
};

Expand All @@ -53,14 +64,38 @@ void dispatch_sparse_matrix_multiply( const Matrix& matrix, const SourceView& sr
if ( introspection::layout_right( src ) || introspection::layout_right( tgt ) ) {
ATLAS_ASSERT( introspection::layout_right( src ) && introspection::layout_right( tgt ) );
// Override layout with known layout given by introspection
SparseMatrixMultiplyHelper<Backend, linalg::Indexing::layout_right>::apply( matrix, src_v, tgt_v, config );
SparseMatrixMultiplyHelper<Backend, linalg::Indexing::layout_right>::multiply( matrix, src_v, tgt_v, config );
}
else {
if( indexing == Indexing::layout_left ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_left>::multiply( matrix, src_v, tgt_v, config );
}
else if( indexing == Indexing::layout_right ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_right>::multiply( matrix, src_v, tgt_v, config );
}
else {
throw_NotImplemented( "indexing not implemented", Here() );
}
}
}

template <typename Backend, typename Matrix, typename SourceView, typename TargetView>
void dispatch_sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing,
const eckit::Configuration& config ) {
auto src_v = make_view( src );
auto tgt_v = make_view( tgt );

if ( introspection::layout_right( src ) || introspection::layout_right( tgt ) ) {
ATLAS_ASSERT( introspection::layout_right( src ) && introspection::layout_right( tgt ) );
// Override layout with known layout given by introspection
SparseMatrixMultiplyHelper<Backend, linalg::Indexing::layout_right>::multiply_add( matrix, src_v, tgt_v, config );
}
else {
if( indexing == Indexing::layout_left ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_left>::apply( matrix, src_v, tgt_v, config );
SparseMatrixMultiplyHelper<Backend, Indexing::layout_left>::multiply_add( matrix, src_v, tgt_v, config );
}
else if( indexing == Indexing::layout_right ) {
SparseMatrixMultiplyHelper<Backend, Indexing::layout_right>::apply( matrix, src_v, tgt_v, config );
SparseMatrixMultiplyHelper<Backend, Indexing::layout_right>::multiply_add( matrix, src_v, tgt_v, config );
}
else {
throw_NotImplemented( "indexing not implemented", Here() );
Expand Down Expand Up @@ -108,6 +143,44 @@ void sparse_matrix_multiply( const Matrix& matrix, const SourceView& src, Target
sparse_matrix_multiply( matrix, src, tgt, Indexing::layout_left );
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing,
const eckit::Configuration& config ) {
std::string type = config.getString( "type", sparse::current_backend() );
if ( type == sparse::backend::openmp::type() ) {
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::openmp>( matrix, src, tgt, indexing, config );
}
else if ( type == sparse::backend::eckit_linalg::type() ) {
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::eckit_linalg>( matrix, src, tgt, indexing, config );
}
#if ATLAS_ECKIT_HAVE_ECKIT_585
else if( eckit::linalg::LinearAlgebraSparse::hasBackend(type) ) {
#else
else if( eckit::linalg::LinearAlgebra::hasBackend(type) ) {
#endif
sparse::dispatch_sparse_matrix_multiply_add<sparse::backend::eckit_linalg>( matrix, src, tgt, indexing, util::Config("backend",type) );
}
else {
throw_NotImplemented( "sparse_matrix_multiply_add cannot be performed with unsupported backend [" + type + "]",
Here() );
}
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, const eckit::Configuration& config ) {
sparse_matrix_multiply_add( matrix, src, tgt, Indexing::layout_left, config );
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt, Indexing indexing ) {
sparse_matrix_multiply_add( matrix, src, tgt, indexing, sparse::Backend() );
}

template <typename Matrix, typename SourceView, typename TargetView>
void sparse_matrix_multiply_add( const Matrix& matrix, const SourceView& src, TargetView& tgt ) {
sparse_matrix_multiply_add( matrix, src, tgt, Indexing::layout_left );
}

} // namespace linalg
} // namespace atlas

Expand Down
51 changes: 47 additions & 4 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "SparseMatrixMultiply_EckitLinalg.h"

#include "atlas/array.h"
#include "atlas/library/config.h"

#if ATLAS_ECKIT_HAVE_ECKIT_585
Expand Down Expand Up @@ -62,9 +63,15 @@ const eckit::linalg::LinearAlgebra& eckit_linalg_backend(const Configuration& co
}
#endif

template <typename Value, int Rank>
auto linalg_make_view(atlas::array::ArrayT<double>& array) {
auto v_array = array::make_view<Value, Rank>(array);
return atlas::linalg::make_view(v_array);
}

} // namespace

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::apply(
void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {
ATLAS_ASSERT(src.contiguous());
ATLAS_ASSERT(tgt.contiguous());
Expand All @@ -73,7 +80,7 @@ void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, cons
eckit_linalg_backend(config).spmv(W, v_src, v_tgt);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::apply(
void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::multiply(
const SparseMatrix& W, const View<const double, 2>& src, View<double, 2>& tgt, const Configuration& config) {
ATLAS_ASSERT(src.contiguous());
ATLAS_ASSERT(tgt.contiguous());
Expand All @@ -84,9 +91,45 @@ void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, cons
eckit_linalg_backend(config).spmm(W, m_src, m_tgt);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double>::apply(
void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double>::multiply(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {
SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply(W, src, tgt,
config);
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply_add(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {

array::ArrayT<double> tmp(src.shape(0));
auto v_tmp = linalg_make_view<double, 1>(tmp);
v_tmp.assign(0.);

SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply(W, src, v_tmp, config);

for (idx_t t = 0; t < tmp.shape(0); ++t) {
tgt(t) += v_tmp(t);
}
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::multiply_add(
const SparseMatrix& W, const View<const double, 2>& src, View<double, 2>& tgt, const Configuration& config) {

array::ArrayT<double> tmp(src.shape(0), src.shape(1));
auto v_tmp = linalg_make_view<double, 2>(tmp);
v_tmp.assign(0.);

SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double>::multiply(W, src, v_tmp, config);

for (idx_t t = 0; t < tmp.shape(0); ++t) {
for (idx_t k = 0; k < tmp.shape(1); ++k) {
tgt(t, k) += v_tmp(t, k);
}
}
}

void SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double>::multiply_add(
const SparseMatrix& W, const View<const double, 1>& src, View<double, 1>& tgt, const Configuration& config) {
SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::apply(W, src, tgt,
SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double>::multiply_add(W, src, tgt,
config);
}

Expand Down
12 changes: 9 additions & 3 deletions src/atlas/linalg/sparse/SparseMatrixMultiply_EckitLinalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,26 @@ namespace sparse {

template <>
struct SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 1, const double, double> {
static void apply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
static void multiply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
static void multiply_add(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
};

template <>
struct SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_right, 2, const double, double> {
static void apply(const SparseMatrix&, const View<const double, 2>& src, View<double, 2>& tgt,
static void multiply(const SparseMatrix&, const View<const double, 2>& src, View<double, 2>& tgt,
const Configuration&);
static void multiply_add(const SparseMatrix&, const View<const double, 2>& src, View<double, 2>& tgt,
const Configuration&);
};


template <>
struct SparseMatrixMultiply<backend::eckit_linalg, Indexing::layout_left, 1, const double, double> {
static void apply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
static void multiply(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
static void multiply_add(const SparseMatrix&, const View<const double, 1>& src, View<double, 1>& tgt,
const Configuration&);
};

Expand Down
Loading
Loading