Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseTask: fix load_from_checkpoint, ignore 'ignore' #2317

Merged
merged 2 commits into from
Sep 28, 2024

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Sep 24, 2024

Alternative to #2314. @calebrob6 can you see if this also fixes your issue?

I think this makes more sense, as ignore is more of a class attribute than an instance parameter. It also makes it possible to add subclasses with additional ignores. Downside is that it's technically backwards incompatible and will have to wait until 0.7.0.

@adamjstewart adamjstewart added the backwards-incompatible Changes that are not backwards compatible label Sep 24, 2024
@adamjstewart adamjstewart added this to the 0.7.0 milestone Sep 24, 2024
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Sep 24, 2024
@adamjstewart
Copy link
Collaborator Author

I haven't yet found a simple way to test this, may skip the tests.

@adamjstewart adamjstewart marked this pull request as ready for review September 24, 2024 16:40
@calebrob6
Copy link
Member

🫢

@calebrob6
Copy link
Member

Yep it works! Here is the code I used to test if that's helpful:

from torchgeo.trainers import SemanticSegmentationTask
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset

class CustomSemanticSegmentationTask(SemanticSegmentationTask):
    def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def configure_optimizers(self):
        return {
            'optimizer': Adam(self.parameters(), lr=self.hparams['lr']),
        }

class TestDataset(Dataset):
    def __init__(self, n_samples=10):
        self.n_samples = n_samples
        self.data = torch.rand((n_samples, 3, 256, 256))
        self.targets = torch.randint(0, 2, (n_samples, 256, 256))
    def __len__(self):
        return self.n_samples
    def __getitem__(self, idx):
        return {
            "image": self.data[idx],
            "mask": self.targets[idx],
        }



task = CustomSemanticSegmentationTask(model="unet", tmax=100)

dataset = TestDataset(10)
train_loader = DataLoader(dataset, batch_size=5)

checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',
    dirpath='./tmp/',
    filename='my_model-{epoch:02d}',
    save_top_k=1,
    mode='min'
)

trainer = Trainer(
    max_epochs=1,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=0,
)

trainer.fit(task, train_loader)

task = CustomSemanticSegmentationTask.load_from_checkpoint("tmp/my_model-epoch=00.ckpt")

@adamjstewart adamjstewart merged commit 69f91a2 into microsoft:main Sep 28, 2024
19 checks passed
@adamjstewart adamjstewart deleted the trainers/base branch September 28, 2024 13:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants