-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classification refactor #1001
Comments
I like that we are thinking of simplifying this for the users. The classification module has grown organically - many more options have been added over time. I agree a refactor is appropriate here. I like the alternative approach 1 where we keep the current basic classes as light wrappers. These are easy to remember and frequently used (Accuracy, F1, Precision ...) by a large user base. If we go with the wrapper class and do validation inside of them, perhaps we can recommend the specific classes BinaryMetricName, MultiClassMetricName, MultiLabelMetricName for the depending on the error and use case. |
Sounds good, I agree with @awaelchli that it would be nice to keep the base |
While it is not directly related to the discussion, in theory, this issue may also be relevant, since it highlights one more aspect of accuracy-like metrics that may be taken into account during the refactoring. |
It seems like this plan was aborted at some point? Can someone point me to the discussion where this happened. Want to clarify if the old metrics are discouraged or if the plan is to support both old and new for the indefinite future. |
@adamjstewart The PR that deprecated things was #1195. This was 2 years ago and the deprecation messages are in there and since then, the classes marked for removal have been removed already. I guess maybe you are referring to the classes flavors like Accuracy, MultiClassAccuracy, BinaryAccuracy etc? I am not aware of any plans to deprecate there. For example, |
Ah, gotcha. I interpreted this as deprecating |
This correct as we had in past many confused users which type was used, expecially with some edge cases within first batch... |
@lantiga @SkafteNicki are we considering dropping the old metrics in the future |
I also see references like "Legacy Example" in https://lightning.ai/docs/torchmetrics/stable/classification/accuracy.html that suggest that the previous style is discouraged. I personally think the old style is very convenient, especially for generic LightningModules like torchgeo.trainers.SemanticSegmentationTask, which could support binary/multiclass/multilabel with only minor changes. |
I totally agree but also with the luxury or not needed thinking about the task came frustration that results are not correct... for example you have classification and the shuffle is bad in a few first batches come with labels 0 so the metrics are set up as binary but the task is multi-label in real and other labels comes later... |
Oh I mean auto-detection is bad, but asking the user to specify a |
The classification package is long overdue for a refactor as we are seeing a rising number of issues that either request new features that are hard to implement in the current codebase, a disagreement between what users expect the metrics are doing and what they are actually doing.
A full list of issues marked that should be taken care of with the refactor can be found here
The refactor hope to adress the following problems:
pl.metrics
by an contributor. We should maybe have been more thoroughly in the review phase, because the code has been hard to maintain. This refactor should hopefully help adress this by lowering code complexity.num_classes=2
sometimes means doingbinary
classification and sometimes meansmulticlass
classification (which differs in their definition) depending on what metric you are using.sklearn
are handeling some cases. This refactor will adress these differences.Proposed solution
The proposed solution is to split each metric into three seperate metric instances
BinaryMetricName
MultiClassMetricName
MultiLabelMetricName
For example
Accuracy
will be split intoBinaryAccuracy
,MultiClassAccuracy
,MultiLabelAccuracy
. This solution directly solves a number of problems:if
-else
statements based on what task we are trying to solve, each metric will have a much more clear computational path.threshold
andtopk
as examples, which are current arguments to theAccuracy
,F1
ect. metrics.threshold
should only be set for binary and multilabel andtopk
should only be specified for multiclass. Dividing into seperate metrics helps communicate what arguments have a influence on the computations going on.Alternatives
We keep everything in one class but introduce an new (required) argument:
This alternative is not directly in opposition to the main proposed solution. If requested by users we could still provide a single class that just wraps the three individual metric classes into 1.
We keep the outer API the exact same and try to clean up the internals. This will most likely only address some of the current problems.
Integration
torchmetrics.Metric
class, which should not need to be touched doing the refactor. Some examples may need to be updated.Task
should already contain all information necessary to determine what class should be used. cc: @ethanwharrisDeprecation
The goal is to have the hole classification package refactored/cleaned up as the major work in 0.10. All current classification metrics will be given a deprecation warning and users will have until 0.11 to refactor their code to use the new classes.
While we are developing the new package we will have a freeze on new metrics in the classification package. We will still happily accept new metrics for other domains.
Documentation impact
Up until v0.8, this change would have made our documentation very annoying to scroll through as everything was in one central page. This change would essentially make the documentation for classification 3 times harder to navigate.
However, from v0.8 we changed it to have one page per metric. For this refactor we would keep one page per core metric e.g.
Accuracy, Precision, Recall
etc. and each page would then list every version of the metric.Development
The development can essentially be divided into 3 phases:
StatScore
andConfusionMatrix
class for all three modes. Many classification metrics can be calculated from these statistics.Main part of the refactor will be done by @SkafteNicki and @justusschock, with support from the rest of the core metrics team. We may be open for contributions for step 2 as it should be fairly simple sub classing and copy-paste work. Development should start within 2 weeks time.
Any feedback is appreciated :)
The text was updated successfully, but these errors were encountered: