Skip to content

Commit

Permalink
llvm, Function, Mechanism: Track used parameters and state
Browse files Browse the repository at this point in the history
Add a single function to return pointer to parameter/state substructure
from component structures.
The new function tracks requested substructures and asserts that all
substructures were requests when generating component's LLVM code.

get_state_space is extended with tracking and moved to builder context.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Dec 2, 2023
1 parent 06a9495 commit 354b065
Show file tree
Hide file tree
Showing 15 changed files with 265 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Parameters(Function_Base.Parameters):
variable = Parameter(np.array([0]), read_only=True, pnl_internal=True, constructor_argument='default_variable')

def _gen_llvm_load_param(self, ctx, builder, params, param_name, index, default):
param_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, param_name)
param_ptr = ctx.get_param_or_state_ptr(builder, self, param_name, param_struct_ptr=params)
param_type = param_ptr.type.pointee
if isinstance(param_type, pnlvm.ir.LiteralStructType):
assert len(param_type) == 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def _function(self,

def _gen_llvm_function_body(self, ctx, builder, params, state, _, arg_out, *, tags:frozenset):
random_state = ctx.get_random_state_ptr(builder, self, state, params)
mean_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, "mean")
std_dev_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, "standard_deviation")
mean_ptr = ctx.get_param_or_state_ptr(builder, self, DIST_MEAN, param_struct_ptr=params)
std_dev_ptr = ctx.get_param_or_state_ptr(builder, self, STANDARD_DEVIATION, param_struct_ptr=params)
ret_val_ptr = builder.alloca(ctx.float_ty)
norm_rand_f = ctx.get_normal_dist_function_by_state(random_state)
builder.call(norm_rand_f, [random_state, ret_val_ptr])
Expand Down Expand Up @@ -634,8 +634,8 @@ def _function(self,

def _gen_llvm_function_body(self, ctx, builder, params, state, _, arg_out, *, tags:frozenset):
random_state = ctx.get_random_state_ptr(builder, self, state, params)
low_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, LOW)
high_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, HIGH)
low_ptr = ctx.get_param_or_state_ptr(builder, self, LOW, param_struct_ptr=params)
high_ptr = ctx.get_param_or_state_ptr(builder, self, HIGH, param_struct_ptr=params)
ret_val_ptr = builder.alloca(ctx.float_ty)
norm_rand_f = ctx.get_uniform_dist_function_by_state(random_state)
builder.call(norm_rand_f, [random_state, ret_val_ptr])
Expand Down Expand Up @@ -1421,7 +1421,7 @@ def csch(x):
def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):

def load_scalar_param(name):
param_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, name)
param_ptr = ctx.get_param_or_state_ptr(builder, self, name, param_struct_ptr=params)
return pnlvm.helpers.load_extract_scalar_array_one(builder, param_ptr)

attentional_drift_rate = load_scalar_param(DRIFT_RATE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _update_default_variable(self, new_default_variable, context):
def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
# Dot product
dot_out = builder.alloca(arg_in.type.pointee)
matrix = pnlvm.helpers.get_param_ptr(builder, self, params, MATRIX)
matrix = ctx.get_param_or_state_ptr(builder, self, MATRIX, param_struct_ptr=params, state_struct_ptr=state)

# Convert array pointer to pointer to the fist element
matrix = builder.gep(matrix, [ctx.int32_ty(0), ctx.int32_ty(0)])
Expand All @@ -423,15 +423,18 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
#FIXME: implement this
assert False, "Support for transfer functions is not implemented"
else:
# Check that transfer_fct is absent from the compiled parameter
# structure or represented by an empty structure
assert "transfer_fct" not in self.llvm_param_ids or ctx.get_param_or_state_ptr(builder, self, "transfer_fct", param_struct_ptr=params).type.pointee.elements == ()

trans_out = builder.gep(metric_in, [ctx.int32_ty(0), ctx.int32_ty(1)])
builder.store(builder.load(dot_out), trans_out)

# Copy original variable
builder.store(builder.load(arg_in), builder.gep(metric_in, [ctx.int32_ty(0), ctx.int32_ty(0)]))

# Distance Function
metric_params = pnlvm.helpers.get_param_ptr(builder, self, params, "metric_fct")
metric_state = pnlvm.helpers.get_state_ptr(builder, self, state, "metric_fct")
metric_params, metric_state = ctx.get_param_or_state_ptr(builder, self, "metric_fct", param_struct_ptr=params, state_struct_ptr=state)
metric_out = arg_out
builder.call(metric_fun, [metric_params, metric_state, metric_in, metric_out])
return builder
Expand Down Expand Up @@ -1120,7 +1123,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, _, arg_in, arg_out, *, t
arg_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0),
ctx.int32_ty(0)])

normalize_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, NORMALIZE)
normalize_ptr = ctx.get_param_or_state_ptr(builder, self, NORMALIZE, param_struct_ptr=params)
normalize = builder.load(normalize_ptr)
normalize_b = builder.fcmp_ordered("!=", normalize, normalize.type(0))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1748,8 +1748,7 @@ def _gen_llvm_select_min_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:fr
samples_ptr.attributes.remove('nonnull')

random_state = ctx.get_random_state_ptr(builder, self, state, params)
select_random_ptr = pnlvm.helpers.get_param_ptr(builder, self, params,
self.parameters.select_randomly_from_optimal_values.name)
select_random_ptr = ctx.get_param_or_state_ptr(builder, self, self.parameters.select_randomly_from_optimal_values, param_struct_ptr=params)

select_random_val = builder.load(select_random_ptr)
select_random = builder.fcmp_ordered("!=", select_random_val,
Expand Down Expand Up @@ -1806,8 +1805,7 @@ def _gen_llvm_select_min_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:fr
gen_samples = builder.icmp_signed("==", samples_ptr, samples_ptr.type(None))
with builder.if_else(gen_samples) as (b_true, b_false):
with b_true:
search_space = pnlvm.helpers.get_param_ptr(builder, self, params,
self.parameters.search_space.name)
search_space = ctx.get_param_or_state_ptr(builder, self, self.parameters.search_space.name, param_struct_ptr=params)
pnlvm.helpers.create_sample(b, min_sample_ptr, search_space, min_idx)
with b_false:
sample_ptr = builder.gep(samples_ptr, [min_idx])
Expand Down Expand Up @@ -1862,10 +1860,8 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
extra_args = [comp_input, comp_args[2], num_inputs]
else:
obj_func = ctx.import_llvm_function(self.objective_function)
obj_state_ptr = pnlvm.helpers.get_state_ptr(builder, self, state,
"objective_function")
obj_param_ptr = pnlvm.helpers.get_param_ptr(builder, self, params,
"objective_function")
obj_param_ptr, obj_state_ptr = ctx.get_param_or_state_ptr(builder, self, "objective_function",
param_struct_ptr=params, state_struct_ptr=state)
extra_args = []

sample_t = obj_func.args[2].type.pointee
Expand All @@ -1875,8 +1871,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
sample_ptr = builder.alloca(sample_t)
value_ptr = builder.alloca(value_t)

search_space_ptr = pnlvm.helpers.get_param_ptr(builder, self, params,
self.parameters.search_space.name)
search_space_ptr = ctx.get_param_or_state_ptr(builder, self, self.parameters.search_space, param_struct_ptr=params)

opt_count_ptr = builder.alloca(ctx.float_ty)
builder.store(opt_count_ptr.type.pointee(0), opt_count_ptr)
Expand Down
Loading

0 comments on commit 354b065

Please sign in to comment.