Skip to content

Commit

Permalink
feat: add Rotate90 transform
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandre-Delplanque committed Mar 14, 2024
1 parent 39e317a commit c67404b
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions animaloc/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Please contact the author Alexandre Delplanque ([email protected]) for any questions.
Last modification: March 29, 2023
Last modification: March 14, 2024
"""
__author__ = "Alexandre Delplanque"
__license__ = "CC BY-NC-SA 4.0"
Expand All @@ -21,7 +21,7 @@
import torchvision
import scipy

from typing import Dict, Optional, Union, Tuple, List
from typing import Dict, Optional, Union, Tuple, List, Any

from ..utils.registry import Registry

Expand Down Expand Up @@ -528,4 +528,41 @@ def _onehot(self, image: torch.Tensor, target: torch.Tensor):
mask = self._gaussian_map(mask)
gauss_map[i] = mask

return gauss_map
return gauss_map

@TRANSFORMS.register()
class Rotate90:
''' Rotate the image by 90 degrees '''

def __init__(
self,
k: int = 1
) -> None:
'''
Args:
k (int, optional): number of times to rotate by 90 degrees. Defaults to 1.
'''

self.k = k

def __call__(
self,
image: Union[PIL.Image.Image, torch.Tensor],
target: Any,
) -> Tuple[torch.Tensor,dict]:
'''
Args:
image (PIL.Image.Image or torch.Tensor): image to transform [C,H,W]
target (Any): corresponding target
Returns:
Tuple[torch.Tensor, Any]:
the transormed image and target
'''

if isinstance(image, PIL.Image.Image):
image = torchvision.transforms.ToTensor()(image)

image = image.rot90(k=self.k, dims=(1,2))

return image, target

0 comments on commit c67404b

Please sign in to comment.