Skip to content

Commit

Permalink
Merge branch 'tardis-sn:master' into gsoc-benchmark-first-objective-iarv
Browse files Browse the repository at this point in the history
  • Loading branch information
airvzxf authored Apr 20, 2024
2 parents 67a73c9 + e25468d commit 3642a13
Show file tree
Hide file tree
Showing 33 changed files with 641 additions and 369 deletions.
1 change: 1 addition & 0 deletions tardis/io/model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def transport_from_hdf(fname):
nthreads=d["nthreads"],
enable_virtual_packet_logging=d["virt_logging"],
use_gpu=d["use_gpu"],
montecarlo_configuration=d["montecarlo_configuration"],
)

new_transport.Edotlu_estimator = d["Edotlu_estimator"]
Expand Down
12 changes: 8 additions & 4 deletions tardis/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def no_of_raw_shells(self):
return self.geometry.no_of_shells

@classmethod
def from_config(cls, config, atom_data):
def from_config(cls, config, atom_data, legacy_mode_enabled=False):
"""
Create a new SimulationState instance from a Configuration object.
Expand Down Expand Up @@ -269,7 +269,9 @@ def from_config(cls, config, atom_data):
atom_data.atom_data.mass.copy(),
)

packet_source = parse_packet_source(config, geometry)
packet_source = parse_packet_source(
config, geometry, legacy_mode_enabled
)
radiation_field_state = parse_radiation_field_state(
config,
t_radiative,
Expand All @@ -288,7 +290,7 @@ def from_config(cls, config, atom_data):
)

@classmethod
def from_csvy(cls, config, atom_data=None):
def from_csvy(cls, config, atom_data=None, legacy_mode_enabled=False):
"""
Create a new SimulationState instance from a Configuration object.
Expand Down Expand Up @@ -366,7 +368,9 @@ def from_csvy(cls, config, atom_data=None):
geometry,
)

packet_source = parse_packet_source(config, geometry)
packet_source = parse_packet_source(
config, geometry, legacy_mode_enabled
)

radiation_field_state = parse_csvy_radiation_field_state(
config, csvy_model_config, csvy_model_data, geometry, packet_source
Expand Down
22 changes: 17 additions & 5 deletions tardis/model/parse_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,9 @@ def parse_radiation_field_state(
)


def initialize_packet_source(config, geometry, packet_source):
def initialize_packet_source(
config, geometry, packet_source, legacy_mode_enabled
):
"""
Initialize the packet source based on config and geometry
Expand All @@ -613,9 +615,13 @@ def initialize_packet_source(config, geometry, packet_source):
packet_source = BlackBodySimpleSourceRelativistic(
base_seed=config.montecarlo.seed,
time_explosion=config.supernova.time_explosion,
legacy_mode_enabled=legacy_mode_enabled,
)
else:
packet_source = BlackBodySimpleSource(base_seed=config.montecarlo.seed)
packet_source = BlackBodySimpleSource(
base_seed=config.montecarlo.seed,
legacy_mode_enabled=legacy_mode_enabled,
)

luminosity_requested = config.supernova.luminosity_requested
if config.plasma.initial_t_inner > 0.0 * u.K:
Expand All @@ -635,7 +641,7 @@ def initialize_packet_source(config, geometry, packet_source):
return packet_source


def parse_packet_source(config, geometry):
def parse_packet_source(config, geometry, legacy_mode_enabled):
"""
Parse the packet source based on the given configuration and geometry.
Expand All @@ -655,11 +661,17 @@ def parse_packet_source(config, geometry):
packet_source = BlackBodySimpleSourceRelativistic(
base_seed=config.montecarlo.seed,
time_explosion=config.supernova.time_explosion,
legacy_mode_enabled=legacy_mode_enabled,
)
else:
packet_source = BlackBodySimpleSource(base_seed=config.montecarlo.seed)
packet_source = BlackBodySimpleSource(
base_seed=config.montecarlo.seed,
legacy_mode_enabled=legacy_mode_enabled,
)

return initialize_packet_source(config, geometry, packet_source)
return initialize_packet_source(
config, geometry, packet_source, legacy_mode_enabled
)


def parse_csvy_radiation_field_state(
Expand Down
25 changes: 18 additions & 7 deletions tardis/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from tardis import constants as const
from tardis.io.logger import montecarlo_tracking as mc_tracker
from tardis.io.util import HDFWriterMixin
from tardis.montecarlo import montecarlo_configuration
from tardis.montecarlo.estimators.radfield_mc_estimators import (
initialize_estimator_statistics,
)
from tardis.montecarlo.montecarlo_configuration import (
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.montecarlo.montecarlo_numba import (
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(
debug_packets=False,
logger_buffer=1,
use_gpu=False,
montecarlo_configuration=None,
):
# inject different packets
self.disable_electron_scattering = disable_electron_scattering
Expand All @@ -86,6 +87,7 @@ def __init__(

self.enable_vpacket_tracking = enable_virtual_packet_logging
self.enable_rpacket_tracking = enable_rpacket_tracking
self.montecarlo_configuration = montecarlo_configuration

self.packet_source = packet_source

Expand Down Expand Up @@ -124,7 +126,10 @@ def initialize_transport_state(

geometry_state = simulation_state.geometry.to_numba()
opacity_state = opacity_state_initialize(
plasma, self.line_interaction_type
plasma,
self.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
self.montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
)
transport_state = MonteCarloTransportState(
packet_collection,
Expand All @@ -139,7 +144,9 @@ def initialize_transport_state(
transport_state._integrator = FormalIntegrator(
simulation_state, plasma, self
)
configuration_initialize(self, no_of_virtual_packets)
configuration_initialize(
self.montecarlo_configuration, self, no_of_virtual_packets
)

return transport_state

Expand Down Expand Up @@ -172,7 +179,7 @@ def run(

numba_model = NumbaModel(time_explosion.to("s").value)

number_of_vpackets = montecarlo_configuration.NUMBER_OF_VPACKETS
number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS

(
v_packets_energy_hist,
Expand All @@ -184,6 +191,7 @@ def run(
transport_state.geometry_state,
numba_model,
transport_state.opacity_state,
self.montecarlo_configuration,
transport_state.radfield_mc_estimators,
transport_state.spectrum_frequency.value,
number_of_vpackets,
Expand All @@ -208,15 +216,15 @@ def run(
last_interaction_tracker.shell_ids
)

if montecarlo_configuration.ENABLE_VPACKET_TRACKING and (
if self.montecarlo_configuration.ENABLE_VPACKET_TRACKING and (
number_of_vpackets > 0
):
transport_state.vpacket_tracker = vpacket_tracker

update_iterations_pbar(1)
refresh_packet_pbar()
# Condition for Checking if RPacket Tracking is enabled
if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if self.montecarlo_configuration.ENABLE_RPACKET_TRACKING:
transport_state.rpacket_tracker = rpacket_trackers

if self.transport_state.rpacket_tracker is not None:
Expand All @@ -226,7 +234,7 @@ def run(
)
)
transport_state.virt_logging = (
montecarlo_configuration.ENABLE_VPACKET_TRACKING
self.montecarlo_configuration.ENABLE_VPACKET_TRACKING
)

def legacy_return(self):
Expand Down Expand Up @@ -300,6 +308,8 @@ def from_config(
valid values are 'GPU', 'CPU', and 'Automatic'."""
)

montecarlo_configuration = MonteCarloConfiguration()

montecarlo_configuration.DISABLE_LINE_SCATTERING = (
config.plasma.disable_line_scattering
)
Expand Down Expand Up @@ -329,4 +339,5 @@ def from_config(
enable_rpacket_tracking=config.montecarlo.tracking.track_rpacket,
nthreads=config.montecarlo.nthreads,
use_gpu=use_gpu,
montecarlo_configuration=montecarlo_configuration,
)
120 changes: 120 additions & 0 deletions tardis/montecarlo/estimators/radfield_estimator_calcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from math import exp

from numba import njit

from tardis.montecarlo.montecarlo_numba import (
njit_dict_no_parallel,
)
from tardis.montecarlo.montecarlo_numba.numba_config import KB, H
from tardis.transport.frame_transformations import (
calc_packet_energy,
calc_packet_energy_full_relativity,
)


@njit(**njit_dict_no_parallel)
def update_base_estimators(
r_packet, distance, estimator_state, comov_nu, comov_energy
):
"""
Updating the estimators
"""
estimator_state.j_estimator[r_packet.current_shell_id] += (
comov_energy * distance
)
estimator_state.nu_bar_estimator[r_packet.current_shell_id] += (
comov_energy * distance * comov_nu
)


@njit(**njit_dict_no_parallel)
def update_bound_free_estimators(
comov_nu,
comov_energy,
shell_id,
distance,
estimator_state,
t_electron,
x_sect_bfs,
current_continua,
bf_threshold_list_nu,
):
"""
Update the estimators for bound-free processes.
Parameters
----------
comov_nu : float
comov_energy : float
shell_id : int
distance : float
numba_estimator : tardis.montecarlo.montecarlo_numba.numba_interface.Estimators
t_electron : float
Electron temperature in the current cell.
x_sect_bfs : numpy.ndarray, dtype float
Photoionization cross-sections of all bound-free continua for
which absorption is possible for frequency `comov_nu`.
current_continua : numpy.ndarray, dtype int
Continuum ids for which absorption is possible for frequency `comov_nu`.
bf_threshold_list_nu : numpy.ndarray, dtype float
Threshold frequencies for photoionization sorted by decreasing frequency.
"""
# TODO: Add full relativity mode
boltzmann_factor = exp(-(H * comov_nu) / (KB * t_electron))
for i, current_continuum in enumerate(current_continua):
photo_ion_rate_estimator_increment = (
comov_energy * distance * x_sect_bfs[i] / comov_nu
)
estimator_state.photo_ion_estimator[
current_continuum, shell_id
] += photo_ion_rate_estimator_increment
estimator_state.stim_recomb_estimator[current_continuum, shell_id] += (
photo_ion_rate_estimator_increment * boltzmann_factor
)
estimator_state.photo_ion_estimator_statistics[
current_continuum, shell_id
] += 1

nu_th = bf_threshold_list_nu[current_continuum]
bf_heating_estimator_increment = (
comov_energy * distance * x_sect_bfs[i] * (1 - nu_th / comov_nu)
)
estimator_state.bf_heating_estimator[
current_continuum, shell_id
] += bf_heating_estimator_increment
estimator_state.stim_recomb_cooling_estimator[
current_continuum, shell_id
] += (bf_heating_estimator_increment * boltzmann_factor)


@njit(**njit_dict_no_parallel)
def update_line_estimators(
radfield_mc_estimators,
r_packet,
cur_line_id,
distance_trace,
time_explosion,
enable_full_relativity,
):
"""
Function to update the line estimators
Parameters
----------
estimators : tardis.montecarlo.montecarlo_numba.numba_interface.Estimators
r_packet : tardis.montecarlo.montecarlo_numba.r_packet.RPacket
cur_line_id : int
distance_trace : float
time_explosion : float
"""
if not enable_full_relativity:
energy = calc_packet_energy(r_packet, distance_trace, time_explosion)
else:
energy = calc_packet_energy_full_relativity(r_packet)

radfield_mc_estimators.j_blue_estimator[
cur_line_id, r_packet.current_shell_id
] += (energy / r_packet.nu)
radfield_mc_estimators.Edotlu_estimator[
cur_line_id, r_packet.current_shell_id
] += energy
Loading

0 comments on commit 3642a13

Please sign in to comment.