Skip to content

Commit

Permalink
Track boundary interaction (#2736)
Browse files Browse the repository at this point in the history
* Import Numba List

* Resolving conflicts

* Resolving conflicts

* Rebase

* Remove Unused import

* Add boundary interaction to packet_trackers

* Fix errors

* Add track_boundary functionality

* Add ENUM PacketEjectaStatus

* Add Enum

* Trigger Build

* Move OUTSIDE_EJECTA to PacketStatus ENUM

* Add Track Line Interaction

* Initializae self.line_interaction_array_length

* Fix Typo

* Rename interaction_id to event_id

* Add doc strings to the functions

* Add more doc strings

* Remove Line Interaction from this PR

* Remove Track Line Interaction From this PR

* Use Extend array function

* Fix Typo

* Remove length attribute from the class as it was clustering the code

* Remove line interaction from this PR

* Add tests

* Remove print statement

* Update test

* Trigger Build

* Rename track_boundary_interaction to boundary_interactions_track

* Rename from boundary_interactions_track to get_boundary_data

* Rename from get_boundary_data to get_boundary_interaction

* Rename function name in RPacketLastInteractionTracker to make it same as RPacketTracker

* Change the way RPacketTracker is imported

* Use RPacketLastInteractionTracker since it is set by default

* Remove ENUM

* Max column 60
  • Loading branch information
Sumit112192 authored Aug 12, 2024
1 parent 18dcec8 commit 63eb762
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 29 deletions.
25 changes: 17 additions & 8 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from tardis.transport.montecarlo.estimators import radfield_mc_estimators
from tardis.transport.montecarlo.numba_interface import opacity_state_initialize
from tardis.transport.montecarlo.packet_collections import VPacketCollection
from tardis.transport.montecarlo.packet_trackers import RPacketTracker


class BenchmarkBase:
Expand Down Expand Up @@ -62,8 +61,7 @@ def config_rpacket_tracking(self):
@functools.cached_property
def tardis_ref_path(self):
ref_data_path = Path(
Path(__file__).parent.parent,
"tardis-refdata"
Path(__file__).parent.parent, "tardis-refdata"
).resolve()
return ref_data_path

Expand Down Expand Up @@ -124,7 +122,9 @@ def packet(self):

@functools.cached_property
def verysimple_packet_collection(self):
return self.nb_simulation_verysimple.transport.transport_state.packet_collection
return (
self.nb_simulation_verysimple.transport.transport_state.packet_collection
)

@functools.cached_property
def nb_simulation_verysimple(self):
Expand Down Expand Up @@ -154,11 +154,15 @@ def verysimple_enable_full_relativity(self):

@functools.cached_property
def verysimple_tau_russian(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
)

@functools.cached_property
def verysimple_survival_probability(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
)

@functools.cached_property
def static_packet(self):
Expand All @@ -173,7 +177,9 @@ def static_packet(self):

@functools.cached_property
def verysimple_3vpacket_collection(self):
spectrum_frequency_grid = self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
spectrum_frequency_grid = (
self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency_grid=spectrum_frequency_grid,
Expand All @@ -195,7 +201,10 @@ def montecarlo_configuration(self):

@functools.cached_property
def rpacket_tracker(self):
return RPacketTracker(0)
# Do not use RPacketTracker or RPacketLastInteraction directly
# Use it by importing packet_trackers
# functions with name track_* function is used by ASV
return packet_trackers.RPacketLastInteractionTracker()

@functools.cached_property
def transport_state(self):
Expand Down
109 changes: 89 additions & 20 deletions tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from numba import float64, int64, njit
from numba import float64, int64, njit, from_dtype
from numba.experimental import jitclass
from numba.typed import List
import numpy as np
import pandas as pd


boundary_interaction_dtype = np.dtype(
[
("event_id", "int64"),
("current_shell_id", "int64"),
("next_shell_id", "int64"),
]
)


rpacket_tracker_spec = [
("length", int64),
("seed", int64),
("index", int64),
("status", int64[:]),
Expand All @@ -15,7 +24,10 @@
("energy", float64[:]),
("shell_id", int64[:]),
("interaction_type", int64[:]),
("boundary_interaction", from_dtype(boundary_interaction_dtype)[:]),
("num_interactions", int64),
("boundary_interactions_index", int64),
("event_id", int64),
("extend_factor", int64),
]

Expand Down Expand Up @@ -53,17 +65,25 @@ class RPacketTracker(object):
"""

def __init__(self, length):
self.length = length
"""
Initialize the variables with default value
"""
self.seed = np.int64(0)
self.index = np.int64(0)
self.status = np.empty(self.length, dtype=np.int64)
self.r = np.empty(self.length, dtype=np.float64)
self.nu = np.empty(self.length, dtype=np.float64)
self.mu = np.empty(self.length, dtype=np.float64)
self.energy = np.empty(self.length, dtype=np.float64)
self.shell_id = np.empty(self.length, dtype=np.int64)
self.interaction_type = np.empty(self.length, dtype=np.int64)
self.status = np.empty(length, dtype=np.int64)
self.r = np.empty(length, dtype=np.float64)
self.nu = np.empty(length, dtype=np.float64)
self.mu = np.empty(length, dtype=np.float64)
self.energy = np.empty(length, dtype=np.float64)
self.shell_id = np.empty(length, dtype=np.int64)
self.interaction_type = np.empty(length, dtype=np.int64)
self.boundary_interaction = np.empty(
length,
dtype=boundary_interaction_dtype,
)
self.num_interactions = 0
self.boundary_interactions_index = 0
self.event_id = 1
self.extend_factor = 2

def extend_array(self, array, array_length):
Expand All @@ -74,17 +94,19 @@ def extend_array(self, array, array_length):
return temp_array

def track(self, r_packet):
if self.num_interactions >= self.length:
self.status = self.extend_array(self.status, self.length)
self.r = self.extend_array(self.r, self.length)
self.nu = self.extend_array(self.nu, self.length)
self.mu = self.extend_array(self.mu, self.length)
self.energy = self.extend_array(self.energy, self.length)
self.shell_id = self.extend_array(self.shell_id, self.length)
"""
Track important properties of RPacket
"""
if self.num_interactions >= self.status.size:
self.status = self.extend_array(self.status, self.status.size)
self.r = self.extend_array(self.r, self.r.size)
self.nu = self.extend_array(self.nu, self.nu.size)
self.mu = self.extend_array(self.mu, self.mu.size)
self.energy = self.extend_array(self.energy, self.energy.size)
self.shell_id = self.extend_array(self.shell_id, self.shell_id.size)
self.interaction_type = self.extend_array(
self.interaction_type, self.length
self.interaction_type, self.interaction_type.size
)
self.length = self.length * self.extend_factor

self.index = r_packet.index
self.seed = r_packet.seed
Expand All @@ -99,14 +121,46 @@ def track(self, r_packet):
] = r_packet.last_interaction_type
self.num_interactions += 1

def track_boundary_interaction(self, current_shell_id, next_shell_id):
"""
Track boundary interaction properties
"""
if self.boundary_interactions_index >= self.boundary_interaction.size:
self.boundary_interaction = self.extend_array(
self.boundary_interaction,
self.boundary_interaction.size,
)

self.boundary_interaction[self.boundary_interactions_index][
"event_id"
] = self.event_id
self.event_id += 1

self.boundary_interaction[self.boundary_interactions_index][
"current_shell_id"
] = current_shell_id

self.boundary_interaction[self.boundary_interactions_index][
"next_shell_id"
] = next_shell_id

self.boundary_interactions_index += 1

def finalize_array(self):
"""
Change the size of the array from length ( or multiple of length ) to
the actual number of interactions
"""
self.status = self.status[: self.num_interactions]
self.r = self.r[: self.num_interactions]
self.nu = self.nu[: self.num_interactions]
self.mu = self.mu[: self.num_interactions]
self.energy = self.energy[: self.num_interactions]
self.shell_id = self.shell_id[: self.num_interactions]
self.interaction_type = self.interaction_type[: self.num_interactions]
self.boundary_interaction = self.boundary_interaction[
: self.boundary_interactions_index
]


def rpacket_trackers_to_dataframe(rpacket_trackers):
Expand Down Expand Up @@ -186,6 +240,9 @@ class RPacketLastInteractionTracker(object):
"""

def __init__(self):
"""
Initialize properties with default values
"""
self.index = -1
self.r = -1.0
self.nu = 0.0
Expand All @@ -194,15 +251,27 @@ def __init__(self):
self.interaction_type = -1

def track(self, r_packet):
"""
Track properties of RPacket and override the previous values
"""
self.index = r_packet.index
self.r = r_packet.r
self.nu = r_packet.nu
self.energy = r_packet.energy
self.shell_id = r_packet.current_shell_id
self.interaction_type = r_packet.last_interaction_type

# To make it compatible with RPacketTracker
def finalize_array(self):
"""
Added to make RPacketLastInteractionTracker compatible with RPacketTracker
"""
pass

# To make it compatible with RPacketTracker
def track_boundary_interaction(self, current_shell_id, next_shell_id):
"""
Added to make RPacketLastInteractionTracker compatible with RPacketTracker
"""
pass


Expand Down
8 changes: 7 additions & 1 deletion tardis/transport/montecarlo/single_packet_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def single_packet_loop(
# If continuum processes: update continuum estimators

if interaction_type == InteractionType.BOUNDARY:
rpacket_tracker.track_boundary_interaction(
r_packet.current_shell_id,
r_packet.current_shell_id + delta_shell,
)
move_r_packet(
r_packet,
distance,
Expand All @@ -166,7 +170,9 @@ def single_packet_loop(
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
)
move_packet_across_shell_boundary(
r_packet, delta_shell, len(numba_radial_1d_geometry.r_inner)
r_packet,
delta_shell,
len(numba_radial_1d_geometry.r_inner),
)

elif interaction_type == InteractionType.LINE:
Expand Down
24 changes: 24 additions & 0 deletions tardis/transport/montecarlo/tests/test_rpacket_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ def test_rpacket_tracker_properties(expected, obtained, request):
npt.assert_allclose(expected, obtained)


def test_boundary_interactions(rpacket_tracker, regression_data):
no_of_packets = len(rpacket_tracker)

# Hard coding the number of columns
# Based on the largest size of boundary_interaction array (60)
obtained_boundary_interaction = np.full(
(no_of_packets, 64),
[-1],
dtype=rpacket_tracker[0].boundary_interaction.dtype,
)

for i, tracker in enumerate(rpacket_tracker):
obtained_boundary_interaction[
i, : tracker.boundary_interaction.size
] = tracker.boundary_interaction

expected_boundary_interaction = regression_data.sync_ndarray(
obtained_boundary_interaction
)
npt.assert_array_equal(
obtained_boundary_interaction, expected_boundary_interaction
)


def test_rpacket_trackers_to_dataframe(simulation_rpacket_tracking):
transport_state = simulation_rpacket_tracking.transport.transport_state
rtracker_df = rpacket_trackers_to_dataframe(transport_state.rpacket_tracker)
Expand Down

0 comments on commit 63eb762

Please sign in to comment.