Skip to content

Commit

Permalink
saving on some gemms (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkstoyanov authored Feb 14, 2024
1 parent ad8f34b commit 54a1cc2
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions src/pde/pde_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,26 +341,46 @@ class term
// recombine partial terms to form new coefficient matrices
void rechain_coefficients(dimension<P> const &adapted_dim)
{
int const level = adapted_dim.get_level();

auto const new_dof =
adapted_dim.get_degree() * fm::two_raised_to(adapted_dim.get_level());
adapted_dim.get_degree() * fm::two_raised_to(level);
expect(coefficients_.nrows() == coefficients_.ncols());
auto new_coeffs = eye<P>(new_dof);

for (auto const &pterm : partial_terms_)
if (partial_terms_.empty())
{
// no partial_terms? don't know if this can happen
fk::matrix<P, mem_type::view>(coefficients_, 0, new_dof - 1, 0,
new_dof - 1) = eye<P>(new_dof);
}
else if (partial_terms_.size() == 1)
{
auto const &partial_coeff =
pterm.get_coefficients(adapted_dim.get_level());
expect(partial_coeff.ncols() ==
new_dof); // make sure we built the partial terms to support
// new level/degree

new_coeffs = new_coeffs *
partial_coeff; // at some point, we could consider storing
// these device-side after construction.
// there's only one coefficient, just copy
// probably wasteful too
auto const &new_mat = partial_terms_[0].get_coefficients(level);
fk::matrix<P, mem_type::view>(coefficients_, 0, new_dof - 1, 0,
new_dof - 1) = new_mat;
}
else
{
// multiplying the matrices, we need two matrices
// one keeping the cumulative matrix and one storing the next matrix
fk::matrix<P> temp1 = partial_terms_[0].get_coefficients(level);
fk::matrix<P> temp2(temp1.nrows(), temp1.ncols());

fk::matrix<P, mem_type::view>(coefficients_, 0, new_dof - 1, 0,
new_dof - 1) = new_coeffs;
// make sure the partial term has been build large enough
expect(temp1.ncols() == new_dof);

for (size_t i = 1; i < partial_terms_.size(); i++)
{
auto const &pmat = partial_terms_[i].get_coefficients(level);
fm::gemm(temp1, pmat, temp2);
std::swap(temp1, temp2);
}

fk::matrix<P, mem_type::view>(coefficients_, 0, new_dof - 1, 0,
new_dof - 1) = temp1;
}
}

// public but const data. no getters
Expand Down

0 comments on commit 54a1cc2

Please sign in to comment.