Skip to content

Commit

Permalink
llvm: Cleanup codegen of mechanisms that modify function return value (
Browse files Browse the repository at this point in the history
…#2421)

LCControlMechanism and DDM both modify function return values before using them as mechanism values.
Overloading _gen_llvm_mechanism_functions instead of _gen_llvm_invoke_function gives direct
access to modulated mechanism parameters without accessing the LLVM function arguments.

Use helper functions instead of explicitly checking types when possible.
Drop no longer needed shape workarounds.
  • Loading branch information
jvesely authored May 31, 2022
2 parents af25311 + 827ff8a commit a90236a
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2538,10 +2538,6 @@ def _gen_llvm_integrate(self, builder, index, ctx, vi, vo, params, state):
builder.call(rand_f, [random_state, rand_val_ptr])
rand_val = builder.load(rand_val_ptr)

if isinstance(rate.type, pnlvm.ir.ArrayType):
assert len(rate.type) == 1
rate = builder.extract_value(rate, 0)

# Get state pointers
prev_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, "previous_value")
prev_time_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, "previous_time")
Expand All @@ -2550,10 +2546,8 @@ def _gen_llvm_integrate(self, builder, index, ctx, vi, vo, params, state):
# + np.sqrt(time_step_size * noise) * random_state.normal()
prev_val_ptr = builder.gep(prev_ptr, [ctx.int32_ty(0), index])
prev_val = builder.load(prev_val_ptr)

val = builder.load(builder.gep(vi, [ctx.int32_ty(0), index]))
if isinstance(val.type, pnlvm.ir.ArrayType):
assert len(val.type) == 1
val = builder.extract_value(val, 0)
val = builder.fmul(val, rate)
val = builder.fmul(val, time_step_size)
val = builder.fadd(val, prev_val)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@
import numpy as np
import typecheck as tc

from psyneulink.core import llvm as pnlvm
from psyneulink.core.components.functions.function import Function_Base, is_function_type
from psyneulink.core.components.functions.nonstateful.transferfunctions import Identity
from psyneulink.core.components.functions.nonstateful.combinationfunctions import Concatenate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,13 +1544,11 @@ def _gen_llvm_is_finished_cond(self, ctx, builder, params, state):
return builder.fcmp_ordered("!=", is_finished_flag,
is_finished_flag.type(0))

# If modulated, termination threshold is single element array
if isinstance(threshold_ptr.type.pointee, pnlvm.ir.ArrayType):
assert len(threshold_ptr.type.pointee) == 1
threshold_ptr = builder.gep(threshold_ptr, [ctx.int32_ty(0),
ctx.int32_ty(0)])
# If modulated, termination threshold is single element array.
# Otherwise, it is scalar
threshold = pnlvm.helpers.load_extract_scalar_array_one(builder,
threshold_ptr)

threshold = builder.load(threshold_ptr)
cmp_val_ptr = builder.alloca(threshold.type, name="is_finished_value")
if self.termination_measure is max:
assert self._termination_measure_num_items_expected == 1
Expand Down
16 changes: 5 additions & 11 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,8 @@ def create_llvm_function(self, args, component, name=None, *, return_type=ir.Voi
a.attributes.add('nonnull')

metadata = self.get_debug_location(llvm_func, component)
if metadata is not None:
scope = dict(metadata.operands)["scope"]
llvm_func.set_metadata("dbg", scope)
scope = dict(metadata.operands)["scope"]
llvm_func.set_metadata("dbg", scope)

# Create entry block
block = llvm_func.append_basic_block(name="entry")
Expand Down Expand Up @@ -263,12 +262,9 @@ def get_random_state_ptr(self, builder, component, state, params):
used_seed = builder.load(used_seed_ptr)

seed_ptr = helpers.get_param_ptr(builder, component, params, "seed")
if isinstance(seed_ptr.type.pointee, ir.ArrayType):
# Modulated params are usually single element arrays
seed_ptr = builder.gep(seed_ptr, [self.int32_ty(0), self.int32_ty(0)])
new_seed = builder.load(seed_ptr)
new_seed = pnlvm.helpers.load_extract_scalar_array_one(builder, seed_ptr)
# FIXME: The seed should ideally be integer already.
# However, it can be modulated and we don't support,
# However, it can be modulated and we don't support
# passing integer values as computed results.
new_seed = builder.fptoui(new_seed, used_seed.type)

Expand Down Expand Up @@ -327,9 +323,7 @@ def get_debug_location(func: ir.Function, component):

@staticmethod
def update_debug_loc_position(di_loc: ir.DIValue, line:int, column:int):
subprogram_operand = di_loc.operands[2]
assert subprogram_operand[0] == 'scope'
di_func = subprogram_operand[1]
di_func = dict(di_loc.operands)["scope"]

return di_loc.parent.add_debug_info("DILocation", {
"line": line, "column": column, "scope": di_func,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,12 +834,11 @@ def _execute(

return gain_t, output_values[0], output_values[1], output_values[2]

def _gen_llvm_invoke_function(self, ctx, builder, function, params, state,
variable, out, *, tags:frozenset):
assert function is self.function
mf_out, builder = super()._gen_llvm_invoke_function(ctx, builder, function,
params, state, variable,
None, tags=tags)
def _gen_llvm_mechanism_functions(self, ctx, builder, m_base_params, m_params, m_state, m_in,
m_val, ip_output, *, tags:frozenset):
mf_out, builder = super()._gen_llvm_mechanism_functions(ctx, builder, m_base_params,
m_params, m_state, m_in,
None, ip_output, tags=tags)

# prepend gain type (matches output[1] type)
gain_ty = mf_out.type.pointee.elements[1]
Expand All @@ -849,49 +848,40 @@ def _gen_llvm_invoke_function(self, ctx, builder, function, params, state,

# allocate a new output location if the type doesn't match the one
# provided by the caller.
if mech_out_ty != out.type.pointee:
out = builder.alloca(mech_out_ty, name="mechanism_out")
if mech_out_ty != m_val.type.pointee:
m_val = builder.alloca(mech_out_ty, name="mechanism_out")

# Load mechanism parameters
params = builder.function.args[0]
scaling_factor_ptr = pnlvm.helpers.get_param_ptr(builder, self, params,
scaling_factor_ptr = pnlvm.helpers.get_param_ptr(builder, self, m_params,
"scaling_factor_gain")
base_factor_ptr = pnlvm.helpers.get_param_ptr(builder, self, params,
base_factor_ptr = pnlvm.helpers.get_param_ptr(builder, self, m_params,
"base_level_gain")
# If modulated, scaling factor is a single element array
if isinstance(scaling_factor_ptr.type.pointee, pnlvm.ir.ArrayType):
assert len(scaling_factor_ptr.type.pointee) == 1
scaling_factor_ptr = builder.gep(scaling_factor_ptr,
[ctx.int32_ty(0), ctx.int32_ty(0)])
# If modulated, base factor is a single element array
if isinstance(base_factor_ptr.type.pointee, pnlvm.ir.ArrayType):
assert len(base_factor_ptr.type.pointee) == 1
base_factor_ptr = builder.gep(base_factor_ptr,
[ctx.int32_ty(0), ctx.int32_ty(0)])
scaling_factor = builder.load(scaling_factor_ptr)
base_factor = builder.load(base_factor_ptr)

# Apply to the entire vector
# If modulated, parameters are single element array
scaling_factor = pnlvm.helpers.load_extract_scalar_array_one(builder, scaling_factor_ptr)
base_factor = pnlvm.helpers.load_extract_scalar_array_one(builder, base_factor_ptr)

# Apply to the entire first subvector
vi = builder.gep(mf_out, [ctx.int32_ty(0), ctx.int32_ty(1)])
vo = builder.gep(out, [ctx.int32_ty(0), ctx.int32_ty(0)])
vo = builder.gep(m_val, [ctx.int32_ty(0), ctx.int32_ty(0)])

with pnlvm.helpers.array_ptr_loop(builder, vi, "LC_gain") as (b1, index):
in_ptr = b1.gep(vi, [ctx.int32_ty(0), index])
out_ptr = b1.gep(vo, [ctx.int32_ty(0), index])

val = b1.load(in_ptr)
val = b1.fmul(val, scaling_factor)
val = b1.fadd(val, base_factor)

out_ptr = b1.gep(vo, [ctx.int32_ty(0), index])
b1.store(val, out_ptr)

# copy the main function return value
for i, _ in enumerate(mf_out.type.pointee.elements):
ptr = builder.gep(mf_out, [ctx.int32_ty(0), ctx.int32_ty(i)])
out_ptr = builder.gep(out, [ctx.int32_ty(0), ctx.int32_ty(i + 1)])
out_ptr = builder.gep(m_val, [ctx.int32_ty(0), ctx.int32_ty(i + 1)])
val = builder.load(ptr)
builder.store(val, out_ptr)

return out, builder
return m_val, builder

# 5/8/20: ELIMINATE SYSTEM
# SEEMS TO STILL BE USED BY SOME MODELS; DELETE WHEN THOSE ARE UPDATED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1101,25 +1101,19 @@ def _execute(
return return_value

def _gen_llvm_invoke_function(self, ctx, builder, function, params, state,
variable, out, *, tags:frozenset):

mf_out, builder = super()._gen_llvm_invoke_function(ctx, builder, function,
params, state, variable,
None, tags=tags)
mech_out = out
variable, m_val, *, tags:frozenset):

if isinstance(self.function, IntegratorFunction):
# Integrator version of the DDM mechanism converts the
# second element to a 2d array
builder.store(builder.load(builder.gep(mf_out, [ctx.int32_ty(0),
ctx.int32_ty(0)])),
builder.gep(mech_out, [ctx.int32_ty(0),
ctx.int32_ty(0)]))
builder.store(builder.load(builder.gep(mf_out, [ctx.int32_ty(0),
ctx.int32_ty(1)])),
builder.gep(mech_out, [ctx.int32_ty(0),
ctx.int32_ty(1)]))
# Integrator based DDM works like other mechanisms
return super()._gen_llvm_invoke_function(ctx, builder, function,
params, state, variable,
m_val, tags=tags)

elif isinstance(self.function, DriftDiffusionAnalytical):
mf_out, builder = super()._gen_llvm_invoke_function(ctx, builder, function,
params, state, variable,
None, tags=tags)
# The order and number of returned values is different for DDA
for res_idx, idx in enumerate((self.RESPONSE_TIME_INDEX,
self.PROBABILITY_LOWER_THRESHOLD_INDEX,
self.RT_CORRECT_MEAN_INDEX,
Expand All @@ -1129,47 +1123,68 @@ def _gen_llvm_invoke_function(self, ctx, builder, function, params, state,
self.RT_INCORRECT_VARIANCE_INDEX,
self.RT_INCORRECT_SKEW_INDEX)):
src = builder.gep(mf_out, [ctx.int32_ty(0), ctx.int32_ty(res_idx)])
dst = builder.gep(mech_out, [ctx.int32_ty(0), ctx.int32_ty(idx)])
dst = builder.gep(m_val, [ctx.int32_ty(0), ctx.int32_ty(idx)])
builder.store(builder.load(src), dst)

# Handle upper threshold probability
src = builder.gep(mf_out, [ctx.int32_ty(0), ctx.int32_ty(1),
ctx.int32_ty(0)])
dst = builder.gep(mech_out, [ctx.int32_ty(0),
ctx.int32_ty(self.PROBABILITY_UPPER_THRESHOLD_INDEX),
ctx.int32_ty(0)])
# Handle upper threshold probability (1 - Lower Threshold)
src = builder.gep(m_val, [ctx.int32_ty(0),
ctx.int32_ty(self.PROBABILITY_LOWER_THRESHOLD_INDEX),
ctx.int32_ty(0)])
dst = builder.gep(m_val, [ctx.int32_ty(0),
ctx.int32_ty(self.PROBABILITY_UPPER_THRESHOLD_INDEX),
ctx.int32_ty(0)])
prob_lower_thr = builder.load(src)
prob_upper_thr = builder.fsub(prob_lower_thr.type(1),
prob_lower_thr)
prob_upper_thr = builder.fsub(prob_lower_thr.type(1), prob_lower_thr)
builder.store(prob_upper_thr, dst)

# Load function threshold
# Store threshold as decision variable output
# this will be used by the mechanism to return the right decision
threshold_ptr = pnlvm.helpers.get_param_ptr(builder, self.function,
params, THRESHOLD)
threshold = pnlvm.helpers.load_extract_scalar_array_one(builder,
threshold_ptr)
# Load mechanism state to generate random numbers
mech_params = builder.function.args[0]
mech_state = builder.function.args[1]
random_state = ctx.get_random_state_ptr(builder, self, mech_state, mech_params)
threshold = pnlvm.helpers.load_extract_scalar_array_one(builder, threshold_ptr)
decision_ptr = builder.gep(m_val, [ctx.int32_ty(0),
ctx.int32_ty(self.DECISION_VARIABLE_INDEX),
ctx.int32_ty(0)])
builder.store(threshold, decision_ptr)
else:
assert False, "Unknown mode in compiled DDM!"

return m_val, builder

def _gen_llvm_mechanism_functions(self, ctx, builder, m_base_params, m_params, m_state, m_in,
m_val, ip_output, *, tags:frozenset):

mf_out, builder = super()._gen_llvm_mechanism_functions(ctx, builder, m_base_params,
m_params, m_state, m_in, m_val,
ip_output, tags=tags)
assert mf_out is m_val

if isinstance(self.function, DriftDiffusionAnalytical):
random_state = ctx.get_random_state_ptr(builder, self, m_state, m_params)
random_f = ctx.get_uniform_dist_function_by_state(random_state)
random_val_ptr = builder.alloca(random_f.args[1].type.pointee, name="random_out")
builder.call(random_f, [random_state, random_val_ptr])
random_val = builder.load(random_val_ptr)

# Convert ER to decision variable:
dst = builder.gep(mech_out, [ctx.int32_ty(0),
ctx.int32_ty(self.DECISION_VARIABLE_INDEX),
ctx.int32_ty(0)])
prob_lthr_ptr = builder.gep(m_val, [ctx.int32_ty(0),
ctx.int32_ty(self.PROBABILITY_LOWER_THRESHOLD_INDEX),
ctx.int32_ty(0)])
prob_lower_thr = builder.load(prob_lthr_ptr)
thr_cmp = builder.fcmp_ordered("<", random_val, prob_lower_thr)

# The correct (modulated) threshold value is passed as
# decision variable output
decision_ptr = builder.gep(m_val, [ctx.int32_ty(0),
ctx.int32_ty(self.DECISION_VARIABLE_INDEX),
ctx.int32_ty(0)])
threshold = builder.load(decision_ptr)
neg_threshold = pnlvm.helpers.fneg(builder, threshold)
res = builder.select(thr_cmp, neg_threshold, threshold)

builder.store(res, dst)
else:
assert False, "Unknown mode in compiled DDM!"
builder.store(res, decision_ptr)

return mech_out, builder
return m_val, builder

@handle_external_context(fallback_most_recent=True)
def reset(self, *args, force=False, context=None, **kwargs):
Expand Down

0 comments on commit a90236a

Please sign in to comment.