diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f74514e4..30bde707 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -32,7 +32,7 @@ jobs: export DOCKER_BINARY="docker" echo "DOCKER_BINARY=${DOCKER_BINARY}" >> $GITHUB_ENV - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -87,8 +87,12 @@ jobs: set -vxeuo pipefail make -C docs/ html status=$? + + echo "::group::docker logs" echo "Sirepo ${DOCKER_BINARY} container id: ${SIREPO_DOCKER_CONTAINER_ID}" ${DOCKER_BINARY} logs ${SIREPO_DOCKER_CONTAINER_ID} + echo "::endgroup::" + if [ $status -gt 0 ]; then exit $status fi diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index dde2e821..4f31a48f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -19,7 +19,7 @@ jobs: shell: bash -l {0} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index 007955e9..cd436569 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -17,7 +17,7 @@ jobs: id-token: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index bd0dab11..f1457a63 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -12,12 +12,17 @@ jobs: # pull requests are a duplicate of a branch push if within the same repo. if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.repository - name: Test sirepo-bluesky with ${{ matrix.docker-binary }} and Python ${{ matrix.python-version }} + name: Test sirepo-bluesky with py${{ matrix.python-version }} and databroker ${{ matrix.databroker }} (${{ matrix.docker-binary }}) runs-on: ubuntu-latest strategy: matrix: python-version: ["3.8", "3.9", "3.10"] - docker-binary: ["docker", "podman"] + databroker: ["v1", "v2"] + docker-binary: ["docker"] + include: + - python-version: "3.10" + docker-binary: "podman" + databroker: "v1" fail-fast: false defaults: @@ -33,8 +38,16 @@ jobs: export DOCKER_BINARY=${{ matrix.docker-binary }} echo "DOCKER_BINARY=${DOCKER_BINARY}" >> $GITHUB_ENV + if [ "${{ matrix.databroker }}" == "v1" ]; then + EXTRA_INSTALL="databroker==1.2.5" + elif [ "${{ matrix.databroker }}" == "v2" ]; then + EXTRA_INSTALL="databroker[all]>=2.0.0b tiled[all]" + fi + export EXTRA_INSTALL + echo "EXTRA_INSTALL=${EXTRA_INSTALL}" >> $GITHUB_ENV + - name: Checkout the code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Start MongoDB uses: supercharge/mongodb-github-action@1.6.0 @@ -63,15 +76,20 @@ jobs: run: | set -vxeo pipefail python -m pip install --upgrade pip wheel - python -m pip install -v . + python -m pip install -v . ${EXTRA_INSTALL} python -m pip install -r requirements-dev.txt python -m pip list - name: Copy databroker config file run: | set -vxeuo pipefail - mkdir -v -p ~/.config/databroker/ - cp -v examples/local.yml ~/.config/databroker/ + if [ "${{ matrix.databroker }}" == "v1" ]; then + mkdir -v -p ~/.config/databroker/ + cp -v examples/local.yml ~/.config/databroker/ + elif [ "${{ matrix.databroker }}" == "v2" ]; then + mkdir -v -p ~/.config/tiled/profiles/ + cp -v examples/local-tiled.yml ~/.config/tiled/profiles/local.yml + fi - name: Test with pytest run: | @@ -79,8 +97,12 @@ jobs: pytest -s -vv status=$? ${DOCKER_BINARY} ps -a + + echo "::group::docker logs" echo "Sirepo ${DOCKER_BINARY} container id: ${SIREPO_DOCKER_CONTAINER_ID}" ${DOCKER_BINARY} logs ${SIREPO_DOCKER_CONTAINER_ID} + echo "::endgroup::" + if [ $status -gt 0 ]; then exit $status fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dabf1796..db627e29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/ambv/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black language_version: python3.10 diff --git a/docs/source/notebooks/madx.ipynb b/docs/source/notebooks/madx.ipynb index 24ca93d5..59a79639 100644 --- a/docs/source/notebooks/madx.ipynb +++ b/docs/source/notebooks/madx.ipynb @@ -30,11 +30,11 @@ "\n", "import matplotlib.pyplot as plt\n", "\n", - "from sirepo_bluesky.sirepo_bluesky import SirepoBluesky\n", - "from sirepo_bluesky.madx_flyer import MADXFlyer\n", - "from sirepo_bluesky.sirepo_ophyd import create_classes\n", + "from sirepo_bluesky.common.sirepo_client import SirepoClient\n", + "from sirepo_bluesky.madx.madx_flyer import MADXFlyer\n", + "from sirepo_bluesky.common.create_classes import create_classes\n", "\n", - "connection = SirepoBluesky(\"http://localhost:8000\")\n", + "connection = SirepoClient(\"http://localhost:8000\")\n", "\n", "data, schema = connection.auth(\"madx\", \"00000002\")\n", "classes, objects = create_classes(\n", diff --git a/docs/source/notebooks/shadow.ipynb b/docs/source/notebooks/shadow.ipynb index ec9ef62d..e5fc1b84 100644 --- a/docs/source/notebooks/shadow.ipynb +++ b/docs/source/notebooks/shadow.ipynb @@ -40,10 +40,10 @@ "\n", "%run -i $prepare_re_env.__file__\n", "\n", - "from sirepo_bluesky.sirepo_bluesky import SirepoBluesky\n", - "from sirepo_bluesky.sirepo_ophyd import BeamStatisticsReport, create_classes\n", + "from sirepo_bluesky.common.sirepo_client import SirepoClient\n", + "from sirepo_bluesky.common.create_classes import create_classes\n", "\n", - "connection = SirepoBluesky(\"http://localhost:8000\")\n", + "connection = SirepoClient(\"http://localhost:8000\")\n", "\n", "data, schema = connection.auth(\"shadow\", sim_id=\"00000002\")\n", "classes, objects = create_classes(connection=connection)\n", @@ -110,10 +110,11 @@ "\n", "%run -i $prepare_re_env.__file__\n", "\n", - "from sirepo_bluesky.sirepo_bluesky import SirepoBluesky\n", - "from sirepo_bluesky.sirepo_ophyd import BeamStatisticsReport, create_classes\n", + "from sirepo_bluesky.common.sirepo_client import SirepoClient\n", + "from sirepo_bluesky.shadow.shadow_ophyd import BeamStatisticsReport\n", + "from sirepo_bluesky.common.create_classes import create_classes\n", "\n", - "connection = SirepoBluesky(\"http://localhost:8000\")\n", + "connection = SirepoClient(\"http://localhost:8000\")\n", "\n", "data, schema = connection.auth(\"shadow\", sim_id=\"00000002\")\n", "\n", diff --git a/docs/source/notebooks/srw.ipynb b/docs/source/notebooks/srw.ipynb index 237e2257..20d325d6 100644 --- a/docs/source/notebooks/srw.ipynb +++ b/docs/source/notebooks/srw.ipynb @@ -40,10 +40,10 @@ "\n", "%run -i $prepare_re_env.__file__\n", "\n", - "from sirepo_bluesky.sirepo_bluesky import SirepoBluesky\n", - "from sirepo_bluesky.sirepo_ophyd import create_classes\n", + "from sirepo_bluesky.common.sirepo_client import SirepoClient\n", + "from sirepo_bluesky.common.create_classes import create_classes\n", "\n", - "connection = SirepoBluesky(\"http://localhost:8000\")\n", + "connection = SirepoClient(\"http://localhost:8000\")\n", "\n", "data, schema = connection.auth(\"srw\", sim_id=\"00000002\")\n", "classes, objects = create_classes(connection=connection)\n", @@ -109,10 +109,10 @@ "\n", "%run -i $prepare_re_env.__file__\n", "\n", - "from sirepo_bluesky.sirepo_bluesky import SirepoBluesky\n", - "from sirepo_bluesky.sirepo_ophyd import create_classes\n", + "from sirepo_bluesky.common.sirepo_client import SirepoClient\n", + "from sirepo_bluesky.common.create_classes import create_classes\n", "\n", - "connection = SirepoBluesky(\"http://localhost:8000\")\n", + "connection = SirepoClient(\"http://localhost:8000\")\n", "\n", "data, schema = connection.auth(\"srw\", sim_id=\"00000002\")\n", "classes, objects = create_classes(connection=connection)\n", @@ -162,9 +162,12 @@ " aspect=False,\n", ")\n", "\n", + "h_dims = 1e6 * np.array(\n", + " w9.horizontal_extent_start.get(), w9.horizontal_extent_end.get()\n", + ")\n", + "v_dims = 1e6 * np.array(w9.vertical_extent_start.get(), w9.vertical_extent_start.get())\n", + "\n", "for ax, im in zip(grid, w9_image):\n", - " h_dims = 1e6 * w9.horizontal_extent.get()\n", - " v_dims = 1e6 * w9.vertical_extent.get()\n", " ax.imshow(\n", " im, interpolation=\"nearest\", aspect=\"auto\", extent=(*h_dims[:], *v_dims[:])\n", " )" @@ -192,10 +195,10 @@ "\n", "%run -i $prepare_re_env.__file__\n", "\n", - "from sirepo_bluesky.sirepo_bluesky import SirepoBluesky\n", - "from sirepo_bluesky.sirepo_ophyd import create_classes\n", + "from sirepo_bluesky.common.sirepo_client import SirepoClient\n", + "from sirepo_bluesky.common.create_classes import create_classes\n", "\n", - "connection = SirepoBluesky(\"http://localhost:8000\")\n", + "connection = SirepoClient(\"http://localhost:8000\")\n", "\n", "data, schema = connection.auth(\"srw\", sim_id=\"00000003\")\n", "classes, objects = create_classes(\n", diff --git a/examples/local-tiled.yml b/examples/local-tiled.yml new file mode 100644 index 00000000..668560a6 --- /dev/null +++ b/examples/local-tiled.yml @@ -0,0 +1,16 @@ +local: + direct: + authentication: + allow_anonymous_access: true + trees: + - tree: databroker.mongo_normalized:Tree.from_uri + path: / + args: + uri: mongodb://localhost:27017/datastore + asset_registry_uri: mongodb://localhost:27017/filestore + handler_registry: + srw: sirepo_bluesky.srw.srw_handler:SRWFileHandler + SIREPO_FLYER: sirepo_bluesky.srw.srw_handler:SRWFileHandler + SRW_HDF5: sirepo_bluesky.srw.srw_handler:SRWHDF5FileHandler + shadow: sirepo_bluesky.shadow.shadow_handler:ShadowFileHandler + madx: sirepo_bluesky.madx.madx_handler:MADXFileHandler diff --git a/pytest.ini b/pytest.ini index 90da2022..53e28685 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,6 @@ [pytest] markers = docker: marks tests as requiring Docker to be available/running (deselect with '-m "not docker"') + madx: sirepo/madx simulation code tests + shadow: sirepo/shadow simulation code tests + srw: sirepo/srw simulation code tests diff --git a/requirements.txt b/requirements.txt index 00e00240..d40e83d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ area-detector-handlers bluesky databroker +h5py inflection matplotlib numconv @@ -10,6 +11,7 @@ peakutils pymongo pyqt5>=5.9 requests +scikit-image shadow3>=23.1.4 srwpy>=4.0.0b1 tfs-pandas diff --git a/scripts/start_sirepo.sh b/scripts/start_sirepo.sh index b4a432aa..a81b7a45 100644 --- a/scripts/start_sirepo.sh +++ b/scripts/start_sirepo.sh @@ -82,6 +82,8 @@ cmd_start="${docker_binary} run ${arg} --init ${remove_container} --name sirepo -p 8000:8000 \ -v $SIREPO_SRDB_HOST_RO:/SIREPO_SRDB_ROOT:ro,z " +# TODO: parametrize host port number. + cmd_extra="" if [ ! -z "${SIREPO_SRDB_HOST}" -a ! -z "${SIREPO_SRDB_GUEST}" ]; then cmd_extra="-v ${SIREPO_SRDB_HOST}:${SIREPO_SRDB_GUEST}:rw,z " diff --git a/sirepo_bluesky/__init__.py b/sirepo_bluesky/__init__.py index 8bac5028..d2f3e054 100644 --- a/sirepo_bluesky/__init__.py +++ b/sirepo_bluesky/__init__.py @@ -1,26 +1,5 @@ -from ophyd import Signal - from ._version import get_versions from .utils import prepare_re_env # noqa: F401 __version__ = get_versions()["version"] del get_versions - - -class ExternalFileReference(Signal): - """ - A pure software Signal that describe()s an image in an external file. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def describe(self): - resource_document_data = super().describe() - resource_document_data[self.name].update( - dict( - external="FILESTORE:", - dtype="array", - ) - ) - return resource_document_data diff --git a/sirepo_bluesky/common/__init__.py b/sirepo_bluesky/common/__init__.py new file mode 100644 index 00000000..66d0bb29 --- /dev/null +++ b/sirepo_bluesky/common/__init__.py @@ -0,0 +1,88 @@ +import logging +from collections import deque + +from ophyd import Signal +from ophyd.sim import NullStatus + +logger = logging.getLogger("sirepo-bluesky") +# Note: the following handler could be created/added to the logger on the client side: +# import sys +# stream_handler = logging.StreamHandler(sys.stdout) +# logger.addHandler(stream_handler) + +RESERVED_OPHYD_TO_SIREPO_ATTRS = { # ophyd <-> sirepo + "position": "element_position", + "name": "element_name", + "class": "command_class", +} +RESERVED_SIREPO_TO_OPHYD_ATTRS = {v: k for k, v in RESERVED_OPHYD_TO_SIREPO_ATTRS.items()} + + +class ExternalFileReference(Signal): + """ + A pure software Signal that describe()s an image in an external file. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def describe(self): + resource_document_data = super().describe() + resource_document_data[self.name].update( + dict( + external="FILESTORE:", + dtype="array", + ) + ) + return resource_document_data + + +class SirepoSignal(Signal): + def __init__(self, sirepo_dict, sirepo_param, *args, **kwargs): + super().__init__(*args, **kwargs) + self._sirepo_dict = sirepo_dict + self._sirepo_param = sirepo_param + if sirepo_param in RESERVED_SIREPO_TO_OPHYD_ATTRS: + self._sirepo_param = RESERVED_SIREPO_TO_OPHYD_ATTRS[sirepo_param] + + def set(self, value, *, timeout=None, settle_time=None): + logger.debug(f"Setting value for {self.name} to {value}") + self._sirepo_dict[self._sirepo_param] = value + self._readback = value + return NullStatus() + + def put(self, *args, **kwargs): + self.set(*args, **kwargs).wait() + + +class ReadOnlyException(Exception): + ... + + +class SirepoSignalRO(SirepoSignal): + def set(self, *args, **kwargs): + raise ReadOnlyException("Cannot set/put the read-only signal.") + + +class BlueskyFlyer: + def __init__(self): + self.name = "bluesky_flyer" + self._asset_docs_cache = deque() + self._resource_uids = [] + self._datum_counter = None + self._datum_ids = [] + + def kickoff(self): + return NullStatus() + + def complete(self): + return NullStatus() + + def collect(self): + ... + + def collect_asset_docs(self): + items = list(self._asset_docs_cache) + self._asset_docs_cache.clear() + for item in items: + yield item diff --git a/sirepo_bluesky/common/base_classes.py b/sirepo_bluesky/common/base_classes.py new file mode 100644 index 00000000..f64f54ed --- /dev/null +++ b/sirepo_bluesky/common/base_classes.py @@ -0,0 +1,74 @@ +import hashlib +import json +from collections import deque + +from ophyd import Component as Cpt +from ophyd import Device, Signal +from ophyd.sim import NullStatus + +from sirepo_bluesky.common import ExternalFileReference + + +class DeviceWithJSONData(Device): + sirepo_data_json = Cpt(Signal, kind="normal", value="") + sirepo_data_hash = Cpt(Signal, kind="normal", value="") + duration = Cpt(Signal, kind="normal", value=-1.0) + + def trigger(self, *args, **kwargs): + super().trigger(*args, **kwargs) + + json_str = json.dumps(self.connection.data) + json_hash = hashlib.sha256(json_str.encode()).hexdigest() + self.sirepo_data_json.put(json_str) + self.sirepo_data_hash.put(json_hash) + + return NullStatus() + + +class SirepoWatchpointBase(DeviceWithJSONData): + image = Cpt(ExternalFileReference, kind="normal") + flux = Cpt(Signal, kind="hinted") + mean = Cpt(Signal, kind="normal") + x = Cpt(Signal, kind="normal") + y = Cpt(Signal, kind="normal") + fwhm_x = Cpt(Signal, kind="normal") + fwhm_y = Cpt(Signal, kind="normal") + photon_energy = Cpt(Signal, kind="normal") + horizontal_extent_start = Cpt(Signal) + horizontal_extent_end = Cpt(Signal) + vertical_extent_start = Cpt(Signal) + vertical_extent_end = Cpt(Signal) + + def __init__( + self, + *args, + root_dir="/tmp/sirepo-bluesky-data", + assets_dir=None, + result_file=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self._root_dir = root_dir + self._assets_dir = assets_dir + self._result_file = result_file + + self._asset_docs_cache = deque() + self._resource_document = None + self._datum_factory = None + + self._sim_type = self.connection.data["simulationType"] + if self._sim_type not in self._allowed_sim_types: + raise RuntimeError( + f"Unknown simulation type: {self._sim_type}\nAllowed simulation types: {self._allowed_sim_types}" + ) + + self._report = None + if hasattr(self, "id"): + self._report = f"watchpointReport{self.id._sirepo_dict['id']}" + + def collect_asset_docs(self): + items = list(self._asset_docs_cache) + self._asset_docs_cache.clear() + for item in items: + yield item diff --git a/sirepo_bluesky/common/create_classes.py b/sirepo_bluesky/common/create_classes.py new file mode 100644 index 00000000..30215d92 --- /dev/null +++ b/sirepo_bluesky/common/create_classes.py @@ -0,0 +1,192 @@ +import copy +from collections import namedtuple + +import inflection +from ophyd import Component as Cpt +from ophyd import Device + +from sirepo_bluesky.common import RESERVED_OPHYD_TO_SIREPO_ATTRS, SirepoSignal, logger +from sirepo_bluesky.shadow.shadow_ophyd import SirepoWatchpointShadow +from sirepo_bluesky.srw.srw_ophyd import ( + PropagationConfig, + SimplePropagationConfig, + SingleElectronSpectrumReport, + SirepoSignalCRL, + SirepoSignalCrystal, + SirepoSignalGrazingAngle, + SirepoWatchpointSRW, +) + + +def create_classes(connection, create_objects=True, extra_model_fields=[]): + classes = {} + objects = {} + data = copy.deepcopy(connection.data) + + sim_type = connection.sim_type + + if sim_type == "srw": + SirepoWatchpoint = SirepoWatchpointSRW + elif sim_type == "shadow": + SirepoWatchpoint = SirepoWatchpointShadow + + SimTypeConfig = namedtuple("SimTypeConfig", "element_location class_name_field") + + srw_config = SimTypeConfig("beamline", "title") + shadow_config = SimTypeConfig("beamline", "title") + madx_config = SimTypeConfig("elements", "element_name") + + config_dict = { + "srw": srw_config, + "shadow": shadow_config, + "madx": madx_config, + } + + model_fields = [config_dict[sim_type].element_location] + extra_model_fields + + data_models = {} + for model_field in model_fields: + if sim_type == "srw" and model_field in ["undulator", "intensityReport"]: + if model_field == "intensityReport": + title = "SingleElectronSpectrum" + else: + title = model_field + data["models"][model_field].update({"title": title, "type": model_field}) + data_models[model_field] = [data["models"][model_field]] + else: + data_models[model_field] = data["models"][model_field] + + for model_field, data_model in data_models.items(): + for i, el in enumerate(data_model): # 'el' is a dict, 'data_model' is a list of dicts + logger.debug(f"Processing {el}...") + + for ophyd_key, sirepo_key in RESERVED_OPHYD_TO_SIREPO_ATTRS.items(): + # We have to rename the reserved attribute names. Example error + # from ophyd: + # + # TypeError: The attribute name(s) {'position'} are part of the + # bluesky interface and cannot be used as component names. Choose + # a different name. + if ophyd_key in el: + el[sirepo_key] = el[ophyd_key] + el.pop(ophyd_key) + else: + pass + + class_name = el[config_dict[sim_type].class_name_field] + if model_field == "commands": + # Use command type and index in the model as class name to + # prevent overwriting any other elements or rpnVariables + # Examples of class names: beam0, select1, twiss7 + class_name = inflection.camelize(f"{el['_type']}{i}") + else: + class_name = inflection.camelize( + el[config_dict[sim_type].class_name_field].replace(" ", "_").replace(".", "").replace("-", "_") + ) + object_name = inflection.underscore(class_name) + + base_classes = (Device,) + extra_kwargs = {"connection": connection} + if "type" in el and el["type"] == "watch": + base_classes = (SirepoWatchpoint, Device) + elif "type" in el and el["type"] == "intensityReport": + base_classes = (SingleElectronSpectrumReport, Device) + + components = {} + for k, v in el.items(): + if ( + "type" in el + and el["type"] in ["sphericalMirror", "toroidalMirror", "ellipsoidMirror"] + and k == "grazingAngle" + ): + cpt_class = SirepoSignalGrazingAngle + elif "type" in el and el["type"] == "crl" and k not in ["absoluteFocusPosition", "focalDistance"]: + cpt_class = SirepoSignalCRL + elif ( + "type" in el + and el["type"] == "crystal" + and k + not in [ + "dSpacing", + "grazingAngle", + "nvx", + "nvy", + "nvz", + "outframevx", + "outframevy", + "outoptvx", + "outoptvy", + "outoptvz", + "psi0i", + "psi0r", + "psiHBi", + "psiHBr", + "psiHi", + "psiHr", + "tvx", + "tvy", + ] + ): + cpt_class = SirepoSignalCrystal + else: + cpt_class = SirepoSignal + + if "type" in el and el["type"] not in ["undulator", "intensityReport"]: + sirepo_dict = connection.data["models"][model_field][i] + elif sim_type == "madx" and model_field in ["rpnVariables", "commands"]: + sirepo_dict = connection.data["models"][model_field][i] + else: + sirepo_dict = connection.data["models"][model_field] + + components[k] = Cpt( + cpt_class, + value=(float(v) if type(v) is int else v), + sirepo_dict=sirepo_dict, + sirepo_param=k, + ) + components.update(**extra_kwargs) + + cls = type( + class_name, + base_classes, + components, + ) + + classes[object_name] = cls + if create_objects: + objects[object_name] = cls(name=object_name) + + if sim_type == "srw" and model_field == "beamline": + prop_params = connection.data["models"]["propagation"][str(el["id"])][0] + sirepo_propagation = [] + object_name += "_propagation" + for i in range(9): + sirepo_propagation.append( + SirepoSignal( + name=f"{object_name}_{SimplePropagationConfig._fields[i]}", + value=float(prop_params[i]), + sirepo_dict=prop_params, + sirepo_param=i, + ) + ) + if create_objects: + objects[object_name] = PropagationConfig(*sirepo_propagation[:]) + + if sim_type == "srw": + post_prop_params = connection.data["models"]["postPropagation"] + sirepo_propagation = [] + object_name = "post_propagation" + for i in range(9): + sirepo_propagation.append( + SirepoSignal( + name=f"{object_name}_{SimplePropagationConfig._fields[i]}", + value=float(post_prop_params[i]), + sirepo_dict=post_prop_params, + sirepo_param=i, + ) + ) + classes["propagation_parameters"] = PropagationConfig + if create_objects: + objects[object_name] = PropagationConfig(*sirepo_propagation[:]) + + return classes, objects diff --git a/sirepo_bluesky/common/misc.py b/sirepo_bluesky/common/misc.py new file mode 100644 index 00000000..f9ca7867 --- /dev/null +++ b/sirepo_bluesky/common/misc.py @@ -0,0 +1,53 @@ +from sirepo_bluesky.common.create_classes import create_classes +from sirepo_bluesky.common.sirepo_client import SirepoClient + + +def populate_beamline(sim_name, *args): + """ + Parameters + ---------- + *args : + For one beamline, ``connection, indices, new_positions``. + In general: + + .. code-block:: python + + connection1, indices1, new_positions1 + connection2, indices2, new_positions2 + ..., + connectionN, indicesN, new_positionsN + """ + if len(args) % 3 != 0: + raise ValueError( + "Incorrect signature, arguments must be of the signature: connection, indices, new_positions, ..." + ) + + connections = [] + indices_list = [] + new_positions_list = [] + + for i in range(0, len(args), 3): + connections.append(args[i]) + indices_list.append(args[i + 1]) + new_positions_list.append(args[i + 2]) + + emptysim = SirepoClient("http://localhost:8000") + emptysim.auth("srw", sim_id="emptysim") + new_beam = emptysim.copy_sim(sim_name=sim_name) + new_beamline = new_beam.data["models"]["beamline"] + new_propagation = new_beam.data["models"]["propagation"] + + curr_id = 0 + for connection, indices, new_positions in zip(connections, indices_list, new_positions_list): + old_beamline = connection.data["models"]["beamline"] + old_propagation = connection.data["models"]["propagation"] + for i, pos in zip(indices, new_positions): + new_beamline.append(old_beamline[i].copy()) + new_beamline[curr_id]["id"] = curr_id + new_beamline[curr_id]["position"] = pos + new_propagation[str(curr_id)] = old_propagation[str(old_beamline[i]["id"])].copy() + curr_id += 1 + + classes, objects = create_classes(new_beam) + + return new_beam, classes, objects diff --git a/sirepo_bluesky/sirepo_bluesky.py b/sirepo_bluesky/common/sirepo_client.py similarity index 94% rename from sirepo_bluesky/sirepo_bluesky.py rename to sirepo_bluesky/common/sirepo_client.py index 17e33f91..61e55913 100644 --- a/sirepo_bluesky/sirepo_bluesky.py +++ b/sirepo_bluesky/common/sirepo_client.py @@ -8,11 +8,11 @@ import requests -class SirepoBlueskyClientException(Exception): +class SirepoClientException(Exception): pass -class SirepoBluesky(object): +class SirepoClient(object): """ Invoke a remote sirepo simulation with custom arguments. @@ -26,7 +26,7 @@ class SirepoBluesky(object): # sim_id is the last section from the simulation url # e.g., '.../1tNWph0M' sim_id = '1tNWph0M' - sb = SirepoBluesky('http://localhost:8000') + sb = SirepoClient('http://localhost:8000') data, schema = sb.auth('srw', sim_id) # update the model values and choose the report data['models']['undulator']['verticalAmplitude'] = 0.95 @@ -75,7 +75,7 @@ def auth(self, sim_type, sim_id): self.cookies = None res = self._post_json("auth-bluesky-login", req) if not ("state" in res and res["state"] == "ok"): - raise SirepoBlueskyClientException(f"bluesky_auth failed: {res}") + raise SirepoClientException(f"bluesky_auth failed: {res}") self.sim_type = sim_type self.sim_id = sim_id self.schema = res["schema"] @@ -83,7 +83,7 @@ def auth(self, sim_type, sim_id): return self.data, self.schema def copy_sim(self, sim_name): - """Create a copy of the current simulation. Returns a new instance of SirepoBluesky.""" + """Create a copy of the current simulation. Returns a new instance of SirepoClient.""" if not self.sim_id: raise ValueError(f"sim_id is {self.sim_id!r}") res = self._post_json( @@ -95,7 +95,7 @@ def copy_sim(self, sim_name): "name": sim_name, }, ) - copy = SirepoBluesky(self.server, self.secret) + copy = SirepoClient(self.server, self.secret) copy.cookies = self.cookies copy.sim_type = self.sim_type copy.sim_id = res["models"]["simulation"]["simulationId"] @@ -116,7 +116,7 @@ def delete_copy(self): }, ) if not res["state"] == "ok": - raise SirepoBlueskyClientException(f"Could not delete simulation: {res}") + raise SirepoClientException(f"Could not delete simulation: {res}") self.sim_id = None def compute_crl_characteristics(self, crl_element): @@ -148,7 +148,7 @@ def compute_crystal_init(self, crystal_element): def compute_crystal_orientation(self, crystal_element): res_init = self.compute_crystal_init(crystal_element) if res_init.pop("state") != "completed": - raise SirepoBlueskyClientException("crystal_init returned error state") + raise SirepoClientException("crystal_init returned error state") res = self._post_json( "stateless-compute", { @@ -287,13 +287,13 @@ def run_simulation(self, max_status_calls=1000): time.sleep(res["nextRequestSeconds"]) res = self._post_json("run-status", res["nextRequest"]) if not state == "completed": - raise SirepoBlueskyClientException(f"simulation failed to complete: {state}") + raise SirepoClientException(f"simulation failed to complete: {state}") return res, time.monotonic() - start_time @staticmethod def _assert_success(response, url): if not response.status_code == requests.codes.ok: - raise SirepoBlueskyClientException(f"{url} request failed, status: {response.status_code}") + raise SirepoClientException(f"{url} request failed, status: {response.status_code}") def _post_json(self, url, payload): response = requests.post(f"{self.server}/{url}", json=payload, cookies=self.cookies) diff --git a/sirepo_bluesky/madx/__init__.py b/sirepo_bluesky/madx/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sirepo_bluesky/madx_flyer.py b/sirepo_bluesky/madx/madx_flyer.py similarity index 99% rename from sirepo_bluesky/madx_flyer.py rename to sirepo_bluesky/madx/madx_flyer.py index 08560efd..63176fd6 100644 --- a/sirepo_bluesky/madx_flyer.py +++ b/sirepo_bluesky/madx/madx_flyer.py @@ -7,8 +7,8 @@ from event_model import compose_resource from ophyd.sim import NullStatus, new_uid +from ..srw.srw_flyer import BlueskyFlyer from .madx_handler import read_madx_file -from .sirepo_flyer import BlueskyFlyer logger = logging.getLogger("sirepo-bluesky") diff --git a/sirepo_bluesky/madx_handler.py b/sirepo_bluesky/madx/madx_handler.py similarity index 100% rename from sirepo_bluesky/madx_handler.py rename to sirepo_bluesky/madx/madx_handler.py diff --git a/sirepo_bluesky/shadow/__init__.py b/sirepo_bluesky/shadow/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sirepo_bluesky/shadow_handler.py b/sirepo_bluesky/shadow/shadow_handler.py similarity index 86% rename from sirepo_bluesky/shadow_handler.py rename to sirepo_bluesky/shadow/shadow_handler.py index b169a177..4f3430d3 100644 --- a/sirepo_bluesky/shadow_handler.py +++ b/sirepo_bluesky/shadow/shadow_handler.py @@ -5,7 +5,7 @@ import Shadow.ShadowLibExtensions as sd import Shadow.ShadowTools -from . import utils +from sirepo_bluesky import utils def read_shadow_file_col(filename, parameter=30): @@ -65,8 +65,10 @@ def read_shadow_file_col(filename, parameter=30): "shape": data.shape, "mean": mean_value, "photon_energy": mean_value, - "horizontal_extent": [0, 1], - "vertical_extent": [0, 1], + "horizontal_extent_start": 0, + "horizontal_extent_end": 1, + "vertical_extent_start": 0, + "vertical_extent_end": 1, # 'labels': labels, # 'units': units, } @@ -87,12 +89,12 @@ def read_shadow_file(filename, histogram_bins=None): # This returns a list of N values (N=number of rays) photon_energy_list = Shadow.ShadowTools.getshcol(filename, col=11) # 11=Energy [eV] - data = data_dict["histogram"] + data = data_dict["histogram"].astype(float) photon_energy = np.mean(photon_energy_list) # convert to um - horizontal_extent = 1e3 * np.array(data_dict["xrange"][:2]) - vertical_extent = 1e3 * np.array(data_dict["yrange"][:2]) + horizontal_extent = 1e3 * np.array(data_dict["xrange"][:2]).astype(float) + vertical_extent = 1e3 * np.array(data_dict["yrange"][:2]).astype(float) ret = { "data": data, @@ -100,8 +102,10 @@ def read_shadow_file(filename, histogram_bins=None): "flux": data.sum(), "mean": data.mean(), "photon_energy": photon_energy, - "horizontal_extent": horizontal_extent, - "vertical_extent": vertical_extent, + "horizontal_extent_start": horizontal_extent[0], + "horizontal_extent_end": horizontal_extent[1], + "vertical_extent_start": vertical_extent[0], + "vertical_extent_end": vertical_extent[1], "units": "um", } diff --git a/sirepo_bluesky/shadow/shadow_ophyd.py b/sirepo_bluesky/shadow/shadow_ophyd.py new file mode 100644 index 00000000..23111fec --- /dev/null +++ b/sirepo_bluesky/shadow/shadow_ophyd.py @@ -0,0 +1,142 @@ +import datetime +import json +import time as ttime +from pathlib import Path + +from event_model import compose_resource +from ophyd import Component as Cpt +from ophyd import Signal +from ophyd.sim import NullStatus, new_uid + +from sirepo_bluesky.common import logger +from sirepo_bluesky.common.base_classes import DeviceWithJSONData, SirepoWatchpointBase +from sirepo_bluesky.shadow.shadow_handler import read_shadow_file + + +class SirepoWatchpointShadow(SirepoWatchpointBase): + def __init__( + self, + *args, + root_dir="/tmp/sirepo-bluesky-data", + assets_dir=None, + result_file=None, + **kwargs, + ): + self._allowed_sim_types = ("shadow",) + super().__init__(*args, root_dir=root_dir, assets_dir=assets_dir, result_file=result_file, **kwargs) + + def trigger(self, *args, **kwargs): + logger.debug(f"Custom trigger for {self.name}") + + self.connection.data["report"] = self._report + + date = datetime.datetime.now() + self._assets_dir = date.strftime("%Y/%m/%d") + self._result_file = f"{new_uid()}.dat" + + self._resource_document, self._datum_factory, _ = compose_resource( + start={"uid": "needed for compose_resource() but will be discarded"}, + spec=self.connection.data["simulationType"], + root=self._root_dir, + resource_path=str(Path(self._assets_dir) / Path(self._result_file)), + resource_kwargs={}, + ) + # now discard the start uid, a real one will be added later + self._resource_document.pop("run_start") + self._asset_docs_cache.append(("resource", self._resource_document)) + + sim_result_file = str( + Path(self._resource_document["root"]) / Path(self._resource_document["resource_path"]) + ) + + _, duration = self.connection.run_simulation() + self.duration.put(duration) + + datafile = self.connection.get_datafile(file_index=-1) + + with open(sim_result_file, "wb") as f: + f.write(datafile) + + conn_data = self.connection.data + nbins = conn_data["models"][self._report]["histogramBins"] + ret = read_shadow_file(sim_result_file, histogram_bins=nbins) + self._resource_document["resource_kwargs"]["histogram_bins"] = nbins + + def update_components(_data): + self.flux.put(_data["flux"]) + self.mean.put(_data["mean"]) + self.x.put(_data["x"]) + self.y.put(_data["y"]) + self.fwhm_x.put(_data["fwhm_x"]) + self.fwhm_y.put(_data["fwhm_y"]) + self.photon_energy.put(_data["photon_energy"]) + self.horizontal_extent_start.put(_data["horizontal_extent_start"]) + self.horizontal_extent_end.put(_data["horizontal_extent_end"]) + self.vertical_extent_start.put(_data["vertical_extent_start"]) + self.vertical_extent_end.put(_data["vertical_extent_end"]) + + update_components(ret) + + datum_document = self._datum_factory(datum_kwargs={}) + self._asset_docs_cache.append(("datum", datum_document)) + + self.image.put(datum_document["datum_id"]) + + self._resource_document = None + self._datum_factory = None + + logger.debug(f"\nReport for {self.name}: {self.connection.data['report']}\n") + + # We call the trigger on super at the end to update the sirepo_data_json + # and the corresponding hash after the simulation is run. + super().trigger(*args, **kwargs) + return NullStatus() + + def describe(self): + res = super().describe() + ny = nx = self.connection.data["models"][self._report]["histogramBins"] + res[self.image.name].update(dict(external="FILESTORE", shape=(ny, nx))) + return res + + +# This is for backwards compatibility +SirepoWatchpoint = SirepoWatchpointShadow + + +class BeamStatisticsReport(DeviceWithJSONData): + # NOTE: TES aperture changes don't seem to change the beam statistics + # report graph on the website? + + report = Cpt(Signal, value="", kind="normal") # values are always strings, not dictionaries + + def __init__(self, connection, *args, **kwargs): + super().__init__(*args, **kwargs) + self.connection = connection + self._report = "beamStatisticsReport" + + def trigger(self, *args, **kwargs): + logger.debug(f"Custom trigger for {self.name}") + + self.connection.data["report"] = self._report + + start_time = ttime.monotonic() + self.connection.run_simulation() + self.duration.put(ttime.monotonic() - start_time) + + datafile = self.connection.get_datafile(file_index=-1) + self.report.put(json.dumps(json.loads(datafile.decode()))) + + logger.debug(f"\nReport for {self.name}: {self.connection.data['report']}\n") + + # We call the trigger on super at the end to update the sirepo_data_json + # and the corresponding hash after the simulation is run. + super().trigger(*args, **kwargs) + return NullStatus() + + def stage(self): + super().stage() + self.report.put("") + + def unstage(self): + super().unstage() + self.report.put("") diff --git a/sirepo_bluesky/sirepo_ophyd.py b/sirepo_bluesky/sirepo_ophyd.py deleted file mode 100644 index 289fd2e7..00000000 --- a/sirepo_bluesky/sirepo_ophyd.py +++ /dev/null @@ -1,613 +0,0 @@ -import copy -import datetime -import hashlib -import json -import logging -import time -from collections import OrderedDict, deque, namedtuple -from pathlib import Path - -import inflection -from event_model import compose_resource -from ophyd import Component as Cpt -from ophyd import Device, Signal -from ophyd.sim import NullStatus, new_uid - -from sirepo_bluesky.sirepo_bluesky import SirepoBluesky - -from . import ExternalFileReference -from .shadow_handler import read_shadow_file -from .srw_handler import read_srw_file - -logger = logging.getLogger("sirepo-bluesky") -# Note: the following handler could be created/added to the logger on the client side: -# import sys -# stream_handler = logging.StreamHandler(sys.stdout) -# logger.addHandler(stream_handler) - -RESERVED_OPHYD_TO_SIREPO_ATTRS = { # ophyd <-> sirepo - "position": "element_position", - "name": "element_name", - "class": "command_class", -} -RESERVED_SIREPO_TO_OPHYD_ATTRS = {v: k for k, v in RESERVED_OPHYD_TO_SIREPO_ATTRS.items()} - - -class SirepoSignal(Signal): - def __init__(self, sirepo_dict, sirepo_param, *args, **kwargs): - super().__init__(*args, **kwargs) - self._sirepo_dict = sirepo_dict - self._sirepo_param = sirepo_param - if sirepo_param in RESERVED_SIREPO_TO_OPHYD_ATTRS: - self._sirepo_param = RESERVED_SIREPO_TO_OPHYD_ATTRS[sirepo_param] - - def set(self, value, *, timeout=None, settle_time=None): - logger.debug(f"Setting value for {self.name} to {value}") - self._sirepo_dict[self._sirepo_param] = value - self._readback = value - return NullStatus() - - def put(self, *args, **kwargs): - self.set(*args, **kwargs).wait() - - -class ReadOnlyException(Exception): - ... - - -class SirepoSignalRO(SirepoSignal): - def set(self, *args, **kwargs): - raise ReadOnlyException("Cannot set/put the read-only signal.") - - -class DeviceWithJSONData(Device): - sirepo_data_json = Cpt(Signal, kind="normal", value="") - sirepo_data_hash = Cpt(Signal, kind="normal", value="") - duration = Cpt(Signal, kind="normal", value=-1.0) - - def trigger(self, *args, **kwargs): - super().trigger(*args, **kwargs) - - json_str = json.dumps(self.connection.data) - json_hash = hashlib.sha256(json_str.encode()).hexdigest() - self.sirepo_data_json.put(json_str) - self.sirepo_data_hash.put(json_hash) - - return NullStatus() - - -class SirepoWatchpoint(DeviceWithJSONData): - image = Cpt(ExternalFileReference, kind="normal") - shape = Cpt(Signal) - flux = Cpt(Signal, kind="hinted") - mean = Cpt(Signal, kind="normal") - x = Cpt(Signal, kind="normal") - y = Cpt(Signal, kind="normal") - fwhm_x = Cpt(Signal, kind="normal") - fwhm_y = Cpt(Signal, kind="normal") - photon_energy = Cpt(Signal, kind="normal") - horizontal_extent = Cpt(Signal) - vertical_extent = Cpt(Signal) - - def __init__( - self, - *args, - root_dir="/tmp/sirepo-bluesky-data", - assets_dir=None, - result_file=None, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self._root_dir = root_dir - self._assets_dir = assets_dir - self._result_file = result_file - - self._asset_docs_cache = deque() - self._resource_document = None - self._datum_factory = None - - sim_type = self.connection.data["simulationType"] - allowed_sim_types = ("srw", "shadow", "madx") - if sim_type not in allowed_sim_types: - raise RuntimeError( - f"Unknown simulation type: {sim_type}\nAllowed simulation types: {allowed_sim_types}" - ) - - def trigger(self, *args, **kwargs): - logger.debug(f"Custom trigger for {self.name}") - - date = datetime.datetime.now() - self._assets_dir = date.strftime("%Y/%m/%d") - self._result_file = f"{new_uid()}.dat" - - self._resource_document, self._datum_factory, _ = compose_resource( - start={"uid": "needed for compose_resource() but will be discarded"}, - spec=self.connection.data["simulationType"], - root=self._root_dir, - resource_path=str(Path(self._assets_dir) / Path(self._result_file)), - resource_kwargs={}, - ) - # now discard the start uid, a real one will be added later - self._resource_document.pop("run_start") - self._asset_docs_cache.append(("resource", self._resource_document)) - - sim_result_file = str( - Path(self._resource_document["root"]) / Path(self._resource_document["resource_path"]) - ) - - self.connection.data["report"] = f"watchpointReport{self.id._sirepo_dict['id']}" - - _, duration = self.connection.run_simulation() - self.duration.put(duration) - - datafile = self.connection.get_datafile(file_index=-1) - - with open(sim_result_file, "wb") as f: - f.write(datafile) - - conn_data = self.connection.data - sim_type = conn_data["simulationType"] - if sim_type == "srw": - ndim = 2 # this will always be a report with 2D data. - ret = read_srw_file(sim_result_file, ndim=ndim) - self._resource_document["resource_kwargs"]["ndim"] = ndim - elif sim_type == "shadow": - nbins = conn_data["models"][conn_data["report"]]["histogramBins"] - ret = read_shadow_file(sim_result_file, histogram_bins=nbins) - self._resource_document["resource_kwargs"]["histogram_bins"] = nbins - - def update_components(_data): - self.shape.put(_data["shape"]) - self.flux.put(_data["flux"]) - self.mean.put(_data["mean"]) - self.x.put(_data["x"]) - self.y.put(_data["y"]) - self.fwhm_x.put(_data["fwhm_x"]) - self.fwhm_y.put(_data["fwhm_y"]) - self.photon_energy.put(_data["photon_energy"]) - self.horizontal_extent.put(_data["horizontal_extent"]) - self.vertical_extent.put(_data["vertical_extent"]) - - update_components(ret) - - datum_document = self._datum_factory(datum_kwargs={}) - self._asset_docs_cache.append(("datum", datum_document)) - - self.image.put(datum_document["datum_id"]) - - self._resource_document = None - self._datum_factory = None - - logger.debug(f"\nReport for {self.name}: {self.connection.data['report']}\n") - - # We call the trigger on super at the end to update the sirepo_data_json - # and the corresponding hash after the simulation is run. - super().trigger(*args, **kwargs) - return NullStatus() - - def describe(self): - res = super().describe() - res[self.image.name].update(dict(external="FILESTORE")) - return res - - def unstage(self): - super().unstage() - self._resource_document = None - - def collect_asset_docs(self): - items = list(self._asset_docs_cache) - self._asset_docs_cache.clear() - for item in items: - yield item - - -class SingleElectronSpectrumReport(SirepoWatchpoint): - def trigger(self, *args, **kwargs): - logger.debug(f"Custom trigger for {self.name}") - - date = datetime.datetime.now() - self._assets_dir = date.strftime("%Y/%m/%d") - self._result_file = f"{new_uid()}.dat" - - self._resource_document, self._datum_factory, _ = compose_resource( - start={"uid": "needed for compose_resource() but will be discarded"}, - spec=self.connection.data["simulationType"], - root=self._root_dir, - resource_path=str(Path(self._assets_dir) / Path(self._result_file)), - resource_kwargs={}, - ) - # now discard the start uid, a real one will be added later - self._resource_document.pop("run_start") - self._asset_docs_cache.append(("resource", self._resource_document)) - - sim_result_file = str( - Path(self._resource_document["root"]) / Path(self._resource_document["resource_path"]) - ) - - self.connection.data["report"] = "intensityReport" - - start_time = time.monotonic() - self.connection.run_simulation() - self.duration.put(time.monotonic() - start_time) - - datafile = self.connection.get_datafile() - - with open(sim_result_file, "wb") as f: - f.write(datafile) - - conn_data = self.connection.data - sim_type = conn_data["simulationType"] - if sim_type == "srw": - ndim = 1 - ret = read_srw_file(sim_result_file, ndim=ndim) - self._resource_document["resource_kwargs"]["ndim"] = ndim - - def update_components(_data): - self.shape.put(_data["shape"]) - self.flux.put(_data["flux"]) - self.mean.put(_data["mean"]) - self.x.put(_data["x"]) - self.y.put(_data["y"]) - self.fwhm_x.put(_data["fwhm_x"]) - self.fwhm_y.put(_data["fwhm_y"]) - self.photon_energy.put(_data["photon_energy"]) - self.horizontal_extent.put(_data["horizontal_extent"]) - self.vertical_extent.put(_data["vertical_extent"]) - - update_components(ret) - - datum_document = self._datum_factory(datum_kwargs={}) - self._asset_docs_cache.append(("datum", datum_document)) - - self.image.put(datum_document["datum_id"]) - - self._resource_document = None - self._datum_factory = None - - logger.debug(f"\nReport for {self.name}: {self.connection.data['report']}\n") - - return NullStatus() - - -class BeamStatisticsReport(DeviceWithJSONData): - # NOTE: TES aperture changes don't seem to change the beam statistics - # report graph on the website? - - report = Cpt(Signal, value="", kind="normal") # values are always strings, not dictionaries - - def __init__(self, connection, *args, **kwargs): - super().__init__(*args, **kwargs) - self.connection = connection - - def trigger(self, *args, **kwargs): - logger.debug(f"Custom trigger for {self.name}") - - self.connection.data["report"] = "beamStatisticsReport" - - start_time = time.monotonic() - self.connection.run_simulation() - self.duration.put(time.monotonic() - start_time) - - datafile = self.connection.get_datafile(file_index=-1) - self.report.put(json.dumps(json.loads(datafile.decode()))) - - logger.debug(f"\nReport for {self.name}: {self.connection.data['report']}\n") - - # We call the trigger on super at the end to update the sirepo_data_json - # and the corresponding hash after the simulation is run. - super().trigger(*args, **kwargs) - return NullStatus() - - def stage(self): - super().stage() - self.report.put("") - - def unstage(self): - super().unstage() - self.report.put("") - - -class SirepoSignalGrazingAngle(SirepoSignal): - def set(self, value): - super().set(value) - ret = self.parent.connection.compute_grazing_orientation(self._sirepo_dict) - # State is added to the ret dict from compute_grazing_orientation and we - # want to make sure the vectors are updated properly every time the - # grazing angle is updated. - ret.pop("state") - # Update vector components - for cpt in [ - "normalVectorX", - "normalVectorY", - "normalVectorZ", - "tangentialVectorX", - "tangentialVectorY", - ]: - getattr(self.parent, cpt).put(ret[cpt]) - return NullStatus() - - -class SirepoSignalCRL(SirepoSignal): - def set(self, value): - super().set(value) - ret = self.parent.connection.compute_crl_characteristics(self._sirepo_dict) - # State is added to the ret dict from compute_crl_characteristics and we - # want to make sure the crl element is updated properly when parameters are changed. - ret.pop("state") - # Update crl element - for cpt in ["absoluteFocusPosition", "focalDistance"]: - getattr(self.parent, cpt).put(ret[cpt]) - return NullStatus() - - -class SirepoSignalCrystal(SirepoSignal): - def set(self, value): - super().set(value) - ret = self.parent.connection.compute_crystal_orientation(self._sirepo_dict) - # State is added to the ret dict from compute_crystal_orientation and we - # want to make sure the crystal element is updated properly when parameters are changed. - ret.pop("state") - # Update crystal element - for cpt in [ - "dSpacing", - "grazingAngle", - "nvx", - "nvy", - "nvz", - "outframevx", - "outframevy", - "outoptvx", - "outoptvy", - "outoptvz", - "psi0i", - "psi0r", - "psiHBi", - "psiHBr", - "psiHi", - "psiHr", - "tvx", - "tvy", - ]: - getattr(self.parent, cpt).put(ret[cpt]) - return NullStatus() - - -SimplePropagationConfig = namedtuple( - "PropagationConfig", - "resize_before resize_after precision propagator_type " - + "fourier_resize hrange_mod hres_mod vrange_mod vres_mod", -) - - -class PropagationConfig(SimplePropagationConfig): - read_attrs = list(SimplePropagationConfig._fields) - component_names = SimplePropagationConfig._fields - - def read(self): - read_attrs = self.read_attrs - propagation_read = OrderedDict() - for field in read_attrs: - propagation_read[field] = getattr(self, field).read() - return propagation_read - - -def create_classes(connection, create_objects=True, extra_model_fields=[]): - classes = {} - objects = {} - data = copy.deepcopy(connection.data) - - sim_type = connection.sim_type - - SimTypeConfig = namedtuple("SimTypeConfig", "element_location class_name_field") - - srw_config = SimTypeConfig("beamline", "title") - shadow_config = SimTypeConfig("beamline", "title") - madx_config = SimTypeConfig("elements", "element_name") - - config_dict = { - "srw": srw_config, - "shadow": shadow_config, - "madx": madx_config, - } - - model_fields = [config_dict[sim_type].element_location] + extra_model_fields - - data_models = {} - for model_field in model_fields: - if sim_type == "srw" and model_field in ["undulator", "intensityReport"]: - if model_field == "intensityReport": - title = "SingleElectronSpectrum" - else: - title = model_field - data["models"][model_field].update({"title": title, "type": model_field}) - data_models[model_field] = [data["models"][model_field]] - else: - data_models[model_field] = data["models"][model_field] - - for model_field, data_model in data_models.items(): - for i, el in enumerate(data_model): # 'el' is a dict, 'data_model' is a list of dicts - logger.debug(f"Processing {el}...") - - for ophyd_key, sirepo_key in RESERVED_OPHYD_TO_SIREPO_ATTRS.items(): - # We have to rename the reserved attribute names. Example error - # from ophyd: - # - # TypeError: The attribute name(s) {'position'} are part of the - # bluesky interface and cannot be used as component names. Choose - # a different name. - if ophyd_key in el: - el[sirepo_key] = el[ophyd_key] - el.pop(ophyd_key) - else: - pass - - class_name = el[config_dict[sim_type].class_name_field] - if model_field == "commands": - # Use command type and index in the model as class name to - # prevent overwriting any other elements or rpnVariables - # Examples of class names: beam0, select1, twiss7 - class_name = inflection.camelize(f"{el['_type']}{i}") - else: - class_name = inflection.camelize( - el[config_dict[sim_type].class_name_field].replace(" ", "_").replace(".", "").replace("-", "_") - ) - object_name = inflection.underscore(class_name) - - base_classes = (Device,) - extra_kwargs = {"connection": connection} - if "type" in el and el["type"] == "watch": - base_classes = (SirepoWatchpoint, Device) - elif "type" in el and el["type"] == "intensityReport": - base_classes = (SingleElectronSpectrumReport, Device) - - components = {} - for k, v in el.items(): - if ( - "type" in el - and el["type"] in ["sphericalMirror", "toroidalMirror", "ellipsoidMirror"] - and k == "grazingAngle" - ): - cpt_class = SirepoSignalGrazingAngle - elif "type" in el and el["type"] == "crl" and k not in ["absoluteFocusPosition", "focalDistance"]: - cpt_class = SirepoSignalCRL - elif ( - "type" in el - and el["type"] == "crystal" - and k - not in [ - "dSpacing", - "grazingAngle", - "nvx", - "nvy", - "nvz", - "outframevx", - "outframevy", - "outoptvx", - "outoptvy", - "outoptvz", - "psi0i", - "psi0r", - "psiHBi", - "psiHBr", - "psiHi", - "psiHr", - "tvx", - "tvy", - ] - ): - cpt_class = SirepoSignalCrystal - else: - # TODO: Cover the cases for mirror and crystal grazing angles - cpt_class = SirepoSignal - - if "type" in el and el["type"] not in ["undulator", "intensityReport"]: - sirepo_dict = connection.data["models"][model_field][i] - elif sim_type == "madx" and model_field in ["rpnVariables", "commands"]: - sirepo_dict = connection.data["models"][model_field][i] - else: - sirepo_dict = connection.data["models"][model_field] - - components[k] = Cpt( - cpt_class, - value=(float(v) if type(v) is int else v), - sirepo_dict=sirepo_dict, - sirepo_param=k, - ) - components.update(**extra_kwargs) - - cls = type( - class_name, - base_classes, - components, - ) - - classes[object_name] = cls - if create_objects: - objects[object_name] = cls(name=object_name) - - if sim_type == "srw" and model_field == "beamline": - prop_params = connection.data["models"]["propagation"][str(el["id"])][0] - sirepo_propagation = [] - object_name += "_propagation" - for i in range(9): - sirepo_propagation.append( - SirepoSignal( - name=f"{object_name}_{SimplePropagationConfig._fields[i]}", - value=float(prop_params[i]), - sirepo_dict=prop_params, - sirepo_param=i, - ) - ) - if create_objects: - objects[object_name] = PropagationConfig(*sirepo_propagation[:]) - - if sim_type == "srw": - post_prop_params = connection.data["models"]["postPropagation"] - sirepo_propagation = [] - object_name = "post_propagation" - for i in range(9): - sirepo_propagation.append( - SirepoSignal( - name=f"{object_name}_{SimplePropagationConfig._fields[i]}", - value=float(post_prop_params[i]), - sirepo_dict=post_prop_params, - sirepo_param=i, - ) - ) - classes["propagation_parameters"] = PropagationConfig - if create_objects: - objects[object_name] = PropagationConfig(*sirepo_propagation[:]) - - return classes, objects - - -def populate_beamline(sim_name, *args): - """ - Parameters - ---------- - *args : - For one beamline, ``connection, indices, new_positions``. - In general: - - .. code-block:: python - - connection1, indices1, new_positions1 - connection2, indices2, new_positions2 - ..., - connectionN, indicesN, new_positionsN - """ - if len(args) % 3 != 0: - raise ValueError( - "Incorrect signature, arguments must be of the signature: connection, indices, new_positions, ..." - ) - - connections = [] - indices_list = [] - new_positions_list = [] - - for i in range(0, len(args), 3): - connections.append(args[i]) - indices_list.append(args[i + 1]) - new_positions_list.append(args[i + 2]) - - emptysim = SirepoBluesky("http://localhost:8000") - emptysim.auth("srw", sim_id="emptysim") - new_beam = emptysim.copy_sim(sim_name=sim_name) - new_beamline = new_beam.data["models"]["beamline"] - new_propagation = new_beam.data["models"]["propagation"] - - curr_id = 0 - for connection, indices, new_positions in zip(connections, indices_list, new_positions_list): - old_beamline = connection.data["models"]["beamline"] - old_propagation = connection.data["models"]["propagation"] - for i, pos in zip(indices, new_positions): - new_beamline.append(old_beamline[i].copy()) - new_beamline[curr_id]["id"] = curr_id - new_beamline[curr_id]["position"] = pos - new_propagation[str(curr_id)] = old_propagation[str(old_beamline[i]["id"])].copy() - curr_id += 1 - - classes, objects = create_classes(new_beam) - - return new_beam, classes, objects diff --git a/sirepo_bluesky/srw/__init__.py b/sirepo_bluesky/srw/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sirepo_bluesky/sirepo_flyer.py b/sirepo_bluesky/srw/srw_flyer.py similarity index 95% rename from sirepo_bluesky/sirepo_flyer.py rename to sirepo_bluesky/srw/srw_flyer.py index 62dc3f38..21c67d5c 100644 --- a/sirepo_bluesky/sirepo_flyer.py +++ b/sirepo_bluesky/srw/srw_flyer.py @@ -2,42 +2,17 @@ import hashlib import os import time as ttime -from collections import deque from multiprocessing import Manager, Process from pathlib import Path from ophyd.sim import NullStatus, new_uid -from sirepo_bluesky.srw_handler import read_srw_file +from sirepo_bluesky.common import BlueskyFlyer +from sirepo_bluesky.common.sirepo_client import SirepoClient +from sirepo_bluesky.srw.srw_handler import read_srw_file -from .sirepo_bluesky import SirepoBluesky - -class BlueskyFlyer: - def __init__(self): - self.name = "bluesky_flyer" - self._asset_docs_cache = deque() - self._resource_uids = [] - self._datum_counter = None - self._datum_ids = [] - - def kickoff(self): - return NullStatus() - - def complete(self): - return NullStatus() - - def collect(self): - ... - - def collect_asset_docs(self): - items = list(self._asset_docs_cache) - self._asset_docs_cache.clear() - for item in items: - yield item - - -class SirepoFlyer(BlueskyFlyer): +class SRWFlyer(BlueskyFlyer): """ Multiprocessing "flyer" for Sirepo simulations @@ -192,7 +167,7 @@ def run_parallel(self, value): raise TypeError(f"invalid type: {type(value)}. Must be boolean") def kickoff(self): - sb = SirepoBluesky(self.server_name) + sb = SirepoClient(self.server_name) data, schema = sb.auth(self.sim_code, self.sim_id) self._copies = [] self._srw_files = [] diff --git a/sirepo_bluesky/srw_handler.py b/sirepo_bluesky/srw/srw_handler.py similarity index 61% rename from sirepo_bluesky/srw_handler.py rename to sirepo_bluesky/srw/srw_handler.py index 6d0563a2..b2e69d79 100644 --- a/sirepo_bluesky/srw_handler.py +++ b/sirepo_bluesky/srw/srw_handler.py @@ -1,7 +1,9 @@ +import h5py import numpy as np import srwpy.uti_plot_com as srw_io +from area_detector_handlers.handlers import HandlerBase -from . import utils +from sirepo_bluesky import utils def read_srw_file(filename, ndim=2): @@ -15,8 +17,8 @@ def read_srw_file(filename, ndim=2): else: raise ValueError(f"The value ndim={ndim} is not supported.") - horizontal_extent = np.array(ranges[3:5]) - vertical_extent = np.array(ranges[6:8]) + horizontal_extent = np.array(ranges[3:5]).astype(float) + vertical_extent = np.array(ranges[6:8]).astype(float) ret = { "data": data, @@ -24,8 +26,10 @@ def read_srw_file(filename, ndim=2): "flux": data.sum(), "mean": data.mean(), "photon_energy": photon_energy, - "horizontal_extent": horizontal_extent, - "vertical_extent": vertical_extent, + "horizontal_extent_start": horizontal_extent[0], + "horizontal_extent_end": horizontal_extent[1], + "vertical_extent_start": vertical_extent[0], + "vertical_extent_end": vertical_extent[1], "labels": labels, "units": units, } @@ -48,3 +52,15 @@ def __init__(self, filename, ndim=2): def __call__(self): d = read_srw_file(self._name, ndim=self._ndim) return d["data"] + + +class SRWHDF5FileHandler(HandlerBase): + specs = {"SRW_HDF5"} + + def __init__(self, filename): + self._name = filename + + def __call__(self, frame): + with h5py.File(self._name, "r") as f: + entry = f["/entry/image"] + return entry[frame] diff --git a/sirepo_bluesky/srw/srw_ophyd.py b/sirepo_bluesky/srw/srw_ophyd.py new file mode 100644 index 00000000..e04088b6 --- /dev/null +++ b/sirepo_bluesky/srw/srw_ophyd.py @@ -0,0 +1,309 @@ +import datetime +import itertools +import os +import time +from collections import OrderedDict, namedtuple +from pathlib import Path + +import h5py +import numpy as np +from event_model import compose_resource +from ophyd.sim import NullStatus, new_uid +from skimage.transform import resize + +from sirepo_bluesky.common import SirepoSignal, logger +from sirepo_bluesky.common.base_classes import SirepoWatchpointBase +from sirepo_bluesky.srw.srw_handler import read_srw_file + + +class SirepoWatchpointSRW(SirepoWatchpointBase): + def __init__( + self, + *args, + root_dir="/tmp/sirepo-bluesky-data", + assets_dir=None, + result_file=None, + image_shape=(1024, 1024), + **kwargs, + ): + self._allowed_sim_types = ("srw",) + self._image_shape = image_shape + super().__init__(*args, root_dir=root_dir, assets_dir=assets_dir, result_file=result_file, **kwargs) + + def stage(self): + super().stage() + date = datetime.datetime.now() + self._assets_dir = date.strftime("%Y/%m/%d") + data_file = f"{new_uid()}.h5" + + self._resource_document, self._datum_factory, _ = compose_resource( + start={"uid": "needed for compose_resource() but will be discarded"}, + spec=f'{self.connection.data["simulationType"]}_hdf5'.upper(), + root=self._root_dir, + resource_path=str(Path(self._assets_dir) / Path(data_file)), + resource_kwargs={}, + ) + + self._data_file = str( + Path(self._resource_document["root"]) / Path(self._resource_document["resource_path"]) + ) + + # now discard the start uid, a real one will be added later + self._resource_document.pop("run_start") + self._asset_docs_cache.append(("resource", self._resource_document)) + + self._h5file_desc = h5py.File(self._data_file, "x") + group = self._h5file_desc.create_group("/entry") + self._dataset = group.create_dataset( + "image", + data=np.full(fill_value=np.nan, shape=(1, *self._image_shape)), + maxshape=(None, *self._image_shape), + chunks=(1, *self._image_shape), + dtype="float64", + compression="lzf", + ) + self._counter = itertools.count() + + def describe(self): + res = super().describe() + + res[self.image.name].update(dict(external="FILESTORE", shape=self._image_shape)) + + return res + + def trigger(self, *args, **kwargs): + logger.debug(f"Custom trigger for {self.name}") + + self.connection.data["report"] = self._report + + current_frame = next(self._counter) + sim_result_file = f"{os.path.splitext(self._data_file)[0]}_{self._sim_type}_{current_frame:04d}.dat" + + _, duration = self.connection.run_simulation() + self.duration.put(duration) + + datafile = self.connection.get_datafile(file_index=-1) + + with open(sim_result_file, "wb") as f: + f.write(datafile) + + ndim = 2 # this will always be a report with 2D data. + ret = read_srw_file(sim_result_file, ndim=ndim) + # TODO: rename _image_shape to _target_image_shape? + data = resize(ret["data"], self._image_shape) + self._dataset.resize((current_frame + 1, *self._image_shape)) + + logger.debug(f"{self._dataset = }\n{self._dataset.shape = }") + + self._dataset[current_frame, :, :] = data + + def update_components(_data): + self.flux.put(_data["flux"]) + self.mean.put(_data["mean"]) + self.x.put(_data["x"]) + self.y.put(_data["y"]) + self.fwhm_x.put(_data["fwhm_x"]) + self.fwhm_y.put(_data["fwhm_y"]) + self.photon_energy.put(_data["photon_energy"]) + self.horizontal_extent_start.put(_data["horizontal_extent_start"]) + self.horizontal_extent_end.put(_data["horizontal_extent_end"]) + self.vertical_extent_start.put(_data["vertical_extent_start"]) + self.vertical_extent_end.put(_data["vertical_extent_end"]) + + # TODO: think about what should be passed - raw data from .dat files or the resized data? + update_components(ret) + + datum_document = self._datum_factory(datum_kwargs={"frame": current_frame}) + self._asset_docs_cache.append(("datum", datum_document)) + + self.image.put(datum_document["datum_id"]) + + logger.debug(f"\nReport for {self.name}: {self.connection.data['report']}\n") + + # We call the trigger on super at the end to update the sirepo_data_json + # and the corresponding hash after the simulation is run. + super().trigger(*args, **kwargs) + return NullStatus() + + def unstage(self): + super().unstage() + self._resource_document = None + self._datum_factory = None + del self._dataset + self._h5file_desc.close() + + +# This is for backwards compatibility +SirepoWatchpoint = SirepoWatchpointSRW + + +class SingleElectronSpectrumReport(SirepoWatchpointSRW): + horizontal_extent_start = None + horizontal_extent_end = None + vertical_extent_start = None + vertical_extent_end = None + x = None + y = None + fwhm_x = None + fwhm_y = None + photon_energy = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._report = "intensityReport" + self._image_shape = None # placeholder + + def stage(self): + pass + + def describe(self): + res = super().describe() + + num_points = int(self.connection.data["models"]["intensityReport"]["photonEnergyPointCount"]) + res[self.image.name].update(dict(shape=(num_points,))) + + return res + + def trigger(self, *args, **kwargs): + logger.debug(f"Custom trigger for {self.name}") + + self.connection.data["report"] = self._report + + date = datetime.datetime.now() + self._assets_dir = date.strftime("%Y/%m/%d") + self._result_file = f"{new_uid()}.dat" + + self._resource_document, self._datum_factory, _ = compose_resource( + start={"uid": "needed for compose_resource() but will be discarded"}, + spec=self.connection.data["simulationType"], + root=self._root_dir, + resource_path=str(Path(self._assets_dir) / Path(self._result_file)), + resource_kwargs={}, + ) + # now discard the start uid, a real one will be added later + self._resource_document.pop("run_start") + self._asset_docs_cache.append(("resource", self._resource_document)) + + sim_result_file = str( + Path(self._resource_document["root"]) / Path(self._resource_document["resource_path"]) + ) + + start_time = time.monotonic() + self.connection.run_simulation() + self.duration.put(time.monotonic() - start_time) + + datafile = self.connection.get_datafile() + + with open(sim_result_file, "wb") as f: + f.write(datafile) + + ndim = 1 + ret = read_srw_file(sim_result_file, ndim=ndim) + self._resource_document["resource_kwargs"]["ndim"] = ndim + + def update_components(_data): + self.flux.put(_data["flux"]) + self.mean.put(_data["mean"]) + + update_components(ret) + + datum_document = self._datum_factory(datum_kwargs={}) + self._asset_docs_cache.append(("datum", datum_document)) + + self.image.put(datum_document["datum_id"]) + + self._resource_document = None + self._datum_factory = None + + logger.debug(f"\nReport for {self.name}: {self.connection.data['report']}\n") + + return NullStatus() + + def unstage(self): + self._resource_document = None + self._datum_factory = None + + +class SirepoSignalGrazingAngle(SirepoSignal): + def set(self, value): + super().set(value) + ret = self.parent.connection.compute_grazing_orientation(self._sirepo_dict) + # State is added to the ret dict from compute_grazing_orientation and we + # want to make sure the vectors are updated properly every time the + # grazing angle is updated. + ret.pop("state") + # Update vector components + for cpt in [ + "normalVectorX", + "normalVectorY", + "normalVectorZ", + "tangentialVectorX", + "tangentialVectorY", + ]: + getattr(self.parent, cpt).put(ret[cpt]) + return NullStatus() + + +class SirepoSignalCRL(SirepoSignal): + def set(self, value): + super().set(value) + ret = self.parent.connection.compute_crl_characteristics(self._sirepo_dict) + # State is added to the ret dict from compute_crl_characteristics and we + # want to make sure the crl element is updated properly when parameters are changed. + ret.pop("state") + # Update crl element + for cpt in ["absoluteFocusPosition", "focalDistance"]: + getattr(self.parent, cpt).put(ret[cpt]) + return NullStatus() + + +class SirepoSignalCrystal(SirepoSignal): + def set(self, value): + super().set(value) + ret = self.parent.connection.compute_crystal_orientation(self._sirepo_dict) + # State is added to the ret dict from compute_crystal_orientation and we + # want to make sure the crystal element is updated properly when parameters are changed. + ret.pop("state") + # Update crystal element + for cpt in [ + "dSpacing", + "grazingAngle", + "nvx", + "nvy", + "nvz", + "outframevx", + "outframevy", + "outoptvx", + "outoptvy", + "outoptvz", + "psi0i", + "psi0r", + "psiHBi", + "psiHBr", + "psiHi", + "psiHr", + "tvx", + "tvy", + ]: + getattr(self.parent, cpt).put(ret[cpt]) + return NullStatus() + + +SimplePropagationConfig = namedtuple( + "PropagationConfig", + "resize_before resize_after precision propagator_type " + + "fourier_resize hrange_mod hres_mod vrange_mod vres_mod", +) + + +class PropagationConfig(SimplePropagationConfig): + read_attrs = list(SimplePropagationConfig._fields) + component_names = SimplePropagationConfig._fields + + def read(self): + read_attrs = self.read_attrs + propagation_read = OrderedDict() + for field in read_attrs: + propagation_read[field] = getattr(self, field).read() + return propagation_read diff --git a/sirepo_bluesky/tests/conftest.py b/sirepo_bluesky/tests/conftest.py index 995f3b5b..9a92e211 100644 --- a/sirepo_bluesky/tests/conftest.py +++ b/sirepo_bluesky/tests/conftest.py @@ -8,10 +8,12 @@ from databroker import Broker from ophyd.utils import make_dir_tree -from sirepo_bluesky.madx_handler import MADXFileHandler -from sirepo_bluesky.shadow_handler import ShadowFileHandler -from sirepo_bluesky.sirepo_bluesky import SirepoBluesky -from sirepo_bluesky.srw_handler import SRWFileHandler +from sirepo_bluesky.common.sirepo_client import SirepoClient +from sirepo_bluesky.madx.madx_handler import MADXFileHandler +from sirepo_bluesky.shadow.shadow_handler import ShadowFileHandler +from sirepo_bluesky.srw.srw_handler import SRWFileHandler, SRWHDF5FileHandler + +DEFAULT_SIREPO_URL = "http://localhost:8000" @pytest.fixture(scope="function") @@ -25,8 +27,9 @@ def db(): pass db.reg.register_handler("srw", SRWFileHandler, overwrite=True) - db.reg.register_handler("shadow", ShadowFileHandler, overwrite=True) db.reg.register_handler("SIREPO_FLYER", SRWFileHandler, overwrite=True) + db.reg.register_handler("SRW_HDF5", SRWHDF5FileHandler, overwrite=True) + db.reg.register_handler("shadow", ShadowFileHandler, overwrite=True) db.reg.register_handler("madx", MADXFileHandler, overwrite=True) return db @@ -67,76 +70,76 @@ def make_dirs(): @pytest.fixture(scope="function") def srw_empty_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("srw", "emptysim") return connection @pytest.fixture(scope="function") def srw_youngs_double_slit_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("srw", "00000000") return connection @pytest.fixture(scope="function") def srw_basic_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("srw", "00000001") return connection @pytest.fixture(scope="function") def srw_tes_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("srw", "00000002") return connection @pytest.fixture(scope="function") def srw_ari_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("srw", "00000003") return connection @pytest.fixture(scope="function") def srw_chx_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("srw", "HXV1JQ5c") return connection @pytest.fixture(scope="function") def shadow_basic_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("shadow", "00000001") return connection @pytest.fixture(scope="function") def shadow_tes_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("shadow", "00000002") return connection @pytest.fixture(scope="function") def madx_resr_storage_ring_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("madx", "00000000") return connection @pytest.fixture(scope="function") def madx_bl1_compton_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("madx", "00000001") return connection @pytest.fixture(scope="function") def madx_bl2_triplet_tdc_simulation(make_dirs): - connection = SirepoBluesky("http://localhost:8000") + connection = SirepoClient(DEFAULT_SIREPO_URL) data, _ = connection.auth("madx", "00000002") return connection diff --git a/sirepo_bluesky/tests/test_bl_elements_as_ophyd_objs.py b/sirepo_bluesky/tests/test_bl_elements_as_ophyd_objs.py deleted file mode 100644 index 2713d729..00000000 --- a/sirepo_bluesky/tests/test_bl_elements_as_ophyd_objs.py +++ /dev/null @@ -1,576 +0,0 @@ -import copy -import json -import os -import pprint - -import bluesky.plan_stubs as bps -import bluesky.plans as bp -import dictdiffer -import matplotlib.pyplot as plt -import numpy as np -import peakutils -import pytest -import tfs - -from sirepo_bluesky.madx_flyer import MADXFlyer -from sirepo_bluesky.sirepo_ophyd import BeamStatisticsReport, create_classes - - -def test_beamline_elements_as_ophyd_objects(srw_tes_simulation): - classes, objects = create_classes(connection=srw_tes_simulation) - - for name, obj in objects.items(): - pprint.pprint(obj.read()) - - globals().update(**objects) - - print(mono_crystal1.summary()) # noqa - pprint.pprint(mono_crystal1.read()) # noqa - - -def test_empty_simulation(srw_empty_simulation): - classes, objects = create_classes(connection=srw_empty_simulation) - globals().update(**objects) - - assert not srw_empty_simulation.data["models"]["beamline"] - objects.pop("post_propagation") - assert not objects - - -@pytest.mark.parametrize("method", ["set", "put"]) -def test_beamline_elements_set_put(srw_tes_simulation, method): - classes, objects = create_classes(connection=srw_tes_simulation) - globals().update(**objects) - - i = 0 - for k, v in objects.items(): - if "element_position" in v.component_names: - old_value = v.element_position.get() - old_sirepo_value = srw_tes_simulation.data["models"]["beamline"][i]["position"] - - getattr(v.element_position, method)(old_value + 100) - - new_value = v.element_position.get() - new_sirepo_value = srw_tes_simulation.data["models"]["beamline"][i]["position"] - - print( - f"\n Changed: {old_value} -> {new_value}\n Sirepo: {old_sirepo_value} -> {new_sirepo_value}\n" - ) - - assert old_value == old_sirepo_value - assert new_value == new_sirepo_value - assert new_value != old_value - assert abs(new_value - (old_value + 100)) < 1e-8 - i += 1 - - -@pytest.mark.parametrize("method", ["set", "put"]) -def test_crl_calculation(srw_chx_simulation, method): - classes, objects = create_classes(connection=srw_chx_simulation) - globals().update(**objects) - - params_before = copy.deepcopy(crl1.tipRadius._sirepo_dict) # noqa F821 - params_before.pop("tipRadius") - - getattr(crl1.tipRadius, method)(2000) # noqa F821 - - params_after = copy.deepcopy(crl1.tipRadius._sirepo_dict) # noqa F821 - params_after.pop("tipRadius") - - params_diff = list(dictdiffer.diff(params_before, params_after)) - assert len(params_diff) > 0 # should not be empty - - expected_values = { - "absoluteFocusPosition": -6.195573642892285, - "focalDistance": 237.666984823537, - } - - actual_values = { - "absoluteFocusPosition": crl1.absoluteFocusPosition.get(), # noqa F821 - "focalDistance": crl1.focalDistance.get(), # noqa F821 - } - - assert not list(dictdiffer.diff(expected_values, actual_values)) - - -@pytest.mark.parametrize("method", ["set", "put"]) -def test_crystal_calculation(srw_tes_simulation, method): - classes, objects = create_classes(connection=srw_tes_simulation) - globals().update(**objects) - - params_before = copy.deepcopy(mono_crystal1.energy._sirepo_dict) # noqa F821 - params_before.pop("energy") - - getattr(mono_crystal1.energy, method)(2000) # noqa F821 - - params_after = copy.deepcopy(mono_crystal1.energy._sirepo_dict) # noqa F821 - params_after.pop("energy") - - params_diff = list(dictdiffer.diff(params_before, params_after)) - assert len(params_diff) > 0 # should not be empty - - expected_values = { - "dSpacing": 3.1355713563754857, - "grazingAngle": 1419.9107955732711, - "nvx": 0, - "nvy": 0.15031366142760424, - "nvz": -0.9886383581412506, - "outframevx": 1.0, - "outframevy": 0.0, - "outoptvx": 0.0, - "outoptvy": 0.29721170287997256, - "outoptvz": -0.9548116063764552, - "psi0i": 6.530421915581681e-05, - "psi0r": -0.00020558072555357544, - "psiHBi": 4.559368494529194e-05, - "psiHBr": -0.00010207663788071082, - "psiHi": 4.559368494529194e-05, - "psiHr": -0.00010207663788071082, - "tvx": 0, - "tvy": 0.9886383581412506, - } - - actual_values = { - "dSpacing": mono_crystal1.dSpacing.get(), # noqa F821 - "grazingAngle": mono_crystal1.grazingAngle.get(), # noqa F821 - "nvx": mono_crystal1.nvx.get(), # noqa F821 - "nvy": mono_crystal1.nvy.get(), # noqa F821 - "nvz": mono_crystal1.nvz.get(), # noqa F821 - "outframevx": mono_crystal1.outframevx.get(), # noqa F821 - "outframevy": mono_crystal1.outframevy.get(), # noqa F821 - "outoptvx": mono_crystal1.outoptvx.get(), # noqa F821 - "outoptvy": mono_crystal1.outoptvy.get(), # noqa F821 - "outoptvz": mono_crystal1.outoptvz.get(), # noqa F821 - "psi0i": mono_crystal1.psi0i.get(), # noqa F821 - "psi0r": mono_crystal1.psi0r.get(), # noqa F821 - "psiHBi": mono_crystal1.psiHBi.get(), # noqa F821 - "psiHBr": mono_crystal1.psiHBr.get(), # noqa F821 - "psiHi": mono_crystal1.psiHi.get(), # noqa F821 - "psiHr": mono_crystal1.psiHr.get(), # noqa F821 - "tvx": mono_crystal1.tvx.get(), # noqa F821 - "tvy": mono_crystal1.tvy.get(), # noqa F821 - } - - assert not list(dictdiffer.diff(expected_values, actual_values)) - - -@pytest.mark.parametrize("method", ["set", "put"]) -def test_grazing_angle_calculation(srw_tes_simulation, method): - classes, objects = create_classes(connection=srw_tes_simulation) - globals().update(**objects) - - params_before = copy.deepcopy(toroid.grazingAngle._sirepo_dict) # noqa F821 - params_before.pop("grazingAngle") - - getattr(toroid.grazingAngle, method)(10) # noqa F821 - - params_after = copy.deepcopy(toroid.grazingAngle._sirepo_dict) # noqa F821 - params_after.pop("grazingAngle") - - params_diff = list(dictdiffer.diff(params_before, params_after)) - assert len(params_diff) > 0 # should not be empty - - expected_vector_values = { - "nvx": 0, - "nvy": 0.9999500004166653, - "nvz": -0.009999833334166664, - "tvx": 0, - "tvy": 0.009999833334166664, - } - - actual_vector_values = { - "nvx": toroid.normalVectorX.get(), # noqa F821 - "nvy": toroid.normalVectorY.get(), # noqa F821 - "nvz": toroid.normalVectorZ.get(), # noqa F821 - "tvx": toroid.tangentialVectorX.get(), # noqa F821 - "tvy": toroid.tangentialVectorY.get(), # noqa F821 - } - - assert not list(dictdiffer.diff(expected_vector_values, actual_vector_values)) - - -def test_beamline_elements_simple_connection(srw_basic_simulation): - classes, objects = create_classes(connection=srw_basic_simulation) - - for name, obj in objects.items(): - pprint.pprint(obj.read()) - - globals().update(**objects) - - print(watchpoint.summary()) # noqa F821 - pprint.pprint(watchpoint.read()) # noqa F821 - - -def test_srw_source_with_run_engine(RE, db, srw_ari_simulation, num_steps=5): - classes, objects = create_classes( - connection=srw_ari_simulation, - extra_model_fields=["undulator", "intensityReport"], - ) - globals().update(**objects) - - undulator.verticalAmplitude.kind = "hinted" # noqa F821 - - single_electron_spectrum.initialEnergy.get() # noqa F821 - single_electron_spectrum.initialEnergy.put(20) # noqa F821 - single_electron_spectrum.finalEnergy.put(1100) # noqa F821 - - assert srw_ari_simulation.data["models"]["intensityReport"]["initialEnergy"] == 20 - assert srw_ari_simulation.data["models"]["intensityReport"]["finalEnergy"] == 1100 - - (uid,) = RE( - bp.scan( - [single_electron_spectrum], # noqa F821 - undulator.verticalAmplitude, # noqa F821 - 0.2, - 1, - num_steps, - ) - ) # noqa F821 - - hdr = db[uid] - tbl = hdr.table() - print(tbl) - - ses_data = np.array(list(hdr.data("single_electron_spectrum_image"))) - ampl_data = np.array(list(hdr.data("undulator_verticalAmplitude"))) - # Check the shape of the image data is right: - assert ses_data.shape == (num_steps, 2000) - - resource_files = [] - for name, doc in hdr.documents(): - if name == "resource": - resource_files.append(os.path.basename(doc["resource_path"])) - - # Check that all resource files are unique: - assert len(set(resource_files)) == num_steps - - fig = plt.figure() - ax = fig.add_subplot() - for i in range(num_steps): - ax.plot(ses_data[i, :], label=f"vert. magn. fld. {ampl_data[i]:.3f}T") - peak = peakutils.indexes(ses_data[i, :]) - ax.scatter(peak, ses_data[i, peak]) - ax.grid() - ax.legend() - ax.set_title("Single-Electron Spectrum vs. Vertical Magnetic Field") - fig.savefig("ses-vs-ampl.png") - # plt.show() - - -def test_srw_propagation_with_run_engine(RE, db, srw_chx_simulation, num_steps=5): - classes, objects = create_classes(connection=srw_chx_simulation) - globals().update(**objects) - - post_propagation.hrange_mod.kind = "hinted" # noqa F821 - - (uid,) = RE(bp.scan([sample], post_propagation.hrange_mod, 0.1, 0.3, num_steps)) # noqa F821 - hdr = db[uid] - tbl = hdr.table(fill=True) - print(tbl) - - # Check that the duration for each step in the simulation is positive: - sim_durations = np.array(tbl["sample_duration"]) - assert (sim_durations > 0.0).all(), "Simulation steps did not properly run." - - sample_image = [] - for i in range(num_steps): - sample_image.append(np.array(list(hdr.data("sample_image"))[i])) - - # Check the shape of the image data is right and that hrange_mod was properly changed: - for i, hrange_mod in enumerate(np.linspace(0.1, 0.3, num_steps)): - assert json.loads(tbl["sample_sirepo_data_json"][i + 1])["models"]["postPropagation"][5] == hrange_mod - assert sample_image[i].shape == (294, int(hrange_mod * 1760)) - - -def test_shadow_with_run_engine(RE, db, shadow_tes_simulation, num_steps=5): - classes, objects = create_classes(connection=shadow_tes_simulation) - globals().update(**objects) - - aperture.horizontalSize.kind = "hinted" # noqa F821 - - (uid,) = RE(bp.scan([w9], aperture.horizontalSize, 0, 2, num_steps)) # noqa F821 - hdr = db[uid] - tbl = hdr.table(fill=True) - print(tbl) - - # Check that the duration for each step in the simulation is positive: - sim_durations = np.array(tbl["w9_duration"]) - assert (sim_durations > 0.0).all() - - w9_image = np.array(list(hdr.data("w9_image"))) - # Check the shape of the image data is right: - assert w9_image.shape == (num_steps, 100, 100) - - w9_mean_from_image = w9_image.mean(axis=(1, 2)) - w9_mean_from_table = np.array(tbl["w9_mean"]) - - # Check the number of elements correspond to a number of scan points: - assert len(w9_mean_from_table) == num_steps - - # Check that an average values of the first and last images are right: - assert np.allclose(w9_image[0].mean(), 0.0) - assert np.allclose(w9_image[-1].mean(), 0.255665516042795, atol=1e-3) - - # Check that the values from the table and averages from the image data are - # the same: - assert np.allclose(w9_mean_from_table, w9_mean_from_image) - - # Check that the averaged intensities from the table are ascending: - assert np.all(np.diff(w9_mean_from_table) > 0) - - resource_files = [] - for name, doc in hdr.documents(): - if name == "resource": - resource_files.append(os.path.basename(doc["resource_path"])) - - # Check that all resource files are unique: - assert len(set(resource_files)) == num_steps - - -def test_beam_statistics_report_only(RE, db, shadow_tes_simulation): - classes, objects = create_classes(connection=shadow_tes_simulation) - globals().update(**objects) - - bsr = BeamStatisticsReport(name="bsr", connection=shadow_tes_simulation) - - toroid.r_maj.kind = "hinted" # noqa F821 - - scan_range = (10_000, 50_000, 21) - - (uid,) = RE(bp.scan([bsr], toroid.r_maj, *scan_range)) # noqa F821 - hdr = db[uid] - tbl = hdr.table() - print(tbl) - - calc_durations = np.array(tbl["time"].diff(), dtype=float)[1:] / 1e9 - print(f"Calculated durations (seconds): {calc_durations}") - - # Check that the duration for each step in the simulation is non-zero: - cpt_durations = np.array(tbl["bsr_duration"]) - print(f"Durations from component (seconds): {cpt_durations}") - - assert (cpt_durations > 0.0).all() - assert (calc_durations > cpt_durations[1:]).all() - - fig = plt.figure() - ax = fig.add_subplot() - ax.plot(np.linspace(*scan_range)[1:], calc_durations) - ax.set_ylabel("Duration of simulations [s]") - ax.set_xlabel("Torus Major Radius [m]") - title = ( - f"Shadow TES simulation\n" - f"RE(bp.scan([bsr], toroid.r_maj, " - f"{', '.join([str(x) for x in scan_range])}))" - ) - ax.set_title(title) - fig.savefig("TES-Shadow-timing.png") - # plt.show() - - -def test_beam_statistics_report_and_watchpoint(RE, db, shadow_tes_simulation): - classes, objects = create_classes(connection=shadow_tes_simulation) - globals().update(**objects) - - bsr = BeamStatisticsReport(name="bsr", connection=shadow_tes_simulation) - - toroid.r_maj.kind = "hinted" # noqa F821 - - (uid,) = RE(bp.scan([bsr, w9], toroid.r_maj, 10000, 50000, 5)) # noqa F821 - hdr = db[uid] - tbl = hdr.table() - print(tbl) - - w9_data_1 = json.loads(tbl["w9_sirepo_data_json"][1]) - w9_data_5 = json.loads(tbl["w9_sirepo_data_json"][5]) - - bsr_data_1 = json.loads(tbl["bsr_sirepo_data_json"][1]) - bsr_data_5 = json.loads(tbl["bsr_sirepo_data_json"][5]) - - w9_diffs = list(dictdiffer.diff(w9_data_1, w9_data_5)) - assert w9_diffs == [("change", ["models", "beamline", 5, "r_maj"], (10000.0, 50000.0))] - - bsr_diffs = list(dictdiffer.diff(bsr_data_1, bsr_data_5)) - assert bsr_diffs == [("change", ["models", "beamline", 5, "r_maj"], (10000.0, 50000.0))] - - w9_bsr_diffs = list(dictdiffer.diff(w9_data_1, bsr_data_5)) - assert w9_bsr_diffs == [ - ("change", ["models", "beamline", 5, "r_maj"], (10000.0, 50000.0)), - ("change", "report", ("watchpointReport12", "beamStatisticsReport")), - ] - - -@pytest.mark.parametrize("method", ["set", "put"]) -def test_mad_x_elements_set_put(madx_resr_storage_ring_simulation, method): - connection = madx_resr_storage_ring_simulation - data = connection.data - classes, objects = create_classes(connection=connection) - globals().update(**objects) - - for i, (k, v) in enumerate(objects.items()): - old_value = v.l.get() # l is length - old_sirepo_value = data["models"]["elements"][i]["l"] - - getattr(v.l, method)(old_value + 10) - - new_value = v.l.get() - new_sirepo_value = data["models"]["elements"][i]["l"] - - print(f"\n Changed: {old_value} -> {new_value}\n Sirepo: {old_sirepo_value} -> {new_sirepo_value}\n") - - assert old_value == old_sirepo_value - assert new_value == new_sirepo_value - assert new_value != old_value - assert abs(new_value - (old_value + 10)) < 1e-8 - - -def test_mad_x_elements_simple_connection(madx_bl2_triplet_tdc_simulation): - connection = madx_bl2_triplet_tdc_simulation - classes, objects = create_classes(connection=connection) - for name, obj in objects.items(): - pprint.pprint(obj.read()) - - globals().update(**objects) - - print(bpm5.summary()) # noqa - pprint.pprint(bpm5.read()) # noqa - - -def test_madx_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): - connection = madx_bl2_triplet_tdc_simulation - classes, objects = create_classes(connection=connection) - globals().update(**objects) - - madx_flyer = MADXFlyer( - connection=connection, - root_dir="/tmp/sirepo-bluesky-data", - report="elementAnimation250-20", - ) - - (uid,) = RE(bp.fly([madx_flyer])) # noqa F821 - hdr = db[uid] - tbl = hdr.table(stream_name="madx_flyer", fill=True) - print(tbl) - - resource_files = [] - for name, doc in hdr.documents(): - if name == "resource": - resource_files.append(os.path.join(doc["root"], doc["resource_path"])) - - # Check that we have only one resource madx file for all datum documents: - assert len(set(resource_files)) == 1 - - df = tfs.read(resource_files[0]) - for column in df.columns: - if column == "NAME": - assert (tbl[f"madx_flyer_{column}"].astype("string").values == df[column].values).all() - else: - assert np.allclose( - np.array(tbl[f"madx_flyer_{column}"]).astype(float), - np.array(df[column]), - ) - - -def test_madx_variables_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): - connection = madx_bl2_triplet_tdc_simulation - data = connection.data - classes, objects = create_classes( - connection=connection, - extra_model_fields=["rpnVariables"], - ) - - globals().update(**objects) - - assert len(objects) == len(data["models"]["elements"]) + len(data["models"]["rpnVariables"]) - - madx_flyer = MADXFlyer( - connection=connection, - root_dir="/tmp/sirepo-bluesky-data", - report="elementAnimation250-20", - ) - - def madx_plan(parameter=ihq1, value=2.0): # noqa F821 - yield from bps.mv(parameter.value, value) - return (yield from bp.fly([madx_flyer])) - - (uid,) = RE(madx_plan()) # noqa F821 - hdr = db[uid] - tbl = hdr.table(stream_name="madx_flyer", fill=True) - print(tbl) - - expected_data_len = 151 - - assert len(tbl["madx_flyer_S"]) == expected_data_len - assert len(tbl["madx_flyer_BETX"]) == expected_data_len - assert len(tbl["madx_flyer_BETY"]) == expected_data_len - - -def test_madx_commands_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): - connection = madx_bl2_triplet_tdc_simulation - data = connection.data - classes, objects = create_classes( - connection=connection, - extra_model_fields=["commands"], - ) - - globals().update(**objects) - pprint.pprint(classes, sort_dicts=False) - - assert len(objects) == len(data["models"]["elements"]) + len(data["models"]["commands"]) - - madx_flyer = MADXFlyer( - connection=connection, - root_dir="/tmp/sirepo-bluesky-data", - report="elementAnimation250-20", - ) - - def madx_plan(element=match8, value=1.0): # noqa F821 - yield from bps.mv(element.deltap, value) - return (yield from bp.fly([madx_flyer])) - - (uid,) = RE(madx_plan()) # noqa F821 - hdr = db[uid] - tbl = hdr.table(stream_name="madx_flyer", fill=True) - print(tbl) - - expected_data_len = 151 - - assert len(tbl["madx_flyer_S"]) == expected_data_len - assert len(tbl["madx_flyer_BETX"]) == expected_data_len - assert len(tbl["madx_flyer_BETY"]) == expected_data_len - - -def test_madx_variables_and_commands_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): - connection = madx_bl2_triplet_tdc_simulation - data = connection.data - classes, objects = create_classes( - connection=connection, - extra_model_fields=["rpnVariables", "commands"], - ) - - globals().update(**objects) - - assert len(objects) == len(data["models"]["elements"]) + len(data["models"]["rpnVariables"]) + len( - data["models"]["commands"] - ) - - madx_flyer = MADXFlyer( - connection=connection, - root_dir="/tmp/sirepo-bluesky-data", - report="elementAnimation250-20", - ) - - def madx_plan(element=match8, parameter=ihq1, value=1.0): # noqa F821 - yield from bps.mv(element.deltap, value) - yield from bps.mv(parameter.value, value) - return (yield from bp.fly([madx_flyer])) - - (uid,) = RE(madx_plan()) # noqa F821 - hdr = db[uid] - tbl = hdr.table(stream_name="madx_flyer", fill=True) - print(tbl) - - expected_data_len = 151 - - assert len(tbl["madx_flyer_S"]) == expected_data_len - assert len(tbl["madx_flyer_BETX"]) == expected_data_len - assert len(tbl["madx_flyer_BETY"]) == expected_data_len diff --git a/sirepo_bluesky/tests/test_madx.py b/sirepo_bluesky/tests/test_madx.py new file mode 100644 index 00000000..1739517f --- /dev/null +++ b/sirepo_bluesky/tests/test_madx.py @@ -0,0 +1,194 @@ +import os +import pprint + +import bluesky.plan_stubs as bps +import bluesky.plans as bp +import numpy as np +import pytest +import tfs + +from sirepo_bluesky.common.create_classes import create_classes +from sirepo_bluesky.madx.madx_flyer import MADXFlyer + + +@pytest.mark.madx +@pytest.mark.parametrize("method", ["set", "put"]) +def test_mad_x_elements_set_put(madx_resr_storage_ring_simulation, method): + connection = madx_resr_storage_ring_simulation + data = connection.data + classes, objects = create_classes(connection=connection) + globals().update(**objects) + + for i, (k, v) in enumerate(objects.items()): + old_value = v.l.get() # l is length + old_sirepo_value = data["models"]["elements"][i]["l"] + + getattr(v.l, method)(old_value + 10) + + new_value = v.l.get() + new_sirepo_value = data["models"]["elements"][i]["l"] + + print(f"\n Changed: {old_value} -> {new_value}\n Sirepo: {old_sirepo_value} -> {new_sirepo_value}\n") + + assert old_value == old_sirepo_value + assert new_value == new_sirepo_value + assert new_value != old_value + assert abs(new_value - (old_value + 10)) < 1e-8 + + +@pytest.mark.madx +def test_mad_x_elements_simple_connection(madx_bl2_triplet_tdc_simulation): + connection = madx_bl2_triplet_tdc_simulation + classes, objects = create_classes(connection=connection) + for name, obj in objects.items(): + pprint.pprint(obj.read()) + + globals().update(**objects) + + print(bpm5.summary()) # noqa + pprint.pprint(bpm5.read()) # noqa + + +@pytest.mark.madx +def test_madx_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): + connection = madx_bl2_triplet_tdc_simulation + classes, objects = create_classes(connection=connection) + globals().update(**objects) + + madx_flyer = MADXFlyer( + connection=connection, + root_dir="/tmp/sirepo-bluesky-data", + report="elementAnimation250-20", + ) + + (uid,) = RE(bp.fly([madx_flyer])) # noqa F821 + hdr = db[uid] + tbl = hdr.table(stream_name="madx_flyer", fill=True) + print(tbl) + + resource_files = [] + for name, doc in hdr.documents(): + if name == "resource": + resource_files.append(os.path.join(doc["root"], doc["resource_path"])) + + # Check that we have only one resource madx file for all datum documents: + assert len(set(resource_files)) == 1 + + df = tfs.read(resource_files[0]) + for column in df.columns: + if column == "NAME": + assert (tbl[f"madx_flyer_{column}"].astype("string").values == df[column].values).all() + else: + assert np.allclose( + np.array(tbl[f"madx_flyer_{column}"]).astype(float), + np.array(df[column]), + ) + + +@pytest.mark.madx +def test_madx_variables_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): + connection = madx_bl2_triplet_tdc_simulation + data = connection.data + classes, objects = create_classes( + connection=connection, + extra_model_fields=["rpnVariables"], + ) + + globals().update(**objects) + + assert len(objects) == len(data["models"]["elements"]) + len(data["models"]["rpnVariables"]) + + madx_flyer = MADXFlyer( + connection=connection, + root_dir="/tmp/sirepo-bluesky-data", + report="elementAnimation250-20", + ) + + def madx_plan(parameter=ihq1, value=2.0): # noqa F821 + yield from bps.mv(parameter.value, value) + return (yield from bp.fly([madx_flyer])) + + (uid,) = RE(madx_plan()) # noqa F821 + hdr = db[uid] + tbl = hdr.table(stream_name="madx_flyer", fill=True) + print(tbl) + + expected_data_len = 151 + + assert len(tbl["madx_flyer_S"]) == expected_data_len + assert len(tbl["madx_flyer_BETX"]) == expected_data_len + assert len(tbl["madx_flyer_BETY"]) == expected_data_len + + +@pytest.mark.madx +def test_madx_commands_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): + connection = madx_bl2_triplet_tdc_simulation + data = connection.data + classes, objects = create_classes( + connection=connection, + extra_model_fields=["commands"], + ) + + globals().update(**objects) + pprint.pprint(classes, sort_dicts=False) + + assert len(objects) == len(data["models"]["elements"]) + len(data["models"]["commands"]) + + madx_flyer = MADXFlyer( + connection=connection, + root_dir="/tmp/sirepo-bluesky-data", + report="elementAnimation250-20", + ) + + def madx_plan(element=match8, value=1.0): # noqa F821 + yield from bps.mv(element.deltap, value) + return (yield from bp.fly([madx_flyer])) + + (uid,) = RE(madx_plan()) # noqa F821 + hdr = db[uid] + tbl = hdr.table(stream_name="madx_flyer", fill=True) + print(tbl) + + expected_data_len = 151 + + assert len(tbl["madx_flyer_S"]) == expected_data_len + assert len(tbl["madx_flyer_BETX"]) == expected_data_len + assert len(tbl["madx_flyer_BETY"]) == expected_data_len + + +@pytest.mark.madx +def test_madx_variables_and_commands_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation): + connection = madx_bl2_triplet_tdc_simulation + data = connection.data + classes, objects = create_classes( + connection=connection, + extra_model_fields=["rpnVariables", "commands"], + ) + + globals().update(**objects) + + assert len(objects) == len(data["models"]["elements"]) + len(data["models"]["rpnVariables"]) + len( + data["models"]["commands"] + ) + + madx_flyer = MADXFlyer( + connection=connection, + root_dir="/tmp/sirepo-bluesky-data", + report="elementAnimation250-20", + ) + + def madx_plan(element=match8, parameter=ihq1, value=1.0): # noqa F821 + yield from bps.mv(element.deltap, value) + yield from bps.mv(parameter.value, value) + return (yield from bp.fly([madx_flyer])) + + (uid,) = RE(madx_plan()) # noqa F821 + hdr = db[uid] + tbl = hdr.table(stream_name="madx_flyer", fill=True) + print(tbl) + + expected_data_len = 151 + + assert len(tbl["madx_flyer_S"]) == expected_data_len + assert len(tbl["madx_flyer_BETX"]) == expected_data_len + assert len(tbl["madx_flyer_BETY"]) == expected_data_len diff --git a/sirepo_bluesky/tests/test_shadow.py b/sirepo_bluesky/tests/test_shadow.py new file mode 100644 index 00000000..a4af1a95 --- /dev/null +++ b/sirepo_bluesky/tests/test_shadow.py @@ -0,0 +1,131 @@ +import json +import os + +import bluesky.plans as bp +import dictdiffer +import matplotlib.pyplot as plt +import numpy as np +import pytest + +from sirepo_bluesky.common.create_classes import create_classes +from sirepo_bluesky.shadow.shadow_ophyd import BeamStatisticsReport + + +@pytest.mark.shadow +def test_shadow_with_run_engine(RE, db, shadow_tes_simulation, num_steps=5): + classes, objects = create_classes(connection=shadow_tes_simulation) + globals().update(**objects) + + aperture.horizontalSize.kind = "hinted" # noqa F821 + + (uid,) = RE(bp.scan([w9], aperture.horizontalSize, 0, 2, num_steps)) # noqa F821 + hdr = db[uid] + tbl = hdr.table(fill=True) + print(tbl) + + # Check that the duration for each step in the simulation is positive: + sim_durations = np.array(tbl["w9_duration"]) + assert (sim_durations > 0.0).all() + + w9_image = np.array(list(hdr.data("w9_image"))) + # Check the shape of the image data is right: + assert w9_image.shape == (num_steps, 100, 100) + + w9_mean_from_image = w9_image.mean(axis=(1, 2)) + w9_mean_from_table = np.array(tbl["w9_mean"]) + + # Check the number of elements correspond to a number of scan points: + assert len(w9_mean_from_table) == num_steps + + # Check that an average values of the first and last images are right: + assert np.allclose(w9_image[0].mean(), 0.0) + assert np.allclose(w9_image[-1].mean(), 0.255665516042795, atol=1e-3) + + # Check that the values from the table and averages from the image data are + # the same: + assert np.allclose(w9_mean_from_table, w9_mean_from_image) + + # Check that the averaged intensities from the table are ascending: + assert np.all(np.diff(w9_mean_from_table) > 0) + + resource_files = [] + for name, doc in hdr.documents(): + if name == "resource": + resource_files.append(os.path.basename(doc["resource_path"])) + + # Check that all resource files are unique: + assert len(set(resource_files)) == num_steps + + +@pytest.mark.shadow +def test_beam_statistics_report_only(RE, db, shadow_tes_simulation): + classes, objects = create_classes(connection=shadow_tes_simulation) + globals().update(**objects) + + bsr = BeamStatisticsReport(name="bsr", connection=shadow_tes_simulation) + + toroid.r_maj.kind = "hinted" # noqa F821 + + scan_range = (10_000, 50_000, 21) + + (uid,) = RE(bp.scan([bsr], toroid.r_maj, *scan_range)) # noqa F821 + hdr = db[uid] + tbl = hdr.table() + print(tbl) + + calc_durations = np.array(tbl["time"].diff(), dtype=float)[1:] / 1e9 + print(f"Calculated durations (seconds): {calc_durations}") + + # Check that the duration for each step in the simulation is non-zero: + cpt_durations = np.array(tbl["bsr_duration"]) + print(f"Durations from component (seconds): {cpt_durations}") + + assert (cpt_durations > 0.0).all() + assert (calc_durations > cpt_durations[1:]).all() + + fig = plt.figure() + ax = fig.add_subplot() + ax.plot(np.linspace(*scan_range)[1:], calc_durations) + ax.set_ylabel("Duration of simulations [s]") + ax.set_xlabel("Torus Major Radius [m]") + title = ( + f"Shadow TES simulation\n" + f"RE(bp.scan([bsr], toroid.r_maj, " + f"{', '.join([str(x) for x in scan_range])}))" + ) + ax.set_title(title) + fig.savefig("TES-Shadow-timing.png") + # plt.show() + + +@pytest.mark.shadow +def test_beam_statistics_report_and_watchpoint(RE, db, shadow_tes_simulation): + classes, objects = create_classes(connection=shadow_tes_simulation) + globals().update(**objects) + + bsr = BeamStatisticsReport(name="bsr", connection=shadow_tes_simulation) + + toroid.r_maj.kind = "hinted" # noqa F821 + + (uid,) = RE(bp.scan([bsr, w9], toroid.r_maj, 10000, 50000, 5)) # noqa F821 + hdr = db[uid] + tbl = hdr.table() + print(tbl) + + w9_data_1 = json.loads(tbl["w9_sirepo_data_json"][1]) + w9_data_5 = json.loads(tbl["w9_sirepo_data_json"][5]) + + bsr_data_1 = json.loads(tbl["bsr_sirepo_data_json"][1]) + bsr_data_5 = json.loads(tbl["bsr_sirepo_data_json"][5]) + + w9_diffs = list(dictdiffer.diff(w9_data_1, w9_data_5)) + assert w9_diffs == [("change", ["models", "beamline", 5, "r_maj"], (10000.0, 50000.0))] + + bsr_diffs = list(dictdiffer.diff(bsr_data_1, bsr_data_5)) + assert bsr_diffs == [("change", ["models", "beamline", 5, "r_maj"], (10000.0, 50000.0))] + + w9_bsr_diffs = list(dictdiffer.diff(w9_data_1, bsr_data_5)) + assert w9_bsr_diffs == [ + ("change", ["models", "beamline", 5, "r_maj"], (10000.0, 50000.0)), + ("change", "report", ("watchpointReport12", "beamStatisticsReport")), + ] diff --git a/sirepo_bluesky/tests/test_srw.py b/sirepo_bluesky/tests/test_srw.py new file mode 100644 index 00000000..489f52a7 --- /dev/null +++ b/sirepo_bluesky/tests/test_srw.py @@ -0,0 +1,314 @@ +import copy +import json +import os +import pprint + +import bluesky.plans as bp +import dictdiffer +import matplotlib.pyplot as plt +import numpy as np +import peakutils +import pytest + +from sirepo_bluesky.common.create_classes import create_classes + + +@pytest.mark.srw +def test_beamline_elements_as_ophyd_objects(srw_tes_simulation): + classes, objects = create_classes(connection=srw_tes_simulation) + + for name, obj in objects.items(): + pprint.pprint(obj.read()) + + globals().update(**objects) + + print(mono_crystal1.summary()) # noqa + pprint.pprint(mono_crystal1.read()) # noqa + + +@pytest.mark.srw +def test_empty_simulation(srw_empty_simulation): + classes, objects = create_classes(connection=srw_empty_simulation) + globals().update(**objects) + + assert not srw_empty_simulation.data["models"]["beamline"] + objects.pop("post_propagation") + assert not objects + + +@pytest.mark.srw +@pytest.mark.parametrize("method", ["set", "put"]) +def test_beamline_elements_set_put(srw_tes_simulation, method): + classes, objects = create_classes(connection=srw_tes_simulation) + globals().update(**objects) + + i = 0 + for k, v in objects.items(): + if "element_position" in v.component_names: + old_value = v.element_position.get() + old_sirepo_value = srw_tes_simulation.data["models"]["beamline"][i]["position"] + + getattr(v.element_position, method)(old_value + 100) + + new_value = v.element_position.get() + new_sirepo_value = srw_tes_simulation.data["models"]["beamline"][i]["position"] + + print( + f"\n Changed: {old_value} -> {new_value}\n Sirepo: {old_sirepo_value} -> {new_sirepo_value}\n" + ) + + assert old_value == old_sirepo_value + assert new_value == new_sirepo_value + assert new_value != old_value + assert abs(new_value - (old_value + 100)) < 1e-8 + i += 1 + + +@pytest.mark.srw +@pytest.mark.parametrize("method", ["set", "put"]) +def test_crl_calculation(srw_chx_simulation, method): + classes, objects = create_classes(connection=srw_chx_simulation) + globals().update(**objects) + + params_before = copy.deepcopy(crl1.tipRadius._sirepo_dict) # noqa F821 + params_before.pop("tipRadius") + + getattr(crl1.tipRadius, method)(2000) # noqa F821 + + params_after = copy.deepcopy(crl1.tipRadius._sirepo_dict) # noqa F821 + params_after.pop("tipRadius") + + params_diff = list(dictdiffer.diff(params_before, params_after)) + assert len(params_diff) > 0 # should not be empty + + expected_values = { + "absoluteFocusPosition": -6.195573642892285, + "focalDistance": 237.666984823537, + } + + actual_values = { + "absoluteFocusPosition": crl1.absoluteFocusPosition.get(), # noqa F821 + "focalDistance": crl1.focalDistance.get(), # noqa F821 + } + + assert not list(dictdiffer.diff(expected_values, actual_values)) + + +@pytest.mark.srw +@pytest.mark.parametrize("method", ["set", "put"]) +def test_crystal_calculation(srw_tes_simulation, method): + classes, objects = create_classes(connection=srw_tes_simulation) + globals().update(**objects) + + params_before = copy.deepcopy(mono_crystal1.energy._sirepo_dict) # noqa F821 + params_before.pop("energy") + + getattr(mono_crystal1.energy, method)(2000) # noqa F821 + + params_after = copy.deepcopy(mono_crystal1.energy._sirepo_dict) # noqa F821 + params_after.pop("energy") + + params_diff = list(dictdiffer.diff(params_before, params_after)) + assert len(params_diff) > 0 # should not be empty + + expected_values = { + "dSpacing": 3.1355713563754857, + "grazingAngle": 1419.9107955732711, + "nvx": 0, + "nvy": 0.15031366142760424, + "nvz": -0.9886383581412506, + "outframevx": 1.0, + "outframevy": 0.0, + "outoptvx": 0.0, + "outoptvy": 0.29721170287997256, + "outoptvz": -0.9548116063764552, + "psi0i": 6.530421915581681e-05, + "psi0r": -0.00020558072555357544, + "psiHBi": 4.559368494529194e-05, + "psiHBr": -0.00010207663788071082, + "psiHi": 4.559368494529194e-05, + "psiHr": -0.00010207663788071082, + "tvx": 0, + "tvy": 0.9886383581412506, + } + + actual_values = { + "dSpacing": mono_crystal1.dSpacing.get(), # noqa F821 + "grazingAngle": mono_crystal1.grazingAngle.get(), # noqa F821 + "nvx": mono_crystal1.nvx.get(), # noqa F821 + "nvy": mono_crystal1.nvy.get(), # noqa F821 + "nvz": mono_crystal1.nvz.get(), # noqa F821 + "outframevx": mono_crystal1.outframevx.get(), # noqa F821 + "outframevy": mono_crystal1.outframevy.get(), # noqa F821 + "outoptvx": mono_crystal1.outoptvx.get(), # noqa F821 + "outoptvy": mono_crystal1.outoptvy.get(), # noqa F821 + "outoptvz": mono_crystal1.outoptvz.get(), # noqa F821 + "psi0i": mono_crystal1.psi0i.get(), # noqa F821 + "psi0r": mono_crystal1.psi0r.get(), # noqa F821 + "psiHBi": mono_crystal1.psiHBi.get(), # noqa F821 + "psiHBr": mono_crystal1.psiHBr.get(), # noqa F821 + "psiHi": mono_crystal1.psiHi.get(), # noqa F821 + "psiHr": mono_crystal1.psiHr.get(), # noqa F821 + "tvx": mono_crystal1.tvx.get(), # noqa F821 + "tvy": mono_crystal1.tvy.get(), # noqa F821 + } + + assert not list(dictdiffer.diff(expected_values, actual_values)) + + +@pytest.mark.srw +@pytest.mark.parametrize("method", ["set", "put"]) +def test_grazing_angle_calculation(srw_tes_simulation, method): + classes, objects = create_classes(connection=srw_tes_simulation) + globals().update(**objects) + + params_before = copy.deepcopy(toroid.grazingAngle._sirepo_dict) # noqa F821 + params_before.pop("grazingAngle") + + getattr(toroid.grazingAngle, method)(10) # noqa F821 + + params_after = copy.deepcopy(toroid.grazingAngle._sirepo_dict) # noqa F821 + params_after.pop("grazingAngle") + + params_diff = list(dictdiffer.diff(params_before, params_after)) + assert len(params_diff) > 0 # should not be empty + + expected_vector_values = { + "nvx": 0, + "nvy": 0.9999500004166653, + "nvz": -0.009999833334166664, + "tvx": 0, + "tvy": 0.009999833334166664, + } + + actual_vector_values = { + "nvx": toroid.normalVectorX.get(), # noqa F821 + "nvy": toroid.normalVectorY.get(), # noqa F821 + "nvz": toroid.normalVectorZ.get(), # noqa F821 + "tvx": toroid.tangentialVectorX.get(), # noqa F821 + "tvy": toroid.tangentialVectorY.get(), # noqa F821 + } + + assert not list(dictdiffer.diff(expected_vector_values, actual_vector_values)) + + +@pytest.mark.srw +def test_beamline_elements_simple_connection(srw_basic_simulation): + classes, objects = create_classes(connection=srw_basic_simulation) + + for name, obj in objects.items(): + pprint.pprint(obj.read()) + + globals().update(**objects) + + print(watchpoint.summary()) # noqa F821 + pprint.pprint(watchpoint.read()) # noqa F821 + + +@pytest.mark.srw +def test_srw_source_with_run_engine(RE, db, srw_ari_simulation, num_steps=5): + classes, objects = create_classes( + connection=srw_ari_simulation, + extra_model_fields=["undulator", "intensityReport"], + ) + globals().update(**objects) + + undulator.verticalAmplitude.kind = "hinted" # noqa F821 + + single_electron_spectrum.initialEnergy.get() # noqa F821 + single_electron_spectrum.initialEnergy.put(20) # noqa F821 + single_electron_spectrum.finalEnergy.put(1100) # noqa F821 + + assert srw_ari_simulation.data["models"]["intensityReport"]["initialEnergy"] == 20 + assert srw_ari_simulation.data["models"]["intensityReport"]["finalEnergy"] == 1100 + + (uid,) = RE( + bp.scan( + [single_electron_spectrum], # noqa F821 + undulator.verticalAmplitude, # noqa F821 + 0.2, + 1, + num_steps, + ) + ) # noqa F821 + + hdr = db[uid] + tbl = hdr.table() + print(tbl) + + ses_data = np.array(list(hdr.data("single_electron_spectrum_image"))) + ampl_data = np.array(list(hdr.data("undulator_verticalAmplitude"))) + # Check the shape of the image data is right: + assert ses_data.shape == (num_steps, 2000) + + resource_files = [] + for name, doc in hdr.documents(): + if name == "resource": + resource_files.append(os.path.basename(doc["resource_path"])) + + # Check that all resource files are unique: + assert len(set(resource_files)) == num_steps + + fig = plt.figure() + ax = fig.add_subplot() + for i in range(num_steps): + ax.plot(ses_data[i, :], label=f"vert. magn. fld. {ampl_data[i]:.3f}T") + peak = peakutils.indexes(ses_data[i, :]) + ax.scatter(peak, ses_data[i, peak]) + ax.grid() + ax.legend() + ax.set_title("Single-Electron Spectrum vs. Vertical Magnetic Field") + fig.savefig("ses-vs-ampl.png") + # plt.show() + + +@pytest.mark.srw +def test_srw_propagation_with_run_engine(RE, db, srw_chx_simulation, num_steps=5): + classes, objects = create_classes(connection=srw_chx_simulation) + globals().update(**objects) + + post_propagation.hrange_mod.kind = "hinted" # noqa F821 + + (uid,) = RE(bp.scan([sample], post_propagation.hrange_mod, 0.1, 0.3, num_steps)) # noqa F821 + hdr = db[uid] + tbl = hdr.table(fill=True) + print(tbl) + + # Check that the duration for each step in the simulation is positive: + sim_durations = np.array(tbl["sample_duration"]) + assert (sim_durations > 0.0).all(), "Simulation steps did not properly run." + + sample_image = [] + for i in range(num_steps): + sample_image.append(np.array(list(hdr.data("sample_image"))[i])) + + # Check the shape of the image data is right and that hrange_mod was properly changed: + for i, hrange_mod in enumerate(np.linspace(0.1, 0.3, num_steps)): + assert json.loads(tbl["sample_sirepo_data_json"][i + 1])["models"]["postPropagation"][5] == hrange_mod + + +@pytest.mark.srw +def test_srw_tes_propagation_with_run_engine(RE, db, srw_tes_simulation, num_steps=5): + classes, objects = create_classes(connection=srw_tes_simulation) + globals().update(**objects) + + post_propagation.hrange_mod.kind = "hinted" # noqa F821 + + # TODO: update to look like docs: + # https://nsls-ii.github.io/sirepo-bluesky/notebooks/srw.html#SRW-Propagation-as-Ophyd-Objects + (uid,) = RE(bp.scan([w9], post_propagation.hrange_mod, 0.1, 0.3, num_steps)) # noqa F821 + hdr = db[uid] + tbl = hdr.table(fill=True) + print(tbl) + + # Check that the duration for each step in the simulation is positive: + sim_durations = np.array(tbl["w9_duration"]) + assert (sim_durations > 0.0).all(), "Simulation steps did not properly run." + + sample_image = [] + for i in range(num_steps): + sample_image.append(np.array(list(hdr.data("w9_image"))[i])) + + # Check the shape of the image data is right and that hrange_mod was properly changed: + for i, hrange_mod in enumerate(np.linspace(0.1, 0.3, num_steps)): + assert json.loads(tbl["w9_sirepo_data_json"][i + 1])["models"]["postPropagation"][5] == hrange_mod diff --git a/sirepo_bluesky/tests/test_sirepo_flyer.py b/sirepo_bluesky/tests/test_srw_flyer.py similarity index 94% rename from sirepo_bluesky/tests/test_sirepo_flyer.py rename to sirepo_bluesky/tests/test_srw_flyer.py index 90075693..1ea0bfe0 100644 --- a/sirepo_bluesky/tests/test_sirepo_flyer.py +++ b/sirepo_bluesky/tests/test_srw_flyer.py @@ -6,14 +6,14 @@ import vcr import sirepo_bluesky.tests -from sirepo_bluesky.sirepo_bluesky import SirepoBluesky -from sirepo_bluesky.sirepo_flyer import SirepoFlyer +from sirepo_bluesky.common.sirepo_client import SirepoClient +from sirepo_bluesky.srw.srw_flyer import SRWFlyer cassette_location = os.path.join(os.path.dirname(sirepo_bluesky.tests.__file__), "vcr_cassettes") def _test_smoke_sirepo(sim_id, server_name): - sb = SirepoBluesky(server_name) + sb = SirepoClient(server_name) data, schema = sb.auth("srw", sim_id) assert "beamline" in data["models"] @@ -54,7 +54,7 @@ def _test_sirepo_flyer(RE_no_plot, db, tmpdir, sim_id, server_name): } ) - sirepo_flyer = SirepoFlyer( + sirepo_flyer = SRWFlyer( sim_id=sim_id, server_name=server_name, root_dir=root_dir, diff --git a/sirepo_bluesky/tests/test_stateless_compute.py b/sirepo_bluesky/tests/test_stateless_compute.py index 789538ab..d3bd075a 100644 --- a/sirepo_bluesky/tests/test_stateless_compute.py +++ b/sirepo_bluesky/tests/test_stateless_compute.py @@ -9,7 +9,7 @@ import vcr import sirepo_bluesky.tests -from sirepo_bluesky.sirepo_ophyd import create_classes +from sirepo_bluesky.common.create_classes import create_classes from sirepo_bluesky.utils.json_yaml_converter import dict_to_file cassette_location = os.path.join(os.path.dirname(sirepo_bluesky.tests.__file__), "vcr_cassettes") diff --git a/sirepo_bluesky/utils/prepare_re_env.py b/sirepo_bluesky/utils/prepare_re_env.py index f3ae4e15..8fac703d 100644 --- a/sirepo_bluesky/utils/prepare_re_env.py +++ b/sirepo_bluesky/utils/prepare_re_env.py @@ -71,15 +71,20 @@ def register_handlers(db, handlers): ret = re_env(**kwargs_re) globals().update(**ret) - from sirepo_bluesky.srw_handler import SRWFileHandler + from sirepo_bluesky.srw.srw_handler import SRWFileHandler, SRWHDF5FileHandler if args.env_type == "stepper": - from sirepo_bluesky.shadow_handler import ShadowFileHandler - - handlers = {"srw": SRWFileHandler, "SIREPO_FLYER": SRWFileHandler, "shadow": ShadowFileHandler} + from sirepo_bluesky.shadow.shadow_handler import ShadowFileHandler + + handlers = { + "srw": SRWFileHandler, + "SRW_HDF5": SRWHDF5FileHandler, + "SIREPO_FLYER": SRWFileHandler, + "shadow": ShadowFileHandler, + } plt.ion() elif args.env_type == "flyer": - from sirepo_bluesky.madx_handler import MADXFileHandler + from sirepo_bluesky.madx.madx_handler import MADXFileHandler handlers = {"srw": SRWFileHandler, "SIREPO_FLYER": SRWFileHandler, "madx": MADXFileHandler} bec.disable_plots() # noqa: F821