Skip to content

Commit

Permalink
Added boundary padding when spatial_pad > 0 also.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 25, 2024
1 parent 3284f97 commit 3173823
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 121 deletions.
140 changes: 62 additions & 78 deletions sup3r/pipeline/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,15 @@ def s1_hr_crop_slices(self):
if self._s1_hr_crop_slices is None:
self._s1_hr_crop_slices = self.get_hr_cropped_slices(
unpadded_slices=self.s1_lr_slices,
padded_slices=self.s1_lr_pad_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
)

self._s1_hr_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s1_lr_slices,
cropped_slices=self._s1_hr_crop_slices,
dim=0,
)
return self._s1_hr_crop_slices

@property
Expand All @@ -258,10 +263,14 @@ def s2_hr_crop_slices(self):
if self._s2_hr_crop_slices is None:
self._s2_hr_crop_slices = self.get_hr_cropped_slices(
unpadded_slices=self.s2_lr_slices,
padded_slices=self.s2_lr_pad_slices,
enhancement=self.s_enhance,
padding=self.spatial_pad,
)
self._s2_hr_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s2_lr_slices,
cropped_slices=self._s2_hr_crop_slices,
dim=1,
)
return self._s2_hr_crop_slices

@property
Expand Down Expand Up @@ -296,9 +305,20 @@ def s_lr_crop_slices(self):
s1_crop_slices = self.get_cropped_slices(
self.s1_lr_slices, self.s1_lr_pad_slices, 1
)

s1_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s1_lr_slices,
cropped_slices=s1_crop_slices,
dim=0,
)
s2_crop_slices = self.get_cropped_slices(
self.s2_lr_slices, self.s2_lr_pad_slices, 1
)
s2_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s2_lr_slices,
cropped_slices=s2_crop_slices,
dim=1,
)
self._s_lr_crop_slices = list(
it.product(s1_crop_slices, s2_crop_slices)
)
Expand Down Expand Up @@ -343,52 +363,6 @@ def hr_crop_slices(self):
self._hr_crop_slices.append(node_slices)
return self._hr_crop_slices

def check_boundary_slice(self, slices, dim):
"""Check boundary slice for minimum shape.
When spatial padding is used data is always padded to have at least 2 *
spatial_pad + 1 elements. When spatial padding is not used it's
possible for the forward pass chunk shape to divide the grid size such
that the last slice does not meet the minimum number of elements.
(Padding layers in the generator typically require a minimum shape of
4). So, when spatial padding is not used so we add extra padding to
meet the minimum shape requirement, otherwise we raise an error if the
minimum shape is not met."""

end_slice = slices[-1]
err_msg = (
'The final spatial slice for dimension #%s is too small (%s). '
'Adjust the forward pass chunk shape (%s) and / or spatial '
'padding (%s) so that 2 * spatial_pad + '
'modulo(grid_shape, fwp_chunk_shape) > 3'
)
warn_msg = (
'The final spatial slice for dimension #%s is too small (%s). '
'The start of this slice will be reduced to try to meet the '
'minimum slice length.'
)

if end_slice.stop - end_slice.start < 4:
if self.spatial_pad == 0:
logger.warning(warn_msg, dim + 1, end_slice)
warn(warn_msg % (dim + 1, end_slice))
new_start = np.max([0, end_slice.stop - self.chunk_shape[dim]])
end_slice = slice(new_start, end_slice.stop, end_slice.step)
slices[-1] = end_slice
if 2 * self.spatial_pad + (end_slice.stop - end_slice.start) < 4:
logger.error(
err_msg,
dim + 1,
end_slice,
self.chunk_shape,
self.spatial_pad,
)
raise ValueError(
err_msg
% (dim + 1, end_slice, self.chunk_shape, self.spatial_pad)
)
return slices

@property
def s1_lr_pad_slices(self):
"""List of low resolution spatial slices with padding for first
Expand All @@ -400,9 +374,6 @@ def s1_lr_pad_slices(self):
enhancement=1,
padding=self.spatial_pad,
)
self._s1_lr_pad_slices = self.check_boundary_slice(
slices=self._s1_lr_pad_slices, dim=0
)
return self._s1_lr_pad_slices

@property
Expand All @@ -416,9 +387,6 @@ def s2_lr_pad_slices(self):
enhancement=1,
padding=self.spatial_pad,
)
self._s2_lr_pad_slices = self.check_boundary_slice(
slices=self._s2_lr_pad_slices, dim=1
)
return self._s2_lr_pad_slices

@property
Expand Down Expand Up @@ -561,6 +529,42 @@ def get_padded_slices(slices, shape, enhancement, padding, step=None):
pad_slices.append(slice(start, end, step))
return pad_slices

def check_boundary_slice(self, unpadded_slices, cropped_slices, dim):
"""Check cropped slice at the right boundary for minimum shape.
It is possible for the forward pass chunk shape to divide the grid size
such that the last slice (right boundary) does not meet the minimum
number of elements. (Padding layers in the generator typically require
a minimum shape of 4). When this minimum shape is not met we apply
extra padding in ``ForwardPassStrategy._get_pad_width``. Cropped slices
have to be adjusted to account for this here."""

warn_msg = (
'The final spatial slice for dimension #%s is too small '
'(slice=slice(%s, %s), padding=%s). The start of this slice will '
'be reduced to try to meet the minimum slice length.'
)

lr_slice_start = unpadded_slices[-1].start or 0
lr_slice_stop = unpadded_slices[-1].stop or self.coarse_shape[dim]

# last slice adjustment
if 2 * self.spatial_pad + (lr_slice_stop - lr_slice_start) < 4:
logger.warning(
warn_msg,
dim + 1,
lr_slice_start,
lr_slice_stop,
self.spatial_pad,
)
warn(
warn_msg
% (dim + 1, lr_slice_start, lr_slice_stop, self.spatial_pad)
)
cropped_slices[-1] = slice(2 * self.s_enhance, -2 * self.s_enhance)

return cropped_slices

@staticmethod
def get_cropped_slices(unpadded_slices, padded_slices, enhancement):
"""Get cropped slices to cut off padded output
Expand Down Expand Up @@ -593,23 +597,12 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement):
if stop is not None and stop >= 0:
stop = None
cropped_slices.append(slice(start, stop))

return cropped_slices

@classmethod
def get_hr_cropped_slices(
cls, unpadded_slices, padded_slices, padding, enhancement
):
"""Get high res cropped slices
Note
----
It's possible to get a boundary slice that is too small for generator
input (padding layers typically need at least 4 elements) if the
forward pass chunk shape does not evenly divide the grid shape. We add
extra padding in the low res slices to account for this (with
:meth:`check_boundary_slice`) and need to adjust the high res cropped
slices accordingly.
"""
def get_hr_cropped_slices(cls, unpadded_slices, padding, enhancement):
"""Get high res cropped slices"""

hr_crop_start = None
hr_crop_stop = None
Expand All @@ -618,13 +611,4 @@ def get_hr_cropped_slices(
hr_crop_start = enhancement * padding
hr_crop_stop = -hr_crop_start

slices = [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices)

if padding == 0:
end_slice = cls.get_cropped_slices(
unpadded_slices[-1:],
padded_slices[-1:],
enhancement,
)
slices[-1] = slice(end_slice[0].start, None)
return slices
return [slice(hr_crop_start, hr_crop_stop)] * len(unpadded_slices)
49 changes: 40 additions & 9 deletions sup3r/pipeline/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import pathlib
import pprint
import warnings
from dataclasses import dataclass
from functools import cached_property
from typing import Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -336,6 +335,18 @@ def preflight(self):
out = self.fwp_slicer.get_time_slices()
self.ti_slices, self.ti_pad_slices = out

fwp_s1_steps = self.fwp_chunk_shape[0] + 2 * self.spatial_pad
fwp_s2_steps = self.fwp_chunk_shape[1] + 2 * self.spatial_pad
msg = (
'The padding layers in the generator typically require at least 4 '
'elements per spatial dimension. The padded chunk shape (%s, %s) '
'is smaller than this.'
)

if fwp_s1_steps < 4 or fwp_s2_steps < 4:
logger.warning(msg, fwp_s1_steps, fwp_s2_steps)
warn(msg % (fwp_s1_steps, fwp_s2_steps))

fwp_tsteps = self.fwp_chunk_shape[2] + 2 * self.temporal_pad
tsteps = len(self.input_handler.time_index[self.time_slice])
msg = (
Expand All @@ -345,7 +356,7 @@ def preflight(self):
)
if fwp_tsteps > tsteps:
logger.warning(msg)
warnings.warn(msg)
warn(msg)
out = self.fwp_slicer.get_spatial_slices()
self.lr_slices, self.lr_pad_slices, self.hr_slices = out

Expand Down Expand Up @@ -400,7 +411,7 @@ def out_files(self):
return out_file_list

@staticmethod
def _get_pad_width(window, max_steps, max_pad):
def _get_pad_width(window, max_steps, max_pad, check_boundary=False):
"""
Parameters
----------
Expand All @@ -410,16 +421,30 @@ def _get_pad_width(window, max_steps, max_pad):
Maximum number of steps available. Padding cannot extend past this
max_pad : int
Maximum amount of padding to apply.
check_bounary : bool
Whether to check the final slice for minimum size requirement
Returns
-------
tuple
Tuple of pad width for the given window.
"""
start = window.start or 0
stop = window.stop or max_steps
start = int(np.maximum(0, (max_pad - start)))
stop = int(np.maximum(0, max_pad + stop - max_steps))
win_start = window.start or 0
win_stop = window.stop or max_steps
start = int(np.maximum(0, (max_pad - win_start)))
stop = int(np.maximum(0, max_pad + win_stop - max_steps))

# We add minimum padding to the last slice if the padded window is
# too small for the generator. This can happen if 2 * spatial_pad +
# modulo(grid_size, fwp_chunk_shape) < 4
if (
check_boundary
and win_stop == max_steps
and (win_stop - win_start) < 4
):
stop = np.max([2, max_pad])
start = np.max([2, max_pad])

return (start, stop)

def get_pad_width(self, chunk_index):
Expand All @@ -438,10 +463,16 @@ def get_pad_width(self, chunk_index):

return (
self._get_pad_width(
lr_slice[0], self.input_handler.grid_shape[0], self.spatial_pad
lr_slice[0],
self.input_handler.grid_shape[0],
self.spatial_pad,
check_boundary=True,
),
self._get_pad_width(
lr_slice[1], self.input_handler.grid_shape[1], self.spatial_pad
lr_slice[1],
self.input_handler.grid_shape[1],
self.spatial_pad,
check_boundary=True,
),
self._get_pad_width(
ti_slice, len(self.input_handler.time_index), self.temporal_pad
Expand Down
Loading

0 comments on commit 3173823

Please sign in to comment.