diff --git a/psyneulink/core/components/component.py b/psyneulink/core/components/component.py index 8beb4fc7c6..23d1cfbf8a 100644 --- a/psyneulink/core/components/component.py +++ b/psyneulink/core/components/component.py @@ -528,14 +528,14 @@ Context, ContextError, ContextFlags, INITIALIZATION_STATUS_FLAGS, _get_time, handle_external_context from psyneulink.core.globals.mdf import MDFSerializable from psyneulink.core.globals.keywords import \ - CONTEXT, CONTROL_PROJECTION, DEFERRED_INITIALIZATION, EXECUTE_UNTIL_FINISHED, \ + CONTEXT, CONTROL_PROJECTION, DEFERRED_INITIALIZATION, DETERMINISTIC, EXECUTE_UNTIL_FINISHED, \ FUNCTION, FUNCTION_PARAMS, INIT_FULL_EXECUTE_METHOD, INPUT_PORTS, \ LEARNING, LEARNING_PROJECTION, MATRIX, MAX_EXECUTIONS_BEFORE_FINISHED, \ MODEL_SPEC_ID_PSYNEULINK, MODEL_SPEC_ID_METADATA, \ MODEL_SPEC_ID_INPUT_PORTS, MODEL_SPEC_ID_OUTPUT_PORTS, \ MODEL_SPEC_ID_MDF_VARIABLE, \ MODULATORY_SPEC_KEYWORDS, NAME, OUTPUT_PORTS, OWNER, PARAMS, PREFS_ARG, \ - RESET_STATEFUL_FUNCTION_WHEN, INPUT_SHAPES, VALUE, VARIABLE, SHARED_COMPONENT_TYPES + RANDOM, RESET_STATEFUL_FUNCTION_WHEN, INPUT_SHAPES, VALUE, VARIABLE, SHARED_COMPONENT_TYPES from psyneulink.core.globals.log import LogCondition from psyneulink.core.globals.parameters import \ Defaults, SharedParameter, Parameter, ParameterAlias, ParameterError, ParametersBase, check_user_specified, copy_parameter_value, is_array_like @@ -1391,6 +1391,9 @@ def _get_compilation_state(self): if cost_functions.DURATION not in cost_functions: blacklist.add('duration_cost_fct') + if getattr(self, "mode", None) == DETERMINISTIC and getattr(self, "tie", None) != RANDOM: + whitelist.remove('random_state') + # Drop previous_value from MemoryFunctions if hasattr(self.parameters, 'duplicate_keys'): blacklist.add("previous_value") @@ -1508,13 +1511,20 @@ def _get_compilation_params(self): "retain_torch_trained_outputs", "retain_torch_targets", "retain_torch_losses" "torch_trained_outputs", "torch_targets", "torch_losses", # should be added to relevant _gen_llvm_function... when aug: - # OneHot: - 'abs_val', 'indicator', # SoftMax: 'mask_threshold', 'adapt_scale', 'adapt_base', 'adapt_entropy_weighting', # LCAMechanism "mask" } + + # OneHot: + # * runtime abs_val and indicator are only used in deterministic mode. + # * random_state and seed are only used in RANDOM tie resolution. + if getattr(self, "mode", None) != DETERMINISTIC: + blacklist.update(['abs_val', 'indicator']) + elif getattr(self, "tie", None) != RANDOM: + blacklist.add("seed") + # Mechanism's need few extra entries: # * matrix -- is never used directly, and is flatened below # * integration_rate -- shape mismatch with param port input diff --git a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py index fb73085de6..defd01e050 100644 --- a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py @@ -187,7 +187,7 @@ class OneHot(SelectionFunction): First (possibly only) item specifies a template for the array to be transformed; if `mode ` is *PROB* then a 2nd item must be included that is a probability distribution with same length as 1st item. - mode : DETERMINISITC, PROB, PROB_INDICATOR, + mode : DETERMINISTiC, PROB, PROB_INDICATOR, ARG_MAX, ARG_MAX_ABS, ARG_MAX_INDICATOR, ARG_MAX_ABS_INDICATOR, ARG_MIN, ARG_MIN_ABS, ARG_MIN_INDICATOR, ARG_MIN_ABS_INDICATOR, MAX_VAL, MAX_ABS_VAL, MAX_INDICATOR, MAX_ABS_INDICATOR, @@ -237,7 +237,7 @@ class OneHot(SelectionFunction): distribution, each element of which specifies the probability for selecting the corresponding element of the 1st item. - mode : DETERMINISITC, PROB, PROB_INDICATOR, + mode : DETERMINISTIC, PROB, PROB_INDICATOR, ARG_MAX, ARG_MAX_ABS, ARG_MAX_INDICATOR, ARG_MAX_ABS_INDICATOR, ARG_MIN, ARG_MIN_ABS, ARG_MIN_INDICATOR, ARG_MIN_ABS_INDICATOR, MAX_VAL, MAX_ABS_VAL, MAX_INDICATOR, MAX_ABS_INDICATOR, @@ -421,86 +421,25 @@ def _validate_params(self, request_set, target_set=None, context=None): f"cannot be specified.") def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset): - best_idx_ptr = builder.alloca(ctx.int32_ty) - builder.store(best_idx_ptr.type.pointee(0), best_idx_ptr) - if self.mode in {PROB, PROB_INDICATOR}: + sum_ptr = builder.alloca(ctx.float_ty) builder.store(sum_ptr.type.pointee(-0.0), sum_ptr) - random_draw_ptr = builder.alloca(ctx.float_ty) rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params) rng_f = ctx.get_uniform_dist_function_by_state(rand_state_ptr) + random_draw_ptr = builder.alloca(rng_f.args[-1].type.pointee) builder.call(rng_f, [rand_state_ptr, random_draw_ptr]) random_draw = builder.load(random_draw_ptr) prob_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)]) arg_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)]) - with pnlvm.helpers.array_ptr_loop(builder, arg_in, "search") as (b1, idx): - best_idx = b1.load(best_idx_ptr) - best_ptr = b1.gep(arg_in, [ctx.int32_ty(0), best_idx]) - - current_ptr = b1.gep(arg_in, [ctx.int32_ty(0), idx]) - current = b1.load(current_ptr) - - if self.mode not in {PROB, PROB_INDICATOR}: - fabs = ctx.get_builtin("fabs", [current.type]) - - is_first = b1.icmp_unsigned("==", idx, idx.type(0)) - - # Allow the first element to win the comparison - prev_best = b1.select(is_first, best_ptr.type.pointee(float("NaN")), b1.load(best_ptr)) - - if self.mode == ARG_MAX: - cmp_op = ">" - cmp_prev = prev_best - cmp_curr = current - val = current - - elif self.mode == ARG_MAX_ABS: - cmp_op = ">" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = b1.call(fabs, [current]) - - elif self.mode == ARG_MAX_INDICATOR: - cmp_op = ">" - cmp_prev = prev_best - cmp_curr = current - val = current.type(1.0) - - elif self.mode == ARG_MAX_ABS_INDICATOR: - cmp_op = ">" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = current.type(1.0) - - elif self.mode == ARG_MIN: - cmp_op = "<" - cmp_prev = prev_best - cmp_curr = current - val = current - - elif self.mode == ARG_MIN_ABS: - cmp_op = "<" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = b1.call(fabs, [current]) - - elif self.mode == ARG_MIN_INDICATOR: - cmp_op = "<" - cmp_prev = prev_best - cmp_curr = current - val = current.type(1.0) - - elif self.mode == ARG_MIN_ABS_INDICATOR: - cmp_op = "<" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = current.type(1.0) - - elif self.mode in {PROB, PROB_INDICATOR}: + with pnlvm.helpers.array_ptr_loop(builder, arg_in, "search") as (b1, idx): + + current_ptr = b1.gep(arg_in, [ctx.int32_ty(0), idx]) + current = b1.load(current_ptr) + # Update prefix sum current_prob_ptr = b1.gep(prob_in, [ctx.int32_ty(0), idx]) sum_old = b1.load(sum_ptr) @@ -511,27 +450,125 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, new_above = b1.fcmp_ordered("<", random_draw, sum_new) cond = b1.and_(new_above, old_below) - cmp_prev = current.type(1.0) - cmp_curr = b1.select(cond, cmp_prev, cmp_prev.type(0.0)) - cmp_op = "==" if self.mode == PROB: val = current else: val = current.type(1.0) - else: - assert False, "Unsupported mode in LLVM: {} for OneHot Function".format(self.mode) - prev_res_ptr = b1.gep(arg_out, [ctx.int32_ty(0), best_idx]) - cur_res_ptr = b1.gep(arg_out, [ctx.int32_ty(0), idx]) + write_val = b1.select(cond, val, val.type(0.0)) + cur_res_ptr = b1.gep(arg_out, [ctx.int32_ty(0), idx]) + builder.store(write_val, cur_res_ptr) + + return builder + + elif self.mode == DETERMINISTIC: + direction = self.direction + tie = self.tie + abs_val_ptr = ctx.get_param_or_state_ptr(builder, self, self.parameters.abs_val, param_struct_ptr=params) + indicator_ptr = ctx.get_param_or_state_ptr(builder, self, self.parameters.indicator, param_struct_ptr=params) + + abs_val = builder.load(abs_val_ptr) + is_abs_val = builder.fcmp_unordered("!=", abs_val, abs_val.type(0)) + + indicator = builder.load(indicator_ptr) + is_indicator = builder.fcmp_unordered("!=", indicator, indicator.type(0)) + + else: + direction, abs_val, indicator, tie = self._parse_mode(self.mode) + is_abs_val = ctx.bool_ty(abs_val) + is_indicator = ctx.bool_ty(indicator) + + num_extremes_ptr = builder.alloca(ctx.int32_ty) + builder.store(num_extremes_ptr.type.pointee(0), num_extremes_ptr) + + extreme_val_ptr = builder.alloca(ctx.float_ty) + builder.store(extreme_val_ptr.type.pointee(float("NaN")), extreme_val_ptr) + + fabs_f = ctx.get_builtin("fabs", [extreme_val_ptr.type.pointee]) + + with pnlvm.helpers.recursive_iterate_arrays(ctx, builder, arg_in, loop_id="count_extremes") as (loop_builder, current_ptr): + + current = loop_builder.load(current_ptr) + current_abs = loop_builder.call(fabs_f, [current]) + current = builder.select(is_abs_val, current_abs, current) + + old_extreme = loop_builder.load(extreme_val_ptr) + cmp_op = ">" if direction == MAX else "<" + is_new_extreme = loop_builder.fcmp_unordered(cmp_op, current, old_extreme) + + with loop_builder.if_then(is_new_extreme): + loop_builder.store(current, extreme_val_ptr) + loop_builder.store(num_extremes_ptr.type.pointee(1), num_extremes_ptr) + + is_old_extreme = loop_builder.fcmp_ordered("==", current, old_extreme) + with loop_builder.if_then(is_old_extreme): + extreme_count = loop_builder.load(num_extremes_ptr) + extreme_count = loop_builder.add(extreme_count, extreme_count.type(1)) + loop_builder.store(extreme_count, num_extremes_ptr) + + + if tie == FIRST: + extreme_start = num_extremes_ptr.type.pointee(0) + extreme_stop = num_extremes_ptr.type.pointee(1) + + elif tie == LAST: + extreme_stop = builder.load(num_extremes_ptr) + extreme_start = builder.sub(extreme_stop, extreme_stop.type(1)) + + elif tie == ALL: + extreme_start = num_extremes_ptr.type.pointee(0) + extreme_stop = builder.load(num_extremes_ptr) + + elif tie == RANDOM: + rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params) + rand_f = ctx.get_rand_int_function_by_state(rand_state_ptr) + random_draw_ptr = builder.alloca(rand_f.args[-1].type.pointee) + num_extremes = builder.load(num_extremes_ptr) + + builder.call(rand_f, [rand_state_ptr, ctx.int32_ty(0), num_extremes, random_draw_ptr]) + + extreme_start = builder.load(random_draw_ptr) + extreme_start = builder.trunc(extreme_start, ctx.int32_ty) + extreme_stop = builder.add(extreme_start, extreme_start.type(1)) - # Make sure other elements are zeroed - builder.store(cur_res_ptr.type.pointee(0), cur_res_ptr) + else: + assert False, "Unknown tie resolution: {}".format(tie) + + + extreme_val = builder.load(extreme_val_ptr) + extreme_write_val = builder.select(is_indicator, extreme_val.type(1), extreme_val) + next_extreme_ptr = builder.alloca(num_extremes_ptr.type.pointee) + builder.store(next_extreme_ptr.type.pointee(0), next_extreme_ptr) + + pnlvm.helpers.printf(ctx, + builder, + "{} replacing extreme values of %e from <%u,%u) out of %u\n".format(self.name), + extreme_val, + extreme_start, + extreme_stop, + builder.load(num_extremes_ptr), + tags={"one_hot"}) + + with pnlvm.helpers.recursive_iterate_arrays(ctx, builder, arg_in, arg_out, loop_id="mark_extremes") as (loop_builder, current_ptr, out_ptr): + current = loop_builder.load(current_ptr) + current_abs = loop_builder.call(fabs_f, [current]) + current = builder.select(is_abs_val, current_abs, current) + + is_extreme = loop_builder.fcmp_ordered("==", current, extreme_val) + current_extreme_idx = loop_builder.load(next_extreme_ptr) + + with loop_builder.if_then(is_extreme): + next_extreme_idx = loop_builder.add(current_extreme_idx, current_extreme_idx.type(1)) + loop_builder.store(next_extreme_idx, next_extreme_ptr) - cmp_res = builder.fcmp_unordered(cmp_op, cmp_curr, cmp_prev) - with builder.if_then(cmp_res): - builder.store(prev_res_ptr.type.pointee(0), prev_res_ptr) - builder.store(val, cur_res_ptr) - builder.store(idx, best_idx_ptr) + is_after_start = loop_builder.icmp_unsigned(">=", current_extreme_idx, extreme_start) + is_before_stop = loop_builder.icmp_unsigned("<", current_extreme_idx, extreme_stop) + + should_write_extreme = loop_builder.and_(is_extreme, is_after_start) + should_write_extreme = loop_builder.and_(should_write_extreme, is_before_stop) + + write_value = loop_builder.select(should_write_extreme, extreme_write_val, extreme_write_val.type(0)) + loop_builder.store(write_value, out_ptr) return builder @@ -641,6 +678,9 @@ def _parse_mode(self, mode): indicator = True tie = ALL + else: + assert False, f"Unknown mode: {mode}" + return direction, abs_val, indicator, tie def _function(self, @@ -693,65 +733,62 @@ def _function(self, random_value = random_state.uniform() chosen_item = next(element for element in cum_sum if element > random_value) chosen_in_cum_sum = np.where(cum_sum == chosen_item, 1, 0) - if mode is PROB: + if mode == PROB: result = v * chosen_in_cum_sum else: result = np.ones_like(v) * chosen_in_cum_sum + # chosen_item = np.random.choice(v, 1, p=prob_dist) # one_hot_indicator = np.where(v == chosen_item, 1, 0) # return v * one_hot_indicator return result - elif mode is not DETERMINISTIC: + elif mode != DETERMINISTIC: direction, abs_val, indicator, tie = self._parse_mode(mode) - # if np.array(variable).ndim != 1: - # raise FunctionError(f"If {MODE} for {self.__class__.__name__} {Function.__name__} is not set to " - # f"'PROB' or 'PROB_INDICATOR', variable must be a 1d array: {variable}.") - array = variable - max = None - min = None - if abs_val is True: - array = np.absolute(array) + array = np.absolute(variable) if direction == MAX: - max = np.max(array) - if max == -np.inf: - warnings.warn(f"Array passed to {self.name} of {self.owner.name} " - f"is all -inf.") + extreme_val = np.max(array) + if extreme_val == -np.inf: + warnings.warn(f"Array passed to {self.name} of {self.owner.name} is all -inf.") + + elif direction == MIN: + extreme_val = np.min(array) + if extreme_val == np.inf: + warnings.warn(f"Array passed to {self.name} of {self.owner.name} is all inf.") + else: - min = np.min(array) - if min == np.inf: - warnings.warn(f"Array passed to {self.name} of {self.owner.name} " - f"is all inf.") + assert False, f"Unknown direction: '{direction}'." - extreme_val = max if direction == MAX else min + extreme_indices = np.where(array == extreme_val) + + num_indices = len(extreme_indices[0]) + assert all(len(idx) == num_indices for idx in extreme_indices) + + if tie == FIRST: + selected_idx = 0 + + elif tie == LAST: + selected_idx = -1 + + elif tie == RANDOM: + random_state = self._get_current_parameter_value("random_state", context) + selected_idx = random_state.randint(num_indices) + + elif tie == ALL: + selected_idx = slice(num_indices) - if tie == ALL: - if direction == MAX: - result = np.where(array == max, max, -np.inf) - else: - result = np.where(array == min, min, np.inf) else: - if tie == FIRST: - index = np.min(np.where(array == extreme_val)) - elif tie == LAST: - index = np.max(np.where(array == extreme_val)) - elif tie == RANDOM: - index = np.random.choice(np.where(array == extreme_val)) - else: - assert False, f"PROGRAM ERROR: Unrecognized value for 'tie' in OneHot function: '{tie}'." - result = np.zeros_like(array) - result[index] = extreme_val - - if indicator is True: - result = np.where(result == extreme_val, 1, result) - if max is not None: - result = np.where(result == -np.inf, 0, result) - if min is not None: - result = np.where(result == np.inf, 0, result) + assert False, f"PROGRAM ERROR: Unrecognized value for 'tie' in OneHot function: '{tie}'." + + + set_indices = tuple(index[selected_idx] for index in extreme_indices) + + result = np.zeros_like(variable) + result[set_indices] = 1 if indicator else extreme_val return self.convert_output_type(result) diff --git a/psyneulink/core/components/functions/stateful/memoryfunctions.py b/psyneulink/core/components/functions/stateful/memoryfunctions.py index cd97cdb418..7466f968ab 100644 --- a/psyneulink/core/components/functions/stateful/memoryfunctions.py +++ b/psyneulink/core/components/functions/stateful/memoryfunctions.py @@ -2372,6 +2372,7 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, max_entries = len(vals_ptr.type.pointee) entries = builder.load(count_ptr) entries = pnlvm.helpers.uint_min(builder, entries, max_entries) + # The call to random function needs to be after check to match python with builder.if_then(retr_rand): rand_ptr = builder.alloca(ctx.float_ty) @@ -2385,53 +2386,41 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, with builder.if_then(retr, likely=True): # Determine distances distance_f = ctx.import_llvm_function(self.distance_function) - distance_params, distance_state = ctx.get_param_or_state_ptr(builder, self, "distance_function", param_struct_ptr=params, state_struct_ptr=state) + distance_params, distance_state = ctx.get_param_or_state_ptr(builder, + self, + "distance_function", + param_struct_ptr=params, + state_struct_ptr=state) distance_arg_in = builder.alloca(distance_f.args[2].type.pointee) - builder.store(builder.load(var_key_ptr), - builder.gep(distance_arg_in, [ctx.int32_ty(0), - ctx.int32_ty(0)])) + builder.store(builder.load(var_key_ptr), builder.gep(distance_arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])) selection_arg_in = builder.alloca(pnlvm.ir.ArrayType(distance_f.args[3].type.pointee, max_entries)) with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b, idx): compare_ptr = b.gep(keys_ptr, [ctx.int32_ty(0), idx]) - b.store(b.load(compare_ptr), - b.gep(distance_arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)])) + b.store(b.load(compare_ptr), b.gep(distance_arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)])) distance_arg_out = b.gep(selection_arg_in, [ctx.int32_ty(0), idx]) - b.call(distance_f, [distance_params, distance_state, - distance_arg_in, distance_arg_out]) - - # MODIFIED 10/13/24 NEW: - # IMPLEMENTATION NOTE: - # REPLACE MIN_VAL with ARG_MIN and MIN_INDICATOR with ARG_MIN_INDICATOR - # until the MIN_XXX args are implemented in LLVM - # since, at present, the tests don't seem to distinguish between these (i.e., return of multiple values; - # should add tests that do so once MIN_VAL and related args are implemented in LLVM) - if isinstance(self.selection_function, OneHot): - mode = self.selection_function.mode - if mode == MIN_VAL: - self.selection_function.mode = ARG_MIN - elif mode == MIN_INDICATOR: - self.selection_function.mode = ARG_MIN_INDICATOR - # MODIFIED 10/13/24 END + b.call(distance_f, [distance_params, distance_state, distance_arg_in, distance_arg_out]) + selection_f = ctx.import_llvm_function(self.selection_function) - selection_params, selection_state = ctx.get_param_or_state_ptr(builder, self, "selection_function", param_struct_ptr=params, state_struct_ptr=state) + selection_params, selection_state = ctx.get_param_or_state_ptr(builder, + self, + "selection_function", + param_struct_ptr=params, + state_struct_ptr=state) selection_arg_out = builder.alloca(selection_f.args[3].type.pointee) - builder.call(selection_f, [selection_params, selection_state, - selection_arg_in, selection_arg_out]) + builder.call(selection_f, [selection_params, selection_state, selection_arg_in, selection_arg_out]) # Find the selected index selected_idx_ptr = builder.alloca(ctx.int32_ty) builder.store(ctx.int32_ty(0), selected_idx_ptr) - with pnlvm.helpers.for_loop_zero_inc(builder, entries, "distance_loop") as (b,idx): + with pnlvm.helpers.for_loop_zero_inc(builder, entries, "selection_loop") as (b, idx): selection_val = b.load(b.gep(selection_arg_out, [ctx.int32_ty(0), idx])) non_zero = b.fcmp_ordered('!=', selection_val, selection_val.type(0)) with b.if_then(non_zero): b.store(idx, selected_idx_ptr) selected_idx = builder.load(selected_idx_ptr) - selected_key = builder.load(builder.gep(keys_ptr, [ctx.int32_ty(0), - selected_idx])) - selected_val = builder.load(builder.gep(vals_ptr, [ctx.int32_ty(0), - selected_idx])) + selected_key = builder.load(builder.gep(keys_ptr, [ctx.int32_ty(0), selected_idx])) + selected_val = builder.load(builder.gep(vals_ptr, [ctx.int32_ty(0), selected_idx])) builder.store(selected_key, out_key_ptr) builder.store(selected_val, out_val_ptr) diff --git a/psyneulink/core/llvm/builder_context.py b/psyneulink/core/llvm/builder_context.py index 0dcb6bae85..7fcd4224cd 100644 --- a/psyneulink/core/llvm/builder_context.py +++ b/psyneulink/core/llvm/builder_context.py @@ -210,29 +210,46 @@ def init_builtins(self): if "time_stat" in debug_env: print("Time to setup PNL builtins: {}".format(finish - start)) + def get_rand_int_function_by_state(self, state): + if len(state.type.pointee) == 5: + return self.import_llvm_function("__pnl_builtin_mt_rand_int32_bounded") + + elif len(state.type.pointee) == 7: + # we have different versions based on selected FP precision + return self.import_llvm_function("__pnl_builtin_philox_rand_int32_bounded") + + else: + assert False, "Unknown PRNG type!" + def get_uniform_dist_function_by_state(self, state): if len(state.type.pointee) == 5: return self.import_llvm_function("__pnl_builtin_mt_rand_double") + elif len(state.type.pointee) == 7: # we have different versions based on selected FP precision return self.import_llvm_function("__pnl_builtin_philox_rand_{}".format(str(self.float_ty))) + else: assert False, "Unknown PRNG type!" def get_binomial_dist_function_by_state(self, state): if len(state.type.pointee) == 5: return self.import_llvm_function("__pnl_builtin_mt_rand_binomial") + elif len(state.type.pointee) == 7: return self.import_llvm_function("__pnl_builtin_philox_rand_binomial") + else: assert False, "Unknown PRNG type!" def get_normal_dist_function_by_state(self, state): if len(state.type.pointee) == 5: return self.import_llvm_function("__pnl_builtin_mt_rand_normal") + elif len(state.type.pointee) == 7: # Normal exists only for self.float_ty return self.import_llvm_function("__pnl_builtin_philox_rand_normal") + else: assert False, "Unknown PRNG type!" diff --git a/psyneulink/core/llvm/builtins.py b/psyneulink/core/llvm/builtins.py index 20920ccf59..f965c819ad 100644 --- a/psyneulink/core/llvm/builtins.py +++ b/psyneulink/core/llvm/builtins.py @@ -17,8 +17,10 @@ def _setup_builtin_func_builder(ctx, name, args, *, return_type=ir.VoidType()): - builder = ctx.create_llvm_function(args, None, _BUILTIN_PREFIX + name, - return_type=return_type) + if not name.startswith(_BUILTIN_PREFIX): + name = _BUILTIN_PREFIX + name + + builder = ctx.create_llvm_function(args, None, name, return_type=return_type) # Add noalias attribute for a in builder.function.args: @@ -656,10 +658,11 @@ def _setup_mt_rand_init(ctx, state_ty, init_scalar): return builder.function -def _setup_mt_rand_integer(ctx, state_ty): +def _setup_mt_rand_int32(ctx, state_ty): int64_ty = ir.IntType(64) + # Generate random number generator function. - # It produces random 32bit numberin a 64bit word + # It produces random 32bit number in a 64bit word builder = _setup_builtin_func_builder(ctx, "mt_rand_int32", (state_ty.as_pointer(), int64_ty.as_pointer())) state, out = builder.function.args @@ -758,6 +761,42 @@ def _setup_mt_rand_integer(ctx, state_ty): return builder.function +def _setup_rand_bounded_int32(ctx, state_ty, gen_int32): + + out_ty = gen_int32.args[1].type.pointee + builder = _setup_builtin_func_builder(ctx, gen_int32.name + "_bounded", (state_ty.as_pointer(), ctx.int32_ty, ctx.int32_ty, out_ty.as_pointer())) + state, lower, upper, out_ptr = builder.function.args + + rand_range_excl = builder.sub(upper, lower) + rand_range_excl = builder.zext(rand_range_excl, out_ty) + + range_leading_zeros = builder.ctlz(rand_range_excl, ctx.bool_ty(1)) + mask = builder.lshr(range_leading_zeros.type(-1), range_leading_zeros) + + loop_block = builder.append_basic_block("bounded_loop_block") + out_block = builder.append_basic_block("bounded_out_block") + + builder.branch(loop_block) + + # Loop: + # do: + # r = random() & mask + # while r >= limit + builder.position_at_end(loop_block) + + builder.call(gen_int32, [state, out_ptr]) + val = builder.load(out_ptr) + val = builder.and_(val, mask) + + is_above_limit = builder.icmp_unsigned(">=", val, rand_range_excl) + builder.cbranch(is_above_limit, loop_block, out_block) + + builder.position_at_end(out_block) + offset = builder.zext(lower, val.type) + result = builder.add(val, offset) + builder.store(result, out_ptr) + builder.ret_void() + def _setup_mt_rand_float(ctx, state_ty, gen_int): """ Mersenne Twister double prcision random number generation. @@ -892,8 +931,9 @@ def setup_mersenne_twister(ctx): init_scalar = _setup_mt_rand_init_scalar(ctx, state_ty) _setup_mt_rand_init(ctx, state_ty, init_scalar) - gen_int = _setup_mt_rand_integer(ctx, state_ty) - gen_float = _setup_mt_rand_float(ctx, state_ty, gen_int) + gen_int32 = _setup_mt_rand_int32(ctx, state_ty) + _setup_rand_bounded_int32(ctx, state_ty, gen_int32) + gen_float = _setup_mt_rand_float(ctx, state_ty, gen_int32) _setup_mt_rand_normal(ctx, state_ty, gen_float) _setup_rand_binomial(ctx, state_ty, gen_float, prefix="mt") @@ -1138,6 +1178,88 @@ def _setup_philox_rand_int32(ctx, state_ty, gen_int64): return builder.function +def _setup_rand_lemire_int32(ctx, state_ty, gen_int32): + """ + Uses Lemire's algorithm - https://arxiv.org/abs/1805.10941 + As implemented in Numpy to match Numpy results. + """ + + out_ty = gen_int32.args[1].type.pointee + builder = _setup_builtin_func_builder(ctx, gen_int32.name + "_bounded", (state_ty.as_pointer(), out_ty, out_ty, out_ty.as_pointer())) + state, lower, upper, out_ptr = builder.function.args + + rand_range_excl = builder.sub(upper, lower) + rand_range_excl_64 = builder.zext(rand_range_excl, ir.IntType(64)) + rand_range = builder.sub(rand_range_excl, rand_range_excl.type(1)) + + + builder.call(gen_int32, [state, out_ptr]) + val = builder.load(out_ptr) + + is_full_range = builder.icmp_unsigned("==", rand_range, rand_range.type(0xffffffff)) + with builder.if_then(is_full_range): + builder.ret_void() + + val64 = builder.zext(val, rand_range_excl_64.type) + m = builder.mul(val64, rand_range_excl_64) + + # Store current result as output. It will be overwritten below if needed. + out_val = builder.lshr(m, m.type(32)) + out_val = builder.trunc(out_val, out_ptr.type.pointee) + out_val = builder.add(out_val, lower) + builder.store(out_val, out_ptr) + + leftover = builder.and_(m, m.type(0xffffffff)) + + is_good = builder.icmp_unsigned(">=", leftover, rand_range_excl_64) + with builder.if_then(is_good): + builder.ret_void() + + # Apply rejection sampling + leftover_ptr = builder.alloca(leftover.type) + builder.store(leftover, leftover_ptr) + + rand_range_64 = builder.zext(rand_range, ir.IntType(64)) + threshold = builder.sub(rand_range_64, rand_range_64.type(0xffffffff)) + threshold = builder.urem(threshold, rand_range_excl_64) + + cond_block = builder.append_basic_block("bounded_cond_block") + loop_block = builder.append_basic_block("bounded_loop_block") + out_block = builder.append_basic_block("bounded_out_block") + + builder.branch(cond_block) + + # Condition: leftover < threshold + builder.position_at_end(cond_block) + leftover = builder.load(leftover_ptr) + do_next = builder.icmp_unsigned("<", leftover, threshold) + builder.cbranch(do_next, loop_block, out_block) + + # Loop block: + # m = ((uint64_t)next_uint32(bitgen_state)) * rng_excl; + # leftover = m & 0xffffffff + # result = m >> 32 + builder.position_at_end(loop_block) + builder.call(gen_int32, [state, out_ptr]) + + val = builder.load(out_ptr) + val64 = builder.zext(val, rand_range_excl_64.type) + m = builder.mul(val64, rand_range_excl_64) + + leftover = builder.and_(m, m.type(0xffffffff)) + builder.store(leftover, leftover_ptr) + + out_val = builder.lshr(m, m.type(32)) + out_val = builder.trunc(out_val, out_ptr.type.pointee) + out_val = builder.add(out_val, lower) + builder.store(out_val, out_ptr) + builder.branch(cond_block) + + + builder.position_at_end(out_block) + builder.ret_void() + + def _setup_philox_rand_double(ctx, state_ty, gen_int64): # Generate random float number generator function double_ty = ir.DoubleType() @@ -2087,6 +2209,7 @@ def setup_philox(ctx): _setup_rand_binomial(ctx, state_ty, gen_double, prefix="philox") gen_int32 = _setup_philox_rand_int32(ctx, state_ty, gen_int64) + _setup_rand_lemire_int32(ctx, state_ty, gen_int32) gen_float = _setup_philox_rand_float(ctx, state_ty, gen_int32) _setup_philox_rand_normal(ctx, state_ty, gen_float, gen_int32, _wi_float_data, _ki_i32_data, _fi_float_data) _setup_rand_binomial(ctx, state_ty, gen_float, prefix="philox") diff --git a/psyneulink/core/llvm/helpers.py b/psyneulink/core/llvm/helpers.py index c3dc3336bf..7d7b7df10a 100644 --- a/psyneulink/core/llvm/helpers.py +++ b/psyneulink/core/llvm/helpers.py @@ -377,23 +377,23 @@ def array_from_shape(shape, element_ty): array_ty = ir.ArrayType(array_ty, dim) return array_ty -def recursive_iterate_arrays(ctx, builder, u, *args): +@contextmanager +def recursive_iterate_arrays(ctx, builder, *args, loop_id="recursive_iteration"): """Recursively iterates over all elements in scalar arrays of the same shape""" - assert isinstance(u.type.pointee, ir.ArrayType), "Can only iterate over arrays!" + + assert len(args) > 0, "Need at least one array to iterate over!" + assert all(isinstance(arr.type.pointee, ir.ArrayType) for arr in args), "Can only iterate over arrays!" + + u = args[0] assert all(len(u.type.pointee) == len(v.type.pointee) for v in args), "Tried to iterate over differing lengths!" - with array_ptr_loop(builder, u, "recursive_iteration") as (b, idx): - u_ptr = b.gep(u, [ctx.int32_ty(0), idx]) - arg_ptrs = (b.gep(v, [ctx.int32_ty(0), idx]) for v in args) - if is_scalar(u_ptr): - yield (u_ptr, *arg_ptrs) - else: - yield from recursive_iterate_arrays(ctx, b, u_ptr, *arg_ptrs) -# TODO: Remove this function. Can be replaced by `recursive_iterate_arrays` -def call_elementwise_operation(ctx, builder, x, operation, output_ptr): - """Recurse through an array structure and call operation on each scalar element of the structure. Store result in output_ptr""" - for (inp_ptr, out_ptr) in recursive_iterate_arrays(ctx, builder, x, output_ptr): - builder.store(operation(ctx, builder, builder.load(inp_ptr)), out_ptr) + with array_ptr_loop(builder, u, loop_id) as (b, idx): + arg_ptrs = tuple(b.gep(arr, [ctx.int32_ty(0), idx]) for arr in args) + if is_scalar(arg_ptrs[0]): + yield (b, *arg_ptrs) + else: + with recursive_iterate_arrays(ctx, b, *arg_ptrs) as (b, *nested_args): + yield (b, *nested_args) def printf(ctx, builder, fmt, *args, tags:set): diff --git a/tests/composition/test_emcomposition.py b/tests/composition/test_emcomposition.py index a46411dea3..d2af70bee9 100644 --- a/tests/composition/test_emcomposition.py +++ b/tests/composition/test_emcomposition.py @@ -221,34 +221,37 @@ def test_memory_fill(start, memory_fill): elif repeat and repeat < memory_capacity: # Multi-entry specification and repeat = number entries; remainder test_memory_fill(start=repeat, memory_fill=memory_fill) - def test_softmax_choice(self): - for softmax_choice in [pnl.WEIGHTED_AVG, pnl.ARG_MAX, pnl.PROBABILISTIC]: - em = EMComposition(memory_template=[[[1,.1,.1]], [[1,.1,.1]], [[.1,.1,1]]], - softmax_choice=softmax_choice, - enable_learning=False) - result = em.run(inputs={em.query_input_nodes[0]:[[1,0,0]]}) - if softmax_choice == pnl.WEIGHTED_AVG: - np.testing.assert_allclose(result, [[0.93016008, 0.1, 0.16983992]]) - if softmax_choice == pnl.ARG_MAX: - np.testing.assert_allclose(result, [[1, .1, .1]]) - if softmax_choice == pnl.PROBABILISTIC: # NOTE: actual stochasticity not tested here - np.testing.assert_allclose(result, [[1, .1, .1]]) - - em = EMComposition(memory_template=[[[1,.1,.1]], [[.1,1,.1]], [[.1,.1,1]]]) - for softmax_choice in [pnl.ARG_MAX, pnl.PROBABILISTIC]: - with pytest.raises(EMCompositionError) as error_text: - em.parameters.softmax_choice.set(softmax_choice) - em.learn() - assert (f"The ARG_MAX and PROBABILISTIC options for the 'softmax_choice' arg " - f"of '{em.name}' cannot be used during learning; change to WEIGHTED_AVG." in str(error_text.value)) - - for softmax_choice in [pnl.ARG_MAX, pnl.PROBABILISTIC]: - with pytest.warns(UserWarning) as warning: - em = EMComposition(softmax_choice=softmax_choice, enable_learning=True) - warning_msg = (f"The 'softmax_choice' arg of '{em.name}' is set to '{softmax_choice}' with " - f"'enable_learning' set to True (or a list); this will generate an error if its " - f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") - assert warning_msg in str(warning[0].message) + @pytest.mark.parametrize("softmax_choice, expected", + [(pnl.WEIGHTED_AVG, [[0.93016008, 0.1, 0.16983992]]), + (pnl.ARG_MAX, [[1, .1, .1]]), + (pnl.PROBABILISTIC, [[1, .1, .1]]), # NOTE: actual stochasticity not tested here + ]) + def test_softmax_choice(self, softmax_choice, expected): + em = EMComposition(memory_template=[[[1,.1,.1]], [[1,.1,.1]], [[.1,.1,1]]], + softmax_choice=softmax_choice, + enable_learning=False) + result = em.run(inputs={em.query_input_nodes[0]:[[1,0,0]]}) + + np.testing.assert_allclose(result, expected) + + @pytest.mark.parametrize("softmax_choice", [pnl.ARG_MAX, pnl.PROBABILISTIC]) + def test_softmax_choice_error(self, softmax_choice): + em = EMComposition(memory_template=[[[1, .1, .1]], [[.1, 1, .1]], [[.1, .1, 1]]]) + msg = (f"The ARG_MAX and PROBABILISTIC options for the 'softmax_choice' arg " + f"of '{em.name}' cannot be used during learning; change to WEIGHTED_AVG.") + + with pytest.raises(EMCompositionError, match=msg): + em.parameters.softmax_choice.set(softmax_choice) + em.learn() + + @pytest.mark.parametrize("softmax_choice", [pnl.ARG_MAX, pnl.PROBABILISTIC]) + def test_softmax_choice_warn(self, softmax_choice): + warning_msg = (f"The 'softmax_choice' arg of '.*' is set to '{softmax_choice}' with " + f"'enable_learning' set to True \\(or a list\\); this will generate an error if its " + f"'learn' method is called. Set 'softmax_choice' to WEIGHTED_AVG before learning.") + + with pytest.warns(UserWarning, match=warning_msg): + EMComposition(softmax_choice=softmax_choice, enable_learning=True) @pytest.mark.pytorch diff --git a/tests/functions/test_selection.py b/tests/functions/test_selection.py index aa238af2bf..dea0ab9e06 100644 --- a/tests/functions/test_selection.py +++ b/tests/functions/test_selection.py @@ -1,7 +1,7 @@ import numpy as np import pytest -import psyneulink.core.components.functions.nonstateful.selectionfunctions as Functions +import psyneulink as pnl import psyneulink.core.globals.keywords as kw from psyneulink.core.globals.utilities import _SeededPhilox @@ -23,34 +23,34 @@ llvm_res['fp32'][expected_philox_ind] = (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) test_data = [ - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MAX}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot ARG_MAX"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MAX_ABS}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot ARG MAX_ABS"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.ARG_MAX_ABS}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot ARG MAX_ABS Neg"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot ARG_MAX_INDICATOR"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot ARG_MAX_ABS_INDICATOR"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.ARG_MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot ARG_MAX_ABS_INDICATOR Neg"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MIN}, (0., 0., 0., 0., 0., 0., 0., 0., 0, -0.23311696), id="OneHot ARG_MIN"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MIN_ABS}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.ARG_MIN_ABS}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS Neg"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), id="OneHot ARG_MIN_INDICATOR"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR Neg"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MAX_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_VAL"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_VAL"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_VAL Neg"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_INDICATOR"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_INDICATOR"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_INDICATOR Neg"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MIN_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0., -0.23311696), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_VAL"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_VAL"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_VAL Neg"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_INDICATOR"), - pytest.param(Functions.OneHot, test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_INDICATOR"), - pytest.param(Functions.OneHot, -test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_INDICATOR Neg"), - pytest.param(Functions.OneHot, [test_var, test_prob], {'mode':kw.PROB}, (0., 0., 0., 0.08976636599379373, 0., 0., 0., 0., 0., 0.), id="OneHot PROB"), - pytest.param(Functions.OneHot, [test_var, test_prob], {'mode':kw.PROB_INDICATOR}, (0., 0., 0., 1., 0., 0., 0., 0., 0., 0.), id="OneHot PROB_INDICATOR"), - pytest.param(Functions.OneHot, [test_var, test_philox], {'mode':kw.PROB}, expected_philox_prob, id="OneHot PROB Philox"), - pytest.param(Functions.OneHot, [test_var, test_philox], {'mode':kw.PROB_INDICATOR}, expected_philox_ind, id="OneHot PROB_INDICATOR Philox"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MAX}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot ARG_MAX"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MAX_ABS}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot ARG MAX_ABS"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.ARG_MAX_ABS}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot ARG MAX_ABS Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot ARG_MAX_INDICATOR"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot ARG_MAX_ABS_INDICATOR"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.ARG_MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot ARG_MAX_ABS_INDICATOR Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN}, (0., 0., 0., 0., 0., 0., 0., 0., 0, -0.23311696), id="OneHot ARG_MIN"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN_ABS}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.ARG_MIN_ABS}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), id="OneHot ARG_MIN_INDICATOR"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_VAL"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_ABS_VAL"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_ABS_VAL Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_INDICATOR"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_ABS_INDICATOR"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_ABS_INDICATOR Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0., -0.23311696), id="OneHot MIN_VAL"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_VAL"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_VAL Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), id="OneHot MIN_INDICATOR"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_INDICATOR"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_INDICATOR Neg"), + pytest.param(pnl.OneHot, [test_var, test_prob], {'mode':kw.PROB}, (0., 0., 0., 0.08976636599379373, 0., 0., 0., 0., 0., 0.), id="OneHot PROB"), + pytest.param(pnl.OneHot, [test_var, test_prob], {'mode':kw.PROB_INDICATOR}, (0., 0., 0., 1., 0., 0., 0., 0., 0., 0.), id="OneHot PROB_INDICATOR"), + pytest.param(pnl.OneHot, [test_var, test_philox], {'mode':kw.PROB}, expected_philox_prob, id="OneHot PROB Philox"), + pytest.param(pnl.OneHot, [test_var, test_philox], {'mode':kw.PROB_INDICATOR}, expected_philox_ind, id="OneHot PROB_INDICATOR Philox"), ] @@ -77,3 +77,198 @@ def test_basic(func, variable, params, expected, benchmark, func_mode): res = benchmark(EX, variable) np.testing.assert_allclose(res, expected) + + +test_var3 = np.append(np.append(test_var, test_var), test_var) +test_var_2d = np.atleast_2d(test_var) +test_var3_2d = np.append(np.append(test_var_2d, test_var_2d, axis=0), test_var_2d, axis=0) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("variable, direction, abs_val, tie, expected", +[ + # simple + *[(test_var, kw.MAX, "absolute", tie, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(test_var, kw.MAX, "original", tie, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(test_var, kw.MIN, "absolute", tie, [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(test_var, kw.MIN, "original", tie, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + + # negated + *[(-test_var, kw.MAX, "absolute", tie, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(-test_var, kw.MAX, "original", tie, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(-test_var, kw.MIN, "absolute", tie, [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(-test_var, kw.MIN, "original", tie, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + + # 2d + *[(test_var_2d, kw.MAX, "absolute", tie, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(test_var_2d, kw.MAX, "original", tie, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(test_var_2d, kw.MIN, "absolute", tie, [[0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(test_var_2d, kw.MIN, "original", tie, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + + # 2d negated + *[(-test_var_2d, kw.MAX, "absolute", tie, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(-test_var_2d, kw.MAX, "original", tie, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(-test_var_2d, kw.MIN, "absolute", tie, [[0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + *[(-test_var_2d, kw.MIN, "original", tie, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0]]) for tie in [kw.FIRST,kw.LAST,kw.RANDOM,kw.ALL]], + + # multiple extreme values + *[(test_var3, kw.MAX, abs_val, kw.FIRST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + for abs_val in ("absolute", "original")], + *[(test_var3, kw.MAX, abs_val, kw.LAST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]) + for abs_val in ("absolute", "original")], + *[(test_var3, kw.MAX, abs_val, kw.RANDOM, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + for abs_val in ("absolute", "original")], + *[(test_var3, kw.MAX, abs_val, kw.ALL, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]) + for abs_val in ("absolute", "original")], + + (test_var3, kw.MIN, "absolute", kw.FIRST, [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (test_var3, kw.MIN, "absolute", kw.LAST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (test_var3, kw.MIN, "absolute", kw.RANDOM, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (test_var3, kw.MIN, "absolute", kw.ALL, [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + + (test_var3, kw.MIN, "original", kw.FIRST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (test_var3, kw.MIN, "original", kw.LAST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446]), + (test_var3, kw.MIN, "original", kw.RANDOM, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (test_var3, kw.MIN, "original", kw.ALL, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446]), + + # multiple extreme values negated + (-test_var3, kw.MAX, "absolute", kw.FIRST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MAX, "absolute", kw.LAST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]), + (-test_var3, kw.MAX, "absolute", kw.RANDOM, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MAX, "absolute", kw.ALL, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]), + + (-test_var3, kw.MAX, "original", kw.FIRST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MAX, "original", kw.LAST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446]), + (-test_var3, kw.MAX, "original", kw.RANDOM, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MAX, "original", kw.ALL, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2331169623484446]), + + (-test_var3, kw.MIN, "absolute", kw.FIRST, [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MIN, "absolute", kw.LAST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MIN, "absolute", kw.RANDOM, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MIN, "absolute", kw.ALL, [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + + (-test_var3, kw.MIN, "original", kw.FIRST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MIN, "original", kw.LAST, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0]), + (-test_var3, kw.MIN, "original", kw.RANDOM, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + (-test_var3, kw.MIN, "original", kw.ALL, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.9273255210020586, 0.0]), + + # multiple extreme values 2d + *[(test_var3_2d, kw.MAX, abs_val, kw.FIRST, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]) + for abs_val in ("absolute", "original")], + *[(test_var3_2d, kw.MAX, abs_val, kw.LAST, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]]) + for abs_val in ("absolute", "original")], + *[(test_var3_2d, kw.MAX, abs_val, kw.RANDOM, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]) + for abs_val in ("absolute", "original")], + *[(test_var3_2d, kw.MAX, abs_val, kw.ALL, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9273255210020586, 0.0]]) + for abs_val in ("absolute", "original")], + + (test_var3_2d, kw.MIN, "absolute", kw.FIRST, [[0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]), + (test_var3_2d, kw.MIN, "absolute", kw.LAST, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]), + (test_var3_2d, kw.MIN, "absolute", kw.RANDOM, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]), + (test_var3_2d, kw.MIN, "absolute", kw.ALL, [[0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.08976636599379373, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]), + + (test_var3_2d, kw.MIN, "original", kw.FIRST, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]), + (test_var3_2d, kw.MIN, "original", kw.LAST, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446]]), + (test_var3_2d, kw.MIN, "original", kw.RANDOM, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]), + (test_var3_2d, kw.MIN, "original", kw.ALL, [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.2331169623484446]]), +], ids=lambda x: x if isinstance(x, str) else str(getattr(x, 'shape', '')) ) +@pytest.mark.parametrize("indicator", ["indicator", "value"]) +def test_one_hot_mode_deterministic(benchmark, variable, tie, indicator, direction, abs_val, expected, func_mode): + + f = pnl.OneHot(default_variable=np.zeros_like(variable), + mode=kw.DETERMINISTIC, + tie=tie, + indicator=indicator=="indicator", + abs_val=abs_val=="absolute", + direction=direction, + seed=5) # seed to select middle of the 3 ties + + EX = pytest.helpers.get_func_execution(f, func_mode) + + EX(variable) + res = benchmark(EX, variable) + + if indicator == "indicator": + expected = np.where(np.asarray(expected) != 0, np.ones_like(expected), expected) + + np.testing.assert_allclose(res, expected) diff --git a/tests/llvm/test_builtins_mt_random.py b/tests/llvm/test_builtins_mt_random.py index f19f01e78e..37a1730c9d 100644 --- a/tests/llvm/test_builtins_mt_random.py +++ b/tests/llvm/test_builtins_mt_random.py @@ -6,12 +6,70 @@ SEED = 0 +@pytest.mark.benchmark(group="Mersenne Twister bounded integer PRNG") +@pytest.mark.parametrize('mode', ['numpy', + pytest.param('LLVM', marks=pytest.mark.llvm), + pytest.helpers.cuda_param('PTX')]) +@pytest.mark.parametrize("bounds, expected", + [((0xffffffff,), [3626764237, 1654615998, 3255389356, 3823568514, 1806341205]), + ((14,), [13, 12, 2, 5, 4]), + ((0,14), [13, 12, 2, 5, 4]), + ((5,0xffff), [2002, 28611, 19633, 1671, 37978]), + ], ids=lambda x: str(x) if len(x) != 5 else "") +# Python uses sampling of upper bits (vs. lower bits in Numpy). Skip it in this test. +def test_random_int32_bounded(benchmark, mode, bounds, expected): + + if mode == 'numpy': + # Numpy promotes elements to int64 + state = np.random.RandomState([SEED]) + + def f(): + return state.randint(*bounds, dtype=np.uint32) + + elif mode == 'LLVM': + init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_init') + state = init_fun.np_buffer_for_arg(0) + + init_fun(state, SEED) + + gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_int32_bounded') + + def f(): + lower, upper = bounds if len(bounds) == 2 else (0, bounds[0]) + out = gen_fun.np_buffer_for_arg(3) + gen_fun(state, lower, upper, out) + return out + + elif mode == 'PTX': + init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_init') + + state_size = init_fun.np_buffer_for_arg(0).nbytes + gpu_state = pnlvm.jit_engine.pycuda.driver.mem_alloc(state_size) + + init_fun.cuda_call(gpu_state, np.int32(SEED)) + + gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_int32_bounded') + out = gen_fun.np_buffer_for_arg(3) + gpu_out = pnlvm.jit_engine.pycuda.driver.Out(out) + + def f(): + lower, upper = bounds if len(bounds) == 2 else (0, bounds[0]) + gen_fun.cuda_call(gpu_state, np.uint32(lower), np.uint32(upper), gpu_out) + return out.copy() + + else: + assert False, "Unknown mode: {}".format(mode) + + res = [f(), f(), f(), f(), f()] + np.testing.assert_array_equal(res, expected) + benchmark(f) + @pytest.mark.benchmark(group="Mersenne Twister integer PRNG") @pytest.mark.parametrize('mode', ['Python', 'numpy', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.helpers.cuda_param('PTX')]) -def test_random_int(benchmark, mode): - res = [] +def test_random_int32(benchmark, mode): + if mode == 'Python': state = random.Random(SEED) @@ -23,7 +81,7 @@ def f(): state = np.random.RandomState([SEED]) def f(): - return state.randint(0xffffffff, dtype=np.int64) + return state.randint(0xffffffff, dtype=np.uint32) elif mode == 'LLVM': init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_mt_rand_init') @@ -57,8 +115,8 @@ def f(): else: assert False, "Unknown mode: {}".format(mode) - res = [f(), f()] - np.testing.assert_allclose(res, [3626764237, 1654615998]) + res = [f(), f(), f(), f(), f()] + np.testing.assert_array_equal(res, [3626764237, 1654615998, 3255389356, 3823568514, 1806341205]) benchmark(f) @@ -67,7 +125,7 @@ def f(): pytest.param('LLVM', marks=pytest.mark.llvm), pytest.helpers.cuda_param('PTX')]) def test_random_float(benchmark, mode): - res = [] + if mode == 'Python': # Python treats every seed as array state = random.Random(SEED) @@ -124,6 +182,7 @@ def f(): pytest.helpers.cuda_param('PTX')]) # Python uses different algorithm so skip it in this test def test_random_normal(benchmark, mode): + if mode == 'numpy': # numpy promotes elements to int64 state = np.random.RandomState([SEED]) diff --git a/tests/llvm/test_builtins_philox_random.py b/tests/llvm/test_builtins_philox_random.py index 56b6485b75..8553e6054e 100644 --- a/tests/llvm/test_builtins_philox_random.py +++ b/tests/llvm/test_builtins_philox_random.py @@ -16,7 +16,6 @@ (0xfeedcafe, [14360762734736817955, 5188080951818105836, 1417692977344505657, 15919241602363537044, 11006348070701344872, 12539562470140893435]), ]) def test_random_int64(benchmark, mode, seed, expected): - res = [] if mode == 'numpy': state = np.random.Philox([np.int64(seed).astype(np.uint64)]) prng = np.random.Generator(state) @@ -56,16 +55,69 @@ def f(): # Get >4 samples to force regeneration of Philox buffer res = [f(), f(), f(), f(), f(), f()] - np.testing.assert_allclose(res, expected) + np.testing.assert_array_equal(res, expected) benchmark(f) +@pytest.mark.benchmark(group="Philox integer PRNG") +@pytest.mark.parametrize('mode', ['numpy', + pytest.param('LLVM', marks=pytest.mark.llvm), + pytest.helpers.cuda_param('PTX')]) +@pytest.mark.parametrize("bounds, expected", + [((0xffffffff,), [582496168, 60417457, 4027530180, 1107101888, 1659784451, 2025357888]), + ((15,), [2, 0, 14, 3, 5, 7]), + ((0,15), [2, 0, 14, 3, 5, 7]), + ((5,0xffff), [8892, 926, 61454, 16896, 25328, 30906]), + ], ids=lambda x: str(x) if len(x) != 6 else "") +def test_random_int32_bounded(benchmark, mode, bounds, expected): + if mode == 'numpy': + state = np.random.Philox([SEED]) + prng = np.random.Generator(state) + + def f(): + return prng.integers(*bounds, dtype=np.uint32, endpoint=False) + + elif mode == 'LLVM': + init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_init') + state = init_fun.np_buffer_for_arg(0) + init_fun(state, SEED) + + gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_int32_bounded') + + def f(): + lower, upper = bounds if len(bounds) == 2 else (0, bounds[0]) + out = gen_fun.np_buffer_for_arg(3) + gen_fun(state, lower, upper, out) + return out + + elif mode == 'PTX': + init_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_init') + state_size = init_fun.np_buffer_for_arg(0).nbytes + gpu_state = pnlvm.jit_engine.pycuda.driver.mem_alloc(state_size) + init_fun.cuda_call(gpu_state, np.int64(SEED)) + + gen_fun = pnlvm.LLVMBinaryFunction.get('__pnl_builtin_philox_rand_int32_bounded') + out = gen_fun.np_buffer_for_arg(3) + gpu_out = pnlvm.jit_engine.pycuda.driver.Out(out) + + def f(): + lower, upper = bounds if len(bounds) == 2 else (0, bounds[0]) + gen_fun.cuda_call(gpu_state, np.uint32(lower), np.uint32(upper), gpu_out) + return out.copy() + + else: + assert False, "Unknown mode: {}".format(mode) + + # Get >4 samples to force regeneration of Philox buffer + res = [f(), f(), f(), f(), f(), f()] + np.testing.assert_array_equal(res, expected) + benchmark(f) + @pytest.mark.benchmark(group="Philox integer PRNG") @pytest.mark.parametrize('mode', ['numpy', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.helpers.cuda_param('PTX')]) def test_random_int32(benchmark, mode): - res = [] if mode == 'numpy': state = np.random.Philox([SEED]) prng = np.random.Generator(state)\ @@ -105,7 +157,7 @@ def f(): # Get >4 samples to force regeneration of Philox buffer res = [f(), f(), f(), f(), f(), f()] - np.testing.assert_allclose(res, [582496169, 60417458, 4027530181, 1107101889, 1659784452, 2025357889]) + np.testing.assert_array_equal(res, [582496169, 60417458, 4027530181, 1107101889, 1659784452, 2025357889]) benchmark(f) @@ -114,7 +166,6 @@ def f(): pytest.param('LLVM', marks=pytest.mark.llvm), pytest.helpers.cuda_param('PTX')]) def test_random_double(benchmark, mode): - res = [] if mode == 'numpy': state = np.random.Philox([SEED]) prng = np.random.Generator(state) @@ -161,7 +212,6 @@ def f(): pytest.param('LLVM', marks=pytest.mark.llvm), pytest.helpers.cuda_param('PTX')]) def test_random_float(benchmark, mode): - res = [] if mode == 'numpy': state = np.random.Philox([SEED]) prng = np.random.Generator(state) @@ -207,8 +257,7 @@ def f(): @pytest.mark.parametrize('mode', ['numpy', pytest.param('LLVM', marks=pytest.mark.llvm), pytest.helpers.cuda_param('PTX')]) -@pytest.mark.parametrize('fp_type', [pnlvm.ir.DoubleType(), pnlvm.ir.FloatType()], - ids=str) +@pytest.mark.parametrize('fp_type', [pnlvm.ir.DoubleType(), pnlvm.ir.FloatType()], ids=str) def test_random_normal(benchmark, mode, fp_type): if mode != 'numpy': # Instantiate builder context with the desired type diff --git a/tests/llvm/test_helpers.py b/tests/llvm/test_helpers.py index fa43c7fd31..6653dfd408 100644 --- a/tests/llvm/test_helpers.py +++ b/tests/llvm/test_helpers.py @@ -447,54 +447,15 @@ def test_helper_numerical(mode, op, var, expected, fp_type): np.testing.assert_allclose(res, expected) -@pytest.mark.llvm -@pytest.mark.parametrize('mode', ['CPU', pytest.helpers.cuda_param('PTX')]) -@pytest.mark.parametrize('var,expected', [ - (np.asfarray([1,2,3]), np.asfarray([2,3,4])), - (np.asfarray([[1,2],[3,4]]), np.asfarray([[2,3],[4,5]])), -], ids=["vector", "matrix"]) -def test_helper_elementwise_op(mode, var, expected): - with pnlvm.LLVMBuilderContext.get_current() as ctx: - arr_ptr_ty = ctx.convert_python_struct_to_llvm_ir(var).as_pointer() - - func_ty = ir.FunctionType(ir.VoidType(), [arr_ptr_ty, arr_ptr_ty]) - - custom_name = ctx.get_unique_name("elementwise_op") - function = ir.Function(ctx.module, func_ty, name=custom_name) - inp, out = function.args - block = function.append_basic_block(name="entry") - builder = ir.IRBuilder(block) - - pnlvm.helpers.call_elementwise_operation(ctx, builder, inp, - lambda ctx, builder, x: builder.fadd(x.type(1.0), x), out) - builder.ret_void() - - bin_f = pnlvm.LLVMBinaryFunction.get(custom_name) - - vec = np.asfarray(var, dtype=bin_f.np_arg_dtypes[0].base) - res = bin_f.np_buffer_for_arg(1) - - if mode == 'CPU': - bin_f(vec, res) - else: - bin_f.cuda_wrap_call(vec, res) - - assert np.array_equal(res, expected) @pytest.mark.llvm @pytest.mark.parametrize('mode', ['CPU', pytest.helpers.cuda_param('PTX')]) @pytest.mark.parametrize('var1,var2,expected', [ (np.array([1.,2.,3.]), np.array([1.,2.,3.]), np.array([2.,4.,6.])), (np.array([1.,2.,3.]), np.array([0.,1.,2.]), np.array([1.,3.,5.])), - (np.array([[1.,2.,3.], - [4.,5.,6.], - [7.,8.,9.]]), - np.array([[10.,11.,12.], - [13.,14.,15.], - [16.,17.,18.]]), - np.array([[11.,13.,15.], - [17.,19.,21.], - [23.,25.,27.]])), + (np.array([[1.,2.,3.], [4.,5.,6.], [7.,8.,9.]]), + np.array([[10.,11.,12.], [13.,14.,15.], [16.,17.,18.]]), + np.array([[11.,13.,15.], [17.,19.,21.], [23.,25.,27.]])), ]) def test_helper_recursive_iterate_arrays(mode, var1, var2, expected): with pnlvm.LLVMBuilderContext.get_current() as ctx: @@ -508,10 +469,10 @@ def test_helper_recursive_iterate_arrays(mode, var1, var2, expected): block = function.append_basic_block(name="entry") builder = ir.IRBuilder(block) - for (a_ptr, b_ptr, o_ptr) in pnlvm.helpers.recursive_iterate_arrays(ctx, builder, u, v, out): - a = builder.load(a_ptr) - b = builder.load(b_ptr) - builder.store(builder.fadd(a,b), o_ptr) + with pnlvm.helpers.recursive_iterate_arrays(ctx, builder, u, v, out) as (loop_builder, a_ptr, b_ptr, o_ptr): + a = loop_builder.load(a_ptr) + b = loop_builder.load(b_ptr) + loop_builder.store(loop_builder.fadd(a, b), o_ptr) builder.ret_void()