Skip to content

Commit

Permalink
Split FFT operations from halo communication
Browse files Browse the repository at this point in the history
  • Loading branch information
jngrad committed Aug 1, 2024
1 parent 2f97dd8 commit 0263159
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/core/electrostatics/p3m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ void CoulombP3MImpl<FloatType, Architecture>::init_cpu_kernels() {

assert(p3m.fft);
p3m.local_mesh.calc_local_ca_mesh(p3m.params, local_geo, skin, elc_layer);
p3m.fft->init_halo();
p3m.fft->init_fft();
p3m.calc_differential_operator();

Expand Down Expand Up @@ -390,6 +391,7 @@ Utils::Vector9d CoulombP3MImpl<FloatType, Architecture>::long_range_pressure(

if (p3m.sum_q2 > 0.) {
charge_assign(particles);
p3m.fft->perform_scalar_halo_gather();
p3m.fft->perform_scalar_fwd_fft();

auto constexpr mesh_start = Utils::Vector3i::broadcast(0);
Expand Down Expand Up @@ -455,6 +457,7 @@ double CoulombP3MImpl<FloatType, Architecture>::long_range_kernel(
system.coulomb.impl->solver)) {
charge_assign(particles);
}
p3m.fft->perform_scalar_halo_gather();
p3m.fft->perform_scalar_fwd_fft();
}

Expand Down Expand Up @@ -513,6 +516,7 @@ double CoulombP3MImpl<FloatType, Architecture>::long_range_kernel(
not p3m.params.tuning and check_complex_residuals;
p3m.fft->check_complex_residuals = check_residuals;
p3m.fft->perform_vector_back_fft();
p3m.fft->perform_vector_halo_spread();
p3m.fft->check_complex_residuals = false;

auto const force_prefac = prefactor / volume;
Expand Down
4 changes: 4 additions & 0 deletions src/core/magnetostatics/dp3m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void DipolarP3MImpl<FloatType, Architecture>::init_cpu_kernels() {

assert(dp3m.fft);
dp3m.local_mesh.calc_local_ca_mesh(dp3m.params, local_geo, verlet_skin, 0.);
dp3m.fft->init_halo();
dp3m.fft->init_fft();
dp3m.calc_differential_operator();

Expand Down Expand Up @@ -252,6 +253,7 @@ double DipolarP3MImpl<FloatType, Architecture>::long_range_kernel(

if (dp3m.sum_mu2 > 0.) {
dipole_assign(particles);
dp3m.fft->perform_vector_halo_gather();
dp3m.fft->perform_vector_fwd_fft();
}

Expand Down Expand Up @@ -353,6 +355,7 @@ double DipolarP3MImpl<FloatType, Architecture>::long_range_kernel(
++index;
});
dp3m.fft->perform_scalar_back_fft();
dp3m.fft->perform_scalar_halo_spread();
/* Assign force component from mesh to particle */
auto const d_rs = (d + dp3m.mesh.ks_pnum) % 3;
Utils::integral_parameter<int, AssignTorques, 1, 7>(
Expand Down Expand Up @@ -404,6 +407,7 @@ double DipolarP3MImpl<FloatType, Architecture>::long_range_kernel(
++index;
});
dp3m.fft->perform_vector_back_fft();
dp3m.fft->perform_vector_halo_spread();
/* Assign force component from mesh to particle */
auto const d_rs = (d + dp3m.mesh.ks_pnum) % 3;
Utils::integral_parameter<int, AssignForces, 1, 7>(
Expand Down
29 changes: 22 additions & 7 deletions src/core/p3m/FFTBackendLegacy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ void FFTBackendLegacy<FloatType>::update_mesh_data() {
}
}

template <typename FloatType> void FFTBackendLegacy<FloatType>::init_fft() {
template <typename FloatType> void FFTBackendLegacy<FloatType>::init_halo() {
mesh_comm.resize(::comm_cart, local_mesh);
}

template <typename FloatType> void FFTBackendLegacy<FloatType>::init_fft() {
auto ca_mesh_size = fft->initialize_fft(
::comm_cart, local_mesh.dim, local_mesh.margin, params.mesh,
params.mesh_off, mesh.ks_pnum, ::communicator.node_grid);
Expand All @@ -79,31 +82,41 @@ template <typename FloatType> void FFTBackendLegacy<FloatType>::init_fft() {

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_vector_back_fft() {
/* Back FFT force component mesh */
for (auto &rs_mesh_field : rs_mesh_fields) {
fft->backward_fft(::comm_cart, rs_mesh_field.data(),
check_complex_residuals);
}
/* redistribute force component mesh */
}

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_vector_halo_spread() {
std::array<FloatType *, 3u> meshes = {{rs_mesh_fields[0u].data(),
rs_mesh_fields[1u].data(),
rs_mesh_fields[2u].data()}};
mesh_comm.spread_grid(::comm_cart, meshes, local_mesh.dim);
}

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_scalar_fwd_fft() {
void FFTBackendLegacy<FloatType>::perform_scalar_halo_gather() {
mesh_comm.gather_grid(::comm_cart, rs_mesh.data(), local_mesh.dim);
}

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_scalar_fwd_fft() {
fft->forward_fft(::comm_cart, rs_mesh.data());
update_mesh_data();
}

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_vector_fwd_fft() {
void FFTBackendLegacy<FloatType>::perform_vector_halo_gather() {
std::array<FloatType *, 3u> meshes = {{rs_mesh_fields[0u].data(),
rs_mesh_fields[1u].data(),
rs_mesh_fields[2u].data()}};
mesh_comm.gather_grid(::comm_cart, meshes, local_mesh.dim);
}

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_vector_fwd_fft() {
for (auto &rs_mesh_field : rs_mesh_fields) {
fft->forward_fft(::comm_cart, rs_mesh_field.data());
}
Expand All @@ -112,9 +125,11 @@ void FFTBackendLegacy<FloatType>::perform_vector_fwd_fft() {

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_scalar_back_fft() {
/* Back FFT force component mesh */
fft->backward_fft(::comm_cart, rs_mesh.data(), check_complex_residuals);
/* redistribute force component mesh */
}

template <typename FloatType>
void FFTBackendLegacy<FloatType>::perform_scalar_halo_spread() {
mesh_comm.spread_grid(::comm_cart, rs_mesh.data(), local_mesh.dim);
}

Expand Down
5 changes: 5 additions & 0 deletions src/core/p3m/FFTBackendLegacy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ class FFTBackendLegacy : public FFTBackend<FloatType> {
FFTBackendLegacy(p3m_data_struct_fft<FloatType> &obj, bool dipolar);
~FFTBackendLegacy() override;
void init_fft() override;
void init_halo() override;
void perform_scalar_fwd_fft() override;
void perform_vector_fwd_fft() override;
void perform_scalar_back_fft() override;
void perform_vector_back_fft() override;
void perform_scalar_halo_gather() override;
void perform_vector_halo_gather() override;
void perform_scalar_halo_spread() override;
void perform_vector_halo_spread() override;
void update_mesh_data();

/**
Expand Down
10 changes: 10 additions & 0 deletions src/core/p3m/data_struct.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ template <typename FloatType> class FFTBackend {
virtual ~FFTBackend() = default;
/** @brief Initialize the FFT plans and buffers. */
virtual void init_fft() = 0;
/** @brief Initialize the halo buffers. */
virtual void init_halo() = 0;
/** @brief Carry out the forward FFT of the scalar mesh. */
virtual void perform_scalar_fwd_fft() = 0;
/** @brief Carry out the forward FFT of the vector meshes. */
Expand All @@ -112,6 +114,14 @@ template <typename FloatType> class FFTBackend {
virtual void perform_scalar_back_fft() = 0;
/** @brief Carry out the backward FFT of the vector meshes. */
virtual void perform_vector_back_fft() = 0;
/** @brief Update scalar mesh halo with data from neighbors (accumulation). */
virtual void perform_scalar_halo_gather() = 0;
/** @brief Update vector mesh halo with data from neighbors (accumulation). */
virtual void perform_vector_halo_gather() = 0;
/** @brief Update scalar mesh halo of all neighbors. */
virtual void perform_scalar_halo_spread() = 0;
/** @brief Update vector mesh halo of all neighbors. */
virtual void perform_vector_halo_spread() = 0;
/** @brief Get indices of the k-space data layout. */
virtual std::tuple<int, int, int> get_permutations() const = 0;
};
Expand Down

0 comments on commit 0263159

Please sign in to comment.