Skip to content

Commit

Permalink
Ruff tardis/transport (#2829)
Browse files Browse the repository at this point in the history
* ruff autofix tardis/transport safe fixes

* ruff autofix tardis/transport unsafe fixes

* fix ruff unsafe rule PIE790

* black 2 files

---------

Co-authored-by: Andrew Fullard <[email protected]>
  • Loading branch information
atharva-2001 and andrewfullard authored Sep 25, 2024
1 parent 59c5d28 commit de4d872
Show file tree
Hide file tree
Showing 32 changed files with 83 additions and 125 deletions.
1 change: 0 additions & 1 deletion tardis/transport/frame_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.configuration.constants import C_SPEED_OF_LIGHT


Expand Down
9 changes: 1 addition & 8 deletions tardis/transport/geometry/calculate_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,13 @@
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)

from tardis.transport.montecarlo.configuration.constants import (
C_SPEED_OF_LIGHT,
CLOSE_LINE_THRESHOLD,
MISS_DISTANCE,
SIGMA_THOMSON,
CLOSE_LINE_THRESHOLD,
)

from tardis.transport.montecarlo.utils import MonteCarloException
from tardis.transport.montecarlo.r_packet import (
print_r_packet_properties,
)


@njit(**njit_dict_no_parallel)
Expand All @@ -35,7 +30,6 @@ def calculate_distance_boundary(r, mu, r_inner, r_outer):
r_outer : float
outer radius of current shell
"""

delta_shell = 0
if mu > 0.0:
# direction outward
Expand Down Expand Up @@ -88,7 +82,6 @@ def calculate_distance_line(
Returns
-------
"""

nu = r_packet.nu

if is_last_line:
Expand Down
6 changes: 3 additions & 3 deletions tardis/transport/montecarlo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Faciliating the MonteCarlo iterations.
Faciliating the MonteCarlo iterations.
During a simulation run, a number of MonteCarlo iterations specified
in the configuration are run using the numba compiler.
Most of the iterations are used to calculate the steady-state plasma
Most of the iterations are used to calculate the steady-state plasma
properties and with the last iteration, the spectrum is determined.
"""

Expand All @@ -21,7 +21,7 @@
"parallel": False,
}

from tardis.transport.montecarlo.r_packet import RPacket
from tardis.transport.montecarlo.packet_collections import (
PacketCollection,
)
from tardis.transport.montecarlo.r_packet import RPacket
8 changes: 4 additions & 4 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from tardis import constants as const
from tardis.io.logger import montecarlo_tracking as mc_tracker
from tardis.io.util import HDFWriterMixin
from tardis.opacities.opacity_state import (
opacity_state_to_numba,
)
from tardis.transport.montecarlo.configuration.base import (
MonteCarloConfiguration,
configuration_initialize,
Expand All @@ -23,12 +26,9 @@
from tardis.transport.montecarlo.montecarlo_transport_state import (
MonteCarloTransportState,
)
from tardis.opacities.opacity_state import (
opacity_state_to_numba,
)
from tardis.transport.montecarlo.packet_trackers import (
generate_rpacket_tracker_list,
generate_rpacket_last_interaction_tracker_list,
generate_rpacket_tracker_list,
rpacket_trackers_to_dataframe,
)
from tardis.util.base import (
Expand Down
7 changes: 3 additions & 4 deletions tardis/transport/montecarlo/configuration/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
from astropy import units as u
from numba import float64, int64, boolean
from numba import boolean, float64, int64
from numba.experimental import jitclass
import numpy as np

from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo import montecarlo_main_loop
from tardis.transport.montecarlo.numba_interface import (
LineInteractionType,
)
Expand All @@ -29,7 +28,7 @@


@jitclass(numba_config_spec)
class MonteCarloConfiguration(object):
class MonteCarloConfiguration:
def __init__(self):
self.ENABLE_FULL_RELATIVITY = False
self.TEMPORARY_V_PACKET_BINS = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from numba import njit

from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.configuration.constants import KB, H
from tardis.transport.frame_transformations import (
calc_packet_energy,
calc_packet_energy_full_relativity,
)
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.configuration.constants import KB, H


@njit(**njit_dict_no_parallel)
Expand Down
10 changes: 5 additions & 5 deletions tardis/transport/montecarlo/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

import tardis.transport.montecarlo.configuration.montecarlo_globals as montecarlo_globals
from tardis import constants as const
from tardis.transport.frame_transformations import (
angle_aberration_CMF_to_LF,
get_doppler_factor,
get_inverse_doppler_factor,
)
from tardis.transport.montecarlo import njit_dict_no_parallel
from tardis.transport.montecarlo.macro_atom import (
MacroAtomTransitionType,
Expand All @@ -15,11 +20,6 @@
PacketStatus,
)
from tardis.transport.montecarlo.utils import get_random_mu
from tardis.transport.frame_transformations import (
angle_aberration_CMF_to_LF,
get_doppler_factor,
get_inverse_doppler_factor,
)

K_B = const.k_B.cgs.value
H = const.h.cgs.value
Expand Down
5 changes: 3 additions & 2 deletions tardis/transport/montecarlo/macro_atom.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
from enum import IntEnum

import numpy as np
from numba import njit
from tardis.transport.montecarlo import njit_dict, njit_dict_no_parallel

from tardis.transport.montecarlo import njit_dict_no_parallel


class MacroAtomError(ValueError):
Expand Down
11 changes: 6 additions & 5 deletions tardis/transport/montecarlo/nonhomologous_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def velocity_dvdr(r_packet, geometry):
geometry: Geometry
Returns
-----------
-------
v: float, current velocity
frac: float, dv/dr for current shell
"""
Expand All @@ -37,16 +37,16 @@ def tau_sobolev_factor(r_packet, geometry):
The angle and velocity dependent Tau Sobolev factor component. Is called when ENABLE_NONHOMOLOGOUS_EXPANSION is set to True.
Note: to get Tau Sobolev, this needs to be multiplied by tau_sobolevs found from plasma
Parameters
----------
r_packet: RPacket
geometry: Geometry
Returns
-----------
-------
factor = 1.0 / ((1 - mu * mu) * v / r + mu * mu * dvdr)
"""

v, dvdr = velocity_dvdr(r_packet, geometry)
r = r_packet.r
mu = r_packet.mu
Expand All @@ -61,11 +61,12 @@ def quartic_roots(a, b, c, d, e, threshold):
Uses: https://en.wikipedia.org/wiki/Quartic_function#General_formula_for_roots
Parameters
-----------
----------
a, b, c, d, e: coefficients of the equations ax^4 + bx^3 + cx^2 + dx + e = 0, float
threshold: lower needed limit on roots, float
Returns
-----------
-------
roots: real positive roots of ax^4 + bx^3 + cx^2 + dx + e = 0
"""
Expand Down
7 changes: 3 additions & 4 deletions tardis/transport/montecarlo/numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


@jitclass(opacity_state_spec)
class OpacityState(object):
class OpacityState:
def __init__(
self,
electron_density,
Expand Down Expand Up @@ -80,7 +80,6 @@ def __init__(
transition_line_id : numpy.ndarray
bf_threshold_list_nu : numpy.ndarray
"""

self.electron_density = electron_density
self.t_electrons = t_electrons
self.line_list_nu = line_list_nu
Expand Down Expand Up @@ -118,7 +117,8 @@ def __getitem__(self, i: slice):
Args:
i (slice): shell slice. Will fail if slice is int since class only supports array types
Returns:
Returns
-------
OpacityState : a shallow copy of the current instance
"""
# NOTE: This currently will not work with continuum processes since it does not slice those arrays
Expand Down Expand Up @@ -161,7 +161,6 @@ def opacity_state_initialize(
plasma : tardis.plasma.BasePlasma
line_interaction_type : enum
"""

electron_densities = plasma.electron_densities.values
t_electrons = plasma.t_electrons
line_list_nu = plasma.atomic_data.lines.nu.values
Expand Down
9 changes: 4 additions & 5 deletions tardis/transport/montecarlo/packet_source.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import abc

import numpy as np
import numexpr as ne
import numpy as np
from astropy import units as u

from tardis import constants as const
from tardis.io.util import HDFWriterMixin
from tardis.transport.montecarlo.packet_collections import (
PacketCollection,
)
from tardis.io.util import HDFWriterMixin
from astropy import units as u


class BasePacketSource(abc.ABC):
Expand Down Expand Up @@ -239,7 +240,6 @@ def create_packet_mus(self, no_of_packets):
Directions for packets
numpy.ndarray
"""

# For testing purposes
if self.legacy_mode_enabled:
return np.sqrt(np.random.random(no_of_packets))
Expand Down Expand Up @@ -269,7 +269,6 @@ def set_temperature_from_luminosity(self, luminosity: u.Quantity):
Parameters
----------
luminosity : u.Quantity
"""
Expand Down
17 changes: 8 additions & 9 deletions tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from numba import float64, int64, njit, from_dtype
from numba.experimental import jitclass
from numba.typed import List
import numpy as np
import pandas as pd

from numba import float64, from_dtype, int64, njit
from numba.experimental import jitclass
from numba.typed import List

boundary_interaction_dtype = np.dtype(
[
Expand Down Expand Up @@ -33,9 +32,10 @@


@jitclass(rpacket_tracker_spec)
class RPacketTracker(object):
class RPacketTracker:
"""
Numba JITCLASS for storing the information for each interaction a RPacket instance undergoes.
Parameters
----------
length : int
Expand Down Expand Up @@ -200,7 +200,7 @@ def rpacket_trackers_to_dataframe(rpacket_trackers):
rpacket_tracker_ndarray[column_name][
prev_index:cur_index
] = getattr(rpacket_tracker, column_name)
index_array[0][prev_index:cur_index] = getattr(rpacket_tracker, "index")
index_array[0][prev_index:cur_index] = rpacket_tracker.index
index_array[1][prev_index:cur_index] = range(cur_index - prev_index)
return pd.DataFrame(
rpacket_tracker_ndarray,
Expand All @@ -220,9 +220,10 @@ def rpacket_trackers_to_dataframe(rpacket_trackers):


@jitclass(rpacket_last_interaction_tracker_spec)
class RPacketLastInteractionTracker(object):
class RPacketLastInteractionTracker:
"""
Numba JITCLASS for storing the last interaction the RPacket undergoes.
Parameters
----------
index : int
Expand Down Expand Up @@ -265,14 +266,12 @@ def finalize_array(self):
"""
Added to make RPacketLastInteractionTracker compatible with RPacketTracker
"""
pass

# To make it compatible with RPacketTracker
def track_boundary_interaction(self, current_shell_id, next_shell_id):
"""
Added to make RPacketLastInteractionTracker compatible with RPacketTracker
"""
pass


@njit
Expand Down
12 changes: 5 additions & 7 deletions tardis/transport/montecarlo/r_packet.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from enum import IntEnum

import numpy as np
import pandas as pd
from numba import int64, float64, njit, objmode
from numba import float64, int64, njit, objmode
from numba.experimental import jitclass

from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)
from tardis.transport.frame_transformations import (
get_doppler_factor,
)
from tardis.transport.montecarlo import njit_dict_no_parallel
from tardis.transport.montecarlo import (
njit_dict_no_parallel,
)


class InteractionType(IntEnum):
Expand Down Expand Up @@ -48,7 +46,7 @@ class PacketStatus(IntEnum):


@jitclass(rpacket_spec)
class RPacket(object):
class RPacket:
def __init__(self, r, mu, nu, energy, seed, index=0):
self.r = r
self.mu = mu
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/single_packet_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
get_doppler_factor,
get_inverse_doppler_factor,
)
from tardis.transport.montecarlo.r_packet import RPacket
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.estimators.radfield_estimator_calcs import (
update_bound_free_estimators,
Expand All @@ -22,6 +21,7 @@
from tardis.transport.montecarlo.r_packet import (
InteractionType,
PacketStatus,
RPacket,
)
from tardis.transport.montecarlo.r_packet_transport import (
move_packet_across_shell_boundary,
Expand Down
4 changes: 0 additions & 4 deletions tardis/transport/montecarlo/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os
import pandas as pd
import numpy as np
import pytest
from astropy import units as u
from numpy.testing import assert_almost_equal
from pathlib import Path

###
# Save and Load
Expand Down
Loading

0 comments on commit de4d872

Please sign in to comment.