-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Warn different size inputs in loss functions #1522
Comments
It's a good idea, but we'd have to define what range of sizes and dimensionalities are valid for each loss function. I can't think of any scenario where |
Flux models (conv, polling, batchnorm, etc) already can take variable sizes inputs. The loss functions shouldn't alter that. The dimension mismatch that would occur more often is probably most sensible and maybe not worth it to inject these messages in Flux for more specific cases. Re this mwe: Outer products can be handled by broadcasting, but keeping things general would be a bigger design win esp for non array uses. |
@collodi is referring to the case where a dimension mismatch error doesn't occur and Julia silently promotes the rank of one argument to match the other. There are certainly valid use-cases for that, but all of the loss examples we have in the docs use inputs with a uniform number of dimensions. The question is whether we can get away with defining loss functions like |
I addressed that in the |
What non-array uses to we foresee though? I would argue scalar-scalar loss calculations are wholly out of scope, so maybe array-scalar? Do we have any strategy other than documentation to prevent the silently incorrect behaviour demonstrated up top? |
Would it be a bad idea to simply warn the user? Instead of restricting the loss function inputs? That way, people who know what they're doing can simply ignore the message. |
The linked PR should address what you want |
Say I have a model output a scalar, and the batch size is 5.
m(x)
gives me an array of size(1, 5)
, whiley
is of size(5, )
.If I pass them to
Flux.Losses.mae
, the value is different than what I expect due to broadcasting.In PyTorch, I get a warning message.
It would be nice to have a warning message in Flux so I know I'm doing something wrong.
The text was updated successfully, but these errors were encountered: