Faster forward
Highligths
TorchMetrics v0.9 is now out, and it brings significant changes to how the forward method works. This blog post goes over these improvements and how they affect both users of TorchMetrics and users that implement custom metrics. TorchMetrics v0.9 also includes several new metrics and bug fixes.
Blog: TorchMetrics v0.9 — Faster forward
The Story of the Forward Method
Since the beginning of TorchMetrics, Forward has served the dual purpose of calculating the metric on the current batch and accumulating in a global state. Internally, this was achieved by calling update twice: one for each purpose, which meant repeating the same computation. However, for many metrics, calling update twice is unnecessary to achieve both the local batch statistics and accumulating globally because the global statistics are simple reductions of the local batch states.
In v0.9, we have finally implemented a logic that can take advantage of this and will only call update once before making a simple reduction. As you can see in the figure below, this can lead to a single call of forward being 2x faster in v0.9 compared to v0.8 of the same metric.
With the improvements to forward, many metrics have become significantly faster (up to 2x)
It should be noted that this change mainly benefits metrics (for example, confusionmatrix
) where calling update is expensive.
We went through all existing metrics in TorchMetrics and enabled this feature for all appropriate metrics, which was almost 95% of all metrics. We want to stress that if you are using metrics from TorchMetrics, nothing has changed to the API, and no code changes are necessary.
[0.9.0] - 2022-05-31
Added
- Added
RetrievalPrecisionRecallCurve
andRetrievalRecallAtFixedPrecision
to retrieval package (#951) - Added class property
full_state_update
that determinesforward
should callupdate
once or twice (#984,#1033) - Added support for nested metric collections (#1003)
- Added
Dice
to classification package (#1021) - Added support to segmentation type
segm
as IOU for mean average precision (#822)
Changed
- Renamed
reduction
argument toaverage
in Jaccard score and added additional options (#874)
Removed
- Removed deprecated
compute_on_step
argument (#962, #967, #979 ,#990, #991, #993, #1005, #1004, #1007)
Fixed
- Fixed non-empty state
dict
for a few metrics (#1012) - Fixed bug when comparing states while finding compute groups (#1022)
- Fixed
torch.double
support in stat score metrics (#1023) - Fixed
FID
calculation for non-equal size real and fake input (#1028) - Fixed case where
KLDivergence
could outputNan
(#1030) - Fixed deterministic for PyTorch<1.8 (#1035)
- Fixed default value for
mdmc_average
inAccuracy
(#1036) - Fixed missing copy of property when using compute groups in
MetricCollection
(#1052)
Contributors
@Borda, @burglarhobbit, @charlielito, @gianscarpe, @MrShevan, @phaseolud, @razmikmelikbekyan, @SkafteNicki, @tanmoyio, @vumichien
If we forgot someone due to not matching commit email with GitHub account, let us know :]