Skip to content

Commit

Permalink
llvm/Functions/GridSearch: Use Numpy structures for fixed sizes argum…
Browse files Browse the repository at this point in the history
…ents

function params and state are still left as ctypes as they can't be
easily reinitialized

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Aug 4, 2024
1 parent 6a394b2 commit 1c367c6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2103,22 +2103,25 @@ def _function(self,
# select_min params are:
# params, state, min_sample_ptr, sample_ptr, min_value_ptr, value_ptr, opt_count_ptr, count
min_tags = frozenset({"select_min", "evaluate_type_objective"})
bin_func = pnlvm.LLVMBinaryFunction.from_obj(self, tags=min_tags)
bin_func = pnlvm.LLVMBinaryFunction.from_obj(self, tags=min_tags, numpy_args=(2, 4, 6))

ct_param = bin_func.byref_arg_types[0](*self._get_param_initializer(context))
ct_state = bin_func.byref_arg_types[1](*self._get_state_initializer(context))
ct_opt_sample = bin_func.byref_arg_types[2](float("NaN"))
ct_alloc = None # NULL for samples
ct_opt_value = bin_func.byref_arg_types[4]()
ct_opt_count = bin_func.byref_arg_types[6](0)
ct_start = bin_func.c_func.argtypes[7](0)
ct_stop = bin_func.c_func.argtypes[8](num_values)

bin_func(ct_param, ct_state, ct_opt_sample, ct_alloc, ct_opt_value,
ct_values, ct_opt_count, ct_start, ct_stop)

optimal_value = ct_opt_value.value
optimal_sample = np.ctypeslib.as_array(ct_opt_sample)
optimal_sample = bin_func.np_buffer_for_arg(2)
optimal_value = bin_func.np_buffer_for_arg(4)
number_of_optimal_values = bin_func.np_buffer_for_arg(6, fill_value=0)

bin_func(ct_param,
ct_state,
optimal_sample,
None, # samples. NULL, it's generated by the function.
optimal_value,
ct_values,
number_of_optimal_values,
bin_func.c_func.argtypes[7](0), # start
bin_func.c_func.argtypes[8](num_values)) # stop

# Convert outputs to Numpy/Python
all_values = np.ctypeslib.as_array(ct_values)

# Python version
Expand Down
4 changes: 2 additions & 2 deletions psyneulink/core/llvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ def cuda_wrap_call(self, *args, **kwargs):
wrap_args = (jit_engine.pycuda.driver.InOut(a) if isinstance(a, np.ndarray) else a for a in args)
self.cuda_call(*wrap_args, **kwargs)

def np_buffer_for_arg(self, arg_num, *, extra_dimensions=()):
def np_buffer_for_arg(self, arg_num, *, extra_dimensions=(), fill_value=np.nan):

out_base = self.np_params[arg_num].base
out_shape = extra_dimensions + self.np_params[arg_num].shape

# fill the buffer with NaN poison
return np.full(out_shape, np.nan, dtype=out_base)
return np.full(out_shape, fill_value, dtype=out_base)

@staticmethod
@functools.lru_cache(maxsize=32)
Expand Down

0 comments on commit 1c367c6

Please sign in to comment.