Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

freeze_backbone and freeze_decoder in Trainers #1290

Merged
merged 10 commits into from
Apr 26, 2023
14 changes: 14 additions & 0 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,20 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None
)
trainer.predict(model=model, datamodule=datamodule)

@pytest.mark.parametrize(
"model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"]
)
def test_freeze_backbone(
self, model_name: str, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["freeze_backbone"] = True
model_kwargs["model"] = model_name
model = ClassificationTask(**model_kwargs)
assert not all([param.requires_grad for param in model.model.parameters()])
assert all(
[param.requires_grad for param in model.model.get_classifier().parameters()]
)


class TestMultiLabelClassificationTask:
@pytest.mark.parametrize(
Expand Down
9 changes: 9 additions & 0 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,12 @@ def test_predict(self, model_kwargs: dict[Any, Any], fast_dev_run: bool) -> None
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)

@pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"])
def test_freeze_backbone(
self, model_name: str, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["freeze_backbone"] = True
model_kwargs["model"] = model_name
model = ObjectDetectionTask(**model_kwargs)
assert not all([param.requires_grad for param in model.model.parameters()])
54 changes: 54 additions & 0 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ def test_invalid_loss(self, model_kwargs: dict[str, Any]) -> None:
with pytest.raises(ValueError, match=match):
RegressionTask(**model_kwargs)

@pytest.mark.parametrize(
"model_name", ["resnet18", "efficientnetv2_s", "vit_base_patch16_384"]
)
def test_freeze_backbone(
self, model_name: str, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["freeze_backbone"] = True
model_kwargs["model"] = model_name
model = RegressionTask(**model_kwargs)
assert not all([param.requires_grad for param in model.model.parameters()])
assert all(
[param.requires_grad for param in model.model.get_classifier().parameters()]
)


class TestPixelwiseRegressionTask:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -282,3 +296,43 @@ def model_kwargs(self) -> dict[str, Any]:
"learning_rate": 1e-3,
"learning_rate_schedule_patience": 6,
}

@pytest.mark.parametrize(
"backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"]
)
@pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
def test_freeze_backbone(
self, backbone: str, model_name: str, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["freeze_backbone"] = True
model_kwargs["model"] = model_name
model_kwargs["backbone"] = backbone
model = PixelwiseRegressionTask(**model_kwargs)
assert all(
[param.requires_grad is False for param in model.model.encoder.parameters()]
)
assert all([param.requires_grad for param in model.model.decoder.parameters()])
assert all(
[
param.requires_grad
for param in model.model.segmentation_head.parameters()
]
)

@pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
def test_freeze_decoder(
self, model_name: str, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["freeze_decoder"] = True
model_kwargs["model"] = model_name
model = PixelwiseRegressionTask(**model_kwargs)
assert all(
[param.requires_grad is False for param in model.model.decoder.parameters()]
)
assert all([param.requires_grad for param in model.model.encoder.parameters()])
assert all(
[
param.requires_grad
for param in model.model.segmentation_head.parameters()
]
)
40 changes: 40 additions & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,43 @@ def test_no_rgb(
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

@pytest.mark.parametrize(
"backbone", ["resnet18", "mobilenet_v2", "efficientnet-b0"]
)
@pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
def test_freeze_backbone(
self, backbone: str, model_name: str, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["freeze_backbone"] = True
model_kwargs["model"] = model_name
model_kwargs["backbone"] = backbone
model = SemanticSegmentationTask(**model_kwargs)
assert all(
[param.requires_grad is False for param in model.model.encoder.parameters()]
)
assert all([param.requires_grad for param in model.model.decoder.parameters()])
assert all(
[
param.requires_grad
for param in model.model.segmentation_head.parameters()
]
)

@pytest.mark.parametrize("model_name", ["unet", "deeplabv3+"])
def test_freeze_decoder(
self, model_name: str, model_kwargs: dict[Any, Any]
) -> None:
model_kwargs["freeze_decoder"] = True
model_kwargs["model"] = model_name
model = SemanticSegmentationTask(**model_kwargs)
assert all(
[param.requires_grad is False for param in model.model.decoder.parameters()]
)
assert all([param.requires_grad for param in model.model.encoder.parameters()])
assert all(
[
param.requires_grad
for param in model.model.segmentation_head.parameters()
]
)
11 changes: 11 additions & 0 deletions torchgeo/trainers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def config_model(self) -> None:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)

# Freeze backbone and unfreeze classifier head
if self.hyperparams.get("freeze_backbone", False):
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
param.requires_grad = True

def config_task(self) -> None:
"""Configures the task based on kwargs parameters passed to the constructor."""
self.config_model()
Expand All @@ -90,9 +97,13 @@ def __init__(self, **kwargs: Any) -> None:
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
freeze_backbone: Fine-tune the cls head by freezing the backbone
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

.. versionchanged:: 0.4
The *classification_model* parameter was renamed to *model*.
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved

.. versionchanged:: 0.5
Added *freeze_backbone* parameter.
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__()

Expand Down
17 changes: 17 additions & 0 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def config_task(self) -> None:
roi_pooler = MultiScaleRoIAlign(
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
)

if self.hyperparams.get("freeze_backbone", False):
for param in backbone.parameters():
param.requires_grad = False

self.model = torchvision.models.detection.FasterRCNN(
backbone,
num_classes,
Expand All @@ -113,6 +118,10 @@ def config_task(self) -> None:
aspect_ratios=((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)),
)

if self.hyperparams.get("freeze_backbone", False):
for param in backbone.parameters():
param.requires_grad = False

self.model = torchvision.models.detection.FCOS(
backbone, num_classes, anchor_generator=anchor_generator
)
Expand Down Expand Up @@ -140,6 +149,10 @@ def config_task(self) -> None:
norm_layer=partial(torch.nn.GroupNorm, 32),
)

if self.hyperparams.get("freeze_backbone", False):
for param in backbone.parameters():
param.requires_grad = False

self.model = torchvision.models.detection.RetinaNet(
backbone, num_classes, anchor_generator=anchor_generator, head=head
)
Expand All @@ -156,12 +169,16 @@ def __init__(self, **kwargs: Any) -> None:
num_classes: Number of semantic classes to predict
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
freeze_backbone: Fine-tune the detection head by freezing the backbone

Raises:
ValueError: if kwargs arguments are invalid

.. versionchanged:: 0.4
The *detection_model* parameter was renamed to *model*.

.. versionchanged:: 0.5
Added *freeze_backbone* parameter.
"""
super().__init__()
# Creates `self.hparams` from kwargs
Expand Down
25 changes: 25 additions & 0 deletions torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ def config_model(self) -> None:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model = utils.load_state_dict(self.model, state_dict)

# Freeze backbone and unfreeze classifier head
if self.hyperparams.get("freeze_backbone", False):
for param in self.model.parameters():
param.requires_grad = False
for param in self.model.get_classifier().parameters():
param.requires_grad = True

def config_task(self) -> None:
"""Configures the task based on kwargs parameters."""
self.config_model()
Expand Down Expand Up @@ -86,9 +93,13 @@ def __init__(self, **kwargs: Any) -> None:
in_channels: Number of input channels to model
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
freeze_backbone: Fine-tune the cls head by freezing the backbone

.. versionchanged:: 0.4
Change regression model support from torchvision.models to timm

.. versionchanged:: 0.5
Added *freeze_backbone* parameter.
"""
super().__init__()

Expand Down Expand Up @@ -294,3 +305,17 @@ def config_model(self) -> None:
f"Model type '{self.hyperparams['model']}' is not valid. "
f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)

# Freeze backbone
if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
"model"
] in ["unet", "deeplabv3+"]:
for param in self.model.encoder.parameters():
param.requires_grad = False

# Freeze decoder
if self.hyperparams.get("freeze_decoder", False) and self.hyperparams[
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
"model"
] in ["unet", "deeplabv3+"]:
for param in self.model.decoder.parameters():
param.requires_grad = False
18 changes: 17 additions & 1 deletion torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ def config_task(self) -> None:
f"Currently, supports 'ce', 'jaccard' or 'focal' loss."
)

# Freeze backbone
if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
"model"
] in ["unet", "deeplabv3+"]:
isaaccorley marked this conversation as resolved.
Show resolved Hide resolved
for param in self.model.encoder.parameters():
param.requires_grad = False

# Freeze decoder
if self.hyperparams.get("freeze_decoder", False) and self.hyperparams[
"model"
] in ["unet", "deeplabv3+"]:
for param in self.model.decoder.parameters():
param.requires_grad = False

def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.

Expand Down Expand Up @@ -110,7 +124,9 @@ class and used with 'ce' loss
*encoder_weights* to *weights*.

.. versionadded: 0.5
The *class_weights* parameter.
The *class_weights*, *freeze_backbone*,
and *freeze_decoder* parameters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not documented.


"""
super().__init__()

Expand Down