Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: nesterov_momentum() got an unexpected keyword argument 'logger' #63

Open
xiao751 opened this issue May 13, 2020 · 1 comment

Comments

@xiao751
Copy link

xiao751 commented May 13, 2020

When I get to this row, I get an error. Why is that?
"metrics = model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn)"
[INFO] Training CoxMLP

TypeError Traceback (most recent call last)
in
18
19 # If you have validation data, you can add it as the second parameter to the function
---> 20 metrics = model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn)

D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in train(self, train_data, valid_data, n_epochs, validation_frequency, patience, improvement_threshold, patience_increase, verbose, update_fn, **kwargs)
366 reached, looks at validation improvement to increase patience or
367 early stop.
--> 368 improvement_threshold: percentage of improvement needed to increase
369 patience.
370 patience_increase: multiplier to patience if threshold is reached.

D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in _get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, **kwargs)
208 updates = update_fn(
209 scaled_grads, self.params, **kwargs
--> 210 )
211 else:
212 updates = update_fn(

D:\anaconda\lib\site-packages\deepsurv\deep_surv.py in _get_loss_updates(self, L1_reg, L2_reg, update_fn, max_norm, deterministic, **kwargs)
179 Returns Theano expressions for the network's loss function and parameter
180 updates.
--> 181
182 Parameters:
183 L1_reg: float for L1 weight regularization coefficient.

TypeError: nesterov_momentum() got an unexpected keyword argument 'logger'

@jaredleekatzman
Copy link
Owner

Are you still having this issue? It looks like the logger is being passed from the .train() function to the _get_train_valid_fn() which is then passing it up update_fn through the **kwargs.

Adding a logger=None parameter to the function signature of _get_train_valid_fn might fix the problem.

For example:

_get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, **kwargs)

Would become:

_get_train_valid_fn(self, L1_reg, L2_reg, learning_rate, logger=None, **kwargs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants