Skip to content

Commit

Permalink
Applied numba where easiest
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Dec 16, 2021
1 parent 41eb63f commit 342c9c5
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 2 deletions.
6 changes: 5 additions & 1 deletion tardis/energy_input/calculate_opacity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import astropy.units as u
from numba import njit

from tardis import constants as const
from tardis.energy_input.util import kappa_calculation
Expand All @@ -10,7 +11,8 @@
M_P = const.m_p.to(u.g).value
SIGMA_T = const.sigma_T.cgs.value

# TODO: add units for completeness

@njit
def compton_opacity_calculation(energy, ejecta_density):
"""Calculate the Compton scattering opacity for a given energy
(Rybicki & Lightman, 1979)
Expand Down Expand Up @@ -54,6 +56,7 @@ def compton_opacity_calculation(energy, ejecta_density):
return ejecta_density / (M_P * 2) * sigma_KN


@njit
def photoabsorption_opacity_calculation(
energy, ejecta_density, iron_group_fraction
):
Expand Down Expand Up @@ -93,6 +96,7 @@ def photoabsorption_opacity_calculation(
return si_opacity + fe_opacity


@njit
def pair_creation_opacity_calculation(
energy, ejecta_density, iron_group_fraction
):
Expand Down
3 changes: 3 additions & 0 deletions tardis/energy_input/energy_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
from nuclear.ejecta import Ejecta
from numba import njit


def decay_nuclides(shell_mass, initial_composition, epoch):
Expand All @@ -10,6 +11,7 @@ def decay_nuclides(shell_mass, initial_composition, epoch):
return new_fractions


@njit
def create_energy_cdf(energy, intensity):
"""Creates a CDF of given intensities
Expand Down Expand Up @@ -37,6 +39,7 @@ def create_energy_cdf(energy, intensity):
return energy, cdf


@njit
def sample_energy_distribution(energy_sorted, cdf):
"""Randomly samples a CDF of energies
Expand Down
1 change: 1 addition & 0 deletions tardis/energy_input/gamma_ray_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from nuclear.io.nndc import get_decay_radiation_database, store_decay_radiation
import pandas as pd
import astropy.units as u
from numba import njit

from tardis.energy_input.util import (
solve_quadratic_equation,
Expand Down
7 changes: 6 additions & 1 deletion tardis/energy_input/gamma_ray_interactions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import numpy as np
from numba import njit

from tardis.energy_input.util import (
kappa_calculation,
Expand All @@ -15,6 +16,7 @@
from tardis.energy_input.GXPhoton import GXPhotonStatus


@njit
def get_compton_angle(energy):
"""
Computes the compton angle from the Klein-Nishina equation.
Expand Down Expand Up @@ -65,7 +67,9 @@ def compton_scatter(photon, compton_angle):
Photon phi direction
"""
# transform original direction vector to cartesian coordinates
original_direction = normalize_vector(photon.direction.cartesian_coords)
original_direction = normalize_vector(
np.array(photon.direction.cartesian_coords)
)
# compute an arbitrary perpendicular vector to the original direction
orthogonal_vector = get_perpendicular_vector(original_direction)
# determine a random vector with compton_angle to the original direction
Expand Down Expand Up @@ -126,6 +130,7 @@ def pair_creation(photon):
return photon, backward_ray


@njit
def scatter_type(compton_opacity, photoabsorption_opacity, total_opacity):
"""
Determines the scattering type based on process opacities
Expand Down
12 changes: 12 additions & 0 deletions tardis/energy_input/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import astropy.units as u
import tardis.constants as const
import numpy as np
from numba import njit

R_ELECTRON_SQUARED = (const.a0.cgs.value * const.alpha.cgs.value ** 2.0) ** 2.0
ELECTRON_MASS_ENERGY_KEV = (const.m_e * const.c ** 2.0).to("keV").value
Expand Down Expand Up @@ -47,20 +48,23 @@ def cartesian_coords(self):
return x, y, z


@njit
def spherical_to_cartesian(r, theta, phi):
x = r * np.cos(phi) * np.sin(theta)
y = r * np.sin(phi) * np.sin(theta)
z = r * np.cos(theta)
return x, y, z


@njit
def cartesian_to_spherical(x, y, z):
r = np.sqrt(x ** 2 + y ** 2 + z ** 2)
theta = np.arccos(z / r)
phi = np.arctan2(y, x)
return r, theta, phi


@njit
def kappa_calculation(energy):
"""
Calculates kappa for various other calculations
Expand All @@ -79,6 +83,7 @@ def kappa_calculation(energy):
return energy / ELECTRON_MASS_ENERGY_KEV


@njit
def euler_rodrigues(theta, direction):
"""
Calculates the Euler-Rodrigues rotation matrix
Expand Down Expand Up @@ -116,6 +121,7 @@ def euler_rodrigues(theta, direction):
)


@njit
def solve_quadratic_equation(x, y, z, x_dir, y_dir, z_dir, radius_velocity):
"""
Solves the quadratic equation for the distance to the shell boundary
Expand Down Expand Up @@ -145,6 +151,7 @@ def solve_quadratic_equation(x, y, z, x_dir, y_dir, z_dir, radius_velocity):
return solution_1, solution_2


@njit
def klein_nishina(energy, theta_C):
"""
Calculates the Klein-Nishina equation
Expand Down Expand Up @@ -181,6 +188,7 @@ def klein_nishina(energy, theta_C):
)


@njit
def compton_theta_distribution(energy, sample_resolution=100):
"""
Calculates the cumulative distribution function of theta angles
Expand All @@ -205,6 +213,7 @@ def compton_theta_distribution(energy, sample_resolution=100):
return theta_angles, norm_theta_distribution


@njit
def get_random_theta_photon():
"""Get a random theta direction between 0 and pi
Returns
Expand All @@ -215,6 +224,7 @@ def get_random_theta_photon():
return np.arccos(1.0 - 2.0 * np.random.random())


@njit
def get_random_phi_photon():
"""Get a random phi direction between 0 and 2 * pi
Expand Down Expand Up @@ -252,6 +262,7 @@ def convert_half_life_to_astropy_units(half_life_string):
return half_life_with_unit.to(u.s)


@njit
def normalize_vector(vector):
"""
Normalizes a vector in cartesian coordinates
Expand All @@ -269,6 +280,7 @@ def normalize_vector(vector):
return vector / np.linalg.norm(vector)


@njit
def get_perpendicular_vector(original_direction):
"""
Computes a vector which is perpendicular to the input vector
Expand Down

0 comments on commit 342c9c5

Please sign in to comment.