diff --git a/README.md b/README.md index cba592c294f..8817d730785 100644 --- a/README.md +++ b/README.md @@ -119,12 +119,29 @@ TorchGeo includes a number of [*benchmark datasets*](https://torchgeo.readthedoc If you've used [torchvision](https://pytorch.org/vision) before, these datasets should seem very familiar. In this example, we'll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class ([VHR-10](https://github.com/chaozhong2010/VHR-10_dataset_coco)) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision. ```python +from torch.utils.data import DataLoader + +from torchgeo.datamodules.utils import collate_fn_detection +from torchgeo.datasets import VHR10 + +# Initialize the dataset dataset = VHR10(root="...", download=True, checksum=True) -dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) +# Initialize the dataloader with the custom collate function +dataloader = DataLoader( + dataset, + batch_size=128, + shuffle=True, + num_workers=4, + collate_fn=collate_fn_detection, +) + +# Training loop for batch in dataloader: - image = batch["image"] - label = batch["label"] + images = batch["image"] # list of images + boxes = batch["boxes"] # list of boxes + labels = batch["labels"] # list of labels + masks = batch["masks"] # list of masks # train a model, or make predictions using a pre-trained model ```