Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed May 8, 2024
1 parent ddc3d2c commit 8ddf424
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 19 deletions.
4 changes: 3 additions & 1 deletion tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def initialize_transport_state(
opacity_state=opacity_state,
)

transport_state.enable_full_relativity = self.montecarlo_configuration.ENABLE_FULL_RELATIVITY
transport_state.enable_full_relativity = (
self.montecarlo_configuration.ENABLE_FULL_RELATIVITY
)
transport_state.integrator_settings = self.integrator_settings
transport_state._integrator = FormalIntegrator(
simulation_state, plasma, self
Expand Down
4 changes: 1 addition & 3 deletions tardis/transport/montecarlo/formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ def numba_formal_integral(
# calculate e-scattering optical depth to grid cell boundary

Jkkp = 0.5 * (Jred_lu[pJred_lu] + Jblue_lu[pJblue_lu])
zend = (
time_explosion / C_INV * (1.0 - nu_end / nu)
) # check
zend = time_explosion / C_INV * (1.0 - nu_end / nu) # check
escat_contrib += (
(zend - zstart) * escat_op * (Jkkp - I_nu[p_idx])
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ def time_explosion():
not GPUs_available, reason="No GPU is available to test CUDA function"
)
@pytest.mark.parametrize(["p", "p_loc"], [(0.0, 0), (0.5, 1), (1.0, 2)])
def test_calculate_z_cuda(
formal_integral_geometry, time_explosion, p, p_loc
):
def test_calculate_z_cuda(formal_integral_geometry, time_explosion, p, p_loc):
"""
Initializes the test of the cuda version
against the numba implementation of the
Expand Down
14 changes: 4 additions & 10 deletions tardis/transport/montecarlo/tests/test_numba_formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def formal_integral_geometry(request):

@pytest.fixture(scope="function")
def time_explosion():
return 1 / c.c.cgs.value
return 1 / c.c.cgs.value


@pytest.mark.parametrize("p", [0.0, 0.5, 1.0])
Expand All @@ -89,16 +89,12 @@ def test_calculate_z(formal_integral_geometry, time_explosion, p):


@pytest.mark.parametrize("p", [0, 0.5, 1])
def test_populate_z_photosphere(
formal_integral_geometry, time_explosion, p
):
def test_populate_z_photosphere(formal_integral_geometry, time_explosion, p):
"""
Test the case where p < r[0]
That means we 'hit' all shells from inside to outside.
"""
integrator = formal_integral.FormalIntegrator(
time_explosion, None, None
)
integrator = formal_integral.FormalIntegrator(time_explosion, None, None)
func = formal_integral.populate_z
size = len(formal_integral_geometry.r_outer)
r_inner = formal_integral_geometry.r_inner
Expand All @@ -121,9 +117,7 @@ def test_populate_z_shells(formal_integral_geometry, time_explosion, p):
"""
Test the case where p > r[0]
"""
integrator = formal_integral.FormalIntegrator(
time_explosion, None, None
)
integrator = formal_integral.FormalIntegrator(time_explosion, None, None)
func = formal_integral.populate_z

size = len(formal_integral_geometry.r_inner)
Expand Down
4 changes: 3 additions & 1 deletion tardis/transport/montecarlo/tests/test_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ def test_trace_packet(
set_seed_fixture,
):
set_seed_fixture(1963)
packet.initialize_line_id(verysimple_opacity_state, verysimple_time_explosion)
packet.initialize_line_id(
verysimple_opacity_state, verysimple_time_explosion
)
distance, interaction_type, delta_shell = r_packet_transport.trace_packet(
packet,
verysimple_time_explosion,
Expand Down
4 changes: 3 additions & 1 deletion tardis/transport/montecarlo/tests/test_vpacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def test_trace_vpacket_volley(
# Set seed because of RNG in trace_vpacket
np.random.seed(1)

packet.initialize_line_id(verysimple_opacity_state, verysimple_time_explosion)
packet.initialize_line_id(
verysimple_opacity_state, verysimple_time_explosion
)

vpacket.trace_vpacket_volley(
packet,
Expand Down

0 comments on commit 8ddf424

Please sign in to comment.