Skip to content

Commit

Permalink
llvm, tests: Enable writeback of state for all compiled mechanism tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Apr 4, 2024
1 parent 0b82f41 commit a63de1d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
20 changes: 18 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,25 @@ def get_func_execution(func, func_mode):
@pytest.helpers.register
def get_mech_execution(mech, mech_mode):
if mech_mode == 'LLVM':
return pnlvm.execution.MechExecution(mech).execute
ex = pnlvm.execution.MechExecution(mech)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
ex.writeback_state_to_pnl()

return ex.execute

elif mech_mode == 'PTX':
return pnlvm.execution.MechExecution(mech).cuda_execute
ex = pnlvm.execution.MechExecution(mech)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
ex.writeback_state_to_pnl()

return ex.cuda_execute

elif mech_mode == 'Python':
def mech_wrapper(x):
mech.execute(x)
Expand Down
2 changes: 1 addition & 1 deletion psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal
if hasattr(old_value, 'shape'):
value = value.reshape(old_value.shape)

pnl_param.set(value, context=context)
pnl_param.set(value, context=context, override=True)


class CUDAExecution(Execution):
Expand Down

0 comments on commit a63de1d

Please sign in to comment.