Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning 1.5 auto adds "fit" key to top level of CLI config breaking jsonargparse configuration check #10460

Closed
rbracco opened this issue Nov 10, 2021 · 4 comments
Assignees
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working help wanted Open to be worked on

Comments

@rbracco
Copy link
Contributor

rbracco commented Nov 10, 2021

🐛 Bug

Since the CLI subcommands were added in 1.5 my config files break after running because a new top level key "fit" is added. I run the command python train_scripts/trainer.py fit --config config.yaml as recommended in the docs and this runs properly, but the config.yaml is overwritten and goes from the format

trainer:
   gpus: 1
data:
   num_workers: 8

to having it all nested under the 'fit' key as follows:

fit:
   trainer:
       gpus: 1
   data:
       num_workers: 8

then when I run python train_scripts/trainer.py fit --config config.yaml again it fails with jsonargparse error message: trainer.py: error: Configuration check failed :: Key "data.val_manifest" is required but not included in config object or its value is None., presumably because that value is now found under "fit.data.val_manifest".

To Reproduce

Expected behavior

Running python train_scripts/trainer.py fit --config config.yaml should not clobber the config file. Expected behavior in the docs is that config files that come after the subcommand should not need to have top-level subcommand keys in the config

image

Environment

  • PyTorch Lightning Version (e.g., 1.5.1):
  • PyTorch Version (e.g., 1.10.0)
  • Python version: 3.8.10
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: 11.3
  • How you installed PyTorch (conda, pip, source): pip
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information: I updated both pytorch-lightning[extras] and jsonargparse
@rbracco rbracco added bug Something isn't working help wanted Open to be worked on labels Nov 10, 2021
@carmocca
Copy link
Contributor

Mind sharing your LightningCLI implementation and how you instantiate it?

@carmocca carmocca added the argparse (removed) Related to argument parsing (argparse, Hydra, ...) label Nov 11, 2021
@carmocca carmocca self-assigned this Nov 11, 2021
@rbracco
Copy link
Contributor Author

rbracco commented Nov 11, 2021

Sure, I've included it below. Thanks for taking a look.

class FinetuneCLI(LightningCLI):
    def add_arguments_to_parser(self, parser) -> None:
        parser.add_argument("--loss_function")
        parser.add_argument("--decoder")
        parser.add_lightning_class_args(FinetuneEncoderDecoder, "finetuner")
        parser.set_defaults(
            {
                "trainer.max_epochs": 10,
                "trainer.gpus": -1,
                "trainer.log_every_n_steps": 10,
                "trainer.flush_logs_every_n_steps": 50,
            }
        )
        return super().add_arguments_to_parser(parser)

    def before_fit(self):
        self.model.loss_function = getattr(loss_functions, self.config["loss_function"])
        DecoderClass = getattr(blocks, self.config["decoder"])
        num_classes = len(self.model.text_transform.vocab)
        model_arch = str(type(self.model))
        self.model.decoder = DecoderClass(1024, num_classes)


def get_quartznet(
    checkpoint: str = "QuartzNet15x5Base_En",
    learning_rate: float = 3e-4,
    labels: list = None,
) -> QuartznetModule:
    checkpoint = QuartznetCheckpoint.from_string(checkpoint)
    module = QNPronounceModule.load_from_nemo(checkpoint)
    module.hparams.optim_cfg.learning_rate = learning_rate
    text_config = TextTransformConfig(labels)
    module.change_vocab(text_config)
    return module


FinetuneCLI(
    model_class=get_quartznet,
    datamodule_class=PronounceDatamodule,
    save_config_overwrite=True,
)

@carmocca
Copy link
Contributor

So you are saying that running

python script.py fit --print_config > config.yaml

saves the correct config, which can be used later with

python script.py fit --config config.yaml

but at some point config.yaml gets overwritten with the config format that includes the subcommand.

If that's the case, I haven't been able to reproduce it with this script

import torch
from torch.utils.data import Dataset, DataLoader

from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.cli import LightningCLI


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)


class FinetuneCLI(LightningCLI):
    def add_arguments_to_parser(self, parser) -> None:
        parser.set_defaults({"trainer.max_epochs": 10})

    def before_fit(self):
        ...


FinetuneCLI(
    model_class=BoringModel,
    save_config_overwrite=True,
)

@rbracco
Copy link
Contributor Author

rbracco commented Nov 14, 2021

So you are saying that running

python script.py fit --print_config > config.yaml

saves the correct config, which can be used later with

python script.py fit --config config.yaml

but at some point config.yaml gets overwritten with the config format that includes the subcommand.

Looking at this I think the issue may be that I was using a config file I maintained from a previous version of lightning and the rules changed in 1.5. I already rolled back and have everything working but eventually I will upgrade again and reopen if the issue persists, or update confirming that the problem was the initial config formatting. Thanks for looking in to this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

2 participants