Skip to content

Commit

Permalink
refactor compile time constants
Browse files Browse the repository at this point in the history
  • Loading branch information
wkerzendorf committed Jul 11, 2024
1 parent a73f54a commit b3b470a
Show file tree
Hide file tree
Showing 16 changed files with 38 additions and 104 deletions.
2 changes: 1 addition & 1 deletion benchmarks/transport_montecarlo_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import tardis.transport.montecarlo.interaction as interaction
from benchmarks.benchmark_base import BenchmarkBase
from tardis.transport.montecarlo.numba_interface import (
from tardis.transport.montecarlo.numba_config import (
LineInteractionType,
)

Expand Down
2 changes: 1 addition & 1 deletion tardis/opacities/opacities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
SIGMA_THOMSON,
)

Expand Down
4 changes: 3 additions & 1 deletion tardis/transport/frame_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.numba_config import C_SPEED_OF_LIGHT
from tardis.transport.montecarlo.configuration.constants import (
C_SPEED_OF_LIGHT,
)


@njit(**njit_dict_no_parallel)
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/geometry/calculate_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
C_SPEED_OF_LIGHT,
MISS_DISTANCE,
SIGMA_THOMSON,
Expand Down
40 changes: 24 additions & 16 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
from tardis.io.util import HDFWriterMixin
from tardis.transport.montecarlo import (
montecarlo_main_loop,
numba_config,
)
from tardis.transport.montecarlo.configuration import (
constants,
montecarlo_globals,
)
from tardis.transport.montecarlo.configuration.base import (
configuration_initialize,
MonteCarloConfiguration,
)
from tardis.transport.montecarlo.estimators.radfield_mc_estimators import (
initialize_estimator_statistics,
)
from tardis.transport.montecarlo.formal_integral import FormalIntegrator
from tardis.transport.montecarlo.montecarlo_configuration import (
configuration_initialize,
)
from tardis.transport.montecarlo.montecarlo_transport_state import (
MonteCarloTransportState,
)
Expand All @@ -32,8 +36,6 @@
update_iterations_pbar,
)

from tardis.transport.montecarlo import montecarlo_globals

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -63,6 +65,7 @@ def __init__(
debug_packets=False,
logger_buffer=1,
use_gpu=False,
montecarlo_configuration=None,
):
# inject different packets
self.spectrum_frequency = spectrum_frequency
Expand All @@ -77,6 +80,8 @@ def __init__(

self.enable_vpacket_tracking = enable_virtual_packet_logging
self.enable_rpacket_tracking = enable_rpacket_tracking
self.montecarlo_configuration = montecarlo_configuration

self.packet_source = packet_source

# Setting up the Tracking array for storing all the RPacketTracker instances
Expand Down Expand Up @@ -128,11 +133,11 @@ def initialize_transport_state(
montecarlo_globals.ENABLE_FULL_RELATIVITY
)
transport_state.integrator_settings = self.integrator_settings
transport_state.integrator = FormalIntegrator(
transport_state._integrator = FormalIntegrator(
simulation_state, plasma, self
)
configuration_initialize(
montecarlo_configuration, self, no_of_virtual_packets
self.montecarlo_configuration, self, no_of_virtual_packets
)

return transport_state
Expand Down Expand Up @@ -164,7 +169,7 @@ def run(
set_num_threads(self.nthreads)
self.transport_state = transport_state

number_of_vpackets = montecarlo_configuration.NUMBER_OF_VPACKETS
number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS

(
v_packets_energy_hist,
Expand Down Expand Up @@ -199,15 +204,15 @@ def run(
last_interaction_tracker.shell_ids
)

if montecarlo_configuration.ENABLE_VPACKET_TRACKING and (
if montecarlo_globals.ENABLE_VPACKET_TRACKING and (
number_of_vpackets > 0
):
transport_state.vpacket_tracker = vpacket_tracker

update_iterations_pbar(1)
refresh_packet_pbar()
# Condition for Checking if RPacket Tracking is enabled
if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if montecarlo_globals.ENABLE_RPACKET_TRACKING:
transport_state.rpacket_tracker = rpacket_trackers

if self.transport_state.rpacket_tracker is not None:
Expand All @@ -217,7 +222,7 @@ def run(
)
)
transport_state.virt_logging = (
montecarlo_configuration.ENABLE_VPACKET_TRACKING
montecarlo_globals.ENABLE_VPACKET_TRACKING
)

@classmethod
Expand All @@ -242,10 +247,10 @@ def from_config(
"Likely bug in formal integral - "
"will not give same results."
)
numba_config.SIGMA_THOMSON = 1e-200
constants.SIGMA_THOMSON = 1e-200
else:
logger.debug("Electron scattering switched on")
numba_config.SIGMA_THOMSON = const.sigma_T.to("cm^2").value
constants.SIGMA_THOMSON = const.sigma_T.to("cm^2").value

spectrum_frequency = quantity_linspace(
config.spectrum.stop.to("Hz", u.spectral()),
Expand All @@ -272,14 +277,16 @@ def from_config(
valid values are 'GPU', 'CPU', and 'Automatic'."""
)

montecarlo_configuration.DISABLE_LINE_SCATTERING = (
montecarlo_globals.DISABLE_LINE_SCATTERING = (
config.plasma.disable_line_scattering
)

montecarlo_configuration.DISABLE_ELECTRON_SCATTERING = (
montecarlo_globals.DISABLE_ELECTRON_SCATTERING = (
config.plasma.disable_electron_scattering
)

montecarlo_configuration = MonteCarloConfiguration()

montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH = (
config.montecarlo.tracking.initial_array_length
)
Expand All @@ -301,4 +308,5 @@ def from_config(
enable_rpacket_tracking=config.montecarlo.tracking.track_rpacket,
nthreads=config.montecarlo.nthreads,
use_gpu=use_gpu,
montecarlo_configuration=montecarlo_configuration,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.numba_config import KB, H
from tardis.transport.montecarlo.configuration.constants import KB, H
from tardis.transport.frame_transformations import (
calc_packet_energy,
calc_packet_energy_full_relativity,
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tardis.transport.montecarlo.formal_integral_cuda import (
CudaFormalIntegrator,
)
from tardis.transport.montecarlo.numba_config import SIGMA_THOMSON
from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON
from tardis.transport.montecarlo.numba_interface import (
opacity_state_initialize,
)
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MacroAtomTransitionType,
macro_atom,
)
from tardis.transport.montecarlo.numba_interface import (
from tardis.transport.montecarlo.numba_config import (
LineInteractionType,
)
from tardis.transport.montecarlo.r_packet import (
Expand Down
52 changes: 0 additions & 52 deletions tardis/transport/montecarlo/montecarlo_configuration.py

This file was deleted.

9 changes: 0 additions & 9 deletions tardis/transport/montecarlo/montecarlo_globals.py

This file was deleted.

3 changes: 2 additions & 1 deletion tardis/transport/montecarlo/montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from numba.np.ufunc.parallel import get_num_threads, get_thread_id
from numba.typed import List

from tardis.transport.montecarlo import montecarlo_configuration, njit_dict
from tardis.transport.montecarlo import njit_dict
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.packet_collections import (
VPacketCollection,
consolidate_vpacket_tracker,
Expand Down
8 changes: 0 additions & 8 deletions tardis/transport/montecarlo/numba_config.py

This file was deleted.

8 changes: 0 additions & 8 deletions tardis/transport/montecarlo/numba_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from enum import IntEnum

from numba import float64, int64
from numba.experimental import jitclass
import numpy as np
Expand Down Expand Up @@ -248,9 +246,3 @@ def opacity_state_initialize(
photo_ion_activation_idx,
k_packet_idx,
)


class LineInteractionType(IntEnum):
SCATTER = 0
DOWNBRANCH = 1
MACROATOM = 2
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/r_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tardis.transport.frame_transformations import (
get_doppler_factor,
)
from tardis.transport.montecarlo import numba_config as nc
from tardis.transport.montecarlo.configuration import constants as nc
from tardis.transport.montecarlo import njit_dict_no_parallel


Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/tests/test_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy.testing as npt
import numpy as np
import tardis.transport.montecarlo.interaction as interaction
from tardis.transport.montecarlo.numba_interface import (
from tardis.transport.montecarlo.numba_config import (
LineInteractionType,
)

Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/vpacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
calculate_distance_line,
)
from tardis.transport.montecarlo import njit_dict_no_parallel
from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
C_SPEED_OF_LIGHT,
SIGMA_THOMSON,
)
Expand Down

0 comments on commit b3b470a

Please sign in to comment.