diff --git a/animaloc/data/transforms.py b/animaloc/data/transforms.py index 287a43a..50f756b 100644 --- a/animaloc/data/transforms.py +++ b/animaloc/data/transforms.py @@ -8,7 +8,7 @@ Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. - Last modification: March 29, 2023 + Last modification: March 14, 2024 """ __author__ = "Alexandre Delplanque" __license__ = "CC BY-NC-SA 4.0" @@ -21,7 +21,7 @@ import torchvision import scipy -from typing import Dict, Optional, Union, Tuple, List +from typing import Dict, Optional, Union, Tuple, List, Any from ..utils.registry import Registry @@ -528,4 +528,41 @@ def _onehot(self, image: torch.Tensor, target: torch.Tensor): mask = self._gaussian_map(mask) gauss_map[i] = mask - return gauss_map \ No newline at end of file + return gauss_map + +@TRANSFORMS.register() +class Rotate90: + ''' Rotate the image by 90 degrees ''' + + def __init__( + self, + k: int = 1 + ) -> None: + ''' + Args: + k (int, optional): number of times to rotate by 90 degrees. Defaults to 1. + ''' + + self.k = k + + def __call__( + self, + image: Union[PIL.Image.Image, torch.Tensor], + target: Any, + ) -> Tuple[torch.Tensor,dict]: + ''' + Args: + image (PIL.Image.Image or torch.Tensor): image to transform [C,H,W] + target (Any): corresponding target + + Returns: + Tuple[torch.Tensor, Any]: + the transormed image and target + ''' + + if isinstance(image, PIL.Image.Image): + image = torchvision.transforms.ToTensor()(image) + + image = image.rot90(k=self.k, dims=(1,2)) + + return image, target \ No newline at end of file