Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spectrum refactor #2689

Merged
merged 20 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/transport_montecarlo_numba_formal_integral_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from asv_runner.benchmarks.mark import parameterize
from numba import config

import tardis.transport.montecarlo.formal_integral as formal_integral
import tardis.spectrum.formal_integral as formal_integral
from benchmarks.benchmark_base import BenchmarkBase
from tardis import constants as c
from tardis.model.geometry.radial1d import NumbaRadial1DGeometry
Expand Down
8 changes: 4 additions & 4 deletions docs/io/optional/how_to_custom_source.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@
"outputs": [],
"source": [
"%matplotlib inline\n",
"plt.plot(mdl.transport.transport_state.spectrum_virtual.wavelength,\n",
" mdl.transport.transport_state.spectrum_virtual.luminosity_density_lambda,\n",
"plt.plot(mdl.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
" mdl.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
" color='red', label='truncated blackbody (custom packet source)')\n",
"plt.plot(mdl_norm.transport.transport_state.spectrum_virtual.wavelength,\n",
" mdl_norm.transport.transport_state.spectrum_virtual.luminosity_density_lambda,\n",
"plt.plot(mdl_norm.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
" mdl_norm.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
" color='blue', label='normal blackbody (default packet source)')\n",
"plt.xlabel('$\\lambda [\\AA]$')\n",
"plt.ylabel('$L_\\lambda$ [erg/s/$\\AA$]')\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/physics/spectrum/basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"source": [
"from tardis.io.configuration.config_reader import Configuration\n",
"from tardis.simulation import Simulation\n",
"from tardis.spectrum import TARDISSpectrum\n",
"from tardis.spectrum.spectrum import TARDISSpectrum\n",
"from tardis.io.atom_data.util import download_atom_data\n",
"from astropy import units as u\n",
"import numpy as np\n",
Expand Down Expand Up @@ -460,7 +460,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions docs/physics/update_and_conv/update_and_conv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@
"#nu_lower = tardis_config.supernova.luminosity_wavelength_end.to(u.Hz, u.spectral)\n",
"#nu_upper = tardis_config.supernova.luminosity_wavelength_start.to(u.Hz, u.spectral)\n",
"\n",
"L_output = transport.transport_state.calculate_emitted_luminosity(0,np.inf)\n",
"L_output = sim.spectrum_solver.calculate_emitted_luminosity(0,np.inf)\n",
"L_output"
]
},
Expand Down Expand Up @@ -517,7 +517,7 @@
},
"outputs": [],
"source": [
"sim.advance_state()"
"sim.advance_state(emitted_luminosity=L_output)"
]
},
{
Expand Down Expand Up @@ -596,7 +596,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.3"
},
"vscode": {
"interpreter": {
Expand Down
15 changes: 11 additions & 4 deletions docs/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@
"metadata": {},
"outputs": [],
"source": [
"spectrum = sim.transport.transport_state.spectrum\n",
"spectrum_virtual = sim.transport.transport_state.spectrum_virtual\n",
"spectrum_integrated = sim.transport.transport_state.spectrum_integrated"
"spectrum = sim.spectrum_solver.spectrum_real_packets\n",
"spectrum_virtual = sim.spectrum_solver.spectrum_virtual_packets\n",
"spectrum_integrated = sim.spectrum_solver.spectrum_integrated"
]
},
{
Expand All @@ -170,6 +170,13 @@
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -191,7 +198,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
3 changes: 2 additions & 1 deletion tardis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def config_montecarlo_1e5_verysimple(example_configuration_dir):
def simulation_verysimple(config_verysimple, atomic_dataset):
atomic_data = deepcopy(atomic_dataset)
sim = Simulation.from_config(config_verysimple, atom_data=atomic_data)
sim.iterate(4000)
sim.last_no_of_packets = 4000
sim.run_final()
return sim


Expand Down
53 changes: 37 additions & 16 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from tardis.io.util import HDFWriterMixin
from tardis.model import SimulationState
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.spectrum.formal_integral import FormalIntegrator
from tardis.simulation.convergence import ConvergenceSolver
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots
from tardis.spectrum.base import SpectrumSolver

# Adding logging support
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,6 +119,7 @@
"iterations_t_rad",
"iterations_electron_densities",
"iterations_t_inner",
"spectrum_solver",
]
hdf_name = "simulation"

Expand All @@ -136,6 +139,8 @@
show_convergence_plots,
convergence_plots_kwargs,
show_progress_bars,
spectrum_solver,
integrator_settings,
):
super(Simulation, self).__init__(
iterations, simulation_state.no_of_shells
Expand All @@ -153,6 +158,8 @@
self.luminosity_nu_start = luminosity_nu_start
self.luminosity_nu_end = luminosity_nu_end
self.luminosity_requested = luminosity_requested
self.spectrum_solver = spectrum_solver
self.integrator_settings = integrator_settings

Check warning on line 162 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L161-L162

Added lines #L161 - L162 were not covered by tests
self.show_progress_bars = show_progress_bars
self.version = tardis.__version__

Expand Down Expand Up @@ -205,14 +212,12 @@
)

def estimate_t_inner(
self, input_t_inner, luminosity_requested, t_inner_update_exponent=-0.5
self,
input_t_inner,
luminosity_requested,
emitted_luminosity,
t_inner_update_exponent=-0.5,
):
emitted_luminosity = (
self.transport.transport_state.calculate_emitted_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
)
)

luminosity_ratios = (
(emitted_luminosity / luminosity_requested).to(1).value
)
Expand Down Expand Up @@ -255,7 +260,7 @@
self.consecutive_converges_count = 0
return False

def advance_state(self):
def advance_state(self, emitted_luminosity):
"""
Advances the state of the model and the plasma for the next
iteration of the simulation. Returns True if the convergence criteria
Expand All @@ -272,6 +277,7 @@
estimated_t_inner = self.estimate_t_inner(
self.simulation_state.t_inner,
self.luminosity_requested,
emitted_luminosity,
t_inner_update_exponent=self.convergence_strategy.t_inner_update_exponent,
)

Expand Down Expand Up @@ -381,27 +387,31 @@
iteration=self.iterations_executed,
)

self.transport.run(
v_packets_energy_hist = self.transport.run(

Check warning on line 390 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L390

Added line #L390 was not covered by tests
transport_state,
time_explosion=self.simulation_state.time_explosion,
iteration=self.iterations_executed,
total_iterations=self.iterations,
show_progress_bars=self.show_progress_bars,
)

# Set up spectrum solver
self.spectrum_solver.transport_state = transport_state
self.spectrum_solver._montecarlo_virtual_luminosity.value[

Check warning on line 400 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L399-L400

Added lines #L399 - L400 were not covered by tests
andrewfullard marked this conversation as resolved.
Show resolved Hide resolved
:
] = v_packets_energy_hist

output_energy = (
self.transport.transport_state.packet_collection.output_energies
)
if np.sum(output_energy < 0) == len(output_energy):
logger.critical("No r-packet escaped through the outer boundary.")

emitted_luminosity = (
self.transport.transport_state.calculate_emitted_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
)
emitted_luminosity = self.spectrum_solver.calculate_emitted_luminosity(

Check warning on line 410 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L410

Added line #L410 was not covered by tests
self.luminosity_nu_start, self.luminosity_nu_end
)
reabsorbed_luminosity = (
self.transport.transport_state.calculate_reabsorbed_luminosity(
self.spectrum_solver.calculate_reabsorbed_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
)
)
Expand All @@ -424,6 +434,7 @@

self.log_run_results(emitted_luminosity, reabsorbed_luminosity)
self.iterations_executed += 1
return emitted_luminosity

Check warning on line 437 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L437

Added line #L437 was not covered by tests

def run_convergence(self):
"""
Expand All @@ -438,8 +449,8 @@
self.plasma.electron_densities,
self.simulation_state.t_inner,
)
self.iterate(self.no_of_packets)
self.converged = self.advance_state()
emitted_luminosity = self.iterate(self.no_of_packets)
self.converged = self.advance_state(emitted_luminosity)

Check warning on line 453 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L452-L453

Added lines #L452 - L453 were not covered by tests
if hasattr(self, "convergence_plots"):
self.convergence_plots.update()
self._call_back()
Expand All @@ -465,6 +476,12 @@
)
self.iterate(self.last_no_of_packets, self.no_of_virtual_packets)

# Set up spectrum solver integrator
self.spectrum_solver.integrator_settings = self.integrator_settings
self.spectrum_solver._integrator = FormalIntegrator(

Check warning on line 481 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L480-L481

Added lines #L480 - L481 were not covered by tests
self.simulation_state, self.plasma, self.transport
)

self.reshape_plasma_state_store(self.iterations_executed)
if hasattr(self, "convergence_plots"):
self.convergence_plots.fetch_data(
Expand Down Expand Up @@ -737,6 +754,8 @@
last_no_of_packets = config.montecarlo.no_of_packets
last_no_of_packets = int(last_no_of_packets)

spectrum_solver = SpectrumSolver.from_config(config)

Check warning on line 757 in tardis/simulation/base.py

View check run for this annotation

Codecov / codecov/patch

tardis/simulation/base.py#L757

Added line #L757 was not covered by tests

return cls(
iterations=config.montecarlo.iterations,
simulation_state=simulation_state,
Expand All @@ -752,4 +771,6 @@
convergence_strategy=config.montecarlo.convergence_strategy,
convergence_plots_kwargs=convergence_plots_kwargs,
show_progress_bars=show_progress_bars,
integrator_settings=config.spectrum.integrated,
spectrum_solver=spectrum_solver,
)
Empty file added tardis/spectrum/__init__.py
Empty file.
Loading
Loading