Skip to content

Commit

Permalink
llvm, ports/ControlSignal: Always call 'combine_costs' when calculati…
Browse files Browse the repository at this point in the history
…ng costs (#2146)

Matches Python semantics.
Add input shape workaround for Reduce function.
  • Loading branch information
jvesely authored Oct 13, 2021
2 parents e99c010 + b181aec commit 6c5cc54
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,13 @@ def _function(self,

return self.convert_output_type(result)

def _get_input_struct_type(self, ctx):
# FIXME: Workaround a special case of simple array.
# It should just pass through to modifiers, which matches what
# single element 2d array does
default_var = np.atleast_2d(self.defaults.variable)
return ctx.convert_python_struct_to_llvm_ir(default_var)

def _gen_llvm_combine(self, builder, index, ctx, vi, vo, params):
scale = self._gen_llvm_load_param(ctx, builder, params, SCALE, index, 1.0)
offset = self._gen_llvm_load_param(ctx, builder, params, OFFSET, index, -0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@

# FIX: EVCControlMechanism IS IMPORTED HERE TO DEAL WITH COST FUNCTIONS THAT ARE DEFINED IN EVCControlMechanism
# SHOULD THEY BE LIMITED TO EVC??
from psyneulink.core import llvm as pnlvm
from psyneulink.core.components.functions.nonstateful.combinationfunctions import Reduce
from psyneulink.core.components.functions.function import is_function_type
from psyneulink.core.components.functions.stateful.integratorfunctions import SimpleIntegrator
Expand Down Expand Up @@ -1078,3 +1079,72 @@ def compute_costs(self, intensity, context=None):
all_costs = [intensity_cost, adjustment_cost, duration_cost]
combined_cost = self.combine_costs_function(all_costs, context=context)
return max(0.0, combined_cost)

def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext,
extra_args=[], tags:frozenset):
if "costs" in tags:
assert len(extra_args) == 0
return self._gen_llvm_costs(ctx=ctx, tags=tags)

return super()._gen_llvm_function(ctx=ctx, extra_args=extra_args, tags=tags)

def _gen_llvm_costs(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
args = [ctx.get_param_struct_type(self).as_pointer(),
ctx.get_state_struct_type(self).as_pointer(),
ctx.get_input_struct_type(self).as_pointer()]

assert "costs" in tags
builder = ctx.create_llvm_function(args, self, str(self) + "_costs",
tags=tags,
return_type=ctx.float_ty)

params, state, arg_in = builder.function.args

func_params = pnlvm.helpers.get_param_ptr(builder, self, params,
"function")
func_state = pnlvm.helpers.get_state_ptr(builder, self, state,
"function")
# FIXME: Add support for other cost types
assert self.cost_options == CostFunctions.INTENSITY

cfunc = ctx.import_llvm_function(self.function.combine_costs_fct)
cfunc_in = builder.alloca(cfunc.args[2].type.pointee)

# Set to 0 be default
builder.store(cfunc_in.type.pointee(None), cfunc_in)

cost_funcs = 0
if self.cost_options & CostFunctions.INTENSITY:
ifunc = ctx.import_llvm_function(self.function.intensity_cost_fct)

ifunc_params = pnlvm.helpers.get_param_ptr(builder, self.function,
func_params,
"intensity_cost_fct")
ifunc_state = pnlvm.helpers.get_state_ptr(builder, self.function,
func_state,
"intensity_cost_fct")
# Port input is always struct { data input, modulations }
ifunc_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
# point output to the proper slot in comb func input
ifunc_comb_slot = builder.gep(cfunc_in, [ctx.int32_ty(0), ctx.int32_ty(cost_funcs)])

builder.call(ifunc, [ifunc_params, ifunc_state, ifunc_in, ifunc_comb_slot])

cost_funcs += 1


# Call combination function
cfunc_params = pnlvm.helpers.get_param_ptr(builder, self.function,
func_params,
"combine_costs_fct")
cfunc_state = pnlvm.helpers.get_state_ptr(builder, self.function,
func_state,
"combine_costs_fct")
cfunc_out = builder.alloca(cfunc.args[3].type.pointee)
builder.call(cfunc, [cfunc_params, cfunc_state, cfunc_in, cfunc_out])


ret_val = pnlvm.helpers.load_extract_scalar_array_one(builder, cfunc_out)
builder.ret(ret_val)

return builder.function
51 changes: 0 additions & 51 deletions psyneulink/core/components/ports/outputport.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,8 @@
import types
import warnings

from psyneulink.core import llvm as pnlvm
from psyneulink.core.components.component import Component, ComponentError
from psyneulink.core.components.functions.function import Function
from psyneulink.core.components.functions.nonstateful.transferfunctions import CostFunctions
from psyneulink.core.components.ports.port import Port_Base, _instantiate_port_list, port_type_keywords
from psyneulink.core.globals.context import ContextFlags, handle_external_context
from psyneulink.core.globals.keywords import \
Expand Down Expand Up @@ -1291,55 +1289,6 @@ def _dict_summary(self):
}
}

def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext,
extra_args=[], tags:frozenset):
if "costs" in tags:
assert len(extra_args) == 0
return self._gen_llvm_costs(ctx=ctx, tags=tags)

return super()._gen_llvm_function(ctx=ctx, extra_args=extra_args, tags=tags)

def _gen_llvm_costs(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
args = [ctx.get_param_struct_type(self).as_pointer(),
ctx.get_state_struct_type(self).as_pointer(),
ctx.get_input_struct_type(self).as_pointer()]

assert "costs" in tags
builder = ctx.create_llvm_function(args, self, str(self) + "_costs",
tags=tags,
return_type=ctx.float_ty)

params, state, arg_in = builder.function.args

# FIXME: Add support for other cost types
assert self.cost_options == CostFunctions.INTENSITY

ifunc = ctx.import_llvm_function(self.function.intensity_cost_fct)

func_params = pnlvm.helpers.get_param_ptr(builder, self, params,
"function")
func_state = pnlvm.helpers.get_state_ptr(builder, self, state,
"function")
ifunc_params = pnlvm.helpers.get_param_ptr(builder, self.function,
func_params,
"intensity_cost_fct")
ifunc_state = pnlvm.helpers.get_state_ptr(builder, self.function,
func_state,
"intensity_cost_fct")
ifunc_out = builder.alloca(ifunc.args[3].type.pointee)
# Port input is always struct
ifunc_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])

builder.call(ifunc, [ifunc_params, ifunc_state, ifunc_in, ifunc_out])


# Cost function output is 1 element array
ret_ptr = builder.gep(ifunc_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
ret_val = builder.load(ret_ptr)
builder.ret(ret_val)

return builder.function


def _instantiate_output_ports(owner, output_ports=None, context=None):
"""Call Port._instantiate_port_list() to instantiate ContentAddressableList of OutputPort(s)
Expand Down

0 comments on commit 6c5cc54

Please sign in to comment.