Skip to content

Commit

Permalink
llvm/OneHot: Refactor to match Python behaviour for modes != DETERMIN…
Browse files Browse the repository at this point in the history
…ISTIC

Allow output of multiple extremes.
Enable tests.

Still TODO:
 * mode==DETERMINISTIC
 * tie==RANDOM (only used with mode==DETERMINISTIC)
 * 2d arguments

Signed-off-by: Jan Vesely <[email protected]>
  • Loading branch information
jvesely committed Nov 14, 2024
1 parent 0b8ddf7 commit af4a9d0
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 87 deletions.
158 changes: 83 additions & 75 deletions psyneulink/core/components/functions/nonstateful/selectionfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,85 +479,93 @@ def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out,
elif self.mode == DETERMINISTIC:
assert False, "DETERMINISTIC mode not supported"

best_idx_ptr = builder.alloca(ctx.int32_ty)
builder.store(best_idx_ptr.type.pointee(0), best_idx_ptr)

with pnlvm.helpers.array_ptr_loop(builder, arg_in, "search") as (b1, idx):
best_idx = b1.load(best_idx_ptr)
best_ptr = b1.gep(arg_in, [ctx.int32_ty(0), best_idx])

current_ptr = b1.gep(arg_in, [ctx.int32_ty(0), idx])
current = b1.load(current_ptr)

fabs = ctx.get_builtin("fabs", [current.type])

is_first = b1.icmp_unsigned("==", idx, idx.type(0))

# Allow the first element to win the comparison
prev_best = b1.select(is_first, best_ptr.type.pointee(float("NaN")), b1.load(best_ptr))

if self.mode == ARG_MAX:
cmp_op = ">"
cmp_prev = prev_best
cmp_curr = current
val = current

elif self.mode == ARG_MAX_ABS:
cmp_op = ">"
cmp_prev = b1.call(fabs, [prev_best])
cmp_curr = b1.call(fabs, [current])
val = b1.call(fabs, [current])

elif self.mode == ARG_MAX_INDICATOR:
cmp_op = ">"
cmp_prev = prev_best
cmp_curr = current
val = current.type(1.0)

elif self.mode == ARG_MAX_ABS_INDICATOR:
cmp_op = ">"
cmp_prev = b1.call(fabs, [prev_best])
cmp_curr = b1.call(fabs, [current])
val = current.type(1.0)

elif self.mode == ARG_MIN:
cmp_op = "<"
cmp_prev = prev_best
cmp_curr = current
val = current

elif self.mode == ARG_MIN_ABS:
cmp_op = "<"
cmp_prev = b1.call(fabs, [prev_best])
cmp_curr = b1.call(fabs, [current])
val = b1.call(fabs, [current])

elif self.mode == ARG_MIN_INDICATOR:
cmp_op = "<"
cmp_prev = prev_best
cmp_curr = current
val = current.type(1.0)

elif self.mode == ARG_MIN_ABS_INDICATOR:
cmp_op = "<"
cmp_prev = b1.call(fabs, [prev_best])
cmp_curr = b1.call(fabs, [current])
val = current.type(1.0)
else:
direction, abs_val, indicator, tie = self._parse_mode(self.mode)
is_abs_val = ctx.bool_ty(abs_val)
is_indicator = ctx.bool_ty(indicator)

else:
assert False, "Unsupported mode in LLVM: {} for OneHot Function".format(self.mode)
num_extremes_ptr = builder.alloca(ctx.int32_ty)
builder.store(num_extremes_ptr.type.pointee(0), num_extremes_ptr)

extreme_val_ptr = builder.alloca(ctx.float_ty)
builder.store(extreme_val_ptr.type.pointee(float("NaN")), extreme_val_ptr)

fabs_f = ctx.get_builtin("fabs", [extreme_val_ptr.type.pointee])

with pnlvm.helpers.array_ptr_loop(builder, arg_in, "count_extremes") as (loop_builder, idx):

current_ptr = loop_builder.gep(arg_in, [ctx.int32_ty(0), idx])
current = loop_builder.load(current_ptr)
current_abs = loop_builder.call(fabs_f, [current])
current = builder.select(is_abs_val, current_abs, current)

old_extreme = loop_builder.load(extreme_val_ptr)
cmp_op = ">" if direction == MAX else "<"
is_new_extreme = loop_builder.fcmp_unordered(cmp_op, current, old_extreme)

with loop_builder.if_then(is_new_extreme):
loop_builder.store(current, extreme_val_ptr)
loop_builder.store(num_extremes_ptr.type.pointee(1), num_extremes_ptr)

is_old_extreme = loop_builder.fcmp_ordered("==", current, old_extreme)
with loop_builder.if_then(is_old_extreme):
extreme_count = loop_builder.load(num_extremes_ptr)
extreme_count = loop_builder.add(extreme_count, extreme_count.type(1))
loop_builder.store(extreme_count, num_extremes_ptr)


if tie == FIRST:
extreme_start = num_extremes_ptr.type.pointee(0)
extreme_stop = num_extremes_ptr.type.pointee(1)

elif tie == LAST:
extreme_stop = builder.load(num_extremes_ptr)
extreme_start = builder.sub(output_top, output_stop(1))

elif tie == ALL:
extreme_start = num_extremes_ptr.type.pointee(0)
extreme_stop = builder.load(num_extremes_ptr)

else:
assert False


extreme_val = builder.load(extreme_val_ptr)
extreme_write_val = builder.select(is_indicator, extreme_val.type(1), extreme_val)
next_extreme_ptr = builder.alloca(num_extremes_ptr.type.pointee)
builder.store(next_extreme_ptr.type.pointee(0), next_extreme_ptr)

pnlvm.helpers.printf(ctx,
builder,
"{} replacing extreme values of %e from <%u,%u) out of %u\n".format(self.name),
extreme_val,
extreme_start,
extreme_stop,
builder.load(num_extremes_ptr),
tags={"one_hot"})

with pnlvm.helpers.array_ptr_loop(builder, arg_in, "mark_extremes") as (loop_builder, idx):
current_ptr = loop_builder.gep(arg_in, [ctx.int32_ty(0), idx])
current = loop_builder.load(current_ptr)
current_abs = loop_builder.call(fabs_f, [current])
current = builder.select(is_abs_val, current_abs, current)

is_extreme = loop_builder.fcmp_ordered("==", current, extreme_val)
current_extreme_idx = loop_builder.load(next_extreme_ptr)

with loop_builder.if_then(is_extreme):
next_extreme_idx = loop_builder.add(current_extreme_idx, current_extreme_idx.type(1))
loop_builder.store(next_extreme_idx, next_extreme_ptr)

prev_res_ptr = b1.gep(arg_out, [ctx.int32_ty(0), best_idx])
cur_res_ptr = b1.gep(arg_out, [ctx.int32_ty(0), idx])
is_after_start = loop_builder.icmp_unsigned(">=", current_extreme_idx, extreme_start)
is_before_stop = loop_builder.icmp_unsigned("<", current_extreme_idx, extreme_stop)

# Make sure other elements are zeroed
builder.store(cur_res_ptr.type.pointee(0), cur_res_ptr)
should_write_extreme = loop_builder.and_(is_extreme, is_after_start)
should_write_extreme = loop_builder.and_(should_write_extreme, is_before_stop)

cmp_res = builder.fcmp_unordered(cmp_op, cmp_curr, cmp_prev)
with builder.if_then(cmp_res):
builder.store(prev_res_ptr.type.pointee(0), prev_res_ptr)
builder.store(val, cur_res_ptr)
builder.store(idx, best_idx_ptr)
write_value = loop_builder.select(should_write_extreme, extreme_write_val, extreme_write_val.type(0))
out_ptr = loop_builder.gep(arg_out, [ctx.int32_ty(0), idx])
loop_builder.store(write_value, out_ptr)

return builder

Expand Down
24 changes: 12 additions & 12 deletions tests/functions/test_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@
pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), id="OneHot ARG_MIN_INDICATOR"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.ARG_MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot ARG_MIN_ABS_INDICATOR Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_VAL"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_VAL"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_VAL Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_INDICATOR"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_INDICATOR"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MAX_ABS_INDICATOR Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0., -0.23311696), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_VAL"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_VAL"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_VAL Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_INDICATOR"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_INDICATOR"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), marks=pytest.mark.llvm_not_implemented, id="OneHot MIN_ABS_INDICATOR Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_VAL"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_ABS_VAL"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0.92732552, 0.), id="OneHot MAX_ABS_VAL Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_INDICATOR"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_ABS_INDICATOR"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MAX_ABS_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 1., 0.), id="OneHot MAX_ABS_INDICATOR Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_VAL}, (0., 0., 0., 0., 0., 0., 0., 0., 0., -0.23311696), id="OneHot MIN_VAL"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_VAL"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_VAL}, (0., 0., 0., 0.08976637, 0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_VAL Neg"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_INDICATOR}, (0., 0., 0., 0., 0., 0., 0., 0., 0., 1.), id="OneHot MIN_INDICATOR"),
pytest.param(pnl.OneHot, test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_INDICATOR"),
pytest.param(pnl.OneHot, -test_var, {'mode':kw.MIN_ABS_INDICATOR}, (0., 0., 0., 1.,0., 0., 0., 0., 0., 0.), id="OneHot MIN_ABS_INDICATOR Neg"),
pytest.param(pnl.OneHot, [test_var, test_prob], {'mode':kw.PROB}, (0., 0., 0., 0.08976636599379373, 0., 0., 0., 0., 0., 0.), id="OneHot PROB"),
pytest.param(pnl.OneHot, [test_var, test_prob], {'mode':kw.PROB_INDICATOR}, (0., 0., 0., 1., 0., 0., 0., 0., 0., 0.), id="OneHot PROB_INDICATOR"),
pytest.param(pnl.OneHot, [test_var, test_philox], {'mode':kw.PROB}, expected_philox_prob, id="OneHot PROB Philox"),
Expand Down

0 comments on commit af4a9d0

Please sign in to comment.