Skip to content

Commit

Permalink
fix component modeler to and from file after web container refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerflex committed May 13, 2024
1 parent 6316bad commit 0d6d991
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks
28 changes: 28 additions & 0 deletions tests/test_plugins/test_component_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
)
from tidy3d.exceptions import SetupError, Tidy3dKeyError
from ..utils import run_emulated
from ..test_web.test_webapi import (
mock_upload,
mock_metadata,
mock_get_info,
mock_start,
mock_monitor,
mock_download,
mock_load,
mock_job_status,
mock_load,
set_api_key,
)

# Waveguide height
wg_height = 0.22
Expand Down Expand Up @@ -386,3 +398,19 @@ def test_batch_filename(tmp_path):

def test_import_smatrix_smatrix():
from tidy3d.plugins.smatrix.smatrix import Port, ComponentModeler # noqa: F401


def test_to_from_file_batch(monkeypatch, tmp_path):
modeler = make_component_modeler(path_dir=str(tmp_path))
s_matrix = run_component_modeler(monkeypatch, modeler)

batch = td.web.Batch(simulations=dict())

modeler._cached_properties["batch"] = batch

fname = str(tmp_path) + "/modeler.json"

modeler.to_file(fname)
modeler2 = modeler.from_file(fname)

assert modeler2.batch_cached == modeler2.batch == batch
12 changes: 7 additions & 5 deletions tests/test_web/test_webapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,12 +541,14 @@ def test_batch(mock_webapi, mock_job_status, mock_load, tmp_path):
fname = str(tmp_path / "batch.json")

b.to_file(fname)
b = b.from_file(fname)
b2 = b.from_file(fname)

b.estimate_cost()
b.run(path_dir=str(tmp_path))
_ = b.get_info()
assert b.real_cost() == FLEX_UNIT * len(sims)
assert all(j.task_id == j2.task_id for j, j2 in zip(b.jobs.values(), b2.jobs.values()))

b2.estimate_cost()
b2.run(path_dir=str(tmp_path))
_ = b2.get_info()
assert b2.real_cost() == FLEX_UNIT * len(sims)


""" Async """
Expand Down
29 changes: 28 additions & 1 deletion tidy3d/plugins/smatrix/component_modelers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, Dict, Union
import os
from abc import ABC, abstractmethod
import json

import pydantic.v1 as pd
import numpy as np
Expand Down Expand Up @@ -91,6 +92,15 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel):
"If not supplied, uses default for the current front end version.",
)

batch_cached: Batch = pd.Field(
None,
title="Batch (Cached)",
description="Optional field to specify ``batch``. Only used as a workaround internally "
"so that ``batch`` is written when ``.to_file()`` and then the proper batch is loaded "
"from ``.from_file()``. We recommend leaving unset as setting this field along with "
"fields that were not used to create the task will cause errors.",
)

@pd.validator("simulation", always=True)
def _sim_has_no_sources(cls, val):
"""Make sure simulation has no sources as they interfere with tool."""
Expand All @@ -109,12 +119,30 @@ def _task_name(port: Port, mode_index: int = None) -> str:
def sim_dict(self) -> Dict[str, Simulation]:
"""Generate all the :class:`Simulation` objects for the S matrix calculation."""

def json(self, **kwargs):
"""Save component to dictionary. Add the ``batch`` if it has been cached."""

self_json = super().json(**kwargs)

batch = self._cached_properties.get("batch")

if not batch:
return self_json

self_dict = json.loads(self_json)
self_dict["batch_cached"] = json.loads(batch.json())
return json.dumps(self_dict)

@cached_property
def batch(self) -> Batch:
"""Batch associated with this component modeler."""

if self.batch_cached is not None:
return self.batch_cached

# first try loading the batch from file, if it exists
batch_path = self._batch_path

if os.path.exists(batch_path):
return Batch.from_file(fname=batch_path)

Expand Down Expand Up @@ -175,7 +203,6 @@ def _construct_smatrix(self, batch_data: BatchData) -> DataArray:
def run(self, path_dir: str = DEFAULT_DATA_DIR) -> DataArray:
"""Solves for the scattering matrix of the system."""
path_dir = self.get_path_dir(path_dir)

batch_data = self._run_sims(path_dir=path_dir)
return self._construct_smatrix(batch_data=batch_data)

Expand Down
26 changes: 26 additions & 0 deletions tidy3d/web/api/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,15 @@ class Batch(WebContainer):
"number of threads available on the system.",
)

jobs_cached: Dict[TaskName, Job] = pd.Field(
None,
title="Jobs (Cached)",
description="Optional field to specify ``jobs``. Only used as a workaround internally "
"so that ``jobs`` is written when ``Batch.to_file()`` and then the proper task is loaded "
"from ``Batch.from_file()``. We recommend leaving unset as setting this field along with "
"fields that were not used to create the task will cause errors.",
)

_job_type = Job

@staticmethod
Expand Down Expand Up @@ -547,6 +556,9 @@ def jobs(self) -> Dict[TaskName, Job]:
To start the simulations running, must call :meth:`Batch.start` after uploaded.
"""

if self.jobs_cached is not None:
return self.jobs_cached

# the type of job to upload (to generalize to subclasses)
JobType = self._job_type
self_dict = self.dict()
Expand All @@ -569,6 +581,20 @@ def jobs(self) -> Dict[TaskName, Job]:
jobs[task_name] = job
return jobs

def json(self, **kwargs):
"""Save ``Batch`` to dictionary. Add the ``jobs`` if they have been cached."""

self_json = super().json(**kwargs)

jobs = self._cached_properties.get("jobs")

if not jobs:
return self_json

self_dict = json.loads(self_json)
self_dict["jobs_cached"] = {k: json.loads(j.json()) for k, j in jobs.items()}
return json.dumps(self_dict)

@property
def num_jobs(self) -> int:
"""Number of jobs in the batch."""
Expand Down

0 comments on commit 0d6d991

Please sign in to comment.