Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AnyFFT #1099

Merged
merged 12 commits into from
May 2, 2024
Merged
30 changes: 23 additions & 7 deletions docs/source/run/parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,31 @@ The default is to use the explicit solver. **We strongly recommend to use the ex
Which solver to use.
Possible values: ``explicit`` and ``predictor-corrector``.

* ``fields.poisson_solver`` (`string`) optional (default `FFTDirichlet`)
* ``fields.poisson_solver`` (`string`) optional (default CPU: `FFTDirichletDirect`, GPU: `FFTDirichletFast`)
Which Poisson solver to use for ``Psi``, ``Ez`` and ``Bz``. The ``predictor-corrector`` BxBy
solver also uses this poisson solver for ``Bx`` and ``By`` internally. Available solvers are
``FFTDirichlet``, ``FFTPeriodic`` and ``MGDirichlet``.
solver also uses this poisson solver for ``Bx`` and ``By`` internally. Available solvers are:

* ``hipace.use_small_dst`` (`bool`) optional (default `0` or `1`)
Whether to use a large R2C or a small C2R fft in the dst of the Poisson solver.
The small dst is quicker for simulations with :math:`\geq 511` transverse grid points.
The default is set accordingly.
* ``FFTDirichletDirect`` Use the discrete sine transformation that is directly implemented
by FFTW to solve the Poisson equation with Dirichlet boundary conditions.
This option is only available when compiling for CPUs with FFTW.
Preferred resolution: :math:`2^N-1`.

* ``FFTDirichletExpanded`` Perform the discrete sine transformation by symmetrically
expanding the field to twice its size.
Preferred resolution: :math:`2^N-1`.

* ``FFTDirichletFast`` Perform the discrete sine transformation using a fast sine transform
algorithm that uses FFTs of the same size as the fields.
Preferred resolution: :math:`2^N-1`.

* ``MGDirichlet`` Use the HiPACE++ multigrid solver to solve the Poisson equation with
Dirichlet boundary conditions.
Preferred resolution: :math:`2^N` and :math:`2^N-1`.

* ``FFTPeriodic`` Use FFTs to solve the Poisson equation with Periodic boundary conditions.
Note that this does not work with features that change the boundary values,
like mesh refinement or open boundaries.
Preferred resolution: :math:`2^N`.
Comment on lines +262 to +282
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice doc!


* ``fields.extended_solve`` (`bool`) optional (default `0`)
Extends the area of the FFT Poisson solver to the ghost cells. This can reduce artifacts
Expand Down
3 changes: 3 additions & 0 deletions src/Hipace.H
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ struct Hipace_early_init
* and Parser Constants */
Hipace_early_init (Hipace* instance);

/** Destructor for FFT cleanup */
~Hipace_early_init ();

/** Struct containing physical constants (which values depends on the unit system, determined
* at runtime): SI or normalized units. */
PhysConst m_phys_const;
Expand Down
7 changes: 7 additions & 0 deletions src/Hipace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "utils/GPUUtil.H"
#include "particles/pusher/GetAndSetPosition.H"
#include "mg_solver/HpMultiGrid.H"
#include "fields/fft_poisson_solver/fft/AnyFFT.H"

#include <AMReX_ParmParse.H>
#include <AMReX_IntVect.H>
Expand Down Expand Up @@ -53,6 +54,12 @@ Hipace_early_init::Hipace_early_init (Hipace* instance)
int max_level = 0;
queryWithParser(pp_amr, "max_level", max_level);
m_N_level = max_level + 1;
AnyFFT::setup();
}

Hipace_early_init::~Hipace_early_init ()
{
AnyFFT::cleanup();
}

Hipace&
Expand Down
2 changes: 1 addition & 1 deletion src/fields/Fields.H
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ private:
/** Vector over levels of all required fields to compute current slice */
amrex::Vector<amrex::MultiFab> m_slices;
/** Type of poisson solver to use */
std::string m_poisson_solver_str = "FFTDirichlet";
std::string m_poisson_solver_str = "";
/** Class to handle transverse FFT Poisson solver on 1 slice */
amrex::Vector<std::unique_ptr<FFTPoissonSolver>> m_poisson_solver;
/** Stores temporary values for z interpolation in Fields::Copy */
Expand Down
33 changes: 26 additions & 7 deletions src/fields/Fields.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
*/
#include "Fields.H"
#include "fft_poisson_solver/FFTPoissonSolverPeriodic.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichlet.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichletDirect.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichletExpanded.H"
#include "fft_poisson_solver/FFTPoissonSolverDirichletFast.H"
#include "fft_poisson_solver/MGPoissonSolverDirichlet.H"
#include "Hipace.H"
#include "OpenBoundary.H"
Expand All @@ -29,6 +31,12 @@ Fields::Fields (const int nlev)
{
amrex::ParmParse ppf("fields");
DeprecatedInput("fields", "do_dirichlet_poisson", "poisson_solver", "");
// set default Poisson solver based on the platform
#ifdef AMREX_USE_GPU
m_poisson_solver_str = "FFTDirichletFast";
#else
m_poisson_solver_str = "FFTDirichletDirect";
#endif
queryWithParser(ppf, "poisson_solver", m_poisson_solver_str);
queryWithParser(ppf, "extended_solve", m_extended_solve);
queryWithParser(ppf, "open_boundary", m_open_boundary);
Expand Down Expand Up @@ -178,11 +186,21 @@ Fields::AllocData (
// The Poisson solver operates on transverse slices only.
// The constructor takes the BoxArray and the DistributionMap of a slice,
// so the FFTPlans are built on a slice.
if (m_poisson_solver_str == "FFTDirichlet"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichlet>(
new FFTPoissonSolverDirichlet(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
if (m_poisson_solver_str == "FFTDirichletDirect"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletDirect>(
new FFTPoissonSolverDirichletDirect(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
} else if (m_poisson_solver_str == "FFTDirichletExpanded"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletExpanded>(
new FFTPoissonSolverDirichletExpanded(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
} else if (m_poisson_solver_str == "FFTDirichletFast"){
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverDirichletFast>(
new FFTPoissonSolverDirichletFast(getSlices(lev).boxArray(),
getSlices(lev).DistributionMap(),
geom)) );
} else if (m_poisson_solver_str == "FFTPeriodic") {
m_poisson_solver.push_back(std::unique_ptr<FFTPoissonSolverPeriodic>(
new FFTPoissonSolverPeriodic(getSlices(lev).boxArray(),
Expand All @@ -195,7 +213,8 @@ Fields::AllocData (
geom)) );
} else {
amrex::Abort("Unknown poisson solver '" + m_poisson_solver_str +
"', must be 'FFTDirichlet', 'FFTPeriodic' or 'MGDirichlet'");
"', must be 'FFTDirichletDirect', 'FFTDirichletExpanded', 'FFTDirichletFast', " +
"'FFTPeriodic' or 'MGDirichlet'");
}

if (lev == 0 && m_insitu_period > 0) {
Expand Down
4 changes: 3 additions & 1 deletion src/fields/fft_poisson_solver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ target_sources(HiPACE
PRIVATE
FFTPoissonSolver.cpp
FFTPoissonSolverPeriodic.cpp
FFTPoissonSolverDirichlet.cpp
FFTPoissonSolverDirichletDirect.cpp
FFTPoissonSolverDirichletExpanded.cpp
FFTPoissonSolverDirichletFast.cpp
MGPoissonSolverDirichlet.cpp
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
* Authors: AlexanderSinn, MaxThevenet, Severin Diederichs
* License: BSD-3-Clause-LBNL
*/
#ifndef FFT_POISSON_SOLVER_DIRICHLET_H_
#define FFT_POISSON_SOLVER_DIRICHLET_H_
#ifndef FFT_POISSON_SOLVER_DIRICHLET_DIRECT_H_
#define FFT_POISSON_SOLVER_DIRICHLET_DIRECT_H_

#include "fields/fft_poisson_solver/fft/AnyDST.H"
#include "fields/fft_poisson_solver/fft/AnyFFT.H"
#include "FFTPoissonSolver.H"

#include <AMReX_MultiFab.H>
Expand All @@ -23,16 +23,16 @@
* 2. Call FFTPoissonSolver::SolvePoissonEquation(mf), which will solve Poisson equation with RHS
* in the staging area and return the LHS in mf.
*/
class FFTPoissonSolverDirichlet final : public FFTPoissonSolver
class FFTPoissonSolverDirichletDirect final : public FFTPoissonSolver
{
public:
/** Constructor */
FFTPoissonSolverDirichlet ( amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm);
FFTPoissonSolverDirichletDirect ( amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm);

/** virtual destructor */
virtual ~FFTPoissonSolverDirichlet () override final {}
virtual ~FFTPoissonSolverDirichletDirect () override final {}

/**
* \brief Define real space and spectral space boxes and multifabs, Dirichlet
Expand Down Expand Up @@ -63,8 +63,12 @@ private:
amrex::MultiFab m_tmpSpectralField;
/** Multifab eigenvalues, to solve Poisson equation with Dirichlet BC. */
amrex::MultiFab m_eigenvalue_matrix;
/** DST plans */
AnyDST::DSTplans m_plan;
/** forward DST plan */
AnyFFT m_forward_fft;
/** backward DST plan */
AnyFFT m_backward_fft;
/** work area for both DST plans */
amrex::Gpu::DeviceVector<char> m_fft_work_area;
};

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
*
* License: BSD-3-Clause-LBNL
*/
#include "FFTPoissonSolverDirichlet.H"
#include "fft/AnyDST.H"
#include "FFTPoissonSolverDirichletDirect.H"
#include "fft/AnyFFT.H"
#include "fields/Fields.H"
#include "utils/Constants.H"
#include "utils/GPUUtil.H"
#include "utils/HipaceProfilerWrapper.H"

FFTPoissonSolverDirichlet::FFTPoissonSolverDirichlet (
FFTPoissonSolverDirichletDirect::FFTPoissonSolverDirichletDirect (
amrex::BoxArray const& realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm )
Expand All @@ -22,10 +22,11 @@ FFTPoissonSolverDirichlet::FFTPoissonSolverDirichlet (
}

void
FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm )
FFTPoissonSolverDirichletDirect::define (amrex::BoxArray const& a_realspace_ba,
amrex::DistributionMapping const& dm,
amrex::Geometry const& gm )
{
HIPACE_PROFILE("FFTPoissonSolverDirichletDirect::define()");
using namespace amrex::literals;

// If we are going to support parallel FFT, the constructor needs to take a communicator.
Expand All @@ -48,16 +49,18 @@ FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
"There should be only one box locally.");

const amrex::Box fft_box = m_stagingArea[0].box();
const amrex::IntVect fft_size = fft_box.length();
const int nx = fft_size[0];
const int ny = fft_size[1];
const auto dx = gm.CellSizeArray();
const amrex::Real dxsquared = dx[0]*dx[0];
const amrex::Real dysquared = dx[1]*dx[1];
const amrex::Real sine_x_factor = MathConst::pi / ( 2. * ( fft_box.length(0) + 1 ));
const amrex::Real sine_y_factor = MathConst::pi / ( 2. * ( fft_box.length(1) + 1 ));
const amrex::Real sine_x_factor = MathConst::pi / ( 2. * ( nx + 1 ));
const amrex::Real sine_y_factor = MathConst::pi / ( 2. * ( ny + 1 ));

// Normalization of FFTW's 'DST-I' discrete sine transform (FFTW_RODFT00)
// This normalization is used regardless of the sine transform library
const amrex::Real norm_fac = 0.5 / ( 2 * (( fft_box.length(0) + 1 )
*( fft_box.length(1) + 1 )));
const amrex::Real norm_fac = 0.5 / ( 2 * (( nx + 1 ) * ( ny + 1 )));

// Calculate the array of m_eigenvalue_matrix
for (amrex::MFIter mfi(m_eigenvalue_matrix, DfltMfi); mfi.isValid(); ++mfi ){
Expand All @@ -67,9 +70,9 @@ FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
fft_box, [=] AMREX_GPU_DEVICE (int i, int j, int /* k */) noexcept
{
/* fast poisson solver diagonal x coeffs */
amrex::Real sinex_sq = sin(( i - lo[0] + 1 ) * sine_x_factor) * sin(( i - lo[0] + 1 ) * sine_x_factor);
amrex::Real sinex_sq = std::sin(( i - lo[0] + 1 ) * sine_x_factor) * std::sin(( i - lo[0] + 1 ) * sine_x_factor);
/* fast poisson solver diagonal y coeffs */
amrex::Real siney_sq = sin(( j - lo[1] + 1 ) * sine_y_factor) * sin(( j - lo[1] + 1 ) * sine_y_factor);
amrex::Real siney_sq = std::sin(( j - lo[1] + 1 ) * sine_y_factor) * std::sin(( j - lo[1] + 1 ) * sine_y_factor);

if ((sinex_sq!=0) && (siney_sq!=0)) {
eigenvalue_matrix(i,j) = norm_fac / ( -4.0 * ( sinex_sq / dxsquared + siney_sq / dysquared ));
Expand All @@ -81,29 +84,25 @@ FFTPoissonSolverDirichlet::define (amrex::BoxArray const& a_realspace_ba,
}

// Allocate and initialize the FFT plans
m_plan = AnyDST::DSTplans(a_realspace_ba, dm);
// Loop over boxes and allocate the corresponding plan
// for each box owned by the local MPI proc
for ( amrex::MFIter mfi(m_stagingArea, DfltMfi); mfi.isValid(); ++mfi ){
// Note: the size of the real-space box and spectral-space box
// differ when using real-to-complex FFT. When initializing
// the FFT plan, the valid dimensions are those of the real-space box.
amrex::IntVect fft_size = fft_box.length();
m_plan[mfi] = AnyDST::CreatePlan(
fft_size, &m_stagingArea[mfi], &m_tmpSpectralField[mfi]);
}
std::size_t fwd_area = m_forward_fft.Initialize(FFTType::R2R_2D, fft_size[0], fft_size[1]);
std::size_t bkw_area = m_backward_fft.Initialize(FFTType::R2R_2D, fft_size[0], fft_size[1]);

// Allocate work area for both FFTs
m_fft_work_area.resize(std::max(fwd_area, bkw_area));

m_forward_fft.SetBuffers(m_stagingArea[0].dataPtr(), m_tmpSpectralField[0].dataPtr(),
m_fft_work_area.dataPtr());
m_backward_fft.SetBuffers(m_tmpSpectralField[0].dataPtr(), m_stagingArea[0].dataPtr(),
m_fft_work_area.dataPtr());
}


void
FFTPoissonSolverDirichlet::SolvePoissonEquation (amrex::MultiFab& lhs_mf)
FFTPoissonSolverDirichletDirect::SolvePoissonEquation (amrex::MultiFab& lhs_mf)
{
HIPACE_PROFILE("FFTPoissonSolverDirichlet::SolvePoissonEquation()");
HIPACE_PROFILE("FFTPoissonSolverDirichletDirect::SolvePoissonEquation()");

for ( amrex::MFIter mfi(m_stagingArea, DfltMfi); mfi.isValid(); ++mfi ){
// Perform Fourier transform from the staging area to `tmpSpectralField`
AnyDST::Execute(m_plan[mfi], AnyDST::direction::forward);
}
m_forward_fft.Execute();

#ifdef AMREX_USE_OMP
#pragma omp parallel if (amrex::Gpu::notInLaunchRegion())
Expand All @@ -120,10 +119,7 @@ FFTPoissonSolverDirichlet::SolvePoissonEquation (amrex::MultiFab& lhs_mf)
});
}

for ( amrex::MFIter mfi(m_stagingArea, DfltMfi); mfi.isValid(); ++mfi ){
// Perform Fourier transform from `tmpSpectralField` to the staging area
AnyDST::Execute(m_plan[mfi], AnyDST::direction::backward);
}
m_backward_fft.Execute();

#ifdef AMREX_USE_OMP
#pragma omp parallel if (amrex::Gpu::notInLaunchRegion())
Expand Down
Loading
Loading