Skip to content

Commit

Permalink
llvm/execution: Use 'evaluate' function params when constructing simu…
Browse files Browse the repository at this point in the history
…lation input

The original code assumed 'run' variant would be called.
This is not the case for parallel evaluate that only needs
'run, simulation' variant, resulting in redundant code
generation and compiler calls.

Instead, use the 'evaluate' compiled function that provides the
same binary type of the argument at a different offset.

Tested by asserting that Python-{LLVM,PTX} tests only generate
'run, simulate' variant of the composition function.

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Aug 17, 2022
1 parent d81df22 commit 1d81c49
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,11 @@ def cuda_execute(self, inputs):

# Methods used to accelerate "Run"

def _get_run_input_struct(self, inputs, num_input_sets):
input_type = self._bin_run_func.byref_arg_types[3]
def _get_run_input_struct(self, inputs, num_input_sets, arg=3):
# Callers that override input arg, should ensure that _bin_func is not None
bin_f = self._bin_run_func if arg == 3 else self._bin_func

input_type = bin_f.byref_arg_types[arg]
c_input = (input_type * num_input_sets) * len(self._execution_contexts)
if len(self._execution_contexts) == 1:
inputs = [inputs]
Expand Down Expand Up @@ -694,8 +697,8 @@ def _prepare_evaluate(self, inputs, num_input_sets, num_evaluations):
ct_comp_state = self._get_compilation_param('_eval_state', '_get_state_initializer', 1)
ct_comp_data = self._get_compilation_param('_eval_data', '_get_data_initializer', 6)

# Construct input variable
ct_inputs = self._get_run_input_struct(inputs, num_input_sets)
# Construct input variable, the 5th parameter of the evaluate function
ct_inputs = self._get_run_input_struct(inputs, num_input_sets, 5)

# Output ctype
out_ty = bin_func.byref_arg_types[4] * num_evaluations
Expand Down

0 comments on commit 1d81c49

Please sign in to comment.