Skip to content

Commit

Permalink
Typing update
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jun 25, 2023
1 parent 321fed2 commit 4147430
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchgeo/datamodules/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""NWPU VHR-10 datamodule."""

from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Union

import kornia.augmentation as K
import torch
Expand All @@ -24,7 +24,7 @@ class _AugPipe(Module):
"""Pipeline for applying augmentations sequentially on select data keys."""

def __init__(
self, augs: Callable[[Dict[str, Any]], Dict[str, Any]], batch_size: int
self, augs: Callable[[dict[str, Any]], dict[str, Any]], batch_size: int
) -> None:
"""Initialize a new _AugPipe instance.
Expand All @@ -36,7 +36,7 @@ def __init__(
self.augs = augs
self.batch_size = batch_size

def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Apply the augmentation.
Args:
Expand Down Expand Up @@ -67,7 +67,7 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
return batch


def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
def collate_fn(batch: list[dict[str, Tensor]]) -> dict[str, Any]:
"""Custom object detection collate fn to handle variable boxes.
Args:
Expand All @@ -76,7 +76,7 @@ def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]:
Returns:
batch dict output
"""
output: Dict[str, Any] = {}
output: dict[str, Any] = {}
output["image"] = [sample["image"] for sample in batch]
output["boxes"] = [sample["boxes"] for sample in batch]
output["labels"] = [sample["labels"] for sample in batch]
Expand All @@ -93,7 +93,7 @@ class VHR10DataModule(NonGeoDataModule):
def __init__(
self,
batch_size: int = 64,
patch_size: Union[Tuple[int, int], int] = 512,
patch_size: Union[tuple[int, int], int] = 512,
num_workers: int = 0,
val_split_pct: float = 0.2,
test_split_pct: float = 0.2,
Expand Down

0 comments on commit 4147430

Please sign in to comment.