Skip to content

Commit

Permalink
Spectrum frequency var rename
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Jul 22, 2024
1 parent 7d03b2e commit 72c8220
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 53 deletions.
6 changes: 2 additions & 4 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,10 @@ def set_seed(value):

@property
def verysimple_3vpacket_collection(self):
spectrum_frequency = (
self.nb_simulation_verysimple.transport.spectrum_frequency.value
)
spectrum_frequency_grid = self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency=spectrum_frequency,
spectrum_frequency_grid=spectrum_frequency_grid,
number_of_vpackets=3,
v_packet_spawn_start_frequency=0,
v_packet_spawn_end_frequency=np.inf,
Expand Down
8 changes: 5 additions & 3 deletions benchmarks/transport_montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
"""

from benchmarks.benchmark_base import BenchmarkBase
from tardis.transport.montecarlo.montecarlo_main_loop import montecarlo_main_loop
from tardis.transport.montecarlo.montecarlo_main_loop import (
montecarlo_main_loop,
)


class BenchmarkTransportMontecarloMainLoop(BenchmarkBase):
Expand All @@ -19,9 +21,9 @@ def time_montecarlo_main_loop(self):
self.transport_state.opacity_state,
self.montecarlo_configuration,
self.transport_state.radfield_mc_estimators,
self.transport_state.spectrum_frequency.value,
self.transport_state.spectrum_frequency_grid.value,
self.montecarlo_configuration.NUMBER_OF_VPACKETS,
iteration=0,
show_progress_bars=False,
total_iterations=0
total_iterations=0,
)
22 changes: 11 additions & 11 deletions tardis/spectrum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class SpectrumSolver(HDFWriterMixin):

hdf_name = "spectrum"

def __init__(self, transport_state, spectrum_frequency):
def __init__(self, transport_state, spectrum_frequency_grid):
self.transport_state = transport_state
self.spectrum_frequency = spectrum_frequency
self.spectrum_frequency_grid = spectrum_frequency_grid
self._montecarlo_virtual_luminosity = u.Quantity(
np.zeros_like(self.spectrum_frequency.value), "erg / s"
np.zeros_like(self.spectrum_frequency_grid.value), "erg / s"
) # should be init with v_packets_energy_hist
self._integrator = None
self.integrator_settings = None
Expand All @@ -35,13 +35,13 @@ def __init__(self, transport_state, spectrum_frequency):
@property
def spectrum_real_packets(self):
return TARDISSpectrum(
self.spectrum_frequency, self.montecarlo_emitted_luminosity
self.spectrum_frequency_grid, self.montecarlo_emitted_luminosity
)

@property
def spectrum_real_packets_reabsorbed(self):
return TARDISSpectrum(
self.spectrum_frequency, self.montecarlo_reabsorbed_luminosity
self.spectrum_frequency_grid, self.montecarlo_reabsorbed_luminosity
)

@property
Expand All @@ -55,7 +55,7 @@ def spectrum_virtual_packets(self):
)

return TARDISSpectrum(
self.spectrum_frequency, self.montecarlo_virtual_luminosity
self.spectrum_frequency_grid, self.montecarlo_virtual_luminosity
)

@property
Expand All @@ -65,7 +65,7 @@ def spectrum_integrated(self):
# is not used in calculate_spectrum
try:
self._spectrum_integrated = self.integrator.calculate_spectrum(
self.spectrum_frequency[:-1],
self.spectrum_frequency_grid[:-1],
points=self.integrator_settings.points,
interpolate_shells=self.integrator_settings.interpolate_shells,
)
Expand Down Expand Up @@ -109,7 +109,7 @@ def montecarlo_reabsorbed_luminosity(self):
np.histogram(
self.transport_state.reabsorbed_packet_nu,
weights=self.transport_state.reabsorbed_packet_luminosity,
bins=self.spectrum_frequency,
bins=self.spectrum_frequency_grid,
)[0],
"erg / s",
)
Expand All @@ -120,7 +120,7 @@ def montecarlo_emitted_luminosity(self):
np.histogram(
self.transport_state.emitted_packet_nu,
weights=self.transport_state.emitted_packet_luminosity,
bins=self.spectrum_frequency,
bins=self.spectrum_frequency_grid,
)[0],
"erg / s",
)
Expand Down Expand Up @@ -180,13 +180,13 @@ def calculate_reabsorbed_luminosity(

@classmethod
def from_config(cls, config):
spectrum_frequency = quantity_linspace(
spectrum_frequency_grid = quantity_linspace(
config.spectrum.stop.to("Hz", u.spectral()),
config.spectrum.start.to("Hz", u.spectral()),
num=config.spectrum.num + 1,
)

return cls(
transport_state=None,
spectrum_frequency=spectrum_frequency,
spectrum_frequency_grid=spectrum_frequency_grid,
)
16 changes: 8 additions & 8 deletions tardis/spectrum/tests/test_spectrum_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,26 @@ def get_expected_data(self, key: str):

def test_initialization(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency = simulation.transport.spectrum_frequency
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency)
solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
assert solver.transport_state == transport_state
assert np.array_equal(
solver.spectrum_frequency.value, spectrum_frequency.value
solver.spectrum_frequency_grid.value, spectrum_frequency_grid.value
)
assert np.array_equal(
solver._montecarlo_virtual_luminosity.value,
np.zeros_like(spectrum_frequency.value),
np.zeros_like(spectrum_frequency_grid.value),
)
assert solver._integrator is None
assert solver.integrator_settings is None
assert solver._spectrum_integrated is None

def test_spectrum_real_packets(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency = simulation.transport.spectrum_frequency
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency)
solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
result = solver.spectrum_real_packets.luminosity
key = "simulation/spectrum_solver/spectrum_real_packets/luminosity"
expected = self.get_expected_data(key)
Expand All @@ -74,9 +74,9 @@ def test_spectrum_real_packets(self, simulation):

def test_spectrum_real_packets_reabsorbed(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency = simulation.transport.spectrum_frequency
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency)
solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
result = solver.spectrum_real_packets_reabsorbed.luminosity
key = "simulation/spectrum_solver/spectrum_real_packets_reabsorbed/luminosity"
expected = self.get_expected_data(key)
Expand Down
10 changes: 5 additions & 5 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MonteCarloTransportSolver(HDFWriterMixin):

def __init__(
self,
spectrum_frequency,
spectrum_frequency_grid,
virtual_spectrum_spawn_range,
enable_full_relativity,
line_interaction_type,
Expand All @@ -63,7 +63,7 @@ def __init__(
montecarlo_configuration=None,
):
# inject different packets
self.spectrum_frequency = spectrum_frequency
self.spectrum_frequency_grid = spectrum_frequency_grid
self.virtual_spectrum_spawn_range = virtual_spectrum_spawn_range
self.enable_full_relativity = enable_full_relativity
self.line_interaction_type = line_interaction_type
Expand Down Expand Up @@ -171,7 +171,7 @@ def run(
transport_state.opacity_state,
self.montecarlo_configuration,
transport_state.radfield_mc_estimators,
self.spectrum_frequency.value,
self.spectrum_frequency_grid.value,
number_of_vpackets,
iteration=iteration,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -241,7 +241,7 @@ def from_config(
logger.debug("Electron scattering switched on")
constants.SIGMA_THOMSON = const.sigma_T.to("cm^2").value

spectrum_frequency = quantity_linspace(
spectrum_frequency_grid = quantity_linspace(
config.spectrum.stop.to("Hz", u.spectral()),
config.spectrum.start.to("Hz", u.spectral()),
num=config.spectrum.num + 1,
Expand Down Expand Up @@ -281,7 +281,7 @@ def from_config(
)

return cls(
spectrum_frequency=spectrum_frequency,
spectrum_frequency_grid=spectrum_frequency_grid,
virtual_spectrum_spawn_range=config.montecarlo.virtual_spectrum_spawn_range,
enable_full_relativity=config.montecarlo.enable_full_relativity,
line_interaction_type=config.plasma.line_interaction_type,
Expand Down
20 changes: 10 additions & 10 deletions tardis/transport/montecarlo/montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def montecarlo_main_loop(
opacity_state,
montecarlo_configuration,
estimators,
spectrum_frequency,
spectrum_frequency_grid,
number_of_vpackets,
iteration,
show_progress_bars,
Expand All @@ -49,7 +49,7 @@ def montecarlo_main_loop(
Time in seconds
opacity_state : OpacityState
estimators : Estimators
spectrum_frequency : astropy.units.Quantity
spectrum_frequency_grid : astropy.units.Quantity
Frequency bins
number_of_vpackets : int
VPackets released per interaction
Expand All @@ -66,8 +66,8 @@ def montecarlo_main_loop(
no_of_packets
)

v_packets_energy_hist = np.zeros_like(spectrum_frequency)
delta_nu = spectrum_frequency[1] - spectrum_frequency[0]
v_packets_energy_hist = np.zeros_like(spectrum_frequency_grid)
delta_nu = spectrum_frequency_grid[1] - spectrum_frequency_grid[0]

# Pre-allocate a list of vpacket collections for later storage
vpacket_collections = List()
Expand All @@ -77,7 +77,7 @@ def montecarlo_main_loop(
vpacket_collections.append(
VPacketCollection(
i,
spectrum_frequency,
spectrum_frequency_grid,
montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY,
montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY,
number_of_vpackets,
Expand Down Expand Up @@ -156,12 +156,12 @@ def montecarlo_main_loop(
vpacket_collection.finalize_arrays()

v_packets_idx = np.floor(
(vpacket_collection.nus - spectrum_frequency[0]) / delta_nu
(vpacket_collection.nus - spectrum_frequency_grid[0]) / delta_nu
).astype(np.int64)

for j, idx in enumerate(v_packets_idx):
if (vpacket_collection.nus[j] < spectrum_frequency[0]) or (
vpacket_collection.nus[j] > spectrum_frequency[-1]
if (vpacket_collection.nus[j] < spectrum_frequency_grid[0]) or (
vpacket_collection.nus[j] > spectrum_frequency_grid[-1]
):
continue
v_packets_energy_hist[idx] += vpacket_collection.energies[j]
Expand All @@ -172,14 +172,14 @@ def montecarlo_main_loop(
if montecarlo_configuration.ENABLE_VPACKET_TRACKING:
vpacket_tracker = consolidate_vpacket_tracker(
vpacket_collections,
spectrum_frequency,
spectrum_frequency_grid,
montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY,
montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY,
)
else:
vpacket_tracker = VPacketCollection(
-1,
spectrum_frequency,
spectrum_frequency_grid,
montecarlo_configuration.VPACKET_SPAWN_START_FREQUENCY,
montecarlo_configuration.VPACKET_SPAWN_END_FREQUENCY,
-1,
Expand Down
12 changes: 6 additions & 6 deletions tardis/transport/montecarlo/packet_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def update_last_interaction(self, r_packet, i):

vpacket_collection_spec = [
("source_rpacket_index", int64),
("spectrum_frequency", float64[:]),
("spectrum_frequency_grid", float64[:]),
("v_packet_spawn_start_frequency", float64),
("v_packet_spawn_end_frequency", float64),
("nus", float64[:]),
Expand All @@ -129,13 +129,13 @@ class VPacketCollection:
def __init__(
self,
source_rpacket_index,
spectrum_frequency,
spectrum_frequency_grid,
v_packet_spawn_start_frequency,
v_packet_spawn_end_frequency,
number_of_vpackets,
temporary_v_packet_bins,
):
self.spectrum_frequency = spectrum_frequency
self.spectrum_frequency_grid = spectrum_frequency_grid
self.v_packet_spawn_start_frequency = v_packet_spawn_start_frequency
self.v_packet_spawn_end_frequency = v_packet_spawn_end_frequency
self.nus = np.empty(temporary_v_packet_bins, dtype=np.float64)
Expand Down Expand Up @@ -298,7 +298,7 @@ def finalize_arrays(self):

@njit(**njit_dict_no_parallel)
def consolidate_vpacket_tracker(
vpacket_collections, spectrum_frequency, start_frequency, end_frequency
vpacket_collections, spectrum_frequency_grid, start_frequency, end_frequency
):
"""
Consolidate the vpacket trackers from multiple collections into a single vpacket tracker.
Expand All @@ -307,7 +307,7 @@ def consolidate_vpacket_tracker(
----------
vpacket_collections : List[VPacketCollection]
List of vpacket collections to consolidate.
spectrum_frequency : ndarray
spectrum_frequency_grid : ndarray
Array of spectrum frequencies.
Returns
Expand All @@ -322,7 +322,7 @@ def consolidate_vpacket_tracker(

vpacket_tracker = VPacketCollection(
-1,
spectrum_frequency,
spectrum_frequency_grid,
start_frequency,
end_frequency,
-1,
Expand Down
12 changes: 6 additions & 6 deletions tardis/transport/montecarlo/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def verysimple_estimators(nb_simulation_verysimple):

@pytest.fixture(scope="package")
def verysimple_vpacket_collection(nb_simulation_verysimple):
spectrum_frequency = (
nb_simulation_verysimple.transport.spectrum_frequency.value
spectrum_frequency_grid = (
nb_simulation_verysimple.transport.spectrum_frequency_grid.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency=spectrum_frequency,
spectrum_frequency_grid=spectrum_frequency_grid,
number_of_vpackets=0,
v_packet_spawn_start_frequency=0,
v_packet_spawn_end_frequency=np.inf,
Expand All @@ -82,12 +82,12 @@ def verysimple_vpacket_collection(nb_simulation_verysimple):

@pytest.fixture(scope="package")
def verysimple_3vpacket_collection(nb_simulation_verysimple):
spectrum_frequency = (
nb_simulation_verysimple.transport.spectrum_frequency.value
spectrum_frequency_grid = (
nb_simulation_verysimple.transport.spectrum_frequency_grid.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency=spectrum_frequency,
spectrum_frequency_grid=spectrum_frequency_grid,
number_of_vpackets=3,
v_packet_spawn_start_frequency=0,
v_packet_spawn_end_frequency=np.inf,
Expand Down

0 comments on commit 72c8220

Please sign in to comment.