From bd10d08c04cd3c6abed76c4968cf627e5bad0f82 Mon Sep 17 00:00:00 2001 From: jdcpni Date: Thu, 8 Mar 2018 23:18:08 -0500 Subject: [PATCH] Clean up/function/softmax one hot (#713) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * • ControlMechanism assign_as_controller: fixed bug (assignment of self.allocation) * - --- psyneulink/components/functions/function.py | 37 ++++++--------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/psyneulink/components/functions/function.py b/psyneulink/components/functions/function.py index 5dce11ff54b..7716cb11c57 100644 --- a/psyneulink/components/functions/function.py +++ b/psyneulink/components/functions/function.py @@ -3474,15 +3474,10 @@ def _instantiate_function(self, context=None): self.one_hot_function = None output_type = self.get_current_function_param(OUTPUT_TYPE) + bounds = None - if output_type is MAX_VAL: - bounds = None - self.one_hot_function = OneHot(default_variable=self.variable, - mode=MAX_VAL).function - elif output_type is MAX_INDICATOR: - bounds = None - self.one_hot_function = OneHot(default_variable=self.variable, - mode=MAX_INDICATOR).function + if not output_type is ALL: + self.one_hot_function = OneHot(mode=output_type).function super()._instantiate_function(context=context) @@ -3521,6 +3516,8 @@ def function(self, output_type = self.get_current_function_param(OUTPUT_TYPE) gain = self.get_current_function_param(GAIN) + # Compute softmax and assign to sm + # Modulate variable by gain v = gain * variable # Shift by max to avoid extreme values: @@ -3530,26 +3527,14 @@ def function(self, # Normalize (to sum to 1) sm = v / np.sum(v, axis=0) - # For the element that is max of softmax, set it's value to its softmax value or 1, set others to zero - if output_type in {MAX_VAL, MAX_INDICATOR}: - sm = self.one_hot_function(sm) - - # # For the element that is max of softmax, set its value to 1, set others to zero - # elif output_type is MAX_INDICATOR: - # # sm = np.where(sm == np.max(sm), 1, 0) - # max_value = np.max(sm) - # sm = np.where(sm == max_value, 1, 0) + # Generate one-hot encoding based on selected output_type - # Choose a single element probabilistically based on softmax of their values; - # leave that element's value intact, set others to zero + if output_type in {MAX_VAL, MAX_INDICATOR}: + return self.one_hot_function(sm) elif output_type is PROB: - cum_sum = np.cumsum(sm) - random_value = np.random.uniform() - chosen_item = next(element for element in cum_sum if element > random_value) - chosen_in_cum_sum = np.where(cum_sum == chosen_item, 1, 0) - sm = variable * chosen_in_cum_sum - - return sm + return self.one_hot_function([variable, sm]) + else: + return sm def derivative(self, output, input=None): """