Skip to content

Commit

Permalink
modified: pytorch_forecasting/models/base_model.py
Browse files Browse the repository at this point in the history
	modified:   pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
	modified:   tests/test_models/test_temporal_fusion_transformer.py
  • Loading branch information
Luke-Chesley committed Feb 17, 2024
1 parent ac2fe54 commit 67ed8e5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
12 changes: 7 additions & 5 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def __init__(
reduce_on_plateau_min_lr: float = 1e-5,
weight_decay: float = 0.0,
optimizer_params: Dict[str, Any] = None,
monotone_constaints: Dict[str, int] = {},
monotone_constraints: Dict[str, int] = {},
output_transformer: Callable = None,
optimizer="Ranger",
):
Expand All @@ -430,7 +430,7 @@ def __init__(
Defaults to 1e-5
weight_decay (float): weight decay. Defaults to 0.0.
optimizer_params (Dict[str, Any]): additional parameters for the optimizer. Defaults to {}.
monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder
monotone_constraints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder
variables mapping
position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive,
larger numbers add more weight to the constraint vs. the loss but are usually not necessary).
Expand Down Expand Up @@ -726,7 +726,7 @@ def step(
y[1],
)

if self.training and len(self.hparams.monotone_constaints) > 0:
if self.training and len(self.hparams.monotone_constraints) > 0:
# calculate gradient with respect to continous decoder features
x["decoder_cont"].requires_grad_(True)
assert not torch._C._get_cudnn_enabled(), (
Expand Down Expand Up @@ -754,10 +754,12 @@ def step(

# select relevant features
indices = torch.tensor(
[self.hparams.x_reals.index(name) for name in self.hparams.monotone_constaints.keys()]
[self.hparams.x_reals.index(name) for name in self.hparams.monotone_constraints.keys()]
)
monotonicity = torch.tensor(
[val for val in self.hparams.monotone_constaints.values()], dtype=gradient.dtype, device=gradient.device
[val for val in self.hparams.monotone_constraints.values()],
dtype=gradient.dtype,
device=gradient.device,
)
# add additionl loss if gradient points in wrong direction
gradient = gradient[..., indices] * monotonicity[None, None]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


class TemporalFusionTransformer(BaseModelWithCovariates):

def __init__(
self,
hidden_size: int = 16,
Expand Down Expand Up @@ -55,7 +56,7 @@ def __init__(
log_val_interval: Union[int, float] = None,
log_gradient_flow: bool = False,
reduce_on_plateau_patience: int = 1000,
monotone_constaints: Dict[str, int] = {},
monotone_constraints: Dict[str, int] = {},
share_single_variable_networks: bool = False,
causal_attention: bool = True,
logging_metrics: nn.ModuleList = None,
Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(
log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training
failures
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder
monotone_constraints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder
variables mapping
position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive,
larger numbers add more weight to the constraint vs. the loss but are usually not necessary).
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
)
# test monotone constraints automatically
if "discount_in_percent" in train_dataloader.dataset.reals:
monotone_constaints = {"discount_in_percent": +1}
monotone_constraints = {"discount_in_percent": +1}
cuda_context = torch.backends.cudnn.flags(enabled=False)
else:
monotone_constaints = {}
monotone_constraints = {}
cuda_context = nullcontext()

kwargs.setdefault("learning_rate", 0.15)
Expand Down Expand Up @@ -149,7 +149,7 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
log_interval=5,
log_val_interval=1,
log_gradient_flow=True,
monotone_constaints=monotone_constaints,
monotone_constraints=monotone_constraints,
**kwargs
)
net.size()
Expand Down

0 comments on commit 67ed8e5

Please sign in to comment.