From ab6e91106e9de641dcf62e46880d0a859c022c42 Mon Sep 17 00:00:00 2001 From: Jan Vesely Date: Wed, 17 Aug 2022 14:59:32 -0400 Subject: [PATCH] llvm/execution: Use 'evaluate' function params when constructing simulation input (#2467) The original code assumed 'run' variant would be called. This is not the case for parallel evaluate that only needs 'run, simulation' variant, resulting in redundant code generation and compiler calls. Instead, use the 'evaluate' compiled function that provides the same binary type of the argument at a different offset. Tested by asserting that Python-{LLVM,PTX} tests only generate 'run, simulation' variant of the composition function. Signed-off-by: Jan Vesely --- psyneulink/core/llvm/execution.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/psyneulink/core/llvm/execution.py b/psyneulink/core/llvm/execution.py index 1075c6b4d24..2500d160997 100644 --- a/psyneulink/core/llvm/execution.py +++ b/psyneulink/core/llvm/execution.py @@ -563,8 +563,11 @@ def cuda_execute(self, inputs): # Methods used to accelerate "Run" - def _get_run_input_struct(self, inputs, num_input_sets): - input_type = self._bin_run_func.byref_arg_types[3] + def _get_run_input_struct(self, inputs, num_input_sets, arg=3): + # Callers that override input arg, should ensure that _bin_func is not None + bin_f = self._bin_run_func if arg == 3 else self._bin_func + + input_type = bin_f.byref_arg_types[arg] c_input = (input_type * num_input_sets) * len(self._execution_contexts) if len(self._execution_contexts) == 1: inputs = [inputs] @@ -694,8 +697,8 @@ def _prepare_evaluate(self, inputs, num_input_sets, num_evaluations): ct_comp_state = self._get_compilation_param('_eval_state', '_get_state_initializer', 1) ct_comp_data = self._get_compilation_param('_eval_data', '_get_data_initializer', 6) - # Construct input variable - ct_inputs = self._get_run_input_struct(inputs, num_input_sets) + # Construct input variable, the 5th parameter of the evaluate function + ct_inputs = self._get_run_input_struct(inputs, num_input_sets, 5) # Output ctype out_ty = bin_func.byref_arg_types[4] * num_evaluations