From 30427ff00f444d30d74f7967a8ff8cb26c09ae7a Mon Sep 17 00:00:00 2001 From: Andrew Date: Fri, 9 Aug 2024 13:59:22 -0400 Subject: [PATCH] Minor refactor of the spectrum solver (#2759) * Factor out luminosity calculation, add setup solver method * Fixes docs * Further cleanup and tests * Fix formal integral benchmark --- benchmarks/spectrum_formal_integral.py | 4 +- .../update_and_conv/update_and_conv.ipynb | 3 +- tardis/simulation/base.py | 45 +++++++------ tardis/spectrum/base.py | 64 ++++++------------- tardis/spectrum/luminosity.py | 29 +++++++++ tardis/spectrum/tests/test_luminosity.py | 60 +++++++++++++++++ 6 files changed, 138 insertions(+), 67 deletions(-) create mode 100644 tardis/spectrum/luminosity.py create mode 100644 tardis/spectrum/tests/test_luminosity.py diff --git a/benchmarks/spectrum_formal_integral.py b/benchmarks/spectrum_formal_integral.py index 8c69ea71bce..f557a6dffbf 100644 --- a/benchmarks/spectrum_formal_integral.py +++ b/benchmarks/spectrum_formal_integral.py @@ -35,10 +35,10 @@ def time_intensity_black_body(self): # Benchmark for functions in FormalIntegrator class def time_FormalIntegrator_functions(self): self.FormalIntegrator.calculate_spectrum( - self.sim.spectrum_solver.spectrum_real_packets.frequency + self.sim.spectrum_solver.spectrum_frequency_grid ) self.FormalIntegrator.make_source_function() self.FormalIntegrator.generate_numba_objects() self.FormalIntegrator.formal_integral( - self.sim.spectrum_solver.spectrum_real_packets.frequency, 1000 + self.sim.spectrum_solver.spectrum_frequency_grid, 1000 ) diff --git a/docs/physics/update_and_conv/update_and_conv.ipynb b/docs/physics/update_and_conv/update_and_conv.ipynb index cbc5c0c9618..b6d9094c3cd 100644 --- a/docs/physics/update_and_conv/update_and_conv.ipynb +++ b/docs/physics/update_and_conv/update_and_conv.ipynb @@ -462,7 +462,8 @@ "#nu_lower = tardis_config.supernova.luminosity_wavelength_end.to(u.Hz, u.spectral)\n", "#nu_upper = tardis_config.supernova.luminosity_wavelength_start.to(u.Hz, u.spectral)\n", "\n", - "L_output = sim.spectrum_solver.calculate_emitted_luminosity(0,np.inf)\n", + "from tardis.spectrum.luminosity import calculate_filtered_luminosity\n", + "L_output = calculate_filtered_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n", "L_output" ] }, diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 599fe2b0caf..c4c6b3c00e6 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -20,6 +20,9 @@ from tardis.simulation.convergence import ConvergenceSolver from tardis.spectrum.base import SpectrumSolver from tardis.spectrum.formal_integral import FormalIntegrator +from tardis.spectrum.luminosity import ( + calculate_filtered_luminosity, +) from tardis.transport.montecarlo.base import MonteCarloTransportSolver from tardis.transport.montecarlo.configuration import montecarlo_globals from tardis.transport.montecarlo.estimators.continuum_radfield_properties import ( @@ -454,25 +457,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): show_progress_bars=self.show_progress_bars, ) - # Set up spectrum solver - self.spectrum_solver.transport_state = transport_state - self.spectrum_solver._montecarlo_virtual_luminosity.value[ - : - ] = v_packets_energy_hist - output_energy = ( self.transport.transport_state.packet_collection.output_energies ) if np.sum(output_energy < 0) == len(output_energy): logger.critical("No r-packet escaped through the outer boundary.") - emitted_luminosity = self.spectrum_solver.calculate_emitted_luminosity( - self.luminosity_nu_start, self.luminosity_nu_end + emitted_luminosity = calculate_filtered_luminosity( + transport_state.emitted_packet_nu, + transport_state.emitted_packet_luminosity, + self.luminosity_nu_start, + self.luminosity_nu_end, ) - reabsorbed_luminosity = ( - self.spectrum_solver.calculate_reabsorbed_luminosity( - self.luminosity_nu_start, self.luminosity_nu_end - ) + reabsorbed_luminosity = calculate_filtered_luminosity( + transport_state.reabsorbed_packet_nu, + transport_state.reabsorbed_packet_luminosity, + self.luminosity_nu_start, + self.luminosity_nu_end, ) if hasattr(self, "convergence_plots"): self.convergence_plots.fetch_data( @@ -493,7 +494,7 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): self.log_run_results(emitted_luminosity, reabsorbed_luminosity) self.iterations_executed += 1 - return emitted_luminosity + return emitted_luminosity, v_packets_energy_hist def run_convergence(self): """ @@ -508,7 +509,9 @@ def run_convergence(self): self.plasma.electron_densities, self.simulation_state.t_inner, ) - emitted_luminosity = self.iterate(self.no_of_packets) + emitted_luminosity, v_packets_energy_hist = self.iterate( + self.no_of_packets + ) self.converged = self.advance_state(emitted_luminosity) if hasattr(self, "convergence_plots"): self.convergence_plots.update() @@ -533,11 +536,17 @@ def run_final(self): self.plasma.electron_densities, self.simulation_state.t_inner, ) - self.iterate(self.last_no_of_packets, self.no_of_virtual_packets) + emitted_luminosity, v_packets_energy_hist = self.iterate( + self.last_no_of_packets, self.no_of_virtual_packets + ) - # Set up spectrum solver integrator - self.spectrum_solver._integrator = FormalIntegrator( - self.simulation_state, self.plasma, self.transport + # Set up spectrum solver integrator and virtual spectrum + self.spectrum_solver.setup_optional_spectra( + self.transport.transport_state, + v_packets_energy_hist, + FormalIntegrator( + self.simulation_state, self.plasma, self.transport + ), ) self.reshape_plasma_state_store(self.iterations_executed) diff --git a/tardis/spectrum/base.py b/tardis/spectrum/base.py index 4ef1a913258..2780e893783 100644 --- a/tardis/spectrum/base.py +++ b/tardis/spectrum/base.py @@ -34,6 +34,24 @@ def __init__( self.integrator_settings = integrator_settings self._spectrum_integrated = None + def setup_optional_spectra( + self, transport_state, virtual_packet_luminosity=None, integrator=None + ): + """Set up the solver to handle virtual and integrated spectra + + Parameters + ---------- + virtual_packet_luminosity : np.ndarray, optional + Virtual packet luminosity, unnormalized, by default None + integrator : FormalIntegrator, optional + Integrator to compute the integrated spectrum with, by default None + """ + self.transport_state = transport_state + self._montecarlo_virtual_luminosity = ( + virtual_packet_luminosity * u.erg / u.s + ) + self._integrator = integrator + @property def spectrum_real_packets(self): return TARDISSpectrum( @@ -136,52 +154,6 @@ def montecarlo_virtual_luminosity(self): / self.transport_state.time_of_simulation.value ) - def calculate_emitted_luminosity( - self, luminosity_nu_start, luminosity_nu_end - ): - """ - Calculate emitted luminosity. - - Parameters - ---------- - luminosity_nu_start : astropy.units.Quantity - luminosity_nu_end : astropy.units.Quantity - - Returns - ------- - astropy.units.Quantity - """ - luminosity_wavelength_filter = ( - self.transport_state.emitted_packet_nu > luminosity_nu_start - ) & (self.transport_state.emitted_packet_nu < luminosity_nu_end) - - return self.transport_state.emitted_packet_luminosity[ - luminosity_wavelength_filter - ].sum() - - def calculate_reabsorbed_luminosity( - self, luminosity_nu_start, luminosity_nu_end - ): - """ - Calculate reabsorbed luminosity. - - Parameters - ---------- - luminosity_nu_start : astropy.units.Quantity - luminosity_nu_end : astropy.units.Quantity - - Returns - ------- - astropy.units.Quantity - """ - luminosity_wavelength_filter = ( - self.transport_state.reabsorbed_packet_nu > luminosity_nu_start - ) & (self.transport_state.reabsorbed_packet_nu < luminosity_nu_end) - - return self.transport_state.reabsorbed_packet_luminosity[ - luminosity_wavelength_filter - ].sum() - def solve(self, transport_state): """Solve the spectra diff --git a/tardis/spectrum/luminosity.py b/tardis/spectrum/luminosity.py new file mode 100644 index 00000000000..6e7f214d50c --- /dev/null +++ b/tardis/spectrum/luminosity.py @@ -0,0 +1,29 @@ +import astropy.units as u +import numpy as np + + +def calculate_filtered_luminosity( + packet_nu, + packet_luminosity, + luminosity_nu_start=0 * u.Hz, + luminosity_nu_end=np.inf * u.Hz, +): + """ + Calculate total luminosity within a filter range. + + Parameters + ---------- + packet_nu : astropy.units.Quantity + packet_luminosity : astropy.units.Quantity + luminosity_nu_start : astropy.units.Quantity + luminosity_nu_end : astropy.units.Quantity + + Returns + ------- + astropy.units.Quantity + """ + luminosity_wavelength_filter = (packet_nu > luminosity_nu_start) & ( + packet_nu < luminosity_nu_end + ) + + return packet_luminosity[luminosity_wavelength_filter].sum() diff --git a/tardis/spectrum/tests/test_luminosity.py b/tardis/spectrum/tests/test_luminosity.py new file mode 100644 index 00000000000..9acad9bf6f8 --- /dev/null +++ b/tardis/spectrum/tests/test_luminosity.py @@ -0,0 +1,60 @@ +import astropy.units as u +import numpy as np +import pytest + +from tardis.spectrum.luminosity import ( + calculate_filtered_luminosity, +) + + +@pytest.mark.parametrize( + "packet_nu, packet_luminosity, luminosity_nu_start, luminosity_nu_end, expected", + [ + # All frequencies within the range + ( + np.array([1, 2, 3]) * u.Hz, + np.array([10, 20, 30]) * u.erg / u.s, + 0 * u.Hz, + 4 * u.Hz, + 60 * u.erg / u.s, + ), + # All frequencies outside the range + ( + np.array([1, 2, 3]) * u.Hz, + np.array([10, 20, 30]) * u.erg / u.s, + 4 * u.Hz, + 5 * u.Hz, + 0 * u.erg / u.s, + ), + # Mix of frequencies within and outside the range + ( + np.array([1, 2, 3, 4]) * u.Hz, + np.array([10, 20, 30, 40]) * u.erg / u.s, + 2 * u.Hz, + 4 * u.Hz, + 30 * u.erg / u.s, + ), + # Edge case: Frequencies exactly on the boundary + ( + np.array([1, 2, 3, 4]) * u.Hz, + np.array([10, 20, 30, 40]) * u.erg / u.s, + 2 * u.Hz, + 3 * u.Hz, + 0 * u.erg / u.s, + ), + ], +) +def test_calculate_filtered_luminosity( + packet_nu, + packet_luminosity, + luminosity_nu_start, + luminosity_nu_end, + expected, +): + result = calculate_filtered_luminosity( + packet_nu, + packet_luminosity, + luminosity_nu_start, + luminosity_nu_end, + ) + assert result == expected