From 1c7671da18ada689b2232a63c1ec9e37dedd774d Mon Sep 17 00:00:00 2001 From: Andrew Fullard Date: Tue, 23 Jul 2024 16:44:28 -0400 Subject: [PATCH 1/2] Add solve method, fixes integrator settings issue --- tardis/simulation/base.py | 4 --- tardis/spectrum/base.py | 27 ++++++++++++++++--- tardis/spectrum/tests/test_spectrum_solver.py | 25 +++++++++++++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index f928e90018c..19fc3fd5729 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -140,7 +140,6 @@ def __init__( convergence_plots_kwargs, show_progress_bars, spectrum_solver, - integrator_settings, ): super(Simulation, self).__init__( iterations, simulation_state.no_of_shells @@ -159,7 +158,6 @@ def __init__( self.luminosity_nu_end = luminosity_nu_end self.luminosity_requested = luminosity_requested self.spectrum_solver = spectrum_solver - self.integrator_settings = integrator_settings self.show_progress_bars = show_progress_bars self.version = tardis.__version__ @@ -477,7 +475,6 @@ def run_final(self): self.iterate(self.last_no_of_packets, self.no_of_virtual_packets) # Set up spectrum solver integrator - self.spectrum_solver.integrator_settings = self.integrator_settings self.spectrum_solver._integrator = FormalIntegrator( self.simulation_state, self.plasma, self.transport ) @@ -771,6 +768,5 @@ def from_config( convergence_strategy=config.montecarlo.convergence_strategy, convergence_plots_kwargs=convergence_plots_kwargs, show_progress_bars=show_progress_bars, - integrator_settings=config.spectrum.integrated, spectrum_solver=spectrum_solver, ) diff --git a/tardis/spectrum/base.py b/tardis/spectrum/base.py index 8c6b4096674..ba817c5540b 100644 --- a/tardis/spectrum/base.py +++ b/tardis/spectrum/base.py @@ -22,14 +22,16 @@ class SpectrumSolver(HDFWriterMixin): hdf_name = "spectrum" - def __init__(self, transport_state, spectrum_frequency_grid): + def __init__( + self, transport_state, spectrum_frequency_grid, integrator_settings=None + ): self.transport_state = transport_state self.spectrum_frequency_grid = spectrum_frequency_grid self._montecarlo_virtual_luminosity = u.Quantity( np.zeros_like(self.spectrum_frequency_grid.value), "erg / s" ) # should be init with v_packets_energy_hist self._integrator = None - self.integrator_settings = None + self.integrator_settings = integrator_settings self._spectrum_integrated = None @property @@ -60,7 +62,7 @@ def spectrum_virtual_packets(self): @property def spectrum_integrated(self): - if self._spectrum_integrated is None: + if self._spectrum_integrated is None and self.integrator is not None: # This was changed from unpacking to specific attributes as compute # is not used in calculate_spectrum try: @@ -83,13 +85,15 @@ def spectrum_integrated(self): np.array([np.nan, np.nan]) * u.Hz, np.array([np.nan]) * u.erg / u.s, ) + else: + self._spectrum_integrated = None return self._spectrum_integrated @property def integrator(self): if self._integrator is None: warnings.warn( - "MontecarloTransport.integrator: " + "SpectrumSolver.integrator: " "The FormalIntegrator is not yet available." "Please run the montecarlo simulation at least once.", UserWarning, @@ -178,6 +182,20 @@ def calculate_reabsorbed_luminosity( luminosity_wavelength_filter ].sum() + def solve(self): + """Solve the spectra + + Returns + ------- + tuple(TARDISSpectrum) + Real, virtual and integrated spectra, if available + """ + return ( + self.spectrum_real_packets, + self.spectrum_virtual_packets, + self.spectrum_integrated, + ) + @classmethod def from_config(cls, config): spectrum_frequency_grid = quantity_linspace( @@ -189,4 +207,5 @@ def from_config(cls, config): return cls( transport_state=None, spectrum_frequency_grid=spectrum_frequency_grid, + integrator_settings=config.spectrum.integrated, ) diff --git a/tardis/spectrum/tests/test_spectrum_solver.py b/tardis/spectrum/tests/test_spectrum_solver.py index 001c6ac4230..6907cda2c49 100644 --- a/tardis/spectrum/tests/test_spectrum_solver.py +++ b/tardis/spectrum/tests/test_spectrum_solver.py @@ -87,3 +87,28 @@ def test_spectrum_real_packets_reabsorbed(self, simulation): result, luminosity, ) + + def test_solve(self, simulation): + transport_state = simulation.transport.transport_state + spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid + + solver = SpectrumSolver(transport_state, spectrum_frequency_grid) + result_real, result_virtual, result_integrated = solver.solve() + key_real = "simulation/spectrum_solver/spectrum_real_packets/luminosity" + expected_real = self.get_expected_data(key_real) + + luminosity_real = u.Quantity(expected_real, "erg /s") + + assert_quantity_allclose( + result_real.luminosity, + luminosity_real, + ) + + assert_quantity_allclose( + result_virtual.luminosity, + u.Quantity( + np.zeros_like(spectrum_frequency_grid.value)[:-1], "erg / s" + ), + ) + + assert result_integrated is None From 04ca7dccc87f1f890b38a91b39b4e6a5c9fd69ce Mon Sep 17 00:00:00 2001 From: Andrew Fullard Date: Wed, 24 Jul 2024 12:19:35 -0400 Subject: [PATCH 2/2] Add transport_state parameter to solve method --- tardis/spectrum/base.py | 9 ++++++++- tardis/spectrum/tests/test_spectrum_solver.py | 4 +++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tardis/spectrum/base.py b/tardis/spectrum/base.py index ba817c5540b..4ef1a913258 100644 --- a/tardis/spectrum/base.py +++ b/tardis/spectrum/base.py @@ -182,14 +182,21 @@ def calculate_reabsorbed_luminosity( luminosity_wavelength_filter ].sum() - def solve(self): + def solve(self, transport_state): """Solve the spectra + Parameters + ---------- + transport_state: MonteCarloTransportState + The transport state to be used to compute the spectra + Returns ------- tuple(TARDISSpectrum) Real, virtual and integrated spectra, if available """ + self.transport_state = transport_state + return ( self.spectrum_real_packets, self.spectrum_virtual_packets, diff --git a/tardis/spectrum/tests/test_spectrum_solver.py b/tardis/spectrum/tests/test_spectrum_solver.py index 6907cda2c49..bb82e5d822c 100644 --- a/tardis/spectrum/tests/test_spectrum_solver.py +++ b/tardis/spectrum/tests/test_spectrum_solver.py @@ -93,7 +93,9 @@ def test_solve(self, simulation): spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid solver = SpectrumSolver(transport_state, spectrum_frequency_grid) - result_real, result_virtual, result_integrated = solver.solve() + result_real, result_virtual, result_integrated = solver.solve( + transport_state + ) key_real = "simulation/spectrum_solver/spectrum_real_packets/luminosity" expected_real = self.get_expected_data(key_real)