diff --git a/docs/io/output/rpacket_tracking.ipynb b/docs/io/output/rpacket_tracking.ipynb index fc4b1dc002e..b2a1127f00f 100644 --- a/docs/io/output/rpacket_tracking.ipynb +++ b/docs/io/output/rpacket_tracking.ipynb @@ -202,7 +202,7 @@ "metadata": {}, "outputs": [], "source": [ - "type(sim.transport.rpacket_tracker)" + "type(sim.transport.transport_state.rpacket_tracker)" ] }, { @@ -243,7 +243,7 @@ "metadata": {}, "outputs": [], "source": [ - "len(sim.transport.rpacket_tracker)" + "len(sim.transport.transport_state.rpacket_tracker)" ] }, { @@ -281,7 +281,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker[10].index" + "sim.transport.transport_state.rpacket_tracker[10].index" ] }, { @@ -300,7 +300,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker[10].seed" + "sim.transport.transport_state.rpacket_tracker[10].seed" ] }, { @@ -319,7 +319,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker[10].status" + "sim.transport.transport_state.rpacket_tracker[10].status" ] }, { @@ -347,7 +347,7 @@ "metadata": {}, "outputs": [], "source": [ - "len(sim.transport.rpacket_tracker[10].shell_id)" + "len(sim.transport.transport_state.rpacket_tracker[10].shell_id)" ] }, { @@ -403,7 +403,7 @@ "metadata": {}, "outputs": [], "source": [ - "type(sim.transport.rpacket_tracker_df)" + "type(sim.transport.transport_state.rpacket_tracker_df)" ] }, { @@ -461,7 +461,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker_df" + "sim.transport.transport_state.rpacket_tracker_df" ] }, { @@ -490,7 +490,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker_df.loc[10]" + "sim.transport.transport_state.rpacket_tracker_df.loc[10]" ] }, { @@ -509,7 +509,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker_df.loc[10][\"energy\"]" + "sim.transport.transport_state.rpacket_tracker_df.loc[10][\"energy\"]" ] }, { @@ -528,7 +528,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker_df.loc[10,5][\"energy\"]" + "sim.transport.transport_state.rpacket_tracker_df.loc[10,5][\"energy\"]" ] }, { @@ -538,7 +538,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.transport.rpacket_tracker_df.loc[10][\"energy\"][5]" + "sim.transport.transport_state.rpacket_tracker_df.loc[10][\"energy\"][5]" ] }, { @@ -567,7 +567,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/docs/io/visualization/montecarlo_packet_visualization.ipynb b/docs/io/visualization/montecarlo_packet_visualization.ipynb index fe46b67b745..1c15aac1657 100644 --- a/docs/io/visualization/montecarlo_packet_visualization.ipynb +++ b/docs/io/visualization/montecarlo_packet_visualization.ipynb @@ -89,7 +89,7 @@ "outputs": [], "source": [ "# accessing the rpacket_tracker dataframe\n", - "sim.transport.rpacket_tracker_df" + "sim.transport.transport_state.rpacket_tracker_df" ] }, { @@ -191,11 +191,11 @@ "source": [ "# function for getting coordinates of all packets\n", "def get_coordinates_multiple_packets_ints(r_packet_tracker,time):\n", - " thetas = np.linspace(0, 2*math.pi,sim.transport.rpacket_tracker_df.index[-1][0]+2)\n", + " thetas = np.linspace(0, 2*math.pi,sim.transport.transport_state.rpacket_tracker_df.index[-1][0]+2)\n", " x = []\n", " y = []\n", " inters = []\n", - " for i in range(sim.transport.rpacket_tracker_df.index[-1][0]+1):\n", + " for i in range(sim.transport.transport_state.rpacket_tracker_df.index[-1][0]+1):\n", " xs,ys,ints = get_x_y_ints_with_theta_init(r_packet_tracker.loc[i][\"r\"],r_packet_tracker.loc[i][\"mu\"],time,r_packet_tracker.loc[i][\"interaction_type\"],'n',thetas[i])\n", " x.append(xs)\n", " y.append(ys)\n", @@ -236,7 +236,7 @@ "# animated plot\n", "\n", "v_shells = sim.simulation_state.radius.value * 1e-5 / sim.simulation_state.time_explosion.value\n", - "xs,ys,ints = get_coordinates_multiple_packets_ints(sim.transport.rpacket_tracker_df,sim.simulation_state.time_explosion.value)\n", + "xs,ys,ints = get_coordinates_multiple_packets_ints(sim.transport.transport_state.rpacket_tracker_df,sim.simulation_state.time_explosion.value)\n", "xs,ys,ints,max_size = get_equal_array_size(xs,ys,ints)\n", "\n", "fig = go.Figure()\n", @@ -273,7 +273,7 @@ "\n", "#Add the packet trajectory\n", "\n", - "df = sim.transport.rpacket_tracker_df\n", + "df = sim.transport.transport_state.rpacket_tracker_df\n", "interaction_from_num = {0: \"No Interaction\", 1: \"EScattering\", 2: \"Line\"}\n", "interaction_color_from_num = {0: \"darkslategrey\", 1: \"#3366FF\", 2: \"#FF3300\"}\n", "interaction_opacity_from_num = {0: 0, 1: 1, 2: 1}\n", diff --git a/tardis/analysis.py b/tardis/analysis.py index 56c4eb1bbd5..5c1ceef1e26 100644 --- a/tardis/analysis.py +++ b/tardis/analysis.py @@ -3,25 +3,26 @@ """ import re -import os -from astropy import units as u -from tardis import constants import numpy as np import pandas as pd +from astropy import units as u + +from tardis import constants INVALID_ION_ERROR_MSG = "Atomic number, ion_number pair not present in model" -class LastLineInteraction(object): +class LastLineInteraction: @classmethod def from_simulation(cls, simulation, packet_filter_mode="packet_out_nu"): + transport_state = simulation.transport.transport_state return cls( - simulation.transport.last_line_interaction_in_id, - simulation.transport.last_line_interaction_out_id, - simulation.transport.last_line_interaction_shell_id, - simulation.transport.transport_state.packet_collection.output_nus, - simulation.transport.last_interaction_in_nu, + transport_state.last_line_interaction_in_id, + transport_state.last_line_interaction_out_id, + transport_state.last_line_interaction_shell_id, + transport_state.packet_collection.output_nus, + transport_state.last_interaction_in_nu, simulation.plasma.atomic_data.lines, packet_filter_mode, ) @@ -260,7 +261,7 @@ def onpress(event): fig.canvas.mpl_connect("on_press", onpress) -class TARDISHistory(object): +class TARDISHistory: """ Records the history of the model """ diff --git a/tardis/conftest.py b/tardis/conftest.py index 5c45a307db9..42530086111 100644 --- a/tardis/conftest.py +++ b/tardis/conftest.py @@ -232,3 +232,14 @@ def simulation_verysimple(config_verysimple, atomic_dataset): sim = Simulation.from_config(config_verysimple, atom_data=atomic_data) sim.iterate(4000) return sim + + +@pytest.fixture(scope="session") +def simulation_verysimple_vpacket_tracking(config_verysimple, atomic_dataset): + atomic_data = deepcopy(atomic_dataset) + sim = Simulation.from_config( + config_verysimple, atom_data=atomic_data, virtual_packet_logging=True + ) + sim.last_no_of_packets = 4000 + sim.run_final() + return sim diff --git a/tardis/grid/tests/test_grid.py b/tardis/grid/tests/test_grid.py index e6f84281acc..4a793c7e9a4 100644 --- a/tardis/grid/tests/test_grid.py +++ b/tardis/grid/tests/test_grid.py @@ -4,8 +4,11 @@ import pandas as pd import tardis + +from pathlib import Path import tardis.grid as grid + DATA_PATH = Path(tardis.__path__[0]) / "grid" / "tests" / "data" diff --git a/tardis/model/parse_input.py b/tardis/model/parse_input.py index 061aecb3633..01914fb8e6a 100644 --- a/tardis/model/parse_input.py +++ b/tardis/model/parse_input.py @@ -598,6 +598,14 @@ def initialize_packet_source(config, geometry, packet_source): ValueError If both t_inner and luminosity_requested are None. """ + if config.montecarlo.enable_full_relativity: + packet_source = BlackBodySimpleSourceRelativistic( + base_seed=config.montecarlo.seed, + time_explosion=config.supernova.time_explosion, + ) + else: + packet_source = BlackBodySimpleSource(base_seed=config.montecarlo.seed) + luminosity_requested = config.supernova.luminosity_requested if config.plasma.initial_t_inner > 0.0 * u.K: packet_source.radius = geometry.r_inner[0] diff --git a/tardis/montecarlo/base.py b/tardis/montecarlo/base.py index e2df4666d67..2faf4e8d000 100644 --- a/tardis/montecarlo/base.py +++ b/tardis/montecarlo/base.py @@ -1,28 +1,37 @@ import logging -import warnings +import numpy as np from astropy import units as u -from tardis import constants as const -from numba import set_num_threads -from numba import cuda +from numba import cuda, set_num_threads - -from tardis.util.base import quantity_linspace +from tardis import constants as const +from tardis.io.logger import montecarlo_tracking as mc_tracker from tardis.io.util import HDFWriterMixin -from tardis.montecarlo import packet_source as source -from tardis.montecarlo.montecarlo_numba.formal_integral import FormalIntegrator +from tardis.montecarlo import montecarlo_configuration +from tardis.montecarlo.montecarlo_configuration import ( + configuration_initialize, +) +from tardis.montecarlo.montecarlo_numba import ( + montecarlo_main_loop, + numba_config, +) from tardis.montecarlo.montecarlo_numba.estimators import initialize_estimators -from tardis.montecarlo import montecarlo_configuration as mc_config_module -from tardis.montecarlo.montecarlo_state import MonteCarloTransportState - -from tardis.montecarlo.montecarlo_numba import montecarlo_radial1d +from tardis.montecarlo.montecarlo_numba.formal_integral import FormalIntegrator from tardis.montecarlo.montecarlo_numba.numba_interface import ( - configuration_initialize, + NumbaModel, + opacity_state_initialize, +) +from tardis.montecarlo.montecarlo_numba.r_packet import ( + rpacket_trackers_to_dataframe, +) +from tardis.montecarlo.montecarlo_transport_state import ( + MonteCarloTransportState, +) +from tardis.util.base import ( + quantity_linspace, + refresh_packet_pbar, + update_iterations_pbar, ) -from tardis.montecarlo.montecarlo_numba import numba_config -from tardis.io.logger import montecarlo_tracking as mc_tracker - -import numpy as np logger = logging.getLogger(__name__) @@ -30,30 +39,11 @@ # TODO: refactor this into more parts class MonteCarloTransportSolver(HDFWriterMixin): """ - This class is designed as an interface between the Python part and the - montecarlo C-part + This class modifies the MonteCarloTransportState to solve the radiative + transfer problem. """ - hdf_properties = [ - "transport_state", - "last_interaction_in_nu", - "last_interaction_type", - "last_line_interaction_in_id", - "last_line_interaction_out_id", - "last_line_interaction_shell_id", - ] - - vpacket_hdf_properties = [ - "virt_packet_nus", - "virt_packet_energies", - "virt_packet_initial_rs", - "virt_packet_initial_mus", - "virt_packet_last_interaction_in_nu", - "virt_packet_last_interaction_type", - "virt_packet_last_line_interaction_in_id", - "virt_packet_last_line_interaction_out_id", - "virt_packet_last_line_interaction_shell_id", - ] + hdf_properties = ["transport_state"] hdf_name = "transport" @@ -70,11 +60,11 @@ def __init__( v_packet_settings, spectrum_method, packet_source, - enable_virtual_packet_logging, + enable_virtual_packet_logging=False, + enable_rpacket_tracking=False, nthreads=1, debug_packets=False, logger_buffer=1, - tracking_rpacket=False, use_gpu=False, ): # inject different packets @@ -84,7 +74,6 @@ def __init__( self.enable_reflective_inner_boundary = enable_reflective_inner_boundary self.inner_boundary_albedo = inner_boundary_albedo self.enable_full_relativity = enable_full_relativity - numba_config.ENABLE_FULL_RELATIVITY = enable_full_relativity self.line_interaction_type = line_interaction_type self.integrator_settings = integrator_settings self.v_packet_settings = v_packet_settings @@ -93,18 +82,8 @@ def __init__( self.use_gpu = use_gpu - self.virt_logging = enable_virtual_packet_logging - - # Length 2 for initialization - will be removed in next PR - self.virt_packet_last_interaction_type = np.ones(2) * -1 - self.virt_packet_last_interaction_in_nu = np.ones(2) * -1.0 - self.virt_packet_last_line_interaction_in_id = np.ones(2) * -1 - self.virt_packet_last_line_interaction_out_id = np.ones(2) * -1 - self.virt_packet_last_line_interaction_shell_id = np.ones(2) * -1 - self.virt_packet_nus = np.ones(2) * -1.0 - self.virt_packet_energies = np.ones(2) * -1.0 - self.virt_packet_initial_rs = np.ones(2) * -1.0 - self.virt_packet_initial_mus = np.ones(2) * -1.0 + self.enable_vpacket_tracking = enable_virtual_packet_logging + self.enable_rpacket_tracking = enable_rpacket_tracking self.packet_source = packet_source @@ -118,39 +97,54 @@ def __init__( mc_tracker.DEBUG_MODE = debug_packets mc_tracker.BUFFER = logger_buffer - mc_config_module.RPACKET_TRACKING = tracking_rpacket - if self.spectrum_method == "integrated": self.optional_hdf_properties.append("spectrum_integrated") - def _initialize_packets(self, no_of_packets, iteration): - # the iteration (passed as seed_offset) is added each time to preserve randomness - # across different simulations with the same temperature, - # for example. + def initialize_transport_state( + self, + simulation_state, + plasma, + no_of_packets, + no_of_virtual_packets=0, + iteration=0, + ): + if not plasma.continuum_interaction_species.empty: + gamma_shape = plasma.gamma.shape + else: + gamma_shape = (0, 0) - # Create packets - self.packet_collection = self.packet_source.create_packets( + packet_collection = self.packet_source.create_packets( no_of_packets, seed_offset=iteration ) + estimators = initialize_estimators( + plasma.tau_sobolevs.shape, gamma_shape + ) - self.last_line_interaction_in_id = -1 * np.ones( - no_of_packets, dtype=np.int64 + geometry_state = simulation_state.geometry.to_numba() + opacity_state = opacity_state_initialize( + plasma, self.line_interaction_type ) - self.last_line_interaction_out_id = -1 * np.ones( - no_of_packets, dtype=np.int64 + transport_state = MonteCarloTransportState( + packet_collection, + estimators, + spectrum_frequency=self.spectrum_frequency, + geometry_state=geometry_state, + opacity_state=opacity_state, ) - self.last_line_interaction_shell_id = -1 * np.ones( - no_of_packets, dtype=np.int64 + + transport_state.enable_full_relativity = self.enable_full_relativity + transport_state.integrator_settings = self.integrator_settings + transport_state._integrator = FormalIntegrator( + simulation_state, plasma, self ) - self.last_interaction_type = -1 * np.ones(no_of_packets, dtype=np.int64) - self.last_interaction_in_nu = np.zeros(no_of_packets, dtype=np.float64) + configuration_initialize(self, no_of_virtual_packets) + + return transport_state def run( self, - simulation_state, - plasma, - no_of_packets, - no_of_virtual_packets=0, + transport_state, + time_explosion, iteration=0, total_iterations=0, show_progress_bars=True, @@ -171,46 +165,66 @@ def run( ------- None """ - set_num_threads(self.nthreads) - - if not plasma.continuum_interaction_species.empty: - gamma_shape = plasma.gamma.shape - else: - gamma_shape = (0, 0) - - # Initializing estimator array - estimators = initialize_estimators( - plasma.tau_sobolevs.shape, gamma_shape + self.transport_state = transport_state + + numba_model = NumbaModel(time_explosion.to("s").value) + + number_of_vpackets = montecarlo_configuration.NUMBER_OF_VPACKETS + + ( + v_packets_energy_hist, + last_interaction_tracker, + vpacket_tracker, + rpacket_trackers, + ) = montecarlo_main_loop( + transport_state.packet_collection, + transport_state.geometry_state, + numba_model, + transport_state.opacity_state, + transport_state.estimators, + transport_state.spectrum_frequency.value, + number_of_vpackets, + iteration=iteration, + show_progress_bars=show_progress_bars, + total_iterations=total_iterations, + enable_virtual_packet_logging=self.enable_vpacket_tracking, ) - self._initialize_packets(no_of_packets, iteration) - - self.transport_state = MonteCarloTransportState( - self.packet_collection, - estimators, - simulation_state.volume.cgs.copy(), - spectrum_frequency=self.spectrum_frequency, - geometry_state=simulation_state.geometry.to_numba(), + transport_state._montecarlo_virtual_luminosity.value[ + : + ] = v_packets_energy_hist + transport_state.last_interaction_type = last_interaction_tracker.types + transport_state.last_interaction_in_nu = last_interaction_tracker.in_nus + transport_state.last_line_interaction_in_id = ( + last_interaction_tracker.in_ids ) - self.transport_state.enable_full_relativity = ( - self.enable_full_relativity + transport_state.last_line_interaction_out_id = ( + last_interaction_tracker.out_ids ) - self.transport_state.integrator_settings = self.integrator_settings - self.transport_state._integrator = FormalIntegrator( - simulation_state, plasma, self + transport_state.last_line_interaction_shell_id = ( + last_interaction_tracker.shell_ids ) - configuration_initialize(self, no_of_virtual_packets) - montecarlo_radial1d( - simulation_state, - plasma, - iteration, - self.packet_collection, - self.transport_state.estimators, - total_iterations, - show_progress_bars, - self, + if montecarlo_configuration.ENABLE_VPACKET_TRACKING and ( + number_of_vpackets > 0 + ): + transport_state.vpacket_tracker = vpacket_tracker + + update_iterations_pbar(1) + refresh_packet_pbar() + # Condition for Checking if RPacket Tracking is enabled + if montecarlo_configuration.ENABLE_RPACKET_TRACKING: + transport_state.rpacket_tracker = rpacket_trackers + + if self.transport_state.rpacket_tracker is not None: + self.transport_state.rpacket_tracker_df = ( + rpacket_trackers_to_dataframe( + self.transport_state.rpacket_tracker + ) + ) + transport_state.virt_logging = ( + montecarlo_configuration.ENABLE_VPACKET_TRACKING ) def legacy_return(self): @@ -219,10 +233,10 @@ def legacy_return(self): self.transport_state.packet_collection.output_energies, self.transport_state.estimators.j_estimator, self.transport_state.estimators.nu_bar_estimator, - self.last_line_interaction_in_id, - self.last_line_interaction_out_id, - self.last_interaction_type, - self.last_line_interaction_shell_id, + self.transport_state.last_line_interaction_in_id, + self.transport_state.last_line_interaction_out_id, + self.transport_state.last_interaction_type, + self.transport_state.last_line_interaction_shell_id, ) def get_line_interaction_id(self, line_interaction_type): @@ -247,7 +261,7 @@ def from_config( MontecarloTransport """ if config.plasma.disable_electron_scattering: - logger.warn( + logger.warning( "Disabling electron scattering - this is not physical." "Likely bug in formal integral - " "will not give same results." @@ -284,11 +298,11 @@ def from_config( valid values are 'GPU', 'CPU', and 'Automatic'.""" ) - mc_config_module.disable_line_scattering = ( + montecarlo_configuration.DISABLE_LINE_SCATTERING = ( config.plasma.disable_line_scattering ) - mc_config_module.INITIAL_TRACKING_ARRAY_LENGTH = ( + montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH = ( config.montecarlo.tracking.initial_array_length ) @@ -310,7 +324,7 @@ def from_config( config.spectrum.virtual.virtual_packet_logging | enable_virtual_packet_logging ), + enable_rpacket_tracking=config.montecarlo.tracking.track_rpacket, nthreads=config.montecarlo.nthreads, - tracking_rpacket=config.montecarlo.tracking.track_rpacket, use_gpu=use_gpu, ) diff --git a/tardis/montecarlo/montecarlo_configuration.py b/tardis/montecarlo/montecarlo_configuration.py index aee97013ea6..6e813298f57 100644 --- a/tardis/montecarlo/montecarlo_configuration.py +++ b/tardis/montecarlo/montecarlo_configuration.py @@ -1,18 +1,74 @@ +from astropy import units as u from tardis import constants as const +from tardis.montecarlo import ( + montecarlo_configuration as montecarlo_configuration, +) +from tardis.montecarlo.montecarlo_numba.numba_interface import ( + LineInteractionType, +) + +ENABLE_FULL_RELATIVITY = False +TEMPORARY_V_PACKET_BINS = 0 +NUMBER_OF_VPACKETS = 0 +MONTECARLO_SEED = 0 +LINE_INTERACTION_TYPE = None +PACKET_SEEDS = [] +DISABLE_ELECTRON_SCATTERING = False +DISABLE_LINE_SCATTERING = False +SURVIVAL_PROBABILITY = 0.0 +VPACKET_TAU_RUSSIAN = 10.0 -full_relativity = True -temporary_v_packet_bins = 0 -number_of_vpackets = 0 -montecarlo_seed = 0 -line_interaction_type = None -packet_seeds = [] -disable_electron_scattering = False -disable_line_scattering = False -survival_probability = 0.0 -tau_russian = 10.0 INITIAL_TRACKING_ARRAY_LENGTH = None LEGACY_MODE_ENABLED = False -VPACKET_LOGGING = False -RPACKET_TRACKING = False + +ENABLE_RPACKET_TRACKING = False CONTINUUM_PROCESSES_ENABLED = False + + +VPACKET_SPAWN_START_FREQUENCY = 0 +VPACKET_SPAWN_END_FREQUENCY = 1e200 +ENABLE_VPACKET_TRACKING = False + + +def configuration_initialize(transport, number_of_vpackets): + if transport.line_interaction_type == "macroatom": + montecarlo_configuration.LINE_INTERACTION_TYPE = ( + LineInteractionType.MACROATOM + ) + elif transport.line_interaction_type == "downbranch": + montecarlo_configuration.LINE_INTERACTION_TYPE = ( + LineInteractionType.DOWNBRANCH + ) + elif transport.line_interaction_type == "scatter": + montecarlo_configuration.LINE_INTERACTION_TYPE = ( + LineInteractionType.SCATTER + ) + else: + raise ValueError( + f'Line interaction type must be one of "macroatom",' + f'"downbranch", or "scatter" but is ' + f"{transport.line_interaction_type}" + ) + montecarlo_configuration.NUMBER_OF_VPACKETS = number_of_vpackets + montecarlo_configuration.TEMPORARY_V_PACKET_BINS = number_of_vpackets + montecarlo_configuration.ENABLE_FULL_RELATIVITY = ( + transport.enable_full_relativity + ) + montecarlo_configuration.MONTECARLO_SEED = transport.packet_source.base_seed + montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY = ( + transport.virtual_spectrum_spawn_range.end.to( + u.Hz, equivalencies=u.spectral() + ).value + ) + montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY = ( + transport.virtual_spectrum_spawn_range.start.to( + u.Hz, equivalencies=u.spectral() + ).value + ) + montecarlo_configuration.ENABLE_VPACKET_TRACKING = ( + transport.enable_vpacket_tracking + ) + montecarlo_configuration.ENABLE_RPACKET_TRACKING = ( + transport.enable_rpacket_tracking + ) diff --git a/tardis/montecarlo/montecarlo_numba/__init__.py b/tardis/montecarlo/montecarlo_numba/__init__.py index b7f1bb07f19..fb30d055475 100644 --- a/tardis/montecarlo/montecarlo_numba/__init__.py +++ b/tardis/montecarlo/montecarlo_numba/__init__.py @@ -13,7 +13,7 @@ } from tardis.montecarlo.montecarlo_numba.r_packet import RPacket -from tardis.montecarlo.montecarlo_numba.base import montecarlo_radial1d +from tardis.montecarlo.montecarlo_numba.base import montecarlo_main_loop from tardis.montecarlo.montecarlo_numba.packet_collections import ( PacketCollection, ) diff --git a/tardis/montecarlo/montecarlo_numba/base.py b/tardis/montecarlo/montecarlo_numba/base.py index f69f3546654..455423fbe88 100644 --- a/tardis/montecarlo/montecarlo_numba/base.py +++ b/tardis/montecarlo/montecarlo_numba/base.py @@ -1,174 +1,75 @@ -from numba import prange, njit, objmode -from numba.np.ufunc.parallel import get_thread_id, get_num_threads - import numpy as np +from numba import njit, objmode, prange +from numba.np.ufunc.parallel import get_num_threads, get_thread_id +from numba.typed import List + +from tardis.montecarlo import montecarlo_configuration +from tardis.montecarlo.montecarlo_numba import njit_dict from tardis.montecarlo.montecarlo_numba.estimators import Estimators +from tardis.montecarlo.montecarlo_numba.numba_interface import ( + NumbaModel, + RPacketTracker, +) from tardis.montecarlo.montecarlo_numba.packet_collections import ( VPacketCollection, + consolidate_vpacket_tracker, + initialize_last_interaction_tracker, ) - - from tardis.montecarlo.montecarlo_numba.r_packet import ( - RPacket, PacketStatus, + RPacket, ) - -from tardis.montecarlo.montecarlo_numba.numba_interface import ( - RPacketTracker, - NumbaModel, - opacity_state_initialize, -) - -from tardis.montecarlo import ( - montecarlo_configuration as montecarlo_configuration, -) - from tardis.montecarlo.montecarlo_numba.single_packet_loop import ( single_packet_loop, ) -from tardis.montecarlo.montecarlo_numba import njit_dict -from numba.typed import List -from tardis.util.base import ( - update_iterations_pbar, - update_packet_pbar, - refresh_packet_pbar, -) - - -def montecarlo_radial1d( - simulation_state, - plasma, - iteration, - packet_collection, - estimators, - total_iterations, - show_progress_bars, - transport, -): - numba_radial_1d_geometry = transport.transport_state.geometry_state - numba_model = NumbaModel( - simulation_state.time_explosion.to("s").value, - ) - opacity_state = opacity_state_initialize( - plasma, transport.line_interaction_type - ) - - number_of_vpackets = montecarlo_configuration.number_of_vpackets - ( - v_packets_energy_hist, - last_interaction_type, - last_interaction_in_nu, - last_line_interaction_in_id, - last_line_interaction_out_id, - last_line_interaction_shell_id, - virt_packet_nus, - virt_packet_energies, - virt_packet_initial_mus, - virt_packet_initial_rs, - virt_packet_last_interaction_in_nu, - virt_packet_last_interaction_type, - virt_packet_last_line_interaction_in_id, - virt_packet_last_line_interaction_out_id, - virt_packet_last_line_interaction_shell_id, - rpacket_trackers, - ) = montecarlo_main_loop( - packet_collection, - numba_radial_1d_geometry, - numba_model, - opacity_state, - estimators, - transport.spectrum_frequency.value, - number_of_vpackets, - montecarlo_configuration.VPACKET_LOGGING, - iteration=iteration, - show_progress_bars=show_progress_bars, - total_iterations=total_iterations, - ) - - transport.transport_state._montecarlo_virtual_luminosity.value[ - : - ] = v_packets_energy_hist - transport.last_interaction_type = last_interaction_type - transport.last_interaction_in_nu = last_interaction_in_nu - transport.last_line_interaction_in_id = last_line_interaction_in_id - transport.last_line_interaction_out_id = last_line_interaction_out_id - transport.last_line_interaction_shell_id = last_line_interaction_shell_id - - if montecarlo_configuration.VPACKET_LOGGING and number_of_vpackets > 0: - transport.virt_packet_nus = np.concatenate(virt_packet_nus).ravel() - transport.virt_packet_energies = np.concatenate( - virt_packet_energies - ).ravel() - transport.virt_packet_initial_mus = np.concatenate( - virt_packet_initial_mus - ).ravel() - transport.virt_packet_initial_rs = np.concatenate( - virt_packet_initial_rs - ).ravel() - transport.virt_packet_last_interaction_in_nu = np.concatenate( - virt_packet_last_interaction_in_nu - ).ravel() - transport.virt_packet_last_interaction_type = np.concatenate( - virt_packet_last_interaction_type - ).ravel() - transport.virt_packet_last_line_interaction_in_id = np.concatenate( - virt_packet_last_line_interaction_in_id - ).ravel() - transport.virt_packet_last_line_interaction_out_id = np.concatenate( - virt_packet_last_line_interaction_out_id - ).ravel() - transport.virt_packet_last_line_interaction_shell_id = np.concatenate( - virt_packet_last_line_interaction_shell_id - ).ravel() - update_iterations_pbar(1) - refresh_packet_pbar() - # Condition for Checking if RPacket Tracking is enabled - if montecarlo_configuration.RPACKET_TRACKING: - transport.rpacket_tracker = rpacket_trackers +from tardis.util.base import update_packet_pbar @njit(**njit_dict) def montecarlo_main_loop( packet_collection, - numba_radial_1d_geometry, + geometry_state, numba_model, opacity_state, estimators, spectrum_frequency, number_of_vpackets, - virtual_packet_logging, iteration, show_progress_bars, total_iterations, + enable_virtual_packet_logging, ): - """ - This is the main loop of the MonteCarlo routine that generates packets + """This is the main loop of the MonteCarlo routine that generates packets and sends them through the ejecta. + Parameters ---------- packet_collection : PacketCollection - numba_radial_1d_geometry : NumbaRadial1DGeometry + Real packet collection + geometry_state : GeometryState + Simulation geometry numba_model : NumbaModel opacity_state : OpacityState - estimators : NumbaEstimators - spectrum_frequency : astropy.units.Quantity - frequency binspas + estimators : Estimators + spectrum_frequency : astropy.units.Quantity + Frequency bins number_of_vpackets : int VPackets released per interaction - packet_seeds : numpy.array - virtual_packet_logging : bool - Option to enable virtual packet logging. + iteration : int + Current iteration number + show_progress_bars : bool + Display progress bars + total_iterations : int + Maximum number of iterations + enable_virtual_packet_logging : bool + Enable virtual packet tracking """ no_of_packets = len(packet_collection.initial_nus) - output_nus = np.empty(no_of_packets, dtype=np.float64) - output_energies = np.empty(no_of_packets, dtype=np.float64) - last_interaction_in_nus = np.empty(no_of_packets, dtype=np.float64) - last_interaction_types = -1 * np.ones(no_of_packets, dtype=np.int64) - last_line_interaction_in_ids = -np.ones(no_of_packets, dtype=np.int64) - last_line_interaction_out_ids = -np.ones(no_of_packets, dtype=np.int64) - last_line_interaction_shell_ids = -np.ones(no_of_packets, dtype=np.int64) + last_interaction_tracker = initialize_last_interaction_tracker( + no_of_packets + ) v_packets_energy_hist = np.zeros_like(spectrum_frequency) delta_nu = spectrum_frequency[1] - spectrum_frequency[0] @@ -177,15 +78,15 @@ def montecarlo_main_loop( vpacket_collections = List() # Configuring the Tracking for R_Packets rpacket_trackers = List() - for i in range(len(output_nus)): + for i in range(no_of_packets): vpacket_collections.append( VPacketCollection( i, spectrum_frequency, - montecarlo_configuration.v_packet_spawn_start_frequency, - montecarlo_configuration.v_packet_spawn_end_frequency, + montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY, + montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY, number_of_vpackets, - montecarlo_configuration.temporary_v_packet_bins, + montecarlo_configuration.TEMPORARY_V_PACKET_BINS, ) ) rpacket_trackers.append(RPacketTracker()) @@ -195,9 +96,10 @@ def montecarlo_main_loop( n_threads = get_num_threads() estimator_list = List() - for i in range(n_threads): # betting get tid goes from 0 to num threads - # Note that get_thread_id() returns values from 0 to n_threads-1, - # so we iterate from 0 to n_threads-1 to create the estimator_list + # betting get thread_id goes from 0 to num threads + # Note that get_thread_id() returns values from 0 to n_threads-1, + # so we iterate from 0 to n_threads-1 to create the estimator_list + for i in range(n_threads): estimator_list.append( Estimators( np.copy(estimators.j_estimator), @@ -211,20 +113,11 @@ def montecarlo_main_loop( np.copy(estimators.photo_ion_estimator_statistics), ) ) - # Arrays for vpacket logging - virt_packet_nus = [] - virt_packet_energies = [] - virt_packet_initial_mus = [] - virt_packet_initial_rs = [] - virt_packet_last_interaction_in_nu = [] - virt_packet_last_interaction_type = [] - virt_packet_last_line_interaction_in_id = [] - virt_packet_last_line_interaction_out_id = [] - virt_packet_last_line_interaction_shell_id = [] - for i in prange(len(output_nus)): - tid = get_thread_id() + + for i in prange(no_of_packets): + thread_id = get_thread_id() if show_progress_bars: - if tid == main_thread_id: + if thread_id == main_thread_id: with objmode: update_amount = 1 * n_threads update_packet_pbar( @@ -242,136 +135,75 @@ def montecarlo_main_loop( packet_collection.packet_seeds[i], i, ) + # Seed the random number generator np.random.seed(r_packet.seed) - local_estimators = estimator_list[tid] + + # Get the local estimators for this thread + local_estimators = estimator_list[thread_id] + + # Get the local v_packet_collection for this thread vpacket_collection = vpacket_collections[i] + + # RPacket Tracker for this thread rpacket_tracker = rpacket_trackers[i] loop = single_packet_loop( r_packet, - numba_radial_1d_geometry, + geometry_state, numba_model, opacity_state, local_estimators, vpacket_collection, rpacket_tracker, ) + packet_collection.output_nus[i] = r_packet.nu - output_nus[i] = r_packet.nu - last_interaction_in_nus[i] = r_packet.last_interaction_in_nu - last_line_interaction_in_ids[i] = r_packet.last_line_interaction_in_id - last_line_interaction_out_ids[i] = r_packet.last_line_interaction_out_id - last_line_interaction_shell_ids[ - i - ] = r_packet.last_line_interaction_shell_id + last_interaction_tracker.update_last_interaction(r_packet, i) if r_packet.status == PacketStatus.REABSORBED: - output_energies[i] = -r_packet.energy - last_interaction_types[i] = r_packet.last_interaction_type + packet_collection.output_energies[i] = -r_packet.energy + last_interaction_tracker.types[i] = r_packet.last_interaction_type elif r_packet.status == PacketStatus.EMITTED: - output_energies[i] = r_packet.energy - last_interaction_types[i] = r_packet.last_interaction_type + packet_collection.output_energies[i] = r_packet.energy + last_interaction_tracker.types[i] = r_packet.last_interaction_type - vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] - vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx] - vpackets_initial_mu = vpacket_collection.initial_mus[ - : vpacket_collection.idx - ] - vpackets_initial_r = vpacket_collection.initial_rs[ - : vpacket_collection.idx - ] + vpacket_collection.finalize_arrays() v_packets_idx = np.floor( - (vpackets_nu - spectrum_frequency[0]) / delta_nu + (vpacket_collection.nus - spectrum_frequency[0]) / delta_nu ).astype(np.int64) for j, idx in enumerate(v_packets_idx): - if (vpackets_nu[j] < spectrum_frequency[0]) or ( - vpackets_nu[j] > spectrum_frequency[-1] + if (vpacket_collection.nus[j] < spectrum_frequency[0]) or ( + vpacket_collection.nus[j] > spectrum_frequency[-1] ): continue - v_packets_energy_hist[idx] += vpackets_energy[j] + v_packets_energy_hist[idx] += vpacket_collection.energies[j] for sub_estimator in estimator_list: estimators.increment(sub_estimator) - if virtual_packet_logging: - for vpacket_collection in vpacket_collections: - vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] - vpackets_energy = vpacket_collection.energies[ - : vpacket_collection.idx - ] - vpackets_initial_mu = vpacket_collection.initial_mus[ - : vpacket_collection.idx - ] - vpackets_initial_r = vpacket_collection.initial_rs[ - : vpacket_collection.idx - ] - virt_packet_nus.append(np.ascontiguousarray(vpackets_nu)) - virt_packet_energies.append(np.ascontiguousarray(vpackets_energy)) - virt_packet_initial_mus.append( - np.ascontiguousarray(vpackets_initial_mu) - ) - virt_packet_initial_rs.append( - np.ascontiguousarray(vpackets_initial_r) - ) - virt_packet_last_interaction_in_nu.append( - np.ascontiguousarray( - vpacket_collection.last_interaction_in_nu[ - : vpacket_collection.idx - ] - ) - ) - virt_packet_last_interaction_type.append( - np.ascontiguousarray( - vpacket_collection.last_interaction_type[ - : vpacket_collection.idx - ] - ) - ) - virt_packet_last_line_interaction_in_id.append( - np.ascontiguousarray( - vpacket_collection.last_interaction_in_id[ - : vpacket_collection.idx - ] - ) - ) - virt_packet_last_line_interaction_out_id.append( - np.ascontiguousarray( - vpacket_collection.last_interaction_out_id[ - : vpacket_collection.idx - ] - ) - ) - virt_packet_last_line_interaction_shell_id.append( - np.ascontiguousarray( - vpacket_collection.last_interaction_shell_id[ - : vpacket_collection.idx - ] - ) - ) + if enable_virtual_packet_logging: + vpacket_tracker = consolidate_vpacket_tracker( + vpacket_collections, spectrum_frequency + ) + else: + vpacket_tracker = VPacketCollection( + -1, + spectrum_frequency, + montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY, + montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY, + -1, + 1, + ) - if montecarlo_configuration.RPACKET_TRACKING: + if montecarlo_configuration.ENABLE_RPACKET_TRACKING: for rpacket_tracker in rpacket_trackers: rpacket_tracker.finalize_array() - packet_collection.output_energies[:] = output_energies[:] - packet_collection.output_nus[:] = output_nus[:] return ( v_packets_energy_hist, - last_interaction_types, - last_interaction_in_nus, - last_line_interaction_in_ids, - last_line_interaction_out_ids, - last_line_interaction_shell_ids, - virt_packet_nus, - virt_packet_energies, - virt_packet_initial_mus, - virt_packet_initial_rs, - virt_packet_last_interaction_in_nu, - virt_packet_last_interaction_type, - virt_packet_last_line_interaction_in_id, - virt_packet_last_line_interaction_out_id, - virt_packet_last_line_interaction_shell_id, + last_interaction_tracker, + vpacket_tracker, rpacket_trackers, ) diff --git a/tardis/montecarlo/montecarlo_numba/estimators.py b/tardis/montecarlo/montecarlo_numba/estimators.py index 3bae2ec640a..5e3ddf6514f 100644 --- a/tardis/montecarlo/montecarlo_numba/estimators.py +++ b/tardis/montecarlo/montecarlo_numba/estimators.py @@ -5,7 +5,7 @@ from numba import njit, float64, int64 from numba.experimental import jitclass -from tardis.montecarlo.montecarlo_numba import numba_config as nc +from tardis.montecarlo import montecarlo_configuration as nc from tardis.montecarlo.montecarlo_numba.numba_config import H, KB from tardis.montecarlo.montecarlo_numba import ( diff --git a/tardis/montecarlo/montecarlo_numba/formal_integral.py b/tardis/montecarlo/montecarlo_numba/formal_integral.py index c1339989865..27aa54d283e 100644 --- a/tardis/montecarlo/montecarlo_numba/formal_integral.py +++ b/tardis/montecarlo/montecarlo_numba/formal_integral.py @@ -11,7 +11,9 @@ from tardis.montecarlo.montecarlo_numba.numba_config import SIGMA_THOMSON -from tardis.montecarlo import montecarlo_configuration as mc_config_module +from tardis.montecarlo import ( + montecarlo_configuration as montecarlo_configuration, +) from tardis.montecarlo.montecarlo_numba import njit_dict, njit_dict_no_parallel from tardis.montecarlo.montecarlo_numba.numba_interface import ( opacity_state_initialize, @@ -359,7 +361,7 @@ def raise_or_return(message): 'and line_interaction_type == "macroatom"' ) - if mc_config_module.CONTINUUM_PROCESSES_ENABLED: + if montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED: return raise_or_return( "The FormalIntegrator currently does not work for " "continuum interactions." @@ -416,10 +418,9 @@ def make_source_function(self): ------- Numpy array containing ( 1 - exp(-tau_ul) ) S_ul ordered by wavelength of the transition u -> l """ - simulation_state = self.simulation_state transport = self.transport - mct_state = transport.transport_state + montecarlo_transport_state = transport.transport_state # macro_ref = self.atomic_data.macro_atom_references macro_ref = self.atomic_data.macro_atom_references @@ -440,12 +441,14 @@ def make_source_function(self): destination_level_idx = ma_int_data.destination_level_idx.values Edotlu_norm_factor = 1 / ( - mct_state.packet_collection.time_of_simulation + montecarlo_transport_state.packet_collection.time_of_simulation * simulation_state.volume ) exptau = 1 - np.exp(-self.original_plasma.tau_sobolevs) Edotlu = ( - Edotlu_norm_factor * exptau * mct_state.estimators.Edotlu_estimator + Edotlu_norm_factor + * exptau + * montecarlo_transport_state.estimators.Edotlu_estimator ) # The following may be achieved by calling the appropriate plasma @@ -457,7 +460,7 @@ def make_source_function(self): / ( 4 * np.pi - * mct_state.time_of_simulation + * montecarlo_transport_state.time_of_simulation * simulation_state.volume ) ) @@ -536,8 +539,12 @@ def make_source_function(self): att_S_ul, Jredlu, Jbluelu, e_dot_u ) else: - transport.r_inner_i = mct_state.geometry_state.r_inner - transport.r_outer_i = mct_state.geometry_state.r_outer + transport.r_inner_i = ( + montecarlo_transport_state.geometry_state.r_inner + ) + transport.r_outer_i = ( + montecarlo_transport_state.geometry_state.r_outer + ) transport.tau_sobolevs_integ = ( self.original_plasma.tau_sobolevs.values ) diff --git a/tardis/montecarlo/montecarlo_numba/interaction.py b/tardis/montecarlo/montecarlo_numba/interaction.py index 277ddafa57d..c8968c7ce75 100644 --- a/tardis/montecarlo/montecarlo_numba/interaction.py +++ b/tardis/montecarlo/montecarlo_numba/interaction.py @@ -320,7 +320,7 @@ def free_free_emission(r_packet, time_explosion, opacity_state): current_line_id = get_current_line_id(comov_nu, opacity_state.line_list_nu) r_packet.next_line_id = current_line_id - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: r_packet.mu = angle_aberration_CMF_to_LF( r_packet, time_explosion, r_packet.mu ) @@ -350,7 +350,7 @@ def bound_free_emission(r_packet, time_explosion, opacity_state, continuum_id): current_line_id = get_current_line_id(comov_nu, opacity_state.line_list_nu) r_packet.next_line_id = current_line_id - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: r_packet.mu = angle_aberration_CMF_to_LF( r_packet, time_explosion, r_packet.mu ) @@ -384,7 +384,7 @@ def thomson_scatter(r_packet, time_explosion): r_packet.nu = comov_nu * inverse_new_doppler_factor r_packet.energy = comov_energy * inverse_new_doppler_factor - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: r_packet.mu = angle_aberration_CMF_to_LF( r_packet, time_explosion, r_packet.mu ) @@ -462,7 +462,7 @@ def line_emission(r_packet, emission_line_id, time_explosion, opacity_state): r_packet.next_line_id = emission_line_id + 1 nu_line = opacity_state.line_list_nu[emission_line_id] - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: r_packet.mu = angle_aberration_CMF_to_LF( r_packet, time_explosion, r_packet.mu ) diff --git a/tardis/montecarlo/montecarlo_numba/numba_config.py b/tardis/montecarlo/montecarlo_numba/numba_config.py index 4bceebf975c..fb21f5421ae 100644 --- a/tardis/montecarlo/montecarlo_numba/numba_config.py +++ b/tardis/montecarlo/montecarlo_numba/numba_config.py @@ -6,5 +6,3 @@ MISS_DISTANCE = 1e99 KB = const.k_B.cgs.value H = const.h.cgs.value - -ENABLE_FULL_RELATIVITY = False diff --git a/tardis/montecarlo/montecarlo_numba/numba_interface.py b/tardis/montecarlo/montecarlo_numba/numba_interface.py index 9b8e850bfa7..6d310565622 100644 --- a/tardis/montecarlo/montecarlo_numba/numba_interface.py +++ b/tardis/montecarlo/montecarlo_numba/numba_interface.py @@ -4,7 +4,6 @@ from numba.experimental import jitclass import numpy as np -from astropy import units as u from tardis import constants as const from tardis.montecarlo import ( @@ -151,7 +150,7 @@ def opacity_state_initialize(plasma, line_interaction_type): tau_sobolev = np.ascontiguousarray( plasma.tau_sobolevs.values.copy(), dtype=np.float64 ) - if montecarlo_configuration.disable_line_scattering: + if montecarlo_configuration.DISABLE_LINE_SCATTERING: tau_sobolev *= 0 if line_interaction_type == "scatter": @@ -377,43 +376,6 @@ def finalize_array(self): self.interaction_type = self.interaction_type[: self.num_interactions] -def configuration_initialize(transport, number_of_vpackets): - if transport.line_interaction_type == "macroatom": - montecarlo_configuration.line_interaction_type = ( - LineInteractionType.MACROATOM - ) - elif transport.line_interaction_type == "downbranch": - montecarlo_configuration.line_interaction_type = ( - LineInteractionType.DOWNBRANCH - ) - elif transport.line_interaction_type == "scatter": - montecarlo_configuration.line_interaction_type = ( - LineInteractionType.SCATTER - ) - else: - raise ValueError( - f'Line interaction type must be one of "macroatom",' - f'"downbranch", or "scatter" but is ' - f"{transport.line_interaction_type}" - ) - montecarlo_configuration.number_of_vpackets = number_of_vpackets - montecarlo_configuration.temporary_v_packet_bins = number_of_vpackets - montecarlo_configuration.full_relativity = transport.enable_full_relativity - montecarlo_configuration.montecarlo_seed = transport.packet_source.base_seed - montecarlo_configuration.v_packet_spawn_start_frequency = ( - transport.virtual_spectrum_spawn_range.end.to( - u.Hz, equivalencies=u.spectral() - ).value - ) - montecarlo_configuration.v_packet_spawn_end_frequency = ( - transport.virtual_spectrum_spawn_range.start.to( - u.Hz, equivalencies=u.spectral() - ).value - ) - montecarlo_configuration.VPACKET_LOGGING = transport.virt_logging - - -# class TrackRPacket(object): class LineInteractionType(IntEnum): SCATTER = 0 DOWNBRANCH = 1 diff --git a/tardis/montecarlo/montecarlo_numba/packet_collections.py b/tardis/montecarlo/montecarlo_numba/packet_collections.py index 244125686ed..0a30889962f 100644 --- a/tardis/montecarlo/montecarlo_numba/packet_collections.py +++ b/tardis/montecarlo/montecarlo_numba/packet_collections.py @@ -1,6 +1,11 @@ import numpy as np +from numba import float64, int64, njit from numba.experimental import jitclass -from numba import float64, int64 + +from tardis.montecarlo import montecarlo_configuration +from tardis.montecarlo.montecarlo_numba import ( + njit_dict_no_parallel, +) packet_collection_spec = [ ("initial_radii", float64[:]), @@ -41,8 +46,60 @@ def __init__( ) +@njit(**njit_dict_no_parallel) +def initialize_last_interaction_tracker(no_of_packets): + last_line_interaction_in_ids = -1 * np.ones(no_of_packets, dtype=np.int64) + last_line_interaction_out_ids = -1 * np.ones(no_of_packets, dtype=np.int64) + last_line_interaction_shell_ids = -1 * np.ones( + no_of_packets, dtype=np.int64 + ) + last_interaction_types = -1 * np.ones(no_of_packets, dtype=np.int64) + last_interaction_in_nus = np.zeros(no_of_packets, dtype=np.float64) + + return LastInteractionTracker( + last_interaction_types, + last_interaction_in_nus, + last_line_interaction_in_ids, + last_line_interaction_out_ids, + last_line_interaction_shell_ids, + ) + + +last_interaction_tracker_spec = [ + ("types", int64[:]), + ("in_nus", float64[:]), + ("in_ids", int64[:]), + ("out_ids", int64[:]), + ("shell_ids", int64[:]), +] + + +@jitclass(last_interaction_tracker_spec) +class LastInteractionTracker: + def __init__( + self, + types, + in_nus, + in_ids, + out_ids, + shell_ids, + ): + self.types = types + self.in_nus = in_nus + self.in_ids = in_ids + self.out_ids = out_ids + self.shell_ids = shell_ids + + def update_last_interaction(self, r_packet, i): + self.types[i] = r_packet.last_interaction_type + self.in_nus[i] = r_packet.last_interaction_in_nu + self.in_ids[i] = r_packet.last_line_interaction_in_id + self.out_ids[i] = r_packet.last_line_interaction_out_id + self.shell_ids[i] = r_packet.last_line_interaction_shell_id + + vpacket_collection_spec = [ - ("rpacket_index", int64), + ("source_rpacket_index", int64), ("spectrum_frequency", float64[:]), ("v_packet_spawn_start_frequency", float64), ("v_packet_spawn_end_frequency", float64), @@ -62,10 +119,10 @@ def __init__( @jitclass(vpacket_collection_spec) -class VPacketCollection(object): +class VPacketCollection: def __init__( self, - rpacket_index, + source_rpacket_index, spectrum_frequency, v_packet_spawn_start_frequency, v_packet_spawn_end_frequency, @@ -96,10 +153,10 @@ def __init__( temporary_v_packet_bins, dtype=np.int64 ) self.idx = 0 - self.rpacket_index = rpacket_index + self.source_rpacket_index = source_rpacket_index self.length = temporary_v_packet_bins - def set_properties( + def add_packet( self, nu, energy, @@ -111,6 +168,35 @@ def set_properties( last_interaction_out_id, last_interaction_shell_id, ): + """ + Add a packet to the vpacket collection and potentially resizing the arrays. + + Parameters + ---------- + nu : float + Frequency of the packet. + energy : float + Energy of the packet. + initial_mu : float + Initial mu of the packet. + initial_r : float + Initial r of the packet. + last_interaction_in_nu : float + Frequency of the last interaction of the packet. + last_interaction_type : int + Type of the last interaction of the packet. + last_interaction_in_id : int + ID of the last interaction in the packet. + last_interaction_out_id : int + ID of the last interaction out of the packet. + last_interaction_shell_id : int + ID of the last interaction shell of the packet. + + Returns + ------- + None + + """ if self.idx >= self.length: temp_length = self.length * 2 + self.number_of_vpackets temp_nus = np.empty(temp_length, dtype=np.float64) @@ -168,3 +254,95 @@ def set_properties( self.last_interaction_out_id[self.idx] = last_interaction_out_id self.last_interaction_shell_id[self.idx] = last_interaction_shell_id self.idx += 1 + + def finalize_arrays(self): + """ + Finalize the arrays by truncating them based on the current index. + + Returns + ------- + None + + """ + self.nus = self.nus[: self.idx] + self.energies = self.energies[: self.idx] + self.initial_mus = self.initial_mus[: self.idx] + self.initial_rs = self.initial_rs[: self.idx] + self.last_interaction_in_nu = self.last_interaction_in_nu[: self.idx] + self.last_interaction_type = self.last_interaction_type[: self.idx] + self.last_interaction_in_id = self.last_interaction_in_id[: self.idx] + self.last_interaction_out_id = self.last_interaction_out_id[: self.idx] + self.last_interaction_shell_id = self.last_interaction_shell_id[ + : self.idx + ] + + +@njit(**njit_dict_no_parallel) +def consolidate_vpacket_tracker(vpacket_collections, spectrum_frequency): + """ + Consolidate the vpacket trackers from multiple collections into a single vpacket tracker. + + Parameters + ---------- + vpacket_collections : List[VPacketCollection] + List of vpacket collections to consolidate. + spectrum_frequency : ndarray + Array of spectrum frequencies. + + Returns + ------- + VPacketCollection + Consolidated vpacket tracker. + + """ + vpacket_tracker_length = 0 + for vpacket_collection in vpacket_collections: + vpacket_tracker_length += vpacket_collection.idx + + vpacket_tracker = VPacketCollection( + -1, + spectrum_frequency, + montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY, + montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY, + -1, + vpacket_tracker_length, + ) + current_start_vpacket_tracker_idx = 0 + for vpacket_collection in vpacket_collections: + current_end_vpacket_tracker_idx = ( + current_start_vpacket_tracker_idx + vpacket_collection.idx + ) + vpacket_tracker.nus[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.nus + vpacket_tracker.energies[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.energies + vpacket_tracker.initial_mus[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.initial_mus + vpacket_tracker.initial_rs[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.initial_rs + vpacket_tracker.last_interaction_in_nu[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.last_interaction_in_nu + + vpacket_tracker.last_interaction_type[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.last_interaction_type + + vpacket_tracker.last_interaction_in_id[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.last_interaction_in_id + + vpacket_tracker.last_interaction_out_id[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.last_interaction_out_id + + vpacket_tracker.last_interaction_shell_id[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.last_interaction_shell_id + + current_start_vpacket_tracker_idx = current_end_vpacket_tracker_idx + return vpacket_tracker diff --git a/tardis/montecarlo/montecarlo_numba/single_packet_loop.py b/tardis/montecarlo/montecarlo_numba/single_packet_loop.py index 71d319590ec..00efd249b78 100644 --- a/tardis/montecarlo/montecarlo_numba/single_packet_loop.py +++ b/tardis/montecarlo/montecarlo_numba/single_packet_loop.py @@ -64,9 +64,9 @@ def single_packet_loop( This function does not return anything but changes the r_packet object and if virtual packets are requested - also updates the vpacket_collection """ - line_interaction_type = montecarlo_configuration.line_interaction_type + line_interaction_type = montecarlo_configuration.LINE_INTERACTION_TYPE - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: set_packet_props_full_relativity(r_packet, numba_model) else: set_packet_props_partial_relativity(r_packet, numba_model) @@ -80,7 +80,7 @@ def single_packet_loop( opacity_state, ) - if montecarlo_configuration.RPACKET_TRACKING: + if montecarlo_configuration.ENABLE_RPACKET_TRACKING: rpacket_tracker.track(r_packet) # this part of the code is temporary and will be better incorporated @@ -107,7 +107,7 @@ def single_packet_loop( chi_continuum = chi_e + chi_bf_tot + chi_ff escat_prob = chi_e / chi_continuum # probability of e-scatter - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: chi_continuum *= doppler_factor distance, interaction_type, delta_shell = trace_packet( r_packet, @@ -132,7 +132,7 @@ def single_packet_loop( else: escat_prob = 1.0 chi_continuum = chi_e - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: chi_continuum *= doppler_factor distance, interaction_type, delta_shell = trace_packet( r_packet, @@ -215,7 +215,7 @@ def single_packet_loop( ) else: pass - if montecarlo_configuration.RPACKET_TRACKING: + if montecarlo_configuration.ENABLE_RPACKET_TRACKING: rpacket_tracker.track(r_packet) diff --git a/tardis/montecarlo/montecarlo_numba/tests/conftest.py b/tardis/montecarlo/montecarlo_numba/tests/conftest.py index 996f3b0d1d4..b5a68c0876f 100644 --- a/tardis/montecarlo/montecarlo_numba/tests/conftest.py +++ b/tardis/montecarlo/montecarlo_numba/tests/conftest.py @@ -70,7 +70,7 @@ def verysimple_vpacket_collection(nb_simulation_verysimple): nb_simulation_verysimple.transport.spectrum_frequency.value ) return VPacketCollection( - rpacket_index=0, + source_rpacket_index=0, spectrum_frequency=spectrum_frequency, number_of_vpackets=0, v_packet_spawn_start_frequency=0, @@ -85,7 +85,7 @@ def verysimple_3vpacket_collection(nb_simulation_verysimple): nb_simulation_verysimple.transport.spectrum_frequency.value ) return VPacketCollection( - rpacket_index=0, + source_rpacket_index=0, spectrum_frequency=spectrum_frequency, number_of_vpackets=3, v_packet_spawn_start_frequency=0, @@ -96,7 +96,7 @@ def verysimple_3vpacket_collection(nb_simulation_verysimple): @pytest.fixture(scope="package") def verysimple_packet_collection(nb_simulation_verysimple): - return nb_simulation_verysimple.transport.packet_collection + return nb_simulation_verysimple.transport.transport_state.packet_collection @pytest.fixture(scope="function") diff --git a/tardis/montecarlo/montecarlo_numba/tests/test_base.py b/tardis/montecarlo/montecarlo_numba/tests/test_base.py index 1523821f213..5741412b7cb 100644 --- a/tardis/montecarlo/montecarlo_numba/tests/test_base.py +++ b/tardis/montecarlo/montecarlo_numba/tests/test_base.py @@ -90,6 +90,8 @@ def test_montecarlo_main_loop_vpacket_log( montecarlo_main_loop_simulation.run_convergence() montecarlo_main_loop_simulation.run_final() + assert montecarlo_configuration.ENABLE_VPACKET_TRACKING == True + expected_hdf_store = regression_data.sync_hdf_store( montecarlo_main_loop_simulation ) @@ -116,8 +118,8 @@ def test_montecarlo_main_loop_vpacket_log( actual_nu = transport_state.packet_collection.output_nus actual_nu_bar_estimator = transport_state.estimators.nu_bar_estimator actual_j_estimator = transport_state.estimators.j_estimator - actual_vpacket_log_nus = transport.virt_packet_nus - actual_vpacket_log_energies = transport.virt_packet_energies + actual_vpacket_log_nus = transport_state.vpacket_tracker.nus + actual_vpacket_log_energies = transport_state.vpacket_tracker.energies expected_hdf_store.close() # Compare diff --git a/tardis/montecarlo/montecarlo_numba/tests/test_cuda_formal_integral.py b/tardis/montecarlo/montecarlo_numba/tests/test_cuda_formal_integral.py index 679b25491c5..d4c21767cbe 100644 --- a/tardis/montecarlo/montecarlo_numba/tests/test_cuda_formal_integral.py +++ b/tardis/montecarlo/montecarlo_numba/tests/test_cuda_formal_integral.py @@ -11,7 +11,12 @@ FormalIntegrator, NumbaFormalIntegrator, ) -from tardis.montecarlo.montecarlo_numba.numba_interface import NumbaModel +from tardis.montecarlo.montecarlo_numba.numba_interface import ( + NumbaModel, +) + +from tardis.montecarlo.base import MonteCarloTransportSolver + # Test cases must also take into account use of a GPU to run. If there is no GPU then the test cases will fail. GPUs_available = cuda.is_available() @@ -379,8 +384,8 @@ def test_full_formal_integral( L_cuda = formal_integrator_cuda.integrator.formal_integral( formal_integrator_cuda.simulation_state.t_inner, - sim.transport.spectrum.frequency, - sim.transport.spectrum.frequency.shape[0], + sim.transport.transport_state.spectrum.frequency, + sim.transport.transport_state.spectrum.frequency.shape[0], att_S_ul_cuda, Jred_lu_cuda, Jblue_lu_cuda, @@ -391,8 +396,8 @@ def test_full_formal_integral( L_numba = formal_integrator_numba.integrator.formal_integral( formal_integrator_numba.simulation_state.t_inner, - sim.transport.spectrum.frequency, - sim.transport.spectrum.frequency.shape[0], + sim.transport.transport_state.spectrum.frequency, + sim.transport.transport_state.spectrum.frequency.shape[0], att_S_ul_numba, Jred_lu_numba, Jblue_lu_numba, diff --git a/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py b/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py index b11b9315912..ef6492efdc1 100644 --- a/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py +++ b/tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py @@ -59,7 +59,7 @@ def test_configuration_initialize(): assert False -def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): +def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): assert verysimple_3vpacket_collection.length == 0 nus = [3.0e15, 0.0, 1e15, 1e5] @@ -95,7 +95,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection): last_interaction_out_ids, last_interaction_shell_ids, ): - verysimple_3vpacket_collection.set_properties( + verysimple_3vpacket_collection.add_packet( nu, energy, initial_mu, diff --git a/tardis/montecarlo/montecarlo_numba/tests/test_packet.py b/tardis/montecarlo/montecarlo_numba/tests/test_packet.py index bdf02a5dcbe..9a5e8b148e1 100644 --- a/tardis/montecarlo/montecarlo_numba/tests/test_packet.py +++ b/tardis/montecarlo/montecarlo_numba/tests/test_packet.py @@ -1,28 +1,25 @@ -import pytest import numpy as np -import tardis.montecarlo.montecarlo_numba.estimators +import pytest +import tardis.montecarlo.montecarlo_configuration as numba_config +import tardis.montecarlo.montecarlo_numba.estimators +import tardis.montecarlo.montecarlo_numba.numba_interface as numba_interface +import tardis.montecarlo.montecarlo_numba.opacities as opacities import tardis.montecarlo.montecarlo_numba.r_packet as r_packet -import tardis.transport.geometry.calculate_distances as calculate_distances +import tardis.montecarlo.montecarlo_numba.utils as utils import tardis.transport.frame_transformations as frame_transformations -import tardis.montecarlo.montecarlo_numba.opacities as opacities +import tardis.transport.geometry.calculate_distances as calculate_distances import tardis.transport.r_packet_transport as r_packet_transport -from tardis.montecarlo.montecarlo_numba.estimators import update_line_estimators -import tardis.montecarlo.montecarlo_numba.utils as utils - -import tardis.montecarlo.montecarlo_numba.numba_interface as numba_interface -from tardis.model.geometry.radial1d import NumbaRadial1DGeometry from tardis import constants as const - -import tardis.montecarlo.montecarlo_numba.numba_config as numba_config - +from tardis.model.geometry.radial1d import NumbaRadial1DGeometry +from tardis.montecarlo.montecarlo_numba.estimators import update_line_estimators C_SPEED_OF_LIGHT = const.c.to("cm/s").value SIGMA_THOMSON = const.sigma_T.to("cm^2").value from numpy.testing import ( - assert_almost_equal, assert_allclose, + assert_almost_equal, ) diff --git a/tardis/montecarlo/montecarlo_numba/tests/test_r_packet.py b/tardis/montecarlo/montecarlo_numba/tests/test_r_packet.py index 21014dcd664..1964dc48edc 100644 --- a/tardis/montecarlo/montecarlo_numba/tests/test_r_packet.py +++ b/tardis/montecarlo/montecarlo_numba/tests/test_r_packet.py @@ -1,17 +1,16 @@ -import pytest -import pandas as pd -import os -import numpy.testing as npt -import numpy as np from copy import deepcopy -from tardis.base import run_tardis -from tardis.montecarlo.montecarlo_numba.r_packet import ( - rpacket_trackers_to_dataframe, -) +import numpy as np +import numpy.testing as npt +import pytest + +from tardis.base import run_tardis from tardis.montecarlo import ( montecarlo_configuration as montecarlo_configuration, ) +from tardis.montecarlo.montecarlo_numba.r_packet import ( + rpacket_trackers_to_dataframe, +) @pytest.fixture(scope="module") @@ -34,15 +33,16 @@ def simulation_rpacket_tracking_enabled(config_verysimple, atomic_dataset): def test_rpacket_trackers_to_dataframe(simulation_rpacket_tracking_enabled): sim = simulation_rpacket_tracking_enabled - rtracker_df = rpacket_trackers_to_dataframe(sim.transport.rpacket_tracker) + transport_state = sim.transport.transport_state + rtracker_df = rpacket_trackers_to_dataframe(transport_state.rpacket_tracker) # check df shape and column names assert rtracker_df.shape == ( - sum([len(tracker.r) for tracker in sim.transport.rpacket_tracker]), + sum([len(tracker.r) for tracker in transport_state.rpacket_tracker]), 8, ) npt.assert_array_equal( - sim.transport.rpacket_tracker_df.columns.values, + transport_state.rpacket_tracker_df.columns.values, np.array( [ "status", @@ -59,7 +59,7 @@ def test_rpacket_trackers_to_dataframe(simulation_rpacket_tracking_enabled): # check all data with rpacket_tracker expected_rtrackers = [] - for rpacket in sim.transport.rpacket_tracker: + for rpacket in transport_state.rpacket_tracker: for rpacket_step_no in range(len(rpacket.r)): expected_rtrackers.append( [ diff --git a/tardis/montecarlo/montecarlo_numba/vpacket.py b/tardis/montecarlo/montecarlo_numba/vpacket.py index 76f8c25deac..627cbd75e64 100644 --- a/tardis/montecarlo/montecarlo_numba/vpacket.py +++ b/tardis/montecarlo/montecarlo_numba/vpacket.py @@ -115,7 +115,7 @@ def trace_vpacket_within_shell( else: chi_continuum = chi_e - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: chi_continuum *= doppler_factor tau_continuum = chi_continuum * distance_boundary @@ -189,15 +189,15 @@ def trace_vpacket( v_packet, delta_shell, len(numba_radial_1d_geometry.r_inner) ) - if tau_trace_combined > montecarlo_configuration.tau_russian: + if tau_trace_combined > montecarlo_configuration.VPACKET_TAU_RUSSIAN: event_random = np.random.random() - if event_random > montecarlo_configuration.survival_probability: + if event_random > montecarlo_configuration.SURVIVAL_PROBABILITY: v_packet.energy = 0.0 v_packet.status = PacketStatus.EMITTED else: v_packet.energy = ( v_packet.energy - / montecarlo_configuration.survival_probability + / montecarlo_configuration.SURVIVAL_PROBABILITY * math.exp(-tau_trace_combined) ) tau_trace_combined = 0.0 @@ -258,7 +258,7 @@ def trace_vpacket_volley( r_inner_over_r = numba_radial_1d_geometry.r_inner[0] / r_packet.r mu_min = -math.sqrt(1 - r_inner_over_r * r_inner_over_r) v_packet_on_inner_boundary = False - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: mu_min = angle_aberration_LF_to_CMF( r_packet, numba_model.time_explosion, mu_min ) @@ -266,7 +266,7 @@ def trace_vpacket_volley( v_packet_on_inner_boundary = True mu_min = 0.0 - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: inv_c = 1 / C_SPEED_OF_LIGHT inv_t = 1 / numba_model.time_explosion beta_inner = numba_radial_1d_geometry.r_inner[0] * inv_t * inv_c @@ -279,7 +279,7 @@ def trace_vpacket_volley( v_packet_mu = mu_min + i * mu_bin + np.random.random() * mu_bin if v_packet_on_inner_boundary: # The weights are described in K&S 2014 - if not montecarlo_configuration.full_relativity: + if not montecarlo_configuration.ENABLE_FULL_RELATIVITY: weight = 2 * v_packet_mu / no_of_vpackets else: weight = ( @@ -293,7 +293,7 @@ def trace_vpacket_volley( weight = (1 - mu_min) / (2 * no_of_vpackets) # C code: next line, angle_aberration_CMF_to_LF( & virt_packet, storage); - if montecarlo_configuration.full_relativity: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: v_packet_mu = angle_aberration_CMF_to_LF( r_packet, numba_model.time_explosion, v_packet_mu ) @@ -328,7 +328,7 @@ def trace_vpacket_volley( v_packet.energy *= math.exp(-tau_vpacket) - vpacket_collection.set_properties( + vpacket_collection.add_packet( v_packet.nu, v_packet.energy, v_packet_mu, diff --git a/tardis/montecarlo/montecarlo_state.py b/tardis/montecarlo/montecarlo_transport_state.py similarity index 70% rename from tardis/montecarlo/montecarlo_state.py rename to tardis/montecarlo/montecarlo_transport_state.py index 82f7c47d4ca..e6565d09aaa 100644 --- a/tardis/montecarlo/montecarlo_state.py +++ b/tardis/montecarlo/montecarlo_transport_state.py @@ -33,21 +33,47 @@ class MonteCarloTransportState(HDFWriterMixin): "spectrum_reabsorbed", "time_of_simulation", "emitted_packet_mask", + "last_interaction_type", + "last_interaction_in_nu", + "last_line_interaction_out_id", + "last_line_interaction_in_id", + "last_line_interaction_shell_id", + ] + + vpacket_hdf_properties = [ + "virt_packet_nus", + "virt_packet_energies", + "virt_packet_initial_rs", + "virt_packet_initial_mus", + "virt_packet_last_interaction_in_nu", + "virt_packet_last_interaction_type", + "virt_packet_last_line_interaction_in_id", + "virt_packet_last_line_interaction_out_id", + "virt_packet_last_line_interaction_shell_id", ] hdf_name = "transport_state" + last_interaction_type = None + last_interaction_in_nu = None + last_line_interaction_out_id = None + last_line_interaction_in_id = None + last_line_interaction_shell_id = None + + virt_logging = False + def __init__( self, packet_collection, estimators, - volume, spectrum_frequency, geometry_state, + opacity_state, + rpacket_tracker=None, + vpacket_tracker=None, ): self.packet_collection = packet_collection self.estimators = estimators - self.volume = volume self.spectrum_frequency = spectrum_frequency self._montecarlo_virtual_luminosity = u.Quantity( np.zeros_like(self.spectrum_frequency.value), "erg / s" @@ -57,6 +83,9 @@ def __init__( self._spectrum_integrated = None self.enable_full_relativity = False self.geometry_state = geometry_state + self.opacity_state = opacity_state + self.rpacket_tracker = rpacket_tracker + self.vpacket_tracker = vpacket_tracker def calculate_radiationfield_properties(self): """ @@ -274,12 +303,12 @@ def calculate_reabsorbed_luminosity( ].sum() @property - def virtual_packet_nu(self): + def virt_packet_nus(self): try: - return u.Quantity(self.virt_packet_nus, u.Hz) + return u.Quantity(self.vpacket_tracker.nus, u.Hz) except AttributeError: warnings.warn( - "MontecarloTransport.virtual_packet_nu:" + "MontecarloTransport.virt_packet_nus:" "Set 'virtual_packet_logging: True' in the configuration file" "to access this property" "It should be added under 'virtual' property of 'spectrum' property", @@ -288,12 +317,12 @@ def virtual_packet_nu(self): return None @property - def virtual_packet_energy(self): + def virt_packet_energies(self): try: - return u.Quantity(self.virt_packet_energies, u.erg) + return u.Quantity(self.vpacket_tracker.energies, u.erg) except AttributeError: warnings.warn( - "MontecarloTransport.virtual_packet_energy:" + "MontecarloTransport.virt_packet_energies:" "Set 'virtual_packet_logging: True' in the configuration file" "to access this property" "It should be added under 'virtual' property of 'spectrum' property", @@ -304,8 +333,9 @@ def virtual_packet_energy(self): @property def virtual_packet_luminosity(self): try: - return self.virtual_packet_energy / ( - self.packet_collection.time_of_simulation * u.s + return ( + self.virt_packet_energies + / self.packet_collection.time_of_simulation ) except TypeError: warnings.warn( @@ -318,19 +348,26 @@ def virtual_packet_luminosity(self): return None @property - def montecarlo_virtual_luminosity(self): - return ( - self._montecarlo_virtual_luminosity[:-1] - / self.packet_collection.time_of_simulation - ) + def virt_packet_initial_rs(self): + try: + return u.Quantity(self.vpacket_tracker.initial_rs, u.erg) + except AttributeError: + warnings.warn( + "MontecarloTransport.virt_packet_initial_rs:" + "Set 'virtual_packet_logging: True' in the configuration file" + "to access this property" + "It should be added under 'virtual' property of 'spectrum' property", + UserWarning, + ) + return None @property - def virtual_packet_nu(self): + def virt_packet_initial_mus(self): try: - return u.Quantity(self.virt_packet_nus, u.Hz) + return u.Quantity(self.vpacket_tracker.initial_mus, u.erg) except AttributeError: warnings.warn( - "MontecarloTransport.virtual_packet_nu:" + "MontecarloTransport.virt_packet_initial_mus:" "Set 'virtual_packet_logging: True' in the configuration file" "to access this property" "It should be added under 'virtual' property of 'spectrum' property", @@ -339,12 +376,14 @@ def virtual_packet_nu(self): return None @property - def virtual_packet_energy(self): + def virt_packet_last_interaction_in_nu(self): try: - return u.Quantity(self.virt_packet_energies, u.erg) + return u.Quantity( + self.vpacket_tracker.last_interaction_in_nu, u.erg + ) except AttributeError: warnings.warn( - "MontecarloTransport.virtual_packet_energy:" + "MontecarloTransport.virt_packet_last_interaction_in_nu:" "Set 'virtual_packet_logging: True' in the configuration file" "to access this property" "It should be added under 'virtual' property of 'spectrum' property", @@ -353,15 +392,60 @@ def virtual_packet_energy(self): return None @property - def virtual_packet_luminosity(self): + def virt_packet_last_interaction_type(self): try: - return ( - self.virtual_packet_energy - / self.packet_collection.time_of_simulation + return u.Quantity(self.vpacket_tracker.last_interaction_type, u.erg) + except AttributeError: + warnings.warn( + "MontecarloTransport.virt_packet_last_interaction_type:" + "Set 'virtual_packet_logging: True' in the configuration file" + "to access this property" + "It should be added under 'virtual' property of 'spectrum' property", + UserWarning, ) - except TypeError: + return None + + @property + def virt_packet_last_line_interaction_in_id(self): + try: + return u.Quantity( + self.vpacket_tracker.last_interaction_in_id, u.erg + ) + except AttributeError: warnings.warn( - "MontecarloTransport.virtual_packet_luminosity:" + "MontecarloTransport.virt_packet_last_line_interaction_in_id:" + "Set 'virtual_packet_logging: True' in the configuration file" + "to access this property" + "It should be added under 'virtual' property of 'spectrum' property", + UserWarning, + ) + return None + + @property + def virt_packet_last_line_interaction_out_id(self): + try: + return u.Quantity( + self.vpacket_tracker.last_interaction_out_id, u.erg + ) + except AttributeError: + warnings.warn( + "MontecarloTransport.virt_packet_last_line_interaction_out_id:" + "Set 'virtual_packet_logging: True' in the configuration file" + "to access this property" + "It should be added under 'virtual' property of 'spectrum' property", + UserWarning, + ) + return None + + @property + def virt_packet_last_line_interaction_shell_id(self): + try: + return u.Quantity( + self.vpacket_tracker.last_interaction_shell_id, u.erg + ) + except AttributeError: + warnings.warn( + "MontecarloTransport.virt_packet_last_line_interaction_shell_id:" "Set 'virtual_packet_logging: True' in the configuration file" "to access this property" "It should be added under 'virtual' property of 'spectrum' property", diff --git a/tardis/montecarlo/tests/test_base.py b/tardis/montecarlo/tests/test_base.py index d7cab27c2e8..9c7ec3bffbe 100644 --- a/tardis/montecarlo/tests/test_base.py +++ b/tardis/montecarlo/tests/test_base.py @@ -12,27 +12,24 @@ @pytest.fixture(scope="module", autouse=True) -def to_hdf_buffer(hdf_file_path, simulation_verysimple): - simulation_verysimple.transport.to_hdf( +def to_hdf_buffer(hdf_file_path, simulation_verysimple_vpacket_tracking): + simulation_verysimple_vpacket_tracking.transport.to_hdf( hdf_file_path, name="transport", overwrite=True ) - simulation_verysimple.transport.transport_state.to_hdf( + simulation_verysimple_vpacket_tracking.transport.transport_state.to_hdf( hdf_file_path, name="transport_state", overwrite=True ) -transport_properties = [ - "last_interaction_in_nu", - "last_interaction_type", - "last_line_interaction_in_id", - "last_line_interaction_out_id", - "last_line_interaction_shell_id", -] +transport_properties = [None] +@pytest.mark.xfail(reason="No HDF properties being written currently") @pytest.mark.parametrize("attr", transport_properties) -def test_hdf_transport(hdf_file_path, simulation_verysimple, attr): - actual = getattr(simulation_verysimple.transport, attr) +def test_hdf_transport( + hdf_file_path, simulation_verysimple_vpacket_tracking, attr +): + actual = getattr(simulation_verysimple_vpacket_tracking.transport, attr) if hasattr(actual, "cgs"): actual = actual.cgs.value path = f"transport/{attr}" @@ -47,12 +44,37 @@ def test_hdf_transport(hdf_file_path, simulation_verysimple, attr): "j_estimator", "montecarlo_virtual_luminosity", "packet_luminosity", + # These are nested properties that should be tested differently + # "spectrum", + # "spectrum_virtual", + # "spectrum_reabsorbed", + # This is a scalar and should be tested differently + # "time_of_simulation", + "emitted_packet_mask", + "last_interaction_type", + "last_interaction_in_nu", + "last_line_interaction_out_id", + "last_line_interaction_in_id", + "last_line_interaction_shell_id", + "virt_packet_nus", + "virt_packet_energies", + "virt_packet_initial_rs", + "virt_packet_initial_mus", + "virt_packet_last_interaction_in_nu", + "virt_packet_last_interaction_type", + "virt_packet_last_line_interaction_in_id", + "virt_packet_last_line_interaction_out_id", + "virt_packet_last_line_interaction_shell_id", ] @pytest.mark.parametrize("attr", transport_state_properties) -def test_hdf_transport_state(hdf_file_path, simulation_verysimple, attr): - actual = getattr(simulation_verysimple.transport.transport_state, attr) +def test_hdf_transport_state( + hdf_file_path, simulation_verysimple_vpacket_tracking, attr +): + actual = getattr( + simulation_verysimple_vpacket_tracking.transport.transport_state, attr + ) if hasattr(actual, "cgs"): actual = actual.cgs.value path = f"transport_state/{attr}" diff --git a/tardis/montecarlo/tests/test_montecarlo.py b/tardis/montecarlo/tests/test_montecarlo.py index c3cc904b6b3..5aa5f4eedc5 100644 --- a/tardis/montecarlo/tests/test_montecarlo.py +++ b/tardis/montecarlo/tests/test_montecarlo.py @@ -459,7 +459,7 @@ def test_move_packet(packet_params, expected_params, full_relativity): packet.energy = packet_params["energy"] packet.r = packet_params["r"] # model.full_relativity = full_relativity - mc.full_relativity = full_relativity + mc.ENABLE_FULL_RELATIVITY = full_relativity doppler_factor = get_doppler_factor(packet.r, packet.mu, time_explosion) numba_estimator = Estimators( @@ -478,7 +478,7 @@ def test_move_packet(packet_params, expected_params, full_relativity): expected_j *= doppler_factor expected_nubar *= doppler_factor - mc.full_relativity = False + mc.ENABLE_FULL_RELATIVITY = False assert_allclose( numba_estimator.j_estimator[packet.current_shell_id], @@ -564,8 +564,8 @@ def test_move_packet(packet_params, expected_params, full_relativity): ) def test_frame_transformations(mu, r, inv_t_exp, full_relativity): packet = r_packet.RPacket(r=r, mu=mu, energy=0.9, nu=0.4) - mc.full_relativity = bool(full_relativity) - mc.full_relativity = full_relativity + mc.ENABLE_FULL_RELATIVITY = bool(full_relativity) + mc.ENABLE_FULL_RELATIVITY = full_relativity inverse_doppler_factor = r_packet.get_inverse_doppler_factor( r, mu, 1 / inv_t_exp @@ -573,7 +573,7 @@ def test_frame_transformations(mu, r, inv_t_exp, full_relativity): r_packet.angle_aberration_CMF_to_LF(packet, 1 / inv_t_exp, packet.mu) doppler_factor = get_doppler_factor(r, mu, 1 / inv_t_exp) - mc.full_relativity = False + mc.ENABLE_FULL_RELATIVITY = False assert_almost_equal(doppler_factor * inverse_doppler_factor, 1.0) @@ -591,12 +591,12 @@ def test_frame_transformations(mu, r, inv_t_exp, full_relativity): ) def test_angle_transformation_invariance(mu, r, inv_t_exp): packet = r_packet.RPacket(r, mu, 0.4, 0.9) - mc.full_relativity = True + mc.ENABLE_FULL_RELATIVITY = True mu1 = angle_aberration_CMF_to_LF(packet, 1 / inv_t_exp, mu) mu_obtained = angle_aberration_LF_to_CMF(packet, 1 / inv_t_exp, mu1) - mc.full_relativity = False + mc.ENABLE_FULL_RELATIVITY = False assert_almost_equal(mu_obtained, mu) @@ -624,7 +624,7 @@ def test_compute_distance2line_relativistic( transport.j_blue_estimator, transport.Edotlu_estimator, ) - mc.full_relativity = bool(full_relativity) + mc.ENABLE_FULL_RELATIVITY = bool(full_relativity) doppler_factor = get_doppler_factor(r, mu, t_exp) comov_nu = packet.nu * doppler_factor @@ -635,7 +635,7 @@ def test_compute_distance2line_relativistic( doppler_factor = get_doppler_factor(r, mu, t_exp) comov_nu = packet.nu * doppler_factor - mc.full_relativity = False + mc.ENABLE_FULL_RELATIVITY = False assert_allclose(comov_nu, nu_line, rtol=1e-14) diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 2cc940874de..43ab0be8345 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -15,11 +15,10 @@ from tardis.io.util import HDFWriterMixin from tardis.model import SimulationState from tardis.model.parse_input import initialize_packet_source -from tardis.montecarlo import montecarlo_configuration as mc_config_module -from tardis.montecarlo.base import MonteCarloTransportSolver -from tardis.montecarlo.montecarlo_numba.r_packet import ( - rpacket_trackers_to_dataframe, +from tardis.montecarlo import ( + montecarlo_configuration as montecarlo_configuration, ) +from tardis.montecarlo.base import MonteCarloTransportSolver from tardis.plasma.standard_plasmas import assemble_plasma from tardis.util.base import is_notebook from tardis.visualization import ConvergencePlots @@ -193,7 +192,7 @@ def __init__( self._callbacks = OrderedDict() self._cb_next_id = 0 - mc_config_module.CONTINUUM_PROCESSES_ENABLED = ( + montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = ( not self.plasma.continuum_interaction_species.empty ) @@ -392,15 +391,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): logger.info( f"\n\tStarting iteration {(self.iterations_executed + 1):d} of {self.iterations:d}" ) - self.transport.run( + + transport_state = self.transport.initialize_transport_state( self.simulation_state, self.plasma, no_of_packets, no_of_virtual_packets=no_of_virtual_packets, iteration=self.iterations_executed, + ) + + self.transport.run( + transport_state, + time_explosion=self.simulation_state.time_explosion, + iteration=self.iterations_executed, total_iterations=self.iterations, show_progress_bars=self.show_progress_bars, ) + output_energy = ( self.transport.transport_state.packet_collection.output_energies ) @@ -489,11 +496,6 @@ def run_final(self): last=True, ) - if self.transport.rpacket_tracker: - self.transport.rpacket_tracker_df = rpacket_trackers_to_dataframe( - self.transport.rpacket_tracker - ) - self._call_back() def log_plasma_state( diff --git a/tardis/tests/test_tardis_full.py b/tardis/tests/test_tardis_full.py index 55a57370707..5ababcbb6bc 100644 --- a/tardis/tests/test_tardis_full.py +++ b/tardis/tests/test_tardis_full.py @@ -97,35 +97,3 @@ def test_virtual_spectrum(self, transport, refdata): assert_quantity_allclose( transport.transport_state.spectrum_virtual.luminosity, luminosity ) - - def test_transport_properties(self, transport): - """ - Tests whether a number of transport attributes exist and also verifies - their types - - Currently, transport attributes needed to call the model routine to_hdf5 - are checked. - """ - - virt_type = np.ndarray - - props_required_by_modeltohdf5 = dict( - [ - ("virt_packet_last_interaction_type", virt_type), - ("virt_packet_last_line_interaction_in_id", virt_type), - ("virt_packet_last_line_interaction_out_id", virt_type), - ("virt_packet_last_line_interaction_shell_id", virt_type), - ("virt_packet_last_interaction_in_nu", virt_type), - ("virt_packet_nus", virt_type), - ("virt_packet_energies", virt_type), - ] - ) - - required_props = props_required_by_modeltohdf5.copy() - - for prop, prop_type in required_props.items(): - actual = getattr(transport, prop) - assert type(actual) == prop_type, ( - f"wrong type of attribute '{prop}':" - f"expected {prop_type}, found {type(actual)}" - ) diff --git a/tardis/transport/frame_transformations.py b/tardis/transport/frame_transformations.py index e6245b278a2..07bb9dc5b14 100644 --- a/tardis/transport/frame_transformations.py +++ b/tardis/transport/frame_transformations.py @@ -6,7 +6,7 @@ njit_dict_no_parallel, ) -from tardis.montecarlo.montecarlo_numba import numba_config as nc +from tardis.montecarlo import montecarlo_configuration as nc from tardis.montecarlo.montecarlo_numba.numba_config import C_SPEED_OF_LIGHT diff --git a/tardis/transport/geometry/calculate_distances.py b/tardis/transport/geometry/calculate_distances.py index 44f35a506e9..204b42da3f2 100644 --- a/tardis/transport/geometry/calculate_distances.py +++ b/tardis/transport/geometry/calculate_distances.py @@ -6,7 +6,7 @@ njit_dict_no_parallel, ) -import tardis.montecarlo.montecarlo_numba.numba_config as nc +import tardis.montecarlo.montecarlo_configuration as nc from tardis.montecarlo.montecarlo_numba.numba_config import ( C_SPEED_OF_LIGHT, MISS_DISTANCE, diff --git a/tardis/transport/r_packet_transport.py b/tardis/transport/r_packet_transport.py index 19763c67edb..2520733c14d 100644 --- a/tardis/transport/r_packet_transport.py +++ b/tardis/transport/r_packet_transport.py @@ -15,7 +15,6 @@ from tardis.transport.frame_transformations import ( get_doppler_factor, ) -import tardis.montecarlo.montecarlo_numba.numba_config as nc from tardis.montecarlo.montecarlo_numba.opacities import calculate_tau_electron from tardis.montecarlo.montecarlo_numba.r_packet import ( InteractionType, @@ -136,7 +135,7 @@ def trace_packet( if ( tau_trace_combined > tau_event - and not montecarlo_configuration.disable_line_scattering + and not montecarlo_configuration.DISABLE_LINE_SCATTERING ): interaction_type = InteractionType.LINE # Line r_packet.last_interaction_in_nu = r_packet.nu @@ -209,7 +208,7 @@ def move_r_packet(r_packet, distance, time_explosion, numba_estimator): comov_energy = r_packet.energy * doppler_factor # Account for length contraction - if nc.ENABLE_FULL_RELATIVITY: + if montecarlo_configuration.ENABLE_FULL_RELATIVITY: distance *= doppler_factor update_base_estimators( diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index c16c13a63a4..8490a58b990 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -173,14 +173,16 @@ def from_simulation(cls, sim, packets_mode): if packets_mode == "virtual": return cls( - last_interaction_type=sim.transport.virt_packet_last_interaction_type, - last_line_interaction_in_id=sim.transport.virt_packet_last_line_interaction_in_id, - last_line_interaction_out_id=sim.transport.virt_packet_last_line_interaction_out_id, - last_line_interaction_in_nu=sim.transport.virt_packet_last_interaction_in_nu, + last_interaction_type=transport_state.vpacket_tracker.last_interaction_type, + last_line_interaction_in_id=transport_state.vpacket_tracker.last_interaction_in_id, + last_line_interaction_out_id=transport_state.vpacket_tracker.last_interaction_out_id, + last_line_interaction_in_nu=transport_state.vpacket_tracker.last_interaction_in_nu, lines_df=lines_df, - packet_nus=u.Quantity(sim.transport.virt_packet_nus, "Hz"), + packet_nus=u.Quantity( + transport_state.vpacket_tracker.nus, "Hz" + ), packet_energies=u.Quantity( - sim.transport.virt_packet_energies, "erg" + transport_state.vpacket_tracker.energies, "erg" ), r_inner=r_inner, spectrum_delta_frequency=transport_state.spectrum_virtual.delta_frequency, @@ -194,17 +196,18 @@ def from_simulation(cls, sim, packets_mode): elif packets_mode == "real": # Packets-specific properties need to be only for those packets # which got emitted + transport_state = sim.transport.transport_state return cls( - last_interaction_type=sim.transport.last_interaction_type[ + last_interaction_type=transport_state.last_interaction_type[ transport_state.emitted_packet_mask ], - last_line_interaction_in_id=sim.transport.last_line_interaction_in_id[ + last_line_interaction_in_id=transport_state.last_line_interaction_in_id[ transport_state.emitted_packet_mask ], - last_line_interaction_out_id=sim.transport.last_line_interaction_out_id[ + last_line_interaction_out_id=transport_state.last_line_interaction_out_id[ transport_state.emitted_packet_mask ], - last_line_interaction_in_nu=sim.transport.last_interaction_in_nu[ + last_line_interaction_in_nu=transport_state.last_interaction_in_nu[ transport_state.emitted_packet_mask ], lines_df=lines_df, @@ -263,53 +266,55 @@ def from_hdf(cls, hdf_fpath, packets_mode): if packets_mode == "virtual": return cls( last_interaction_type=hdf[ - "/simulation/transport/virt_packet_last_interaction_type" + "/simulation/transport/transport_state/virt_packet_last_interaction_type" ], last_line_interaction_in_id=hdf[ - "/simulation/transport/virt_packet_last_line_interaction_in_id" + "/simulation/transport/transport_state/virt_packet_last_line_interaction_in_id" ], last_line_interaction_out_id=hdf[ - "/simulation/transport/virt_packet_last_line_interaction_out_id" + "/simulation/transport/transport_state/virt_packet_last_line_interaction_out_id" ], last_line_interaction_in_nu=u.Quantity( hdf[ - "/simulation/transport/virt_packet_last_interaction_in_nu" + "/simulation/transport/transport_state/virt_packet_last_interaction_in_nu" ].to_numpy(), "Hz", ), lines_df=lines_df, packet_nus=u.Quantity( - hdf["/simulation/transport/virt_packet_nus"].to_numpy(), + hdf[ + "/simulation/transport/transport_state/virt_packet_nus" + ].to_numpy(), "Hz", ), packet_energies=u.Quantity( hdf[ - "/simulation/transport/virt_packet_energies" + "/simulation/transport/transport_state/virt_packet_energies" ].to_numpy(), "erg", ), r_inner=r_inner, spectrum_delta_frequency=u.Quantity( hdf[ - "/simulation/transport/spectrum_virtual/scalars" + "/simulation/transport/transport_state/spectrum_virtual/scalars" ].delta_frequency, "Hz", ), spectrum_frequency_bins=u.Quantity( hdf[ - "/simulation/transport/spectrum_virtual/_frequency" + "/simulation/transport/transport_state/spectrum_virtual/_frequency" ].to_numpy(), "Hz", ), spectrum_luminosity_density_lambda=u.Quantity( hdf[ - "/simulation/transport/spectrum_virtual/luminosity_density_lambda" + "/simulation/transport/transport_state/spectrum_virtual/luminosity_density_lambda" ].to_numpy(), "erg / s cm", # luminosity_density_lambda is saved in hdf in CGS ).to("erg / s AA"), spectrum_wavelength=u.Quantity( hdf[ - "/simulation/transport/spectrum_virtual/wavelength" + "/simulation/transport/transport_state/spectrum_virtual/wavelength" ].to_numpy(), "cm", # wavelength is saved in hdf in CGS ).to("AA"), @@ -319,61 +324,61 @@ def from_hdf(cls, hdf_fpath, packets_mode): elif packets_mode == "real": emitted_packet_mask = hdf[ - "/simulation/transport/emitted_packet_mask" + "/simulation/transport/transport_state/emitted_packet_mask" ].to_numpy() return cls( # First convert series read from hdf to array before masking # to eliminate index info which creates problems otherwise last_interaction_type=hdf[ - "/simulation/transport/last_interaction_type" + "/simulation/transport/transport_state/last_interaction_type" ].to_numpy()[emitted_packet_mask], last_line_interaction_in_id=hdf[ - "/simulation/transport/last_line_interaction_in_id" + "/simulation/transport/transport_state/last_line_interaction_in_id" ].to_numpy()[emitted_packet_mask], last_line_interaction_out_id=hdf[ - "/simulation/transport/last_line_interaction_out_id" + "/simulation/transport/transport_state/last_line_interaction_out_id" ].to_numpy()[emitted_packet_mask], last_line_interaction_in_nu=u.Quantity( hdf[ - "/simulation/transport/last_interaction_in_nu" + "/simulation/transport/transport_state/last_interaction_in_nu" ].to_numpy()[emitted_packet_mask], "Hz", ), lines_df=lines_df, packet_nus=u.Quantity( - hdf["/simulation/transport/output_nu"].to_numpy()[ - emitted_packet_mask - ], + hdf[ + "/simulation/transport/transport_state/output_nu" + ].to_numpy()[emitted_packet_mask], "Hz", ), packet_energies=u.Quantity( - hdf["/simulation/transport/output_energy"].to_numpy()[ - emitted_packet_mask - ], + hdf[ + "/simulation/transport/transport_state/output_energy" + ].to_numpy()[emitted_packet_mask], "erg", ), r_inner=r_inner, spectrum_delta_frequency=u.Quantity( hdf[ - "/simulation/transport/spectrum/scalars" + "/simulation/transport/transport_state/spectrum/scalars" ].delta_frequency, "Hz", ), spectrum_frequency_bins=u.Quantity( hdf[ - "/simulation/transport/spectrum/_frequency" + "/simulation/transport/transport_state/spectrum/_frequency" ].to_numpy(), "Hz", ), spectrum_luminosity_density_lambda=u.Quantity( hdf[ - "/simulation/transport/spectrum/luminosity_density_lambda" + "/simulation/transport/transport_state/spectrum/luminosity_density_lambda" ].to_numpy(), "erg / s cm", ).to("erg / s AA"), spectrum_wavelength=u.Quantity( hdf[ - "/simulation/transport/spectrum/wavelength" + "/simulation/transport/transport_state/spectrum/wavelength" ].to_numpy(), "cm", ).to("AA"), @@ -421,7 +426,7 @@ def from_simulation(cls, sim): ------- SDECPlotter """ - if sim.transport.virt_logging: + if sim.transport.transport_state.virt_logging: return cls( { "virtual": SDECData.from_simulation(sim, "virtual"), diff --git a/tardis/visualization/tools/tests/test_sdec_plot.py b/tardis/visualization/tools/tests/test_sdec_plot.py index 975d5aa03fb..a92454a89fc 100644 --- a/tardis/visualization/tools/tests/test_sdec_plot.py +++ b/tardis/visualization/tools/tests/test_sdec_plot.py @@ -1,16 +1,17 @@ """Tests for SDEC Plots.""" -from tardis.base import run_tardis -import pytest -import pandas as pd -import numpy as np import os from copy import deepcopy -from tardis.visualization.tools.sdec_plot import SDECData, SDECPlotter + import astropy.units as u +import numpy as np +import pandas as pd +import pytest +import tables from matplotlib.collections import PolyCollection from matplotlib.lines import Line2D -import tables -import re + +from tardis.base import run_tardis +from tardis.visualization.tools.sdec_plot import SDECPlotter def make_valid_name(testid): @@ -158,6 +159,7 @@ def test_parse_species_list(self, request, plotter, species): plotter : tardis.visualization.tools.sdec_plot.SDECPlotter species : list """ + # THIS NEEDS TO BE RUN FIRST. NOT INDEPENDENT TESTS plotter._parse_species_list(species) subgroup_name = make_valid_name(request.node.callspec.id) if request.config.getoption("--generate-reference"):