diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index dece7ce..a777cd8 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -16,6 +16,7 @@ TileCommunicator, TilePartitioner, ) +from ndsl.optional_imports import cupy as cp def _get_factories( @@ -74,7 +75,9 @@ def _get_factories( grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) - quantity_factory = QuantityFactory(sizer, np) + quantity_factory = QuantityFactory( + sizer, cp if stencil_config.is_gpu_backend else np + ) return stencil_factory, quantity_factory