diff --git a/conftest.py b/conftest.py index 99736e738e1..f6251fceba5 100644 --- a/conftest.py +++ b/conftest.py @@ -191,15 +191,14 @@ def cuda_param(val): return pytest.param(val, marks=[pytest.mark.llvm, pytest.mark.cuda]) @pytest.helpers.register -def get_func_execution(func, func_mode, *, writeback:bool=True): +def get_func_execution(func, func_mode): if func_mode == 'LLVM': ex = pnlvm.execution.FuncExecution(func) # Calling writeback here will replace parameter values # with numpy instances that share memory with the binary # structure used by the compiled function - if writeback: - ex.writeback_state_to_pnl() + ex.writeback_state_to_pnl() return ex.execute @@ -209,8 +208,7 @@ def get_func_execution(func, func_mode, *, writeback:bool=True): # Calling writeback here will replace parameter values # with numpy instances that share memory with the binary # structure used by the compiled function - if writeback: - ex.writeback_state_to_pnl() + ex.writeback_state_to_pnl() return ex.cuda_execute @@ -222,9 +220,25 @@ def get_func_execution(func, func_mode, *, writeback:bool=True): @pytest.helpers.register def get_mech_execution(mech, mech_mode): if mech_mode == 'LLVM': - return pnlvm.execution.MechExecution(mech).execute + ex = pnlvm.execution.MechExecution(mech) + + # Calling writeback here will replace parameter values + # with numpy instances that share memory with the binary + # structure used by the compiled function + ex.writeback_state_to_pnl() + + return ex.execute + elif mech_mode == 'PTX': - return pnlvm.execution.MechExecution(mech).cuda_execute + ex = pnlvm.execution.MechExecution(mech) + + # Calling writeback here will replace parameter values + # with numpy instances that share memory with the binary + # structure used by the compiled function + ex.writeback_state_to_pnl() + + return ex.cuda_execute + elif mech_mode == 'Python': def mech_wrapper(x): mech.execute(x) diff --git a/psyneulink/core/llvm/execution.py b/psyneulink/core/llvm/execution.py index 55838dd845d..9af8f8e86d5 100644 --- a/psyneulink/core/llvm/execution.py +++ b/psyneulink/core/llvm/execution.py @@ -158,6 +158,21 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal ids=ids, condition=condition) else: + # TODO: Reconstruct Python RandomState + if attribute == "random_state": + continue + + # TODO: Reconstruct Python memory storage + if attribute == "ring_memory": + continue + + # "old_val" is a helper storage in compiled RecurrentTransferMechanism + # to workaround the fact that compiled projections do no pull values + # from their source output ports + # recurrent projection of RTM is not a PNL parameter. + if attribute in {"old_val", "recurrent_projection"}: + continue + # Handle PNL parameters pnl_param = getattr(component.parameters, attribute) pnl_value = pnl_param.get(context=context) @@ -183,10 +198,6 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal # Writeback parameter value if the condition matches elif condition(pnl_param): - # TODO: Reconstruct Python RandomState - if attribute == "random_state": - continue - # Replace empty structures with None if ctypes.sizeof(compiled_attribute_param) == 0: value = None @@ -202,7 +213,7 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal if hasattr(old_value, 'shape'): value = value.reshape(old_value.shape) - pnl_param.set(value, context=context) + pnl_param.set(value, context=context, override=True) class CUDAExecution(Execution): diff --git a/tests/functions/test_memory.py b/tests/functions/test_memory.py index b1b7bf64f13..7c1dbbbc19c 100644 --- a/tests/functions/test_memory.py +++ b/tests/functions/test_memory.py @@ -144,9 +144,7 @@ def test_basic(func, variable, params, expected, benchmark, func_mode): if variable is philox_var: f.parameters.random_state.set(_SeededPhilox([module_seed])) - # Do not allow writeback. "ring_memory" used by DictionaryMemory is a - # custom structure, not a PNL parameter - EX = pytest.helpers.get_func_execution(f, func_mode, writeback=False) + EX = pytest.helpers.get_func_execution(f, func_mode) EX(variable)