diff --git a/tardis/workflows/simple_tardis_workflow.py b/tardis/workflows/simple_tardis_workflow.py index c7d0649eb11..5ed88dd2f6e 100644 --- a/tardis/workflows/simple_tardis_workflow.py +++ b/tardis/workflows/simple_tardis_workflow.py @@ -36,35 +36,30 @@ def __init__(self, configuration): super().__init__(configuration, self.log_level, self.specific_log_level) atom_data = parse_atom_data(configuration) - self.line_interaction_type = configuration.plasma.line_interaction_type - # set up states and solvers self.simulation_state = SimulationState.from_config( configuration, atom_data=atom_data, ) - self.opacity_state = None - self.opacity_solver = OpacitySolver( - self.line_interaction_type, - configuration.plasma.disable_line_scattering, - ) - - self.macro_atom_state = None - if self.line_interaction_type in ( - "downbranch", - "macroatom", - ): - self.macro_atom_solver = MacroAtomSolver() - else: - self.macro_atom_solver = None - self.plasma_solver = assemble_plasma( configuration, self.simulation_state, atom_data=atom_data, ) + line_interaction_type = configuration.plasma.line_interaction_type + + self.opacity_solver = OpacitySolver( + line_interaction_type, + configuration.plasma.disable_line_scattering, + ) + + if line_interaction_type == "scatter": + self.macro_atom_solver = None + else: + self.macro_atom_solver = MacroAtomSolver() + self.transport_solver = MonteCarloTransportSolver.from_config( configuration, packet_source=self.simulation_state.packet_source, @@ -326,11 +321,43 @@ def solve_plasma(self, estimated_radfield_properties): self.plasma_solver.update(**update_properties) - def solve_montecarlo(self, no_of_real_packets, no_of_virtual_packets=0): + def solve_opacity(self): + """Solves the opacity state and any associated objects + + Returns + ------- + dict + opacity_state : tardis.opacities.opacity_state.OpacityState + State of the line opacities + macro_atom_state : tardis.opacities.macro_atom.macro_atom_state.MacroAtomState or None + State of the macro atom + """ + opacity_state = self.opacity_solver.solve(self.plasma_solver) + + if self.macro_atom_solver is None: + macro_atom_state = None + else: + macro_atom_state = self.macro_atom_solver.solve( + self.plasma_solver, + self.plasma_solver.atomic_data, + opacity_state.tau_sobolev, + self.plasma_solver.stimulated_emission_factor, + ) + + return { + "opacity_state": opacity_state, + "macro_atom_state": macro_atom_state, + } + + def solve_montecarlo( + self, opacity_states, no_of_real_packets, no_of_virtual_packets=0 + ): """Solve the MonteCarlo process Parameters ---------- + opacity_states : dict + Opacity and (optionally) Macro Atom states. no_of_real_packets : int Number of real packets to simulate no_of_virtual_packets : int, optional @@ -343,10 +370,13 @@ def solve_montecarlo(self, no_of_real_packets, no_of_virtual_packets=0): ndarray Array of unnormalized virtual packet energies in each frequency bin """ + opacity_state = opacity_states["opacity_state"] + macro_atom_state = opacity_states["macro_atom_state"] + transport_state = self.transport_solver.initialize_transport_state( self.simulation_state, - self.opacity_state, - self.macro_atom_state, + opacity_state, + macro_atom_state, self.plasma_solver, no_of_real_packets, no_of_virtual_packets=no_of_virtual_packets, @@ -405,18 +435,10 @@ def run(self): f"\n\tStarting iteration {(self.completed_iterations + 1):d} of {self.total_iterations:d}" ) - self.opacity_state = self.opacity_solver.solve(self.plasma_solver) - - if self.macro_atom_solver is not None: - self.macro_atom_state = self.macro_atom_solver.solve( - self.plasma_solver, - self.plasma_solver.atomic_data, - self.opacity_state.tau_sobolev, - self.plasma_solver.stimulated_emission_factor, - ) + opacity_states = self.solve_opacity() transport_state, virtual_packet_energies = self.solve_montecarlo( - self.real_packet_count + opacity_states, self.real_packet_count ) ( @@ -441,7 +463,9 @@ def run(self): "\n\tITERATIONS HAVE NOT CONVERGED, starting final iteration" ) transport_state, virtual_packet_energies = self.solve_montecarlo( - self.final_iteration_packet_count, self.virtual_packet_count + opacity_states, + self.final_iteration_packet_count, + self.virtual_packet_count, ) self.initialize_spectrum_solver(