diff --git a/tardis/io/schemas/montecarlo.yml b/tardis/io/schemas/montecarlo.yml index 0e6df3c7988..87e6773b5c3 100644 --- a/tardis/io/schemas/montecarlo.yml +++ b/tardis/io/schemas/montecarlo.yml @@ -64,6 +64,15 @@ properties: default: false description: Enables a more complete treatment of relativitic effects. This includes angle aberration as well as use of the fully general Doppler formula. + montecarlo_tracking: + type: object + default: {} + properties: + r_packet_tracking: + type: boolean + default: false + description: Allows for Tracking the properties of the RPackets in Single Packet Loop + description: Sets up tracking for Montecarlo debug_packets: type: boolean default: false diff --git a/tardis/montecarlo/base.py b/tardis/montecarlo/base.py index d1fa24e1f30..a09ac811d34 100644 --- a/tardis/montecarlo/base.py +++ b/tardis/montecarlo/base.py @@ -100,6 +100,7 @@ def __init__( debug_packets=False, logger_buffer=1, single_packet_seed=None, + track_r_packet=False, ): self.seed = seed @@ -134,10 +135,14 @@ def __init__( self.virt_packet_initial_rs = np.ones(2) * -1.0 self.virt_packet_initial_mus = np.ones(2) * -1.0 + self.r_packet_tracking = np.zeros(1) + # set up logger based on config mc_logger.DEBUG_MODE = debug_packets mc_logger.BUFFER = logger_buffer + mc_config_module.RPACKET_TRACKING = track_r_packet + if self.spectrum_method == "integrated": self.optional_hdf_properties.append("spectrum_integrated") @@ -610,4 +615,5 @@ def from_config( config.spectrum.virtual.virtual_packet_logging | virtual_packet_logging ), + track_r_packet=config.montecarlo.montecarlo_tracking.r_packet_tracking, ) diff --git a/tardis/montecarlo/montecarlo_configuration.py b/tardis/montecarlo/montecarlo_configuration.py index 438d9b15f19..00718dd11bb 100644 --- a/tardis/montecarlo/montecarlo_configuration.py +++ b/tardis/montecarlo/montecarlo_configuration.py @@ -13,3 +13,4 @@ tau_russian = 10.0 LEGACY_MODE_ENABLED = False VPACKET_LOGGING = False +RPACKET_TRACKING = False diff --git a/tardis/montecarlo/montecarlo_numba/base.py b/tardis/montecarlo/montecarlo_numba/base.py index ba107d6292a..c921f45c99f 100644 --- a/tardis/montecarlo/montecarlo_numba/base.py +++ b/tardis/montecarlo/montecarlo_numba/base.py @@ -11,6 +11,7 @@ from tardis.montecarlo.montecarlo_numba.numba_interface import ( PacketCollection, VPacketCollection, + RPacketCollection, NumbaModel, numba_plasma_initialize, Estimators, @@ -68,6 +69,7 @@ def montecarlo_radial1d(model, plasma, runner): virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, + r_packet_tracker, ) = montecarlo_main_loop( packet_collection, numba_model, @@ -110,6 +112,9 @@ def montecarlo_radial1d(model, plasma, runner): np.array(virt_packet_last_line_interaction_out_id) ).ravel() + # Condition for Checking if R Packet Tracking is enabled + runner.r_packet_tracking = r_packet_tracker + @njit(**njit_dict) def montecarlo_main_loop( @@ -177,6 +182,11 @@ def montecarlo_main_loop( virt_packet_last_line_interaction_in_id = [] virt_packet_last_line_interaction_out_id = [] + # Tracking for R_Packet + r_packet_tracker = List() + for i in range(len(output_nus)): + r_packet_tracker.append(RPacketCollection()) + for i in prange(len(output_nus)): if montecarlo_configuration.single_packet_seed != -1: seed = packet_seeds[montecarlo_configuration.single_packet_seed] @@ -193,9 +203,15 @@ def montecarlo_main_loop( i, ) vpacket_collection = vpacket_collections[i] + r_packet_track = r_packet_tracker[i] - loop = single_packet_loop( - r_packet, numba_model, numba_plasma, estimators, vpacket_collection + single_packet_loop( + r_packet, + numba_model, + numba_plasma, + estimators, + vpacket_collection, + r_packet_track, ) # if loop and 'stop' in loop: # raise MonteCarloException @@ -214,8 +230,12 @@ def montecarlo_main_loop( vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx] - vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx] - vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx] + vpackets_initial_mu = vpacket_collection.initial_mus[ + : vpacket_collection.idx + ] + vpackets_initial_r = vpacket_collection.initial_rs[ + : vpacket_collection.idx + ] v_packets_idx = np.floor( (vpackets_nu - spectrum_frequency[0]) / delta_nu @@ -233,17 +253,29 @@ def montecarlo_main_loop( if montecarlo_configuration.VPACKET_LOGGING: for vpacket_collection in vpacket_collections: vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] - vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx] - vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx] - vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx] + vpackets_energy = vpacket_collection.energies[ + : vpacket_collection.idx + ] + vpackets_initial_mu = vpacket_collection.initial_mus[ + : vpacket_collection.idx + ] + vpackets_initial_r = vpacket_collection.initial_rs[ + : vpacket_collection.idx + ] virt_packet_nus.append(np.ascontiguousarray(vpackets_nu)) virt_packet_energies.append(np.ascontiguousarray(vpackets_energy)) - virt_packet_initial_mus.append(np.ascontiguousarray(vpackets_initial_mu)) - virt_packet_initial_rs.append(np.ascontiguousarray(vpackets_initial_r)) - virt_packet_last_interaction_in_nu.append(np.ascontiguousarray( - vpacket_collection.last_interaction_in_nu[ - : vpacket_collection.idx - ]) + virt_packet_initial_mus.append( + np.ascontiguousarray(vpackets_initial_mu) + ) + virt_packet_initial_rs.append( + np.ascontiguousarray(vpackets_initial_r) + ) + virt_packet_last_interaction_in_nu.append( + np.ascontiguousarray( + vpacket_collection.last_interaction_in_nu[ + : vpacket_collection.idx + ] + ) ) virt_packet_last_interaction_type.append( np.ascontiguousarray( @@ -284,4 +316,5 @@ def montecarlo_main_loop( virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, + r_packet_tracker, ) diff --git a/tardis/montecarlo/montecarlo_numba/numba_interface.py b/tardis/montecarlo/montecarlo_numba/numba_interface.py index 8a3bbf32622..670821ee580 100644 --- a/tardis/montecarlo/montecarlo_numba/numba_interface.py +++ b/tardis/montecarlo/montecarlo_numba/numba_interface.py @@ -304,6 +304,88 @@ def set_properties( self.idx += 1 +rpacket_collection_spec = [ + ("seed", int64[:]), + ("index", int64[:]), + ("status", int64[:]), + ("r", float64[:]), + ("nu", float64[:]), + ("mu", float64[:]), + ("energy", float64[:]), + ("current_shell_id", int64[:]), + ("distance", float64[:]), + ("last_interaction_type", int64[:]), + ("last_interaction_in_nu", float64[:]), + ("last_line_interaction_in_id", int64[:]), + ("last_line_interaction_out_id", int64[:]), +] + + +@jitclass(rpacket_collection_spec) +class RPacketCollection(object): + def __init__(self): + self.seed = np.zeros(1, dtype=np.int64) + self.index = np.zeros(1, dtype=np.int64) + self.status = np.zeros(1, dtype=np.int64) + self.r = np.zeros(1, dtype=np.float64) + self.nu = np.zeros(1, dtype=np.float64) + self.mu = np.zeros(1, dtype=np.float64) + self.energy = np.zeros(1, dtype=np.float64) + self.current_shell_id = np.zeros(1, dtype=np.int64) + self.distance = np.zeros(1, dtype=np.float64) + self.last_interaction_type = -1 * np.ones(1, dtype=np.int64) + self.last_interaction_in_nu = np.zeros(1, dtype=np.float64) + self.last_line_interaction_in_id = -1 * np.ones(1, dtype=np.int64) + self.last_line_interaction_out_id = -1 * np.ones(1, dtype=np.int64) + + def set_properties(self, r_packet, distance): + self.seed = np.concatenate( + (self.seed, np.array([r_packet.seed])) + ).ravel() + self.index = np.concatenate( + (self.index, np.array([r_packet.index])) + ).ravel() + self.status = np.concatenate( + (self.status, np.array([r_packet.status])) + ).ravel() + self.r = np.concatenate((self.r, np.array([r_packet.r]))).ravel() + self.nu = np.concatenate((self.nu, np.array([r_packet.nu]))).ravel() + self.mu = np.concatenate((self.mu, np.array([r_packet.mu]))).ravel() + self.energy = np.concatenate( + (self.energy, np.array([r_packet.energy])) + ).ravel() + self.current_shell_id = np.concatenate( + (self.current_shell_id, np.array([r_packet.current_shell_id])) + ).ravel() + self.distance = np.concatenate( + (self.distance, np.array([distance])) + ).ravel() + self.last_interaction_type = np.concatenate( + ( + self.last_interaction_type, + np.array([r_packet.last_interaction_type]), + ) + ).ravel() + self.last_interaction_in_nu = np.concatenate( + ( + self.last_interaction_in_nu, + np.array([r_packet.last_interaction_in_nu]), + ) + ).ravel() + self.last_line_interaction_in_id = np.concatenate( + ( + self.last_line_interaction_in_id, + np.array([r_packet.last_line_interaction_in_id]), + ) + ).ravel() + self.last_line_interaction_out_id = np.concatenate( + ( + self.last_line_interaction_out_id, + np.array([r_packet.last_line_interaction_out_id]), + ) + ).ravel() + + estimators_spec = [ ("j_estimator", float64[:]), ("nu_bar_estimator", float64[:]), diff --git a/tardis/montecarlo/montecarlo_numba/r_packet.py b/tardis/montecarlo/montecarlo_numba/r_packet.py index f1956df70b9..e838c41ca62 100644 --- a/tardis/montecarlo/montecarlo_numba/r_packet.py +++ b/tardis/montecarlo/montecarlo_numba/r_packet.py @@ -313,3 +313,8 @@ def move_packet_across_shell_boundary(packet, delta_shell, no_of_shells): packet.status = PacketStatus.REABSORBED else: packet.current_shell_id = next_shell_id + + +@njit(**njit_dict_no_parallel) +def track_r_packet(r_packet, r_packet_track, distance): + r_packet_track.set_properties(r_packet, distance) diff --git a/tardis/montecarlo/montecarlo_numba/single_packet_loop.py b/tardis/montecarlo/montecarlo_numba/single_packet_loop.py index c14321c8079..0b344194f52 100644 --- a/tardis/montecarlo/montecarlo_numba/single_packet_loop.py +++ b/tardis/montecarlo/montecarlo_numba/single_packet_loop.py @@ -6,6 +6,7 @@ trace_packet, move_packet_across_shell_boundary, move_r_packet, + track_r_packet, ) from tardis.montecarlo.montecarlo_numba.utils import MonteCarloException @@ -31,13 +32,15 @@ C_SPEED_OF_LIGHT = const.c.to("cm/s").value -from tardis.io.logger.montecarlo_logger import log_decorator -from tardis.io.logger import montecarlo_logger as mc_logger -# @log_decorator @njit def single_packet_loop( - r_packet, numba_model, numba_plasma, estimators, vpacket_collection + r_packet, + numba_model, + numba_plasma, + estimators, + vpacket_collection, + r_packet_track, ): """ Parameters @@ -67,12 +70,8 @@ def single_packet_loop( r_packet, vpacket_collection, numba_model, numba_plasma ) - if mc_logger.DEBUG_MODE: - r_packet_track_nu = [r_packet.nu] - r_packet_track_mu = [r_packet.mu] - r_packet_track_r = [r_packet.r] - r_packet_track_interaction = [InteractionType.BOUNDARY] - r_packet_track_distance = [0.0] + if montecarlo_configuration.RPACKET_TRACKING: + track_r_packet(r_packet, r_packet_track, distance=0) while r_packet.status == PacketStatus.IN_PROCESS: distance, interaction_type, delta_shell = trace_packet( @@ -87,6 +86,9 @@ def single_packet_loop( r_packet, delta_shell, len(numba_model.r_inner) ) + if montecarlo_configuration.RPACKET_TRACKING: + track_r_packet(r_packet, r_packet_track, distance) + elif interaction_type == InteractionType.LINE: r_packet.last_interaction_type = 2 @@ -103,6 +105,9 @@ def single_packet_loop( r_packet, vpacket_collection, numba_model, numba_plasma ) + if montecarlo_configuration.RPACKET_TRACKING: + track_r_packet(r_packet, r_packet_track, distance) + elif interaction_type == InteractionType.ESCATTERING: r_packet.last_interaction_type = 1 @@ -114,21 +119,9 @@ def single_packet_loop( trace_vpacket_volley( r_packet, vpacket_collection, numba_model, numba_plasma ) - if mc_logger.DEBUG_MODE: - r_packet_track_nu.append(r_packet.nu) - r_packet_track_mu.append(r_packet.mu) - r_packet_track_r.append(r_packet.r) - r_packet_track_interaction.append(interaction_type) - r_packet_track_distance.append(distance) - - if mc_logger.DEBUG_MODE: - return ( - r_packet_track_nu, - r_packet_track_mu, - r_packet_track_r, - r_packet_track_interaction, - r_packet_track_distance, - ) + + if montecarlo_configuration.RPACKET_TRACKING: + track_r_packet(r_packet, r_packet_track, distance) # check where else initialize line ID happens!