Skip to content

Commit

Permalink
Factor out luminosity calculation, add setup solver method
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Aug 5, 2024
1 parent e1aa887 commit 4401047
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 64 deletions.
46 changes: 28 additions & 18 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from tardis.simulation.convergence import ConvergenceSolver
from tardis.spectrum.base import SpectrumSolver
from tardis.spectrum.formal_integral import FormalIntegrator
from tardis.spectrum.luminosity import (
calculate_emitted_luminosity,
calculate_reabsorbed_luminosity,
)
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.estimators.continuum_radfield_properties import (
Expand Down Expand Up @@ -454,25 +458,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):
show_progress_bars=self.show_progress_bars,
)

# Set up spectrum solver
self.spectrum_solver.transport_state = transport_state
self.spectrum_solver._montecarlo_virtual_luminosity.value[
:
] = v_packets_energy_hist

output_energy = (
self.transport.transport_state.packet_collection.output_energies
)
if np.sum(output_energy < 0) == len(output_energy):
logger.critical("No r-packet escaped through the outer boundary.")

emitted_luminosity = self.spectrum_solver.calculate_emitted_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
emitted_luminosity = calculate_emitted_luminosity(
transport_state.emitted_packet_nu,
transport_state.emitted_packet_luminosity,
self.luminosity_nu_start,
self.luminosity_nu_end,
)
reabsorbed_luminosity = (
self.spectrum_solver.calculate_reabsorbed_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
)
reabsorbed_luminosity = calculate_reabsorbed_luminosity(
transport_state.reabsorbed_packet_nu,
transport_state.reabsorbed_packet_luminosity,
self.luminosity_nu_start,
self.luminosity_nu_end,
)
if hasattr(self, "convergence_plots"):
self.convergence_plots.fetch_data(
Expand All @@ -493,7 +495,7 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):

self.log_run_results(emitted_luminosity, reabsorbed_luminosity)
self.iterations_executed += 1
return emitted_luminosity
return emitted_luminosity, v_packets_energy_hist

def run_convergence(self):
"""
Expand All @@ -508,7 +510,9 @@ def run_convergence(self):
self.plasma.electron_densities,
self.simulation_state.t_inner,
)
emitted_luminosity = self.iterate(self.no_of_packets)
emitted_luminosity, v_packets_energy_hist = self.iterate(
self.no_of_packets
)
self.converged = self.advance_state(emitted_luminosity)
if hasattr(self, "convergence_plots"):
self.convergence_plots.update()
Expand All @@ -533,11 +537,17 @@ def run_final(self):
self.plasma.electron_densities,
self.simulation_state.t_inner,
)
self.iterate(self.last_no_of_packets, self.no_of_virtual_packets)
emitted_luminosity, v_packets_energy_hist = self.iterate(
self.last_no_of_packets, self.no_of_virtual_packets
)

# Set up spectrum solver integrator
self.spectrum_solver._integrator = FormalIntegrator(
self.simulation_state, self.plasma, self.transport
# Set up spectrum solver integrator and virtual spectrum
self.spectrum_solver.setup_optional_spectra(
self.transport.transport_state,
v_packets_energy_hist,
FormalIntegrator(
self.simulation_state, self.plasma, self.transport
),
)

self.reshape_plasma_state_store(self.iterations_executed)
Expand Down
64 changes: 18 additions & 46 deletions tardis/spectrum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def __init__(
self.integrator_settings = integrator_settings
self._spectrum_integrated = None

def setup_optional_spectra(
self, transport_state, virtual_packet_luminosity=None, integrator=None
):
"""Set up the solver to handle virtual and integrated spectra
Parameters
----------
virtual_packet_luminosity : np.ndarray, optional
Virtual packet luminosity, unnormalized, by default None
integrator : FormalIntegrator, optional
Integrator to compute the integrated spectrum with, by default None
"""
self.transport_state = transport_state
self._montecarlo_virtual_luminosity = (
virtual_packet_luminosity * u.erg / u.s
)
self._integrator = integrator

@property
def spectrum_real_packets(self):
return TARDISSpectrum(
Expand Down Expand Up @@ -136,52 +154,6 @@ def montecarlo_virtual_luminosity(self):
/ self.transport_state.time_of_simulation.value
)

def calculate_emitted_luminosity(
self, luminosity_nu_start, luminosity_nu_end
):
"""
Calculate emitted luminosity.
Parameters
----------
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (
self.transport_state.emitted_packet_nu > luminosity_nu_start
) & (self.transport_state.emitted_packet_nu < luminosity_nu_end)

return self.transport_state.emitted_packet_luminosity[
luminosity_wavelength_filter
].sum()

def calculate_reabsorbed_luminosity(
self, luminosity_nu_start, luminosity_nu_end
):
"""
Calculate reabsorbed luminosity.
Parameters
----------
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (
self.transport_state.reabsorbed_packet_nu > luminosity_nu_start
) & (self.transport_state.reabsorbed_packet_nu < luminosity_nu_end)

return self.transport_state.reabsorbed_packet_luminosity[
luminosity_wavelength_filter
].sum()

def solve(self, transport_state):
"""Solve the spectra
Expand Down
56 changes: 56 additions & 0 deletions tardis/spectrum/luminosity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import astropy.units as u
import numpy as np


def calculate_emitted_luminosity(
emitted_packet_nu,
emitted_packet_luminosity,
luminosity_nu_start=0 * u.Hz,
luminosity_nu_end=np.inf * u.Hz,
):
"""
Calculate emitted luminosity.
Parameters
----------
emitted_packet_nu :
emitted_packet_luminosity :
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (emitted_packet_nu > luminosity_nu_start) & (
emitted_packet_nu < luminosity_nu_end
)

return emitted_packet_luminosity[luminosity_wavelength_filter].sum()


def calculate_reabsorbed_luminosity(
reabsorbed_packet_nu,
reabsorbed_packet_luminosity,
luminosity_nu_start=0 * u.Hz,
luminosity_nu_end=np.inf * u.Hz,
):
"""
Calculate reabsorbed luminosity.
Parameters
----------
reabsorbed_packet_nu :
reabsorbed_packet_luminosity :
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity
Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (
reabsorbed_packet_nu > luminosity_nu_start
) & (reabsorbed_packet_nu < luminosity_nu_end)

return reabsorbed_packet_luminosity[luminosity_wavelength_filter].sum()

0 comments on commit 4401047

Please sign in to comment.