Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llvm: Remove workaround, rename node assembly #3016

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def module_count():
'mt_rand_init', 'philox_rand_init'))


class _node_wrapper():
class _node_assembly():
def __init__(self, composition, node):
self._comp = weakref.proxy(composition)
self._node = node
Expand All @@ -61,7 +61,7 @@ def __repr__(self):
return "Node wrapper for node '{}' in composition '{}'".format(self._node, self._comp)

def _gen_llvm_function(self, *, ctx, tags:frozenset):
return codegen.gen_node_wrapper(ctx, self._comp, self._node, tags=tags)
return codegen.gen_node_assembly(ctx, self._comp, self._node, tags=tags)

def _comp_cached(func):
@functools.wraps(func)
Expand Down Expand Up @@ -349,6 +349,13 @@ def get_state_space(self, builder, component, state_ptr, param):
return helpers.get_state_space(builder, component, state_ptr, param_name)

def check_used_params(self, component, *, tags:frozenset):
"""
This function checks that parameters included in the compiled structures are used in compiled code.

If the assertion in this function triggers the parameter name should be added to the parameter
block list in the Component class.
"""

# Skip the check if the parameter use is not tracked. Some components (like node wrappers)
# don't even have parameters.
if component not in self._component_state_use and component not in self._component_param_use:
Expand Down Expand Up @@ -378,12 +385,6 @@ def check_used_params(self, component, *, tags:frozenset):
if hasattr(component, 'evaluate_agent_rep'):
used_param_ids.add('num_trials_per_estimate')

if hasattr(component, 'adapt_scale'):
used_param_ids.add('threshold')
used_param_ids.add('adapt_scale')
used_param_ids.add('adapt_base')
used_param_ids.add('adapt_entropy_weighting')

unused_param_ids = component_param_ids - used_param_ids - initializers
unused_state_ids = component_state_ids - used_state_ids

Expand Down Expand Up @@ -504,12 +505,12 @@ def get_data_struct_type(self, component):

return ir.LiteralStructType([])

def get_node_wrapper(self, composition, node):
cache = getattr(composition, '_wrapped_nodes', None)
def get_node_assembly(self, composition, node):
cache = getattr(composition, '_node_assemblies', None)
if cache is None:
cache = weakref.WeakKeyDictionary()
setattr(composition, '_wrapped_nodes', cache)
return cache.setdefault(node, _node_wrapper(composition, node))
setattr(composition, '_node_assemblies', cache)
return cache.setdefault(node, _node_assembly(composition, node))

def convert_python_struct_to_llvm_ir(self, t):
self._stats["types_converted"] += 1
Expand Down
30 changes: 15 additions & 15 deletions psyneulink/core/llvm/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,9 @@ def find_max(builder, x):
return res


def gen_node_wrapper(ctx, composition, node, *, tags:frozenset):
assert "node_wrapper" in tags
func_tags = tags.difference({"node_wrapper"})
def gen_node_assembly(ctx, composition, node, *, tags:frozenset):
assert "node_assembly" in tags
func_tags = tags.difference({"node_assembly"})

node_function = ctx.import_llvm_function(node, tags=func_tags)
# FIXME: This is a hack
Expand Down Expand Up @@ -782,14 +782,14 @@ def _gen_composition_exec_context(ctx, composition, *, tags:frozenset, suffix=""
params = builder.alloca(const_params.type, name="const_params_loc")
builder.store(const_params, params)

node_tags = tags.union({"node_wrapper"})
node_tags = tags.union({"node_assembly"})
# Call input CIM
input_cim_w = ctx.get_node_wrapper(composition, composition.input_CIM)
input_cim_w = ctx.get_node_assembly(composition, composition.input_CIM)
input_cim_f = ctx.import_llvm_function(input_cim_w, tags=node_tags)
builder.call(input_cim_f, [state, params, comp_in, data, data])

# Call parameter CIM
param_cim_w = ctx.get_node_wrapper(composition, composition.parameter_CIM)
param_cim_w = ctx.get_node_assembly(composition, composition.parameter_CIM)
param_cim_f = ctx.import_llvm_function(param_cim_w, tags=node_tags)
builder.call(param_cim_f, [state, params, comp_in, data, data])

Expand All @@ -803,7 +803,7 @@ def _gen_composition_exec_context(ctx, composition, *, tags:frozenset, suffix=""

def gen_composition_exec(ctx, composition, *, tags:frozenset):
simulation = "simulation" in tags
node_tags = tags.union({"node_wrapper"})
node_tags = tags.union({"node_assembly"})

with _gen_composition_exec_context(ctx, composition, tags=tags) as (builder, data, params, cond_gen):
state, _, comp_in, _, cond = builder.function.args
Expand All @@ -823,7 +823,7 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
is_finished_callbacks = {}
for node in composition.nodes:
args = [state, params, comp_in, data, output_storage]
wrapper = ctx.get_node_wrapper(composition, node)
wrapper = ctx.get_node_assembly(composition, node)
is_finished_callbacks[node] = (wrapper, args)


Expand Down Expand Up @@ -851,14 +851,14 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
num_exec_locs,
nodes_states)
with builder.if_then(reinit_cond):
node_w = ctx.get_node_wrapper(composition, node)
node_w = ctx.get_node_assembly(composition, node)
node_reinit_f = ctx.import_llvm_function(node_w, tags=node_tags.union({"reset"}))
builder.call(node_reinit_f, [state, params, comp_in, data, data])

# Run controller if it's enabled in 'BEFORE' mode
if simulation is False and composition.enable_controller and composition.controller_mode == BEFORE:
assert composition.controller is not None
controller_w = ctx.get_node_wrapper(composition, composition.controller)
controller_w = ctx.get_node_assembly(composition, composition.controller)
controller_f = ctx.import_llvm_function(controller_w, tags=node_tags)
builder.call(controller_f, [state, params, comp_in, data, data])

Expand Down Expand Up @@ -929,7 +929,7 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
run_set_node_ptr = builder.gep(run_set_ptr, [zero, ctx.int32_ty(idx)])
node_cond = builder.load(run_set_node_ptr, name="node_" + node.name + "_should_run")
with builder.if_then(node_cond):
node_w = ctx.get_node_wrapper(composition, node)
node_w = ctx.get_node_assembly(composition, node)
node_f = ctx.import_llvm_function(node_w, tags=node_tags)
builder.block.name = "invoke_" + node_f.name
# Wrappers do proper indexing of all structures
Expand Down Expand Up @@ -984,12 +984,12 @@ def gen_composition_exec(ctx, composition, *, tags:frozenset):
if simulation is False and composition.enable_controller and \
composition.controller_mode == AFTER:
assert composition.controller is not None
controller_w = ctx.get_node_wrapper(composition, composition.controller)
controller_w = ctx.get_node_assembly(composition, composition.controller)
controller_f = ctx.import_llvm_function(controller_w, tags=node_tags)
builder.call(controller_f, [state, params, comp_in, data, data])

# Call output CIM
output_cim_w = ctx.get_node_wrapper(composition, composition.output_CIM)
output_cim_w = ctx.get_node_assembly(composition, composition.output_CIM)
output_cim_f = ctx.import_llvm_function(output_cim_w, tags=node_tags)
builder.block.name = "invoke_" + output_cim_f.name
builder.call(output_cim_f, [state, params, comp_in, data, data])
Expand Down Expand Up @@ -1180,9 +1180,9 @@ def gen_autodiffcomp_exec(ctx, composition, *, tags:frozenset):
pytorch_func = ctx.import_llvm_function(pytorch_model, tags=tags)
builder.call(pytorch_func, [state, params, data])

node_tags = tags.union({"node_wrapper"})
node_tags = tags.union({"node_assembly"})
# Call output CIM
output_cim_w = ctx.get_node_wrapper(composition, composition.output_CIM)
output_cim_w = ctx.get_node_assembly(composition, composition.output_CIM)
output_cim_f = ctx.import_llvm_function(output_cim_w, tags=node_tags)
builder.call(output_cim_f, [state, params, comp_in, data, data])

Expand Down
4 changes: 2 additions & 2 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ def _bin_func_multirun(self):

def _set_bin_node(self, node):
assert node in self._composition._all_nodes
wrapper = builder_context.LLVMBuilderContext.get_current().get_node_wrapper(self._composition, node)
wrapper = builder_context.LLVMBuilderContext.get_current().get_node_assembly(self._composition, node)
self.__bin_func = pnlvm.LLVMBinaryFunction.from_obj(
wrapper, tags=self.__tags.union({"node_wrapper"}))
wrapper, tags=self.__tags.union({"node_assembly"}))

@property
def _conditions(self):
Expand Down
6 changes: 3 additions & 3 deletions psyneulink/core/llvm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
# The first argument is the target node
assert len(condition.args) == 1
target = is_finished_callbacks[condition.args[0]]
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_wrapper"}))
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"}))
return builder.call(is_finished_f, target[1])

elif isinstance(condition, WhenFinishedAny):
Expand All @@ -715,7 +715,7 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
run_cond = self.ctx.bool_ty(0)
for node in condition.args:
target = is_finished_callbacks[node]
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_wrapper"}))
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"}))
node_is_finished = builder.call(is_finished_f, target[1])

run_cond = builder.or_(run_cond, node_is_finished)
Expand All @@ -728,7 +728,7 @@ def generate_sched_condition(self, builder, condition, cond_ptr, node,
run_cond = self.ctx.bool_ty(1)
for node in condition.args:
target = is_finished_callbacks[node]
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_wrapper"}))
is_finished_f = self.ctx.import_llvm_function(target[0], tags=frozenset({"is_finished", "node_assembly"}))
node_is_finished = builder.call(is_finished_f, target[1])

run_cond = builder.and_(run_cond, node_is_finished)
Expand Down
Loading