Skip to content

Commit

Permalink
llvm/OneHot: Implement all modes (#3124)
Browse files Browse the repository at this point in the history
Refactor Python OneHot implementation.
Add tests for OneHot(mode=DETERMINISTIC).

Refactor compiled OneHot to isolate PROB and PROB_INDICATOR.
Convert recursive traversal of arrays from generator to context manager.

Implement compiled random integer generation for bounded integers.
Implement all OneHot modes in compiled mode.
  • Loading branch information
jvesely authored Nov 21, 2024
2 parents 8822de0 + 06bff2b commit c0f73e2
Show file tree
Hide file tree
Showing 11 changed files with 739 additions and 296 deletions.
18 changes: 14 additions & 4 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,14 @@
Context, ContextError, ContextFlags, INITIALIZATION_STATUS_FLAGS, _get_time, handle_external_context
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.keywords import \
CONTEXT, CONTROL_PROJECTION, DEFERRED_INITIALIZATION, EXECUTE_UNTIL_FINISHED, \
CONTEXT, CONTROL_PROJECTION, DEFERRED_INITIALIZATION, DETERMINISTIC, EXECUTE_UNTIL_FINISHED, \
FUNCTION, FUNCTION_PARAMS, INIT_FULL_EXECUTE_METHOD, INPUT_PORTS, \
LEARNING, LEARNING_PROJECTION, MATRIX, MAX_EXECUTIONS_BEFORE_FINISHED, \
MODEL_SPEC_ID_PSYNEULINK, MODEL_SPEC_ID_METADATA, \
MODEL_SPEC_ID_INPUT_PORTS, MODEL_SPEC_ID_OUTPUT_PORTS, \
MODEL_SPEC_ID_MDF_VARIABLE, \
MODULATORY_SPEC_KEYWORDS, NAME, OUTPUT_PORTS, OWNER, PARAMS, PREFS_ARG, \
RESET_STATEFUL_FUNCTION_WHEN, INPUT_SHAPES, VALUE, VARIABLE, SHARED_COMPONENT_TYPES
RANDOM, RESET_STATEFUL_FUNCTION_WHEN, INPUT_SHAPES, VALUE, VARIABLE, SHARED_COMPONENT_TYPES
from psyneulink.core.globals.log import LogCondition
from psyneulink.core.globals.parameters import \
Defaults, SharedParameter, Parameter, ParameterAlias, ParameterError, ParametersBase, check_user_specified, copy_parameter_value, is_array_like
Expand Down Expand Up @@ -1391,6 +1391,9 @@ def _get_compilation_state(self):
if cost_functions.DURATION not in cost_functions:
blacklist.add('duration_cost_fct')

if getattr(self, "mode", None) == DETERMINISTIC and getattr(self, "tie", None) != RANDOM:
whitelist.remove('random_state')

# Drop previous_value from MemoryFunctions
if hasattr(self.parameters, 'duplicate_keys'):
blacklist.add("previous_value")
Expand Down Expand Up @@ -1508,13 +1511,20 @@ def _get_compilation_params(self):
"retain_torch_trained_outputs", "retain_torch_targets", "retain_torch_losses"
"torch_trained_outputs", "torch_targets", "torch_losses",
# should be added to relevant _gen_llvm_function... when aug:
# OneHot:
'abs_val', 'indicator',
# SoftMax:
'mask_threshold', 'adapt_scale', 'adapt_base', 'adapt_entropy_weighting',
# LCAMechanism
"mask"
}

# OneHot:
# * runtime abs_val and indicator are only used in deterministic mode.
# * random_state and seed are only used in RANDOM tie resolution.
if getattr(self, "mode", None) != DETERMINISTIC:
blacklist.update(['abs_val', 'indicator'])
elif getattr(self, "tie", None) != RANDOM:
blacklist.add("seed")

# Mechanism's need few extra entries:
# * matrix -- is never used directly, and is flatened below
# * integration_rate -- shape mismatch with param port input
Expand Down
Loading

0 comments on commit c0f73e2

Please sign in to comment.