From 466128a97f00965374b99eca05e24fdd617fc8af Mon Sep 17 00:00:00 2001 From: DropD Date: Wed, 16 Oct 2024 14:56:00 +0200 Subject: [PATCH] fix dace orchestrator code path requiring `exchange_obj` --- .../model/common/orchestration/decorator.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/model/common/src/icon4py/model/common/orchestration/decorator.py b/model/common/src/icon4py/model/common/orchestration/decorator.py index 10a2b8be74..aa4603e4cc 100644 --- a/model/common/src/icon4py/model/common/orchestration/decorator.py +++ b/model/common/src/icon4py/model/common/orchestration/decorator.py @@ -65,15 +65,13 @@ @overload -def orchestrate(func: Callable[P, R], *, method: bool | None = None) -> Callable[P, R]: - ... +def orchestrate(func: Callable[P, R], *, method: bool | None = None) -> Callable[P, R]: ... @overload def orchestrate( func: None = None, *, method: bool | None = None -) -> Callable[[Callable[P, R]], Callable[P, R]]: - ... +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... def orchestrate( @@ -539,35 +537,37 @@ def dace_specific_kwargs( The additional kwargs are the connectivity tables (runtime tables) and the GHEX C++ pointers. """ - return { + dace_kwargs = { # connectivity tables at runtime - **{ - connectivity_identifier(k): v.table - for k, v in offset_providers.items() - if hasattr(v, "table") - }, - # GHEX C++ ptrs - "__context_ptr": expose_cpp_ptr(exchange_obj._context) - if not isinstance(exchange_obj, decomposition.SingleNodeExchange) - else 0, - "__comm_ptr": expose_cpp_ptr(exchange_obj._comm) - if not isinstance(exchange_obj, decomposition.SingleNodeExchange) - else 0, - **{ - f"__pattern_{dim.value}Dim_ptr": expose_cpp_ptr(exchange_obj._patterns[dim]) + connectivity_identifier(k): v.table + for k, v in offset_providers.items() + if hasattr(v, "table") + } + if exchange_obj: + dace_kwargs |= { + # GHEX C++ ptrs + "__context_ptr": expose_cpp_ptr(exchange_obj._context) if not isinstance(exchange_obj, decomposition.SingleNodeExchange) - else 0 - for dim in dims.global_dimensions.values() - }, - **{ - f"__domain_descriptor_{dim.value}Dim_ptr": expose_cpp_ptr( - exchange_obj._domain_descriptors[dim].__wrapped__ - ) + else 0, + "__comm_ptr": expose_cpp_ptr(exchange_obj._comm) if not isinstance(exchange_obj, decomposition.SingleNodeExchange) - else 0 - for dim in dims.global_dimensions.values() - }, - } + else 0, + **{ + f"__pattern_{dim.value}Dim_ptr": expose_cpp_ptr(exchange_obj._patterns[dim]) + if not isinstance(exchange_obj, decomposition.SingleNodeExchange) + else 0 + for dim in dims.global_dimensions.values() + }, + **{ + f"__domain_descriptor_{dim.value}Dim_ptr": expose_cpp_ptr( + exchange_obj._domain_descriptors[dim].__wrapped__ + ) + if not isinstance(exchange_obj, decomposition.SingleNodeExchange) + else 0 + for dim in dims.global_dimensions.values() + }, + } + return dace_kwargs def modified_orig_annotations( dace_annotations: dict[str, Any], fuse_func_orig_annotations: dict[str, Any]