From 3213e5784f9749b0495b96f28685b2202573931a Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Sat, 30 Nov 2024 20:22:21 +0200 Subject: [PATCH] Added decay to loss scaling range --- utils/loss.py | 26 ++++++++++++++------------ utils/trainer.py | 2 ++ utils/utils.py | 1 - 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index 49e01f7..f7bea57 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -73,18 +73,15 @@ def __init__(self, loss, reducer): self.loss = loss self.reducer = reducer self.loss_scaling_range = float(os.getenv("loss_scaling_range", 0.25)) + self.patience = int(os.getenv("loss_scaling_patience", 3000)) + self.last_epoch = 0 get_logger().log_both( f"Using Loss scaling with scaling range: {self.loss_scaling_range}" ) - def get_weights( - self, shape: torch.Size, dtype: torch.dtype, device: torch.device - ) -> Tensor: - return self._get_weights_impl(shape, dtype, device) - @abstractmethod - def _get_weights_impl( - self, shape: torch.Size, dtype: torch.dtype, device: torch.device + def get_weights( + self, shape: torch.Size, dtype: torch.dtype, device: torch.device ) -> Tensor: pass @@ -94,10 +91,16 @@ def forward(self, outputs, targets): loss * self.get_weights(loss.shape, loss.dtype, loss.device) ) + def step(self): + self.last_epoch += 1 + if self.last_epoch == self.patience: + self.last_epoch = 0 + self.loss_scaling_range /= 2 + class NormalScalingLoss(LossScaler): - def _get_weights_impl( - self, shape: torch.Size, dtype: torch.dtype, device: torch.device + def get_weights( + self, shape: torch.Size, dtype: torch.dtype, device: torch.device ) -> Tensor: return torch.normal( 1.0, self.loss_scaling_range, shape, dtype=dtype, device=device @@ -105,8 +108,8 @@ def _get_weights_impl( class UniformScalingLoss(LossScaler): - def _get_weights_impl( - self, shape: torch.Size, dtype: torch.dtype, device: torch.device + def get_weights( + self, shape: torch.Size, dtype: torch.dtype, device: torch.device ) -> Tensor: return torch.zeros(shape, dtype=dtype, device=device).uniform_( 1 - self.loss_scaling_range, 1 + self.loss_scaling_range @@ -129,7 +132,6 @@ def init_criterion(args): else: loss = loss(reduction=args.reduction if args.loss_scaling is None else "none") - # TODO: Add decay to loss scaling if args.loss_scaling is None: return loss elif args.loss_scaling == "normal-scaling": diff --git a/utils/trainer.py b/utils/trainer.py index 89754f4..0710926 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -261,6 +261,8 @@ def post_epoch(self, metrics: dict): self.batch_transforms_cpu.step() if self.batch_transforms_device is not None: self.batch_transforms_device.step() + if hasattr(self.criterion, "step"): + self.criterion.step() def epoch_description(self, metrics): train_acc = round(metrics["Train/Accuracy"], 2) diff --git a/utils/utils.py b/utils/utils.py index 86cd6a6..9e47daa 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -62,5 +62,4 @@ def __call__(self, *args, **kwargs): ret, elapsed = timed(return_time=True, stdout=False)(self.fn)(*args, **kwargs) self.total += elapsed self.calls += 1 - print(self.total / self.calls) return ret