-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classification metrics overhaul: input formatting standardization (1/n) #4837
Classification metrics overhaul: input formatting standardization (1/n) #4837
Conversation
Thanks for splitting off the PR! Reviewing now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks like good changes, just a few small things to fix.
Codecov Report
@@ Coverage Diff @@
## master #4837 +/- ##
=======================================
Coverage 93% 93%
=======================================
Files 129 130 +1
Lines 9397 9527 +130
=======================================
+ Hits 8713 8843 +130
Misses 684 684 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall very good :)
Co-authored-by: Nicki Skafte <[email protected]>
Is there anything else that needs to be done before this PR can be merged? |
@tadejsv, thanks for the further description of the |
@tadejsv mind resolve conflicts :] probably after #4549 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just would be nice to have tests also for all the helper functions raising some kind of exception...
…orch-lightning into cls_metrics_input_formatting
Alright, merge conflicts resolved, ready for final review. @SkafteNicki please double check that docs are ok (git diff not useful there). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, docs looks fine :]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High-level review, looks good, nice docs
This PR is a spin-off from #4835. It should be merged before any other spin offs, as it provides a base for all of them
What does this PR do?
General (fundamental) changes
I have created a new
_input_format_classification
function (inmetrics/classification/utils
). The job of this function is to a) validate, and b) transform the inputs into a common format. This common format is a binary label indicator array: either(N, C)
, or(N, C, X)
(only for multi-dimensional multi-class inputs).I believe that having such a "central" function is crucial, as it gets rid of code duplication (which was present in PL metrics before), and enables metric developers to focus on developing the metrics themselves, and not on standardizing and validating inputs.
The validation performed on the inputs basically makes sure that they fall into one of the possible input type cases, that the values are consistent with both the type of the inputs and the additional parameters set (e.g. that there is no label higher than
num_classes
in target). The docstrings (and the new "Input types" section in the documentation) give all the details about how the standardization and validation are performed.Here I'll list the parameters of this function (many of which are also present on some metrics), and why I decided to use them:
threshold
: The probability threshold for binarizing binary and multi-label inputs.num_classes
: number of classes. Used to either decide theC
dimension of inputs, or, if this is already implicitly given, to ensure consistency between inputs and number of classes the user specified when creating the metric (thus ignoring either having to chech this manually inupdate
for each metric, or raising error when updating the state, which may not be very clearto the user).
top_k
: for (multi-dimensional) multi-class, if predictions are given as probabilities, selects the top k highest probabilities per sample. It's a generalization of the usual procedure, withk=1
. This will be used by theAccuracy
metric in subsequent PRs.is_multiclass
: used for transforming binary or multi-label input to 2-class multi-class and 2-class multi-dimensional multi-class, respectively. And vice versa.Why? This is similar to
multilabel
argument that was (is?) present on some metrics. I believe this is a better name for it, as it also deals with transforming to/from binary. But why is it needed? There are cases where it is not clear what the inputs are: for example, say that both preds and target are of the form [0,1,0,1]. This actually appears to be multi-class (could be the case that is simply happened in this batch that there were only 0s and 1s), so an explicit instruction is needed to tell the metrics that this is in fact binary. On the other hand, sometimes we would like to treat binary inputs as two class inputs - this is the case used in confusion matrix.I also experiemented with using
num_classes
to determine this. Besides this being a very confusing approach, requiring several paragraphs to explain clearly, it also does not resolve all ambiguities (is settingnum_classes=1
with 2 class probability predictions a request to treat the data as binary, or an inconsitency of inputs that should raise an error?). So I thinkis_multiclass
is the best approach here.Documentation
Instead of metrics being organized into "Class Metrics" and "Functional Metrics", they are now organized by topics (Classification, Regression, ...), and within topics split into class and functional, if necessary. This allows to add special topic-related sections - in this case I have added a section on what type of inputs are used for classification metric - a section that metrics can link to, in order to not repeat the same thing 100 times, and to keep docstrings short and to the point.
A second half of the Input types section with examples from StatScores metric will be added in the metric's PR.