Skip to content

Commit

Permalink
Fast formal integral (#2731)
Browse files Browse the repository at this point in the history
* Added faster double-grid formal integral method

* Removed the old  cuda formal integral method

* Ran Black

* Edited some comments

* Updated docstrings to be more descriptive, added in better comments on cuda device objecst

* removed comment about boundary condition (verified by Jack as correct treatment)

* Removed unneccessary branching

* Ran black

* small implimentation detail with indexing

* added back the z_end

* Update tardis/spectrum/formal_integral_cuda.py

Co-authored-by: Jing Lu <[email protected]>

* Fixed logic issue with how the first iteration is handled

* ran black

---------

Co-authored-by: Andrew <[email protected]>
Co-authored-by: Jing Lu <[email protected]>
  • Loading branch information
3 people authored Jul 23, 2024
1 parent 02985a6 commit 38a35c3
Showing 1 changed file with 24 additions and 33 deletions.
57 changes: 24 additions & 33 deletions tardis/spectrum/formal_integral_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def cuda_formal_integral(
pJred_lu = int(offset + idx_nu_start)
pJblue_lu = int(offset + idx_nu_start)

first = 1

# loop over all interactions
for i in range(size_z - 1):
escat_op = electron_density[int(shell_id_thread[i])] * SIGMA_THOMSON
Expand All @@ -175,46 +177,35 @@ def cuda_formal_integral(
) # +1 is the offset as the original is from z[1:]

nu_end_idx = line_search_cuda(line_list_nu, nu_end, len(line_list_nu))
zend = time_explosion / C_INV * (1.0 - line_list_nu[pline] / nu)
escat_contrib += (
(zend - zstart)
* escat_op
* (Jblue_lu[pJblue_lu] - I_nu_thread[p_idx])
)
pJred_lu += 1
I_nu_thread[p_idx] += escat_contrib
# // Lucy 1999, Eq 26
I_nu_thread[p_idx] *= exp_tau[pexp_tau]
I_nu_thread[p_idx] += att_S_ul[patt_S_ul]

# // reset e-scattering opacity
escat_contrib = 0.0
zstart = zend

pline += 1
pexp_tau += 1
patt_S_ul += 1
pJblue_lu += 1

for _ in range(1, max(nu_end_idx - pline, 1)):
for _ in range(max(nu_end_idx - pline, 0)):
# calculate e-scattering optical depth to next resonance point
zend = time_explosion / C_INV * (1.0 - line_list_nu[pline] / nu)

# Account for e-scattering, c.f. Eqs 27, 28 in Lucy 1999
Jkkp = 0.5 * (Jred_lu[pJred_lu] + Jblue_lu[pJblue_lu])
escat_contrib += (
(zend - zstart) * escat_op * (Jkkp - I_nu_thread[p_idx])
)
# this introduces the necessary offset of one element between
# pJblue_lu and pJred_lu
pJred_lu += 1
zend = (
time_explosion / C_INV * (1.0 - line_list_nu[pline] / nu)
) # check

if first == 1:
escat_contrib += (
(zend - zstart)
* escat_op
* (Jblue_lu[pJblue_lu] - I_nu_thread[p_idx])
)
first = 0
else:
# Account for e-scattering, c.f. Eqs 27, 28 in Lucy 1999
Jkkp = 0.5 * (Jred_lu[pJred_lu] + Jblue_lu[pJblue_lu])
escat_contrib += (
(zend - zstart) * escat_op * (Jkkp - I_nu_thread[p_idx])
)
# this introduces the necessary ffset of one element between
# pJblue_lu and pJred_lu
pJred_lu += 1
I_nu_thread[p_idx] += escat_contrib
# // Lucy 1999, Eq 26
I_nu_thread[p_idx] *= exp_tau[pexp_tau]
I_nu_thread[p_idx] += att_S_ul[patt_S_ul]

# // reset e-scattering opacity
escat_contrib = 0.0
escat_contrib = 0
zstart = zend

pline += 1
Expand Down

0 comments on commit 38a35c3

Please sign in to comment.