Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Waves simulation optimisation - part 3 #87

Merged
merged 11 commits into from
Nov 23, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Waves: optimise current amplitude calculation for FFT wave simulation
- Investigate optimisation options for the current amplitude calculation.
- vectorised assignment to fft worlspace is not faster than a single loop.
- pre-calculate the fourier amplitude coefficients (reduce index lookup) - marginal gain?

Signed-off-by: Rhys Mainwaring <[email protected]>
srmainwaring committed Nov 22, 2022
commit 7b1e5989c2a7ceccfa351f7b44a6ce6fbea06440
150 changes: 112 additions & 38 deletions gz-waves/src/WaveSimulationFFT2.cc
Original file line number Diff line number Diff line change
@@ -528,67 +528,113 @@ namespace waves

// angular temporal frequency for time-dependent (from dispersion)
omega_k_vec_ = Eigen::sqrt(gravity_ * k.array());

// calculate zhat0
const Eigen::Ref<const Eigen::MatrixXd>& r = rho_vec_;
const Eigen::Ref<const Eigen::MatrixXd>& s = sigma_vec_;
const Eigen::Ref<const Eigen::MatrixXd>& psi_root = cap_psi_2s_root_vec_;

zhat0_rc_ = Eigen::MatrixXd::Zero(nx_, ny_);
zhat0_rs_ = Eigen::MatrixXd::Zero(nx_, ny_);
zhat0_ic_ = Eigen::MatrixXd::Zero(nx_, ny_);
zhat0_is_ = Eigen::MatrixXd::Zero(nx_, ny_);
for (int ikx = 1; ikx < nx_; ++ikx)
{
for (int iky = 1; iky < ny_; ++iky)
{
zhat0_rc_(ikx, iky) = + ( r(ikx, iky) * psi_root(ikx, iky) + r(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) );
zhat0_rs_(ikx, iky) = + ( s(ikx, iky) * psi_root(ikx, iky) + s(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) );
zhat0_is_(ikx, iky) = - ( r(ikx, iky) * psi_root(ikx, iky) - r(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) );
zhat0_ic_(ikx, iky) = + ( s(ikx, iky) * psi_root(ikx, iky) - s(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) );
}
}

for (int iky = 1; iky < ny_/2+1; ++iky)
{
int ikx = 0;
zhat0_rc_(ikx, iky) = + ( r(ikx, iky) * psi_root(ikx, iky) + r(ikx, ny_-iky) * psi_root(ikx, ny_-iky) );
zhat0_rs_(ikx, iky) = + ( s(ikx, iky) * psi_root(ikx, iky) + s(ikx, ny_-iky) * psi_root(ikx, ny_-iky) );
zhat0_is_(ikx, iky) = - ( r(ikx, iky) * psi_root(ikx, iky) - r(ikx, ny_-iky) * psi_root(ikx, ny_-iky) );
zhat0_ic_(ikx, iky) = + ( s(ikx, iky) * psi_root(ikx, iky) - s(ikx, ny_-iky) * psi_root(ikx, ny_-iky) );
}

for (int ikx = 1; ikx < nx_/2+1; ++ikx)
{
int iky = 0;
zhat0_rc_(ikx, iky) = + ( r(ikx, iky) * psi_root(ikx, iky) + r(nx_-ikx, iky) * psi_root(nx_-ikx, iky) );
zhat0_rs_(ikx, iky) = + ( s(ikx, iky) * psi_root(ikx, iky) + s(nx_-ikx, iky) * psi_root(nx_-ikx, iky) );
zhat0_is_(ikx, iky) = - ( r(ikx, iky) * psi_root(ikx, iky) - r(nx_-ikx, iky) * psi_root(nx_-ikx, iky) );
zhat0_ic_(ikx, iky) = + ( s(ikx, iky) * psi_root(ikx, iky) - s(nx_-ikx, iky) * psi_root(nx_-ikx, iky) );
}
}

#define VECTORISE_ZHAT_CALCS 0

//////////////////////////////////////////////////
void WaveSimulationFFT2Impl::ComputeCurrentAmplitudesVectorised(
double time)
{
// alias
const Eigen::Ref<const Eigen::MatrixXd>& r = rho_vec_;
const Eigen::Ref<const Eigen::MatrixXd>& s = sigma_vec_;
const Eigen::Ref<const Eigen::MatrixXd>& psi_root = cap_psi_2s_root_vec_;

// // time update
Eigen::MatrixXd wt = omega_k_vec_.array() * time;
Eigen::MatrixXd cos_omega_k = Eigen::cos(wt.array());
Eigen::MatrixXd sin_omega_k = Eigen::sin(wt.array());

// non-vectorised reference version
Eigen::MatrixXcd zhat = Eigen::MatrixXcd::Zero(nx_, ny_);
Eigen::MatrixXcdRowMajor zhat = Eigen::MatrixXcd::Zero(nx_, ny_);
for (int ikx = 1; ikx < nx_; ++ikx)
{
for (int iky = 1; iky < ny_; ++iky)
{
zhat(ikx, iky) = complex(
+ ( r(ikx, iky) * psi_root(ikx, iky) + r(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) ) * cos_omega_k(ikx, iky)
+ ( s(ikx, iky) * psi_root(ikx, iky) + s(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) ) * sin_omega_k(ikx, iky),
- ( r(ikx, iky) * psi_root(ikx, iky) - r(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) ) * sin_omega_k(ikx, iky)
+ ( s(ikx, iky) * psi_root(ikx, iky) - s(nx_-ikx, ny_-iky) * psi_root(nx_-ikx, ny_-iky) ) * cos_omega_k(ikx, iky));
+ zhat0_rc_(ikx, iky) * cos_omega_k(ikx, iky)
+ zhat0_rs_(ikx, iky) * sin_omega_k(ikx, iky),
+ zhat0_is_(ikx, iky) * sin_omega_k(ikx, iky)
+ zhat0_ic_(ikx, iky) * cos_omega_k(ikx, iky));
}
}

for (int iky = 1; iky < ny_/2+1; ++iky)
{
int ikx = 0;
zhat(ikx, iky) = complex(
+ ( r(ikx, iky) * psi_root(ikx, iky) + r(ikx, ny_-iky) * psi_root(ikx, ny_-iky) ) * cos_omega_k(ikx, iky)
+ ( s(ikx, iky) * psi_root(ikx, iky) + s(ikx, ny_-iky) * psi_root(ikx, ny_-iky) ) * sin_omega_k(ikx, iky),
- ( r(ikx, iky) * psi_root(ikx, iky) - r(ikx, ny_-iky) * psi_root(ikx, ny_-iky) ) * sin_omega_k(ikx, iky)
+ ( s(ikx, iky) * psi_root(ikx, iky) - s(ikx, ny_-iky) * psi_root(ikx, ny_-iky) ) * cos_omega_k(ikx, iky));
+ zhat0_rc_(ikx, iky) * cos_omega_k(ikx, iky)
+ zhat0_rs_(ikx, iky) * sin_omega_k(ikx, iky),
+ zhat0_is_(ikx, iky) * sin_omega_k(ikx, iky)
+ zhat0_ic_(ikx, iky) * cos_omega_k(ikx, iky));
zhat(ikx, ny_-iky) = std::conj(zhat(ikx, iky));
}

for (int ikx = 1; ikx < nx_/2+1; ++ikx)
{
int iky = 0;
zhat(ikx, iky) = complex(
+ ( r(ikx, iky) * psi_root(ikx, iky) + r(nx_-ikx, iky) * psi_root(nx_-ikx, iky) ) * cos_omega_k(ikx, iky)
+ ( s(ikx, iky) * psi_root(ikx, iky) + s(nx_-ikx, iky) * psi_root(nx_-ikx, iky) ) * sin_omega_k(ikx, iky),
- ( r(ikx, iky) * psi_root(ikx, iky) - r(nx_-ikx, iky) * psi_root(nx_-ikx, iky) ) * sin_omega_k(ikx, iky)
+ ( s(ikx, iky) * psi_root(ikx, iky) - s(nx_-ikx, iky) * psi_root(nx_-ikx, iky) ) * cos_omega_k(ikx, iky));
+ zhat0_rc_(ikx, iky) * cos_omega_k(ikx, iky)
+ zhat0_rs_(ikx, iky) * sin_omega_k(ikx, iky),
+ zhat0_is_(ikx, iky) * sin_omega_k(ikx, iky)
+ zhat0_ic_(ikx, iky) * cos_omega_k(ikx, iky));
zhat(nx_-ikx, iky) = std::conj(zhat(ikx, iky));
}

zhat(0, 0) = complex(0.0, 0.0);

/// \todo: change zhat to 1D array and use directly
// zhat = zhat.reshaped<Eigen::RowMajor>();
// zhat = zhat.reshaped<Eigen::ColMajor>();

// write into fft_h_, fft_h_ikx_, fft_h_iky_, etc.
const complex iunit(0.0, 1.0);
const complex czero(0.0, 0.0);

/// \note array version is not faster than the loop.
#if VECTORISE_ZHAT_CALCS
{ // vectorised version: note: ook_ evaluates to zero when abs(k) < 1.0E-8
fft_h_ = zhat; //h
fft_h_ikx_ = zhat.array() * iunit * kx_.array(); //hikx
fft_h_iky_ = zhat.array() * iunit * ky_.array(); //hiky
fft_sx_ = zhat.array() * iunit * ook_.array() * kx_.array() * -1; //dx
fft_sy_ = zhat.array() * iunit * ook_.array() * ky_.array() * -1; //dy
fft_h_kxkx_ = zhat.array() * ook_.array() * kx2_.array(); //hkxkx
fft_h_kyky_ = zhat.array() * ook_.array() * ky2_.array(); //hkyky
fft_h_kxky_ = zhat.array() * ook_.array() * kx_.array() * ky_.array(); //hkxky
}
#else
// loop version
for (int ikx = 0; ikx < nx_; ++ikx)
{
double kx = kx_fft_[ikx];
@@ -598,48 +644,51 @@ namespace waves
double ky = ky_fft_[iky];
double ky2 = ky*ky;
double k = sqrt(kx2 + ky2);
double ook = 1.0 / k;

complex h = zhat(ikx, iky);
// elevation
complex h = zhat(ikx, iky);
complex hi = h * iunit;
complex hok = h * ook;
complex hiok = hi * ook;

// height (amplitude)
fft_h_(ikx, iky) = h;

// height derivatives
// elevation derivatives
complex hikx = hi * kx;
complex hiky = hi * ky;

fft_h_ikx_(ikx, iky) = hi * kx;
fft_h_iky_(ikx, iky) = hi * ky;
fft_h_(ikx, iky) = h;
fft_h_ikx_(ikx, iky) = hikx;
fft_h_iky_(ikx, iky) = hiky;

// displacement and derivatives
if (std::abs(k) < 1.0E-8)
{
fft_sx_(ikx, iky) = czero;
fft_sy_(ikx, iky) = czero;
fft_sx_(ikx, iky) = czero;
fft_sy_(ikx, iky) = czero;
fft_h_kxkx_(ikx, iky) = czero;
fft_h_kyky_(ikx, iky) = czero;
fft_h_kxky_(ikx, iky) = czero;
}
else
{
complex dx = - hiok * kx;
complex dy = - hiok * ky;
// displacements
double ook = 1.0 / k;
complex hok = h * ook;
complex hiok = hi * ook;
complex dx = - hiok * kx;
complex dy = - hiok * ky;

// displacements derivatives
complex hkxkx = hok * kx2;
complex hkyky = hok * ky2;
complex hkxky = hok * kx * ky;

fft_sx_(ikx, iky) = dx;
fft_sy_(ikx, iky) = dy;
fft_sx_(ikx, iky) = dx;
fft_sy_(ikx, iky) = dy;
fft_h_kxkx_(ikx, iky) = hkxkx;
fft_h_kyky_(ikx, iky) = hkyky;
fft_h_kxky_(ikx, iky) = hkxky;
}
}
}
#endif
}

//////////////////////////////////////////////////
@@ -972,6 +1021,31 @@ namespace waves
ky_math_(iky) = ky;
ky_fft_((iky + ny_/2) % ny_) = ky;
}

#if VECTORISE_ZHAT_CALCS
// broadcast (fft) wavenumbers to arrays (aka meshgrid)
kx_ = Eigen::MatrixXd::Zero(nx_, ny_);
ky_ = Eigen::MatrixXd::Zero(nx_, ny_);
kx_.colwise() += kx_fft_;
ky_.rowwise() += ky_fft_.transpose();

// wavenumber and wave angle arrays
kx2_ = Eigen::pow(kx_.array(), 2.0);
ky2_ = Eigen::pow(ky_.array(), 2.0);
k_ = Eigen::sqrt(kx2_.array() + ky2_.array());
theta_ = ky_.binaryExpr(
kx_, [] (double y, double x) { return std::atan2(y, x);}
);

// array k_plus_ has no elements where abs(k_plus_) < 1.0E-8
k_plus_ = (Eigen::abs(k_.array()) < 1.0E-8).select(
Eigen::MatrixXd::Ones(nx_, ny_), k_);

// set 1/k to zero when abs(k) < 1.0E-8 as the quantities it multiplies
// have zero as the limit as k -> 0.
ook_ = (Eigen::abs(k_.array()) < 1.0E-8).select(
Eigen::MatrixXd::Zero(nx_, ny_), 1.0 / k_plus_.array());
#endif
}

//////////////////////////////////////////////////
14 changes: 14 additions & 0 deletions gz-waves/src/WaveSimulationFFT2Impl.hh
Original file line number Diff line number Diff line change
@@ -155,6 +155,12 @@ namespace waves
fftw_plan fft_plan0_, fft_plan1_, fft_plan2_, fft_plan3_;
fftw_plan fft_plan4_, fft_plan5_, fft_plan6_, fft_plan7_;

// precalculated amplitudes (t=0)
Eigen::MatrixXd zhat0_rc_;
Eigen::MatrixXd zhat0_rs_;
Eigen::MatrixXd zhat0_ic_;
Eigen::MatrixXd zhat0_is_;

/// \brief Flag to select whether to use vectorised calculations.
bool use_vectorised_{false};

@@ -217,6 +223,14 @@ namespace waves
Eigen::VectorXd ky_fft_;
Eigen::VectorXd kx_math_;
Eigen::VectorXd ky_math_;
Eigen::MatrixXd kx_;
Eigen::MatrixXd ky_;
Eigen::MatrixXd kx2_;
Eigen::MatrixXd ky2_;
Eigen::MatrixXd k_;
Eigen::MatrixXd k_plus_;
Eigen::MatrixXd theta_;
Eigen::MatrixXd ook_;

/// \brief Set to 1 to use a symmetric spreading function (standing waves).
bool use_symmetric_spreading_fn_{false};