You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, I think there is a better way to do so; if we change loss function to be loss_fns: list[torch.nn.Module], then using list[dict] can provide more flexible for user.
For example, in model step, I do
def model_step(
self, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform a single model step on a batch of data.
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
:return: A tuple containing (in order):
- A tensor of losses.
- A tensor of predictions.
- A tensor of target labels.
"""
x, y = batch
logits = self.forward(x)
preds = torch.argmax(logits, dim=1)
losses = {} # a dict of {loss_fn_name: loss_value}
losses["total_loss"] = 0.0
for loss_fn in self.loss_fns:
losses[loss_fn.tag] = loss_fn(preds, y)
losses["total_loss"] += losses[loss_fn.tag] * loss_fn.weight
return losses, preds, y
This revised approach retains the functionality designed by you but allows greater loss function inclusion. Users simply need to populate their custom loss function into src/models/components/loss_fn.py, and the rest is taken care of.
The text was updated successfully, but these errors were encountered:
This is just an idea, I am a huge fan of hydra template, so I decided to do my first contribution.
If there is any issue, please feel free to tell me since I wanna do some help ✌🏽
If you guys think this is a good idea to do so, I will fix pytest part.
I believe it is advantageous for us to separate the loss and weight addition in
./src/models/mnist_module.py
.For the original code, it uses
self.criterion = torch.nn.CrossEntropyLoss()
to be the only loss function inloss = self.criterion(logits, y)
.However, I think there is a better way to do so; if we change loss function to be
loss_fns: list[torch.nn.Module]
, then usinglist[dict]
can provide more flexible for user.For example, in model step, I do
This revised approach retains the functionality designed by you but allows greater loss function inclusion. Users simply need to populate their custom loss function into
src/models/components/loss_fn.py
, and the rest is taken care of.The text was updated successfully, but these errors were encountered: