Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update info on LightningCLI #1628

Merged
merged 2 commits into from
Oct 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,66 @@ trainer.fit(model=task, datamodule=datamodule)

<img src="https://raw.githubusercontent.com/microsoft/torchgeo/main/images/inria.png" alt="Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset"/>

In our GitHub repo, we provide `train.py` and `evaluate.py` scripts to train and evaluate the performance of models using these datamodules and trainers. These scripts are configurable via the command line and/or via YAML configuration files. See the [conf](https://github.com/microsoft/torchgeo/blob/main/conf) directory for example configuration files that can be customized for different training runs.
TorchGeo also supports command-line interface training using the [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html). It can be invoked in two ways:
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved

```console
$ python train.py config_file=conf/landcoverai.yaml
# If torchgeo has been installed
torchgeo
# If torchgeo has been installed, or if it has been cloned to the current directory
python3 -m torchgeo
```

It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:

```console
# See valid stages
torchgeo --help
# See valid trainer options
torchgeo fit --help
# See valid model options
torchgeo fit --model.help ClassificationTask
# See valid data options
torchgeo fit --data.help EuroSAT100DataModule
```

Using the following config file:
```yaml
trainer:
max_epochs: 20
model:
class_path: ClassificationTask
init_args:
model: "resnet18"
in_channels: 13
num_classes: 10
data:
class_path: EuroSAT100DataModule
init_args:
batch_size: 8
dict_kwargs:
download: true
```

We can see the script in action:
robmarkcole marked this conversation as resolved.
Show resolved Hide resolved
```console
# Train and validate a model
torchgeo fit --config config.yaml
# Validate-only
torchgeo validate --config config.yaml
# Calculate and report test accuracy
torchgeo test --config config.yaml --trainer.ckpt_path=...
```

It can also be imported and used in a Python script if you need to extend it to add new features:

```python
from torchgeo.main import main

main(["fit", "--config", "config.yaml"])
```

See the [Lightning documentation](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) for more details.

## Citation

If you use this software in your work, please cite our [paper](https://dl.acm.org/doi/10.1145/3557915.3560953):
Expand Down
Loading