Skip to content

Commit

Permalink
Clean up/function/softmax one hot (#713)
Browse files Browse the repository at this point in the history
* • ControlMechanism
  assign_as_controller:  fixed bug (assignment of self.allocation)

* -
  • Loading branch information
jdcpni authored Mar 9, 2018
1 parent 0241e93 commit bd10d08
Showing 1 changed file with 11 additions and 26 deletions.
37 changes: 11 additions & 26 deletions psyneulink/components/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3474,15 +3474,10 @@ def _instantiate_function(self, context=None):

self.one_hot_function = None
output_type = self.get_current_function_param(OUTPUT_TYPE)
bounds = None

if output_type is MAX_VAL:
bounds = None
self.one_hot_function = OneHot(default_variable=self.variable,
mode=MAX_VAL).function
elif output_type is MAX_INDICATOR:
bounds = None
self.one_hot_function = OneHot(default_variable=self.variable,
mode=MAX_INDICATOR).function
if not output_type is ALL:
self.one_hot_function = OneHot(mode=output_type).function

super()._instantiate_function(context=context)

Expand Down Expand Up @@ -3521,6 +3516,8 @@ def function(self,
output_type = self.get_current_function_param(OUTPUT_TYPE)
gain = self.get_current_function_param(GAIN)

# Compute softmax and assign to sm

# Modulate variable by gain
v = gain * variable
# Shift by max to avoid extreme values:
Expand All @@ -3530,26 +3527,14 @@ def function(self,
# Normalize (to sum to 1)
sm = v / np.sum(v, axis=0)

# For the element that is max of softmax, set it's value to its softmax value or 1, set others to zero
if output_type in {MAX_VAL, MAX_INDICATOR}:
sm = self.one_hot_function(sm)

# # For the element that is max of softmax, set its value to 1, set others to zero
# elif output_type is MAX_INDICATOR:
# # sm = np.where(sm == np.max(sm), 1, 0)
# max_value = np.max(sm)
# sm = np.where(sm == max_value, 1, 0)
# Generate one-hot encoding based on selected output_type

# Choose a single element probabilistically based on softmax of their values;
# leave that element's value intact, set others to zero
if output_type in {MAX_VAL, MAX_INDICATOR}:
return self.one_hot_function(sm)
elif output_type is PROB:
cum_sum = np.cumsum(sm)
random_value = np.random.uniform()
chosen_item = next(element for element in cum_sum if element > random_value)
chosen_in_cum_sum = np.where(cum_sum == chosen_item, 1, 0)
sm = variable * chosen_in_cum_sum

return sm
return self.one_hot_function([variable, sm])
else:
return sm

def derivative(self, output, input=None):
"""
Expand Down

0 comments on commit bd10d08

Please sign in to comment.