Skip to content

Commit

Permalink
Update type hints (#2192)
Browse files Browse the repository at this point in the history
* Replace deprecated aliases: Callable, Dict, List, Set, Tuple, Type, with the recommended replacements
* Replace Optional[x] with x | None
* Replace Union[x, y, ...] with x | y | ...
* Add signature to Callable type hints where feasible
* Use Self where applicable
* Remove argument type hints from docstrings (non-exhaustive)
* Fix/improve docstrings
* Fix docs build
  • Loading branch information
AdeelH authored Jul 8, 2024
1 parent 8f9f6d4 commit 116e788
Show file tree
Hide file tree
Showing 181 changed files with 2,056 additions and 2,185 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import List

from rastervision.pipeline.pipeline import Pipeline


class TestPipeline(Pipeline):
commands: List[str] = ['print_msg']
commands: list[str] = ['print_msg']

def print_msg(self):
print(self.config.message)
8 changes: 4 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# full list see the documentation:
# http://www.sphinx-doc.org/en/stable/config

from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING
import sys
from unittest.mock import MagicMock

Expand All @@ -21,13 +21,13 @@ def __getattr__(cls, name):
return MagicMock()


MOCK_MODULES = ['pyproj', 'h5py', 'osgeo']
MOCK_MODULES = ['h5py', 'osgeo']
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)

# -- Allow Jinja templates in non-template .rst files -------------------------


def rstjinja(app: 'Sphinx', docname: str, source: List[str]) -> None:
def rstjinja(app: 'Sphinx', docname: str, source: list[str]) -> None:
"""Allow use of jinja templating in all doc pages.
Adapted from:
Expand Down Expand Up @@ -124,7 +124,7 @@ def setup(app: 'Sphinx') -> None:
autodoc_typehints = 'both'
autodoc_class_signature = 'separated'
autodoc_member_order = 'groupwise'
autodoc_mock_imports = ['torch', 'torchvision', 'pycocotools', 'geopandas']
autodoc_mock_imports = ['pycocotools']
#########################

#########################
Expand Down
13 changes: 6 additions & 7 deletions integration_tests/integration_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

from typing import List
from os.path import join, dirname, abspath, isfile
import math
import traceback
Expand Down Expand Up @@ -108,7 +107,7 @@ def get_actual_eval_path(test_id: str, tmp_dir: str) -> str:


def check_eval_item(test_id: str, test_cfg: dict, expected_item: dict,
actual_item: dict) -> List[TestError]:
actual_item: dict) -> list[TestError]:
errors = []
f1_threshold = 0.05
class_name = expected_item['class_name']
Expand All @@ -125,7 +124,7 @@ def check_eval_item(test_id: str, test_cfg: dict, expected_item: dict,
return errors


def check_eval(test_id: str, test_cfg: dict, tmp_dir: str) -> List[TestError]:
def check_eval(test_id: str, test_cfg: dict, tmp_dir: str) -> list[TestError]:
errors = []

actual_eval_path = get_actual_eval_path(test_id, tmp_dir)
Expand All @@ -152,7 +151,7 @@ def check_eval(test_id: str, test_cfg: dict, tmp_dir: str) -> List[TestError]:

def test_model_bundle_validation(pipeline, test_id: str, test_cfg: dict,
tmp_dir: str,
image_uri: str) -> List[TestError]:
image_uri: str) -> list[TestError]:
console_info('Checking predict command validation...')
errors = []
model_bundle_uri = pipeline.get_model_bundle_uri()
Expand All @@ -171,7 +170,7 @@ def test_model_bundle_validation(pipeline, test_id: str, test_cfg: dict,

def test_model_bundle_results(pipeline, test_id: str, test_cfg: dict,
tmp_dir: str, scenes: list,
scenes_to_uris: dict) -> List[TestError]:
scenes_to_uris: dict) -> list[TestError]:
console_info('Checking model bundle produces same results...')
errors = []
model_bundle_uri = pipeline.get_model_bundle_uri()
Expand Down Expand Up @@ -213,7 +212,7 @@ def test_model_bundle(pipeline,
test_id: str,
test_cfg: dict,
tmp_dir: str,
check_channel_order: bool = False) -> List[TestError]:
check_channel_order: bool = False) -> list[TestError]:
# Check the model bundle.
# This will only work with raster_sources that
# have a single URI.
Expand Down Expand Up @@ -250,7 +249,7 @@ def test_model_bundle(pipeline,
return errors


def run_test(test_id: str, test_cfg: dict, tmp_dir: str) -> List[TestError]:
def run_test(test_id: str, test_cfg: dict, tmp_dir: str) -> list[TestError]:
msg = f'\nRunning test: {test_id}'
console_info(msg, bold=True)
console_info('With params:')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any
import logging
import os
import uuid
Expand Down Expand Up @@ -33,7 +33,7 @@ class AWSBatchRunner(Runner):
def run(self,
cfg_json_uri: str,
pipeline: 'Pipeline',
commands: List[str],
commands: list[str],
num_splits: int = 1,
pipeline_run_name: str = 'raster-vision'): # pragma: no cover
parent_job_ids = []
Expand Down Expand Up @@ -65,7 +65,7 @@ def build_cmd(self,
pipeline: 'Pipeline',
num_splits: int = 1,
pipeline_run_name: str = 'raster-vision'
) -> Tuple[List[str], Dict[str, Any]]:
) -> tuple[list[str], dict[str, Any]]:

verbosity = rv_config.get_verbosity_cli_opt()

Expand Down Expand Up @@ -105,15 +105,15 @@ def get_split_ind(self) -> int:
return int(os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX', 0))

def run_command(self,
cmd: List[str],
job_name: Optional[str] = None,
cmd: list[str],
job_name: str | None = None,
debug: bool = False,
attempts: int = 1,
parent_job_ids: Optional[List[str]] = None,
num_array_jobs: Optional[int] = None,
parent_job_ids: list[str] | None = None,
num_array_jobs: int | None = None,
use_gpu: bool = False,
job_queue: Optional[str] = None,
job_def: Optional[str] = None,
job_queue: str | None = None,
job_def: str | None = None,
**kwargs) -> str: # pragma: no cover
"""Submit a command as a job to AWS Batch.
Expand Down
4 changes: 2 additions & 2 deletions rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterator, Tuple
from typing import Any, Iterator
import io
import os
import subprocess
Expand Down Expand Up @@ -132,7 +132,7 @@ def matches_uri(uri: str, mode: str) -> bool:
return parsed_uri.scheme == 's3'

@staticmethod
def parse_uri(uri: str) -> Tuple[str, str]:
def parse_uri(uri: str) -> tuple[str, str]:
"""Parse bucket name and key from an S3 URI."""
parsed_uri = urlparse(uri)
bucket, key = parsed_uri.netloc, parsed_uri.path[1:]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING
from os.path import join, basename
import logging
from pprint import pprint
Expand Down Expand Up @@ -67,9 +67,9 @@ class AWSSageMakerRunner(Runner):
def run(self,
cfg_json_uri: str,
pipeline: 'Pipeline',
commands: List[str],
commands: list[str],
num_splits: int = 1,
cmd_prefix: List[str] = [
cmd_prefix: list[str] = [
'python', '-m', 'rastervision.pipeline.cli'
],
pipeline_run_name: str = 'rv'):
Expand All @@ -95,9 +95,9 @@ def run(self,
def build_pipeline(self,
cfg_json_uri: str,
pipeline: 'Pipeline',
commands: List[str],
commands: list[str],
num_splits: int = 1,
cmd_prefix: List[str] = [
cmd_prefix: list[str] = [
'python', '-m', 'rastervision.pipeline.cli'
],
pipeline_run_name: str = 'rv') -> 'SageMakerPipeline':
Expand Down Expand Up @@ -213,7 +213,7 @@ def build_step(self,
pipeline: 'RVPipeline',
step_name: str,
job_name: str,
cmd: List[str],
cmd: list[str],
role: str,
image_uri: str,
instance_type: str,
Expand All @@ -222,7 +222,7 @@ def build_step(self,
instance_count: int = 1,
max_wait: int = DEFAULT_MAX_RUN_TIME,
max_run: int = DEFAULT_MAX_RUN_TIME,
**kwargs) -> Union['TrainingStep', 'ProcessingStep']:
**kwargs) -> 'TrainingStep | ProcessingStep':
"""Build appropriate SageMaker pipeline step.
If ``step_name=='train'``, builds a :class:`.TrainingStep`. Otherwise,
Expand All @@ -247,8 +247,7 @@ def build_step(self,
max_run=max_run,
**kwargs,
)
step_args: Optional['_JobStepArguments'] = estimator.fit(
wait=False)
step_args: '_JobStepArguments | None' = estimator.fit(wait=False)
step = TrainingStep(job_name, step_args=step_args)
else:
from sagemaker.processing import Processor
Expand All @@ -263,38 +262,38 @@ def build_step(self,
entrypoint=cmd,
**kwargs,
)
step_args: Optional['_JobStepArguments'] = step_processor.run(
step_args: '_JobStepArguments | None' = step_processor.run(
wait=False)
step = ProcessingStep(job_name, step_args=step_args)

return step

def run_command(self,
cmd: List[str],
cmd: list[str],
use_gpu: bool = False,
image_uri: Optional[str] = None,
instance_type: Optional[str] = None,
role: Optional[str] = None,
job_name: Optional[str] = None,
sagemaker_session: Optional['Session'] = None) -> None:
image_uri: str | None = None,
instance_type: str | None = None,
role: str | None = None,
job_name: str | None = None,
sagemaker_session: 'Session | None' = None) -> None:
"""Run a single command as a SageMaker processing job.
Args:
cmd (List[str]): The command to run.
cmd (list[str]): The command to run.
use_gpu (bool): Use the GPU instance type and image from the
Everett config. This is ignored if image_uri and instance_type
are provided. Defaults to False.
image_uri (Optional[str]): URI of docker image to use. If not
image_uri (str | None): URI of docker image to use. If not
provided, will be picked up from Everett config.
Defaults to None.
instance_type (Optional[str]): AWS instance type to use. If not
instance_type (str | None): AWS instance type to use. If not
provided, will be picked up from Everett config.
Defaults to None.
role (Optional[str]): AWS IAM role with SageMaker permissions. If
role (str | None): AWS IAM role with SageMaker permissions. If
not provided, will be picked up from Everett config.
Defaults to None.
job_name (Optional[str]): Optional job name. Defaults to None.
sagemaker_session (Optional[Session]): SageMaker session.
job_name (str | None): Optional job name. Defaults to None.
sagemaker_session (Session | None): SageMaker session.
Defaults to None.
"""
from sagemaker.processing import Processor
Expand Down Expand Up @@ -331,8 +330,8 @@ def _build_pytorch_estimator(self,
sagemaker_session: 'PipelineSession',
use_spot_instances: bool = False,
instance_count: int = 1,
distribution: Optional[dict] = None,
job_name: Optional[str] = None,
distribution: dict | None = None,
job_name: str | None = None,
**kwargs):
from sagemaker.pytorch import PyTorch
from rastervision.aws_s3.s3_file_system import S3FileSystem
Expand Down
3 changes: 1 addition & 2 deletions rastervision_core/rastervision/core/analyzer/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import List
from abc import (ABC, abstractmethod)

from rastervision.core.data import Scene
Expand All @@ -11,5 +10,5 @@ class Analyzer(ABC):
"""

@abstractmethod
def process(self, scenes: List[Scene], tmp_dir: str):
def process(self, scenes: list[Scene], tmp_dir: str):
"""Process scenes and save result."""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Iterable, Optional, Tuple
from typing import TYPE_CHECKING, Iterable

from rastervision.pipeline.config import register_config, Config

Expand All @@ -10,11 +10,11 @@
class AnalyzerConfig(Config):
"""Configure an :class:`.Analyzer`."""

def build(self, scene_group: Optional[Tuple[str, Iterable[str]]] = None
def build(self, scene_group: tuple[str, Iterable[str]] | None = None
) -> 'Analyzer':
pass

def get_bundle_filenames(self) -> List[str]:
def get_bundle_filenames(self) -> list[str]:
"""Returns the names of files that should be included in a model bundle.
The files are assumed to be in the analyze/ directory generated by the analyze
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional
from typing import Iterable

from rastervision.core.analyzer import Analyzer
from rastervision.core.raster_stats import RasterStats
Expand All @@ -9,10 +9,10 @@ class StatsAnalyzer(Analyzer):
"""Compute imagery statistics of scenes."""

def __init__(self,
stats_uri: Optional[str] = None,
stats_uri: str | None = None,
sample_prob: float = 0.1,
chip_sz: int = 300,
nodata_value: Optional[float] = 0):
nodata_value: float | None = 0):
self.stats_uri = stats_uri
self.sample_prob = sample_prob
self.chip_sz = chip_sz
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from typing import TYPE_CHECKING, Iterable
from os.path import join

from rastervision.pipeline.config import register_config, ConfigError, Field
Expand All @@ -16,13 +16,13 @@ class StatsAnalyzerConfig(AnalyzerConfig):
be used to normalize chips read from them.
"""

output_uri: Optional[str] = Field(
output_uri: str | None = Field(
None,
description='URI of directory where stats will be saved. '
'Stats for a scene-group will be save in a JSON file at '
'<output_uri>/<scene-group-name>/stats.json. If None, and this Config '
'is part of an RVPipeline, this field will be auto-generated.')
sample_prob: Optional[float] = Field(
sample_prob: float | None = Field(
0.1,
description=(
'The probability of using a random window for computing statistics. '
Expand All @@ -31,20 +31,20 @@ class StatsAnalyzerConfig(AnalyzerConfig):
300,
description='Chip size to use when sampling chips to compute stats '
'from.')
nodata_value: Optional[float] = Field(
nodata_value: float | None = Field(
0,
description='NODATA value. If set, these pixels will be ignored when '
'computing stats.')

def update(self, pipeline: Optional['RVPipelineConfig'] = None) -> None:
def update(self, pipeline: 'RVPipelineConfig | None' = None) -> None:
if pipeline is not None and self.output_uri is None:
self.output_uri = join(pipeline.analyze_uri, 'stats')

def validate_config(self):
if self.sample_prob > 1 or self.sample_prob <= 0:
raise ConfigError('sample_prob must be <= 1 and > 0')

def build(self, scene_group: Optional[Tuple[str, Iterable[str]]] = None
def build(self, scene_group: tuple[str, Iterable[str]] | None = None
) -> StatsAnalyzer:
if scene_group is None:
output_uri = join(self.output_uri, f'stats.json')
Expand Down
Loading

0 comments on commit 116e788

Please sign in to comment.