Skip to content

Commit

Permalink
Added functionality to track properties for r_packets, configured fro…
Browse files Browse the repository at this point in the history
…m YAML
  • Loading branch information
DhruvSondhi committed Jul 26, 2021
1 parent 8903c13 commit b6c989d
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 38 deletions.
9 changes: 9 additions & 0 deletions tardis/io/schemas/montecarlo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ properties:
default: false
description: Enables a more complete treatment of relativitic effects. This includes
angle aberration as well as use of the fully general Doppler formula.
montecarlo_tracking:
type: object
default: {}
properties:
r_packet_tracking:
type: boolean
default: false
description: Allows for Tracking the properties of the RPackets in Single Packet Loop
description: Sets up tracking for Montecarlo
debug_packets:
type: boolean
default: false
Expand Down
6 changes: 6 additions & 0 deletions tardis/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
debug_packets=False,
logger_buffer=1,
single_packet_seed=None,
track_r_packet=False,
):

self.seed = seed
Expand Down Expand Up @@ -134,10 +135,14 @@ def __init__(
self.virt_packet_initial_rs = np.ones(2) * -1.0
self.virt_packet_initial_mus = np.ones(2) * -1.0

self.r_packet_tracking = np.zeros(1)

# set up logger based on config
mc_logger.DEBUG_MODE = debug_packets
mc_logger.BUFFER = logger_buffer

mc_config_module.RPACKET_TRACKING = track_r_packet

if self.spectrum_method == "integrated":
self.optional_hdf_properties.append("spectrum_integrated")

Expand Down Expand Up @@ -610,4 +615,5 @@ def from_config(
config.spectrum.virtual.virtual_packet_logging
| virtual_packet_logging
),
track_r_packet=config.montecarlo.montecarlo_tracking.r_packet_tracking,
)
1 change: 1 addition & 0 deletions tardis/montecarlo/montecarlo_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
tau_russian = 10.0
LEGACY_MODE_ENABLED = False
VPACKET_LOGGING = False
RPACKET_TRACKING = False
59 changes: 46 additions & 13 deletions tardis/montecarlo/montecarlo_numba/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tardis.montecarlo.montecarlo_numba.numba_interface import (
PacketCollection,
VPacketCollection,
RPacketCollection,
NumbaModel,
numba_plasma_initialize,
Estimators,
Expand Down Expand Up @@ -68,6 +69,7 @@ def montecarlo_radial1d(model, plasma, runner):
virt_packet_last_interaction_type,
virt_packet_last_line_interaction_in_id,
virt_packet_last_line_interaction_out_id,
r_packet_tracker,
) = montecarlo_main_loop(
packet_collection,
numba_model,
Expand Down Expand Up @@ -110,6 +112,9 @@ def montecarlo_radial1d(model, plasma, runner):
np.array(virt_packet_last_line_interaction_out_id)
).ravel()

# Condition for Checking if R Packet Tracking is enabled
runner.r_packet_tracking = r_packet_tracker


@njit(**njit_dict)
def montecarlo_main_loop(
Expand Down Expand Up @@ -177,6 +182,11 @@ def montecarlo_main_loop(
virt_packet_last_line_interaction_in_id = []
virt_packet_last_line_interaction_out_id = []

# Tracking for R_Packet
r_packet_tracker = List()
for i in range(len(output_nus)):
r_packet_tracker.append(RPacketCollection())

for i in prange(len(output_nus)):
if montecarlo_configuration.single_packet_seed != -1:
seed = packet_seeds[montecarlo_configuration.single_packet_seed]
Expand All @@ -193,9 +203,15 @@ def montecarlo_main_loop(
i,
)
vpacket_collection = vpacket_collections[i]
r_packet_track = r_packet_tracker[i]

loop = single_packet_loop(
r_packet, numba_model, numba_plasma, estimators, vpacket_collection
single_packet_loop(
r_packet,
numba_model,
numba_plasma,
estimators,
vpacket_collection,
r_packet_track,
)
# if loop and 'stop' in loop:
# raise MonteCarloException
Expand All @@ -214,8 +230,12 @@ def montecarlo_main_loop(

vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx]
vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx]
vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx]
vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx]
vpackets_initial_mu = vpacket_collection.initial_mus[
: vpacket_collection.idx
]
vpackets_initial_r = vpacket_collection.initial_rs[
: vpacket_collection.idx
]

v_packets_idx = np.floor(
(vpackets_nu - spectrum_frequency[0]) / delta_nu
Expand All @@ -233,17 +253,29 @@ def montecarlo_main_loop(
if montecarlo_configuration.VPACKET_LOGGING:
for vpacket_collection in vpacket_collections:
vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx]
vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx]
vpackets_initial_mu = vpacket_collection.initial_mus[: vpacket_collection.idx]
vpackets_initial_r = vpacket_collection.initial_rs[: vpacket_collection.idx]
vpackets_energy = vpacket_collection.energies[
: vpacket_collection.idx
]
vpackets_initial_mu = vpacket_collection.initial_mus[
: vpacket_collection.idx
]
vpackets_initial_r = vpacket_collection.initial_rs[
: vpacket_collection.idx
]
virt_packet_nus.append(np.ascontiguousarray(vpackets_nu))
virt_packet_energies.append(np.ascontiguousarray(vpackets_energy))
virt_packet_initial_mus.append(np.ascontiguousarray(vpackets_initial_mu))
virt_packet_initial_rs.append(np.ascontiguousarray(vpackets_initial_r))
virt_packet_last_interaction_in_nu.append(np.ascontiguousarray(
vpacket_collection.last_interaction_in_nu[
: vpacket_collection.idx
])
virt_packet_initial_mus.append(
np.ascontiguousarray(vpackets_initial_mu)
)
virt_packet_initial_rs.append(
np.ascontiguousarray(vpackets_initial_r)
)
virt_packet_last_interaction_in_nu.append(
np.ascontiguousarray(
vpacket_collection.last_interaction_in_nu[
: vpacket_collection.idx
]
)
)
virt_packet_last_interaction_type.append(
np.ascontiguousarray(
Expand Down Expand Up @@ -284,4 +316,5 @@ def montecarlo_main_loop(
virt_packet_last_interaction_type,
virt_packet_last_line_interaction_in_id,
virt_packet_last_line_interaction_out_id,
r_packet_tracker,
)
82 changes: 82 additions & 0 deletions tardis/montecarlo/montecarlo_numba/numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,88 @@ def set_properties(
self.idx += 1


rpacket_collection_spec = [
("seed", int64[:]),
("index", int64[:]),
("status", int64[:]),
("r", float64[:]),
("nu", float64[:]),
("mu", float64[:]),
("energy", float64[:]),
("current_shell_id", int64[:]),
("distance", float64[:]),
("last_interaction_type", int64[:]),
("last_interaction_in_nu", float64[:]),
("last_line_interaction_in_id", int64[:]),
("last_line_interaction_out_id", int64[:]),
]


@jitclass(rpacket_collection_spec)
class RPacketCollection(object):
def __init__(self):
self.seed = np.zeros(1, dtype=np.int64)
self.index = np.zeros(1, dtype=np.int64)
self.status = np.zeros(1, dtype=np.int64)
self.r = np.zeros(1, dtype=np.float64)
self.nu = np.zeros(1, dtype=np.float64)
self.mu = np.zeros(1, dtype=np.float64)
self.energy = np.zeros(1, dtype=np.float64)
self.current_shell_id = np.zeros(1, dtype=np.int64)
self.distance = np.zeros(1, dtype=np.float64)
self.last_interaction_type = -1 * np.ones(1, dtype=np.int64)
self.last_interaction_in_nu = np.zeros(1, dtype=np.float64)
self.last_line_interaction_in_id = -1 * np.ones(1, dtype=np.int64)
self.last_line_interaction_out_id = -1 * np.ones(1, dtype=np.int64)

def set_properties(self, r_packet, distance):
self.seed = np.concatenate(
(self.seed, np.array([r_packet.seed]))
).ravel()
self.index = np.concatenate(
(self.index, np.array([r_packet.index]))
).ravel()
self.status = np.concatenate(
(self.status, np.array([r_packet.status]))
).ravel()
self.r = np.concatenate((self.r, np.array([r_packet.r]))).ravel()
self.nu = np.concatenate((self.nu, np.array([r_packet.nu]))).ravel()
self.mu = np.concatenate((self.mu, np.array([r_packet.mu]))).ravel()
self.energy = np.concatenate(
(self.energy, np.array([r_packet.energy]))
).ravel()
self.current_shell_id = np.concatenate(
(self.current_shell_id, np.array([r_packet.current_shell_id]))
).ravel()
self.distance = np.concatenate(
(self.distance, np.array([distance]))
).ravel()
self.last_interaction_type = np.concatenate(
(
self.last_interaction_type,
np.array([r_packet.last_interaction_type]),
)
).ravel()
self.last_interaction_in_nu = np.concatenate(
(
self.last_interaction_in_nu,
np.array([r_packet.last_interaction_in_nu]),
)
).ravel()
self.last_line_interaction_in_id = np.concatenate(
(
self.last_line_interaction_in_id,
np.array([r_packet.last_line_interaction_in_id]),
)
).ravel()
self.last_line_interaction_out_id = np.concatenate(
(
self.last_line_interaction_out_id,
np.array([r_packet.last_line_interaction_out_id]),
)
).ravel()


estimators_spec = [
("j_estimator", float64[:]),
("nu_bar_estimator", float64[:]),
Expand Down
5 changes: 5 additions & 0 deletions tardis/montecarlo/montecarlo_numba/r_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,8 @@ def move_packet_across_shell_boundary(packet, delta_shell, no_of_shells):
packet.status = PacketStatus.REABSORBED
else:
packet.current_shell_id = next_shell_id


@njit(**njit_dict_no_parallel)
def track_r_packet(r_packet, r_packet_track, distance):
r_packet_track.set_properties(r_packet, distance)
43 changes: 18 additions & 25 deletions tardis/montecarlo/montecarlo_numba/single_packet_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
trace_packet,
move_packet_across_shell_boundary,
move_r_packet,
track_r_packet,
)

from tardis.montecarlo.montecarlo_numba.utils import MonteCarloException
Expand All @@ -31,13 +32,15 @@

C_SPEED_OF_LIGHT = const.c.to("cm/s").value

from tardis.io.logger.montecarlo_logger import log_decorator
from tardis.io.logger import montecarlo_logger as mc_logger

# @log_decorator
@njit
def single_packet_loop(
r_packet, numba_model, numba_plasma, estimators, vpacket_collection
r_packet,
numba_model,
numba_plasma,
estimators,
vpacket_collection,
r_packet_track,
):
"""
Parameters
Expand Down Expand Up @@ -67,12 +70,8 @@ def single_packet_loop(
r_packet, vpacket_collection, numba_model, numba_plasma
)

if mc_logger.DEBUG_MODE:
r_packet_track_nu = [r_packet.nu]
r_packet_track_mu = [r_packet.mu]
r_packet_track_r = [r_packet.r]
r_packet_track_interaction = [InteractionType.BOUNDARY]
r_packet_track_distance = [0.0]
if montecarlo_configuration.RPACKET_TRACKING:
track_r_packet(r_packet, r_packet_track, distance=0)

while r_packet.status == PacketStatus.IN_PROCESS:
distance, interaction_type, delta_shell = trace_packet(
Expand All @@ -87,6 +86,9 @@ def single_packet_loop(
r_packet, delta_shell, len(numba_model.r_inner)
)

if montecarlo_configuration.RPACKET_TRACKING:
track_r_packet(r_packet, r_packet_track, distance)

elif interaction_type == InteractionType.LINE:
r_packet.last_interaction_type = 2

Expand All @@ -103,6 +105,9 @@ def single_packet_loop(
r_packet, vpacket_collection, numba_model, numba_plasma
)

if montecarlo_configuration.RPACKET_TRACKING:
track_r_packet(r_packet, r_packet_track, distance)

elif interaction_type == InteractionType.ESCATTERING:
r_packet.last_interaction_type = 1

Expand All @@ -114,21 +119,9 @@ def single_packet_loop(
trace_vpacket_volley(
r_packet, vpacket_collection, numba_model, numba_plasma
)
if mc_logger.DEBUG_MODE:
r_packet_track_nu.append(r_packet.nu)
r_packet_track_mu.append(r_packet.mu)
r_packet_track_r.append(r_packet.r)
r_packet_track_interaction.append(interaction_type)
r_packet_track_distance.append(distance)

if mc_logger.DEBUG_MODE:
return (
r_packet_track_nu,
r_packet_track_mu,
r_packet_track_r,
r_packet_track_interaction,
r_packet_track_distance,
)

if montecarlo_configuration.RPACKET_TRACKING:
track_r_packet(r_packet, r_packet_track, distance)

# check where else initialize line ID happens!

Expand Down

0 comments on commit b6c989d

Please sign in to comment.