Skip to content

Commit

Permalink
test mask-rcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed May 17, 2024
1 parent db626a6 commit 46c6a29
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 601 deletions.
1 change: 0 additions & 1 deletion src/cultionet/augment/augmenter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from scipy.ndimage.measurements import label as nd_label
from tsaug import AddNoise, Drift, TimeWarp

from ..data.data import Data
Expand Down
9 changes: 4 additions & 5 deletions src/cultionet/augment/augmenters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import joblib
import numpy as np
import torch
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as TF
from torchvision.transforms import v2
from torchvision.transforms import InterpolationMode, v2
from torchvision.transforms.v2 import functional as VF
from tsaug import AddNoise, Drift, TimeWarp

from ..data.data import Data
Expand Down Expand Up @@ -208,9 +207,9 @@ def forward(
x = einops.rearrange(cdata.x, '1 c t h w -> 1 t c h w')

if self.direction == 'fliplr':
flip_transform = TF.hflip
flip_transform = VF.hflip
elif self.direction == 'flipud':
flip_transform = TF.vflip
flip_transform = VF.vflip
else:
raise NameError("The direction is not supported.")

Expand Down
168 changes: 0 additions & 168 deletions src/cultionet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from .models.lightning import (
CultionetLitModel,
CultionetLitTransferModel,
MaskRCNNLitModel,
RefineLitModel,
)
from .utils.logging import set_color_logger
Expand Down Expand Up @@ -205,173 +204,6 @@ def get_trainer_params(self) -> dict:
)


def fit_maskrcnn(
dataset: EdgeDataset,
ckpt_file: T.Union[str, Path],
test_dataset: T.Optional[EdgeDataset] = None,
val_frac: T.Optional[float] = 0.2,
batch_size: T.Optional[int] = 4,
accumulate_grad_batches: T.Optional[int] = 1,
filters: T.Optional[int] = 64,
num_classes: T.Optional[int] = 2,
learning_rate: T.Optional[float] = 0.001,
epochs: T.Optional[int] = 30,
save_top_k: T.Optional[int] = 1,
early_stopping_patience: T.Optional[int] = 7,
early_stopping_min_delta: T.Optional[float] = 0.01,
gradient_clip_val: T.Optional[float] = 1.0,
reset_model: T.Optional[bool] = False,
auto_lr_find: T.Optional[bool] = False,
device: T.Optional[str] = "gpu",
devices: T.Optional[int] = 1,
weight_decay: T.Optional[float] = 1e-5,
precision: T.Optional[int] = 32,
stochastic_weight_averaging: T.Optional[bool] = False,
stochastic_weight_averaging_lr: T.Optional[float] = 0.05,
stochastic_weight_averaging_start: T.Optional[float] = 0.8,
model_pruning: T.Optional[bool] = False,
resize_height: T.Optional[int] = 201,
resize_width: T.Optional[int] = 201,
min_image_size: T.Optional[int] = 100,
max_image_size: T.Optional[int] = 600,
trainable_backbone_layers: T.Optional[int] = 3,
):
"""Fits a Mask R-CNN instance model.
Args:
dataset (EdgeDataset): The dataset to fit on.
ckpt_file (str | Path): The checkpoint file path.
test_dataset (Optional[EdgeDataset]): A test dataset to evaluate on. If given, early stopping
will switch from the validation dataset to the test dataset.
val_frac (Optional[float]): The fraction of data to use for model validation.
batch_size (Optional[int]): The data batch size.
filters (Optional[int]): The number of initial model filters.
learning_rate (Optional[float]): The model learning rate.
epochs (Optional[int]): The number of epochs.
save_top_k (Optional[int]): The number of top-k model checkpoints to save.
early_stopping_patience (Optional[int]): The patience (epochs) before early stopping.
early_stopping_min_delta (Optional[float]): The minimum change threshold before early stopping.
gradient_clip_val (Optional[float]): A gradient clip limit.
reset_model (Optional[bool]): Whether to reset an existing model. Otherwise, pick up from last epoch of
an existing model.
auto_lr_find (Optional[bool]): Whether to search for an optimized learning rate.
device (Optional[str]): The device to train on. Choices are ['cpu', 'gpu'].
devices (Optional[int]): The number of GPU devices to use.
weight_decay (Optional[float]): The weight decay passed to the optimizer. Default is 1e-5.
precision (Optional[int]): The data precision. Default is 32.
stochastic_weight_averaging (Optional[bool]): Whether to use stochastic weight averaging.
Default is False.
stochastic_weight_averaging_lr (Optional[float]): The stochastic weight averaging learning rate.
Default is 0.05.
stochastic_weight_averaging_start (Optional[float]): The stochastic weight averaging epoch start.
Default is 0.8.
model_pruning (Optional[bool]): Whether to prune the model. Default is False.
"""
ckpt_file = Path(ckpt_file)

# Split the dataset into train/validation
train_ds, val_ds = dataset.split_train_val(val_frac=val_frac)

# Setup the data module
data_module = EdgeDataModule(
train_ds=train_ds,
val_ds=val_ds,
test_ds=test_dataset,
batch_size=batch_size,
num_workers=0,
shuffle=True,
)
lit_model = MaskRCNNLitModel(
cultionet_model_file=ckpt_file.parent / "cultionet.pt",
cultionet_num_features=train_ds.num_features,
cultionet_num_time_features=train_ds.num_time_features,
cultionet_filters=filters,
cultionet_num_classes=num_classes,
learning_rate=learning_rate,
weight_decay=weight_decay,
resize_height=resize_height,
resize_width=resize_width,
min_image_size=min_image_size,
max_image_size=max_image_size,
trainable_backbone_layers=trainable_backbone_layers,
)

if reset_model:
if ckpt_file.is_file():
ckpt_file.unlink()
model_file = ckpt_file.parent / "maskrcnn.pt"
if model_file.is_file():
model_file.unlink()

# Checkpoint
cb_train_loss = ModelCheckpoint(
dirpath=ckpt_file.parent,
filename=ckpt_file.stem,
save_last=True,
save_top_k=save_top_k,
mode="min",
monitor="loss",
every_n_train_steps=0,
every_n_epochs=1,
)
# Validation and test loss
cb_val_loss = ModelCheckpoint(monitor="val_loss")
# Early stopping
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=early_stopping_min_delta,
patience=early_stopping_patience,
mode="min",
check_on_train_epoch_end=False,
)
# Learning rate
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [lr_monitor, cb_train_loss, cb_val_loss, early_stop_callback]
if stochastic_weight_averaging:
callbacks.append(
StochasticWeightAveraging(
swa_lrs=stochastic_weight_averaging_lr,
swa_epoch_start=stochastic_weight_averaging_start,
)
)
if 0 < model_pruning <= 1:
callbacks.append(ModelPruning("l1_unstructured", amount=model_pruning))

trainer = L.Trainer(
default_root_dir=str(ckpt_file.parent),
callbacks=callbacks,
enable_checkpointing=True,
accumulate_grad_batches=accumulate_grad_batches,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm="value",
check_val_every_n_epoch=1,
min_epochs=5 if epochs >= 5 else epochs,
max_epochs=epochs,
precision=precision,
devices=devices,
accelerator=device,
log_every_n_steps=50,
profiler=None,
deterministic=False,
benchmark=False,
)

if auto_lr_find:
trainer.tune(model=lit_model, datamodule=data_module)
else:
trainer.fit(
model=lit_model,
datamodule=data_module,
ckpt_path=ckpt_file if ckpt_file.is_file() else None,
)
if test_dataset is not None:
trainer.test(
model=lit_model,
dataloaders=data_module.test_dataloader(),
ckpt_path="last",
)


def get_data_module(
dataset: EdgeDataset,
test_dataset: T.Optional[EdgeDataset] = None,
Expand Down
Loading

0 comments on commit 46c6a29

Please sign in to comment.