-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[major] lagged regressor with interaction modeling (shared NN) #903
Conversation
Hi @karl-richter it looks good to me in terms of model architecture changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice one, I think this is a nice improvement!
A few things from me:
num_hidden_layers
- Maybe we can rename this tonum_hidden_layers_lagged_regressors
or, if want to share with ar_net,num_hidden_layers_lagged
. For example, Im building a shared network for future_regressors and I should be able to use a different configuration.- Out of curiosity, how come you use that function/formula to create
d_hidden
if it is not passed by user? - Would it make sense to pass the neural network architecture as an array? Something like
covar_net_layers_array
which would replace currentn_hidden
andnum_hidden_layers
Hey @karl-richter @ourownstory The commit I did was about removing Have a look, because it implies a few changes:
If you agree on the new changes there are two things missing that I will do as soon as you give thumbs up
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work @karl-richter and @alfonsogarciadecorral !!
oskar reviewed it once more
Status quo
For each covariate (aka lagged regressor), a seperate net is trained.
Change
All lagged regressors share a network. In plot_parameters, the attribution of each covariate on the forecast is derived using the Captum model attribution method for deep networks.