diff --git a/docs/physics/update_and_conv/update_and_conv.ipynb b/docs/physics/update_and_conv/update_and_conv.ipynb index da9db07d389..b6d9094c3cd 100644 --- a/docs/physics/update_and_conv/update_and_conv.ipynb +++ b/docs/physics/update_and_conv/update_and_conv.ipynb @@ -462,8 +462,8 @@ "#nu_lower = tardis_config.supernova.luminosity_wavelength_end.to(u.Hz, u.spectral)\n", "#nu_upper = tardis_config.supernova.luminosity_wavelength_start.to(u.Hz, u.spectral)\n", "\n", - "from tardis.spectrum.luminosity import calculate_emitted_luminosity\n", - "L_output = calculate_emitted_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n", + "from tardis.spectrum.luminosity import calculate_filtered_luminosity\n", + "L_output = calculate_filtered_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n", "L_output" ] }, diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 9059c28127e..c4c6b3c00e6 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -21,8 +21,7 @@ from tardis.spectrum.base import SpectrumSolver from tardis.spectrum.formal_integral import FormalIntegrator from tardis.spectrum.luminosity import ( - calculate_emitted_luminosity, - calculate_reabsorbed_luminosity, + calculate_filtered_luminosity, ) from tardis.transport.montecarlo.base import MonteCarloTransportSolver from tardis.transport.montecarlo.configuration import montecarlo_globals @@ -464,13 +463,13 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): if np.sum(output_energy < 0) == len(output_energy): logger.critical("No r-packet escaped through the outer boundary.") - emitted_luminosity = calculate_emitted_luminosity( + emitted_luminosity = calculate_filtered_luminosity( transport_state.emitted_packet_nu, transport_state.emitted_packet_luminosity, self.luminosity_nu_start, self.luminosity_nu_end, ) - reabsorbed_luminosity = calculate_reabsorbed_luminosity( + reabsorbed_luminosity = calculate_filtered_luminosity( transport_state.reabsorbed_packet_nu, transport_state.reabsorbed_packet_luminosity, self.luminosity_nu_start, diff --git a/tardis/spectrum/luminosity.py b/tardis/spectrum/luminosity.py index 830b2650238..6e7f214d50c 100644 --- a/tardis/spectrum/luminosity.py +++ b/tardis/spectrum/luminosity.py @@ -2,19 +2,19 @@ import numpy as np -def calculate_emitted_luminosity( - emitted_packet_nu, - emitted_packet_luminosity, +def calculate_filtered_luminosity( + packet_nu, + packet_luminosity, luminosity_nu_start=0 * u.Hz, luminosity_nu_end=np.inf * u.Hz, ): """ - Calculate emitted luminosity. + Calculate total luminosity within a filter range. Parameters ---------- - emitted_packet_nu : - emitted_packet_luminosity : + packet_nu : astropy.units.Quantity + packet_luminosity : astropy.units.Quantity luminosity_nu_start : astropy.units.Quantity luminosity_nu_end : astropy.units.Quantity @@ -22,35 +22,8 @@ def calculate_emitted_luminosity( ------- astropy.units.Quantity """ - luminosity_wavelength_filter = (emitted_packet_nu > luminosity_nu_start) & ( - emitted_packet_nu < luminosity_nu_end + luminosity_wavelength_filter = (packet_nu > luminosity_nu_start) & ( + packet_nu < luminosity_nu_end ) - return emitted_packet_luminosity[luminosity_wavelength_filter].sum() - - -def calculate_reabsorbed_luminosity( - reabsorbed_packet_nu, - reabsorbed_packet_luminosity, - luminosity_nu_start=0 * u.Hz, - luminosity_nu_end=np.inf * u.Hz, -): - """ - Calculate reabsorbed luminosity. - - Parameters - ---------- - reabsorbed_packet_nu : - reabsorbed_packet_luminosity : - luminosity_nu_start : astropy.units.Quantity - luminosity_nu_end : astropy.units.Quantity - - Returns - ------- - astropy.units.Quantity - """ - luminosity_wavelength_filter = ( - reabsorbed_packet_nu > luminosity_nu_start - ) & (reabsorbed_packet_nu < luminosity_nu_end) - - return reabsorbed_packet_luminosity[luminosity_wavelength_filter].sum() + return packet_luminosity[luminosity_wavelength_filter].sum() diff --git a/tardis/spectrum/tests/test_luminosity.py b/tardis/spectrum/tests/test_luminosity.py new file mode 100644 index 00000000000..9acad9bf6f8 --- /dev/null +++ b/tardis/spectrum/tests/test_luminosity.py @@ -0,0 +1,60 @@ +import astropy.units as u +import numpy as np +import pytest + +from tardis.spectrum.luminosity import ( + calculate_filtered_luminosity, +) + + +@pytest.mark.parametrize( + "packet_nu, packet_luminosity, luminosity_nu_start, luminosity_nu_end, expected", + [ + # All frequencies within the range + ( + np.array([1, 2, 3]) * u.Hz, + np.array([10, 20, 30]) * u.erg / u.s, + 0 * u.Hz, + 4 * u.Hz, + 60 * u.erg / u.s, + ), + # All frequencies outside the range + ( + np.array([1, 2, 3]) * u.Hz, + np.array([10, 20, 30]) * u.erg / u.s, + 4 * u.Hz, + 5 * u.Hz, + 0 * u.erg / u.s, + ), + # Mix of frequencies within and outside the range + ( + np.array([1, 2, 3, 4]) * u.Hz, + np.array([10, 20, 30, 40]) * u.erg / u.s, + 2 * u.Hz, + 4 * u.Hz, + 30 * u.erg / u.s, + ), + # Edge case: Frequencies exactly on the boundary + ( + np.array([1, 2, 3, 4]) * u.Hz, + np.array([10, 20, 30, 40]) * u.erg / u.s, + 2 * u.Hz, + 3 * u.Hz, + 0 * u.erg / u.s, + ), + ], +) +def test_calculate_filtered_luminosity( + packet_nu, + packet_luminosity, + luminosity_nu_start, + luminosity_nu_end, + expected, +): + result = calculate_filtered_luminosity( + packet_nu, + packet_luminosity, + luminosity_nu_start, + luminosity_nu_end, + ) + assert result == expected