From f99914a3c47b038fb5f145126ceda9e3bfdd7cd4 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 27 Dec 2024 09:46:18 -0500 Subject: [PATCH] Make sure the correct allocator backend is used for Quantities --- ndsl/boilerplate.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ndsl/boilerplate.py b/ndsl/boilerplate.py index dece7ce..a777cd8 100644 --- a/ndsl/boilerplate.py +++ b/ndsl/boilerplate.py @@ -16,6 +16,7 @@ TileCommunicator, TilePartitioner, ) +from ndsl.optional_imports import cupy as cp def _get_factories( @@ -74,7 +75,9 @@ def _get_factories( grid_indexing = GridIndexing.from_sizer_and_communicator(sizer, comm) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) - quantity_factory = QuantityFactory(sizer, np) + quantity_factory = QuantityFactory( + sizer, cp if stencil_config.is_gpu_backend else np + ) return stencil_factory, quantity_factory