Skip to content

Commit

Permalink
GPU Options (#1919)
Browse files Browse the repository at this point in the history
* Added parsing method for formal_integral

* Remove random spaces

* Removed space, specified arguments

* remove cuda import
  • Loading branch information
KevinCawley authored Mar 3, 2022
1 parent dadc137 commit f541e73
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
15 changes: 15 additions & 0 deletions tardis/io/schemas/spectrum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ properties:
integral quantities are interpolated. For -1 no interpolation
is used. The default is to use twice the number of computational
shells but at least 80.
compute:
type: string
default: "CPU"
description: Which method the formal_integral will be computed with.
It defaults to the Numba version, but it can be run on NVIDIA Cuda
GPUs. GPU will make it only run on a NVIDIA Cuda GPU, so if one is
not available it will raise an error. Automatic first tries to find
an acceptable GPU, and if none is found it will run on the CPU.
properties:
type:
enum:
- "CPU"
- "GPU"
- "Automatic"

virtual:
type: object
default: {}
Expand Down
32 changes: 31 additions & 1 deletion tardis/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astropy import units as u
from tardis import constants as const
from numba import set_num_threads
from numba import cuda

from scipy.special import zeta
from tardis.montecarlo.spectrum import TARDISSpectrum
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
logger_buffer=1,
single_packet_seed=None,
tracking_rpacket=False,
use_gpu=False,
):

self.seed = seed
Expand All @@ -124,6 +126,7 @@ def __init__(
self.seed = seed
self._integrator = None
self._spectrum_integrated = None
self.use_gpu=use_gpu

self.virt_logging = virtual_packet_logging
self.virt_packet_last_interaction_type = np.ones(2) * -1
Expand Down Expand Up @@ -266,8 +269,11 @@ def spectrum_virtual(self):
@property
def spectrum_integrated(self):
if self._spectrum_integrated is None:
#This was changed from unpacking to specific attributes as compute
#is not used in calculate_spectrum
self._spectrum_integrated = self.integrator.calculate_spectrum(
self.spectrum_frequency[:-1], **self.integrator_settings
self.spectrum_frequency[:-1], points=self.integrator_settings.points,
interpolate_shells=self.integrator_settings.interpolate_shells,
)
return self._spectrum_integrated

Expand Down Expand Up @@ -635,6 +641,29 @@ def from_config(
config.spectrum.start.to("Hz", u.spectral()),
num=config.spectrum.num + 1,
)
running_mode=config.spectrum.integrated.compute.upper()

if running_mode == "GPU":
if cuda.is_available():
use_gpu = True
else:
raise ValueError(
"""The GPU option was selected for the formal_integral,
but no CUDA GPU is available."""
)
elif running_mode == "AUTOMATIC":
if cuda.is_available():
use_gpu = True
else:
use_gpu = False
elif running_mode == "CPU":
use_gpu = False
else:
raise ValueError(
"""An invalid option for compute was passed. The three
valid values are 'GPU', 'CPU', and 'Automatic'."""
)

mc_config_module.disable_line_scattering = (
config.plasma.disable_line_scattering
)
Expand Down Expand Up @@ -664,4 +693,5 @@ def from_config(
| virtual_packet_logging
),
tracking_rpacket=config.montecarlo.tracking.track_rpacket,
use_gpu=use_gpu,
)
3 changes: 1 addition & 2 deletions tardis/montecarlo/montecarlo_numba/formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
NumbaModel,
NumbaPlasma,
)
from numba import cuda
from tardis.montecarlo.montecarlo_numba.formal_integral_cuda import CudaFormalIntegrator

from tardis.montecarlo.spectrum import TARDISSpectrum
Expand Down Expand Up @@ -287,7 +286,7 @@ def generate_numba_objects(self):
self.numba_plasma = numba_plasma_initialize(
self.original_plasma, self.runner.line_interaction_type
)
if cuda.is_available():
if self.runner.use_gpu:
self.integrator = CudaFormalIntegrator(
self.numba_model, self.numba_plasma, self.points
)
Expand Down

0 comments on commit f541e73

Please sign in to comment.