Skip to content

Commit

Permalink
Use smarter default padding value for sliding windows (#2190)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH authored Jul 5, 2024
1 parent adac1a2 commit 8f9f6d4
Show file tree
Hide file tree
Showing 10 changed files with 426 additions and 80 deletions.
155 changes: 142 additions & 13 deletions docs/usage/tutorials/sampling_training_data.ipynb

Large diffs are not rendered by default.

66 changes: 33 additions & 33 deletions rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rasterio.windows import Window as RioWindow

from rastervision.pipeline.utils import repr_with_args
from rastervision.core.utils.misc import (calculate_required_padding,
ensure_tuple)

if TYPE_CHECKING:
from shapely.geometry import MultiPolygon
Expand Down Expand Up @@ -355,54 +357,52 @@ def pad(self, ymin: int, xmin: int, ymax: int, xmax: int) -> 'Box':
def copy(self) -> 'Box':
return Box(*self)

def get_windows(self,
size: Union[PosInt, Tuple[PosInt, PosInt]],
stride: Union[PosInt, Tuple[PosInt, PosInt]],
padding: Optional[Union[NonNegInt, Tuple[
NonNegInt, NonNegInt]]] = None,
pad_direction: Literal['both', 'start', 'end'] = 'end'
) -> List['Box']:
"""Returns a list of boxes representing windows generated using a
sliding window traversal with the specified size, stride, and
padding.
def get_windows(
self,
size: PosInt | tuple[PosInt, PosInt],
stride: PosInt | tuple[PosInt, PosInt],
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
pad_direction: Literal['both', 'start', 'end'] = 'end'
) -> list['Box']:
"""Return sliding windows for given size, stride, and padding.
Each of size, stride, and padding can be either a positive int or
a tuple `(vertical-component, horizontal-component)` of positive ints.
a tuple ``(vertical-component, horizontal-component)`` of positive
ints.
Padding currently only applies to the right and bottom edges.
If ``padding`` is not specified and ``stride <= size``, it will be
automatically calculated such that the windows cover the entire extent.
Args:
size (Union[PosInt, Tuple[PosInt, PosInt]]): Size (h, w) of the
windows.
stride (Union[PosInt, Tuple[PosInt, PosInt]]): Step size between
windows. Can be 2-tuple (h_step, w_step) or positive int.
padding (Optional[Union[PosInt, Tuple[PosInt, PosInt]]], optional):
Optional padding to accommodate windows that overflow the
size: Size (h, w) of the windows.
stride: Step size between windows. Can be 2-tuple (h_step, w_step)
or positive int.
padding: Optional padding to accommodate windows that overflow the
extent. Can be 2-tuple (h_pad, w_pad) or non-negative int.
If None, will be set to (size[0]//2, size[1]//2).
Defaults to None.
pad_direction (Literal['both', 'start', 'end']): If 'end', only pad
ymax and xmax (bottom and right). If 'start', only pad ymin and
xmin (top and left). If 'both', pad all sides. Has no effect if
padding is zero. Defaults to 'end'.
If None, will be automatically calculated such that the windows
cover the entire extent. Defaults to ``None``.
pad_direction: If ``'end'``, only pad ymax and xmax (bottom and
right). If ``'start'``, only pad ymin and xmin (top and left).
If ``'both'``, pad all sides. If ``'both'`` pad all sides. Has
no effect if padding is zero. Defaults to ``'end'``.
Returns:
List[Box]: List of Box objects.
List of Box objects.
"""
if not isinstance(size, tuple):
size = (size, size)

if not isinstance(stride, tuple):
stride = (stride, stride)
size: tuple[PosInt, PosInt] = ensure_tuple(size)
stride: tuple[PosInt, PosInt] = ensure_tuple(stride)

if size[0] <= 0 or size[1] <= 0 or stride[0] <= 0 or stride[1] <= 0:
raise ValueError('size and stride must be positive.')

if padding is None:
padding = (size[0] // 2, size[1] // 2)
if size[0] < stride[0] or size[1] < stride[1]:
padding = (0, 0)
else:
padding = calculate_required_padding(self.size, size, stride,
pad_direction)

if not isinstance(padding, tuple):
padding = (padding, padding)
padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)

if padding[0] < 0 or padding[1] < 0:
raise ValueError('padding must be non-negative.')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class WindowSamplingConfig(Config):
'end',
description='If "end", only pad ymax and xmax (bottom and right). '
'If "start", only pad ymin and xmin (top and left). If "both", '
'pad all sides. Has no effect if paddiong is zero. Defaults to "end".')
'pad all sides. Has no effect if padding is zero. Defaults to "end".')
size_lims: Optional[Tuple[PosInt, PosInt]] = Field(
None,
description='[min, max) interval from which window sizes will be '
Expand Down
1 change: 1 addition & 0 deletions rastervision_core/rastervision/core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from rastervision.core.utils.stac import *
from rastervision.core.utils.types import *
from rastervision.core.utils.misc import *
89 changes: 89 additions & 0 deletions rastervision_core/rastervision/core/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Literal, TypeVar
import math

from pydantic.types import NonNegativeInt as NonNegInt, PositiveInt as PosInt

T = TypeVar('T')


def ensure_tuple(x: T, n: int = 2) -> tuple[T, ...]:
"""Convert to n-tuple if not already an n-tuple."""
if isinstance(x, tuple):
if len(x) != n:
raise ValueError()
return x
return tuple([x] * n)


def calculate_required_padding(
extent_sz: PosInt | tuple[PosInt, PosInt],
chip_sz: PosInt | tuple[PosInt, PosInt],
stride: PosInt | tuple[PosInt, PosInt],
pad_direction: Literal['start', 'end', 'both'],
crop_sz: NonNegInt | None = None) -> tuple[NonNegInt, NonNegInt]:
"""Calculate min padding to ensure sliding windows cover all pixels.
Args:
extent_sz: Extent size as (h, w) tuple.
chip_sz: Chip size as (h, w) tuple.
stride: Stride size as (h_step, w_step) tuple.
pad_direction: One of: 'start', 'end', 'both'.
crop_sz: When cropping out window edges during semantic segmentation
prediction, pixels at the edges of the scene can be left with no
prediction if there is not enough padding. When ``crop_sz`` is
specified, the calculated padding takes this into account. Has no
effect if zero. Defaults to ``None``.
Returns:
Padding as (h_pad, w_pad) tuple.
"""
extent_sz: tuple[PosInt, PosInt] = ensure_tuple(extent_sz)
chip_sz: tuple[PosInt, PosInt] = ensure_tuple(chip_sz)
stride: tuple[PosInt, PosInt] = ensure_tuple(stride)

img_h, img_w = extent_sz
chip_h, chip_w = chip_sz
stride_h, stride_w = stride

if chip_h < stride_h or chip_w < stride_w:
raise ValueError(
f'chip_sz ({chip_sz}) cannot be less than stride ({stride}).')

if crop_sz is not None and crop_sz > 0:
if pad_direction != 'both':
raise ValueError(
'crop_sz is only supported with pad_direction="both"')
cropped_chip_h = chip_h - 2 * crop_sz
cropped_chip_w = chip_w - 2 * crop_sz
if cropped_chip_h < stride_h or cropped_chip_w < stride_w:
raise ValueError(
f'Cropped chip size ({(cropped_chip_h, cropped_chip_w)}) '
f'cannot be less than stride ({stride}).')
h_padding, w_padding = calculate_required_padding(
extent_sz,
(cropped_chip_h, cropped_chip_w),
stride,
pad_direction=pad_direction,
crop_sz=None,
)
h_padding += 2 * crop_sz
w_padding += 2 * crop_sz
else:
if img_h > chip_h:
num_strides = math.ceil((img_h - chip_h) / stride_h)
max_val = chip_h + num_strides * stride_h
h_padding = max_val - img_h
else:
h_padding = chip_h - img_h
if img_w > chip_w:
num_strides = math.ceil((img_w - chip_w) / stride_w)
max_val = chip_w + num_strides * stride_w
w_padding = max_val - img_w
else:
w_padding = chip_w - img_w
if pad_direction == 'both':
h_padding = math.ceil(h_padding / 2)
w_padding = math.ceil(w_padding / 2)

padding = (h_padding, w_padding)
return padding
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple, Union
import logging

import numpy as np
Expand All @@ -8,6 +8,7 @@
from shapely.ops import unary_union

from rastervision.core.box import Box
from rastervision.core.utils import ensure_tuple
from rastervision.core.data import Scene
from rastervision.core.data.utils import AoiSampler
from rastervision.pytorch_learner.learner_config import PosInt, NonNegInt
Expand All @@ -19,17 +20,6 @@

log = logging.getLogger(__name__)

T = TypeVar('T')


def _to_tuple(x: T, n: int = 2) -> Tuple[T, ...]:
"""Convert to n-tuple if not already an n-tuple."""
if isinstance(x, tuple):
if len(x) != n:
raise ValueError()
return x
return tuple([x] * n)


class AlbumentationsDataset(Dataset):
"""An adapter to use arbitrary datasets with albumentations transforms."""
Expand Down Expand Up @@ -160,7 +150,7 @@ def __init__(
self.out_size = None

if out_size is not None:
self.out_size = _to_tuple(out_size)
self.out_size: tuple[PosInt, PosInt] = ensure_tuple(out_size)
transform = self.append_resize_transform(transform, self.out_size)

super().__init__(
Expand Down Expand Up @@ -224,7 +214,7 @@ def __init__(
pad_direction (Literal['both', 'start', 'end']): If 'end', only pad
ymax and xmax (bottom and right). If 'start', only pad ymin and
xmin (top and left). If 'both', pad all sides. Has no effect if
paddiong is zero. Defaults to 'end'.
padding is zero. Defaults to 'end'.
within_aoi: If True and if the scene has an AOI, only sample
windows that lie fully within the AOI. If False, windows only
partially intersecting the AOI will also be allowed.
Expand Down Expand Up @@ -256,8 +246,8 @@ def __init__(
normalize=normalize,
to_pytorch=to_pytorch,
return_window=return_window)
self.size = _to_tuple(size)
self.stride = _to_tuple(stride)
self.size: tuple[PosInt, PosInt] = ensure_tuple(size)
self.stride: tuple[PosInt, PosInt] = ensure_tuple(stride)
self.padding = padding
self.pad_direction = pad_direction
self.init_windows()
Expand Down Expand Up @@ -402,7 +392,7 @@ def __init__(self,
else:
max_h, max_w = h_lims[1], w_lims[1]
padding = (max_h // 2, max_w // 2)
padding = _to_tuple(padding)
padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)

if max_windows is None:
max_windows = np.iinfo('int').max
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from rastervision.core.data import (ChipClassificationLabels,
ObjectDetectionLabels,
SemanticSegmentationLabels)
from rastervision.core.utils import calculate_required_padding

if TYPE_CHECKING:
import numpy as np
Expand Down Expand Up @@ -98,13 +99,25 @@ def predict_scene_ss(learner: 'SemanticSegmentationLearner', scene: 'Scene',
raw_out = label_store.smooth_output

base_tf, _ = learner.cfg.data.get_data_transforms()
pad_direction = 'end' if crop_sz is None else 'both'
ds = SemanticSegmentationSlidingWindowGeoDataset(
scene,
size=chip_sz,
stride=stride,
pad_direction=pad_direction,
transform=base_tf)
if crop_sz is None:
ds = SemanticSegmentationSlidingWindowGeoDataset(
scene, size=chip_sz, stride=stride, transform=base_tf)
else:
padding = calculate_required_padding(
extent_sz=scene.extent.size,
chip_sz=(chip_sz, chip_sz),
stride=(stride, stride),
pad_direction='both',
crop_sz=crop_sz,
)
ds = SemanticSegmentationSlidingWindowGeoDataset(
scene,
size=chip_sz,
stride=stride,
padding=padding,
pad_direction='both',
transform=base_tf,
)

predictions: Iterator['np.ndarray'] = learner.predict_dataset(
ds,
Expand Down
10 changes: 9 additions & 1 deletion tests/core/test_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def test_get_windows(self):

extent = Box(0, 0, 100, 100)
windows = extent.get_windows(size=10, stride=5)
self.assertEqual(len(windows), 20 * 20)
self.assertEqual(len(windows), 19 * 19)

extent = Box(0, 0, 20, 20)
windows = set(extent.get_windows(size=10, stride=10))
Expand Down Expand Up @@ -365,6 +365,14 @@ def test_get_windows(self):
msg = f'{extent!r}.get_windows({arg_str})'
self.assertSetEqual(windows, expected_windows, msg=msg)

# default padding = (0, 0) if stride > size
extent = Box(0, 0, 10, 10)
args = dict(size=5, stride=6, pad_direction='end')
windows = extent.get_windows(**args)
arg_str = ', '.join(f'{k}={v!r}' for k, v in args.items())
msg = f'{extent!r}.get_windows({arg_str})'
self.assertEqual(len(windows), 1, msg=msg)

args = dict(size=5, stride=3, padding=2, pad_direction='invalid')
self.assertRaises(ValueError, lambda: extent.get_windows(**args))

Expand Down
Loading

0 comments on commit 8f9f6d4

Please sign in to comment.