diff --git a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py index e1d0fdcf3c..ef7dbd63a7 100644 --- a/psyneulink/core/components/functions/nonstateful/selectionfunctions.py +++ b/psyneulink/core/components/functions/nonstateful/selectionfunctions.py @@ -479,85 +479,93 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, elif self.mode == DETERMINISTIC: assert False, "DETERMINISTIC mode not supported" - best_idx_ptr = builder.alloca(ctx.int32_ty) - builder.store(best_idx_ptr.type.pointee(0), best_idx_ptr) - - with pnlvm.helpers.array_ptr_loop(builder, arg_in, "search") as (b1, idx): - best_idx = b1.load(best_idx_ptr) - best_ptr = b1.gep(arg_in, [ctx.int32_ty(0), best_idx]) - - current_ptr = b1.gep(arg_in, [ctx.int32_ty(0), idx]) - current = b1.load(current_ptr) - - fabs = ctx.get_builtin("fabs", [current.type]) - - is_first = b1.icmp_unsigned("==", idx, idx.type(0)) - - # Allow the first element to win the comparison - prev_best = b1.select(is_first, best_ptr.type.pointee(float("NaN")), b1.load(best_ptr)) - - if self.mode == ARG_MAX: - cmp_op = ">" - cmp_prev = prev_best - cmp_curr = current - val = current - - elif self.mode == ARG_MAX_ABS: - cmp_op = ">" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = b1.call(fabs, [current]) - - elif self.mode == ARG_MAX_INDICATOR: - cmp_op = ">" - cmp_prev = prev_best - cmp_curr = current - val = current.type(1.0) - - elif self.mode == ARG_MAX_ABS_INDICATOR: - cmp_op = ">" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = current.type(1.0) - - elif self.mode == ARG_MIN: - cmp_op = "<" - cmp_prev = prev_best - cmp_curr = current - val = current - - elif self.mode == ARG_MIN_ABS: - cmp_op = "<" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = b1.call(fabs, [current]) - - elif self.mode == ARG_MIN_INDICATOR: - cmp_op = "<" - cmp_prev = prev_best - cmp_curr = current - val = current.type(1.0) - - elif self.mode == ARG_MIN_ABS_INDICATOR: - cmp_op = "<" - cmp_prev = b1.call(fabs, [prev_best]) - cmp_curr = b1.call(fabs, [current]) - val = current.type(1.0) + else: + direction, abs_val, indicator, tie = self._parse_mode(self.mode) + is_abs_val = ctx.bool_ty(abs_val) + is_indicator = ctx.bool_ty(indicator) - else: - assert False, "Unsupported mode in LLVM: {} for OneHot Function".format(self.mode) + num_extremes_ptr = builder.alloca(ctx.int32_ty) + builder.store(num_extremes_ptr.type.pointee(0), num_extremes_ptr) + + extreme_val_ptr = builder.alloca(ctx.float_ty) + builder.store(extreme_val_ptr.type.pointee(float("NaN")), extreme_val_ptr) + + fabs_f = ctx.get_builtin("fabs", [extreme_val_ptr.type.pointee]) + + with pnlvm.helpers.array_ptr_loop(builder, arg_in, "count_extremes") as (loop_builder, idx): + + current_ptr = loop_builder.gep(arg_in, [ctx.int32_ty(0), idx]) + current = loop_builder.load(current_ptr) + current_abs = loop_builder.call(fabs_f, [current]) + current = builder.select(is_abs_val, current_abs, current) + + old_extreme = loop_builder.load(extreme_val_ptr) + cmp_op = ">" if direction == MAX else "<" + is_new_extreme = loop_builder.fcmp_unordered(cmp_op, current, old_extreme) + + with loop_builder.if_then(is_new_extreme): + loop_builder.store(current, extreme_val_ptr) + loop_builder.store(num_extremes_ptr.type.pointee(1), num_extremes_ptr) + + is_old_extreme = loop_builder.fcmp_ordered("==", current, old_extreme) + with loop_builder.if_then(is_old_extreme): + extreme_count = loop_builder.load(num_extremes_ptr) + extreme_count = loop_builder.add(extreme_count, extreme_count.type(1)) + loop_builder.store(extreme_count, num_extremes_ptr) + + + if tie == FIRST: + extreme_start = num_extremes_ptr.type.pointee(0) + extreme_stop = num_extremes_ptr.type.pointee(1) + + elif tie == LAST: + extreme_stop = builder.load(num_extremes_ptr) + extreme_start = builder.sub(output_top, output_stop(1)) + + elif tie == ALL: + extreme_start = num_extremes_ptr.type.pointee(0) + extreme_stop = builder.load(num_extremes_ptr) + + else: + assert False + + + extreme_val = builder.load(extreme_val_ptr) + extreme_write_val = builder.select(is_indicator, extreme_val.type(1), extreme_val) + next_extreme_ptr = builder.alloca(num_extremes_ptr.type.pointee) + builder.store(next_extreme_ptr.type.pointee(0), next_extreme_ptr) + + pnlvm.helpers.printf(ctx, + builder, + "{} replacing extreme values of %e from <%u,%u) out of %u\n".format(self.name), + extreme_val, + extreme_start, + extreme_stop, + builder.load(num_extremes_ptr), + tags={"one_hot"}) + + with pnlvm.helpers.array_ptr_loop(builder, arg_in, "mark_extremes") as (loop_builder, idx): + current_ptr = loop_builder.gep(arg_in, [ctx.int32_ty(0), idx]) + current = loop_builder.load(current_ptr) + current_abs = loop_builder.call(fabs_f, [current]) + current = builder.select(is_abs_val, current_abs, current) + + is_extreme = loop_builder.fcmp_ordered("==", current, extreme_val) + current_extreme_idx = loop_builder.load(next_extreme_ptr) + + with loop_builder.if_then(is_extreme): + next_extreme_idx = loop_builder.add(current_extreme_idx, current_extreme_idx.type(1)) + loop_builder.store(next_extreme_idx, next_extreme_ptr) - prev_res_ptr = b1.gep(arg_out, [ctx.int32_ty(0), best_idx]) - cur_res_ptr = b1.gep(arg_out, [ctx.int32_ty(0), idx]) + is_after_start = loop_builder.icmp_unsigned(">=", current_extreme_idx, extreme_start) + is_before_stop = loop_builder.icmp_unsigned("<", current_extreme_idx, extreme_stop) - # Make sure other elements are zeroed - builder.store(cur_res_ptr.type.pointee(0), cur_res_ptr) + should_write_extreme = loop_builder.and_(is_extreme, is_after_start) + should_write_extreme = loop_builder.and_(should_write_extreme, is_before_stop) - cmp_res = builder.fcmp_unordered(cmp_op, cmp_curr, cmp_prev) - with builder.if_then(cmp_res): - builder.store(prev_res_ptr.type.pointee(0), prev_res_ptr) - builder.store(val, cur_res_ptr) - builder.store(idx, best_idx_ptr) + write_value = loop_builder.select(should_write_extreme, extreme_write_val, extreme_write_val.type(0)) + out_ptr = loop_builder.gep(arg_out, [ctx.int32_ty(0), idx]) + loop_builder.store(write_value, out_ptr) return builder diff --git a/tests/functions/test_selection.py b/tests/functions/test_selection.py index f74f5fa1d4..ee39fa5814 100644 --- a/tests/functions/test_selection.py +++ b/tests/functions/test_selection.py @@ -35,18 +35,18 @@ pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), id="OneHot ARG_MIN_INDICATOR"), pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR"), pytest.param(pnl.OneHot, -test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR Neg"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_VAL"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_VAL"), - pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_VAL Neg"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_INDICATOR"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_INDICATOR"), - pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_INDICATOR Neg"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0., -0.23311696), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_VAL"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_VAL"), - pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_VAL Neg"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_INDICATOR"), - pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_INDICATOR"), - pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_INDICATOR Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_VAL"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_ABS_VAL"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_ABS_VAL Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_INDICATOR"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_ABS_INDICATOR"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_ABS_INDICATOR Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0., -0.23311696), id="OneHot MIN_VAL"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_VAL"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_VAL Neg"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), id="OneHot MIN_INDICATOR"), + pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_INDICATOR"), + pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_INDICATOR Neg"), pytest.param(pnl.OneHot, [test_var, test_prob], {'mode':kw.PROB}, (0., 0., 0., 0.08976636599379373, 0., 0., 0., 0., 0., 0.), id="OneHot PROB"), pytest.param(pnl.OneHot, [test_var, test_prob], {'mode':kw.PROB_INDICATOR}, (0., 0., 0., 1., 0., 0., 0., 0., 0., 0.), id="OneHot PROB_INDICATOR"), pytest.param(pnl.OneHot, [test_var, test_philox], {'mode':kw.PROB}, expected_philox_prob, id="OneHot PROB Philox"),