Skip to content

Commit

Permalink
openpmd ParticleGroup/ParticleArray converters with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
st-walker committed Jun 7, 2021
1 parent 6b0c53e commit 55d485d
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
60 changes: 60 additions & 0 deletions ocelot/adaptors/pmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Optional

import pmd_beamphysics as pmd
import numpy as np

from ocelot.cpbd.beam import ParticleArray


def particle_group_to_parray(pgroup: pmd.ParticleGroup) -> ParticleArray:
"""Construct an Ocelot ParticleArray from an openPMD-beamphysics ParticleGroup.
The particle type is assumed to be electrons.
:param pgroup: ParticleGroup from which to construct the ParticleArray
:return: ParticleArray corresponding to the provided ParticleGroup
:rtype: ParticleArray
"""
parray = ParticleArray(pgroup.n_particle)

# These are in eV, but so are px and py so this is OK.
reference_momentum = pgroup.avg("p")
reference_energy = pgroup.avg("energy")

rpart = parray.rparticles
rpart[0] = pgroup.x
rpart[1] = pgroup.px / reference_momentum
rpart[2] = pgroup.y
rpart[3] = pgroup.py / reference_momentum
rpart[4] = pgroup.z
rpart[5] = (pgroup.energy - reference_energy) / reference_momentum

parray.E = reference_energy * 1e-9 # Convert eV to GeV
parray.q_array = pgroup.weight

return parray


def load_pmd(filename: str) -> ParticleArray:
return particle_group_to_parray(pmd.ParticleGroup(h5=filename))


def particle_array_to_particle_group(parray: ParticleArray) -> pmd.ParticleGroup:
px = parray.px() * parray.p0 * 1e9 # to eV
py = parray.py() * parray.p0 * 1e9 # to eV
pz = parray.pz * parray.p0 * 1e9 # to eV

data = {
"x": parray.x(),
"px": px,
"y": parray.y(),
"py": py,
"z": parray.tau(),
"pz": pz,
"t": np.zeros_like(px),
"weight": parray.q_array,
"status": np.ones_like(px),
"species": "electron",
}

return pmd.ParticleGroup(data=data)
11 changes: 11 additions & 0 deletions ocelot/cpbd/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

from ocelot.adaptors.astra2ocelot import astraBeam2particleArray, particleArray2astraBeam
from ocelot.adaptors.csrtrack2ocelot import csrtrackBeam2particleArray, particleArray2csrtrackBeam

try:
from ocelot.adaptors.pmd import load_pmd, particle_array_to_particle_group
except ImportError:
pass

from ocelot.cpbd.beam import ParticleArray, Twiss, Beam


Expand Down Expand Up @@ -50,6 +56,8 @@ def load_particle_array(filename, print_params=False):
parray = astraBeam2particleArray(filename, print_params=False)
elif file_extension in [".fmt1"]:
parray = csrtrackBeam2particleArray(filename)
elif file_extension == ".h5":
parray = load_pmd(filename)
else:
raise Exception("Unknown format of the beam file: " + file_extension + " but must be *.ast, *fmt1 or *.npz ")

Expand Down Expand Up @@ -77,6 +85,9 @@ def save_particle_array(filename, p_array):
particleArray2astraBeam(p_array, filename)
elif file_extension == ".fmt1":
particleArray2csrtrackBeam(p_array, filename)
elif file_extension == ".h5":
particle_array_to_particle_group(p_array).write(filename)
else:
particle_array_to_particle_group(p_array).write(filename)
raise Exception("Unknown format of the beam file: " + file_extension + " but must be *.ast, *.fmt1 or *.npz")

77 changes: 77 additions & 0 deletions unit_tests/adaptors_test/test_pmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import pmd_beamphysics as pmd
import pytest

from ocelot.adaptors import pmd as pmd_adaptor
from ocelot.common.globals import m_e_eV
from ocelot.cpbd.beam import ParticleArray


# Toy data for instantiating a ParticleGroup
PARTICLE_GROUP_DATA = {
"x": np.array([1e-6, 2e-6, 3e-6]),
"px": np.array([10, 20, 30]),
"y": np.array([1e-7, 2e-7, 3e-7]),
"py": np.array([15, 25, 35]),
"z": np.array([0.1, 0.2, 0.3]),
"pz": np.array([1e9, 1.01e9, 1.02e9]), # i.e. ~1 GeV/c
"t": np.array([1e-9, 2e-9, 3e-9]),
"weight": np.array([1e-15, 1e-15, 1e-15]),
"status": np.array([1, 1, 1]),
"species": "electron",
}


@pytest.fixture
def tmp_pmdh5(tmp_path):
"""tmp path to written pmd h5 file defined from above PARTICLE_GROUP_DATA"""
pg = pmd.ParticleGroup(data=PARTICLE_GROUP_DATA)
pmd_path = tmp_path / "pmd.h5"
pg.write(str(pmd_path))
yield pmd_path


@pytest.fixture
def pmd_parray(tmp_pmdh5):
"""Ocelot ParticleArray fixture from the above PARTICLE_GROUP_DATA."""
pgroup = pmd.ParticleGroup(h5=str(tmp_pmdh5))
yield pmd_adaptor.particle_group_to_parray(pgroup)


def compare_particle_group_with_array(pgroup, parray):
# Use the values from the pgroup for consistency as these are what
# particle_group_to_parray uses internally.
refmom = pgroup.avg("p")
refenergy = pgroup.avg("energy")

np.testing.assert_allclose(pgroup.x, parray.x())
np.testing.assert_allclose(pgroup.px / refmom, parray.px()),
np.testing.assert_allclose(pgroup.y, parray.y())
np.testing.assert_allclose(pgroup.py / refmom, parray.py())

np.testing.assert_allclose(pgroup.z, parray.tau())
dp = (pgroup.energy - refenergy) / refmom
np.testing.assert_allclose(dp, parray.p())

np.testing.assert_allclose(pgroup.weight, parray.q_array)


def test_particle_group_to_parray():
"""Convertiong of ParticleGroup to ParticleArray"""
# instantiate a ParticleGroup and make corresponding ParticleArray
pgroup = pmd.ParticleGroup(data=PARTICLE_GROUP_DATA)
parray = pmd_adaptor.particle_group_to_parray(pgroup)
compare_particle_group_with_array(pgroup, parray)


def test_load_pmd(tmp_pmdh5):
"""Loading of PMD files"""
parray = pmd_adaptor.load_pmd(str(tmp_pmdh5))
pgroup = pmd.ParticleGroup(data=PARTICLE_GROUP_DATA)
compare_particle_group_with_array(pgroup, parray)


def test_parray_to_particle_group(pmd_parray):
"""Conversion of ParticleArray to ParticleGroup"""
pgroup = pmd_adaptor.particle_array_to_particle_group(pmd_parray)
compare_particle_group_with_array(pgroup, pmd_parray)

0 comments on commit 55d485d

Please sign in to comment.