Skip to content

Commit

Permalink
Merge branch 'feature' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandre-Delplanque authored Mar 25, 2024
2 parents 506e1e1 + d6b3ffb commit 0fb5a68
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 17 deletions.
41 changes: 39 additions & 2 deletions animaloc/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torchvision
import scipy

from typing import Dict, Optional, Union, Tuple, List
from typing import Dict, Optional, Union, Tuple, List, Any

from ..utils.registry import Registry

Expand Down Expand Up @@ -527,4 +527,41 @@ def _onehot(self, image: torch.Tensor, target: torch.Tensor):
mask = self._gaussian_map(mask)
gauss_map[i] = mask

return gauss_map
return gauss_map

@TRANSFORMS.register()
class Rotate90:
''' Rotate the image by 90 degrees '''

def __init__(
self,
k: int = 1
) -> None:
'''
Args:
k (int, optional): number of times to rotate by 90 degrees. Defaults to 1.
'''

self.k = k

def __call__(
self,
image: Union[PIL.Image.Image, torch.Tensor],
target: Any,
) -> Tuple[torch.Tensor,dict]:
'''
Args:
image (PIL.Image.Image or torch.Tensor): image to transform [C,H,W]
target (Any): corresponding target
Returns:
Tuple[torch.Tensor, Any]:
the transormed image and target
'''

if isinstance(image, PIL.Image.Image):
image = torchvision.transforms.ToTensor()(image)

image = image.rot90(k=self.k, dims=(1,2))

return image, target
3 changes: 1 addition & 2 deletions animaloc/train/losses/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,11 @@ def _ssim_loss(
assert weights.shape[0] == channel, \
'Number of weights must match the number of channels, ' \
f'got {channel} channels and {weights.shape[0]} weights'

weights = weights.to(output.device)

ssim_map = _ssim(target, output, window, window_size, channel)

if weights is not None:
weights = weights.to(output.device)
ssim_list = 1. - ssim_map.mean(3).mean(2)
loss = weights * ssim_list
else:
Expand Down
26 changes: 16 additions & 10 deletions animaloc/train/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def start(
warmup_iters: Optional[int] = None,
checkpoints: str = 'best',
select: str = 'min',
validate_on: str = 'recall',
validate_on: str = 'all',
wandb_flag: bool = False
) -> torch.nn.Module:
''' Start training from epoch 1
Expand All @@ -241,10 +241,11 @@ def start(
- 'min' (default), for selecting the epoch that yields to a minimum validation value,
- 'max', for selecting the epoch that yields to a maximum validation value.
Defaults to 'min'.
validate_on (str, optional): metrics used for validation (i.e. best model and auto-lr) when
custom evaluator is specified. Possible values are: 'recall', 'precision', 'f1_score',
'mse', 'mae', and 'rmse'.
Defauts to 'recall'
validate_on (str, optional): metrics/loss used for validation (i.e. best model and auto-lr).
For validation with losses, possible values are the names returned by the model, or 'all'
for using the sum of all losses (default). Possible values for evaluator are: 'recall',
'precision', 'f1_score', 'mse', 'mae', 'rmse', 'accuracy' or 'mAP'.
Defauts to 'all'
wandb_flag (bool, optional): set to True to log on Weight & Biases. Defaults to False.
Returns:
Expand Down Expand Up @@ -290,7 +291,7 @@ def start(

elif self.val_dataloader is not None:
val_flag = True
val_output = self.evaluate(epoch, wandb_flag=wandb_flag)
val_output = self.evaluate(epoch, wandb_flag=wandb_flag, returns=validate_on)
if wandb_flag:
wandb.log({'val_loss': val_output, 'epoch': epoch})

Expand Down Expand Up @@ -407,7 +408,7 @@ def resume(

elif self.val_dataloader is not None:
val_flag = True
val_output = self.evaluate(epoch, wandb_flag=wandb_flag)
val_output = self.evaluate(epoch, wandb_flag=wandb_flag, returns=validate_on)
if wandb_flag:
wandb.log({'val_loss': val_output, 'epoch': epoch})

Expand All @@ -423,7 +424,10 @@ def resume(
# scheduler
if lr_scheduler is not None:
if self.auto_lr_flag:
lr_scheduler.step(val_output)
if 'val_output' in locals():
lr_scheduler.step(val_output)
else:
lr_scheduler.step(self.best_val)
else:
lr_scheduler.step()

Expand All @@ -442,7 +446,7 @@ def resume(
return self.model

@torch.no_grad()
def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False) -> float:
def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False, returns: str = 'all') -> float:

self.model.eval()

Expand All @@ -457,6 +461,8 @@ def evaluate(self, epoch: int, reduction: str = 'mean', wandb_flag: bool = False
output, loss_dict = self.model(images, targets)

losses = sum(loss for loss in loss_dict.values())
if returns != 'all':
losses = loss_dict[returns]

loss_dict_reduced = reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
Expand Down Expand Up @@ -517,7 +523,7 @@ def _train(
wandb.log(loss_dict)

self.losses = sum(loss for loss in loss_dict.values())
batches_losses.append(self.losses)
batches_losses.append(self.losses.detach())

loss_dict_reduced = reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
Expand Down
19 changes: 16 additions & 3 deletions tools/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.utils.data import DataLoader
from PIL import Image

from animaloc.data.transforms import DownSample
from animaloc.data.transforms import DownSample, Rotate90
from animaloc.models import LossWrapper, HerdNet
from animaloc.eval import HerdNetStitcher, HerdNetEvaluator
from animaloc.eval.metrics import PointsMetrics
Expand Down Expand Up @@ -59,6 +59,8 @@
help='thumbnail size. Defaults to 256.')
parser.add_argument('-pf', type=int, default=10,
help='print frequence. Defaults to 10.')
parser.add_argument('-rot', type=int, default=0,
help='number of times to rotate by 90 degrees. Defaults to 0.')

args = parser.parse_args()

Expand All @@ -85,11 +87,19 @@ def main():
if i.endswith(('.JPG','.jpg','.JPEG','.jpeg'))]
n = len(img_names)
df = pandas.DataFrame(data={'images': img_names, 'x': [0]*n, 'y': [0]*n, 'labels': [1]*n})

end_transforms = []
if args.rot != 0:
end_transforms.append(Rotate90(k=args.rot))
end_transforms.append(DownSample(down_ratio = 2, anno_type = 'point'))

albu_transforms = [A.Normalize(mean=img_mean, std=img_std)]

dataset = CSVDataset(
csv_file = df,
root_dir = args.root,
albu_transforms = [A.Normalize(mean=img_mean, std=img_std)],
end_transforms = [DownSample(down_ratio = 2, anno_type = 'point')]
albu_transforms = albu_transforms,
end_transforms = end_transforms
)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False,
Expand Down Expand Up @@ -146,6 +156,9 @@ def main():
img_names = numpy.unique(detections['images'].values).tolist()
for img_name in img_names:
img = Image.open(os.path.join(args.root, img_name))
if args.rot != 0:
rot = args.rot * 90
img = img.rotate(rot, expand=True)
img_cpy = img.copy()
pts = list(detections[detections['images']==img_name][['y','x']].to_records(index=False))
pts = [(y, x) for y, x in pts]
Expand Down
4 changes: 4 additions & 0 deletions tools/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torchvision
import numpy
import cv2
import pandas

from albumentations import PadIfNeeded

Expand Down Expand Up @@ -56,6 +57,9 @@ def main():
patches_buffer = PatchesBuffer(args.csv, args.root, (args.height, args.width), overlap=args.overlap, min_visibility=args.min).buffer
patches_buffer.drop(columns='limits').to_csv(os.path.join(args.dest, 'gt.csv'), index=False)

if not args.all:
images_paths = [os.path.join(args.root, x) for x in pandas.read_csv(args.csv)['images'].unique()]

for img_path in tqdm(images_paths, desc='Exporting patches'):
pil_img = PIL.Image.open(img_path)
img_tensor = torchvision.transforms.ToTensor()(pil_img)
Expand Down

0 comments on commit 0fb5a68

Please sign in to comment.