diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 6dc8500b..1edb7931 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -129,12 +129,24 @@ def __init__( persistent=False, ) + # Compute dimensionalities (e.g. to instantiate MLPs) ( self.num_boundary_nodes, boundary_static_dim, ) = self.boundary_static_features.shape - # TODO Compute boundary input dim separately - self.boundary_dim = self.grid_dim + + # Compute boundary input dim separately + num_boundary_forcing_vars = datastore_boundary.get_num_data_vars( + category="forcing" + ) + num_past_boundary_steps = args.num_past_boundary_steps + num_future_boundary_steps = args.num_future_boundary_steps + self.boundary_dim = ( + boundary_static_dim + # Temporal Embedding counts as one additional forcing_feature + + (num_boundary_forcing_vars + 1) + * (num_past_boundary_steps + num_future_boundary_steps + 1) + ) # Instantiate loss function self.loss = metrics.get_metric(args.loss)