Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestion with Solution about loss function #605

Open
Mai0313 opened this issue Oct 2, 2023 · 1 comment
Open

Suggestion with Solution about loss function #605

Mai0313 opened this issue Oct 2, 2023 · 1 comment

Comments

@Mai0313
Copy link

Mai0313 commented Oct 2, 2023

I believe it is advantageous for us to separate the loss and weight addition in ./src/models/mnist_module.py.

For the original code, it uses self.criterion = torch.nn.CrossEntropyLoss() to be the only loss function in loss = self.criterion(logits, y).

However, I think there is a better way to do so; if we change loss function to be loss_fns: list[torch.nn.Module], then using list[dict] can provide more flexible for user.
For example, in model step, I do

def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        x, y = batch
        logits = self.forward(x)
        preds = torch.argmax(logits, dim=1)
        losses = {}  # a dict of {loss_fn_name: loss_value}
        losses["total_loss"] = 0.0
        for loss_fn in self.loss_fns:
            losses[loss_fn.tag] = loss_fn(preds, y)
            losses["total_loss"] += losses[loss_fn.tag] * loss_fn.weight
        return losses, preds, y

This revised approach retains the functionality designed by you but allows greater loss function inclusion. Users simply need to populate their custom loss function into src/models/components/loss_fn.py, and the rest is taken care of.

@Mai0313
Copy link
Author

Mai0313 commented Oct 2, 2023

This is just an idea, I am a huge fan of hydra template, so I decided to do my first contribution.
If there is any issue, please feel free to tell me since I wanna do some help ✌🏽

If you guys think this is a good idea to do so, I will fix pytest part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant