Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

during-training valid loss is wrong #444

Closed
bernstei opened this issue Jun 5, 2024 · 11 comments
Closed

during-training valid loss is wrong #444

bernstei opened this issue Jun 5, 2024 · 11 comments

Comments

@bernstei
Copy link
Collaborator

bernstei commented Jun 5, 2024

Looks to me like the validation loss in the log during fitting is actually the sum over all heads so far:

mace/mace/tools/train.py

Lines 217 to 233 in e4ac498

for valid_loader_name, valid_loader in valid_loaders.items():
valid_loss_head, eval_metrics = evaluate(
model=model_to_evaluate,
loss_fn=loss_fn,
data_loader=valid_loader,
output_args=output_args,
device=device,
)
valid_loss += valid_loss_head
valid_err_log(
valid_loss,
eval_metrics,
logger,
log_errors,
epoch,
valid_loader_name,
)

Is the solution just that the quantity that should be passed to is valid_loss_head?

@bernstei
Copy link
Collaborator Author

bernstei commented Jun 5, 2024

Pretty sure that this is indeed a bug, but also validation batches in the multi-head-interface branch are not deterministic.

@LarsSchaaf
Copy link
Collaborator

LarsSchaaf commented Jun 6, 2024

What would you say the best behaviour should be? The user can supply a list of weights for each head? Defaulting to 0: mp and 1:Default ?

Given that the main impact this has is which checkpoint (and model) to save:

  1. If youre finetuning you want to have the situation above
  2. If you're training a model on multiple reference data you want to be able to select which method yo care most about (hence the ability to supply weights?)

@gabor1
Copy link
Collaborator

gabor1 commented Jun 6, 2024

I think the validation losses and rmses should be printed separately for each head

@bernstei
Copy link
Collaborator Author

bernstei commented Jun 6, 2024

I think the validation losses and rmses should be printed separately for each head

They already are

@bernstei
Copy link
Collaborator Author

bernstei commented Jun 6, 2024

What would you say the best behaviour should be? The user can supply a list of weights for each head? Defaulting to 0: mp and 1:Default ?

Given that the main impact this has is which checkpoint (and model) to save:

  1. If youre finetuning you want to have the situation above
  2. If you're training a model on multiple reference data you want to be able to select which method yo care most about (hence the ability to supply weights?)

This has nothing to do with weights. That's a separate issue. The code claims to print a loss for each head, but actually prints the cumulative loss that it's computing as it calculates the total loss by looping over heads. That's all. I'll do a PR for this issue, now that it seems pretty clear (from the slack) that the validation loss not being deterministic is a separate bug.

@LarsSchaaf
Copy link
Collaborator

The point is that checkpoints only get saved if the loss decreases. Now that we have multiple heads and therefore multiple validation losses how do we decide when to save a checkpoint? My suggestion was having a main_loss that is a combination of the losses. Which combination depends on the usecase - hence the user should be able to supply a weighting over head_loss s. If the main_loss decreases a new checkpoint is saved.

@bernstei
Copy link
Collaborator Author

bernstei commented Jun 6, 2024

The point is that checkpoints only get saved if the loss decreases. Now that we have multiple heads and therefore multiple validation losses how do we decide when to save a checkpoint? My suggestion was having a main_loss that is a combination of the losses. Which combination depends on the usecase - hence the user should be able to supply a weighting over head_loss s. If the main_loss decreases a new checkpoint is saved.

A fine suggestion, but independent of this issue. I agree that the way the "total" loss, which is used to save checkpoints, is calculated could use further thought. And I don't even mind making that part of the PR I created for this issue. But the issue was really only about how valid_err_log needs to get the head-specific loss, rather than the (currently naively computed) partial sum that's constructed to calculate a total loss.

[edited] @LarsSchaaf I think you should perhaps open a new issue, an enhancement request, to make the logic for saving checkpoints based on loss less naive

@ilyes319
Copy link
Contributor

ilyes319 commented Jun 6, 2024

The user can already provide weights for different heads just to be clear.

@bernstei
Copy link
Collaborator Author

bernstei commented Jun 6, 2024

The user can already provide weights for different heads just to be clear.

Are those used when printing the validation loss? Or only when computing the gradient of the training loss w.r.t. shared parameters?

@ilyes319
Copy link
Contributor

ilyes319 commented Jun 6, 2024

They would be used for printing also currently.

@bernstei
Copy link
Collaborator Author

bernstei commented Jun 6, 2024

This issue was supposed to be closed by #449, but seems to still be open. Do we want to continue this discussion here, or close it and open a new one having to do with weights?

@ilyes319 ilyes319 closed this as completed Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants