Skip to content

Commit

Permalink
Fix #6022 srw scale wavefront during sim (#6186)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeilman authored Sep 18, 2023
1 parent f9872e9 commit 8cc5637
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 47 deletions.
4 changes: 3 additions & 1 deletion sirepo/package_data/template/srw/parameters.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ try:
except:
pass

{% if in_server %}
import sirepo.template
{% endif %}
import srwl_bl
import srwlib
import srwlpy
Expand Down Expand Up @@ -431,7 +434,6 @@ def _rsopt_set_params({{ rsOptFuctionSignature() }}):

def epilogue():
{% if in_server %}
import sirepo.template
sirepo.template.run_epilogue('srw')
{% else %}
pass
Expand Down
186 changes: 140 additions & 46 deletions sirepo/template/srw.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@
from sirepo.template import template_common
import array
import copy
import glob
import math
import numpy as np
import os
import pickle
import pykern.pkjson
import re
import sirepo.job
import sirepo.mpi
import sirepo.sim_data
import sirepo.util
import srwl_bl
import srwlib
import srwlpy
import time
import traceback
import uti_io
Expand All @@ -36,6 +37,8 @@
PARSED_DATA_ATTR = "srwParsedData"

_CANVAS_MAX_SIZE = 65535
_LOG_DIR = "__srwl_logs__"
_MAX_REPORT_POINTS = 20000000
_MIN_CORES = 3

_OUTPUT_FOR_MODEL = PKDict(
Expand Down Expand Up @@ -141,9 +144,7 @@
)
_OUTPUT_FOR_MODEL.sourceIntensityReport.title = "E={sourcePhotonEnergy} eV"

_LOG_DIR = "__srwl_logs__"

_JSON_MESSAGE_EXPANSION = 20
_PREPROCESS_PREFIX = "preproc-"

_RSOPT_PARAMS = {
i
Expand Down Expand Up @@ -438,6 +439,7 @@ def extract_report_data(sim_in):
),
)
if out.dimensions == 3:
res.report = r
res = _remap_3d(res, allrange, out, dm[r])
return res

Expand Down Expand Up @@ -501,7 +503,7 @@ def sim_frame(frame_args):
_copy_frame_args_into_model(frame_args, r)
elif "beamlineAnimation" in r:
wid = int(re.search(r".*?(\d+)$", r)[1])
fn = _wavefront_pickle_filename(wid)
fn = _wavefront_pickle_filename(wid, is_processed=True)
with open(fn, "rb") as f:
wfr = pickle.load(f)
m = _copy_frame_args_into_model(frame_args, "watchpointReport")
Expand Down Expand Up @@ -669,6 +671,8 @@ def post_execution_processing(
success_exit,
**kwargs,
):
for f in glob.glob(str(run_dir.join(_PREPROCESS_PREFIX + "*"))):
os.remove(f)
if success_exit:
if _SIM_DATA.is_for_ml(compute_model):
f = _SIM_DATA.ML_OUTPUT
Expand Down Expand Up @@ -787,6 +791,59 @@ def process_undulator_definition(model):
return model


def process_watch(wid=0):
def _resize_wavefront(wfr):
mesh = wfr.mesh
nx, ny = _resize_mesh_dimensions(mesh.nx, mesh.ny)
pkdc("resized mesh: {}x{}", nx, ny)
# resize the electic fields in the wavefront mesh - note it modifies wfr
srwlpy.ResizeElecFieldMesh(
wfr,
srwlib.SRWLRadMesh(
_eStart=mesh.eStart,
_eFin=mesh.eFin,
_ne=mesh.ne,
_xStart=mesh.xStart,
_xFin=mesh.xFin,
_nx=nx,
_yStart=mesh.yStart,
_yFin=mesh.yFin,
_ny=ny,
_zStart=mesh.zStart,
_nvx=mesh.nvx,
_nvy=mesh.nvy,
_nvz=mesh.nvz,
_hvx=mesh.hvx,
_hvy=mesh.hvy,
_hvz=mesh.hvz,
_arSurf=mesh.arSurf,
),
[0, 1],
)
return wfr

def _op():
sim_in = simulation_db.read_json(template_common.INPUT_BASE_NAME)
report = sim_in.models[f"beamlineAnimation{wid}"]
p = _wavefront_pickle_filename(wid)
with open(p, "rb") as f:
wfr = pickle.load(f)
pkdc("original mesh: {}x{}", wfr.mesh.nx, wfr.mesh.ny)
if (
wfr.mesh.nx < _CANVAS_MAX_SIZE
and wfr.mesh.ny < _CANVAS_MAX_SIZE
and wfr.mesh.nx * wfr.mesh.ny <= _MAX_REPORT_POINTS
):
pkio.py_path(p).copy(
pkio.py_path(_wavefront_pickle_filename(wid, is_processed=True))
)
else:
with open(_wavefront_pickle_filename(wid, is_processed=True), "wb") as f:
pickle.dump(_resize_wavefront(wfr), f)

sirepo.mpi.restrict_op_to_first_rank(_op)


def python_source_for_model(data, model, qcall, plot_reports=True, **kwargs):
data.report = model or _SIM_DATA.SRW_RUN_ALL_MODEL
data.report = re.sub("beamlineAnimation0", "initialIntensityReport", data.report)
Expand Down Expand Up @@ -1143,7 +1200,6 @@ def _beamline_animation_percent_complete(run_dir, res):
res.outputInfo = [
PKDict(
modelKey="beamlineAnimation0",
filename=_wavefront_pickle_filename(0),
id=0,
),
]
Expand All @@ -1156,18 +1212,20 @@ def _beamline_animation_percent_complete(run_dir, res):
PKDict(
waitForData=True,
modelKey=f"beamlineAnimation{item.id}",
filename=_wavefront_pickle_filename(item.id),
id=item.id,
)
)
count = 0
for info in res.outputInfo:
try:
with open(info.filename, "rb") as f:
# TODO(pjm): instead look at last byte == pickle.STOP, see template_common.read_last_csv_line()
wfr = pickle.load(f)
count += 1
info.waitForData = False
with open(
pkio.py_path(_wavefront_pickle_filename(info.id, is_processed=True)),
"rb",
) as f:
f.seek(-1, os.SEEK_END)
if f.read(1) == pickle.STOP:
count += 1
info.waitForData = False
except Exception as e:
break
res.frameCount = count
Expand Down Expand Up @@ -1956,6 +2014,8 @@ def _generate_srw_main(data, plot_reports, beamline_info):
source_type = data.models.simulation.sourceType
run_all = report == _SIM_DATA.SRW_RUN_ALL_MODEL or is_for_rsopt
vp_var = "vp" if is_for_rsopt else "varParam"
prev_watch = 0
final_watch = None
content = [
f"v = srwl_bl.srwl_uti_parse_options(srwl_bl.srwl_uti_ext_options({vp_var}), use_sys_argv={plot_reports})",
]
Expand All @@ -1979,6 +2039,7 @@ def _generate_srw_main(data, plot_reports, beamline_info):
for n in beamline_info.names:
names.append(n)
if n in beamline_info.watches:
final_watch = n
is_last_watch = n == beamline_info.names[-1]
content.append("names = ['" + "','".join(names) + "']")
names = []
Expand All @@ -1989,6 +2050,11 @@ def _generate_srw_main(data, plot_reports, beamline_info):
content.append("op = set_optics(v, names, {})".format(is_last_watch))
if not is_last_watch:
content.append("srwl_bl.SRWLBeamline(_name=v.name).calc_all(v, op)")
content.append(
f"sirepo.template.import_module('srw').process_watch(wid={prev_watch})"
)
prev_watch = beamline_info.watches[n]

elif run_all or (
_SIM_DATA.srw_is_beamline_report(report) and len(data.models.beamline)
):
Expand Down Expand Up @@ -2057,6 +2123,13 @@ def _generate_srw_main(data, plot_reports, beamline_info):
if plot_reports:
content.append("v.tr_pl = 'xz'")
content.append("srwl_bl.SRWLBeamline(_name=v.name).calc_all(v, op)")
if report == "beamlineAnimation":
content.append(
f"sirepo.template.import_module('srw').process_watch(wid={prev_watch})"
)
content.append(
f"sirepo.template.import_module('srw').process_watch(wid={beamline_info.watches.get(final_watch, 0)})"
)
return "\n".join(
[f" {x}" for x in content] + [""] + ([] if is_for_rsopt else ["main()", ""])
)
Expand Down Expand Up @@ -2162,21 +2235,10 @@ def _process_rsopt_elements(els):


def _remap_3d(info, allrange, out, report):
x_range = [allrange[3], allrange[4], allrange[5]]
y_range = [allrange[6], allrange[7], allrange[8]]
ar2d = info.points
totLen = int(x_range[2] * y_range[2])
n = len(ar2d) if totLen > len(ar2d) else totLen
ar2d = np.reshape(ar2d[0:n], (int(y_range[2]), int(x_range[2])))

if report.get("usePlotRange", "0") == "1":
ar2d, x_range, y_range = _update_report_range(report, ar2d, x_range, y_range)
if report.get("useIntensityLimits", "0") == "1":
ar2d[ar2d < report.minIntensityLimit] = report.minIntensityLimit
ar2d[ar2d > report.maxIntensityLimit] = report.maxIntensityLimit
ar2d, x_range, y_range = _resize_report(report, ar2d, x_range, y_range)
if report.get("rotateAngle", 0):
ar2d, x_range, y_range = _rotate_report(report, ar2d, x_range, y_range, info)
ar2d, x_range, y_range = _reshape_3d(np.array(info.points), allrange, report)
rotate_angle = report.get("rotateAngle", 0)
if rotate_angle and info.title != "Power Density":
info.subtitle = info.subtitle + " Image Rotate {}^0".format(rotate_angle)
if out.units[2]:
out.labels[2] = "{} [{}]".format(out.labels[2], out.units[2])
if report.get("useIntensityLimits", "0") == "1":
Expand All @@ -2197,23 +2259,28 @@ def _remap_3d(info, allrange, out, report):
)


def _reshape_3d(ar1d, allrange, report):
x_range = [allrange[3], allrange[4], allrange[5]]
y_range = [allrange[6], allrange[7], allrange[8]]
totLen = int(x_range[2] * y_range[2])
n = len(ar1d) if totLen > len(ar1d) else totLen
ar2d = np.reshape(ar1d[0:n], (int(y_range[2]), int(x_range[2])))
if report.get("usePlotRange", "0") == "1":
ar2d, x_range, y_range = _update_report_range(report, ar2d, x_range, y_range)
if report.get("useIntensityLimits", "0") == "1":
ar2d[ar2d < report["minIntensityLimit"]] = report["minIntensityLimit"]
ar2d[ar2d > report["maxIntensityLimit"]] = report["maxIntensityLimit"]
ar2d, x_range, y_range = _resize_report(report, ar2d, x_range, y_range)
if report.get("rotateAngle", 0):
ar2d, x_range, y_range = _rotate_report(report, ar2d, x_range, y_range)
return ar2d, x_range, y_range


def _resize_report(report, ar2d, x_range, y_range):
width_pixels = int(report.get("intensityPlotsWidth", 0))
if not width_pixels:
# upper limit is browser's max html canvas size
width_pixels = _CANVAS_MAX_SIZE
# roughly 20x size increase for json
if ar2d.size * _JSON_MESSAGE_EXPANSION > sirepo.job.cfg().max_message_bytes:
max_width = int(
math.sqrt(sirepo.job.cfg().max_message_bytes / _JSON_MESSAGE_EXPANSION)
)
if max_width < width_pixels:
pkdc(
"auto scaling dimensions to fit message size. size: {}, max_width: {}",
ar2d.size,
max_width,
)
width_pixels = max_width
# rescale width and height to maximum of width_pixels
if width_pixels and (width_pixels < x_range[2] or width_pixels < y_range[2]):
from scipy import ndimage
Expand All @@ -2238,7 +2305,7 @@ def _resize_report(report, ar2d, x_range, y_range):
return ar2d, x_range, y_range


def _rotate_report(report, ar2d, x_range, y_range, info):
def _rotate_report(report, ar2d, x_range, y_range):
from scipy import ndimage

rotate_angle = report.rotateAngle
Expand Down Expand Up @@ -2272,8 +2339,6 @@ def _rotate_report(report, ar2d, x_range, y_range, info):

x_range[2] = ar2d.shape[1]
y_range[2] = ar2d.shape[0]
if info.title != "Power Density":
info.subtitle = info.subtitle + " Image Rotate {}^0".format(rotate_angle)
return ar2d, x_range, y_range


Expand Down Expand Up @@ -2427,6 +2492,34 @@ def _core_error(cores):
raise sirepo.util.UserAlert(f"cores={cores} when cores must be >= {_MIN_CORES}")


def _resize_mesh_dimensions(num_x, num_y):
def _max_size(v, v2):
return min(v, int(_MAX_REPORT_POINTS / v2))

nx = num_x
ny = num_y
_MIN_DIMENSION = int(_MAX_REPORT_POINTS / _CANVAS_MAX_SIZE)
if nx > _MIN_DIMENSION and ny > _MIN_DIMENSION and nx * ny > _MAX_REPORT_POINTS:
r = math.sqrt(_MAX_REPORT_POINTS / (nx * ny))
if r * nx <= _MIN_DIMENSION:
nx = _MIN_DIMENSION
ny = _max_size(ny, nx)
elif r * ny <= _MIN_DIMENSION:
ny = _MIN_DIMENSION
nx = _max_size(nx, ny)
elif r * nx >= _CANVAS_MAX_SIZE:
nx = _CANVAS_MAX_SIZE
ny = _max_size(ny, nx)
elif r * ny >= _CANVAS_MAX_SIZE:
ny = _CANVAS_MAX_SIZE
nx = _max_size(nx, ny)
else:
nx = int(r * nx)
ny = int(r * ny)

return min(nx, _CANVAS_MAX_SIZE), min(ny, _CANVAS_MAX_SIZE)


def _superscript(val):
return re.sub(r"\^2", "\u00B2", val)

Expand Down Expand Up @@ -2642,10 +2735,11 @@ def file_attrs_ok(attrs):
)


def _wavefront_pickle_filename(el_id):
if el_id:
return f"wid-{el_id}.pkl"
return "initial.pkl"
def _wavefront_pickle_filename(el_id, is_processed=False):
f = f"wid-{el_id}" if el_id else "initial"
if not is_processed:
f = _PREPROCESS_PREFIX + f
return f"{f}.pkl"


def _write_rsopt_files(data, run_dir, ctx):
Expand Down
34 changes: 34 additions & 0 deletions tests/template/srw_resize_3d_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
"""PyTest for :mod:`sirepo.template.srw`
:copyright: Copyright (c) 2023 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""


def test_srw_resize_3d():
from pykern import pkunit
from sirepo.template import srw

MAX = srw._CANVAS_MAX_SIZE

for t in (
[(100, 100), (100, 100)],
[(100, MAX + 1), (100, MAX)],
[(MAX + 1, 100), (MAX, 100)],
[(400, 400), (400, 400)],
[(MAX, MAX), (4472, 4472)],
[(MAX, MAX * 2), (3162, 6324)],
[(MAX * 2, MAX), (6324, 3162)],
[(MAX + 1, 300), (MAX, 300)],
[(300, MAX + 1), (300, MAX)],
[(400, MAX), (349, 57242)],
[(MAX, 400), (57242, 349)],
[(400, 1e6), (305, MAX)],
[(1e6, 400), (MAX, 305)],
[(MAX * 10, MAX), (14142, 1414)],
[(MAX, MAX * 10), (1414, 14142)],
):
x, y = srw._resize_mesh_dimensions(t[0][0], t[0][1])
assert x * y <= srw._MAX_REPORT_POINTS
pkunit.pkeq((x, y), t[1])

0 comments on commit 8cc5637

Please sign in to comment.