Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llvm: Move learnable matrices from RO params to RW params #2933

Merged
merged 6 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_params_to_pnl()
ex.writeback_state_to_pnl()

return ex.execute

Expand All @@ -210,7 +210,7 @@ def get_func_execution(func, func_mode, *, writeback:bool=True):
# with numpy instances that share memory with the binary
# structure used by the compiled function
if writeback:
ex.writeback_params_to_pnl()
ex.writeback_state_to_pnl()

return ex.cuda_execute

Expand Down
76 changes: 46 additions & 30 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,30 +1292,55 @@ def __deepcopy__(self, memo):
# ------------------------------------------------------------------------------------------------------------------
# Compilation support
# ------------------------------------------------------------------------------------------------------------------
def _is_compilable_param(self, p):

# User only parameters are not compiled.
if p.read_only and p.getter is not None:
return False

# Shared and aliased parameters are for user conveniecne and not compiled.
if isinstance(p, (ParameterAlias, SharedParameter)):
return False

# TODO this should use default value
val = p.get()

# Strings, builtins, functions, and methods are not compilable
return not isinstance(val, (str,
type(max),
type(np.sum),
type(make_parameter_property),
type(self._get_compilation_params)))


def _get_compilation_state(self):
# FIXME: MAGIC LIST, Use stateful tag for this
whitelist = {"previous_time", "previous_value", "previous_v",
"previous_w", "random_state",
"input_ports", "output_ports",
"adjustment_cost", "intensity_cost", "duration_cost",
"intensity"}

# Prune subcomponents (which are enabled by type rather than a list)
# that should be omitted
blacklist = { "objective_mechanism", "agent_rep", "projections", "shadow_inputs"}

# Only mechanisms use "value" state, can execute 'until finished',
# and need to track executions
# Mechanisms;
# * use "value" state
# * can execute 'until finished'
# * need to track number of executions
if hasattr(self, 'ports'):
whitelist.update({"value", "num_executions_before_finished",
"num_executions", "is_finished_flag"})

# If both the mechanism and its functoin use random_state it's DDM
# with integrator function. The mechanism's random_state is not used.
# If both the mechanism and its function use random_state.
# it's DDM with integrator function.
# The mechanism's random_state is not used.
if hasattr(self.parameters, 'random_state') and hasattr(self.function.parameters, 'random_state'):
whitelist.remove('random_state')


# Only mechanisms and compositions need 'num_executions'
# Compositions need to track number of executions
if hasattr(self, 'nodes'):
whitelist.add("num_executions")

Expand All @@ -1341,11 +1366,15 @@ def _get_compilation_state(self):
if hasattr(self.parameters, 'duplicate_keys'):
blacklist.add("previous_value")

# Matrices of learnable projections are stateful
if getattr(self, 'owner', None) and getattr(self.owner, 'learnable', False):
whitelist.add('matrix')

def _is_compilation_state(p):
# FIXME: This should use defaults instead of 'p.get'
return p.name not in blacklist and \
not isinstance(p, (ParameterAlias, SharedParameter)) and \
(p.name in whitelist or isinstance(p.get(), Component))
(p.name in whitelist or isinstance(p.get(), Component)) and \
self._is_compilable_param(p)

return filter(_is_compilation_state, self.parameters)

Expand All @@ -1362,9 +1391,10 @@ def llvm_state_ids(self):

def _get_state_initializer(self, context):
def _convert(p):
# FIXME: This should use defaults instead of 'p.get'
x = p.get(context)
if isinstance(x, np.random.RandomState):
if p.name == 'matrix': # Flatten matrix
val = tuple(np.asfarray(x).flatten())
elif isinstance(x, np.random.RandomState):
# Skip first element of random state (id string)
val = pnlvm._tupleize((*x.get_state()[1:], x.used_seed[0]))
elif isinstance(x, np.random.Generator):
Expand Down Expand Up @@ -1432,11 +1462,11 @@ def _get_compilation_params(self):
"learning_results", "learning_signal", "learning_signals",
"error_matrix", "error_signal", "activation_input",
"activation_output", "error_sources", "covariates_sources",
"target", "sample",
"target", "sample", "learning_function"
}
# Mechanism's need few extra entries:
# * matrix -- is never used directly, and is flatened below
# * integration rate -- shape mismatch with param port input
# * integration_rate -- shape mismatch with param port input
# * initializer -- only present on DDM and never used
# * search_space -- duplicated between OCM and its function
if hasattr(self, 'ports'):
Expand Down Expand Up @@ -1466,26 +1496,12 @@ def _get_compilation_params(self):
if cost_functions.DURATION not in cost_functions:
blacklist.add('duration_cost_fct')

def _is_compilation_param(p):
def _is_user_only_param(p):
if p.read_only and p.getter is not None:
return True
if isinstance(p, (ParameterAlias, SharedParameter)):
return True

return False

# Matrices of learnable projections are stateful
if getattr(self, 'owner', None) and getattr(self.owner, 'learnable', False):
blacklist.add('matrix')

if p.name not in blacklist and not _is_user_only_param(p):
# FIXME: this should use defaults
val = p.get()
# Check if the value type is valid for compilation
return not isinstance(val, (str, ComponentsMeta,
type(max),
type(np.sum),
type(_is_compilation_param),
type(self._get_compilation_params)))
return False
def _is_compilation_param(p):
return p.name not in blacklist and self._is_compilable_param(p)

return filter(_is_compilation_param, self.parameters)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3886,7 +3886,7 @@ def instantiate_matrix(self, specification, context=None):
return np.array(specification)


def _gen_llvm_function_body(self, ctx, builder, params, _, arg_in, arg_out, *, tags:frozenset):
def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
# Restrict to 1d arrays
if self.defaults.variable.ndim != 1:
warnings.warn("Shape mismatch: {} (in {}) got 2D input: {}".format(
Expand All @@ -3899,7 +3899,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, _, arg_in, arg_out, *, t
pnlvm.PNLCompilerWarning)
arg_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])

matrix = ctx.get_param_or_state_ptr(builder, self, MATRIX, param_struct_ptr=params)
matrix = ctx.get_param_or_state_ptr(builder, self, MATRIX, param_struct_ptr=params, state_struct_ptr=state)
normalize = ctx.get_param_or_state_ptr(builder, self, NORMALIZE, param_struct_ptr=params)

# Convert array pointer to pointer to the fist element
Expand Down
6 changes: 2 additions & 4 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11251,10 +11251,8 @@ def run(
self.parameters.results._set(results, context)

if self._is_learning(context):
# copies back matrix to pnl from param struct (after learning)
_comp_ex.writeback_params_to_pnl(params=_comp_ex._param_struct,
ids="llvm_param_ids",
condition=lambda p: p.name == "matrix")
# copies back matrix to pnl from state struct after learning
_comp_ex.writeback_state_to_pnl(condition=lambda p: p.name == "matrix")

self._propagate_most_recent_context(context)

Expand Down
2 changes: 2 additions & 0 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def _state_struct(p):
return self.get_state_struct_type(val)
if isinstance(val, ContentAddressableList):
return ir.LiteralStructType(self.get_state_struct_type(x) for x in val)
if p.name == 'matrix': # Flatten matrix
val = np.asfarray(val).flatten()
struct = self.convert_python_struct_to_llvm_ir(val)
return ir.ArrayType(struct, p.history_min_length + 1)

Expand Down
15 changes: 6 additions & 9 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,13 @@ def _get_compilation_param(self, name, init_method, arg):
return struct


def writeback_params_to_pnl(self, params=None, ids:Optional[str]=None, condition:Callable=lambda p: True):
def writeback_state_to_pnl(self, condition:Callable=lambda p: True):

assert (params is None) == (ids is None), "Either both 'params' and 'ids' have to be set or neither"

if params is None:
# Default to stateful params
params = self._state_struct
ids = "llvm_state_ids"

self._copy_params_to_pnl(self._execution_contexts[0], self._obj, params, ids, condition)
self._copy_params_to_pnl(self._execution_contexts[0],
self._obj,
self._state_struct,
"llvm_state_ids",
condition)


def _copy_params_to_pnl(self, context, component, params, ids:str, condition:Callable):
Expand Down
Loading