Skip to content

Commit

Permalink
llvm: Remove workaround, rename node assembly (#3016)
Browse files Browse the repository at this point in the history
Remove workaround in used parameters check. Add a comment that
unused parameters should be added to the blocklist in Component.
Rename node_wrapper -> node assembly to avoid confusion with wrapped nodes.
  • Loading branch information
jvesely authored Jul 27, 2024
2 parents 010dbc4 + f8d1f43 commit d622a29
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 32 deletions.
25 changes: 13 additions & 12 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def module_count():
'mt_rand_init', 'philox_rand_init'))


class _node_wrapper():
class _node_assembly():
def __init__(self, composition, node):
self._comp = weakref.proxy(composition)
self._node = node
Expand All @@ -61,7 +61,7 @@ def __repr__(self):
return "Node wrapper for node '{}' in composition '{}'".format(self._node, self._comp)

def _gen_llvm_function(self, *, ctx, tags:frozenset):
return codegen.gen_node_wrapper(ctx, self._comp, self._node, tags=tags)
return codegen.gen_node_assembly(ctx, self._comp, self._node, tags=tags)

def _comp_cached(func):
@functools.wraps(func)
Expand Down Expand Up @@ -349,6 +349,13 @@ def get_state_space(self, builder, component, state_ptr, param):
return helpers.get_state_space(builder, component, state_ptr, param_name)

def check_used_params(self, component, *, tags:frozenset):
"""
This function checks that parameters included in the compiled structures are used in compiled code.
If the assertion in this function triggers the parameter name should be added to the parameter
block list in the Component class.
"""

# Skip the check if the parameter use is not tracked. Some components (like node wrappers)
# don't even have parameters.
if component not in self._component_state_use and component not in self._component_param_use:
Expand Down Expand Up @@ -378,12 +385,6 @@ def check_used_params(self, component, *, tags:frozenset):
if hasattr(component, 'evaluate_agent_rep'):
used_param_ids.add('num_trials_per_estimate')

if hasattr(component, 'adapt_scale'):
used_param_ids.add('threshold')
used_param_ids.add('adapt_scale')
used_param_ids.add('adapt_base')
used_param_ids.add('adapt_entropy_weighting')

unused_param_ids = component_param_ids - used_param_ids - initializers
unused_state_ids = component_state_ids - used_state_ids

Expand Down Expand Up @@ -504,12 +505,12 @@ def get_data_struct_type(self, component):

return ir.LiteralStructType([])

def get_node_wrapper(self, composition, node):
cache = getattr(composition, '_wrapped_nodes', None)
def get_node_assembly(self, composition, node):
cache = getattr(composition, '_node_assemblies', None)
if cache is None:
cache = weakref.WeakKeyDictionary()
setattr(composition, '_wrapped_nodes', cache)
return cache.setdefault(node, _node_wrapper(composition, node))
setattr(composition, '_node_assemblies', cache)
return cache.setdefault(node, _node_assembly(composition, node))

def convert_python_struct_to_llvm_ir(self, t):
self._stats["types_converted"] += 1
Expand Down
30 changes: 15 additions & 15 deletions psyneulink/core/llvm/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,9 @@ def find_max(builder, x):
return res


def gen_node_wrapper(ctx, composition, node, *, tags:frozenset):
assert "node_wrapper" in tags
func_tags = tags.difference({"node_wrapper"})
def gen_node_assembly(ctx, composition, node, *, tags:frozenset):
assert "node_assembly" in tags
func_tags = tags.difference({"node_assembly"})

node_function = ctx.import_llvm_function(node, tags=func_tags)
# FIXME: This is a hack
Expand Down Expand Up @@ -782,14 +782,14 @@ def _gen_composition_exec_context(ctx, composition, *, tags:frozenset, suffix=""
params = builder.alloca(const_params.type, name="const_params_loc")
builder.store(const_params, params)

node_tags = tags.union({"node_wrapper"})
node_tags = tags.union({"node_assembly"})
# Call input CIM
input_cim_w = ctx.get_node_wrapper(composition, composition.input_CIM)
input_cim_w = ctx.get_node_assembly(composition, composition.input_CIM)
input_cim_f = ctx.import_llvm_function(input_cim_w, tags=node_tags)
builder.call(input_cim_f, [state, params, comp_in, data, data])

# Call parameter CIM
param_cim_w = ctx.get_node_wrapper(composition, composition.parameter_CIM)
param_cim_w = ctx.get_node_assembly(composition, composition.parameter_CIM)
param_cim_f = ctx.import_llvm_function(param_cim_w, tags=node_tags)
builder.call(param_cim_f, [state, params, comp_in, data, data])

Expand All @@ -803,7 +803,7 @@ def _gen_composition_exec_context(ctx, composition, *, tags:frozenset, suffix=""

def gen_composition_exec(ctx, composition, *, tags:frozenset):
simulation = "simulation" in tags
node_tags = tags.union({"node_wrapper"})
node_tags = tags.union({"node_assembly"})

with _gen_composition_exec_context(ctx, composition, tags=tags) as (builder, data, params, cond_gen):
state, _, comp_in, _, cond = builder.function.args
Expand All @@ -823,7 +823,7 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
is_finished_callbacks = {}
for node in composition.nodes:
args = [state, params, comp_in, data, output_storage]
wrapper = ctx.get_node_wrapper(composition, node)
wrapper = ctx.get_node_assembly(composition, node)
is_finished_callbacks[node] = (wrapper, args)


Expand Down Expand Up @@ -851,14 +851,14 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
num_exec_locs,
nodes_states)
with builder.if_then(reinit_cond):
node_w = ctx.get_node_wrapper(composition, node)
node_w = ctx.get_node_assembly(composition, node)
node_reinit_f = ctx.import_llvm_function(node_w, tags=node_tags.union({"reset"}))
builder.call(node_reinit_f, [state, params, comp_in, data, data])

# Run controller if it's enabled in 'BEFORE' mode
if simulation is False and composition.enable_controller and composition.controller_mode == BEFORE:
assert composition.controller is not None
controller_w = ctx.get_node_wrapper(composition, composition.controller)
controller_w = ctx.get_node_assembly(composition, composition.controller)
controller_f = ctx.import_llvm_function(controller_w, tags=node_tags)
builder.call(controller_f, [state, params, comp_in, data, data])

Expand Down Expand Up @@ -929,7 +929,7 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
run_set_node_ptr = builder.gep(run_set_ptr, [zero, ctx.int32_ty(idx)])
node_cond = builder.load(run_set_node_ptr, name="node_" + node.name + "_should_run")
with builder.if_then(node_cond):
node_w = ctx.get_node_wrapper(composition, node)
node_w = ctx.get_node_assembly(composition, node)
node_f = ctx.import_llvm_function(node_w, tags=node_tags)
builder.block.name = "invoke_" + node_f.name
# Wrappers do proper indexing of all structures
Expand Down Expand Up @@ -984,12 +984,12 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
if simulation is False and composition.enable_controller and \
composition.controller_mode == AFTER:
assert composition.controller is not None
controller_w = ctx.get_node_wrapper(composition, composition.controller)
controller_w = ctx.get_node_assembly(composition, composition.controller)
controller_f = ctx.import_llvm_function(controller_w, tags=node_tags)
builder.call(controller_f, [state, params, comp_in, data, data])

# Call output CIM
output_cim_w = ctx.get_node_wrapper(composition, composition.output_CIM)
output_cim_w = ctx.get_node_assembly(composition, composition.output_CIM)
output_cim_f = ctx.import_llvm_function(output_cim_w, tags=node_tags)
builder.block.name = "invoke_" + output_cim_f.name
builder.call(output_cim_f, [state, params, comp_in, data, data])
Expand Down Expand Up @@ -1180,9 +1180,9 @@ def gen_autodiffcomp_exec(ctx, composition, *, tags:frozenset):
pytorch_func = ctx.import_llvm_function(pytorch_model, tags=tags)
builder.call(pytorch_func, [state, params, data])

node_tags = tags.union({"node_wrapper"})
node_tags = tags.union({"node_assembly"})
# Call output CIM
output_cim_w = ctx.get_node_wrapper(composition, composition.output_CIM)
output_cim_w = ctx.get_node_assembly(composition, composition.output_CIM)
output_cim_f = ctx.import_llvm_function(output_cim_w, tags=node_tags)
builder.call(output_cim_f, [state, params, comp_in, data, data])

Expand Down
4 changes: 2 additions & 2 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ def _bin_func_multirun(self):

def _set_bin_node(self, node):
assert node in self._composition._all_nodes
wrapper = builder_context.LLVMBuilderContext.get_current().get_node_wrapper(self._composition, node)
wrapper = builder_context.LLVMBuilderContext.get_current().get_node_assembly(self._composition, node)
self.__bin_func = pnlvm.LLVMBinaryFunction.from_obj(
wrapper, tags=self.__tags.union({"node_wrapper"}))
wrapper, tags=self.__tags.union({"node_assembly"}))

@property
def _conditions(self):
Expand Down
6 changes: 3 additions & 3 deletions psyneulink/core/llvm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
# The first argument is the target node
assert len(condition.args) == 1
target = is_finished_callbacks[condition.args[0]]
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_wrapper"}))
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"}))
return builder.call(is_finished_f, target[1])

elif isinstance(condition, WhenFinishedAny):
Expand All @@ -715,7 +715,7 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
run_cond = self.ctx.bool_ty(0)
for node in condition.args:
target = is_finished_callbacks[node]
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_wrapper"}))
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"}))
node_is_finished = builder.call(is_finished_f, target[1])

run_cond = builder.or_(run_cond, node_is_finished)
Expand All @@ -728,7 +728,7 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
run_cond = self.ctx.bool_ty(1)
for node in condition.args:
target = is_finished_callbacks[node]
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_wrapper"}))
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"}))
node_is_finished = builder.call(is_finished_f, target[1])

run_cond = builder.and_(run_cond, node_is_finished)
Expand Down

0 comments on commit d622a29

Please sign in to comment.