From 0d6d991fabdb470889bd3966c7fa3ff49b46bf42 Mon Sep 17 00:00:00 2001 From: Tyler Hughes Date: Sun, 12 May 2024 09:22:57 -0400 Subject: [PATCH] fix component modeler to and from file after web container refactor --- docs/notebooks | 2 +- tests/test_plugins/test_component_modeler.py | 28 ++++++++++++++++++ tests/test_web/test_webapi.py | 12 ++++---- .../smatrix/component_modelers/base.py | 29 ++++++++++++++++++- tidy3d/web/api/container.py | 26 +++++++++++++++++ 5 files changed, 90 insertions(+), 7 deletions(-) diff --git a/docs/notebooks b/docs/notebooks index a4c0aab94..1c95bcca6 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit a4c0aab94dd906e77041124a230490daa21deb4b +Subproject commit 1c95bcca6f7c6db8d93065710be7fcfa206521c8 diff --git a/tests/test_plugins/test_component_modeler.py b/tests/test_plugins/test_component_modeler.py index 8a8faed28..35f811dfa 100644 --- a/tests/test_plugins/test_component_modeler.py +++ b/tests/test_plugins/test_component_modeler.py @@ -12,6 +12,18 @@ ) from tidy3d.exceptions import SetupError, Tidy3dKeyError from ..utils import run_emulated +from ..test_web.test_webapi import ( + mock_upload, + mock_metadata, + mock_get_info, + mock_start, + mock_monitor, + mock_download, + mock_load, + mock_job_status, + mock_load, + set_api_key, +) # Waveguide height wg_height = 0.22 @@ -386,3 +398,19 @@ def test_batch_filename(tmp_path): def test_import_smatrix_smatrix(): from tidy3d.plugins.smatrix.smatrix import Port, ComponentModeler # noqa: F401 + + +def test_to_from_file_batch(monkeypatch, tmp_path): + modeler = make_component_modeler(path_dir=str(tmp_path)) + s_matrix = run_component_modeler(monkeypatch, modeler) + + batch = td.web.Batch(simulations=dict()) + + modeler._cached_properties["batch"] = batch + + fname = str(tmp_path) + "/modeler.json" + + modeler.to_file(fname) + modeler2 = modeler.from_file(fname) + + assert modeler2.batch_cached == modeler2.batch == batch diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 870bbd9c6..064e61990 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -541,12 +541,14 @@ def test_batch(mock_webapi, mock_job_status, mock_load, tmp_path): fname = str(tmp_path / "batch.json") b.to_file(fname) - b = b.from_file(fname) + b2 = b.from_file(fname) - b.estimate_cost() - b.run(path_dir=str(tmp_path)) - _ = b.get_info() - assert b.real_cost() == FLEX_UNIT * len(sims) + assert all(j.task_id == j2.task_id for j, j2 in zip(b.jobs.values(), b2.jobs.values())) + + b2.estimate_cost() + b2.run(path_dir=str(tmp_path)) + _ = b2.get_info() + assert b2.real_cost() == FLEX_UNIT * len(sims) """ Async """ diff --git a/tidy3d/plugins/smatrix/component_modelers/base.py b/tidy3d/plugins/smatrix/component_modelers/base.py index 555ef39c0..dc6466c1e 100644 --- a/tidy3d/plugins/smatrix/component_modelers/base.py +++ b/tidy3d/plugins/smatrix/component_modelers/base.py @@ -3,6 +3,7 @@ from typing import Tuple, Dict, Union import os from abc import ABC, abstractmethod +import json import pydantic.v1 as pd import numpy as np @@ -91,6 +92,15 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): "If not supplied, uses default for the current front end version.", ) + batch_cached: Batch = pd.Field( + None, + title="Batch (Cached)", + description="Optional field to specify ``batch``. Only used as a workaround internally " + "so that ``batch`` is written when ``.to_file()`` and then the proper batch is loaded " + "from ``.from_file()``. We recommend leaving unset as setting this field along with " + "fields that were not used to create the task will cause errors.", + ) + @pd.validator("simulation", always=True) def _sim_has_no_sources(cls, val): """Make sure simulation has no sources as they interfere with tool.""" @@ -109,12 +119,30 @@ def _task_name(port: Port, mode_index: int = None) -> str: def sim_dict(self) -> Dict[str, Simulation]: """Generate all the :class:`Simulation` objects for the S matrix calculation.""" + def json(self, **kwargs): + """Save component to dictionary. Add the ``batch`` if it has been cached.""" + + self_json = super().json(**kwargs) + + batch = self._cached_properties.get("batch") + + if not batch: + return self_json + + self_dict = json.loads(self_json) + self_dict["batch_cached"] = json.loads(batch.json()) + return json.dumps(self_dict) + @cached_property def batch(self) -> Batch: """Batch associated with this component modeler.""" + if self.batch_cached is not None: + return self.batch_cached + # first try loading the batch from file, if it exists batch_path = self._batch_path + if os.path.exists(batch_path): return Batch.from_file(fname=batch_path) @@ -175,7 +203,6 @@ def _construct_smatrix(self, batch_data: BatchData) -> DataArray: def run(self, path_dir: str = DEFAULT_DATA_DIR) -> DataArray: """Solves for the scattering matrix of the system.""" path_dir = self.get_path_dir(path_dir) - batch_data = self._run_sims(path_dir=path_dir) return self._construct_smatrix(batch_data=batch_data) diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index d5b31804a..d5bb45eef 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -494,6 +494,15 @@ class Batch(WebContainer): "number of threads available on the system.", ) + jobs_cached: Dict[TaskName, Job] = pd.Field( + None, + title="Jobs (Cached)", + description="Optional field to specify ``jobs``. Only used as a workaround internally " + "so that ``jobs`` is written when ``Batch.to_file()`` and then the proper task is loaded " + "from ``Batch.from_file()``. We recommend leaving unset as setting this field along with " + "fields that were not used to create the task will cause errors.", + ) + _job_type = Job @staticmethod @@ -547,6 +556,9 @@ def jobs(self) -> Dict[TaskName, Job]: To start the simulations running, must call :meth:`Batch.start` after uploaded. """ + if self.jobs_cached is not None: + return self.jobs_cached + # the type of job to upload (to generalize to subclasses) JobType = self._job_type self_dict = self.dict() @@ -569,6 +581,20 @@ def jobs(self) -> Dict[TaskName, Job]: jobs[task_name] = job return jobs + def json(self, **kwargs): + """Save ``Batch`` to dictionary. Add the ``jobs`` if they have been cached.""" + + self_json = super().json(**kwargs) + + jobs = self._cached_properties.get("jobs") + + if not jobs: + return self_json + + self_dict = json.loads(self_json) + self_dict["jobs_cached"] = {k: json.loads(j.json()) for k, j in jobs.items()} + return json.dumps(self_dict) + @property def num_jobs(self) -> int: """Number of jobs in the batch."""