Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get param set param error #60

Open
gnodabb opened this issue Mar 1, 2022 · 0 comments
Open

get param set param error #60

gnodabb opened this issue Mar 1, 2022 · 0 comments

Comments

@gnodabb
Copy link

gnodabb commented Mar 1, 2022

There is an issue where in the BaseNeuralNetwork class, the 'set_param' and 'get_param' function do not get and set all of the parameters. This causes an issue when using basic sklearn functions such as 'cross_validate'

Edit: Looks like this may be fixed in PR #55

to fix update the following functions like this:

    def get_params(self, deep=False):
        """Get parameters for this estimator.

        Returns
        -------
        params : dictionary
            Parameter names mapped to their values.
        """
        params = {'activation': self.activation,
                  'algorithm': self.algorithm,
                  'hidden_nodes': self.hidden_nodes,
                  'max_iters': self.max_iters,
                  'bias': self.bias,
                  'is_classifier': self.is_classifier,
                  'learning_rate': self.learning_rate,
                  'early_stopping': self.early_stopping,
                  'clip_max': self.clip_max,
                  'restarts': self.restarts,
                  'schedule': self.schedule,
                  'pop_size': self.pop_size,
                  'mutation_prob': self.mutation_prob,
                  'max_attempts': self.max_attempts,
                  'random_state': self.random_state,
                  'curve': self.curve}

        return params

    def set_params(self, **in_params):
        """Set the parameters of this estimator.

        Parameters
        -------
        in_params: dictionary
            Dictionary of parameters to be set and the value to be set to.
        """
        if 'hidden_nodes' in in_params.keys():
            self.hidden_nodes = in_params['hidden_nodes']
        if 'max_iters' in in_params.keys():
            self.max_iters = in_params['max_iters']
        if 'bias' in in_params.keys():
            self.bias = in_params['bias']
        if 'is_classifier' in in_params.keys():
            self.is_classifier = in_params['is_classifier']
        if 'learning_rate' in in_params.keys():
            self.learning_rate = in_params['learning_rate']
        if 'early_stopping' in in_params.keys():
            self.early_stopping = in_params['early_stopping']
        if 'clip_max' in in_params.keys():
            self.clip_max = in_params['clip_max']
        if 'restarts' in in_params.keys():
            self.restarts = in_params['restarts']
        if 'schedule' in in_params.keys():
            self.schedule = in_params['schedule']
        if 'pop_size' in in_params.keys():
            self.pop_size = in_params['pop_size']
        if 'mutation_prob' in in_params.keys():
            self.mutation_prob = in_params['mutation_prob']
        if 'activation' in in_params.keys():
            self.activation = in_params['activation']
        if 'algorithm' in in_params.keys():
            self.algorithm = in_params['algorithm']
        if 'max_attempts' in in_params.keys():
            self.max_attempts = in_params['max_attempts']
        if 'random_state' in in_params.keys():
            self.random_state = in_params['random_state']
        if 'curve' in in_params.keys():
            self.curve = in_params['curve']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant