From eaa8105060024ecdc07844334eb42d37f4e5ad25 Mon Sep 17 00:00:00 2001 From: Andrew Fullard Date: Mon, 1 Jul 2024 14:22:15 -0400 Subject: [PATCH] Fix vis tests --- tardis/simulation/base.py | 26 ++++++++++++------- tardis/spectrum/base.py | 7 +++-- tardis/spectrum/spectrum.py | 4 +-- .../tests/test_tardis_full_formal_integral.py | 7 +++-- tardis/transport/montecarlo/base.py | 2 ++ .../transport/montecarlo/tests/test_base.py | 4 +-- tardis/visualization/widgets/line_info.py | 10 +++---- 7 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 97cf102a714..49318aadd88 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -135,7 +135,7 @@ def __init__( show_convergence_plots, convergence_plots_kwargs, show_progress_bars, - integrator_settings + integrator_settings, ): super(Simulation, self).__init__( iterations, simulation_state.no_of_shells @@ -207,7 +207,11 @@ def __init__( ) def estimate_t_inner( - self, input_t_inner, luminosity_requested, emitted_luminosity, t_inner_update_exponent=-0.5 + self, + input_t_inner, + luminosity_requested, + emitted_luminosity, + t_inner_update_exponent=-0.5, ): luminosity_ratios = ( (emitted_luminosity / luminosity_requested).to(1).value @@ -378,7 +382,7 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): iteration=self.iterations_executed, ) - self.transport.run( + v_packets_energy_hist = self.transport.run( transport_state, time_explosion=self.simulation_state.time_explosion, iteration=self.iterations_executed, @@ -386,10 +390,16 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): show_progress_bars=self.show_progress_bars, ) - self.spectrum_solver = SpectrumSolver(transport_state, self.transport.spectrum_frequency) + self.spectrum_solver = SpectrumSolver( + transport_state, self.transport.spectrum_frequency + ) self.spectrum_solver.integrator_settings = self.integrator_settings + self.spectrum_solver._montecarlo_virtual_luminosity.value[ + : + ] = v_packets_energy_hist + self.spectrum_solver._integrator = FormalIntegrator( self.simulation_state, self.plasma, self.transport ) @@ -400,10 +410,8 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): if np.sum(output_energy < 0) == len(output_energy): logger.critical("No r-packet escaped through the outer boundary.") - emitted_luminosity = ( - self.spectrum_solver.calculate_emitted_luminosity( - self.luminosity_nu_start, self.luminosity_nu_end - ) + emitted_luminosity = self.spectrum_solver.calculate_emitted_luminosity( + self.luminosity_nu_start, self.luminosity_nu_end ) reabsorbed_luminosity = ( self.spectrum_solver.calculate_reabsorbed_luminosity( @@ -760,5 +768,5 @@ def from_config( convergence_strategy=config.montecarlo.convergence_strategy, convergence_plots_kwargs=convergence_plots_kwargs, show_progress_bars=show_progress_bars, - integrator_settings=config.spectrum.integrated + integrator_settings=config.spectrum.integrated, ) diff --git a/tardis/spectrum/base.py b/tardis/spectrum/base.py index 8eed00eb0e9..c985fb1023f 100644 --- a/tardis/spectrum/base.py +++ b/tardis/spectrum/base.py @@ -27,7 +27,7 @@ def __init__(self, transport_state, spectrum_frequency): self.spectrum_frequency = spectrum_frequency self._montecarlo_virtual_luminosity = u.Quantity( np.zeros_like(self.spectrum_frequency.value), "erg / s" - ) # should be init with v_packets_energy_hist + ) # should be init with v_packets_energy_hist self._integrator = None self.integrator_settings = None self._spectrum_integrated = None @@ -186,4 +186,7 @@ def from_config(cls, config, v_packets_energy_hist): num=config.spectrum.num + 1, ) - return cls(transport_state=None, spectrum_frequency=spectrum_frequency,) + return cls( + transport_state=None, + spectrum_frequency=spectrum_frequency, + ) diff --git a/tardis/spectrum/spectrum.py b/tardis/spectrum/spectrum.py index 2caa10fe269..79b3e832650 100644 --- a/tardis/spectrum/spectrum.py +++ b/tardis/spectrum/spectrum.py @@ -40,9 +40,7 @@ def __init__(self, _frequency, luminosity): self.luminosity = luminosity.to("erg / s") l_nu_unit = u.def_unit(r"erg s^{-1} Hz^{-1}", u.Unit("erg/(s Hz)")) - l_lambda_unit = u.def_unit( - r"erg s^{-1} \AA^{-1}", u.Unit("erg/(s AA)") - ) + l_lambda_unit = u.def_unit(r"erg s^{-1} \AA^{-1}", u.Unit("erg/(s AA)")) self.frequency = self._frequency[:-1] self.delta_frequency = self._frequency[1] - self._frequency[0] diff --git a/tardis/tests/test_tardis_full_formal_integral.py b/tardis/tests/test_tardis_full_formal_integral.py index aa2690851fc..2ec0de124ad 100644 --- a/tardis/tests/test_tardis_full_formal_integral.py +++ b/tardis/tests/test_tardis_full_formal_integral.py @@ -60,7 +60,9 @@ def simulation( return simulation else: simulation.spectrum_solver.hdf_properties = ["spectrum"] - simulation.spectrum_solver.to_hdf(tardis_ref_data, "", self.name, overwrite=True) + simulation.spectrum_solver.to_hdf( + tardis_ref_data, "", self.name, overwrite=True + ) simulation.transport.hdf_properties = ["transport_state"] simulation.transport.to_hdf( tardis_ref_data, "", self.name, overwrite=True @@ -97,5 +99,6 @@ def test_spectrum_integrated(self, simulation, refdata): ) assert_quantity_allclose( - simulation.spectrum_solver.spectrum_integrated.luminosity, luminosity + simulation.spectrum_solver.spectrum_integrated.luminosity, + luminosity, ) diff --git a/tardis/transport/montecarlo/base.py b/tardis/transport/montecarlo/base.py index 236aba3c1a0..dd417e3d6d1 100644 --- a/tardis/transport/montecarlo/base.py +++ b/tardis/transport/montecarlo/base.py @@ -212,6 +212,8 @@ def run( self.montecarlo_configuration.ENABLE_VPACKET_TRACKING ) + return v_packets_energy_hist + @classmethod def from_config( cls, config, packet_source, enable_virtual_packet_logging=False diff --git a/tardis/transport/montecarlo/tests/test_base.py b/tardis/transport/montecarlo/tests/test_base.py index 69fe62366d0..c35f94cfbe6 100644 --- a/tardis/transport/montecarlo/tests/test_base.py +++ b/tardis/transport/montecarlo/tests/test_base.py @@ -42,8 +42,8 @@ def test_hdf_transport( "output_energy", "nu_bar_estimator", "j_estimator", - #"montecarlo_virtual_luminosity", - #"packet_luminosity", + # "montecarlo_virtual_luminosity", + # "packet_luminosity", # These are nested properties that should be tested differently # "spectrum", # "spectrum_virtual", diff --git a/tardis/visualization/widgets/line_info.py b/tardis/visualization/widgets/line_info.py index e605e4e1679..e880bac7cec 100644 --- a/tardis/visualization/widgets/line_info.py +++ b/tardis/visualization/widgets/line_info.py @@ -128,7 +128,7 @@ def from_simulation(cls, sim): ------- LineInfoWidget object """ - transport_state = sim.transport.transport_state + spectrum_solver = sim.spectrum_solver return cls( lines_data=sim.plasma.lines.reset_index().set_index("line_id"), line_interaction_analysis={ @@ -137,12 +137,12 @@ def from_simulation(cls, sim): ) for filter_mode in cls.FILTER_MODES }, - spectrum_wavelength=transport_state.spectrum.wavelength, - spectrum_luminosity_density_lambda=transport_state.spectrum.luminosity_density_lambda.to( + spectrum_wavelength=spectrum_solver.spectrum.wavelength, + spectrum_luminosity_density_lambda=spectrum_solver.spectrum.luminosity_density_lambda.to( "erg/(s AA)" ), - virt_spectrum_wavelength=transport_state.spectrum_virtual.wavelength, - virt_spectrum_luminosity_density_lambda=transport_state.spectrum_virtual.luminosity_density_lambda.to( + virt_spectrum_wavelength=spectrum_solver.spectrum_virtual.wavelength, + virt_spectrum_luminosity_density_lambda=spectrum_solver.spectrum_virtual.luminosity_density_lambda.to( "erg/(s AA)" ), )