Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llvm: Enable state writeback on all compiled Functions and Mechanisms #2938

Merged
merged 4 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,14 @@ def cuda_param(val):
return pytest.param(val, marks=[pytest.mark.llvm, pytest.mark.cuda])

@pytest.helpers.register
def get_func_execution(func, func_mode, *, writeback:bool=True):
def get_func_execution(func, func_mode):
if func_mode == 'LLVM':
ex = pnlvm.execution.FuncExecution(func)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_state_to_pnl()
ex.writeback_state_to_pnl()

return ex.execute

Expand All @@ -209,8 +208,7 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_state_to_pnl()
ex.writeback_state_to_pnl()

return ex.cuda_execute

Expand All @@ -222,9 +220,25 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
@pytest.helpers.register
def get_mech_execution(mech, mech_mode):
if mech_mode == 'LLVM':
return pnlvm.execution.MechExecution(mech).execute
ex = pnlvm.execution.MechExecution(mech)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
ex.writeback_state_to_pnl()

return ex.execute

elif mech_mode == 'PTX':
return pnlvm.execution.MechExecution(mech).cuda_execute
ex = pnlvm.execution.MechExecution(mech)

# Calling writeback here will replace parameter values
# with numpy instances that share memory with the binary
# structure used by the compiled function
ex.writeback_state_to_pnl()

return ex.cuda_execute

elif mech_mode == 'Python':
def mech_wrapper(x):
mech.execute(x)
Expand Down
21 changes: 16 additions & 5 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,21 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal
ids=ids,
condition=condition)
else:
# TODO: Reconstruct Python RandomState
if attribute == "random_state":
continue

# TODO: Reconstruct Python memory storage
if attribute == "ring_memory":
continue

# "old_val" is a helper storage in compiled RecurrentTransferMechanism
# to workaround the fact that compiled projections do no pull values
# from their source output ports
# recurrent projection of RTM is not a PNL parameter.
if attribute in {"old_val", "recurrent_projection"}:
continue

# Handle PNL parameters
pnl_param = getattr(component.parameters, attribute)
pnl_value = pnl_param.get(context=context)
Expand All @@ -183,10 +198,6 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal
# Writeback parameter value if the condition matches
elif condition(pnl_param):

# TODO: Reconstruct Python RandomState
if attribute == "random_state":
continue

# Replace empty structures with None
if ctypes.sizeof(compiled_attribute_param) == 0:
value = None
Expand All @@ -202,7 +213,7 @@ def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Cal
if hasattr(old_value, 'shape'):
value = value.reshape(old_value.shape)

pnl_param.set(value, context=context)
pnl_param.set(value, context=context, override=True)


class CUDAExecution(Execution):
Expand Down
4 changes: 1 addition & 3 deletions tests/functions/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def test_basic(func, variable, params, expected, benchmark, func_mode):
if variable is philox_var:
f.parameters.random_state.set(_SeededPhilox([module_seed]))

# Do not allow writeback. "ring_memory" used by DictionaryMemory is a
# custom structure, not a PNL parameter
EX = pytest.helpers.get_func_execution(f, func_mode, writeback=False)
EX = pytest.helpers.get_func_execution(f, func_mode)

EX(variable)

Expand Down
Loading