Skip to content

Commit

Permalink
fix: change how images are rotated
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandre-Delplanque committed Mar 14, 2024
1 parent c67404b commit d6b3ffb
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions tools/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Please contact the author Alexandre Delplanque ([email protected]) for any questions.
Last modification: April 28, 2023
Last modification: March 14, 2024
"""
__author__ = "Alexandre Delplanque"
__license__ = "CC BY-NC-SA 4.0"
Expand All @@ -28,7 +28,7 @@
from torch.utils.data import DataLoader
from PIL import Image

from animaloc.data.transforms import DownSample
from animaloc.data.transforms import DownSample, Rotate90
from animaloc.models import LossWrapper, HerdNet
from animaloc.eval import HerdNetStitcher, HerdNetEvaluator
from animaloc.eval.metrics import PointsMetrics
Expand Down Expand Up @@ -61,7 +61,7 @@
parser.add_argument('-pf', type=int, default=10,
help='print frequence. Defaults to 10.')
parser.add_argument('-rot', type=int, default=0,
help='number of degrees to rotate the images (counter clockwise). Defaults to 0.')
help='number of times to rotate by 90 degrees. Defaults to 0.')

args = parser.parse_args()

Expand Down Expand Up @@ -89,16 +89,18 @@ def main():
n = len(img_names)
df = pandas.DataFrame(data={'images': img_names, 'x': [0]*n, 'y': [0]*n, 'labels': [1]*n})

albu_transforms = []
end_transforms = []
if args.rot != 0:
albu_transforms.append(A.Rotate(limit=(args.rot,args.rot), p=1))
albu_transforms.append(A.Normalize(mean=img_mean, std=img_std))
end_transforms.append(Rotate90(k=args.rot))
end_transforms.append(DownSample(down_ratio = 2, anno_type = 'point'))

albu_transforms = [A.Normalize(mean=img_mean, std=img_std)]

dataset = CSVDataset(
csv_file = df,
root_dir = args.root,
albu_transforms = albu_transforms,
end_transforms = [DownSample(down_ratio = 2, anno_type = 'point')]
end_transforms = end_transforms
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False,
Expand Down Expand Up @@ -156,7 +158,8 @@ def main():
for img_name in img_names:
img = Image.open(os.path.join(args.root, img_name))
if args.rot != 0:
img = img.rotate(args.rot, expand=True)
rot = args.rot * 90
img = img.rotate(rot, expand=True)
img_cpy = img.copy()
pts = list(detections[detections['images']==img_name][['y','x']].to_records(index=False))
pts = [(y, x) for y, x in pts]
Expand Down

0 comments on commit d6b3ffb

Please sign in to comment.