Skip to content

Commit

Permalink
Update info on LightningCLI (#1628)
Browse files Browse the repository at this point in the history
* Add info

* Address comments
  • Loading branch information
robmarkcole authored Oct 6, 2023
1 parent 7faf0bb commit c83181f
Showing 1 changed file with 56 additions and 2 deletions.
58 changes: 56 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,66 @@ trainer.fit(model=task, datamodule=datamodule)

<img src="https://raw.githubusercontent.com/microsoft/torchgeo/main/images/inria.png" alt="Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset"/>

In our GitHub repo, we provide `train.py` and `evaluate.py` scripts to train and evaluate the performance of models using these datamodules and trainers. These scripts are configurable via the command line and/or via YAML configuration files. See the [conf](https://github.com/microsoft/torchgeo/blob/main/conf) directory for example configuration files that can be customized for different training runs.
TorchGeo also supports command-line interface training using [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html). It can be invoked in two ways:

```console
$ python train.py config_file=conf/landcoverai.yaml
# If torchgeo has been installed
torchgeo
# If torchgeo has been installed, or if it has been cloned to the current directory
python3 -m torchgeo
```

It supports command-line configuration or YAML/JSON config files. Valid options can be found from the help messages:

```console
# See valid stages
torchgeo --help
# See valid trainer options
torchgeo fit --help
# See valid model options
torchgeo fit --model.help ClassificationTask
# See valid data options
torchgeo fit --data.help EuroSAT100DataModule
```

Using the following config file:
```yaml
trainer:
max_epochs: 20
model:
class_path: ClassificationTask
init_args:
model: "resnet18"
in_channels: 13
num_classes: 10
data:
class_path: EuroSAT100DataModule
init_args:
batch_size: 8
dict_kwargs:
download: true
```
we can see the script in action:
```console
# Train and validate a model
torchgeo fit --config config.yaml
# Validate-only
torchgeo validate --config config.yaml
# Calculate and report test accuracy
torchgeo test --config config.yaml --trainer.ckpt_path=...
```

It can also be imported and used in a Python script if you need to extend it to add new features:

```python
from torchgeo.main import main

main(["fit", "--config", "config.yaml"])
```

See the [Lightning documentation](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) for more details.

## Citation

If you use this software in your work, please cite our [paper](https://dl.acm.org/doi/10.1145/3557915.3560953):
Expand Down

0 comments on commit c83181f

Please sign in to comment.