diff --git a/examples/rasters/wv_indices.py b/examples/rasters/wv_indices.py index 8e88da5..131c42d 100755 --- a/examples/rasters/wv_indices.py +++ b/examples/rasters/wv_indices.py @@ -2,7 +2,10 @@ from terragpu import engine from terragpu.array.raster import Raster from terragpu.indices.wv_indices import add_indices +from terragpu.engine import array_module, df_module +xp = array_module() +xf = df_module() def main(filename, bands): @@ -38,10 +41,11 @@ def main(filename, bands): ] # Start dask cluster - dask scheduler must be started from main - engine.configure_dask( - device='gpu', - n_workers=4, - local_directory='/lscratch/jacaraba') + if xp.__name__ == 'cupy': + engine.configure_dask( + device='gpu', + n_workers=4, + local_directory='/lscratch/jacaraba') # Execute main function and calculate indices main(filename, bands) diff --git a/requirements/environment_cpu.yaml b/requirements/environment_cpu.yaml index 8756dfc..51cd91c 100755 --- a/requirements/environment_cpu.yaml +++ b/requirements/environment_cpu.yaml @@ -33,3 +33,4 @@ dependencies: - pdoc3 - flake8 - coverage + - numba diff --git a/requirements/environment_gpu.yml b/requirements/environment_gpu.yml index 7025379..3467d0e 100755 --- a/requirements/environment_gpu.yml +++ b/requirements/environment_gpu.yml @@ -28,4 +28,5 @@ dependencies: - tqdm - pdoc3 - flake8 - - coverage \ No newline at end of file + - coverage + - numba diff --git a/terragpu/engine.py b/terragpu/engine.py index 0090aaa..22fbe9e 100755 --- a/terragpu/engine.py +++ b/terragpu/engine.py @@ -4,8 +4,14 @@ import pandas as pd import xarray as xr from types import ModuleType -from dask_cuda import LocalCUDACluster -from dask.distributed import Client, LocalCluster + +try: + from dask_cuda import LocalCUDACluster + from dask.distributed import Client, LocalCluster + HAS_GPU = True +except ModuleNotFoundError: + logging.info("Not importing Dask CUDA libraries") + HAS_GPU = False _warn_array_module_once = False _warn_df_module_once = False