Skip to content

Commit

Permalink
Further cleanup and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Aug 5, 2024
1 parent dba6d35 commit 273bf66
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 42 deletions.
4 changes: 2 additions & 2 deletions docs/physics/update_and_conv/update_and_conv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,8 @@
"#nu_lower = tardis_config.supernova.luminosity_wavelength_end.to(u.Hz, u.spectral)\n",
"#nu_upper = tardis_config.supernova.luminosity_wavelength_start.to(u.Hz, u.spectral)\n",
"\n",
"from tardis.spectrum.luminosity import calculate_emitted_luminosity\n",
"L_output = calculate_emitted_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n",
"from tardis.spectrum.luminosity import calculate_filtered_luminosity\n",
"L_output = calculate_filtered_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n",
"L_output"
]
},
Expand Down
7 changes: 3 additions & 4 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
from tardis.spectrum.base import SpectrumSolver
from tardis.spectrum.formal_integral import FormalIntegrator
from tardis.spectrum.luminosity import (
calculate_emitted_luminosity,
calculate_reabsorbed_luminosity,
calculate_filtered_luminosity,
)
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
Expand Down Expand Up @@ -464,13 +463,13 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):
if np.sum(output_energy < 0) == len(output_energy):
logger.critical("No r-packet escaped through the outer boundary.")

emitted_luminosity = calculate_emitted_luminosity(
emitted_luminosity = calculate_filtered_luminosity(
transport_state.emitted_packet_nu,
transport_state.emitted_packet_luminosity,
self.luminosity_nu_start,
self.luminosity_nu_end,
)
reabsorbed_luminosity = calculate_reabsorbed_luminosity(
reabsorbed_luminosity = calculate_filtered_luminosity(
transport_state.reabsorbed_packet_nu,
transport_state.reabsorbed_packet_luminosity,
self.luminosity_nu_start,
Expand Down
45 changes: 9 additions & 36 deletions tardis/spectrum/luminosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,28 @@
import numpy as np


def calculate_emitted_luminosity(
emitted_packet_nu,
emitted_packet_luminosity,
def calculate_filtered_luminosity(
packet_nu,
packet_luminosity,
luminosity_nu_start=0 * u.Hz,
luminosity_nu_end=np.inf * u.Hz,
):
"""
Calculate emitted luminosity.
Calculate total luminosity within a filter range.
Parameters
----------
emitted_packet_nu :
emitted_packet_luminosity :
packet_nu : astropy.units.Quantity
packet_luminosity : astropy.units.Quantity
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (emitted_packet_nu > luminosity_nu_start) & (
emitted_packet_nu < luminosity_nu_end
luminosity_wavelength_filter = (packet_nu > luminosity_nu_start) & (
packet_nu < luminosity_nu_end
)

return emitted_packet_luminosity[luminosity_wavelength_filter].sum()


def calculate_reabsorbed_luminosity(
reabsorbed_packet_nu,
reabsorbed_packet_luminosity,
luminosity_nu_start=0 * u.Hz,
luminosity_nu_end=np.inf * u.Hz,
):
"""
Calculate reabsorbed luminosity.
Parameters
----------
reabsorbed_packet_nu :
reabsorbed_packet_luminosity :
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (
reabsorbed_packet_nu > luminosity_nu_start
) & (reabsorbed_packet_nu < luminosity_nu_end)

return reabsorbed_packet_luminosity[luminosity_wavelength_filter].sum()
return packet_luminosity[luminosity_wavelength_filter].sum()
60 changes: 60 additions & 0 deletions tardis/spectrum/tests/test_luminosity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import astropy.units as u
import numpy as np
import pytest

from tardis.spectrum.luminosity import (
calculate_filtered_luminosity,
)


@pytest.mark.parametrize(
"packet_nu, packet_luminosity, luminosity_nu_start, luminosity_nu_end, expected",
[
# All frequencies within the range
(
np.array([1, 2, 3]) * u.Hz,
np.array([10, 20, 30]) * u.erg / u.s,
0 * u.Hz,
4 * u.Hz,
60 * u.erg / u.s,
),
# All frequencies outside the range
(
np.array([1, 2, 3]) * u.Hz,
np.array([10, 20, 30]) * u.erg / u.s,
4 * u.Hz,
5 * u.Hz,
0 * u.erg / u.s,
),
# Mix of frequencies within and outside the range
(
np.array([1, 2, 3, 4]) * u.Hz,
np.array([10, 20, 30, 40]) * u.erg / u.s,
2 * u.Hz,
4 * u.Hz,
30 * u.erg / u.s,
),
# Edge case: Frequencies exactly on the boundary
(
np.array([1, 2, 3, 4]) * u.Hz,
np.array([10, 20, 30, 40]) * u.erg / u.s,
2 * u.Hz,
3 * u.Hz,
0 * u.erg / u.s,
),
],
)
def test_calculate_filtered_luminosity(
packet_nu,
packet_luminosity,
luminosity_nu_start,
luminosity_nu_end,
expected,
):
result = calculate_filtered_luminosity(
packet_nu,
packet_luminosity,
luminosity_nu_start,
luminosity_nu_end,
)
assert result == expected

0 comments on commit 273bf66

Please sign in to comment.