Skip to content

Commit

Permalink
refactor to compile time constants
Browse files Browse the repository at this point in the history
  • Loading branch information
wkerzendorf committed Jul 10, 2024
1 parent 467decd commit 830c53c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 101 deletions.
3 changes: 2 additions & 1 deletion tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tardis.model import SimulationState
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.simulation.convergence import ConvergenceSolver
from tardis.transport.montecarlo import montecarlo_configuration
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots
Expand Down Expand Up @@ -199,7 +200,7 @@ def __init__(
self._callbacks = OrderedDict()
self._cb_next_id = 0

self.transport.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = (
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = (
not self.plasma.continuum_interaction_species.empty
)

Expand Down
26 changes: 10 additions & 16 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)
from tardis.transport.montecarlo.formal_integral import FormalIntegrator
from tardis.transport.montecarlo.montecarlo_configuration import (
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.transport.montecarlo.montecarlo_transport_state import (
Expand All @@ -33,6 +32,8 @@
update_iterations_pbar,
)

from tardis.transport.montecarlo import montecarlo_configuration

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -62,7 +63,6 @@ def __init__(
debug_packets=False,
logger_buffer=1,
use_gpu=False,
montecarlo_configuration=None,
):
# inject different packets
self.spectrum_frequency = spectrum_frequency
Expand All @@ -77,8 +77,6 @@ def __init__(

self.enable_vpacket_tracking = enable_virtual_packet_logging
self.enable_rpacket_tracking = enable_rpacket_tracking
self.montecarlo_configuration = montecarlo_configuration

self.packet_source = packet_source

# Setting up the Tracking array for storing all the RPacketTracker instances
Expand Down Expand Up @@ -115,8 +113,8 @@ def initialize_transport_state(
opacity_state = opacity_state_initialize(
plasma,
self.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_configuration.DISABLE_LINE_SCATTERING,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)
transport_state = MonteCarloTransportState(
packet_collection,
Expand All @@ -127,14 +125,14 @@ def initialize_transport_state(
)

transport_state.enable_full_relativity = (
self.montecarlo_configuration.ENABLE_FULL_RELATIVITY
montecarlo_configuration.ENABLE_FULL_RELATIVITY
)
transport_state.integrator_settings = self.integrator_settings
transport_state._integrator = FormalIntegrator(
simulation_state, plasma, self
)
configuration_initialize(
self.montecarlo_configuration, self, no_of_virtual_packets
montecarlo_configuration, self, no_of_virtual_packets
)

return transport_state
Expand Down Expand Up @@ -166,7 +164,7 @@ def run(
set_num_threads(self.nthreads)
self.transport_state = transport_state

number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS
number_of_vpackets = montecarlo_configuration.NUMBER_OF_VPACKETS

(
v_packets_energy_hist,
Expand All @@ -178,7 +176,6 @@ def run(
transport_state.geometry_state,
time_explosion.cgs.value,
transport_state.opacity_state,
self.montecarlo_configuration,
transport_state.radfield_mc_estimators,
transport_state.spectrum_frequency.value,
number_of_vpackets,
Expand All @@ -202,15 +199,15 @@ def run(
last_interaction_tracker.shell_ids
)

if self.montecarlo_configuration.ENABLE_VPACKET_TRACKING and (
if montecarlo_configuration.ENABLE_VPACKET_TRACKING and (
number_of_vpackets > 0
):
transport_state.vpacket_tracker = vpacket_tracker

update_iterations_pbar(1)
refresh_packet_pbar()
# Condition for Checking if RPacket Tracking is enabled
if self.montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
transport_state.rpacket_tracker = rpacket_trackers

if self.transport_state.rpacket_tracker is not None:
Expand All @@ -220,7 +217,7 @@ def run(
)
)
transport_state.virt_logging = (
self.montecarlo_configuration.ENABLE_VPACKET_TRACKING
montecarlo_configuration.ENABLE_VPACKET_TRACKING
)

@classmethod
Expand Down Expand Up @@ -275,8 +272,6 @@ def from_config(
valid values are 'GPU', 'CPU', and 'Automatic'."""
)

montecarlo_configuration = MonteCarloConfiguration()

montecarlo_configuration.DISABLE_LINE_SCATTERING = (
config.plasma.disable_line_scattering
)
Expand Down Expand Up @@ -306,5 +301,4 @@ def from_config(
enable_rpacket_tracking=config.montecarlo.tracking.track_rpacket,
nthreads=config.montecarlo.nthreads,
use_gpu=use_gpu,
montecarlo_configuration=montecarlo_configuration,
)
55 changes: 26 additions & 29 deletions tardis/transport/montecarlo/formal_integral.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
import warnings

import numpy as np
import pandas as pd
import scipy.sparse as sp
import scipy.sparse.linalg as linalg
from scipy.interpolate import interp1d
from astropy import units as u
from tardis import constants as const
from numba import njit, char, float64, int64, typeof, byte, prange
from numba.experimental import jitclass

from numba import njit, prange
from scipy.interpolate import interp1d

from tardis import constants as const
from tardis.opacities.opacity_state import (
OpacityState,
opacity_state_initialize,
)
from tardis.transport.montecarlo.numba_config import SIGMA_THOMSON
from tardis.transport.montecarlo import njit_dict, njit_dict_no_parallel
from tardis.transport.montecarlo.numba_interface import (
opacity_state_initialize,
OpacityState,
from tardis.spectrum import TARDISSpectrum
from tardis.transport.montecarlo import (
montecarlo_configuration,
njit_dict,
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.formal_integral_cuda import (
CudaFormalIntegrator,
)

from tardis.spectrum import TARDISSpectrum
from tardis.transport.montecarlo.numba_config import SIGMA_THOMSON
from tardis.transport.montecarlo.numba_interface import (
opacity_state_initialize,
)

C_INV = 3.33564e-11
M_PI = np.arccos(-1)
Expand Down Expand Up @@ -60,8 +60,7 @@ def numba_formal_integral(
intensities at each p-ray multiplied by p
frequency x p-ray grid
"""

# todo: add all the original todos
# TODO: add all the original todos
# Initialize the output which is shared among threads
L = np.zeros(inu_size, dtype=np.float64)
# global read-only values
Expand Down Expand Up @@ -214,7 +213,7 @@ def numba_formal_integral(


# @jitclass(integrator_spec)
class NumbaFormalIntegrator(object):
class NumbaFormalIntegrator:
"""
Helper class for performing the formal integral
with numba.
Expand Down Expand Up @@ -257,7 +256,7 @@ def formal_integral(
)


class FormalIntegrator(object):
class FormalIntegrator:
"""
Class containing the formal integrator.
Expand All @@ -280,24 +279,21 @@ def __init__(self, simulation_state, plasma, transport, points=1000):
self.simulation_state = simulation_state
self.transport = transport
self.points = points
if transport:
self.montecarlo_configuration = (
self.transport.montecarlo_configuration
)
if plasma:
self.plasma = opacity_state_initialize(
plasma,
transport.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_configuration.DISABLE_LINE_SCATTERING,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)
self.atomic_data = plasma.atomic_data
self.original_plasma = plasma
self.levels_index = plasma.levels

def generate_numba_objects(self):
"""instantiate the numba interface objects
needed for computing the formal integral"""
needed for computing the formal integral
"""
from tardis.model.geometry.radial1d import NumbaRadial1DGeometry

self.numba_radial_1d_geometry = NumbaRadial1DGeometry(
Expand All @@ -311,8 +307,8 @@ def generate_numba_objects(self):
self.opacity_state = opacity_state_initialize(
self.original_plasma,
self.transport.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_configuration.DISABLE_LINE_SCATTERING,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)
if self.transport.use_gpu:
self.integrator = CudaFormalIntegrator(
Expand Down Expand Up @@ -354,7 +350,7 @@ def raise_or_return(message):
"FormalIntegrator."
)

if not self.transport.line_interaction_type in [
if self.transport.line_interaction_type not in [
"downbranch",
"macroatom",
]:
Expand All @@ -364,7 +360,7 @@ def raise_or_return(message):
'and line_interaction_type == "macroatom"'
)

if self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED:
if montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED:
return raise_or_return(
"The FormalIntegrator currently does not work for "
"continuum interactions."
Expand Down Expand Up @@ -614,7 +610,8 @@ def interpolate_integrator_quantities(

def formal_integral(self, nu, N):
"""Do the formal integral with the numba
routines"""
routines
"""
# TODO: get rid of storage later on

res = self.make_source_function()
Expand Down
78 changes: 26 additions & 52 deletions tardis/transport/montecarlo/montecarlo_configuration.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,30 @@
from astropy import units as u
from numba import float64, int64, boolean
from numba.experimental import jitclass
import numpy as np
from astropy import units as u

from tardis.transport.montecarlo.numba_interface import (
LineInteractionType,
)

numba_config_spec = [
("ENABLE_FULL_RELATIVITY", boolean),
("TEMPORARY_V_PACKET_BINS", int64),
("NUMBER_OF_VPACKETS", int64),
("MONTECARLO_SEED", int64),
("LINE_INTERACTION_TYPE", int64),
("PACKET_SEEDS", int64[:]),
("DISABLE_ELECTRON_SCATTERING", boolean),
("DISABLE_LINE_SCATTERING", boolean),
("SURVIVAL_PROBABILITY", float64),
("VPACKET_TAU_RUSSIAN", float64),
("INITIAL_TRACKING_ARRAY_LENGTH", int64),
("LEGACY_MODE_ENABLED", boolean),
("ENABLE_RPACKET_TRACKING", boolean),
("CONTINUUM_PROCESSES_ENABLED", boolean),
("VPACKET_SPAWN_START_FREQUENCY", float64),
("VPACKET_SPAWN_END_FREQUENCY", float64),
("ENABLE_VPACKET_TRACKING", boolean),
]


@jitclass(numba_config_spec)
class MonteCarloConfiguration(object):
def __init__(self):
self.ENABLE_FULL_RELATIVITY = False
self.TEMPORARY_V_PACKET_BINS = 0
self.NUMBER_OF_VPACKETS = 0
self.MONTECARLO_SEED = 0
self.LINE_INTERACTION_TYPE = 0
self.PACKET_SEEDS = np.empty(1, dtype=np.int64)
self.DISABLE_ELECTRON_SCATTERING = False
self.DISABLE_LINE_SCATTERING = False
self.SURVIVAL_PROBABILITY = 0.0
self.VPACKET_TAU_RUSSIAN = 10.0
ENABLE_FULL_RELATIVITY = False
TEMPORARY_V_PACKET_BINS = 0
NUMBER_OF_VPACKETS = 0
MONTECARLO_SEED = 0
LINE_INTERACTION_TYPE = 0
PACKET_SEEDS = np.empty(1, dtype=np.int64)
DISABLE_ELECTRON_SCATTERING = False
DISABLE_LINE_SCATTERING = False
SURVIVAL_PROBABILITY = 0.0
VPACKET_TAU_RUSSIAN = 10.0

self.INITIAL_TRACKING_ARRAY_LENGTH = 0
self.LEGACY_MODE_ENABLED = False
INITIAL_TRACKING_ARRAY_LENGTH = 0
LEGACY_MODE_ENABLED = False

self.ENABLE_RPACKET_TRACKING = False
self.CONTINUUM_PROCESSES_ENABLED = False
ENABLE_RPACKET_TRACKING = False
CONTINUUM_PROCESSES_ENABLED = False

self.VPACKET_SPAWN_START_FREQUENCY = 0
self.VPACKET_SPAWN_END_FREQUENCY = 1e200
self.ENABLE_VPACKET_TRACKING = False
VPACKET_SPAWN_START_FREQUENCY = 0
VPACKET_SPAWN_END_FREQUENCY = 1e200
ENABLE_VPACKET_TRACKING = False


def configuration_initialize(config, transport, number_of_vpackets):
Expand All @@ -66,19 +40,19 @@ def configuration_initialize(config, transport, number_of_vpackets):
f'"downbranch", or "scatter" but is '
f"{transport.line_interaction_type}"
)
config.NUMBER_OF_VPACKETS = number_of_vpackets
config.TEMPORARY_V_PACKET_BINS = number_of_vpackets
config.ENABLE_FULL_RELATIVITY = transport.enable_full_relativity
config.MONTECARLO_SEED = transport.packet_source.base_seed
config.VPACKET_SPAWN_START_FREQUENCY = (
NUMBER_OF_VPACKETS = number_of_vpackets
TEMPORARY_V_PACKET_BINS = number_of_vpackets
ENABLE_FULL_RELATIVITY = transport.enable_full_relativity
MONTECARLO_SEED = transport.packet_source.base_seed
VPACKET_SPAWN_START_FREQUENCY = (
transport.virtual_spectrum_spawn_range.end.to(
u.Hz, equivalencies=u.spectral()
).value
)
config.VPACKET_SPAWN_END_FREQUENCY = (
VPACKET_SPAWN_END_FREQUENCY = (
transport.virtual_spectrum_spawn_range.start.to(
u.Hz, equivalencies=u.spectral()
).value
)
config.ENABLE_VPACKET_TRACKING = transport.enable_vpacket_tracking
config.ENABLE_RPACKET_TRACKING = transport.enable_rpacket_tracking
ENABLE_VPACKET_TRACKING = transport.enable_vpacket_tracking
ENABLE_RPACKET_TRACKING = transport.enable_rpacket_tracking
Loading

0 comments on commit 830c53c

Please sign in to comment.