diff --git a/rastervision_core/rastervision/core/data/utils/misc.py b/rastervision_core/rastervision/core/data/utils/misc.py index a26328d32..914d18633 100644 --- a/rastervision_core/rastervision/core/data/utils/misc.py +++ b/rastervision_core/rastervision/core/data/utils/misc.py @@ -13,8 +13,8 @@ log = logging.getLogger(__name__) -def color_to_triple( - color: str | Sequence | None = None) -> tuple[int, int, int]: +def color_to_triple(color: str | Sequence | None = None + ) -> list[str] | tuple[int, int, int]: """Given a PIL ImageColor string, return a triple of integers representing the red, green, and blue values. diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index f9762f83b..ce370a92c 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -1181,8 +1181,9 @@ def build_dataset(self, ds = self.cfg.data.build_dataset(split=split, tmp_dir=self.tmp_dir) return ds - def build_dataloaders(self, distributed: bool | None = None - ) -> tuple[DataLoader, DataLoader, DataLoader]: + def build_dataloaders( + self, distributed: bool | None = None + ) -> tuple[DataLoader, DataLoader, DataLoader | None]: """Build DataLoaders for train, validation, and test splits.""" if distributed is None: distributed = self.distributed diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index c9326ad1e..1ed445586 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -680,7 +680,7 @@ def num_classes(self): @field_validator('augmentors') @classmethod - def validate_augmentors(cls, v: list[str]) -> str: + def validate_augmentors(cls, v: list[str]) -> list[str]: for aug_name in v: if aug_name not in augmentors: raise ConfigError(f'Unsupported augmentor "{aug_name}"') @@ -862,8 +862,7 @@ def validate_group_uris(self) -> Self: def _build_dataset(self, dirs: Iterable[str], - tf: A.BasicTransform | None = None - ) -> tuple[Dataset, Dataset, Dataset]: + tf: A.BasicTransform | None = None) -> Dataset: """Make datasets for a single split. Args: @@ -1224,7 +1223,7 @@ def _build_dataset(self, split: Literal['train', 'valid', 'test'], tf: A.BasicTransform | None = None, tmp_dir: str | None = None, - **kwargs) -> tuple[Dataset, Dataset, Dataset]: + **kwargs) -> Dataset: """Make training, validation, and test datasets. Args: