Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix CustomSourceTime with times completely outside envelope definitio…
Browse files Browse the repository at this point in the history
…n time range
caseyflex committed Aug 13, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent eaf9bc6 commit 90c5441
Showing 4 changed files with 64 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- `DataArray` interpolation failure due to incorrect ordering of coordinates when interpolating with autograd tracers.
- Error in `CustomSourceTime` when evaluating at a list of times entirely outside of the range of the envelope definition times.

## [2.7.2] - 2024-08-07

12 changes: 12 additions & 0 deletions tests/test_components/test_source.py
Original file line number Diff line number Diff line change
@@ -307,8 +307,20 @@ def test_custom_source_time(log_capture):
atol=ATOL,
)

# all times out of range
_ = cst.amp_time([-1])
_ = cst.amp_time(-1)
assert np.allclose(cst.amp_time([2]), np.exp(-1j * 2 * np.pi * 2 * freq0), rtol=0, atol=ATOL)

assert_log_level(log_capture, None)

vals = td.components.data.data_array.TimeDataArray([1, 2], coords=dict(t=[-1, -0.5]))
dataset = td.components.data.dataset.TimeDataset(values=vals)
cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12)
source = td.PointDipole(center=(0, 0, 0), source_time=cst, polarization="Ex")
with AssertLogLevel(log_capture, "WARNING", contains_str="defined at times"):
sim = sim.updated_copy(sources=[source])

# test normalization warning
with AssertLogLevel(log_capture, "WARNING"):
sim = sim.updated_copy(normalize_index=0)
25 changes: 25 additions & 0 deletions tidy3d/components/simulation.py
Original file line number Diff line number Diff line change
@@ -3161,6 +3161,31 @@ def _post_init_validators(self) -> None:
self._validate_no_structures_pml()
self._validate_tfsf_nonuniform_grid()
self._validate_nonlinear_specs()
self._validate_custom_source_time()

def _validate_custom_source_time(self):
"""Warn if all simulation times are outside CustomSourceTime definition range."""
# skip this validation if tmesh can't be computed, for example because of unloaded
# custom media
try:
_ = self.tmesh
except pydantic.ValidationError:
return
for idx, source in enumerate(self.sources):
if isinstance(source.source_time, CustomSourceTime):
if source.source_time._all_outside_range(tmesh=self.tmesh):
data_times = source.source_time.data_times
mint = np.min(data_times)
maxt = np.max(data_times)
mintmesh = np.min(self.tmesh)
maxtmesh = np.max(self.tmesh)
log.warning(
f"'CustomSourceTime' at 'sources[{idx}]' is defined at "
"times which do not include any of the 'Simulation.tmesh'. "
f"'CustomSourceTime' is defined in the time range "
f"'({mint}, {maxt})'; 'Simulation.tmesh' covers the range "
f"'({mintmesh}, {maxtmesh})'"
)

def _validate_no_structures_pml(self) -> None:
"""Ensure no structures terminate / have bounds inside of PML."""
30 changes: 26 additions & 4 deletions tidy3d/components/source.py
Original file line number Diff line number Diff line change
@@ -352,6 +352,27 @@ def from_values(
source_time_dataset=source_time_dataset,
)

@property
def data_times(self) -> ArrayFloat1D:
"""Times of envelope definition."""
if self.source_time_dataset is None:
return []
data_times = self.source_time_dataset.values.coords["t"].values.squeeze()
return data_times

def _all_outside_range(self, tmesh: ArrayFloat1D) -> bool:
"""Whether all tmesh are outside range of definition."""
# make time a numpy array for uniform handling
data_times = self.data_times

# shift time
twidth = 1.0 / (2 * np.pi * self.fwidth)
time_shifted = tmesh - self.offset * twidth

mask = (time_shifted < min(data_times)) | (time_shifted > max(data_times))

return all(mask)

def amp_time(self, time: float) -> complex:
"""Complex-valued source amplitude as a function of time.
@@ -370,8 +391,8 @@ def amp_time(self, time: float) -> complex:
return None

# make time a numpy array for uniform handling
times = np.array([time] if isinstance(time, float) else time)
data_times = self.source_time_dataset.values.coords["t"].values.squeeze()
times = np.array([time] if isinstance(time, (int, float)) else time)
data_times = self.data_times

# shift time
twidth = 1.0 / (2 * np.pi * self.fwidth)
@@ -384,12 +405,13 @@ def amp_time(self, time: float) -> complex:
envelope = np.zeros(len(time_shifted), dtype=complex)
values = self.source_time_dataset.values
envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy()
envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy()
if not all(mask):
envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy()

# modulation, phase, amplitude
omega0 = 2 * np.pi * self.freq0
offset = np.exp(1j * self.phase)
oscillation = np.exp(-1j * omega0 * time)
oscillation = np.exp(-1j * omega0 * times)
amp = self.amplitude

return offset * oscillation * amp * envelope

0 comments on commit 90c5441

Please sign in to comment.