-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor model class hierarchy #49
Comments
@joeloskarsson This looks great! Introducing even more freedom and clearly structured and understandable classes. I am no expert in this so not adding too much here. I really think though, that relieving |
I have some time to give this a proper think through now @joeloskarsson so here come my thoughts :)
Yes, agreed! Thank you for making that diagram too. It is very helpful in understanding the current structure.
This sounds good. My only suggestion would be to not include the plotting as a method of the I think the structure you have proposed with Maybe when you/we start coding this we could start by including the additions you forse making (ensemble prediction classes, CNN, ViT) as simply empty placeholder classes? Then we can be sure the data-structures we pass around will be general enough to apply in future too? In general this is really fantastic though. |
Thanks for some very valuable input @leifdenby!
Yes, that is what I have in mind as well. In the same way it is handled today, where
There is a logic to that yes. I am still thinking about how to do this in practice. I am in a way not particularly happy to send the target (ground truth) tensor further down the hierarchy than
Yes, this is a good point. I have given this some thought. One should note that having 1 (flattened) spatial dimensions or 2 (x,y) should not majorly impact anything before the Here is an example of a dummy CNN-class I have laying around that handles this with reshapes (using old class hierarchy): import torch
from neural_lam.models.ar_model import ARModel
from neural_lam import constants
class NewModel(ARModel):
"""
A new auto-regressive weather forecasting model
"""
def __init__(self, args):
super().__init__(args)
# Some dimensionalities that can be useful to have stored
self.input_dim = 2*constants.grid_state_dim + constants.grid_forcing_dim +\
constants.batch_static_feature_dim
self.output_dim = constants.grid_state_dim
# TODO: Define modules as members here that will be used in predict_step
self.layer = torch.nn.Conv2d(self.input_dim, self.output_dim, 1) # Dummy layer
def predict_step(self, prev_state, prev_prev_state, batch_static_features, forcing):
"""
Predict weather state one time step ahead
X_{t-1}, X_t -> X_t+1
prev_state: (B, N_grid, d_state), weather state X_t at time t
prev_prev_state: (B, N_grid, d_state), weather state X_{t-1} at time t-1
batch_static_features: (B, N_grid, batch_static_feature_dim), static forcing
forcing: (B, N_grid, forcing_dim), dynamic forcing
Returns:
next_state: (B, N_grid, d_state), predicted weather state X_{t+1} at time t+1
pred_std: None or (B, N_grid, d_state), predicted standard-deviations
(pred_std can be ignored by just returning None)
"""
# Reshape 1d grid to 2d image
input_flat = torch.cat((prev_state, prev_prev_state, batch_static_features,
forcing), dim=-1) # (B, N_grid, d_input)
input_grid = torch.reshape(input_flat, (-1, *constants.grid_shape,
input_flat.shape[2])) # (B, N_x, N_y, d_input)
# Most computer vision methods in torch want channel dimension first
input_grid = input_grid.permute((0,3,1,2)).contiguous() # (B, d_input, N_x, N_y)
# TODO: Feed input_grid through some model to predict output_grid
output_grid = self.layer(input_grid) # Shape (B, d_state, N_x, N_y)
# Reshape back from 2d to flattened grid dimension
output_grid = output_grid.permute((0,2,3,1)) # (B, N_x, N_y, d_state)
next_state = output_grid.flatten(1,2) # (B, N_grid, d_state)
return next_state, None A good reason to keep it as is, flattening in the Dataset class, is that once we start moving to more refined boundary setups, on different gridding, there will no longer be a 2d grid representation of the input data. Grid cells will be on irregular grids, meaning that it is not trivial to apply CNN models to them. Keeping things flattened all the way to the CNN model means that if you want to define a CNN on this you have to decide how to handle this irregularity. We don't have to make such decisions before the
Yes, that is a good idea. Or even
This could be smart yes. For the computer-visions models (CNN, ViT) I don't think that there will be much to keep in mind (see dicsussion above), but for the ensemble model this could make a lot of sense. Especially since I anyhow will populate that with the code from the probabilistic model branches later. My plan is to do this refactoring first, to make the merging of that easier and nicer. Overall it looks like there is support for this idea, so I can start writing the code for it. Then we can discuss more details in upcoming PR. I am happy to do the work on this, but I have a hard time to give a timeline, as it is not directly crucial to progress in ongoing research projects. Anyhow, I think this could potentially fit in v0.3.0 in the roadmap? |
Background
The different models that can be trained in Neural-LAM are currently all sub-classes of
pytorch_lightning.LightningModule
. In particular, much of the functionality sits in the first subclass,ARModel
. The current hierarchy looks like this:(I am making these rather than some more fancy UML-diagrams since I think this should be enough for the level of detail we need to discuss here).
The problem
In the current hierarchy everything is a subclass of
ARModel
. This has a number of drawbacks:forward
calls, but rather must resorts to our own similar construction (e.g.predict_step
)Proposed new hierarchy
I propose to split up the current class hierarchy into subclasses that have clear responsibilities. These should not just all inherit
ARModel
, but rather be members of each other as suitable. A first idea for this is shown below, including also potential future classes for new models (to show how this is more extendible):The core components are (I here count static features as part of forcing):
ForecasterModule
: Takes over much of the responsibility of the oldARModel
. Handles things not directly related to the nerual network components such as plotting, logging, moving batches to the right device. This inheritspytorch_lightning.LightningModule
and have the different train/val/test steps. In each step (train/val/test), unpacks the batch of tensors and uses aForecaster
to produce a full forecast. Also responsible for computing the loss based in a produced forecast (could also be inForecaster
, not entirely sure about this).Forecaster
: A generic forecaster capable of mapping from a set of initial states, forcing and boundary forcing into a full forecast of the requested length.ARForecaster
: Subclass ofForecaster
that uses an auto-regressive strategy to unroll a forecast. Makes use of aStepPredictor
at each AR step.StepPredictor
: A model mapping from the two previous time steps + forcing + boundary forcing to a prediction of the next state. Corresponds to theStepPredictor
.In the figure above we can also see how new kinds of models could fit into this hierarchy:
This is supposed to be a starting point for discussion and there will likely be things I have not thought about. Some parts of this will have to be hammered out when actually writing these classes, but I'd rather have the discussion whether this is a good direction to take things before starting to do too much work. Tagging @leifdenby and @sadamov for visibility.
The text was updated successfully, but these errors were encountered: