Skip to content

Commit

Permalink
imports: deprecate from pkg root [2/n] Detection (#1694)
Browse files Browse the repository at this point in the history
* imports: detection
* functionals
* imports
* fix
* all
  • Loading branch information
Borda authored Apr 11, 2023
1 parent a8bafa2 commit 60e4177
Show file tree
Hide file tree
Showing 13 changed files with 402 additions and 327 deletions.
3 changes: 2 additions & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality # noqa: E402
from torchmetrics.detection._deprecated import _ModifiedPanopticQuality as ModifiedPanopticQuality # noqa: E402
from torchmetrics.detection._deprecated import _PanopticQuality as PanopticQuality # noqa: E402
from torchmetrics.image import ( # noqa: E402
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.detection.panoptic_qualities import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_8

__all__ = ["ModifiedPanopticQuality", "PanopticQuality"]

if _TORCHVISION_GREATER_EQUAL_0_8:
from torchmetrics.detection.mean_ap import MeanAveragePrecision # noqa: F401

from torchmetrics.detection.modified_panoptic_quality import ModifiedPanopticQuality # noqa: F401
from torchmetrics.detection.panoptic_quality import PanopticQuality # noqa: F401
__all__.append("MeanAveragePrecision")
60 changes: 60 additions & 0 deletions src/torchmetrics/detection/_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Any, Collection

from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.prints import _deprecated_root_import_class


class _ModifiedPanopticQuality(ModifiedPanopticQuality):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([[[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]])
>>> target = tensor([[[0, 1], [0, 0], [6, 0], [7, 0], [6, 0], [255, 0]]])
>>> pq_modified = _ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> pq_modified(preds, target)
tensor(0.7667, dtype=torch.float64)
"""

def __init__(
self,
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ModifiedPanopticQuality", "detection")
return super().__init__(
things=things, stuffs=stuffs, allow_unknown_preds_category=allow_unknown_preds_category, **kwargs
)


class _PanopticQuality(PanopticQuality):
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [0, 0], [6, 0], [0, 1]],
... [[0, 0], [7, 0], [6, 0], [1, 0]],
... [[0, 0], [7, 0], [7, 0], [7, 0]]]])
>>> target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [0, 1]],
... [[0, 1], [0, 1], [6, 0], [1, 0]],
... [[0, 1], [7, 0], [1, 0], [1, 0]],
... [[0, 1], [7, 0], [7, 0], [7, 0]]]])
>>> panoptic_quality = _PanopticQuality(things = {0, 1}, stuffs = {6, 7})
>>> panoptic_quality(preds, target)
tensor(0.5463, dtype=torch.float64)
"""

def __init__(
self,
things: Collection[int],
stuffs: Collection[int],
allow_unknown_preds_category: bool = False,
**kwargs: Any,
) -> None:
_deprecated_root_import_class("PanopticQuality", "detection")
return super().__init__(
things=things, stuffs=stuffs, allow_unknown_preds_category=allow_unknown_preds_category, **kwargs
)
212 changes: 0 additions & 212 deletions src/torchmetrics/detection/modified_panoptic_quality.py

This file was deleted.

Loading

0 comments on commit 60e4177

Please sign in to comment.